diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a28d79393..b2f6584687 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -204,7 +204,7 @@ jobs: fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi - if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi diff --git a/pytensor/link/__init__.py b/pytensor/link/__init__.py index c8c236a854..e69de29bb2 100644 --- a/pytensor/link/__init__.py +++ b/pytensor/link/__init__.py @@ -1 +0,0 @@ -from pytensor.link.pytorch.linker import PytorchLinker diff --git a/pytensor/link/pytorch/__init__.py b/pytensor/link/pytorch/__init__.py new file mode 100644 index 0000000000..c8c236a854 --- /dev/null +++ b/pytensor/link/pytorch/__init__.py @@ -0,0 +1 @@ +from pytensor.link.pytorch.linker import PytorchLinker