|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -"""The package setup script for modelopt customizing certain aspects of the installation process.""" |
| 16 | +"""The package setup script for modelopt customizing certain aspects of the installation process. |
| 17 | +
|
| 18 | +If installing from source, the CUDA version is detected and the appropriate cupy package is selected. |
| 19 | +If installing from a wheel, cupy for CUDA 13 is installed by default. |
| 20 | + If you have CUDA 12, you need to run `pip uninstall -y cupy-cuda13x` and `pip install cupy-cuda12x` separately. |
| 21 | +""" |
| 22 | + |
| 23 | +import re |
| 24 | +import subprocess |
17 | 25 |
|
18 | 26 | import setuptools |
19 | 27 | from setuptools_scm import get_version |
20 | 28 |
|
21 | | -# TODO: Set fallback_version to X.Y.Z release version when creating the release branch |
22 | 29 | version = get_version(root=".", fallback_version="0.0.0") |
23 | 30 |
|
| 31 | + |
| 32 | +def get_cuda_major_version() -> int | None: |
| 33 | + """Return CUDA major version installed on the system or None if detection fails.""" |
| 34 | + # Check nvcc version |
| 35 | + try: |
| 36 | + result = subprocess.run( |
| 37 | + ["nvcc", "--version"], |
| 38 | + capture_output=True, |
| 39 | + text=True, |
| 40 | + timeout=5, |
| 41 | + ) |
| 42 | + if result.returncode == 0: |
| 43 | + # Parse output like "release 12.0, V12.0.140" or "release 13.0, V13.0.0" |
| 44 | + for line in result.stdout.split("\n"): |
| 45 | + if "release" in line.lower(): |
| 46 | + match = re.search(r"release (\d+)\.", line) |
| 47 | + if match: |
| 48 | + return int(match.group(1)) |
| 49 | + except Exception: |
| 50 | + pass |
| 51 | + |
| 52 | + return None |
| 53 | + |
| 54 | + |
24 | 55 | # Required and optional dependencies ############################################################### |
25 | 56 | required_deps = [ |
26 | 57 | # Common |
|
43 | 74 | optional_deps = { |
44 | 75 | "onnx": [ |
45 | 76 | "cppimport", |
46 | | - "cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'", |
47 | 77 | "ml_dtypes", # for bfloat16 conversion |
48 | 78 | "onnx-graphsurgeon", |
49 | 79 | "onnx~=1.19.0", |
|
93 | 123 | "sphinx-rtd-theme~=3.0.0", # 3.0 does not show version, which we want as Linux & Windows have separate releases |
94 | 124 | "sphinx-togglebutton>=0.3.2", |
95 | 125 | ], |
96 | | - # build/packaging tools |
97 | | - "dev-build": [ |
98 | | - "cython", |
99 | | - "setuptools>=80", |
100 | | - "setuptools-scm>=8", |
101 | | - ], |
102 | 126 | } |
103 | 127 |
|
| 128 | +# Select the appropriate cupy package based on the detected CUDA version or fallback to cupy-cuda12x |
| 129 | +cuda_version = get_cuda_major_version() |
| 130 | + |
| 131 | +if cuda_version is None: |
| 132 | + # Default to CUDA 13 if detection fails |
| 133 | + cuda_version = 13 |
| 134 | + |
| 135 | +optional_deps["onnx"].append( |
| 136 | + f"cupy-cuda{cuda_version}x ; platform_machine != 'aarch64' and platform_system != 'Darwin'" |
| 137 | +) |
| 138 | + |
104 | 139 | # create "compound" optional dependencies |
105 | 140 | optional_deps["all"] = [ |
106 | 141 | deps for k in optional_deps if not k.startswith("dev") for deps in optional_deps[k] |
|
0 commit comments