Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/verify_extension_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ jobs:

- name: Test extension build via import
run: |
pytest tests/import_test.py -k test_import
pytest \
tests/import_test.py::test_extension_built \
tests/import_test.py::test_torch_extension_built
19 changes: 11 additions & 8 deletions openequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
import sys
import torch
import numpy as np

try:
import openequivariance.extlib
except Exception as e:
raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}")
from pathlib import Path
from importlib.metadata import version

import openequivariance.extlib

from openequivariance.extlib import (
LINKED_LIBPYTHON,
LINKED_LIBPYTHON_ERROR,
BUILT_EXTENSION,
BUILT_EXTENSION_ERROR,
TORCH_COMPILE,
TORCH_COMPILE_ERROR,
)

from openequivariance.implementations.e3nn_lite import (
TPProblem,
Irrep,
Expand Down Expand Up @@ -63,9 +69,6 @@ def torch_ext_so_path():
]
)

LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON
LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR

__all__ = [
"TPProblem",
"Irreps",
Expand Down
193 changes: 98 additions & 95 deletions openequivariance/extlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@
import sysconfig
from pathlib import Path

global torch
import torch

from openequivariance.benchmark.logging_utils import getLogger

oeq_root = str(Path(__file__).parent.parent)

build_ext = True
TORCH_COMPILE = True
TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip
torch_module, generic_module = None, None
postprocess_kernel = lambda kernel: kernel # noqa : E731
BUILT_EXTENSION = False
BUILT_EXTENSION_ERROR = None

TORCH_COMPILE = False
TORCH_COMPILE_ERROR = None

LINKED_LIBPYTHON = False
LINKED_LIBPYTHON_ERROR = None

torch_module, generic_module = None, None
postprocess_kernel = lambda kernel: kernel # noqa : E731


try:
python_lib_dir = sysconfig.get_config_var("LIBDIR")
major, minor = sys.version_info.major, sys.version_info.minor
Expand All @@ -33,114 +37,113 @@
)

LINKED_LIBPYTHON = True

except Exception as e:
LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}"

generic_module = None
if not build_ext:

if BUILT_EXTENSION:
import openequivariance.extlib.generic_module

generic_module = openequivariance.extlib.generic_module
elif TORCH_VERSION_CUDA_OR_HIP:
from torch.utils.cpp_extension import library_paths, include_paths

extra_cflags = ["-O3"]
generic_sources = ["generic_module.cpp"]
torch_sources = ["libtorch_tp_jit.cpp"]

include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])

if LINKED_LIBPYTHON:
extra_link_args.pop()
extra_link_args.extend(
[
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
f"-L{python_lib_dir}",
f"-l{python_lib_name}",
],
)

if torch.version.cuda:
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])

try:
torch_libs, cuda_libs = library_paths("cuda")
elif torch.version.cuda or torch.version.hip:
try:
from torch.utils.cpp_extension import library_paths, include_paths

extra_cflags = ["-O3"]
generic_sources = ["generic_module.cpp"]
torch_sources = ["libtorch_tp_jit.cpp"]

include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])

if LINKED_LIBPYTHON:
extra_link_args.pop()
extra_link_args.extend(
[
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
f"-L{python_lib_dir}",
f"-l{python_lib_name}",
],
)
if torch.version.cuda:
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])

try:
torch_libs, cuda_libs = library_paths("cuda")
extra_link_args.append("-Wl,-rpath," + torch_libs)
extra_link_args.append("-L" + cuda_libs)
if os.path.exists(cuda_libs + "/stubs"):
extra_link_args.append("-L" + cuda_libs + "/stubs")
except Exception as e:
getLogger().info(str(e))

extra_cflags.append("-DCUDA_BACKEND")
elif torch.version.hip:
extra_link_args.extend(["-lhiprtc"])
torch_libs = library_paths("cuda")[0]
extra_link_args.append("-Wl,-rpath," + torch_libs)
extra_link_args.append("-L" + cuda_libs)
if os.path.exists(cuda_libs + "/stubs"):
extra_link_args.append("-L" + cuda_libs + "/stubs")
except Exception as e:
getLogger().info(str(e))

extra_cflags.append("-DCUDA_BACKEND")
elif torch.version.hip:
extra_link_args.extend(["-lhiprtc"])
torch_libs = library_paths("cuda")[0]
extra_link_args.append("-Wl,-rpath," + torch_libs)

def postprocess(kernel):
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
return kernel

postprocess_kernel = postprocess

extra_cflags.append("-DHIP_BACKEND")

generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths(
"cuda"
)

torch_compile_exception = None
with warnings.catch_warnings():
warnings.simplefilter("ignore")

try:
torch_module = torch.utils.cpp_extension.load(
"libtorch_tp_jit",
torch_sources,
def postprocess(kernel):
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
return kernel

postprocess_kernel = postprocess

extra_cflags.append("-DHIP_BACKEND")

generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
include_dirs = [
oeq_root + "/extension/" + d for d in include_dirs
] + include_paths("cuda")

with warnings.catch_warnings():
warnings.simplefilter("ignore")

try:
torch_module = torch.utils.cpp_extension.load(
"libtorch_tp_jit",
torch_sources,
extra_cflags=extra_cflags,
extra_include_paths=include_dirs,
extra_ldflags=extra_link_args,
)
torch.ops.load_library(torch_module.__file__)
TORCH_COMPILE = True
except Exception as e:
# If compiling torch fails (e.g. low gcc version), we should fall back to the
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
TORCH_COMPILE_ERROR = e

generic_module = torch.utils.cpp_extension.load(
"generic_module",
generic_sources,
extra_cflags=extra_cflags,
extra_include_paths=include_dirs,
extra_ldflags=extra_link_args,
)
torch.ops.load_library(torch_module.__file__)
except Exception as e:
# If compiling torch fails (e.g. low gcc version), we should fall back to the
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
TORCH_COMPILE = False
torch_compile_exception = e

generic_module = torch.utils.cpp_extension.load(
"generic_module",
generic_sources,
extra_cflags=extra_cflags,
extra_include_paths=include_dirs,
extra_ldflags=extra_link_args,
)
if "generic_module" not in sys.modules:
sys.modules["generic_module"] = generic_module
if "generic_module" not in sys.modules:
sys.modules["generic_module"] = generic_module

if not TORCH_COMPILE:
warnings.warn(
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
+ f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
)
if not TORCH_COMPILE:
warnings.warn(
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
+ f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}"
)
BUILT_EXTENSION = True
except Exception as e:
BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}"
else:
TORCH_COMPILE = False
BUILT_EXTENSION_ERROR = "OpenEquivariance extension build not attempted"


def _raise_import_error_helper(import_target: str):
if not TORCH_VERSION_CUDA_OR_HIP:
raise ImportError(
f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false"
)
if not BUILT_EXTENSION:
raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}")


if TORCH_VERSION_CUDA_OR_HIP:
if BUILT_EXTENSION:
from generic_module import (
JITTPImpl,
JITConvImpl,
Expand Down
14 changes: 14 additions & 0 deletions tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,17 @@ def test_import():
assert openequivariance.__version__ is not None
assert openequivariance.__version__ != "0.0.0"
assert openequivariance.__version__ == version("openequivariance")


def test_extension_built():
from openequivariance import BUILT_EXTENSION, BUILT_EXTENSION_ERROR

assert BUILT_EXTENSION_ERROR is None
assert BUILT_EXTENSION


def test_torch_extension_built():
from openequivariance import TORCH_COMPILE, TORCH_COMPILE_ERROR

assert TORCH_COMPILE_ERROR is None
assert TORCH_COMPILE