|
9 | 9 |
|
10 | 10 | from setuptools import find_packages, setup |
11 | 11 |
|
12 | | - |
| 12 | +# Minimum required python version |
13 | 13 | REQUIRED_MAJOR = 3 |
14 | 14 | REQUIRED_MINOR = 8 |
15 | 15 |
|
| 16 | +# Requirements for testing, formatting, and tutorials |
| 17 | +TEST_REQUIRES = ["pytest", "pytest-cov"] |
| 18 | +FMT_REQUIRES = ["flake8", "ufmt", "flake8-docstrings"] |
| 19 | +TUTORIALS_REQUIRES = [ |
| 20 | + "ax-platform", |
| 21 | + "cma", |
| 22 | + "jupyter", |
| 23 | + "kaleido", |
| 24 | + "matplotlib", |
| 25 | + "memory_profiler", |
| 26 | + "pykeops", |
| 27 | + "torchvision", |
| 28 | +] |
| 29 | + |
16 | 30 | # Check for python version |
17 | 31 | if sys.version_info < (REQUIRED_MAJOR, REQUIRED_MINOR): |
18 | 32 | error = ( |
|
26 | 40 | ) |
27 | 41 | sys.exit(error) |
28 | 42 |
|
| 43 | +# Assign root dir location for later use |
| 44 | +root_dir = os.path.dirname(__file__) |
29 | 45 |
|
30 | | -# Requirements |
31 | | -TEST_REQUIRES = ["pytest", "pytest-cov"] |
32 | 46 |
|
33 | | -FMT_REQUIRES = ["flake8", "ufmt", "flake8-docstrings"] |
| 47 | +def read_deps_from_file(filname): |
| 48 | + """Read in requirements file and return items as list of strings""" |
| 49 | + with open(os.path.join(root_dir, filname), "r") as fh: |
| 50 | + return [line.strip() for line in fh.readlines() if not line.startswith("#")] |
34 | 51 |
|
35 | | -# Read in the pinned versions of the formatting tools |
36 | | -root_dir = os.path.dirname(__file__) |
37 | | -with open(os.path.join(root_dir, "requirements-fmt.txt"), "r") as fh: |
38 | | - FMT_REQUIRES += [ |
39 | | - line.strip() for line in fh.readlines() if not line.startswith("#") |
| 52 | + |
| 53 | +# Read in the requirements from the requirements.txt file |
| 54 | +install_requires = read_deps_from_file("requirements.txt") |
| 55 | + |
| 56 | +# Allow non-pinned (usually dev) versions of gpytorch and linear_operator |
| 57 | +if os.environ.get("ALLOW_LATEST_GPYTORCH_LINOP"): |
| 58 | + # Allows more recent previously installed versions. If there is no |
| 59 | + # previously installed version, installs the latest release. |
| 60 | + install_requires = [ |
| 61 | + dep.replace("==", ">=") |
| 62 | + if "gpytorch" in dep or "linear_operator" in dep |
| 63 | + else dep |
| 64 | + for dep in install_requires |
40 | 65 | ] |
41 | 66 |
|
| 67 | +# Read in pinned versions of the formatting tools |
| 68 | +FMT_REQUIRES += read_deps_from_file("requirements-fmt.txt") |
| 69 | +# Dev is test + formatting + docs generation |
42 | 70 | DEV_REQUIRES = TEST_REQUIRES + FMT_REQUIRES + ["sphinx"] |
43 | 71 |
|
44 | | -TUTORIALS_REQUIRES = [ |
45 | | - "ax-platform", |
46 | | - "cma", |
47 | | - "jupyter", |
48 | | - "kaleido", |
49 | | - "matplotlib", |
50 | | - "memory_profiler", |
51 | | - "pykeops", |
52 | | - "torchvision", |
53 | | -] |
54 | | - |
55 | 72 | # read in README.md as the long description |
56 | 73 | with open(os.path.join(root_dir, "README.md"), "r") as fh: |
57 | 74 | long_description = fh.read() |
|
80 | 97 | long_description_content_type="text/markdown", |
81 | 98 | python_requires=">=3.8", |
82 | 99 | packages=find_packages(exclude=["test", "test.*"]), |
83 | | - install_requires=[ |
84 | | - "torch>=1.11", |
85 | | - "gpytorch==1.9.0", |
86 | | - "linear_operator==0.1.1", |
87 | | - "scipy", |
88 | | - "multipledispatch", |
89 | | - "pyro-ppl>=1.8.2", |
90 | | - ], |
| 100 | + install_requires=install_requires, |
91 | 101 | extras_require={ |
92 | 102 | "dev": DEV_REQUIRES, |
93 | 103 | "test": TEST_REQUIRES, |
|
0 commit comments