diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 0d8e62e00..1b6b1cf5c 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -16,8 +16,8 @@ requirements: - python>=3.9 - setuptools run: - - numpy<2.0 - - pytorch>=1.13 + - numpy + - pytorch>=2.3 - matplotlib-base - tqdm - packaging diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index f6cb7fab1..42bb3a708 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -12,16 +12,9 @@ jobs: tests: strategy: matrix: - pytorch_args: ["-v 1.13", "-v 2.0.0", "-v 2.1.0", "-v 2.2.0", "-v 2.3.0", "-v 2.4.0", "-v 2.5.0", "-v 2.6.0"] + pytorch_args: ["-v 2.3.0", "-v 2.4.0", "-v 2.5.0", "-v 2.6.0", "-v 2.7.0"] transformers_args: ["-t 4.38.0", "-t 4.39.0", "-t 4.41.0", "-t 4.43.0", "-t 4.45.2"] docker_img: ["cimg/python:3.9", "cimg/python:3.10", "cimg/python:3.11", "cimg/python:3.12"] - exclude: - - pytorch_args: "-v 1.13" - docker_img: "cimg/python:3.12" - - pytorch_args: "-v 2.0.0" - docker_img: "cimg/python:3.12" - - pytorch_args: "-v 2.1.0" - docker_img: "cimg/python:3.12" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/README.md b/README.md index fede17dce..d7d87fd6d 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Captum can also be used by application engineers who are using trained models in **Installation Requirements** - Python >= 3.9 -- PyTorch >= 1.13 +- PyTorch >= 2.3 ##### Installing the latest release diff --git a/environment.yml b/environment.yml index b0e51f273..4112558d7 100644 --- a/environment.yml +++ b/environment.yml @@ -2,8 +2,8 @@ name: captum channels: - pytorch dependencies: - - numpy<2.0 - - pytorch>=1.13 + - numpy + - pytorch>=2.3 - matplotlib-base - tqdm - packaging diff --git a/setup.py b/setup.py index 59c01bf97..32126301b 100644 --- a/setup.py +++ b/setup.py @@ -163,9 +163,9 @@ def get_package_files(root, subdirs): ), install_requires=[ "matplotlib", - "numpy<2.0", + "numpy", "packaging", - "torch>=1.13", + "torch>=2.3", "tqdm", ], packages=find_packages(exclude=("tests", "tests.*")),