|
1 | 1 | --- a/setup.py |
2 | 2 | +++ b/setup.py |
3 | | -@@ -7,13 +7,9 @@ |
| 3 | +@@ -4,14 +4,20 @@ |
| 4 | + import glob |
| 5 | + import os |
| 6 | + import shutil |
| 7 | ++import sys |
4 | 8 | from os import path |
5 | 9 | from setuptools import find_packages, setup |
6 | 10 | from typing import List |
|
10 | 14 |
|
11 | 15 | -torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] |
12 | 16 | -assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8" |
| 17 | ++ |
| 18 | ++ |
| 19 | ++def _get_torch_version(): |
| 20 | ++ import torch |
| 21 | ++ return [int(x) for x in torch.__version__.split(".")[:2]] |
| 22 | ++ |
| 23 | ++ |
| 24 | ++def _get_build_ext_class(): |
| 25 | ++ try: |
| 26 | ++ from torch.utils.cpp_extension import BuildExtension |
| 27 | ++ return BuildExtension |
| 28 | ++ except ImportError: |
| 29 | ++ return type('DummyBuildExt', (), {'run': lambda self: None}) |
13 | 30 |
|
14 | 31 | - |
15 | 32 | def get_version(): |
16 | | - init_py_path = path.join(path.abspath(path.dirname(__file__)), "detectron2", "__init__.py") |
17 | | - init_py = open(init_py_path, "r").readlines() |
18 | | -@@ -38,6 +34,15 @@ |
| 33 | +@@ -38,6 +44,21 @@ def get_version(): |
19 | 34 |
|
20 | 35 |
|
21 | 36 | def get_extensions(): |
| 37 | ++ # Skip extension building during build requirements detection |
| 38 | ++ # This is called by setuptools even when just determining build requirements |
| 39 | ++ if any(arg in sys.argv for arg in ['egg_info', 'dist_info', '--help-commands', '--name', '--version']): |
| 40 | ++ return [] |
| 41 | ++ |
| 42 | ++ # Import torch only when actually building extensions |
22 | 43 | + try: |
23 | 44 | + import torch |
24 | 45 | + from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension |
25 | | -+ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] |
| 46 | ++ torch_ver = _get_torch_version() |
26 | 47 | + assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8" |
27 | 48 | + except ImportError: |
28 | | -+ # Return empty extensions list if torch is not available during build requirements detection |
29 | 49 | + return [] |
30 | | -+ |
| 50 | ++ except Exception: |
| 51 | ++ return [] |
| 52 | ++ |
31 | 53 | this_dir = path.dirname(path.abspath(__file__)) |
32 | 54 | extensions_dir = path.join(this_dir, "detectron2", "layers", "csrc") |
33 | 55 |
|
34 | | -@@ -204,5 +209,5 @@ |
| 56 | +@@ -204,9 +225,6 @@ setup( |
35 | 57 | ], |
36 | 58 | }, |
37 | 59 | ext_modules=get_extensions(), |
38 | 60 | - cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, |
39 | | -+ cmdclass={"build_ext": lambda dist: __import__('torch.utils.cpp_extension', fromlist=['BuildExtension']).BuildExtension(dist)}, |
| 61 | ++ cmdclass={"build_ext": _get_build_ext_class()}, |
40 | 62 | ) |
| 63 | +- |
| 64 | +- |
| 65 | +-def _get_build_ext_class(): |
| 66 | +- try: |
| 67 | +- from torch.utils.cpp_extension import BuildExtension |
| 68 | +- return BuildExtension |
| 69 | +- except ImportError: |
| 70 | +- return type('DummyBuildExt', (), {'run': lambda self: None}) |
0 commit comments