Skip to content

Commit bcb7fdb

Browse files
asgloverAustin Glover
andauthored
stronger criteria to attempt compilation (#166)
* stronger criteria to attempt compilation * integrating PR feedback * make TORCH_COMPILE False by default, remove duplicate setting of generic_module. remove constants for _compile_torch_extension() as it's referenced only once. torch_compile_exception -> TORCH_COMPILE_ERROR for consistency. * format * remove lru cache * bad commit * test extension built * revert intentional mistake * go back to simple checks for attempting build. * add tests for the torch specific extension * remove trailing whitespace --------- Co-authored-by: Austin Glover <[email protected]>
1 parent 339d8ea commit bcb7fdb

File tree

4 files changed

+126
-104
lines changed

4 files changed

+126
-104
lines changed

.github/workflows/verify_extension_build.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ jobs:
3333
3434
- name: Test extension build via import
3535
run: |
36-
pytest tests/import_test.py -k test_import
36+
pytest \
37+
tests/import_test.py::test_extension_built \
38+
tests/import_test.py::test_torch_extension_built

openequivariance/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
import sys
33
import torch
44
import numpy as np
5-
6-
try:
7-
import openequivariance.extlib
8-
except Exception as e:
9-
raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}")
105
from pathlib import Path
116
from importlib.metadata import version
127

8+
import openequivariance.extlib
9+
10+
from openequivariance.extlib import (
11+
LINKED_LIBPYTHON,
12+
LINKED_LIBPYTHON_ERROR,
13+
BUILT_EXTENSION,
14+
BUILT_EXTENSION_ERROR,
15+
TORCH_COMPILE,
16+
TORCH_COMPILE_ERROR,
17+
)
18+
1319
from openequivariance.implementations.e3nn_lite import (
1420
TPProblem,
1521
Irrep,
@@ -63,9 +69,6 @@ def torch_ext_so_path():
6369
]
6470
)
6571

66-
LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON
67-
LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR
68-
6972
__all__ = [
7073
"TPProblem",
7174
"Irreps",

openequivariance/extlib/__init__.py

Lines changed: 98 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55
import sysconfig
66
from pathlib import Path
77

8-
global torch
98
import torch
109

1110
from openequivariance.benchmark.logging_utils import getLogger
1211

1312
oeq_root = str(Path(__file__).parent.parent)
1413

15-
build_ext = True
16-
TORCH_COMPILE = True
17-
TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip
18-
torch_module, generic_module = None, None
19-
postprocess_kernel = lambda kernel: kernel # noqa : E731
14+
BUILT_EXTENSION = False
15+
BUILT_EXTENSION_ERROR = None
16+
17+
TORCH_COMPILE = False
18+
TORCH_COMPILE_ERROR = None
2019

2120
LINKED_LIBPYTHON = False
2221
LINKED_LIBPYTHON_ERROR = None
22+
23+
torch_module, generic_module = None, None
24+
postprocess_kernel = lambda kernel: kernel # noqa : E731
25+
26+
2327
try:
2428
python_lib_dir = sysconfig.get_config_var("LIBDIR")
2529
major, minor = sys.version_info.major, sys.version_info.minor
@@ -33,114 +37,113 @@
3337
)
3438

3539
LINKED_LIBPYTHON = True
36-
3740
except Exception as e:
3841
LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}"
3942

40-
generic_module = None
41-
if not build_ext:
43+
44+
if BUILT_EXTENSION:
4245
import openequivariance.extlib.generic_module
4346

4447
generic_module = openequivariance.extlib.generic_module
45-
elif TORCH_VERSION_CUDA_OR_HIP:
46-
from torch.utils.cpp_extension import library_paths, include_paths
47-
48-
extra_cflags = ["-O3"]
49-
generic_sources = ["generic_module.cpp"]
50-
torch_sources = ["libtorch_tp_jit.cpp"]
51-
52-
include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])
53-
54-
if LINKED_LIBPYTHON:
55-
extra_link_args.pop()
56-
extra_link_args.extend(
57-
[
58-
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
59-
f"-L{python_lib_dir}",
60-
f"-l{python_lib_name}",
61-
],
62-
)
63-
64-
if torch.version.cuda:
65-
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
66-
67-
try:
68-
torch_libs, cuda_libs = library_paths("cuda")
48+
elif torch.version.cuda or torch.version.hip:
49+
try:
50+
from torch.utils.cpp_extension import library_paths, include_paths
51+
52+
extra_cflags = ["-O3"]
53+
generic_sources = ["generic_module.cpp"]
54+
torch_sources = ["libtorch_tp_jit.cpp"]
55+
56+
include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])
57+
58+
if LINKED_LIBPYTHON:
59+
extra_link_args.pop()
60+
extra_link_args.extend(
61+
[
62+
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
63+
f"-L{python_lib_dir}",
64+
f"-l{python_lib_name}",
65+
],
66+
)
67+
if torch.version.cuda:
68+
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
69+
70+
try:
71+
torch_libs, cuda_libs = library_paths("cuda")
72+
extra_link_args.append("-Wl,-rpath," + torch_libs)
73+
extra_link_args.append("-L" + cuda_libs)
74+
if os.path.exists(cuda_libs + "/stubs"):
75+
extra_link_args.append("-L" + cuda_libs + "/stubs")
76+
except Exception as e:
77+
getLogger().info(str(e))
78+
79+
extra_cflags.append("-DCUDA_BACKEND")
80+
elif torch.version.hip:
81+
extra_link_args.extend(["-lhiprtc"])
82+
torch_libs = library_paths("cuda")[0]
6983
extra_link_args.append("-Wl,-rpath," + torch_libs)
70-
extra_link_args.append("-L" + cuda_libs)
71-
if os.path.exists(cuda_libs + "/stubs"):
72-
extra_link_args.append("-L" + cuda_libs + "/stubs")
73-
except Exception as e:
74-
getLogger().info(str(e))
75-
76-
extra_cflags.append("-DCUDA_BACKEND")
77-
elif torch.version.hip:
78-
extra_link_args.extend(["-lhiprtc"])
79-
torch_libs = library_paths("cuda")[0]
80-
extra_link_args.append("-Wl,-rpath," + torch_libs)
81-
82-
def postprocess(kernel):
83-
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
84-
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
85-
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
86-
return kernel
87-
88-
postprocess_kernel = postprocess
89-
90-
extra_cflags.append("-DHIP_BACKEND")
91-
92-
generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
93-
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
94-
include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths(
95-
"cuda"
96-
)
9784

98-
torch_compile_exception = None
99-
with warnings.catch_warnings():
100-
warnings.simplefilter("ignore")
101-
102-
try:
103-
torch_module = torch.utils.cpp_extension.load(
104-
"libtorch_tp_jit",
105-
torch_sources,
85+
def postprocess(kernel):
86+
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
87+
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
88+
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
89+
return kernel
90+
91+
postprocess_kernel = postprocess
92+
93+
extra_cflags.append("-DHIP_BACKEND")
94+
95+
generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
96+
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
97+
include_dirs = [
98+
oeq_root + "/extension/" + d for d in include_dirs
99+
] + include_paths("cuda")
100+
101+
with warnings.catch_warnings():
102+
warnings.simplefilter("ignore")
103+
104+
try:
105+
torch_module = torch.utils.cpp_extension.load(
106+
"libtorch_tp_jit",
107+
torch_sources,
108+
extra_cflags=extra_cflags,
109+
extra_include_paths=include_dirs,
110+
extra_ldflags=extra_link_args,
111+
)
112+
torch.ops.load_library(torch_module.__file__)
113+
TORCH_COMPILE = True
114+
except Exception as e:
115+
# If compiling torch fails (e.g. low gcc version), we should fall back to the
116+
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
117+
TORCH_COMPILE_ERROR = e
118+
119+
generic_module = torch.utils.cpp_extension.load(
120+
"generic_module",
121+
generic_sources,
106122
extra_cflags=extra_cflags,
107123
extra_include_paths=include_dirs,
108124
extra_ldflags=extra_link_args,
109125
)
110-
torch.ops.load_library(torch_module.__file__)
111-
except Exception as e:
112-
# If compiling torch fails (e.g. low gcc version), we should fall back to the
113-
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
114-
TORCH_COMPILE = False
115-
torch_compile_exception = e
116-
117-
generic_module = torch.utils.cpp_extension.load(
118-
"generic_module",
119-
generic_sources,
120-
extra_cflags=extra_cflags,
121-
extra_include_paths=include_dirs,
122-
extra_ldflags=extra_link_args,
123-
)
124-
if "generic_module" not in sys.modules:
125-
sys.modules["generic_module"] = generic_module
126+
if "generic_module" not in sys.modules:
127+
sys.modules["generic_module"] = generic_module
126128

127-
if not TORCH_COMPILE:
128-
warnings.warn(
129-
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
130-
+ f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
131-
)
129+
if not TORCH_COMPILE:
130+
warnings.warn(
131+
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
132+
+ f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}"
133+
)
134+
BUILT_EXTENSION = True
135+
except Exception as e:
136+
BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}"
132137
else:
133-
TORCH_COMPILE = False
138+
BUILT_EXTENSION_ERROR = "OpenEquivariance extension build not attempted"
134139

135140

136141
def _raise_import_error_helper(import_target: str):
137-
if not TORCH_VERSION_CUDA_OR_HIP:
138-
raise ImportError(
139-
f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false"
140-
)
142+
if not BUILT_EXTENSION:
143+
raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}")
141144

142145

143-
if TORCH_VERSION_CUDA_OR_HIP:
146+
if BUILT_EXTENSION:
144147
from generic_module import (
145148
JITTPImpl,
146149
JITConvImpl,

tests/import_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,17 @@ def test_import():
77
assert openequivariance.__version__ is not None
88
assert openequivariance.__version__ != "0.0.0"
99
assert openequivariance.__version__ == version("openequivariance")
10+
11+
12+
def test_extension_built():
13+
from openequivariance import BUILT_EXTENSION, BUILT_EXTENSION_ERROR
14+
15+
assert BUILT_EXTENSION_ERROR is None
16+
assert BUILT_EXTENSION
17+
18+
19+
def test_torch_extension_built():
20+
from openequivariance import TORCH_COMPILE, TORCH_COMPILE_ERROR
21+
22+
assert TORCH_COMPILE_ERROR is None
23+
assert TORCH_COMPILE

0 commit comments

Comments
 (0)