diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 49e1cb04..fa6c7363 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install sphinx furo + pip install -r docs/requirements.txt - name: Build website run: | sphinx-build -M dirhtml docs docs/_build diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 588eebce..2f351dba 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -6,7 +6,7 @@ on: # ref: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ jobs: - build: + build-oeq: name: Build distribution runs-on: ubuntu-latest steps: @@ -16,21 +16,20 @@ jobs: with: python-version: '3.10' - name: install dependencies, then build source tarball - run: | + run: | + cd openequivariance python3 -m pip install build --user python3 -m build --sdist - name: store the distribution packages uses: actions/upload-artifact@v4 with: name: python-package-distributions - path: dist/ + path: openequivariance/dist/ pypi-publish: name: Upload release to PyPI runs-on: ubuntu-latest - # build task to be completed first - needs: build - # Specifying a GitHub environment is optional, but strongly encouraged + needs: build-oeq environment: name: pypi url: https://pypi.org/p/openequivariance @@ -42,6 +41,47 @@ jobs: uses: actions/download-artifact@v4 with: name: python-package-distributions - path: dist/ + path: openequivariance/dist/ + - name: publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + # ------------------------------------ + + build-oeq-extjax: + name: Build distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: install dependencies, then build source tarball + run: | + cd openequivariance_extjax + python3 -m pip install build --user + python3 -m build --sdist + - name: store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: openequivariance_extjax/dist/ + + pypi-publish-extjax: + name: Upload release to PyPI + runs-on: ubuntu-latest + needs: build-oeq-extjax + environment: + name: pypi + url: https://pypi.org/p/openequivariance_extjax + permissions: + # IMPORTANT: this permission is mandatory for Trusted Publishing + id-token: write + steps: + - name: download the distributions + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: openequivariance_extjax/dist/ - name: publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index da9fde80..0e04348c 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,4 +1,6 @@ numpy==2.2.5 torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 -ninja==1.11.1.4 \ No newline at end of file +ninja==1.11.1.4 +nanobind==2.10.2 +scikit-build-core==0.11.6 \ No newline at end of file diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 50b496b8..39888491 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -1,4 +1,4 @@ -name: OEQ CUDA C++ Extension Build Verification +name: OEQ C++ Extension Build Verification on: push: @@ -29,10 +29,14 @@ jobs: sudo apt-get update sudo apt install nvidia-cuda-toolkit pip install -r .github/workflows/requirements_cuda_ci.txt - pip install -e . + pip install -e "./openequivariance" - - name: Test extension build via import + - name: Test CUDA extension build via import run: | pytest \ tests/import_test.py::test_extension_built \ - tests/import_test.py::test_torch_extension_built \ No newline at end of file + tests/import_test.py::test_torch_extension_built + + - name: Test JAX extension build + run: | + XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5c878b1a..64fcaa8d 100644 --- a/.gitignore +++ b/.gitignore @@ -38,7 +38,6 @@ triton_autotuning paper_benchmarks paper_benchmarks_v2 paper_benchmarks_v3 -openequivariance/extlib/*.so get_node.sh *.egg-info \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index a120656b..e9bd29e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ ## Latest Changes +### v0.5.0 (2025-12-25) +JAX support is now available in +OpenEquivariance for BOTH NVIDIA and +AMD GPUs! See the +[documentation](https://passionlab.github.io/OpenEquivariance/) +and README.md for instructions on installation +and usage. + +Minor changes: +- Defer error reporting when CUDA is not available + to the first library usage in code, not library load. + ### v0.4.1 (2025-09-04) Minor update, fixes a bug loading JIT-compiled modules with PyTorch 2.9. diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 7eaa4d91..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,10 +0,0 @@ -include openequivariance/extlib/*.so -include openequivariance/extlib/*.empty - -include openequivariance/templates/*.cuh -include openequivariance/templates/*.jinja - -include openequivariance/extension/* -include openequivariance/extension/convolution/* -include openequivariance/extension/tensorproducts/* -include openequivariance/extension/util/* \ No newline at end of file diff --git a/README.md b/README.md index 288e1daf..c68e4fa9 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # OpenEquivariance -[![OEQ CUDA C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml) +[![OEQ C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml) [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) -[[Examples]](#show-me-some-examples) +[[PyTorch Examples]](#pytorch-examples) +[[JAX Examples]](#jax-examples) [[Citation and Acknowledgements]](#citation-and-acknowledgements) OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product, @@ -12,8 +13,8 @@ that [e3nn](https://e3nn.org/) supports commonly found in graph neural networks (e.g. [Nequip](https://github.com/mir-group/nequip) or [MACE](https://github.com/ACEsuit/mace)). To get -started, ensure that you have GCC 9+ on your system -and install our package via +started with PyTorch, ensure that you have PyTorch +and GCC 9+ available before installing our package via ```bash pip install openequivariance @@ -29,11 +30,26 @@ computation and memory consumption significantly. For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). -📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and -Computational Discrete Algorithms (Proceedings Track)! Catch the talk in -Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025). +⭐️ **JAX**: Our latest update brings +support for JAX. For NVIDIA GPUs, +install it (after installing JAX) +with the following two commands strictly in order: -## Show me some examples +``` bash +pip install openequivariance[jax] +pip install openequivariance_extjax --no-build-isolation +``` + +For AMD GPUs: +``` bash +pip install openequivariance[jax] +JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation +``` + +See the section below for example usage and +our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details. + +## PyTorch Examples Here's a CG tensor product implemented by e3nn: ```python @@ -127,6 +143,48 @@ print(torch.norm(Z)) `deterministic=False`, the `sender` and `receiver` indices can have arbitrary order. +## JAX Examples +After installation, use the library +as follows. Set `OEQ_NOTORCH=1` +in your environment to avoid the PyTorch import in +the regular `openequivariance` package. +```python +import jax +import os + +os.environ["OEQ_NOTORCH"] = "1" +import openequivariance as oeq + +seed = 42 +key = jax.random.PRNGKey(seed) + +batch_size = 1000 +X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e") +problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False) + + +node_ct, nonzero_ct = 3, 4 +edge_index = jax.numpy.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jax.numpy.int32, # NOTE: This int32, not int64 +) + +X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) +Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) +W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + +tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) +Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] +) +print(jax.numpy.linalg.norm(Z)) +``` + ## Citation and Acknowledgements If you find this code useful, please cite our paper: diff --git a/docs/api.rst b/docs/api.rst index 3fac1764..c21b918f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -8,7 +8,7 @@ OpenEquivariance API OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces ``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses the CG tensor product with a subsequent graph convolution. Initializing either class triggers -JIT compilation of a custom kernel, which can take a few seconds. +JIT compilation of a custom kernel, which can take a few seconds. Both classes require a configuration object specified by :py:class:`openequivariance.TPProblem`, which has a constructor @@ -17,6 +17,9 @@ We recommend reading the `e3nn documentation ` trying our code. OpenEquivariance cannot accelerate all tensor products; see :doc:`this page ` for a list of supported configurations. +PyTorch API +------------------------ + .. autoclass:: openequivariance.TensorProduct :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to :undoc-members: @@ -27,14 +30,39 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see :undoc-members: :exclude-members: name -.. autoclass:: openequivariance.TPProblem - :members: - :undoc-members: - .. autofunction:: openequivariance.torch_to_oeq_dtype .. autofunction:: openequivariance.torch_ext_so_path +JAX API +------------------------ +The JAX API consists of ``TensorProduct`` and ``TensorProductConv`` +classes that behave identically to their PyTorch counterparts. These classes +do not conform exactly to the e3nn-jax API, but perform the same computation. + +If you plan to use ``oeq.jax`` without PyTorch installed, +you need to set ``OEQ_NOTORCH=1`` in your local environment (within Python, +``os.environ["OEQ_NOTORCH"] = 1``). For the moment, we require this to avoid +breaking the PyTorch version of OpenEquivariance. + + +.. autoclass:: openequivariance.jax.TensorProduct + :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn + :undoc-members: + :exclude-members: + +.. autoclass:: openequivariance.jax.TensorProductConv + :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn + :undoc-members: + :exclude-members: + +Common API +--------------------- + +.. autoclass:: openequivariance.TPProblem + :members: + :undoc-members: + API Identical to e3nn --------------------- diff --git a/docs/conf.py b/docs/conf.py index 17707552..540cf37e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,11 +28,17 @@ html_theme = "furo" # html_static_path = ["_static"] -extensions = [ - "sphinx.ext.autodoc", +extensions = ["sphinx.ext.autodoc", "sphinx_inline_tabs"] + +sys.path.insert(0, str(Path("../openequivariance").resolve())) + +autodoc_mock_imports = [ + "torch", + "jax", + "openequivariance._torch.extlib", + "openequivariance.jax.extlib", + "openequivariance_extjax", + "jinja2", + "numpy", ] - -sys.path.insert(0, str(Path("..").resolve())) - -autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"] autodoc_typehints = "description" diff --git a/docs/installation.rst b/docs/installation.rst index 9c3588cb..5ade5c0c 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -8,94 +8,139 @@ Installation You need the following to install OpenEquivariance: - A Linux system equipped with an NVIDIA / AMD graphics card. -- PyTorch >= 2.4 (>= 2.8 for AOTI and export). +- Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX>0.5.0 + with CUDA or RocM support. - GCC 9+ and the CUDA / HIP toolkit. The command ``c++ --version`` should return >= 9.0; see below for details on setting an alternate compiler. -Installation is one easy command, followed by import verification: +.. tab:: PyTorch -.. code-block:: bash + Installation is one easy command, followed by import verification: - pip install openequivariance - python -c "import openequivariance" + .. code-block:: bash -The second line triggers a build of the C++ extension we use to compile -kernels, which can take a couple of minutes. Subsequent imports are -much faster since this extension is cached. + pip install openequivariance + python -c "import openequivariance" -To get the nightly build, run + The second line triggers a build of the C++ extension we use to compile + kernels, which can take a couple of minutes. Subsequent imports are + much faster since this extension is cached. -.. code-block:: bash + To support ``torch.compile``, ``torch.export``, and + JITScript, OpenEquivariance needs to compile a C++ extension + tightly integrated with PyTorch. If you see a warning that + this extension could not be compiled, first check: - pip install git+https://github.com/PASSIONLab/OpenEquivariance + .. code-block:: bash + c++ --version + + To build the extension with an alternate compiler, set the + ``CC`` and ``CXX`` + environment variable and retry the import: -Compiling the Integrated PyTorch Extension ------------------------------------------- -To support ``torch.compile``, ``torch.export``, and -JITScript, OpenEquivariance needs to compile a C++ extension -tightly integrated with PyTorch. If you see a warning that -this extension could not be compiled, first check: + .. code-block:: bash -.. code-block:: bash + export CC=/path/to/your/gcc + export CXX=/path/to/your/g++ + python -c "import openequivariance" - c++ --version - -To build the extension with an alternate compiler, set the -``CC`` and ``CXX`` -environment variable and retry the import: + + These configuration steps are required only ONCE after + installation (or upgrade) with pip. + + +.. tab:: JAX NVIDIA GPUs + + First ensure the appropriate JAX Python + package is installed in your environment. Then + run the following two commands stricly in order: + + .. code-block:: bash + + pip install openequivariance[jax] + pip install openequivariance_extjax --no-build-isolation + +.. tab:: JAX AMD GPUs + + Ensure that JAX is installed correctly with RocM support + before running, in order, + + .. code-block:: bash + + pip install openequivariance[jax] + JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation + + +.. tab:: Nightly (PT) + + .. code-block:: bash + + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance" + + +.. tab:: Nightly (JAX) + + .. code-block:: bash + + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax]" + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax --no-build-isolation" + + # Use the command below for JAX+AMD + # JAX_HIP=1 pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax --no-build-isolation" + + +If you're using JAX, set the environment variable +``OEQ_NOTORCH=1`` to avoid a PyTorch import: .. code-block:: bash - export CCC=/path/to/your/gcc - export CXX=/path/to/your/g++ - python -c "import openequivariance" + export OEQ_NOTORCH=1 + python -c "import openequivariance.jax" -These configuration steps are required only ONCE after -installation (or upgrade) with pip. Configurations on Major Platforms --------------------------------- OpenEquivariance has been tested on both supercomputers and lab clusters. -Here are some tested environment configuration files. If use OpenEquivariance -on a widely-used platform, send us a pull request to add your configuration! +Here are some tested environment configuration files. If you use OpenEquivariance +on a major cluster, send us a pull request to add your configuration! -NERSC Perlmutter (NVIDIA A100) -"""""""""""""""""""""""""""""" -.. code-block:: bash - :caption: env.sh (last updated June 2025) +.. tab:: NERSC Perlmutter (NVIDIA A100) - module load gcc - module load conda + .. code-block:: bash + :caption: env.sh (last updated June 2025) - # Deactivate any base environments - for i in $(seq ${CONDA_SHLVL}); do - conda deactivate - done + module load gcc + module load conda - conda activate + # Deactivate any base environments + for i in $(seq ${CONDA_SHLVL}); do + conda deactivate + done + conda activate -OLCF Frontier (AMD MI250x) -"""""""""""""""""""""""""" -You need to install a HIP-enabled verison of PyTorch to use our package. -To do this, follow the steps `here `_. +.. tab:: OLCF Frontier (AMD MI250x) -.. code-block:: bash - :caption: env.sh (last updated June 2025) + You need to install a HIP-enabled verison of PyTorch to use our package. + Follow the steps `here `_. + + + .. code-block:: bash + :caption: env.sh (last updated June 2025) - module load PrgEnv-gnu/8.6.0 - module load miniforge3/23.11.0-0 - module load rocm/6.4.0 - module load craype-accel-amd-gfx90a + module load PrgEnv-gnu/8.6.0 + module load miniforge3/23.11.0-0 + module load rocm/6.4.0 + module load craype-accel-amd-gfx90a - for i in $(seq ${CONDA_SHLVL}); do - conda deactivate - done + for i in $(seq ${CONDA_SHLVL}); do + conda deactivate + done - conda activate - export CC=cc - export CXX=CC \ No newline at end of file + conda activate + export CC=cc + export CXX=CC \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..1cc76517 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +furo +sphinx +sphinx-inline-tabs +sphinx-autobuild \ No newline at end of file diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst index 7f5ff78c..bcc11955 100644 --- a/docs/supported_ops.rst +++ b/docs/supported_ops.rst @@ -117,7 +117,7 @@ toplevel. You can use our implementation by running .. code-block:: - from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction + from openequivariance._torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction Some Github users report weak performance for the symmetric contraction backward pass; your mileage may vary. diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst index 7bc11b26..f602ab44 100644 --- a/docs/tests_and_benchmarks.rst +++ b/docs/tests_and_benchmarks.rst @@ -12,27 +12,41 @@ these; we provide instructions below. We recommend you clone our repository and use an editable install to run tests and benchmarks. -You can still test our code with a non-editable install; just -download the test folder and install the non-editable package and the dependencies with: +Correctness +------------------------------ +To set up an editable install and run our tests, use the following code: -.. code-block:: bash +.. tab:: PyTorch - pip install openequivariance[dev,bench] + .. code-block:: bash -Correctness ------------------------------- -To set up the editable install and run the entire testsuite, use: + git clone https://github.com/PASSIONLab/OpenEquivariance + cd OpenEquivariance + pip install -e "./openequivariance[dev]" + pytest tests/ -.. code-block:: bash +.. tab:: JAX - git clone https://github.com/PASSIONLab/OpenEquivariance - cd OpenEquivariance - pip install -e .[dev] - pytest + Note: To test correctness in JAX, we still require + an installation of PyTorch and e3nn in your environment. + + .. code-block:: bash + + git clone https://github.com/PASSIONLab/OpenEquivariance + cd OpenEquivariance + + pip install "./openequivariance[jax]" + pip install "./openequivariance[dev]" + pip install "./openequivariance_extjax" --no-build-isolation + + pytest --jax tests/example_test.py + pytest --jax tests/batch_test.py + pytest --jax tests/conv_test.py Browse the ``tests`` directory to run specific components. + Replicating our Benchmarks ------------------------------ We conducted our benchmarks on an NVIDIA A100-SXM-80GB GPU at Lawrence Berkeley National Laboratory. @@ -79,12 +93,13 @@ OpenEquivariance exhibits up to 2x speedup over FlashTP's fused kernels. List of GPUs Tested -------------------------------- -OpenEquivariance has been tested successfully the following GPUs. Submit a pull +OpenEquivariance runs successfully the following GPUs. Submit a pull request if you'd like to add your own! - NVIDIA V100 (V. Bharadwaj, LBNL Einsteinium, June 2025) - NVIDIA A100-SXM-40GB and A100-SXM-80GB (A. Glover, NERSC Perlmutter, June 2025) - NVIDIA A5000 (V. Bharadwaj, UCB SLICE, June 2025) +- NVIDIA T4 (V. Bharadwaj, Google Colab, Jan 2026) - NVIDIA H100 (L. Larsen, P1 DTU HPC, June 2025) - AMD MI250x (V. Bharadwaj, OLCF Frontier, June 2025) - AMD MI300x (V. Bharadwaj, AMD Cloud, February 2025) \ No newline at end of file diff --git a/openequivariance/LICENSE b/openequivariance/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/openequivariance/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/openequivariance/MANIFEST.in b/openequivariance/MANIFEST.in new file mode 100644 index 00000000..1d4a8cce --- /dev/null +++ b/openequivariance/MANIFEST.in @@ -0,0 +1,4 @@ +include openequivariance/templates/*.cuh +include openequivariance/templates/*.jinja + +include openequivariance/extension/* \ No newline at end of file diff --git a/openequivariance/README.md b/openequivariance/README.md new file mode 120000 index 00000000..32d46ee8 --- /dev/null +++ b/openequivariance/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py deleted file mode 100644 index 5a8ab812..00000000 --- a/openequivariance/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -# ruff: noqa: F401 -import sys -import torch -import numpy as np -from pathlib import Path -from importlib.metadata import version - -import openequivariance.extlib - -from openequivariance.extlib import ( - LINKED_LIBPYTHON, - LINKED_LIBPYTHON_ERROR, - BUILT_EXTENSION, - BUILT_EXTENSION_ERROR, - TORCH_COMPILE, - TORCH_COMPILE_ERROR, -) - -from openequivariance.implementations.e3nn_lite import ( - TPProblem, - Irrep, - Irreps, - _MulIr, - Instruction, -) -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.implementations.convolution.TensorProductConv import ( - TensorProductConv, -) -from openequivariance.implementations.utils import torch_to_oeq_dtype - -__version__ = None -try: - __version__ = version("openequivariance") -except Exception as e: - print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr) - - -def _check_package_editable(): - import json - from importlib.metadata import Distribution - - direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json") - return json.loads(direct_url).get("dir_info", {}).get("editable", False) - - -_editable_install_output_path = Path(__file__).parent.parent / "outputs" - - -def torch_ext_so_path(): - """ - :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance - from the PyTorch C++ Interface. - """ - return openequivariance.extlib.torch_module.__file__ - - -torch.serialization.add_safe_globals( - [ - TensorProduct, - TensorProductConv, - TPProblem, - Irrep, - Irreps, - _MulIr, - Instruction, - np.float32, - np.float64, - ] -) - -__all__ = [ - "TPProblem", - "Irreps", - "TensorProduct", - "TensorProductConv", - "torch_to_oeq_dtype", - "_check_package_editable", - "torch_ext_so_path", -] diff --git a/openequivariance/implementations/LoopUnrollTP.py b/openequivariance/implementations/LoopUnrollTP.py deleted file mode 100644 index ed6a5395..00000000 --- a/openequivariance/implementations/LoopUnrollTP.py +++ /dev/null @@ -1,311 +0,0 @@ -import numpy as np - -import openequivariance.extlib as extlib -from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.implementations.ComputationSchedule import ComputationSchedule - -from openequivariance.implementations.dtype_enum import dtype_to_enum -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.utils import ( - filter_and_analyze_problem, - count_cg_non_zero, -) - -logger = getLogger() - - -class LoopUnrollTP(TensorProductBase): - def __init__(self, config, torch_op=True): - super().__init__(config, torch_op=torch_op) - - env = get_jinja_environment() - template = env.get_template("loop_unroll_batch.cuh") - dp = extlib.DeviceProp(0) - - analysis = filter_and_analyze_problem(config) - self.is_uvw = analysis["is_uvw"] - - def generate_forward_schedule(warps_per_block): - self.forward_schedule = ComputationSchedule( - self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount * 4, - direction="forward", - irrep_dtype=config.irrep_dtype, - weight_dtype=config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - ) - - def generate_backward_schedule(warps_per_block): - self.backward_schedule = ComputationSchedule( - self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount * 4, - direction="backward", - irrep_dtype=config.irrep_dtype, - weight_dtype=config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - ) - - def generate_double_backward_schedule(warps_per_block): - self.double_backward_schedule = ComputationSchedule( - self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount, - direction="double_backward", - irrep_dtype=config.irrep_dtype, - weight_dtype=config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - schedule_type=3, - ) - - scheduler_generators = [ - generate_forward_schedule, - generate_backward_schedule, - generate_double_backward_schedule, - ] - - for generate_schedule in scheduler_generators: - warp_count = 8 - while warp_count > 0: - try: - generate_schedule(warp_count) - break - except Exception: - warp_count -= 2 - if warp_count == 0: - raise RuntimeError( - "Tensor product schedule generation failed, shared memory inadequate!" - ) - - self.jit_kernel = extlib.postprocess_kernel( - template.render( - forward_schedule=self.forward_schedule, - backward_schedule=self.backward_schedule, - double_backward_schedule=self.double_backward_schedule, - ) - ) - - # with open("scratch.txt", "w") as f: - # f.write(self.jit_kernel) - - internal_cls = None - if self.torch_op and extlib.TORCH_COMPILE: - global torch - import torch - - internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct - else: - internal_cls = extlib.JITTPImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - { - "L1_dim": self.L1.dim, - "L2_dim": self.L2.dim, - "L3_dim": self.L3.dim, - "weight_numel": self.config.weight_numel, - "shared_weights": int(self.config.shared_weights), - "opt_level": 3, - "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], - "weight_dtype": dtype_to_enum[self.config.weight_dtype], - }, - ) - logger.info("Kernel compiled!") - logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - - def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) - - def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) - - @classmethod - def register_torch_fakes(cls): - global torch - import torch - - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct") - class TorchJITProduct: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - self.kernel_plaintext = state["kernel_plaintext"] - self.fwd_config = state["fwd_config"] - self.bwd_config = state["bwd_config"] - self.dbl_bwd_config = state["dbl_bwd_config"] - self.kernel_dims = state["kernel_dims"] - - def exec_tensor_product_rawptr(*args, **kwargs): - pass - - def backward_rawptr(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") - def fake_forward(jit, L1_in, L2_in, W): - L3_dim = None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - else: - L3_dim = jit.L3_dim - - return L1_in.new_empty(L1_in.shape[0], L3_dim) - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") - def fake_backward(jit, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward - - def setup_context(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output - ) - return None, L1_grad, L2_grad, W_grad - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, E, F, G): - result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G - ) - return None, result[0], result[1], result[2], result[3] - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch - - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 - ) - - @staticmethod - def name(): - return "LoopUnrollTP" - - def calculate_flops_forward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_forward(batch_size) - else: - tpp = self.config - flop_count = { - "CG_decomposition": 0, - "linear_combination": 0, - "outer_products": 0, - } - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - flop_count["linear_combination"] += ( - (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 - ) - - flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= ( - batch_size # Weights do not require FMA here - ) - flop_count["total"] = sum(flop_count.values()) - return flop_count - - def calculate_flops_backward(self, batch_size: int) -> dict: - if self.is_uvw: - return super().calculate_flops_backward(batch_size) - else: - tpp = self.config - flop_count = {"backward": 0} - for ins in tpp.instructions: - l1, l2, l3 = ( - tpp.irreps_in1[ins.i_in1].ir.l, - tpp.irreps_in2[ins.i_in2].ir.l, - tpp.irreps_out[ins.i_out].ir.l, - ) - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( - ins.path_shape[0] * ins.path_shape[1] - ) - - flop_count["backward"] *= 9 * batch_size - flop_count["total"] = sum(flop_count.values()) - return flop_count - - -if extlib.TORCH_COMPILE: - LoopUnrollTP.register_torch_fakes() - LoopUnrollTP.register_autograd() - LoopUnrollTP.register_autocast() diff --git a/openequivariance/implementations/TensorProduct.py b/openequivariance/implementations/TensorProduct.py deleted file mode 100644 index 54fc4307..00000000 --- a/openequivariance/implementations/TensorProduct.py +++ /dev/null @@ -1,200 +0,0 @@ -from openequivariance.implementations.LoopUnrollTP import LoopUnrollTP -from openequivariance import TPProblem -from openequivariance import extlib -import torch -import typing -from openequivariance.implementations.utils import torch_to_oeq_dtype - - -class TensorProduct(torch.nn.Module, LoopUnrollTP): - r""" - Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward, - backward, and double-backward passes using JIT-compiled kernels. Initialization - fails if: - - * There are no visible GPUs. - * The provided tensor product specification is unsupported. - - :param problem: Specification of the tensor product. - :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``. - """ - - def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): - torch.nn.Module.__init__(self) - self.input_args = { - "problem": problem, - "torch_op": torch_op, - "use_opaque": use_opaque, - } - self._init_class() - - def _init_class(self): - LoopUnrollTP.__init__( - self, self.input_args["problem"], self.input_args["torch_op"] - ) - self.weight_numel = self.input_args["problem"].weight_numel - self._setup_notorchbind() - if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: - self.forward = self.forward_opaque - - def to(self, *args, **kwargs): - r""" - See `torch.nn.Module.to() `_. - """ - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( - *args, **kwargs - ) - - if dtype is not None: - updated_problem = self.input_args["problem"].clone() - updated_problem.irrep_dtype = torch_to_oeq_dtype(dtype) - updated_problem.weight_dtype = torch_to_oeq_dtype(dtype) - self.input_args["problem"] = updated_problem - self._init_class() - - torch.nn.Module.to(self, *args, **kwargs) - return self - - def __getstate__(self): - return self.input_args - - def __setstate__(self, state): - torch.nn.Module.__init__(self) - self.input_args = state - self._init_class() - - @staticmethod - def name(): - return LoopUnrollTP.name() - - def forward( - self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor - ) -> torch.Tensor: - r""" - Computes :math:`W (x \otimes_{\textrm{CG}} y)`, identical to - ``o3.TensorProduct.forward``. - - :param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype - ``problem.irrep_dtype``. - :param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype - ``problem.irrep_dtype``. - :param W: Tensor of datatype ``problem.weight_dtype`` and shape - - * ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False`` - * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` - - :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. - """ - return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W) - - def _setup_notorchbind(self): - """ - In case TorchBind is not available (e.g. for torch.compile below PT2.8, etc.), - set up operations using custom ops. - """ - - @torch.library.custom_op( - f"openequivariance::tp_forward{self.tp_id}", - mutates_args=(), - device_types="cuda", - ) - def forward( - L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor - ) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = ( - L1_in.contiguous(), - L2_in.contiguous(), - weights.contiguous(), - ) - L3_out = torch.empty( - (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device=L1_in.device - ) - self.forward_raw( - L1_in_c.shape[0], - L1_in_c.data_ptr(), - L2_in_c.data_ptr(), - L3_out.data_ptr(), - weights_c.data_ptr(), - ) - return L3_out - - @forward.register_fake - def _(L1_in, L2_in, weights): - return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - - self.forward_opaque = forward - - # ---------------- Backward pass ----------------- - @torch.library.custom_op( - f"openequivariance::tp_grad_helper{self.tp_id}", - mutates_args=(), - device_types="cuda", - ) - def backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - L3_grad: torch.Tensor, - ) -> typing.List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - weights_grad = torch.empty_like(weights) - - if self.config.shared_weights: - weights_grad[:] = 0.0 - - self.backward_raw( - L1_in.shape[0], - L1_in.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), - L2_grad.data_ptr(), - weights.contiguous().data_ptr(), - weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - ) - - return [L1_grad, L2_grad, weights_grad] - - @backward_helper.register_fake - def _(L1_in, L2_in, weights, L3_grad): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - weights.new_empty(*weights.shape), - ] - - def setup_context(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output) - return result[0], result[1], result[2] - - self.forward_opaque.register_autograd(backward, setup_context=setup_context) - - def setup_context_double_backward(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, grad_output): - A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights - E, F, G = grad_output[0], grad_output[1], grad_output[2] - - op1 = backward_helper(E, F, D, C) - op2 = backward_helper(A, B, G, C) - op3 = forward(E, B, D) - op4 = backward_helper(E, B, D, C) - op5 = backward_helper(A, F, D, C) - op6 = forward(A, F, D) - op7 = forward(A, B, G) - - return ( - op1[0] + op2[0], - op1[1] + op2[1], - (op4[2] + op5[2]), - (op3 + op6 + op7), - ) - - backward_helper.register_autograd( - double_backward, setup_context=setup_context_double_backward - ) diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py deleted file mode 100644 index a5e46ce3..00000000 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ /dev/null @@ -1,453 +0,0 @@ -import numpy as np - -from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase -from openequivariance.implementations.ComputationSchedule import ( - ComputationSchedule, - SMEMCapacityException, -) - -from openequivariance.implementations.dtype_enum import ( - dtype_to_enum, - enum_to_torch_dtype, -) -from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance import extlib -from openequivariance.extlib import JITConvImpl, postprocess_kernel, DeviceProp - -from openequivariance.implementations.utils import filter_and_analyze_problem -from openequivariance.benchmark.logging_utils import getLogger - -logger = getLogger() - - -class LoopUnrollConv(ConvolutionBase): - def __init__( - self, - config, - *, - idx_dtype: type[np.generic] = np.int64, - torch_op: bool = False, - deterministic: bool = False, - kahan: bool = False, - ): - super().__init__( - config, idx_dtype=idx_dtype, torch_op=torch_op, deterministic=deterministic - ) - - if kahan: - assert deterministic - - env = get_jinja_environment() - template = env.get_template("loop_unroll_conv_atomic.cuh") - dp = DeviceProp(0) - - analysis = filter_and_analyze_problem(config) - self.is_uvw = analysis["is_uvw"] - - if config.shared_weights: - assert not deterministic, ( - "Deterministic convolution does not support shared weights" - ) - - forward_schedule_type = 3 - backward_schedule_type = 2 - if deterministic: - backward_schedule_type = 3 - template = env.get_template("loop_unroll_conv_det.cuh") - - def generate_forward_schedule(warps_per_block): - self.forward_schedule = ComputationSchedule( - self.config, - smem_limit=dp.maxSharedMemPerBlock // 4 * 3, - warps_per_block=warps_per_block, - block_count=dp.multiprocessorCount, - direction="forward", - irrep_dtype=config.irrep_dtype, - weight_dtype=config.weight_dtype, - schedule_type=forward_schedule_type, - warp_size=dp.warpsize, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - kahan=kahan, - ) - - def generate_backward_schedule(warps_per_block): - self.backward_schedule = ComputationSchedule( - self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - block_count=dp.multiprocessorCount * 2, - direction="backward", - irrep_dtype=config.irrep_dtype, - weight_dtype=config.weight_dtype, - schedule_type=backward_schedule_type, - warp_size=dp.warpsize, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - kahan=kahan, - ) - - def generate_double_backward_schedule(warps_per_block): - self.double_backward_schedule = ComputationSchedule( - self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount, - direction="double_backward", - irrep_dtype=config.irrep_dtype, - weight_dtype=config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - schedule_type=3, - kahan=kahan, - ) - - scheduler_generators = [ - generate_forward_schedule, - generate_backward_schedule, - generate_double_backward_schedule, - ] - - for generate_schedule in scheduler_generators: - warp_count = 6 - while warp_count > 0: - try: - generate_schedule(warp_count) - break - except SMEMCapacityException: - warp_count -= 1 - if warp_count == 0: - raise SMEMCapacityException( - "Tensor product schedule generation failed, shared memory inadequate!" - ) - - if not deterministic: - for segment in self.forward_schedule.segments: - for key in segment.L3Map.storeback_procedure: - segment.L3Map.storeback_procedure[key] = "atomic_accumulate" - - for segment in self.backward_schedule.segments: - for key in segment.L1Map.storeback_procedure: - segment.L1Map.storeback_procedure[key] = "atomic_accumulate" - - for segment in self.double_backward_schedule.segments: - for key in segment.L1Map.storeback_procedure: - segment.L1Map.storeback_procedure[key] = "atomic_accumulate" - - idx_type_map = {np.int32: "int", np.int64: "long"} - - self.forward_workspace_offset = None - self.backward_workspace_offset = None - self.double_backwardB_offset = None - - workspace_size = 1 - if deterministic: - destination_index_bytes = 32 # Add extra to account for padding - workspace_size = max( - ( - self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize - + destination_index_bytes - ) - * self.forward_schedule.total_warps, - ( - self.backward_schedule.L1.dim - * np.dtype(config.irrep_dtype).itemsize - + destination_index_bytes - ) - * self.backward_schedule.total_warps, - ( - self.double_backward_schedule.L1.dim - * np.dtype(config.irrep_dtype).itemsize - + destination_index_bytes - ) - * self.double_backward_schedule.total_warps, - ) - - self.forward_workspace_offset = ( - self.forward_schedule.L3.dim - * np.dtype(config.irrep_dtype).itemsize - * self.forward_schedule.total_warps - ) - self.backward_workspace_offset = ( - self.backward_schedule.L1.dim - * np.dtype(config.irrep_dtype).itemsize - * self.backward_schedule.total_warps - ) - self.double_backwardB_offset = ( - self.double_backward_schedule.L1.dim - * np.dtype(config.irrep_dtype).itemsize - * self.double_backward_schedule.total_warps - ) - - self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8 - self.backward_workspace_offset = ( - (self.backward_workspace_offset + 7) // 8 * 8 - ) - self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8 - - self.allocate_workspace(workspace_size) - - self.jit_kernel = template.render( - forward_schedule=self.forward_schedule, - backward_schedule=self.backward_schedule, - double_backward_schedule=self.double_backward_schedule, - idx_type=idx_type_map[idx_dtype], - forward_workspace_offset=self.forward_workspace_offset, - backward_workspace_offset=self.backward_workspace_offset, - double_backwardB_offset=self.double_backwardB_offset, - ) - self.jit_kernel = postprocess_kernel(self.jit_kernel) - - if self.torch_op and extlib.TORCH_COMPILE: - global torch - import torch - - internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv - else: - internal_cls = JITConvImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - { - "L1_dim": self.L1.dim, - "L2_dim": self.L2.dim, - "L3_dim": self.L3.dim, - "weight_numel": self.config.weight_numel, - "workspace_size": self.workspace_size, - "opt_level": 3, - "shared_weights": int(config.shared_weights), - "deterministic": int(self.deterministic), - "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], - "weight_dtype": dtype_to_enum[self.config.weight_dtype], - "idx_dtype": dtype_to_enum[self.idx_dtype], - }, - ) - logger.info("Kernel compiled!") - - # with open("scratch.txt", "w") as f: - # f.write(self.jit_kernel) - - def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) - - def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) - - @staticmethod - def name(): - return "LoopUnrollConv" - - @classmethod - def register_torch_fakes(cls): - global torch - import torch - - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv") - class TorchJITConv: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = state - - def exec_conv_rawptrs(*args, **kwargs): - pass - - def backward_rawptrs(*args, **kwargs): - pass - - def double_backward_rawptrs(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") - def fake_forward( - jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm - ): - L3_dim, irrep_dtype = None, None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"] - else: - L3_dim = jit.L3_dim - irrep_dtype = jit.irrep_dtype - - return torch.empty( - L1_in.shape[0], - L3_dim, - device="cuda", - dtype=enum_to_torch_dtype[irrep_dtype], - ) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") - def fake_backward( - jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm - ): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") - def fake_double_backward( - jit, - L1_in, - L2_in, - W, - L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - rows, - cols, - workspace_buffer, - transpose_perm=None, - ): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - W.new_empty(*W.shape), - L3_grad.new_empty(*L3_grad.shape), - ] - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward - double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward - - def setup_context(ctx, inputs, output): - ( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return None, L1_grad, L2_grad, W_grad, None, None, None, None - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - ctx.inputs = inputs - - def double_backward(ctx, E, F, G): - result = double_backward_op( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - E, - F, - G, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return ( - None, - result[0], - result[1], - result[2], - result[3], - None, - None, - None, - None, - ) - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch - - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 - ) - - -if extlib.TORCH_COMPILE: - LoopUnrollConv.register_torch_fakes() - LoopUnrollConv.register_autograd() - LoopUnrollConv.register_autocast() diff --git a/openequivariance/implementations/dtype_enum.py b/openequivariance/implementations/dtype_enum.py deleted file mode 100644 index 292b7e4f..00000000 --- a/openequivariance/implementations/dtype_enum.py +++ /dev/null @@ -1,47 +0,0 @@ -from enum import IntEnum -from types import MappingProxyType -import numpy as np -import torch - - -class DTypeEnum(IntEnum): - FLOAT32 = 1 - FLOAT64 = 2 - INT32 = 3 - INT64 = 4 - UINT8 = 5 - - -dtype_to_enum = MappingProxyType( - { - torch.float32: DTypeEnum.FLOAT32, - torch.float64: DTypeEnum.FLOAT64, - torch.int32: DTypeEnum.INT32, - torch.int64: DTypeEnum.INT64, - torch.uint8: DTypeEnum.UINT8, - # torch - np.float32: DTypeEnum.FLOAT32, - np.float64: DTypeEnum.FLOAT64, - np.int32: DTypeEnum.INT32, - np.int64: DTypeEnum.INT64, - np.uint8: DTypeEnum.UINT8, - # numpy generic - np.dtype(np.float32): DTypeEnum.FLOAT32, - np.dtype(np.float64): DTypeEnum.FLOAT64, - np.dtype(np.int32): DTypeEnum.INT32, - np.dtype(np.int64): DTypeEnum.INT64, - np.dtype(np.uint8): DTypeEnum.UINT8, - # numpy dtype - } -) - - -enum_to_torch_dtype = MappingProxyType( - { - DTypeEnum.FLOAT32: torch.float32, - DTypeEnum.FLOAT64: torch.float64, - DTypeEnum.INT32: torch.int32, - DTypeEnum.INT64: torch.int64, - DTypeEnum.UINT8: torch.uint8, - } -) diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/implementations/symmetric_contraction/__init__.py deleted file mode 100644 index 75ac6cc8..00000000 --- a/openequivariance/implementations/symmetric_contraction/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from openequivariance.implementations.symmetric_contraction.symmetric_contraction import ( - SymmetricContraction, -) - -__all__ = ["SymmetricContraction"] diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py new file mode 100644 index 00000000..a842a7c9 --- /dev/null +++ b/openequivariance/openequivariance/__init__.py @@ -0,0 +1,105 @@ +# ruff: noqa: F401 +import sys +import os +import numpy as np + +from pathlib import Path +from importlib.metadata import version + +from openequivariance.core.e3nn_lite import ( + TPProblem, + Irrep, + Irreps, + _MulIr, + Instruction, +) + +__version__ = None +try: + __version__ = version("openequivariance") +except Exception as e: + print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr) + + +def _check_package_editable(): + import json + from importlib.metadata import Distribution + + direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json") + return json.loads(direct_url).get("dir_info", {}).get("editable", False) + + +_editable_install_output_path = Path(__file__).parent.parent.parent / "outputs" + +if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1": + import torch + + from openequivariance._torch.TensorProduct import TensorProduct + from openequivariance._torch.TensorProductConv import TensorProductConv + + from openequivariance._torch.extlib import ( + torch_ext_so_path as torch_ext_so_path_internal, + ) + from openequivariance.core.utils import torch_to_oeq_dtype + + torch.serialization.add_safe_globals( + [ + TensorProduct, + TensorProductConv, + TPProblem, + Irrep, + Irreps, + _MulIr, + Instruction, + np.float32, + np.float64, + ] + ) + + from openequivariance._torch.extlib import ( + LINKED_LIBPYTHON, + LINKED_LIBPYTHON_ERROR, + BUILT_EXTENSION, + BUILT_EXTENSION_ERROR, + TORCH_COMPILE, + TORCH_COMPILE_ERROR, + ) + + +def torch_ext_so_path(): + """ + :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance + from the PyTorch C++ Interface. + """ + try: + return torch_ext_so_path_internal() + except NameError: + return None + + +jax = None +try: + import openequivariance_extjax + import openequivariance.jax as jax +except Exception as e: + error = e + + class JAX_ERR: + def TensorProduct(*args, **kwargs): + raise error + + def TensorProductConv(*args, **kwargs): + raise error + + jax = JAX_ERR() + +__all__ = [ + "TPProblem", + "Irreps", + "TensorProduct", + "TensorProductConv", + "torch_to_oeq_dtype", + "_check_package_editable", + "torch_ext_so_path", + "jax", +] diff --git a/openequivariance/implementations/convolution/CUEConv.py b/openequivariance/openequivariance/_torch/CUEConv.py similarity index 95% rename from openequivariance/implementations/convolution/CUEConv.py rename to openequivariance/openequivariance/_torch/CUEConv.py index 9287abe8..8500e39c 100644 --- a/openequivariance/implementations/convolution/CUEConv.py +++ b/openequivariance/openequivariance/_torch/CUEConv.py @@ -2,8 +2,8 @@ import itertools from typing import Iterator -from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.convolution.ConvolutionBase import ( +from openequivariance._torch.CUETensorProduct import CUETensorProduct +from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) diff --git a/openequivariance/implementations/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py similarity index 97% rename from openequivariance/implementations/CUETensorProduct.py rename to openequivariance/openequivariance/_torch/CUETensorProduct.py index a7d027f4..33b8db12 100644 --- a/openequivariance/implementations/CUETensorProduct.py +++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py @@ -4,15 +4,15 @@ import itertools from typing import Iterator -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger from openequivariance.benchmark.tpp_creation_utils import ( ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction, ) -from openequivariance.implementations.utils import count_cg_non_zero +from openequivariance.core.utils import count_cg_non_zero os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1" diff --git a/openequivariance/implementations/convolution/E3NNConv.py b/openequivariance/openequivariance/_torch/E3NNConv.py similarity index 86% rename from openequivariance/implementations/convolution/E3NNConv.py rename to openequivariance/openequivariance/_torch/E3NNConv.py index 00b0faa8..4cc20662 100644 --- a/openequivariance/implementations/convolution/E3NNConv.py +++ b/openequivariance/openequivariance/_torch/E3NNConv.py @@ -1,13 +1,14 @@ import numpy as np -from openequivariance.implementations.convolution.ConvolutionBase import ( +from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct +from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv -class E3NNConv(ConvolutionBase): +class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv): def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): assert torch_op super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op) @@ -37,7 +38,7 @@ def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): if config.irrep_dtype == np.float64: torch.set_default_dtype(torch.float32) # Reset to default - def forward(self, L1_in, L2_in, weights, rows, cols): + def forward(self, L1_in, L2_in, weights, rows, cols, transpose_perm=None): messages = self.reference_tp(L1_in[cols], L2_in, weights) return scatter_add_wrapper(messages, rows, L1_in.size(0)) diff --git a/openequivariance/implementations/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py similarity index 94% rename from openequivariance/implementations/E3NNTensorProduct.py rename to openequivariance/openequivariance/_torch/E3NNTensorProduct.py index 334ba65c..067a7e6b 100644 --- a/openequivariance/implementations/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -9,16 +9,17 @@ import pathlib import numpy as np -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") logger = getLogger() -class E3NNTensorProduct(TensorProductBase): +class E3NNTensorProduct(TensorProductBase, NumpyDoubleBackwardMixin): def __init__(self, config: TPProblem, torch_op=True): super().__init__(config, torch_op=torch_op) assert self.torch_op diff --git a/openequivariance/implementations/convolution/FlashTPConv.py b/openequivariance/openequivariance/_torch/FlashTPConv.py similarity index 87% rename from openequivariance/implementations/convolution/FlashTPConv.py rename to openequivariance/openequivariance/_torch/FlashTPConv.py index 0302ef9c..9ec5c409 100644 --- a/openequivariance/implementations/convolution/FlashTPConv.py +++ b/openequivariance/openequivariance/_torch/FlashTPConv.py @@ -4,8 +4,8 @@ import torch import numpy as np -from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase -from openequivariance.implementations.utils import oeq_to_torch_dtype +from openequivariance.core.ConvolutionBase import ConvolutionBase +from openequivariance.core.utils import oeq_to_torch_dtype class FlashTPConv(ConvolutionBase): diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py new file mode 100644 index 00000000..caf94268 --- /dev/null +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -0,0 +1,97 @@ +import torch + + +class NumpyDoubleBackwardMixin: + """ + Adds a Numpy double backward method to any TensorProduct + with the forward pass defined in PyTorch and the relevant + derivatives registered. + """ + + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad + ): + assert self.torch_op + + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") + weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") + out_torch = self.forward(in1_torch, in2_torch, weights_torch) + + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True, + ) + + a, b, c, d = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], + ) + + return ( + a.detach().cpu().numpy(), + b.detach().cpu().numpy(), + c.detach().cpu().numpy(), + d.detach().cpu().numpy(), + ) + + +class NumpyDoubleBackwardMixinConv: + """ + Similar, but for fused graph convolution. + """ + + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph + ): + assert self.torch_op + + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") + weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") + + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") + + out_torch = self.forward( + in1_torch, + in2_torch, + weights_torch, + torch_rows, + torch_cols, + torch_transpose_perm, + ) + + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True, + ) + + a, b, c, d = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], + ) + + return ( + a.detach().cpu().numpy(), + b.detach().cpu().numpy(), + c.detach().cpu().numpy(), + d.detach().cpu().numpy(), + ) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py new file mode 100644 index 00000000..05ea54b5 --- /dev/null +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -0,0 +1,448 @@ +from openequivariance.core.LoopUnrollTP import LoopUnrollTP +from openequivariance import TPProblem +from openequivariance._torch import extlib +import torch +import typing +from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance._torch.utils import reorder_torch +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin + +import numpy as np +from openequivariance._torch.extlib import DeviceBuffer + +logger = getLogger() + + +class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin): + r""" + Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward, + backward, and double-backward passes using JIT-compiled kernels. Initialization + fails if: + + * There are no visible GPUs. + * The provided tensor product specification is unsupported. + + :param problem: Specification of the tensor product. + :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``. + """ + + def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): + torch.nn.Module.__init__(self) + self.input_args = { + "problem": problem, + "torch_op": torch_op, + "use_opaque": use_opaque, + } + self._init_class() + + def _init_class(self): + dp = extlib.DeviceProp(0) + LoopUnrollTP.__init__( + self, + self.input_args["problem"], + dp, + extlib.postprocess_kernel, + self.input_args["torch_op"], + ) + + internal_cls = None + if extlib.TORCH_COMPILE: + internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct + else: + internal_cls = extlib.JITTPImpl + + logger.info("Starting kernel compiler...") + self.internal = internal_cls( + self.jit_kernel, + vars(self.forward_schedule.launch_config), + vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), + self.kernelProp, + ) + logger.info("Kernel compiled!") + logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") + + self.weight_numel = self.input_args["problem"].weight_numel + self._setup_notorchbind() + if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: + self.forward = self.forward_opaque + + def to(self, *args, **kwargs): + r""" + See `torch.nn.Module.to() `_. + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + + if dtype is not None: + updated_problem = self.input_args["problem"].clone() + updated_problem.irrep_dtype = torch_to_oeq_dtype(dtype) + updated_problem.weight_dtype = torch_to_oeq_dtype(dtype) + self.input_args["problem"] = updated_problem + self._init_class() + + torch.nn.Module.to(self, *args, **kwargs) + return self + + def __getstate__(self): + return self.input_args + + def __setstate__(self, state): + torch.nn.Module.__init__(self) + self.input_args = state + self._init_class() + + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return reorder_torch( + self.forward_schedule, weights, "forward", not self.config.shared_weights + ) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return reorder_torch( + self.forward_schedule, weights, "backward", not self.config.shared_weights + ) + + def forward( + self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor + ) -> torch.Tensor: + r""" + Computes :math:`W (x \otimes_{\textrm{CG}} y)`, identical to + ``o3.TensorProduct.forward``. + + :param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype + ``problem.irrep_dtype``. + :param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype + ``problem.irrep_dtype``. + :param W: Tensor of datatype ``problem.weight_dtype`` and shape + + * ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False`` + * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` + + :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. + """ + return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W) + + def _setup_notorchbind(self): + """ + In case TorchBind is not available (e.g. for torch.compile below PT2.8, etc.), + set up operations using custom ops. + """ + + @torch.library.custom_op( + f"openequivariance::tp_forward{self.tp_id}", + mutates_args=(), + device_types="cuda", + ) + def forward( + L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor + ) -> torch.Tensor: + L1_in_c, L2_in_c, weights_c = ( + L1_in.contiguous(), + L2_in.contiguous(), + weights.contiguous(), + ) + L3_out = torch.empty( + (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device=L1_in.device + ) + self.forward_raw( + L1_in_c.shape[0], + L1_in_c.data_ptr(), + L2_in_c.data_ptr(), + L3_out.data_ptr(), + weights_c.data_ptr(), + ) + return L3_out + + @forward.register_fake + def _(L1_in, L2_in, weights): + return L1_in.new_empty(L1_in.shape[0], self.L3.dim) + + self.forward_opaque = forward + + # ---------------- Backward pass ----------------- + @torch.library.custom_op( + f"openequivariance::tp_grad_helper{self.tp_id}", + mutates_args=(), + device_types="cuda", + ) + def backward_helper( + L1_in: torch.Tensor, + L2_in: torch.Tensor, + weights: torch.Tensor, + L3_grad: torch.Tensor, + ) -> typing.List[torch.Tensor]: + L1_grad = torch.zeros_like(L1_in) + L2_grad = torch.zeros_like(L2_in) + weights_grad = torch.empty_like(weights) + + if self.config.shared_weights: + weights_grad[:] = 0.0 + + self.backward_raw( + L1_in.shape[0], + L1_in.contiguous().data_ptr(), + L1_grad.data_ptr(), + L2_in.contiguous().data_ptr(), + L2_grad.data_ptr(), + weights.contiguous().data_ptr(), + weights_grad.data_ptr(), + L3_grad.contiguous().data_ptr(), + ) + + return [L1_grad, L2_grad, weights_grad] + + @backward_helper.register_fake + def _(L1_in, L2_in, weights, L3_grad): + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + weights.new_empty(*weights.shape), + ] + + def setup_context(ctx, inputs, output): + ctx.L1_in, ctx.L2_in, ctx.weights = inputs + + def backward(ctx, grad_output): + result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output) + return result[0], result[1], result[2] + + self.forward_opaque.register_autograd(backward, setup_context=setup_context) + + def setup_context_double_backward(ctx, inputs, output): + ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + + def double_backward(ctx, grad_output): + A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights + E, F, G = grad_output[0], grad_output[1], grad_output[2] + + op1 = backward_helper(E, F, D, C) + op2 = backward_helper(A, B, G, C) + op3 = forward(E, B, D) + op4 = backward_helper(E, B, D, C) + op5 = backward_helper(A, F, D, C) + op6 = forward(A, F, D) + op7 = forward(A, B, G) + + return ( + op1[0] + op2[0], + op1[1] + op2[1], + (op4[2] + op5[2]), + (op3 + op6 + op7), + ) + + backward_helper.register_autograd( + double_backward, setup_context=setup_context_double_backward + ) + + @classmethod + def register_torch_fakes(cls): + @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct") + class TorchJITProduct: + def __init__( + self, + kernel_plaintext: str, + fwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], + kernel_dims: dict[str, int], + ) -> None: + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = ( + kernel_plaintext, + fwd_config, + bwd_config, + dbl_bwd_config, + kernel_dims, + ) + + @classmethod + def __obj_unflatten__(cls, flattened_product): + return cls(**dict(flattened_product)) + + def __len__(self): + return 0 + + def __setstate__(self, state): + self.kernel_plaintext = state["kernel_plaintext"] + self.fwd_config = state["fwd_config"] + self.bwd_config = state["bwd_config"] + self.dbl_bwd_config = state["dbl_bwd_config"] + self.kernel_dims = state["kernel_dims"] + + def exec_tensor_product_rawptr(*args, **kwargs): + pass + + def backward_rawptr(*args, **kwargs): + pass + + def L3_dim_getter(self): + return self.kernel_dims["L3_dim"] + + def irrep_dtype_getter(self): + return self.kernel_dims["irrep_dtype"] + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") + def fake_forward(jit, L1_in, L2_in, W): + L3_dim = None + if hasattr(jit, "wrapped_obj"): + L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] + else: + L3_dim = jit.L3_dim + + return L1_in.new_empty(L1_in.shape[0], L3_dim) + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") + def fake_backward(jit, L1_in, L2_in, W, L3_grad): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @classmethod + def register_autograd(cls): + backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward + + def setup_context(ctx, inputs, output): + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output + ) + return None, L1_grad, L2_grad, W_grad + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + + def double_backward(ctx, E, F, G): + result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G + ) + return None, result[0], result[1], result[2], result[3] + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + + @classmethod + def register_autocast(cls): + global torch + import torch + + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 + ) + + @staticmethod + def name(): + return "LoopUnrollTP" + + def forward_raw( + self, + batch: np.uint64, + L1_in: np.uint64, + L2_in: np.uint64, + L3_out: np.uint64, + weights: np.uint64, + ) -> None: + self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) + + def backward_raw( + self, + batch_size: np.uint64, + L1_in: np.uint64, + L1_grad: np.uint64, + L2_in: np.uint64, + L2_grad: np.uint64, + weights: np.uint64, + weights_grad: np.uint64, + L3_grad: np.uint64, + ): + self.internal.backward_rawptr( + batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad + ) + + def forward_cpu( + self, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_out: np.ndarray, + weights: np.ndarray, + ) -> None: + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + batch = L1_in.shape[0] + L1_d = DeviceBuffer(L1_in) + L2_d = DeviceBuffer(L2_in) + L3_d = DeviceBuffer(L3_out) + weights_d = DeviceBuffer(weights_chunked) + self.internal.exec_tensor_product_rawptr( + batch, + L1_d.data_ptr(), + L2_d.data_ptr(), + L3_d.data_ptr(), + weights_d.data_ptr(), + ) + L3_d.copy_to_host() + + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad + ) -> None: + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + batch = L1_in.shape[0] + L1_d, L2_d, L3_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(L3_grad), + ) + L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) + weights_d, weights_grad_d = ( + DeviceBuffer(weights_chunked), + DeviceBuffer(weights_grad), + ) + + self.internal.backward_rawptr( + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) + + L1_grad_d.copy_to_host() + L2_grad_d.copy_to_host() + weights_grad_d.copy_to_host() + + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, not self.config.shared_weights + ) + + +if extlib.TORCH_COMPILE: + TensorProduct.register_torch_fakes() + TensorProduct.register_autograd() + TensorProduct.register_autocast() diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py similarity index 57% rename from openequivariance/implementations/convolution/TensorProductConv.py rename to openequivariance/openequivariance/_torch/TensorProductConv.py index 7e860944..f30c943c 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -3,18 +3,32 @@ import numpy as np import torch -from openequivariance import extlib -from openequivariance.implementations.convolution.ConvolutionBase import ( +import openequivariance._torch.extlib as extlib +from openequivariance._torch.extlib import ( + JITConvImpl, + postprocess_kernel, + DeviceProp, +) + +from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) -from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv -from openequivariance.implementations.TensorProduct import TensorProduct +from openequivariance.core.LoopUnrollConv import LoopUnrollConv +from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem -from openequivariance.implementations.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance._torch.utils import enum_to_torch_dtype +from openequivariance._torch.utils import reorder_torch + +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv +from openequivariance._torch.extlib import DeviceBuffer + +logger = getLogger() -class TensorProductConv(torch.nn.Module, LoopUnrollConv): +class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixinConv): r""" Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`, :math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes @@ -58,14 +72,34 @@ def __init__( self._init_class() def _init_class(self): + dp = DeviceProp(0) LoopUnrollConv.__init__( self, self.input_args["problem"], + dp, + postprocess_kernel, idx_dtype=np.int64, torch_op=self.input_args["torch_op"], deterministic=self.input_args["deterministic"], kahan=self.input_args["kahan"], ) + + self.allocate_workspace(self.workspace_size) + if extlib.TORCH_COMPILE: + internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv + else: + internal_cls = JITConvImpl + + logger.info("Starting kernel compiler...") + self.internal = internal_cls( + self.jit_kernel, + vars(self.forward_schedule.launch_config), + vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), + self.kernel_prop, + ) + logger.info("Kernel compiled!") + self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel self._setup_notorchbind() @@ -151,9 +185,16 @@ def forward( sender_perm, ) - @staticmethod - def name(): - return LoopUnrollConv.name() + def allocate_workspace(self, size_bytes): + self.workspace_size = size_bytes + if self.torch_op: + self.workspace_buffer = torch.zeros( + size_bytes, dtype=torch.uint8, device="cuda" + ) + else: + self.workspace_buffer = extlib.DeviceBuffer(size_bytes) + self.workspace_ptr = self.workspace_buffer.data_ptr() + logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") def _setup_notorchbind(self): @torch.library.custom_op( @@ -379,6 +420,315 @@ def double_backward(ctx, grad_output): double_backward, setup_context=setup_context_double_backward ) + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return reorder_torch( + self.forward_schedule, weights, "forward", not self.config.shared_weights + ) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return reorder_torch( + self.forward_schedule, weights, "backward", not self.config.shared_weights + ) + + @staticmethod + def name(): + return "LoopUnrollConv" + + @classmethod + def register_torch_fakes(cls): + global torch + import torch + + @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv") + class TorchJITConv: + def __init__( + self, + kernel_plaintext: str, + fwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], + kernel_dims: dict[str, int], + ) -> None: + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = ( + kernel_plaintext, + fwd_config, + bwd_config, + dbl_bwd_config, + kernel_dims, + ) + + @classmethod + def __obj_unflatten__(cls, flattened_product): + return cls(**dict(flattened_product)) + + def __len__(self): + return 0 + + def __setstate__(self, state): + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = state + + def exec_conv_rawptrs(*args, **kwargs): + pass + + def backward_rawptrs(*args, **kwargs): + pass + + def double_backward_rawptrs(*args, **kwargs): + pass + + def L3_dim_getter(self): + return self.kernel_dims["L3_dim"] + + def irrep_dtype_getter(self): + return self.kernel_dims["irrep_dtype"] + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") + def fake_forward( + jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm + ): + L3_dim, irrep_dtype = None, None + if hasattr(jit, "wrapped_obj"): + L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] + irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"] + else: + L3_dim = jit.L3_dim + irrep_dtype = jit.irrep_dtype + + return torch.empty( + L1_in.shape[0], + L3_dim, + device="cuda", + dtype=enum_to_torch_dtype[irrep_dtype], + ) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") + def fake_backward( + jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm + ): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") + def fake_double_backward( + jit, + L1_in, + L2_in, + W, + L3_grad, + L1_dgrad, + L2_dgrad, + w_dgrad, + rows, + cols, + workspace_buffer, + transpose_perm=None, + ): + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + W.new_empty(*W.shape), + L3_grad.new_empty(*L3_grad.shape), + ] + + @classmethod + def register_autograd(cls): + backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward + double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward + + def setup_context(ctx, inputs, output): + ( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return None, L1_grad, L2_grad, W_grad, None, None, None, None + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + ctx.inputs = inputs + + def double_backward(ctx, E, F, G): + result = double_backward_op( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + E, + F, + G, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return ( + None, + result[0], + result[1], + result[2], + result[3], + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + + @classmethod + def register_autocast(cls): + global torch + import torch + + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 + ) + + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype + + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + L1_d, L2_d, weights_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(weights_chunked), + ) + L3_d = DeviceBuffer(L3_out) + + rows_d = DeviceBuffer(graph.rows) + cols_d = DeviceBuffer(graph.cols) + + self.internal.exec_conv_rawptrs( + L1_d.data_ptr(), + L2_d.data_ptr(), + weights_d.data_ptr(), + L3_d.data_ptr(), + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, + self.workspace_ptr, + ) + + L3_d.copy_to_host() + + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph + ): + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype + + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + L1_d = DeviceBuffer(L1_in) + L2_d = DeviceBuffer(L2_in) + weights_d = DeviceBuffer(weights_chunked) + L3_d = DeviceBuffer(L3_grad) + rows_d = DeviceBuffer(graph.rows) + cols_d = DeviceBuffer(graph.cols) + + L1_grad_d = DeviceBuffer(L1_grad) + L2_grad_d = DeviceBuffer(L2_grad) + weights_grad_d = DeviceBuffer(weights_grad) + + transpose_perm_d = None + transpose_perm_ptr = 0 + if self.deterministic: + transpose_perm_d = DeviceBuffer(graph.transpose_perm) + transpose_perm_ptr = transpose_perm_d.data_ptr() + + self.internal.backward_rawptrs( + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, + self.workspace_ptr, + transpose_perm_ptr, + ) + + L1_grad_d.copy_to_host() + L2_grad_d.copy_to_host() + weights_grad_d.copy_to_host() + + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, not self.config.shared_weights + ) + + return L1_grad, L2_grad, weights_grad + + +if extlib.TORCH_COMPILE: + TensorProductConv.register_torch_fakes() + TensorProductConv.register_autograd() + TensorProductConv.register_autocast() + # ================================================================== # Reference implementations for benchmarking diff --git a/openequivariance/extlib/.empty b/openequivariance/openequivariance/_torch/extlib/.empty similarity index 100% rename from openequivariance/extlib/.empty rename to openequivariance/openequivariance/_torch/extlib/.empty diff --git a/openequivariance/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py similarity index 96% rename from openequivariance/extlib/__init__.py rename to openequivariance/openequivariance/_torch/extlib/__init__.py index ac6f15a3..72440872 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -9,7 +9,7 @@ from openequivariance.benchmark.logging_utils import getLogger -oeq_root = str(Path(__file__).parent.parent) +oeq_root = str(Path(__file__).parent.parent.parent) BUILT_EXTENSION = False BUILT_EXTENSION_ERROR = None @@ -23,7 +23,6 @@ torch_module, generic_module = None, None postprocess_kernel = lambda kernel: kernel # noqa : E731 - try: python_lib_dir = sysconfig.get_config_var("LIBDIR") major, minor = sys.version_info.major, sys.version_info.minor @@ -42,9 +41,10 @@ if BUILT_EXTENSION: - import openequivariance.extlib.generic_module + import openequivariance._torch.extlib.generic_module + + generic_module = openequivariance._torch.extlib.generic_module - generic_module = openequivariance.extlib.generic_module elif torch.version.cuda or torch.version.hip: try: from torch.utils.cpp_extension import library_paths, include_paths @@ -143,6 +143,10 @@ def _raise_import_error_helper(import_target: str): raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") +def torch_ext_so_path(): + return torch_module.__file__ + + if BUILT_EXTENSION: from generic_module import ( JITTPImpl, diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py new file mode 100644 index 00000000..00edefcb --- /dev/null +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py @@ -0,0 +1,5 @@ +from openequivariance._torch.symmetric_contraction.symmetric_contraction import ( + SymmetricContraction, +) + +__all__ = ["SymmetricContraction"] diff --git a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py similarity index 99% rename from openequivariance/implementations/symmetric_contraction/symmetric_contraction.py rename to openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py index 9790c2a2..504e788e 100644 --- a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py @@ -1,7 +1,7 @@ # ruff: noqa : E402 import torch -from openequivariance.extlib import GroupMM_F32, GroupMM_F64 +from openequivariance._torch.extlib import GroupMM_F32, GroupMM_F64 class GroupMM: diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py new file mode 100644 index 00000000..7538fb27 --- /dev/null +++ b/openequivariance/openequivariance/_torch/utils.py @@ -0,0 +1,68 @@ +import torch +from types import MappingProxyType +from openequivariance.core.utils import DTypeEnum + + +def reorder_helper(schedule, weights_in, direction, has_batch_dim): + assert direction in ["forward", "backward"] + + specs = schedule.weight_reordering_info(weights_in, has_batch_dim) + weights_out = torch.zeros_like(weights_in) + + for spec in specs: + parent_range = spec["parent_range"] + parent_shape = spec["parent_shape"] + weights_subrange = spec["weights_subrange"] + child_range = spec["child_range"] + transpose_perm = spec["transpose_perm"] + + if direction == "forward": + reshape_size = spec["reshape_size"] + + sliced_weights = weights_in[parent_range].reshape(parent_shape)[ + weights_subrange + ] + + weights_out[child_range] = sliced_weights.permute(transpose_perm).reshape( + reshape_size + ) + + elif direction == "backward": + transpose_child_shape = spec["transpose_child_shape"] + child_shape = spec["child_shape"] + + sliced_weights = ( + weights_in[child_range] + .reshape(transpose_child_shape) + .permute(transpose_perm) + ) + + weights_out[parent_range].reshape(parent_shape)[weights_subrange] = ( + sliced_weights.flatten().reshape(child_shape) + ) + + return weights_out + + +def reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim): + weights_in = torch.from_numpy(weights_in.copy()) + result = reorder_helper(schedule, weights_in, direction, has_batch_dim) + return result.detach().cpu().numpy().copy() + + +def reorder_torch(schedule, weights_in, direction, has_batch_dim): + if isinstance(weights_in, torch.Tensor): + return reorder_helper(schedule, weights_in, direction, has_batch_dim) + else: + return reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim) + + +enum_to_torch_dtype = MappingProxyType( + { + DTypeEnum.FLOAT32: torch.float32, + DTypeEnum.FLOAT64: torch.float64, + DTypeEnum.INT32: torch.int32, + DTypeEnum.INT64: torch.int64, + DTypeEnum.UINT8: torch.uint8, + } +) diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py similarity index 98% rename from openequivariance/benchmark/ConvBenchmarkSuite.py rename to openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index a4b7c982..499a33eb 100644 --- a/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -7,7 +7,7 @@ import openequivariance as oeq from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.convolution.ConvolutionBase import CoordGraph +from openequivariance.core.ConvolutionBase import CoordGraph logger = getLogger() diff --git a/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py similarity index 98% rename from openequivariance/benchmark/TestBenchmarkSuite.py rename to openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index d764be77..119c866c 100644 --- a/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -7,11 +7,11 @@ from dataclasses import dataclass import openequivariance as oeq -from openequivariance.extlib import DeviceProp -from openequivariance.implementations.TensorProductBase import TensorProductBase +from openequivariance._torch.extlib import DeviceProp +from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.benchmark.logging_utils import getLogger, bcolors -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, diff --git a/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py similarity index 96% rename from openequivariance/benchmark/benchmark_utils.py rename to openequivariance/openequivariance/benchmark/benchmark_utils.py index 4dfea422..377df3d6 100644 --- a/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -10,10 +10,10 @@ calculate_minimum_memory_streamed_forward, calculate_minimum_memory_streamed_backward, ) -from openequivariance.implementations.utils import calculate_total_nnz -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import TPProblem -from openequivariance.implementations.CUETensorProduct import CUETensorProduct +from openequivariance.core.utils import calculate_total_nnz +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.logging_utils import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py similarity index 75% rename from openequivariance/benchmark/correctness_utils.py rename to openequivariance/openequivariance/benchmark/correctness_utils.py index e2cf414b..788d209e 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -1,12 +1,14 @@ from typing import Optional, Union -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward, get_random_buffers_backward, + get_random_buffers_double_backward, ) + from openequivariance.benchmark.logging_utils import getLogger, bcolors import numpy as np import numpy.linalg as la @@ -71,7 +73,7 @@ def correctness_forward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -115,7 +117,7 @@ def correctness_backward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -194,66 +196,49 @@ def correctness_double_backward( global torch import torch - in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward( - problem, batch_size, prng_seed + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + get_random_buffers_double_backward( + problem, batch_size=batch_size, prng_seed=prng_seed + ) ) - rng = np.random.default_rng(seed=prng_seed * 2) - dummy_grad = rng.standard_normal(1)[0] if reference_implementation is None: - from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct result = {"thresh": correctness_threshold, "batch_size": batch_size} tensors = [] - for i, impl in enumerate([test_implementation, reference_implementation]): + for _, impl in enumerate([test_implementation, reference_implementation]): tp = instantiate_implementation(impl, problem) - - if impl == CUETensorProduct and problem.shared_weights: - weights = weights[np.newaxis, :] - weights_reordered = tp.reorder_weights_from_e3nn( - weights, not tp.config.shared_weights - ) - - in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) - in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) - weights_torch = torch.tensor( - weights_reordered, device="cuda", requires_grad=True + weights, has_batch_dim=not problem.shared_weights ) - - out_torch = tp.forward(in1_torch, in2_torch, weights_torch) - out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True) - - in1_grad, in2_grad, w_grad = torch.autograd.grad( - outputs=[out_torch], - inputs=[in1_torch, in2_torch, weights_torch], - grad_outputs=[out_grad], - create_graph=True, + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights ) - dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) - dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True) - - dummy.backward( - dummy_grad, - retain_graph=True, - inputs=[out_grad, in1_torch, in2_torch, weights_torch], - ) - - weights_grad = weights_torch.grad.detach().cpu().numpy() - weights_grad = tp.reorder_weights_to_e3nn( - weights_grad, not tp.config.shared_weights + if impl == CUETensorProduct and problem.shared_weights: + weights_reordered = weights_reordered[np.newaxis, :] + + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, ) - tensors.append( ( - out_grad.grad.detach().cpu().numpy(), - in1_torch.grad.detach().cpu().numpy(), - in2_torch.grad.detach().cpu().numpy(), - weights_grad, + out_dgrad, + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not problem.shared_weights + ), ) ) diff --git a/openequivariance/benchmark/logging_utils.py b/openequivariance/openequivariance/benchmark/logging_utils.py similarity index 100% rename from openequivariance/benchmark/logging_utils.py rename to openequivariance/openequivariance/benchmark/logging_utils.py diff --git a/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py similarity index 96% rename from openequivariance/benchmark/perf_metrics_utils.py rename to openequivariance/openequivariance/benchmark/perf_metrics_utils.py index 88a903ab..212f05f4 100644 --- a/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py @@ -1,11 +1,11 @@ import math -from openequivariance.implementations.utils import ( +from openequivariance.core.utils import ( count_cg_non_zero, sparse_outer_product_work, ) -from openequivariance.implementations.e3nn_lite import TPProblem, wigner_3j +from openequivariance.core.e3nn_lite import TPProblem, wigner_3j from openequivariance.benchmark.logging_utils import getLogger import numpy as np diff --git a/openequivariance/benchmark/plotting/__init__.py b/openequivariance/openequivariance/benchmark/plotting/__init__.py similarity index 100% rename from openequivariance/benchmark/plotting/__init__.py rename to openequivariance/openequivariance/benchmark/plotting/__init__.py diff --git a/openequivariance/benchmark/plotting/plot_convolution.py b/openequivariance/openequivariance/benchmark/plotting/plot_convolution.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_convolution.py rename to openequivariance/openequivariance/benchmark/plotting/plot_convolution.py diff --git a/openequivariance/benchmark/plotting/plot_double_backward.py b/openequivariance/openequivariance/benchmark/plotting/plot_double_backward.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_double_backward.py rename to openequivariance/openequivariance/benchmark/plotting/plot_double_backward.py diff --git a/openequivariance/benchmark/plotting/plot_roofline.py b/openequivariance/openequivariance/benchmark/plotting/plot_roofline.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_roofline.py rename to openequivariance/openequivariance/benchmark/plotting/plot_roofline.py diff --git a/openequivariance/benchmark/plotting/plot_uvu.py b/openequivariance/openequivariance/benchmark/plotting/plot_uvu.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_uvu.py rename to openequivariance/openequivariance/benchmark/plotting/plot_uvu.py diff --git a/openequivariance/benchmark/plotting/plot_uvw.py b/openequivariance/openequivariance/benchmark/plotting/plot_uvw.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_uvw.py rename to openequivariance/openequivariance/benchmark/plotting/plot_uvw.py diff --git a/openequivariance/benchmark/plotting/plotting_utils.py b/openequivariance/openequivariance/benchmark/plotting/plotting_utils.py similarity index 100% rename from openequivariance/benchmark/plotting/plotting_utils.py rename to openequivariance/openequivariance/benchmark/plotting/plotting_utils.py diff --git a/openequivariance/benchmark/problems.py b/openequivariance/openequivariance/benchmark/problems.py similarity index 100% rename from openequivariance/benchmark/problems.py rename to openequivariance/openequivariance/benchmark/problems.py diff --git a/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/random_buffer_utils.py similarity index 74% rename from openequivariance/benchmark/random_buffer_utils.py rename to openequivariance/openequivariance/benchmark/random_buffer_utils.py index 41fb7cb6..c657d5bc 100644 --- a/openequivariance/benchmark/random_buffer_utils.py +++ b/openequivariance/openequivariance/benchmark/random_buffer_utils.py @@ -1,6 +1,6 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.e3nn_lite import TPProblem def get_random_buffers_forward( @@ -104,10 +104,16 @@ def get_random_buffers_double_backward( ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) - weights_grad = np.zeros_like(weights) - in1_grad = np.zeros_like(in1) - in2_grad = np.zeros_like(in2) - out_double_grad = np.zeros_like(out_grad) + weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + in1_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_double_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) return ( in1, @@ -176,3 +182,46 @@ def get_random_buffers_backward_conv( in2_grad = np.zeros_like(in2) return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad + + +def get_random_buffers_double_backward_conv( + tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int +): + rng = np.random.default_rng(prng_seed) + in1 = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([edge_count, tpp.weight_numel]) + ) + + weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + in1_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_grad = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_double_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + return ( + in1, + in2, + out_grad, + weights, + weights_grad, + in1_grad, + in2_grad, + out_double_grad, + ) diff --git a/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py similarity index 97% rename from openequivariance/benchmark/tpp_creation_utils.py rename to openequivariance/openequivariance/benchmark/tpp_creation_utils.py index 18f3a84c..7637f412 100644 --- a/openequivariance/benchmark/tpp_creation_utils.py +++ b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py @@ -1,14 +1,12 @@ import numpy as np from typing import Iterator, Optional -from openequivariance.implementations.e3nn_lite import Irrep, Irreps, TPProblem +from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem """ This was taken from - https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py - -And adopted to create TPP's to avoid torch dependence +Adapted to create TPPs to avoid torch dependence. """ diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py similarity index 89% rename from openequivariance/implementations/ComputationSchedule.py rename to openequivariance/openequivariance/core/ComputationSchedule.py index 6d3b7215..9c3884c9 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -1,5 +1,5 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import Irreps, TPProblem, wigner_3j +from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j from itertools import accumulate from openequivariance.benchmark.logging_utils import getLogger @@ -619,40 +619,33 @@ def calculate_backward_smem( smem=self.memory_per_warp * warps_per_block, ) - def reorder_weights(self, weights_in, direction, has_batch_dim): + def weight_reordering_info(self, weights_in, has_batch_dim): """ - Reorders weights from the canonical e3nn form to the - form that LoopUnrollTP can ingest. Can also reorder the parameters - of a dense neural network layer that produces the weight matrix. - - If has_batch_dim is true, the first dimension of the input weight matrix - is treated as the batch dimension. + Calculates all shapes, slices, and permutation info to reorder + weights. """ - import torch # TODO-someday: no need to specialize this to PyTorch + batch_dim = weights_in.shape[0] + reorder_specs = [] - weights_out = torch.zeros_like(weights_in) - assert direction in ["forward", "backward"] for i, child_inst in enumerate(self.problem_splitter.new_instructions): parent_start, parent_end = ( child_inst.parent_weights_start, child_inst.parent_weights_end, ) parent_shape = list(child_inst.parent_weights_shape) + parent_range = [slice(parent_start, parent_end)] child_start, child_end, child_shape = ( self.updated_config.weight_range_and_shape_for_instruction(i) ) + child_range = [slice(child_start, child_end)] - parent_range, child_range = ( - [slice(parent_start, parent_end)], - [slice(child_start, child_end)], - ) weights_subrange = child_inst.weights_subrange - batch_dim = weights_in.shape[0] + reshape_size = [-1] transpose_perm = None - connection_mode = self.updated_config.instructions[i].connection_mode + if connection_mode == "uvu": transpose_perm = [1, 0] elif connection_mode == "uvw": @@ -662,50 +655,29 @@ def reorder_weights(self, weights_in, direction, has_batch_dim): child_range = [slice(0, batch_dim)] + child_range parent_range = [slice(0, batch_dim)] + parent_range parent_shape = [batch_dim] + parent_shape + child_shape = [batch_dim] + list(child_shape) weights_subrange = [slice(0, batch_dim)] + child_inst.weights_subrange reshape_size = [batch_dim] + reshape_size - transpose_perm = [0] + [i + 1 for i in transpose_perm] - - if direction == "forward": - sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[ - tuple(weights_subrange) - ] - weights_out[tuple(child_range)] = sliced_weights.permute( - transpose_perm - ).reshape(reshape_size) - elif direction == "backward": - transpose_child_shape = [child_shape[i] for i in transpose_perm] - sliced_weights = ( - weights_in[tuple(child_range)] - .reshape(transpose_child_shape) - .permute(transpose_perm) - ) - weights_out[tuple(parent_range)].reshape(parent_shape)[ - tuple(weights_subrange) - ] = sliced_weights.flatten().reshape(child_shape) - - return weights_out - - def reorder_weights_numpy(self, weights_in, direction, has_batch_dim): - import torch - - weights_in = torch.from_numpy(weights_in.copy()) - result = self.reorder_weights(weights_in, direction, has_batch_dim) - return result.detach().cpu().numpy().copy() - def reorder_weights_from_e3nn(self, weights_in, has_batch_dim): - import torch - - if isinstance(weights_in, np.ndarray): - return self.reorder_weights_numpy(weights_in, "forward", has_batch_dim) - elif isinstance(weights_in, torch.Tensor): - return self.reorder_weights(weights_in, "forward", has_batch_dim) - - def reorder_weights_to_e3nn(self, weights_in, has_batch_dim): - import torch + if transpose_perm is not None: + transpose_perm = [0] + [k + 1 for k in transpose_perm] + + transpose_child_shape = None + if transpose_perm is not None: + transpose_child_shape = [child_shape[k] for k in transpose_perm] + + reorder_specs.append( + { + "parent_range": tuple(parent_range), + "parent_shape": parent_shape, + "weights_subrange": tuple(weights_subrange), + "child_range": tuple(child_range), + "child_shape": child_shape, + "transpose_perm": transpose_perm, + "reshape_size": reshape_size, + "transpose_child_shape": transpose_child_shape, + } + ) - if isinstance(weights_in, np.ndarray): - return self.reorder_weights_numpy(weights_in, "backward", has_batch_dim) - elif isinstance(weights_in, torch.Tensor): - return self.reorder_weights(weights_in, "backward", has_batch_dim) + return reorder_specs diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py similarity index 73% rename from openequivariance/implementations/convolution/ConvolutionBase.py rename to openequivariance/openequivariance/core/ConvolutionBase.py index 7ed16571..a06b2c79 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -1,15 +1,15 @@ import copy import numpy as np -from openequivariance.extlib import DeviceBuffer from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward_conv, get_random_buffers_backward_conv, + get_random_buffers_double_backward_conv, ) from openequivariance.benchmark.logging_utils import getLogger, bcolors from openequivariance.benchmark.correctness_utils import check_similiarity -from openequivariance.implementations.e3nn_lite import wigner_3j -from openequivariance.implementations.utils import benchmark +from openequivariance.core.e3nn_lite import wigner_3j +from openequivariance.core.utils import benchmark logger = getLogger() @@ -130,106 +130,10 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): """ return weights - def allocate_workspace(self, size_bytes): - self.workspace_size = size_bytes - if self.torch_op: - self.workspace_buffer = torch.zeros( - size_bytes, dtype=torch.uint8, device="cuda" - ) - else: - self.workspace_buffer = DeviceBuffer(size_bytes) - self.workspace_ptr = self.workspace_buffer.data_ptr() - logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") - @staticmethod def name(): raise NotImplementedError() - def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): - assert graph.rows.dtype == self.idx_dtype - assert graph.cols.dtype == self.idx_dtype - - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - L1_d, L2_d, weights_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(weights_chunked), - ) - L3_d = DeviceBuffer(L3_out) - - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - self.internal.exec_conv_rawptrs( - L1_d.data_ptr(), - L2_d.data_ptr(), - weights_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - ) - - L3_d.copy_to_host() - - def backward_cpu( - self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph - ): - assert graph.rows.dtype == self.idx_dtype - assert graph.cols.dtype == self.idx_dtype - - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - weights_d = DeviceBuffer(weights_chunked) - L3_d = DeviceBuffer(L3_grad) - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - L1_grad_d = DeviceBuffer(L1_grad) - L2_grad_d = DeviceBuffer(L2_grad) - weights_grad_d = DeviceBuffer(weights_grad) - - transpose_perm_d = None - transpose_perm_ptr = 0 - if self.deterministic: - transpose_perm_d = DeviceBuffer(graph.transpose_perm) - transpose_perm_ptr = transpose_perm_d.data_ptr() - - self.internal.backward_rawptrs( - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - transpose_perm_ptr, - ) - - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() - - weights_grad[:] = self.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights - ) - - return L1_grad, L2_grad, weights_grad - def test_correctness_forward( self, graph, @@ -240,7 +144,7 @@ def test_correctness_forward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv + from openequivariance._torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -484,7 +388,7 @@ def test_correctness_backward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv + from openequivariance._torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -560,19 +464,12 @@ def test_correctness_double_backward( reference_implementation=None, high_precision_ref=False, ): - global torch - import torch - - assert self.torch_op - buffers = get_random_buffers_backward_conv( + buffers = get_random_buffers_double_backward_conv( self.config, graph.node_count, graph.nnz, prng_seed ) - rng = np.random.default_rng(seed=prng_seed * 2) - dummy_grad_value = rng.standard_normal(1)[0] - if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv + from openequivariance._torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -587,61 +484,41 @@ def test_correctness_double_backward( result = {"thresh": thresh} tensors = [] for i, tp in enumerate([self, reference_tp]): - in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers] + buffers_copy = [buf.copy() for buf in buffers] if i == 1 and high_precision_ref: - in1, in2, out_grad, weights, _, _, _ = [ - np.array(el, dtype=np.float64) for el in buffers - ] - - in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) - in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] - weights_reordered = tp.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - weights_torch = torch.tensor( - weights_reordered, device="cuda", requires_grad=True + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + buffers_copy ) - torch_rows = torch.tensor(graph.rows, device="cuda") - torch_cols = torch.tensor(graph.cols, device="cuda") - torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") - - fwd_args = [in1_torch, in2_torch, weights_torch, torch_rows, torch_cols] - if tp.deterministic: - fwd_args.append(torch_transpose_perm) - - out_torch = tp.forward(*fwd_args) - out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) - - in1_grad, in2_grad, w_grad = torch.autograd.grad( - outputs=[out_torch], - inputs=[in1_torch, in2_torch, weights_torch], - grad_outputs=[out_grad_torch], - create_graph=True, - ) - - dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) - dummy_grad = torch.tensor( - float(dummy_grad_value), device="cuda", requires_grad=True + weights_reordered = tp.reorder_weights_from_e3nn( + weights, not tp.config.shared_weights ) - dummy.backward( - dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch] + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, not tp.config.shared_weights ) - weights_grad = weights_torch.grad.detach().cpu().numpy() - weights_grad = tp.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + graph, ) tensors.append( ( - out_grad_torch.grad.detach().cpu().numpy().copy(), - in1_torch.grad.detach().cpu().numpy().copy(), - in2_torch.grad.detach().cpu().numpy().copy(), - weights_grad.copy(), + out_dgrad, + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not self.config.shared_weights + ), ) ) diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py new file mode 100644 index 00000000..35a9bc3e --- /dev/null +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -0,0 +1,207 @@ +import numpy as np + +from openequivariance.core.ConvolutionBase import ConvolutionBase +from openequivariance.core.ComputationSchedule import ( + ComputationSchedule, + SMEMCapacityException, +) + +from openequivariance.core.utils import dtype_to_enum +from openequivariance.templates.jinja_utils import get_jinja_environment +from openequivariance.core.utils import filter_and_analyze_problem + + +class LoopUnrollConv(ConvolutionBase): + def __init__( + self, + config, + dp, + postprocess_kernel, + *, + idx_dtype: type[np.generic] = np.int64, + torch_op: bool = False, + deterministic: bool = False, + kahan: bool = False, + ): + super().__init__( + config, idx_dtype=idx_dtype, torch_op=torch_op, deterministic=deterministic + ) + + if kahan: + assert deterministic + + env = get_jinja_environment() + template = env.get_template("loop_unroll_conv_atomic.cuh") + + analysis = filter_and_analyze_problem(config) + self.is_uvw = analysis["is_uvw"] + + if config.shared_weights: + assert not deterministic, ( + "Deterministic convolution does not support shared weights" + ) + + forward_schedule_type = 3 + backward_schedule_type = 2 + if deterministic: + backward_schedule_type = 3 + template = env.get_template("loop_unroll_conv_det.cuh") + + def generate_forward_schedule(warps_per_block): + self.forward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock // 4 * 3, + warps_per_block=warps_per_block, + block_count=dp.multiprocessorCount, + direction="forward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + schedule_type=forward_schedule_type, + warp_size=dp.warpsize, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + kahan=kahan, + ) + + def generate_backward_schedule(warps_per_block): + self.backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + block_count=dp.multiprocessorCount * 2, + direction="backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + schedule_type=backward_schedule_type, + warp_size=dp.warpsize, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + kahan=kahan, + ) + + def generate_double_backward_schedule(warps_per_block): + self.double_backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount, + direction="double_backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + schedule_type=3, + kahan=kahan, + ) + + scheduler_generators = [ + generate_forward_schedule, + generate_backward_schedule, + generate_double_backward_schedule, + ] + + for generate_schedule in scheduler_generators: + warp_count = 6 + while warp_count > 0: + try: + generate_schedule(warp_count) + break + except SMEMCapacityException: + warp_count -= 1 + if warp_count == 0: + raise SMEMCapacityException( + "Tensor product schedule generation failed, shared memory inadequate!" + ) + + if not deterministic: + for segment in self.forward_schedule.segments: + for key in segment.L3Map.storeback_procedure: + segment.L3Map.storeback_procedure[key] = "atomic_accumulate" + + for segment in self.backward_schedule.segments: + for key in segment.L1Map.storeback_procedure: + segment.L1Map.storeback_procedure[key] = "atomic_accumulate" + + for segment in self.double_backward_schedule.segments: + for key in segment.L1Map.storeback_procedure: + segment.L1Map.storeback_procedure[key] = "atomic_accumulate" + + idx_type_map = {np.int32: "int", np.int64: "long"} + + self.forward_workspace_offset = None + self.backward_workspace_offset = None + self.double_backwardB_offset = None + + self.workspace_size = 1 + if deterministic: + destination_index_bytes = 32 # Add extra to account for padding + self.workspace_size = max( + ( + self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + + destination_index_bytes + ) + * self.forward_schedule.total_warps, + ( + self.backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + + destination_index_bytes + ) + * self.backward_schedule.total_warps, + ( + self.double_backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + + destination_index_bytes + ) + * self.double_backward_schedule.total_warps, + ) + + self.forward_workspace_offset = ( + self.forward_schedule.L3.dim + * np.dtype(config.irrep_dtype).itemsize + * self.forward_schedule.total_warps + ) + self.backward_workspace_offset = ( + self.backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + * self.backward_schedule.total_warps + ) + self.double_backwardB_offset = ( + self.double_backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + * self.double_backward_schedule.total_warps + ) + + self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8 + self.backward_workspace_offset = ( + (self.backward_workspace_offset + 7) // 8 * 8 + ) + self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8 + + self.kernel_prop = { + "L1_dim": self.L1.dim, + "L2_dim": self.L2.dim, + "L3_dim": self.L3.dim, + "weight_numel": self.config.weight_numel, + "workspace_size": self.workspace_size, + "opt_level": 3, + "shared_weights": int(config.shared_weights), + "deterministic": int(self.deterministic), + "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], + "weight_dtype": dtype_to_enum[self.config.weight_dtype], + "idx_dtype": dtype_to_enum[self.idx_dtype], + } + + self.jit_kernel = template.render( + forward_schedule=self.forward_schedule, + backward_schedule=self.backward_schedule, + double_backward_schedule=self.double_backward_schedule, + idx_type=idx_type_map[idx_dtype], + forward_workspace_offset=self.forward_workspace_offset, + backward_workspace_offset=self.backward_workspace_offset, + double_backwardB_offset=self.double_backwardB_offset, + ) + self.jit_kernel = postprocess_kernel(self.jit_kernel) + + # with open("scratch.txt", "w") as f: + # f.write(self.jit_kernel) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py new file mode 100644 index 00000000..12ad4536 --- /dev/null +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -0,0 +1,157 @@ +import numpy as np + +from openequivariance.templates.jinja_utils import get_jinja_environment +from openequivariance.core.ComputationSchedule import ComputationSchedule +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.utils import dtype_to_enum + +from openequivariance.core.utils import ( + filter_and_analyze_problem, + count_cg_non_zero, +) + + +class LoopUnrollTP(TensorProductBase): + def __init__(self, config, dp, postprocess_kernel, torch_op): + super().__init__(config, torch_op=torch_op) + + env = get_jinja_environment() + template = env.get_template("loop_unroll_batch.cuh") + + analysis = filter_and_analyze_problem(config) + self.is_uvw = analysis["is_uvw"] + + def generate_forward_schedule(warps_per_block): + self.forward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount * 4, + direction="forward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + ) + + def generate_backward_schedule(warps_per_block): + self.backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount * 4, + direction="backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + ) + + def generate_double_backward_schedule(warps_per_block): + self.double_backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount, + direction="double_backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + schedule_type=3, + ) + + scheduler_generators = [ + generate_forward_schedule, + generate_backward_schedule, + generate_double_backward_schedule, + ] + + for generate_schedule in scheduler_generators: + warp_count = 8 + while warp_count > 0: + try: + generate_schedule(warp_count) + break + except Exception: + warp_count -= 2 + if warp_count == 0: + raise RuntimeError( + "Tensor product schedule generation failed, shared memory inadequate!" + ) + + self.jit_kernel = postprocess_kernel( + template.render( + forward_schedule=self.forward_schedule, + backward_schedule=self.backward_schedule, + double_backward_schedule=self.double_backward_schedule, + ) + ) + + self.kernelProp = { + "L1_dim": self.L1.dim, + "L2_dim": self.L2.dim, + "L3_dim": self.L3.dim, + "weight_numel": self.config.weight_numel, + "shared_weights": int(self.config.shared_weights), + "opt_level": 3, + "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], + "weight_dtype": dtype_to_enum[self.config.weight_dtype], + # Not relevant, included for compatibility with convolution + "workspace_size": 0, + "deterministic": 1, + "idx_dtype": 0, + } + + def calculate_flops_forward(self, batch_size: int) -> dict: + if self.is_uvw: + return super().calculate_flops_forward(batch_size) + else: + tpp = self.config + flop_count = { + "CG_decomposition": 0, + "linear_combination": 0, + "outer_products": 0, + } + for ins in tpp.instructions: + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + flop_count["linear_combination"] += ( + (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 + ) + + flop_count["CG_decomposition"] *= 3 * batch_size + flop_count["linear_combination"] *= ( + batch_size # Weights do not require FMA here + ) + flop_count["total"] = sum(flop_count.values()) + return flop_count + + def calculate_flops_backward(self, batch_size: int) -> dict: + if self.is_uvw: + return super().calculate_flops_backward(batch_size) + else: + tpp = self.config + flop_count = {"backward": 0} + for ins in tpp.instructions: + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + + flop_count["backward"] *= 9 * batch_size + flop_count["total"] = sum(flop_count.values()) + return flop_count diff --git a/openequivariance/implementations/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py similarity index 71% rename from openequivariance/implementations/TensorProductBase.py rename to openequivariance/openequivariance/core/TensorProductBase.py index 043c5b77..b5d3831f 100644 --- a/openequivariance/implementations/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -1,9 +1,8 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.utils import benchmark -from openequivariance.extlib import DeviceBuffer +from openequivariance.core.utils import benchmark logger = getLogger() @@ -44,7 +43,7 @@ def reorder_weights_from_e3nn(self, weights, has_batch_dim: bool = True): Reorders weights from ``e3nn`` canonical order to the order used by ``oeq``. :param weights: Weights in ``e3nn`` canonical order, either an - np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]`` + np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]`` when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``. :param has_batch_dim: If ``True``, treats the first dimension of weights as a batch dimension. Default: ``True``. @@ -57,8 +56,8 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True): r""" Reorders weights from ``oeq`` canonical order to the order used by ``e3nn``. - :param weights: Weights in ``oeq`` canonical order, either an - np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]`` + :param weights: Weights in ``oeq`` canonical order, either a + np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]`` when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``. :param has_batch_dim: If ``True``, treats the first dimension of wieghts as a batch dimension. Default: ``True``. @@ -67,94 +66,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True): """ return weights - def forward_raw( - self, - batch: np.uint64, - L1_in: np.uint64, - L2_in: np.uint64, - L3_out: np.uint64, - weights: np.uint64, - ) -> None: - self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) - - def backward_raw( - self, - batch_size: np.uint64, - L1_in: np.uint64, - L1_grad: np.uint64, - L2_in: np.uint64, - L2_grad: np.uint64, - weights: np.uint64, - weights_grad: np.uint64, - L3_grad: np.uint64, - ): - self.internal.backward_rawptr( - batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad - ) - - def forward_cpu( - self, - L1_in: np.ndarray, - L2_in: np.ndarray, - L3_out: np.ndarray, - weights: np.ndarray, - ) -> None: - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - batch = L1_in.shape[0] - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - L3_d = DeviceBuffer(L3_out) - weights_d = DeviceBuffer(weights_chunked) - self.internal.exec_tensor_product_rawptr( - batch, - L1_d.data_ptr(), - L2_d.data_ptr(), - L3_d.data_ptr(), - weights_d.data_ptr(), - ) - L3_d.copy_to_host() - - def backward_cpu( - self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad - ) -> None: - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - batch = L1_in.shape[0] - L1_d, L2_d, L3_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(L3_grad), - ) - L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = ( - DeviceBuffer(weights_chunked), - DeviceBuffer(weights_grad), - ) - - self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - ) - - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() - - weights_grad[:] = self.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights - ) - def benchmark_forward( self, num_warmup: int, diff --git a/openequivariance/implementations/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py similarity index 100% rename from openequivariance/implementations/e3nn_lite.py rename to openequivariance/openequivariance/core/e3nn_lite.py diff --git a/openequivariance/implementations/utils.py b/openequivariance/openequivariance/core/utils.py similarity index 84% rename from openequivariance/implementations/utils.py rename to openequivariance/openequivariance/core/utils.py index b90993c1..5fd8f81d 100644 --- a/openequivariance/implementations/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -3,11 +3,39 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import Instruction, TPProblem, wigner_3j +from openequivariance.core.e3nn_lite import Instruction, TPProblem, wigner_3j import json import tempfile -from openequivariance.extlib import GPUTimer +import hashlib + +from enum import IntEnum + + +class DTypeEnum(IntEnum): + """ + The C++ layer storess a copy of this map. + """ + + FLOAT32 = 1 + FLOAT64 = 2 + INT32 = 3 + INT64 = 4 + UINT8 = 5 + + +dtype_to_enum = { + np.float32: DTypeEnum.FLOAT32, + np.float64: DTypeEnum.FLOAT64, + np.int32: DTypeEnum.INT32, + np.int64: DTypeEnum.INT64, + np.uint8: DTypeEnum.UINT8, + np.dtype(np.float32): DTypeEnum.FLOAT32, + np.dtype(np.float64): DTypeEnum.FLOAT64, + np.dtype(np.int32): DTypeEnum.INT32, + np.dtype(np.int64): DTypeEnum.INT64, + np.dtype(np.uint8): DTypeEnum.UINT8, +} def sparse_outer_product_work(cg: np.ndarray) -> int: @@ -123,6 +151,8 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): mode=gpu_time may include PyTorch overhead mode=kernel_time measures runtime for only the specified kernels """ + from openequivariance._torch.extlib import GPUTimer + assert mode in ["gpu_time", "torch_kernel_time"] time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() @@ -170,3 +200,13 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis + + +def hash_attributes(attrs): + m = hashlib.sha256() + + for key in sorted(attrs.keys()): + m.update(attrs[key].__repr__().encode("utf-8")) + + hash = int(m.hexdigest()[:16], 16) >> 1 + attrs["hash"] = hash diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp similarity index 99% rename from openequivariance/extension/convolution.hpp rename to openequivariance/openequivariance/extension/convolution.hpp index 75b4f879..3b2ce1e6 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/openequivariance/extension/convolution.hpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include struct ConvData { diff --git a/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp similarity index 100% rename from openequivariance/extension/generic_module.cpp rename to openequivariance/openequivariance/extension/generic_module.cpp diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/openequivariance/extension/group_mm_cuda.hpp similarity index 100% rename from openequivariance/extension/group_mm_cuda.hpp rename to openequivariance/openequivariance/extension/group_mm_cuda.hpp diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/openequivariance/extension/group_mm_hip.hpp similarity index 100% rename from openequivariance/extension/group_mm_hip.hpp rename to openequivariance/openequivariance/extension/group_mm_hip.hpp diff --git a/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp similarity index 100% rename from openequivariance/extension/libtorch_tp_jit.cpp rename to openequivariance/openequivariance/extension/libtorch_tp_jit.cpp diff --git a/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp similarity index 100% rename from openequivariance/extension/tensorproducts.hpp rename to openequivariance/openequivariance/extension/tensorproducts.hpp diff --git a/openequivariance/extension/test/CMakeLists.txt b/openequivariance/openequivariance/extension/test/CMakeLists.txt similarity index 100% rename from openequivariance/extension/test/CMakeLists.txt rename to openequivariance/openequivariance/extension/test/CMakeLists.txt diff --git a/openequivariance/extension/test/load_jitscript.cpp b/openequivariance/openequivariance/extension/test/load_jitscript.cpp similarity index 100% rename from openequivariance/extension/test/load_jitscript.cpp rename to openequivariance/openequivariance/extension/test/load_jitscript.cpp diff --git a/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/util/backend_cuda.hpp similarity index 98% rename from openequivariance/extension/util/backend_cuda.hpp rename to openequivariance/openequivariance/extension/util/backend_cuda.hpp index 364186fc..4c79faed 100644 --- a/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/openequivariance/extension/util/backend_cuda.hpp @@ -349,7 +349,11 @@ class __attribute__((visibility("default"))) CUJITKernel { ~CUJITKernel() { if(compiled) { - CUDA_SAFE_CALL(cuLibraryUnload(library)); + auto result = cuLibraryUnload(library); + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { + std::cout << "Failed to unload CUDA library, error code: " << ((int) result) << std::endl; + } + delete[] code; } NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); diff --git a/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/util/backend_hip.hpp similarity index 100% rename from openequivariance/extension/util/backend_hip.hpp rename to openequivariance/openequivariance/extension/util/backend_hip.hpp diff --git a/openequivariance/extension/util/buffer.hpp b/openequivariance/openequivariance/extension/util/buffer.hpp similarity index 100% rename from openequivariance/extension/util/buffer.hpp rename to openequivariance/openequivariance/extension/util/buffer.hpp diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py new file mode 100644 index 00000000..452e7bb7 --- /dev/null +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -0,0 +1,159 @@ +import jax +import numpy as np +from functools import partial +from openequivariance.jax import extlib +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.LoopUnrollTP import LoopUnrollTP +from openequivariance.core.utils import hash_attributes +from openequivariance.jax.utils import reorder_jax + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) +def forward(X, Y, W, L3_dim, irrep_dtype, attrs): + forward_call = jax.ffi.ffi_call( + "tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + ) + return forward_call(X, Y, W, **attrs) + + +def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): + return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) + + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +def backward(X, Y, W, dZ, irrep_dtype, attrs): + backward_call = jax.ffi.ffi_call( + "tp_backward", + ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ), + ) + + return backward_call(X, Y, W, dZ, **attrs) + + +def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): + return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) + + +def double_backward(irrep_dtype, attrs, inputs, derivatives): + double_backward_call = jax.ffi.ffi_call( + "tp_double_backward", + ( + jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + ), + ) + return double_backward_call(*inputs, *derivatives, **attrs) + + +def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): + return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) + + +forward.defvjp(forward_with_inputs, backward_autograd) +backward.defvjp(backward_with_inputs, double_backward) + + +class TensorProduct(LoopUnrollTP): + r""" + Identical to ``oeq.torch.TensorProduct`` with functionality in JAX. + + :param problem: Specification of the tensor product. + """ + + def __init__(self, problem: TPProblem): + dp = extlib.DeviceProp(0) + super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) + + self.attrs = { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars(self.double_backward_schedule.launch_config), + "kernel_prop": self.kernelProp, + } + hash_attributes(self.attrs) + + self.weight_numel = problem.weight_numel + self.L3_dim = self.config.irreps_out.dim + + def forward( + self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray + ) -> jax.numpy.ndarray: + return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) + + def __call__( + self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray + ) -> jax.numpy.ndarray: + return self.forward(X, Y, W) + + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return reorder_jax( + self.forward_schedule, weights, "forward", not self.config.shared_weights + ) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return reorder_jax( + self.forward_schedule, weights, "backward", not self.config.shared_weights + ) + + def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) + result = self.forward( + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + ) + L3_out[:] = np.asarray(result) + + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad + ) -> None: + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) + backward_fn = jax.vjp( + lambda X, Y, W: self.forward(X, Y, W), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( + jax.numpy.asarray(L3_grad) + ) + L1_grad[:] = np.asarray(L1_grad_jax) + L2_grad[:] = np.asarray(L2_grad_jax) + weights_grad[:] = np.asarray(weights_grad_jax) + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not self.config.shared_weights + ) + + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad + ): + in1_jax = jax.numpy.asarray(in1) + in2_jax = jax.numpy.asarray(in2) + weights_jax = jax.numpy.asarray(weights) + out_grad_jax = jax.numpy.asarray(out_grad) + in1_dgrad_jax = jax.numpy.asarray(in1_dgrad) + in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) + weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) + + in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( + lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[ + 1 + ](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, + )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + + return in1_grad, in2_grad, weights_grad, out_dgrad diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py new file mode 100644 index 00000000..3aaee28a --- /dev/null +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -0,0 +1,298 @@ +import numpy as np +from functools import partial +from typing import Optional +from openequivariance.jax import extlib + +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.LoopUnrollConv import LoopUnrollConv +from openequivariance.core.utils import hash_attributes +from openequivariance.jax.utils import reorder_jax + +import jax +import jax.numpy as jnp + +from openequivariance.benchmark.logging_utils import getLogger + +logger = getLogger() + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) +def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): + forward_call = jax.ffi.ffi_call( + "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + ) + return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) + + +def forward_with_inputs( + X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs +): + return forward( + X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs + ), (X, Y, W, rows, cols, sender_perm, workspace) + + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9)) +def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): + backward_call = jax.ffi.ffi_call( + "conv_backward", + ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ), + ) + return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) + + +def backward_with_inputs( + X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs +): + return backward( + X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs + ), (X, Y, W, dZ) # rows, cols, sender_perm, workspace) + + +def double_backward( + rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives +): + double_backward_call = jax.ffi.ffi_call( + "conv_double_backward", + ( + jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + ), + ) + return double_backward_call( + *inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs + ) + + +def backward_autograd( + rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ +): + return backward( + inputs[0], + inputs[1], + inputs[2], + dZ, + rows, + cols, + workspace, + sender_perm, + irrep_dtype, + attrs, + ) + + +forward.defvjp(forward_with_inputs, backward_autograd) +backward.defvjp(backward_with_inputs, double_backward) + + +class TensorProductConv(LoopUnrollConv): + r""" + Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one + key difference: integer arrays passed to this function must have dtype + ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). + + :param problem: Specification of the tensor product. + :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic + fixup-based algorithm. `Default`: ``False``. + :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, + the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. + """ + + def __init__( + self, config: TPProblem, deterministic: bool = False, kahan: bool = False + ): + dp = extlib.DeviceProp(0) + super().__init__( + config, + dp, + extlib.postprocess_kernel, + idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version + torch_op=False, + deterministic=deterministic, + kahan=kahan, + ) + + self.attrs = { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars(self.double_backward_schedule.launch_config), + "kernel_prop": self.kernel_prop, + } + hash_attributes(self.attrs) + + self.weight_numel = config.weight_numel + self.L3_dim = self.config.irreps_out.dim + + self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8) + logger.info( + f"Convolution requires {self.workspace_size // (2**20)}MB of workspace." + ) + self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32) + + def forward( + self, + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray, + rows: jax.numpy.ndarray, + cols: jax.numpy.ndarray, + sender_perm: Optional[jax.numpy.ndarray] = None, + ) -> jax.numpy.ndarray: + r""" + Computes the fused CG tensor product + convolution. + + :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. + :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. + :param W: Tensor of datatype ``problem.weight_dtype`` and shape + + * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False`` + * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` + + :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix, + datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``. + :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix, + datatype ``np.int32``. + :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a + permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. + Must be provided when ``deterministic=True``. + + :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. + """ + if not self.deterministic: + sender_perm = self.dummy_transpose_perm + else: + assert sender_perm is not None, ( + "Must provide sender_perm for deterministic convolutions." + ) + + return forward( + X, + Y, + W, + rows, + cols, + self.workspace, + sender_perm, + self.L3_dim, + self.config.irrep_dtype, + self.attrs, + ) + + def __call__( + self, + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray, + rows: jax.numpy.ndarray, + cols: jax.numpy.ndarray, + sender_perm: Optional[jax.numpy.ndarray] = None, + ) -> jax.numpy.ndarray: + return self.forward(X, Y, W, rows, cols, sender_perm) + + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return reorder_jax(self.forward_schedule, weights, "forward", has_batch_dim) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return reorder_jax(self.forward_schedule, weights, "backward", has_batch_dim) + + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): + rows = graph.rows.astype(np.int32) + cols = graph.cols.astype(np.int32) + sender_perm = graph.transpose_perm.astype(np.int32) + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) + result = self.forward( + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + jax.numpy.asarray(rows), + jax.numpy.asarray(cols), + jax.numpy.asarray(sender_perm), + ) + L3_out[:] = np.asarray(result) + + def backward_cpu( + self, + L1_in, + L1_grad, + L2_in, + L2_grad, + L3_grad, + weights, + weights_grad, + graph, + ): + rows = graph.rows.astype(np.int32) + cols = graph.cols.astype(np.int32) + sender_perm = graph.transpose_perm.astype(np.int32) + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) + + backward_fn = jax.vjp( + lambda X, Y, W: self.forward( + X, + Y, + W, + jax.numpy.asarray(rows), + jax.numpy.asarray(cols), + jax.numpy.asarray(sender_perm), + ), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( + jax.numpy.asarray(L3_grad) + ) + L1_grad[:] = np.asarray(L1_grad_jax) + L2_grad[:] = np.asarray(L2_grad_jax) + weights_grad[:] = np.asarray(weights_grad_jax) + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not self.config.shared_weights + ) + + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph + ): + in1_jax = jax.numpy.asarray(in1) + in2_jax = jax.numpy.asarray(in2) + weights_jax = jax.numpy.asarray(weights) + out_grad_jax = jax.numpy.asarray(out_grad) + in1_dgrad_jax = jax.numpy.asarray(in1_dgrad) + in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) + weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) + + rows_jax = jax.numpy.asarray(graph.rows.astype(self.idx_dtype)) + cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype)) + sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype)) + + in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( + lambda x, y, w, o: jax.vjp( + lambda a, b, c: self.forward( + a, b, c, rows_jax, cols_jax, sender_perm_jax + ), + x, + y, + w, + )[1](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, + )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + + return ( + np.asarray(in1_grad), + np.asarray(in2_grad), + np.asarray(weights_grad), + np.asarray(out_dgrad), + ) diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py new file mode 100644 index 00000000..410e5dbf --- /dev/null +++ b/openequivariance/openequivariance/jax/__init__.py @@ -0,0 +1,6 @@ +from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct +from openequivariance.jax.TensorProductConv import ( + TensorProductConv as TensorProductConv, +) + +__all__ = ["TensorProduct", "TensorProductConv"] diff --git a/openequivariance/openequivariance/jax/extlib/__init__.py b/openequivariance/openequivariance/jax/extlib/__init__.py new file mode 100644 index 00000000..d0965502 --- /dev/null +++ b/openequivariance/openequivariance/jax/extlib/__init__.py @@ -0,0 +1,28 @@ +import jax +import openequivariance_extjax as oeq_extjax + + +def postprocess_kernel(kernel): + if oeq_extjax.is_hip(): + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel + else: + return kernel + + +platform = "CUDA" +if oeq_extjax.is_hip(): + platform = "ROCM" + +for name, target in oeq_extjax.registrations().items(): + jax.ffi.register_ffi_target(name, target, platform=platform) + +GPUTimer = oeq_extjax.GPUTimer +DeviceProp = oeq_extjax.DeviceProp + +__all__ = [ + "GPUTimer", + "DeviceProp", +] diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py new file mode 100644 index 00000000..ae15d1a6 --- /dev/null +++ b/openequivariance/openequivariance/jax/utils.py @@ -0,0 +1,63 @@ +import jax +import jax.numpy as jnp +import numpy as np + + +def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): + assert direction in ["forward", "backward"] + + specs = schedule.weight_reordering_info(weights_in, has_batch_dim) + weights_out = jnp.zeros_like(weights_in) + + for spec in specs: + parent_range = spec["parent_range"] + parent_shape = spec["parent_shape"] + weights_subrange = spec["weights_subrange"] + child_range = spec["child_range"] + transpose_perm = spec["transpose_perm"] + + if direction == "forward": + reshape_size = spec["reshape_size"] + + sliced_weights = weights_in[parent_range].reshape(parent_shape)[ + weights_subrange + ] + + value_to_assign = sliced_weights.transpose(transpose_perm).reshape( + reshape_size + ) + weights_out = weights_out.at[child_range].set(value_to_assign) + + elif direction == "backward": + transpose_child_shape = spec["transpose_child_shape"] + child_shape = spec["child_shape"] + + sliced_weights = ( + weights_in[child_range] + .reshape(transpose_child_shape) + .transpose(transpose_perm) + ) + + value_to_insert = sliced_weights.flatten().reshape(child_shape) + + slab = weights_out[parent_range] + slab_reshaped = slab.reshape(parent_shape) + slab_reshaped = slab_reshaped.at[weights_subrange].set(value_to_insert) + weights_out = weights_out.at[parent_range].set( + slab_reshaped.reshape(slab.shape) + ) + + return weights_out + + +def reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim): + weights_in_jax = jnp.array(weights_in) + result = reorder_jax_helper(schedule, weights_in_jax, direction, has_batch_dim) + return np.array(result) + + +def reorder_jax(schedule, weights_in, direction, has_batch_dim): + if isinstance(weights_in, (jnp.ndarray, jax.Array)): + return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim) + else: + return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) diff --git a/openequivariance/templates/common.cuh b/openequivariance/openequivariance/templates/common.cuh similarity index 100% rename from openequivariance/templates/common.cuh rename to openequivariance/openequivariance/templates/common.cuh diff --git a/openequivariance/templates/jinja_utils.py b/openequivariance/openequivariance/templates/jinja_utils.py similarity index 100% rename from openequivariance/templates/jinja_utils.py rename to openequivariance/openequivariance/templates/jinja_utils.py diff --git a/openequivariance/templates/loop_unroll_batch.cuh b/openequivariance/openequivariance/templates/loop_unroll_batch.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_batch.cuh rename to openequivariance/openequivariance/templates/loop_unroll_batch.cuh diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/openequivariance/templates/loop_unroll_conv_atomic.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_conv_atomic.cuh rename to openequivariance/openequivariance/templates/loop_unroll_conv_atomic.cuh diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/openequivariance/templates/loop_unroll_conv_det.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_conv_det.cuh rename to openequivariance/openequivariance/templates/loop_unroll_conv_det.cuh diff --git a/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_tp.cuh rename to openequivariance/openequivariance/templates/loop_unroll_tp.cuh diff --git a/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja similarity index 100% rename from openequivariance/templates/macros.jinja rename to openequivariance/openequivariance/templates/macros.jinja diff --git a/openequivariance/templates/wmm.cuh b/openequivariance/openequivariance/templates/wmm.cuh similarity index 100% rename from openequivariance/templates/wmm.cuh rename to openequivariance/openequivariance/templates/wmm.cuh diff --git a/pyproject.toml b/openequivariance/pyproject.toml similarity index 91% rename from pyproject.toml rename to openequivariance/pyproject.toml index 6f7b8771..a0ddd618 100644 --- a/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -17,8 +17,7 @@ dependencies = [ "setuptools", "ninja", "jinja2", - "numpy", - "torch >= 2.4", + "numpy" ] readme = "README.md" @@ -57,17 +56,20 @@ dev = [ "pytest-check", "pytest-subtests", "torch_geometric", - "cmake", - "furo", - "sphinx", - "sphinx-autobuild" + "cmake" +] + +jax = [ + "nanobind", + "scikit-build-core", + "setuptools-scm" ] [tool.setuptools.packages.find] include = ["openequivariance*"] [tool.setuptools_scm] -# Presence of this section necessary, even if empty +root = ".." [tool.pytest.ini_options] addopts = [ diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt new file mode 100644 index 00000000..25fec285 --- /dev/null +++ b/openequivariance_extjax/CMakeLists.txt @@ -0,0 +1,97 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) + +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) + +# --- XLA CONFIGURATION --- +if(XLA_DIRECT_DOWNLOAD) + message(STATUS "XLA_DIRECT_DOWNLOAD is ON. Fetching XLA source...") + include(ExternalProject) + ExternalProject_Add( + xla + PREFIX ${CMAKE_BINARY_DIR}/xla + GIT_REPOSITORY https://github.com/openxla/xla.git + GIT_TAG main + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + LOG_DOWNLOAD ON + ) + ExternalProject_Get_Property(xla source_dir) + set(XLA_DIR ${source_dir}) +else() + message(STATUS "XLA_DIRECT_DOWNLOAD is OFF. Locating XLA via installed JAX...") + execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" + "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR + ) +endif() + +message(STATUS "XLA include directory: ${XLA_DIR}") +# ------------------------- + +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT +) +message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") + +find_package(nanobind CONFIG REQUIRED) + +if(JAX_HIP) + message(STATUS "JAX_HIP is set. Building with HIP backend.") + find_package(hip REQUIRED) +else() + message(STATUS "JAX_HIP is not set (or zero). Building with CUDA backend.") + find_package(CUDAToolkit REQUIRED) +endif() + +set(HEADER_DIR "src/extension") +set(OEQ_JAX_SOURCES + src/libjax_tp_jit.cpp +) + +set(OEQ_JAX_HEADERS + ${HEADER_DIR}/convolution.hpp + ${HEADER_DIR}/tensorproducts.hpp + ${HEADER_DIR}/util/backend_cuda.hpp + ${HEADER_DIR}/util/backend_hip.hpp + ${HEADER_DIR}/util/buffer.hpp +) + +nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) +target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DIR}) +set_target_properties(openequivariance_extjax PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) + +# Ensure the module waits for XLA download if we are in direct download mode +if(XLA_DIRECT_DOWNLOAD) + add_dependencies(openequivariance_extjax xla) +endif() + +if(JAX_HIP) + target_link_libraries(openequivariance_extjax PRIVATE hiprtc) + target_compile_definitions(openequivariance_extjax PRIVATE HIP_BACKEND=1) + +else() + set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17) + + get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) + get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) + + set_target_properties(openequivariance_extjax PROPERTIES + BUILD_RPATH "${CUDA_LIB_DIR}" + INSTALL_RPATH "${CUDA_LIB_DIR}" + ) + + target_link_libraries(openequivariance_extjax PRIVATE + CUDA::cudart + CUDA::cuda_driver + CUDA::nvrtc) + target_compile_definitions(openequivariance_extjax PRIVATE CUDA_BACKEND=1) +endif() + +install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file diff --git a/openequivariance_extjax/LICENSE b/openequivariance_extjax/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/openequivariance_extjax/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/openequivariance_extjax/MANIFEST.in b/openequivariance_extjax/MANIFEST.in new file mode 100644 index 00000000..fcd76e59 --- /dev/null +++ b/openequivariance_extjax/MANIFEST.in @@ -0,0 +1,2 @@ +include CMakeLists.txt +include src/* \ No newline at end of file diff --git a/openequivariance_extjax/README.md b/openequivariance_extjax/README.md new file mode 100644 index 00000000..ad7455ef --- /dev/null +++ b/openequivariance_extjax/README.md @@ -0,0 +1,3 @@ +# OpenEquivariance JAX Extension + +The JAX extension module for OpenEquivariance. \ No newline at end of file diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml new file mode 100644 index 00000000..47a21179 --- /dev/null +++ b/openequivariance_extjax/pyproject.toml @@ -0,0 +1,53 @@ +[build-system] +requires = [ + "setuptools-scm", + "scikit-build-core", + "nanobind" +] +build-backend = "scikit_build_core.build" + +[project] +name = "openequivariance_extjax" +dynamic = ["version"] +authors = [ + { name="Austin Glover" }, + { name="Vivek Bharadwaj" }, + { name="Aydin Buluc" }, + { name="James Demmel" } +] +description = "JAX C++ Extension for OpenEquivariance" +requires-python = ">=3.10" + +dependencies = [] +readme = "README.md" + +license = "BSD-3-Clause" +license-files = ["LICENSE"] + +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] + +[project.urls] +homepage = "https://passionlab.github.io/OpenEquivariance/" +source = "https://github.com/PASSIONLab/OpenEquivariance" +issues = "https://github.com/PASSIONLab/OpenEquivariance/issues" + +[tool.scikit-build.cmake.define] +JAX_HIP = {env="JAX_HIP", default="0"} +XLA_DIRECT_DOWNLOAD = {env="XLA_DIRECT_DOWNLOAD", default="0"} + +[tool.setuptools_scm] +root = ".." + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", +] + +[tool.ruff] +lint.ignore = ["E741"] \ No newline at end of file diff --git a/openequivariance_extjax/src/extension b/openequivariance_extjax/src/extension new file mode 120000 index 00000000..8370a418 --- /dev/null +++ b/openequivariance_extjax/src/extension @@ -0,0 +1 @@ +../../openequivariance/openequivariance/extension \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp new file mode 100644 index 00000000..ae2035e8 --- /dev/null +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -0,0 +1,773 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +#ifdef CUDA_BACKEND + #include + #include + + #include "util/backend_cuda.hpp" + #include "group_mm_cuda.hpp" + using JITKernel = CUJITKernel; + using GPU_Allocator = CUDA_Allocator; + + template + using GroupMM = GroupMMCUDA; + using stream_t = cudaStream_t; +#endif + +#ifdef HIP_BACKEND + #include "util/backend_hip.hpp" + #include "group_mm_hip.hpp" + using JITKernel = HIPJITKernel; + using GPU_Allocator = HIP_Allocator; + + template + using GroupMM = GroupMMHIP; + using stream_t = hipStream_t; +#endif + +#include "tensorproducts.hpp" +#include "convolution.hpp" + +xla::ffi::DataType enum_to_xla_dtype(int64_t i){ + switch(i) { + case 1: + return xla::ffi::DataType::F32; + case 2: + return xla::ffi::DataType::F64; + case 3: + return xla::ffi::DataType::S32; + case 4: + return xla::ffi::DataType::S64; + case 5: + return xla::ffi::DataType::U8; + } + throw logic_error("Unsupported tensor datatype!"); +} + +std::string xla_dtype_to_string(xla::ffi::DataType dtype) { + const std::unordered_map map = { + {xla::ffi::DataType::INVALID, "INVALID"}, + {xla::ffi::DataType::PRED, "PRED"}, + {xla::ffi::DataType::S8, "S8"}, + {xla::ffi::DataType::S16, "S16"}, + {xla::ffi::DataType::S32, "S32"}, + {xla::ffi::DataType::S64, "S64"}, + {xla::ffi::DataType::U8, "U8"}, + {xla::ffi::DataType::U16, "U16"}, + {xla::ffi::DataType::U32, "U32"}, + {xla::ffi::DataType::U64, "U64"}, + {xla::ffi::DataType::F16, "F16"}, + {xla::ffi::DataType::F32, "F32"}, + {xla::ffi::DataType::F64, "F64"}, + {xla::ffi::DataType::BF16, "BF16"}, + {xla::ffi::DataType::C64, "C64"}, + {xla::ffi::DataType::C128, "C128"}, + {xla::ffi::DataType::TOKEN, "TOKEN"}, + {xla::ffi::DataType::F8E5M2, "F8E5M2"}, + {xla::ffi::DataType::F8E4M3, "F8E4M3"}, + {xla::ffi::DataType::F8E4M3FN, "F8E4M3FN"}, + {xla::ffi::DataType::F8E4M3B11FNUZ, "F8E4M3B11FNUZ"}, + {xla::ffi::DataType::F8E5M2FNUZ, "F8E5M2FNUZ"}, + {xla::ffi::DataType::F8E4M3FNUZ, "F8E4M3FNUZ"}, + {xla::ffi::DataType::F8E3M4, "F8E3M4"}, + {xla::ffi::DataType::F4E2M1FN, "F4E2M1FN"}, + {xla::ffi::DataType::F8E8M0FNU, "F8E8M0FNU"}, + }; + return map.at(dtype); +} + +inline void* data_ptr(ffi::AnyBuffer &buffer) { + return buffer.untyped_data(); +} + +inline void* data_ptr(ffi::Result &buffer) { + return data_ptr(*buffer); +} + +inline int byte_count(ffi::AnyBuffer &buffer) { + switch (buffer.element_type()) { + case xla::ffi::DataType::U32: + case xla::ffi::DataType::S32: + case xla::ffi::DataType::F32: + return 4; + case xla::ffi::DataType::F64: + case xla::ffi::DataType::S64: + return 8; + case xla::ffi::DataType::U8: + return 1; + default: + throw logic_error("Unsupported tensor datatype!"); + } +} + +#ifdef CUDA_BACKEND +void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { + cudaMemsetAsync( + data_ptr(buffer), + 0, + buffer.element_count() * byte_count(buffer), + stream); +} +#endif +#ifdef HIP_BACKEND +void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { + std::ignore = hipMemsetAsync( + data_ptr(buffer), + 0, + buffer.element_count() * byte_count(buffer), + stream); +} +#endif + +struct KernelProp { + int64_t L1_dim, L2_dim, L3_dim, weight_numel; + bool shared_weights; + xla::ffi::DataType irrep_dtype; + xla::ffi::DataType weight_dtype; + + int64_t workspace_size; // Convolution only + bool deterministic; + xla::ffi::DataType idx_dtype; + xla::ffi::DataType workspace_dtype; + + KernelProp() {} + + KernelProp( + std::unordered_map &kernel_dims, bool is_convolution): + L1_dim(kernel_dims.at("L1_dim")), + L2_dim(kernel_dims.at("L2_dim")), + L3_dim(kernel_dims.at("L3_dim")), + weight_numel(kernel_dims.at("weight_numel")), + shared_weights(kernel_dims.at("shared_weights")), + irrep_dtype(enum_to_xla_dtype(kernel_dims.at("irrep_dtype"))), + weight_dtype(enum_to_xla_dtype(kernel_dims.at("weight_dtype"))), + workspace_dtype(xla::ffi::DataType::U8) { + if(is_convolution) { + workspace_size = kernel_dims.at("workspace_size"); + deterministic = kernel_dims.at("deterministic"); + idx_dtype = enum_to_xla_dtype(kernel_dims.at("idx_dtype")); + } + } +}; + +std::unordered_map>, + KernelProp + >> tp_cache; + +std::unordered_map>, + KernelProp + >> conv_cache; +std::mutex mut; + +std::vector launch_config_keys = { + "num_blocks", + "num_threads", + "smem"}; +std::vector kernel_prop_keys = { + "L1_dim", + "L2_dim", + "L3_dim", + "weight_numel", + "shared_weights", + "opt_level", + "irrep_dtype", + "weight_dtype", + + // Convolution only + "workspace_size", + "deterministic", + "idx_dtype"}; + +std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { + std::unordered_map result; + for (const auto &key : keys) { + result[key] = dict.get(key).value(); + } + return result; +} + +std::pair*, KernelProp> + compile_tp_with_caching(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::Dictionary backward_config, + ffi::Dictionary double_backward_config, + ffi::Dictionary kernel_prop, + int64_t hash, + bool is_convolution) { + + { + const std::lock_guard lock(mut); + auto it = tp_cache.find(hash); + if (it == tp_cache.end()) { + auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + auto jit_tp_impl = std::make_unique>( + std::string(kernel), + parse_ffi_dict(forward_config, launch_config_keys), + parse_ffi_dict(backward_config, launch_config_keys), + parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_prop_map); + tp_cache.insert({hash, + std::make_pair(std::move(jit_tp_impl), + KernelProp(kernel_prop_map, is_convolution))}); + it = tp_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} + +std::pair*, KernelProp> + compile_conv_with_caching(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::Dictionary backward_config, + ffi::Dictionary double_backward_config, + ffi::Dictionary kernel_prop, + int64_t hash, + bool is_convolution) { + + { + const std::lock_guard lock(mut); + auto it = conv_cache.find(hash); + if (it == conv_cache.end()) { + auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + auto jit_conv_impl = std::make_unique>( + std::string(kernel), + parse_ffi_dict(forward_config, launch_config_keys), + parse_ffi_dict(backward_config, launch_config_keys), + parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_prop_map); + conv_cache.insert({hash, + std::make_pair(std::move(jit_conv_impl), + KernelProp(kernel_prop_map, is_convolution))}); + it = conv_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} + +inline void check_tensor(const ffi::AnyBuffer &buffer, + std::initializer_list expected_shape, + xla::ffi::DataType expected_dtype, + std::string tensor_name) { + const ffi::AnyBuffer::Dimensions dims = buffer.dimensions(); + if (dims.size() != expected_shape.size()) { + throw std::logic_error("Rank mismatch for tensor '" + + tensor_name + + "'. Expected rank " + + std::to_string(expected_shape.size()) + + ", got rank " + + std::to_string(dims.size())); + } + + for (size_t i = 0; i < dims.size(); i++) { + if (dims[i] != expected_shape.begin()[i]) { + throw std::logic_error("Shape mismatch for tensor '" + + tensor_name + + "'. Expected dimension " + + std::to_string(expected_shape.begin()[i]) + + " at index " + + std::to_string(i) + + ", got " + + std::to_string(dims[i])); + } + } + + if (buffer.element_type() != expected_dtype) { + throw std::logic_error("Datatype mismatch for tensor " + tensor_name + + ". Expected datatype " + xla_dtype_to_string(expected_dtype) + + ", got " + xla_dtype_to_string(buffer.element_type())); + } +} + +// --------------------- Tensor Products -------------------------- +ffi::Error tp_forward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::Result L3_out, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_tp_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + const int64_t num_batch = L1_in.dimensions()[0]; + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + + jit_kernel->exec_tensor_product( + num_batch, + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(L3_out), + data_ptr(W), + stream); + + return ffi::Error::Success(); +} + +ffi::Error tp_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_tp_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + const int64_t num_batch = L1_in.dimensions()[0]; + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad"); + } + else { + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {num_batch, k.weight_numel}, k.weight_dtype, "W_grad"); + } + + if (k.shared_weights) { + zero_buffer(*W_grad, stream); + } + + jit_kernel->backward( + num_batch, + data_ptr(L1_in), + data_ptr(L1_grad), + data_ptr(L2_in), + data_ptr(L2_grad), + data_ptr(W), + data_ptr(W_grad), + data_ptr(L3_grad), + stream); + return ffi::Error::Success(); +} + + +ffi::Error tp_double_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::AnyBuffer L1_dgrad, + ffi::AnyBuffer L2_dgrad, + ffi::AnyBuffer W_dgrad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + ffi::Result L3_dgrad, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_tp_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + const int64_t num_batch = L1_in.dimensions()[0]; + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + + if (k.shared_weights){ + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + + if (k.shared_weights) { + zero_buffer(*W_grad, stream); + } + + jit_kernel->double_backward( + num_batch, + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(W), + data_ptr(L3_grad), + data_ptr(L1_dgrad), + data_ptr(L2_dgrad), + data_ptr(W_dgrad), + data_ptr(L1_grad), + data_ptr(L2_grad), + data_ptr(W_grad), + data_ptr(L3_dgrad), + stream); + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + tp_forward, tp_forward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + tp_backward, tp_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + tp_double_backward, tp_double_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + +// --------------------- Convolution -------------------------- +ffi::Error conv_forward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer rows, + ffi::AnyBuffer cols, + ffi::AnyBuffer workspace, + ffi::AnyBuffer transpose_perm, + ffi::Result L3_out, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_conv_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + const int64_t nnz = rows.dimensions()[0]; + const int64_t node_count = L1_in.dimensions()[0]; + void* workspace_ptr = data_ptr(workspace); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic){ + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } + else { + workspace_ptr = nullptr; + } + zero_buffer(*L3_out, stream); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + + jit_kernel->exec_conv( + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(W), + data_ptr(L3_out), + data_ptr(rows), + data_ptr(cols), + nnz, node_count, + workspace_ptr, + stream); + + return ffi::Error::Success(); +} + +ffi::Error conv_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + ffi::AnyBuffer rows, + ffi::AnyBuffer cols, + ffi::AnyBuffer workspace, + ffi::AnyBuffer transpose_perm, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_conv_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + const int64_t nnz = rows.dimensions()[0]; + const int64_t node_count = L1_in.dimensions()[0]; + void* workspace_ptr = data_ptr(workspace); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } + else { + workspace_ptr = nullptr; + } + zero_buffer(*L1_grad, stream); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad"); + } + else { + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad"); + } + if(k.shared_weights) + zero_buffer(*W_grad, stream); + + jit_kernel->backward( + data_ptr(L1_in), + data_ptr(L1_grad), + data_ptr(L2_in), + data_ptr(L2_grad), + data_ptr(W), + data_ptr(W_grad), + data_ptr(L3_grad), + data_ptr(rows), + data_ptr(cols), + nnz, node_count, + workspace_ptr, + data_ptr(transpose_perm), + stream); + return ffi::Error::Success(); +} + +ffi::Error conv_double_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::AnyBuffer L1_dgrad, + ffi::AnyBuffer L2_dgrad, + ffi::AnyBuffer W_dgrad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + ffi::Result L3_dgrad, + ffi::AnyBuffer rows, + ffi::AnyBuffer cols, + ffi::AnyBuffer workspace, + ffi::AnyBuffer transpose_perm, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_conv_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + const int64_t nnz = rows.dimensions()[0]; + const int64_t node_count = L1_in.dimensions()[0]; + void* workspace_ptr = data_ptr(workspace); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } + else { + workspace_ptr = nullptr; + } + zero_buffer(*L1_grad, stream); + zero_buffer(*L3_dgrad, stream); + + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + if(k.shared_weights) + zero_buffer(*W_grad, stream); + + jit_kernel->double_backward( + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(W), + data_ptr(L3_grad), + data_ptr(L1_dgrad), + data_ptr(L2_dgrad), + data_ptr(W_dgrad), + data_ptr(L1_grad), + data_ptr(L2_grad), + data_ptr(W_grad), + data_ptr(L3_dgrad), + data_ptr(rows), + data_ptr(cols), + nnz, node_count, + workspace_ptr, + data_ptr(transpose_perm), + stream); + return ffi::Error::Success(); +} + +bool is_hip() { +#ifdef HIP_BACKEND + return true; +#else + return false; +#endif +} + +// --------------------- FFI Bindings -------------------------- + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + conv_forward, conv_forward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + conv_backward, conv_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Arg() + .Arg() + .Arg() + .Arg() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + conv_double_backward, conv_double_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Ret() + .Arg() + .Arg() + .Arg() + .Arg() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + +// --------------------- NB Module -------------------------- +NB_MODULE(openequivariance_extjax, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); + registrations["tp_backward"] = nb::capsule(reinterpret_cast(tp_backward)); + registrations["tp_double_backward"] = nb::capsule(reinterpret_cast(tp_double_backward)); + + registrations["conv_forward"] = nb::capsule(reinterpret_cast(conv_forward)); + registrations["conv_backward"] = nb::capsule(reinterpret_cast(conv_backward)); + registrations["conv_double_backward"] = nb::capsule(reinterpret_cast(conv_double_backward)); + return registrations; + }); + m.def("is_hip", &is_hip); + + nb::class_(m, "DeviceProp") + .def(nb::init()) + .def_ro("name", &DeviceProp::name) + .def_ro("warpsize", &DeviceProp::warpsize) + .def_ro("major", &DeviceProp::major) + .def_ro("minor", &DeviceProp::minor) + .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + nb::class_(m, "GPUTimer") + .def(nb::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); + + /*nb::class_>(m, "DeviceBuffer") + .def(nb::init()) + .def(nb::init()) + .def("copy_to_host", &PyDeviceBuffer::copy_to_host) + .def("data_ptr", &PyDeviceBuffer::data_ptr);*/ +} diff --git a/tests/batch_test.py b/tests/batch_test.py index 3c7cdf27..f32f7b51 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -3,7 +3,6 @@ import numpy as np import openequivariance as oeq -from openequivariance.implementations.TensorProduct import TensorProduct from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, @@ -41,8 +40,17 @@ def extra_tp_constructor_args(self): return {} @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args): - tp = TensorProduct(problem, **extra_tp_constructor_args) + def with_jax(self, request): + return request.config.getoption("--jax") + + @pytest.fixture(scope="class") + def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): + cls = oeq.TensorProduct + if with_jax: + import openequivariance.jax.TensorProduct as jax_tp + + cls = jax_tp + tp = cls(problem, **extra_tp_constructor_args) return tp, problem def test_tp_fwd(self, tp_and_problem): @@ -247,7 +255,9 @@ def problem(self, request, dtype): class TestTorchbindDisable(TestProductionModels): @pytest.fixture(scope="class") - def extra_tp_constructor_args(self): + def extra_tp_constructor_args(self, with_jax): + if with_jax: + pytest.skip("N/A for JAX") return {"use_opaque": True} @@ -261,11 +271,14 @@ def problem(self, request, dtype): return problem @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args): - tp = TensorProduct(problem, **extra_tp_constructor_args) - switch_map = { - np.float32: torch.float64, - np.float64: torch.float32, - } - tp.to(switch_map[problem.irrep_dtype]) - return tp, tp.config + def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + else: + tp = oeq.TensorProduct(problem, **extra_tp_constructor_args) + switch_map = { + np.float32: torch.float64, + np.float64: torch.float32, + } + tp.to(switch_map[problem.irrep_dtype]) + return tp, tp.config diff --git a/tests/benchmark.py b/tests/benchmark.py index ab005cdf..829cc46c 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -11,14 +11,14 @@ import numpy as np from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.extlib import DeviceProp -from openequivariance.implementations.E3NNTensorProduct import ( +from openequivariance._torch.extlib import DeviceProp +from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs, E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, ) -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.implementations.CUETensorProduct import CUETensorProduct +from openequivariance._torch.TensorProduct import TensorProduct +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.TestBenchmarkSuite import ( TestBenchmarkSuite, TestDefinition, @@ -30,15 +30,15 @@ SingleInstruction, ) -from openequivariance.implementations.convolution.TensorProductConv import ( +from openequivariance._torch.TensorProductConv import ( TensorProductConvAtomic, TensorProductConvDeterministic, TensorProductConvKahan, TensorProductConvScatterSum, ) -from openequivariance.implementations.convolution.CUEConv import CUEConv, CUEConvFused -from openequivariance.implementations.convolution.FlashTPConv import FlashTPConv +from openequivariance._torch.CUEConv import CUEConv, CUEConvFused +from openequivariance._torch.FlashTPConv import FlashTPConv from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph from openequivariance.benchmark.problems import ( diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..0e7098e0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +import os + +os.environ["JAX_ENABLE_X64"] = "True" +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" + + +def pytest_addoption(parser): + parser.addoption( + "--jax", + action="store_true", + default=False, + help="Test the JAX frontend instead of PyTorch", + ) diff --git a/tests/conv_test.py b/tests/conv_test.py index 14c8c3d2..9c6bb4c8 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -51,23 +51,29 @@ def graph(self, request): def extra_conv_constructor_args(self): return {} + @pytest.fixture(scope="class") + def with_jax(self, request): + return request.config.getoption("--jax") + @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") - def conv_object(self, request, problem, extra_conv_constructor_args): + def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): + cls = oeq.TensorProductConv + if with_jax: + from openequivariance.jax import TensorProductConv as jax_conv + + cls = jax_conv + if request.param == "atomic": - return oeq.TensorProductConv( - problem, deterministic=False, **extra_conv_constructor_args - ) + return cls(problem, deterministic=False, **extra_conv_constructor_args) elif request.param == "deterministic": if not problem.shared_weights: - return oeq.TensorProductConv( - problem, deterministic=True, **extra_conv_constructor_args - ) + return cls(problem, deterministic=True, **extra_conv_constructor_args) else: pytest.skip("Shared weights not supported with deterministic") elif request.param == "kahan": if problem.irrep_dtype == np.float32: if not problem.shared_weights: - return oeq.TensorProductConv( + return cls( problem, deterministic=True, kahan=True, @@ -228,7 +234,7 @@ def thresh(self, direction): return { "fwd": 1e-5, "bwd": 7.5e-2, # Expect higher errors for shared weights - "double_bwd": 5e-2, + "double_bwd": 5e-1, }[direction] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") @@ -245,7 +251,9 @@ def conv_object(self, request, problem): class TestTorchbindDisable(TestProductionModels): @pytest.fixture(scope="class") - def extra_conv_constructor_args(self): + def extra_conv_constructor_args(self, with_jax): + if with_jax: + pytest.skip("N/A for JAX") return {"use_opaque": True} @@ -253,7 +261,10 @@ class TestTorchTo(ConvCorrectness): problems = [mace_problems()[0]] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") - def problem(self, request, dtype): + def problem(self, request, dtype, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + problem = request.param problem.irrep_dtype, problem.weight_dtype = dtype, dtype return problem diff --git a/tests/example_test.py b/tests/example_test.py new file mode 100644 index 00000000..ae19f77e --- /dev/null +++ b/tests/example_test.py @@ -0,0 +1,163 @@ +import pytest +import os + + +@pytest.fixture +def with_jax(request): + return request.config.getoption("--jax") + + +def test_tutorial_torch(with_jax): + if with_jax: + pytest.skip("Skipping PyTorch tutorial when testing JAX") + + import torch + import e3nn.o3 as o3 + + gen = torch.Generator(device="cuda") + + batch_size = 1000 + X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") + X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) + + instructions = [(0, 0, 0, "uvu", True)] + + tp_e3nn = o3.TensorProduct( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ).to("cuda") + W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) + + Z = tp_e3nn(X, Y, W) + print(torch.norm(Z)) + # =============================== + + # =============================== + import openequivariance as oeq + + problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ) + tp_fast = oeq.TensorProduct(problem) + + Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier + print(torch.norm(Z)) + # =============================== + + # Graph Convolution + # =============================== + from torch_geometric import EdgeIndex + + node_ct, nonzero_ct = 3, 4 + + # Receiver, sender indices for message passing GNN + edge_index = EdgeIndex( + [ + [0, 1, 1, 2], # Receiver + [1, 0, 2, 1], + ], # Sender + device="cuda", + dtype=torch.long, + ) + + X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) + + tp_conv = oeq.TensorProductConv( + problem, deterministic=False + ) # Reuse problem from earlier + Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] + ) # Z has shape [node_ct, z_ir.dim] + print(torch.norm(Z)) + # =============================== + + # =============================== + _, sender_perm = edge_index.sort_by("col") # Sort by sender index + edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index + + # Now we can use the faster deterministic algorithm + tp_conv = oeq.TensorProductConv(problem, deterministic=True) + Z = tp_conv.forward( + X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm + ) + print(torch.norm(Z)) + # =============================== + assert True + + +def test_tutorial_jax(with_jax): + if not with_jax: + pytest.skip("Skipping JAX tutorial when testing PyTorch") + + os.environ["OEQ_NOTORCH"] = "1" + import openequivariance as oeq + import jax + + seed = 42 + key = jax.random.PRNGKey(seed) + + batch_size = 1000 + X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e") + instructions = [(0, 0, 0, "uvu", True)] + + problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ) + tp_fast = oeq.jax.TensorProduct(problem) + + X = jax.random.uniform( + key, + shape=(batch_size, X_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + Y = jax.random.uniform( + key, + shape=(batch_size, Y_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + W = jax.random.uniform( + key, + shape=(batch_size, tp_fast.weight_numel), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + + Z = tp_fast(X, Y, W) + print(jax.numpy.linalg.norm(Z)) + + edge_index = jax.numpy.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jax.numpy.int32, # NOTE: This int32, not int64 + ) + + node_ct, nonzero_ct = 3, 4 + X = jax.random.uniform( + key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32 + ) + Y = jax.random.uniform( + key, + shape=(nonzero_ct, Y_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + W = jax.random.uniform( + key, + shape=(nonzero_ct, problem.weight_numel), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) + Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) + print(jax.numpy.linalg.norm(Z)) diff --git a/tests/examples_test.py b/tests/examples_test.py deleted file mode 100644 index 3beaabb1..00000000 --- a/tests/examples_test.py +++ /dev/null @@ -1,75 +0,0 @@ -def test_tutorial(): - import torch - import e3nn.o3 as o3 - - gen = torch.Generator(device="cuda") - - batch_size = 1000 - X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") - X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) - Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) - - instructions = [(0, 0, 0, "uvu", True)] - - tp_e3nn = o3.TensorProduct( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False - ).to("cuda") - W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) - - Z = tp_e3nn(X, Y, W) - print(torch.norm(Z)) - # =============================== - - # =============================== - import openequivariance as oeq - - problem = oeq.TPProblem( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False - ) - tp_fast = oeq.TensorProduct(problem, torch_op=True) - - Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier - print(torch.norm(Z)) - # =============================== - - # Graph Convolution - # =============================== - from torch_geometric import EdgeIndex - - node_ct, nonzero_ct = 3, 4 - - # Receiver, sender indices for message passing GNN - edge_index = EdgeIndex( - [ - [0, 1, 1, 2], # Receiver - [1, 0, 2, 1], - ], # Sender - device="cuda", - dtype=torch.long, - ) - - X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) - Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) - W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) - - tp_conv = oeq.TensorProductConv( - problem, torch_op=True, deterministic=False - ) # Reuse problem from earlier - Z = tp_conv.forward( - X, Y, W, edge_index[0], edge_index[1] - ) # Z has shape [node_ct, z_ir.dim] - print(torch.norm(Z)) - # =============================== - - # =============================== - _, sender_perm = edge_index.sort_by("col") # Sort by sender index - edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index - - # Now we can use the faster deterministic algorithm - tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) - Z = tp_conv.forward( - X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm - ) - print(torch.norm(Z)) - # =============================== - assert True diff --git a/tests/export_test.py b/tests/export_test.py index e18b38b1..0fd23b2b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -11,7 +11,7 @@ from torch_geometric import EdgeIndex import importlib.resources -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct +from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session")