diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml
index 49e1cb04..fa6c7363 100644
--- a/.github/workflows/docs.yaml
+++ b/.github/workflows/docs.yaml
@@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install sphinx furo
+ pip install -r docs/requirements.txt
- name: Build website
run: |
sphinx-build -M dirhtml docs docs/_build
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 588eebce..2f351dba 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -6,7 +6,7 @@ on:
# ref: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
jobs:
- build:
+ build-oeq:
name: Build distribution
runs-on: ubuntu-latest
steps:
@@ -16,21 +16,20 @@ jobs:
with:
python-version: '3.10'
- name: install dependencies, then build source tarball
- run: |
+ run: |
+ cd openequivariance
python3 -m pip install build --user
python3 -m build --sdist
- name: store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
- path: dist/
+ path: openequivariance/dist/
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
- # build task to be completed first
- needs: build
- # Specifying a GitHub environment is optional, but strongly encouraged
+ needs: build-oeq
environment:
name: pypi
url: https://pypi.org/p/openequivariance
@@ -42,6 +41,47 @@ jobs:
uses: actions/download-artifact@v4
with:
name: python-package-distributions
- path: dist/
+ path: openequivariance/dist/
+ - name: publish package distributions to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+
+ # ------------------------------------
+
+ build-oeq-extjax:
+ name: Build distribution
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+ - name: install dependencies, then build source tarball
+ run: |
+ cd openequivariance_extjax
+ python3 -m pip install build --user
+ python3 -m build --sdist
+ - name: store the distribution packages
+ uses: actions/upload-artifact@v4
+ with:
+ name: python-package-distributions
+ path: openequivariance_extjax/dist/
+
+ pypi-publish-extjax:
+ name: Upload release to PyPI
+ runs-on: ubuntu-latest
+ needs: build-oeq-extjax
+ environment:
+ name: pypi
+ url: https://pypi.org/p/openequivariance_extjax
+ permissions:
+ # IMPORTANT: this permission is mandatory for Trusted Publishing
+ id-token: write
+ steps:
+ - name: download the distributions
+ uses: actions/download-artifact@v4
+ with:
+ name: python-package-distributions
+ path: openequivariance_extjax/dist/
- name: publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
\ No newline at end of file
diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt
index da9fde80..0e04348c 100644
--- a/.github/workflows/requirements_cuda_ci.txt
+++ b/.github/workflows/requirements_cuda_ci.txt
@@ -1,4 +1,6 @@
numpy==2.2.5
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
pytest==8.3.5
-ninja==1.11.1.4
\ No newline at end of file
+ninja==1.11.1.4
+nanobind==2.10.2
+scikit-build-core==0.11.6
\ No newline at end of file
diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml
index 50b496b8..39888491 100644
--- a/.github/workflows/verify_extension_build.yml
+++ b/.github/workflows/verify_extension_build.yml
@@ -1,4 +1,4 @@
-name: OEQ CUDA C++ Extension Build Verification
+name: OEQ C++ Extension Build Verification
on:
push:
@@ -29,10 +29,14 @@ jobs:
sudo apt-get update
sudo apt install nvidia-cuda-toolkit
pip install -r .github/workflows/requirements_cuda_ci.txt
- pip install -e .
+ pip install -e "./openequivariance"
- - name: Test extension build via import
+ - name: Test CUDA extension build via import
run: |
pytest \
tests/import_test.py::test_extension_built \
- tests/import_test.py::test_torch_extension_built
\ No newline at end of file
+ tests/import_test.py::test_torch_extension_built
+
+ - name: Test JAX extension build
+ run: |
+ XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 5c878b1a..64fcaa8d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -38,7 +38,6 @@ triton_autotuning
paper_benchmarks
paper_benchmarks_v2
paper_benchmarks_v3
-openequivariance/extlib/*.so
get_node.sh
*.egg-info
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a120656b..e9bd29e1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,17 @@
## Latest Changes
+### v0.5.0 (2025-12-25)
+JAX support is now available in
+OpenEquivariance for BOTH NVIDIA and
+AMD GPUs! See the
+[documentation](https://passionlab.github.io/OpenEquivariance/)
+and README.md for instructions on installation
+and usage.
+
+Minor changes:
+- Defer error reporting when CUDA is not available
+ to the first library usage in code, not library load.
+
### v0.4.1 (2025-09-04)
Minor update, fixes a bug loading JIT-compiled modules
with PyTorch 2.9.
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 7eaa4d91..00000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,10 +0,0 @@
-include openequivariance/extlib/*.so
-include openequivariance/extlib/*.empty
-
-include openequivariance/templates/*.cuh
-include openequivariance/templates/*.jinja
-
-include openequivariance/extension/*
-include openequivariance/extension/convolution/*
-include openequivariance/extension/tensorproducts/*
-include openequivariance/extension/util/*
\ No newline at end of file
diff --git a/README.md b/README.md
index 288e1daf..c68e4fa9 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,9 @@
# OpenEquivariance
-[](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
+[](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
[](https://opensource.org/licenses/BSD-3-Clause)
-[[Examples]](#show-me-some-examples)
+[[PyTorch Examples]](#pytorch-examples)
+[[JAX Examples]](#jax-examples)
[[Citation and Acknowledgements]](#citation-and-acknowledgements)
OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
@@ -12,8 +13,8 @@ that [e3nn](https://e3nn.org/) supports
commonly found in graph neural networks
(e.g. [Nequip](https://github.com/mir-group/nequip) or
[MACE](https://github.com/ACEsuit/mace)). To get
-started, ensure that you have GCC 9+ on your system
-and install our package via
+started with PyTorch, ensure that you have PyTorch
+and GCC 9+ available before installing our package via
```bash
pip install openequivariance
@@ -29,11 +30,26 @@ computation and memory consumption significantly.
For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
check out the [documentation](https://passionlab.github.io/OpenEquivariance).
-📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
-Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
-Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025).
+⭐️ **JAX**: Our latest update brings
+support for JAX. For NVIDIA GPUs,
+install it (after installing JAX)
+with the following two commands strictly in order:
-## Show me some examples
+``` bash
+pip install openequivariance[jax]
+pip install openequivariance_extjax --no-build-isolation
+```
+
+For AMD GPUs:
+``` bash
+pip install openequivariance[jax]
+JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation
+```
+
+See the section below for example usage and
+our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details.
+
+## PyTorch Examples
Here's a CG tensor product implemented by e3nn:
```python
@@ -127,6 +143,48 @@ print(torch.norm(Z))
`deterministic=False`, the `sender` and `receiver` indices can have
arbitrary order.
+## JAX Examples
+After installation, use the library
+as follows. Set `OEQ_NOTORCH=1`
+in your environment to avoid the PyTorch import in
+the regular `openequivariance` package.
+```python
+import jax
+import os
+
+os.environ["OEQ_NOTORCH"] = "1"
+import openequivariance as oeq
+
+seed = 42
+key = jax.random.PRNGKey(seed)
+
+batch_size = 1000
+X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e")
+problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False)
+
+
+node_ct, nonzero_ct = 3, 4
+edge_index = jax.numpy.array(
+ [
+ [0, 1, 1, 2],
+ [1, 0, 2, 1],
+ ],
+ dtype=jax.numpy.int32, # NOTE: This int32, not int64
+)
+
+X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
+Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim),
+ minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
+W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel),
+ minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
+
+tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
+Z = tp_conv.forward(
+ X, Y, W, edge_index[0], edge_index[1]
+)
+print(jax.numpy.linalg.norm(Z))
+```
+
## Citation and Acknowledgements
If you find this code useful, please cite our paper:
diff --git a/docs/api.rst b/docs/api.rst
index 3fac1764..c21b918f 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -8,7 +8,7 @@ OpenEquivariance API
OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces
``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses
the CG tensor product with a subsequent graph convolution. Initializing either class triggers
-JIT compilation of a custom kernel, which can take a few seconds.
+JIT compilation of a custom kernel, which can take a few seconds.
Both classes require a configuration object specified
by :py:class:`openequivariance.TPProblem`, which has a constructor
@@ -17,6 +17,9 @@ We recommend reading the `e3nn documentation `
trying our code. OpenEquivariance cannot accelerate all tensor products; see
:doc:`this page ` for a list of supported configurations.
+PyTorch API
+------------------------
+
.. autoclass:: openequivariance.TensorProduct
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
:undoc-members:
@@ -27,14 +30,39 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
:undoc-members:
:exclude-members: name
-.. autoclass:: openequivariance.TPProblem
- :members:
- :undoc-members:
-
.. autofunction:: openequivariance.torch_to_oeq_dtype
.. autofunction:: openequivariance.torch_ext_so_path
+JAX API
+------------------------
+The JAX API consists of ``TensorProduct`` and ``TensorProductConv``
+classes that behave identically to their PyTorch counterparts. These classes
+do not conform exactly to the e3nn-jax API, but perform the same computation.
+
+If you plan to use ``oeq.jax`` without PyTorch installed,
+you need to set ``OEQ_NOTORCH=1`` in your local environment (within Python,
+``os.environ["OEQ_NOTORCH"] = 1``). For the moment, we require this to avoid
+breaking the PyTorch version of OpenEquivariance.
+
+
+.. autoclass:: openequivariance.jax.TensorProduct
+ :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
+ :undoc-members:
+ :exclude-members:
+
+.. autoclass:: openequivariance.jax.TensorProductConv
+ :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
+ :undoc-members:
+ :exclude-members:
+
+Common API
+---------------------
+
+.. autoclass:: openequivariance.TPProblem
+ :members:
+ :undoc-members:
+
API Identical to e3nn
---------------------
diff --git a/docs/conf.py b/docs/conf.py
index 17707552..540cf37e 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -28,11 +28,17 @@
html_theme = "furo"
# html_static_path = ["_static"]
-extensions = [
- "sphinx.ext.autodoc",
+extensions = ["sphinx.ext.autodoc", "sphinx_inline_tabs"]
+
+sys.path.insert(0, str(Path("../openequivariance").resolve()))
+
+autodoc_mock_imports = [
+ "torch",
+ "jax",
+ "openequivariance._torch.extlib",
+ "openequivariance.jax.extlib",
+ "openequivariance_extjax",
+ "jinja2",
+ "numpy",
]
-
-sys.path.insert(0, str(Path("..").resolve()))
-
-autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"]
autodoc_typehints = "description"
diff --git a/docs/installation.rst b/docs/installation.rst
index 9c3588cb..5ade5c0c 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -8,94 +8,139 @@ Installation
You need the following to install OpenEquivariance:
- A Linux system equipped with an NVIDIA / AMD graphics card.
-- PyTorch >= 2.4 (>= 2.8 for AOTI and export).
+- Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX>0.5.0
+ with CUDA or RocM support.
- GCC 9+ and the CUDA / HIP toolkit. The command
``c++ --version`` should return >= 9.0; see below for details on
setting an alternate compiler.
-Installation is one easy command, followed by import verification:
+.. tab:: PyTorch
-.. code-block:: bash
+ Installation is one easy command, followed by import verification:
- pip install openequivariance
- python -c "import openequivariance"
+ .. code-block:: bash
-The second line triggers a build of the C++ extension we use to compile
-kernels, which can take a couple of minutes. Subsequent imports are
-much faster since this extension is cached.
+ pip install openequivariance
+ python -c "import openequivariance"
-To get the nightly build, run
+ The second line triggers a build of the C++ extension we use to compile
+ kernels, which can take a couple of minutes. Subsequent imports are
+ much faster since this extension is cached.
-.. code-block:: bash
+ To support ``torch.compile``, ``torch.export``, and
+ JITScript, OpenEquivariance needs to compile a C++ extension
+ tightly integrated with PyTorch. If you see a warning that
+ this extension could not be compiled, first check:
- pip install git+https://github.com/PASSIONLab/OpenEquivariance
+ .. code-block:: bash
+ c++ --version
+
+ To build the extension with an alternate compiler, set the
+ ``CC`` and ``CXX``
+ environment variable and retry the import:
-Compiling the Integrated PyTorch Extension
-------------------------------------------
-To support ``torch.compile``, ``torch.export``, and
-JITScript, OpenEquivariance needs to compile a C++ extension
-tightly integrated with PyTorch. If you see a warning that
-this extension could not be compiled, first check:
+ .. code-block:: bash
-.. code-block:: bash
+ export CC=/path/to/your/gcc
+ export CXX=/path/to/your/g++
+ python -c "import openequivariance"
- c++ --version
-
-To build the extension with an alternate compiler, set the
-``CC`` and ``CXX``
-environment variable and retry the import:
+
+ These configuration steps are required only ONCE after
+ installation (or upgrade) with pip.
+
+
+.. tab:: JAX NVIDIA GPUs
+
+ First ensure the appropriate JAX Python
+ package is installed in your environment. Then
+ run the following two commands stricly in order:
+
+ .. code-block:: bash
+
+ pip install openequivariance[jax]
+ pip install openequivariance_extjax --no-build-isolation
+
+.. tab:: JAX AMD GPUs
+
+ Ensure that JAX is installed correctly with RocM support
+ before running, in order,
+
+ .. code-block:: bash
+
+ pip install openequivariance[jax]
+ JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation
+
+
+.. tab:: Nightly (PT)
+
+ .. code-block:: bash
+
+ pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance"
+
+
+.. tab:: Nightly (JAX)
+
+ .. code-block:: bash
+
+ pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax]"
+ pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax --no-build-isolation"
+
+ # Use the command below for JAX+AMD
+ # JAX_HIP=1 pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax --no-build-isolation"
+
+
+If you're using JAX, set the environment variable
+``OEQ_NOTORCH=1`` to avoid a PyTorch import:
.. code-block:: bash
- export CCC=/path/to/your/gcc
- export CXX=/path/to/your/g++
- python -c "import openequivariance"
+ export OEQ_NOTORCH=1
+ python -c "import openequivariance.jax"
-These configuration steps are required only ONCE after
-installation (or upgrade) with pip.
Configurations on Major Platforms
---------------------------------
OpenEquivariance has been tested on both supercomputers and lab clusters.
-Here are some tested environment configuration files. If use OpenEquivariance
-on a widely-used platform, send us a pull request to add your configuration!
+Here are some tested environment configuration files. If you use OpenEquivariance
+on a major cluster, send us a pull request to add your configuration!
-NERSC Perlmutter (NVIDIA A100)
-""""""""""""""""""""""""""""""
-.. code-block:: bash
- :caption: env.sh (last updated June 2025)
+.. tab:: NERSC Perlmutter (NVIDIA A100)
- module load gcc
- module load conda
+ .. code-block:: bash
+ :caption: env.sh (last updated June 2025)
- # Deactivate any base environments
- for i in $(seq ${CONDA_SHLVL}); do
- conda deactivate
- done
+ module load gcc
+ module load conda
- conda activate
+ # Deactivate any base environments
+ for i in $(seq ${CONDA_SHLVL}); do
+ conda deactivate
+ done
+ conda activate
-OLCF Frontier (AMD MI250x)
-""""""""""""""""""""""""""
-You need to install a HIP-enabled verison of PyTorch to use our package.
-To do this, follow the steps `here `_.
+.. tab:: OLCF Frontier (AMD MI250x)
-.. code-block:: bash
- :caption: env.sh (last updated June 2025)
+ You need to install a HIP-enabled verison of PyTorch to use our package.
+ Follow the steps `here `_.
+
+
+ .. code-block:: bash
+ :caption: env.sh (last updated June 2025)
- module load PrgEnv-gnu/8.6.0
- module load miniforge3/23.11.0-0
- module load rocm/6.4.0
- module load craype-accel-amd-gfx90a
+ module load PrgEnv-gnu/8.6.0
+ module load miniforge3/23.11.0-0
+ module load rocm/6.4.0
+ module load craype-accel-amd-gfx90a
- for i in $(seq ${CONDA_SHLVL}); do
- conda deactivate
- done
+ for i in $(seq ${CONDA_SHLVL}); do
+ conda deactivate
+ done
- conda activate
- export CC=cc
- export CXX=CC
\ No newline at end of file
+ conda activate
+ export CC=cc
+ export CXX=CC
\ No newline at end of file
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 00000000..1cc76517
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,4 @@
+furo
+sphinx
+sphinx-inline-tabs
+sphinx-autobuild
\ No newline at end of file
diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst
index 7f5ff78c..bcc11955 100644
--- a/docs/supported_ops.rst
+++ b/docs/supported_ops.rst
@@ -117,7 +117,7 @@ toplevel. You can use our implementation by running
.. code-block::
- from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
+ from openequivariance._torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
Some Github users report weak performance for the
symmetric contraction backward pass; your mileage may vary.
diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst
index 7bc11b26..f602ab44 100644
--- a/docs/tests_and_benchmarks.rst
+++ b/docs/tests_and_benchmarks.rst
@@ -12,27 +12,41 @@ these; we provide instructions below.
We recommend you clone our repository and use an editable install to run tests
and benchmarks.
-You can still test our code with a non-editable install; just
-download the test folder and install the non-editable package and the dependencies with:
+Correctness
+------------------------------
+To set up an editable install and run our tests, use the following code:
-.. code-block:: bash
+.. tab:: PyTorch
- pip install openequivariance[dev,bench]
+ .. code-block:: bash
-Correctness
-------------------------------
-To set up the editable install and run the entire testsuite, use:
+ git clone https://github.com/PASSIONLab/OpenEquivariance
+ cd OpenEquivariance
+ pip install -e "./openequivariance[dev]"
+ pytest tests/
-.. code-block:: bash
+.. tab:: JAX
- git clone https://github.com/PASSIONLab/OpenEquivariance
- cd OpenEquivariance
- pip install -e .[dev]
- pytest
+ Note: To test correctness in JAX, we still require
+ an installation of PyTorch and e3nn in your environment.
+
+ .. code-block:: bash
+
+ git clone https://github.com/PASSIONLab/OpenEquivariance
+ cd OpenEquivariance
+
+ pip install "./openequivariance[jax]"
+ pip install "./openequivariance[dev]"
+ pip install "./openequivariance_extjax" --no-build-isolation
+
+ pytest --jax tests/example_test.py
+ pytest --jax tests/batch_test.py
+ pytest --jax tests/conv_test.py
Browse the ``tests`` directory to run specific components.
+
Replicating our Benchmarks
------------------------------
We conducted our benchmarks on an NVIDIA A100-SXM-80GB GPU at Lawrence Berkeley National Laboratory.
@@ -79,12 +93,13 @@ OpenEquivariance exhibits up to 2x speedup over FlashTP's fused kernels.
List of GPUs Tested
--------------------------------
-OpenEquivariance has been tested successfully the following GPUs. Submit a pull
+OpenEquivariance runs successfully the following GPUs. Submit a pull
request if you'd like to add your own!
- NVIDIA V100 (V. Bharadwaj, LBNL Einsteinium, June 2025)
- NVIDIA A100-SXM-40GB and A100-SXM-80GB (A. Glover, NERSC Perlmutter, June 2025)
- NVIDIA A5000 (V. Bharadwaj, UCB SLICE, June 2025)
+- NVIDIA T4 (V. Bharadwaj, Google Colab, Jan 2026)
- NVIDIA H100 (L. Larsen, P1 DTU HPC, June 2025)
- AMD MI250x (V. Bharadwaj, OLCF Frontier, June 2025)
- AMD MI300x (V. Bharadwaj, AMD Cloud, February 2025)
\ No newline at end of file
diff --git a/openequivariance/LICENSE b/openequivariance/LICENSE
new file mode 120000
index 00000000..ea5b6064
--- /dev/null
+++ b/openequivariance/LICENSE
@@ -0,0 +1 @@
+../LICENSE
\ No newline at end of file
diff --git a/openequivariance/MANIFEST.in b/openequivariance/MANIFEST.in
new file mode 100644
index 00000000..1d4a8cce
--- /dev/null
+++ b/openequivariance/MANIFEST.in
@@ -0,0 +1,4 @@
+include openequivariance/templates/*.cuh
+include openequivariance/templates/*.jinja
+
+include openequivariance/extension/*
\ No newline at end of file
diff --git a/openequivariance/README.md b/openequivariance/README.md
new file mode 120000
index 00000000..32d46ee8
--- /dev/null
+++ b/openequivariance/README.md
@@ -0,0 +1 @@
+../README.md
\ No newline at end of file
diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py
deleted file mode 100644
index 5a8ab812..00000000
--- a/openequivariance/__init__.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# ruff: noqa: F401
-import sys
-import torch
-import numpy as np
-from pathlib import Path
-from importlib.metadata import version
-
-import openequivariance.extlib
-
-from openequivariance.extlib import (
- LINKED_LIBPYTHON,
- LINKED_LIBPYTHON_ERROR,
- BUILT_EXTENSION,
- BUILT_EXTENSION_ERROR,
- TORCH_COMPILE,
- TORCH_COMPILE_ERROR,
-)
-
-from openequivariance.implementations.e3nn_lite import (
- TPProblem,
- Irrep,
- Irreps,
- _MulIr,
- Instruction,
-)
-from openequivariance.implementations.TensorProduct import TensorProduct
-from openequivariance.implementations.convolution.TensorProductConv import (
- TensorProductConv,
-)
-from openequivariance.implementations.utils import torch_to_oeq_dtype
-
-__version__ = None
-try:
- __version__ = version("openequivariance")
-except Exception as e:
- print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr)
-
-
-def _check_package_editable():
- import json
- from importlib.metadata import Distribution
-
- direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json")
- return json.loads(direct_url).get("dir_info", {}).get("editable", False)
-
-
-_editable_install_output_path = Path(__file__).parent.parent / "outputs"
-
-
-def torch_ext_so_path():
- """
- :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
- from the PyTorch C++ Interface.
- """
- return openequivariance.extlib.torch_module.__file__
-
-
-torch.serialization.add_safe_globals(
- [
- TensorProduct,
- TensorProductConv,
- TPProblem,
- Irrep,
- Irreps,
- _MulIr,
- Instruction,
- np.float32,
- np.float64,
- ]
-)
-
-__all__ = [
- "TPProblem",
- "Irreps",
- "TensorProduct",
- "TensorProductConv",
- "torch_to_oeq_dtype",
- "_check_package_editable",
- "torch_ext_so_path",
-]
diff --git a/openequivariance/implementations/LoopUnrollTP.py b/openequivariance/implementations/LoopUnrollTP.py
deleted file mode 100644
index ed6a5395..00000000
--- a/openequivariance/implementations/LoopUnrollTP.py
+++ /dev/null
@@ -1,311 +0,0 @@
-import numpy as np
-
-import openequivariance.extlib as extlib
-from openequivariance.templates.jinja_utils import get_jinja_environment
-from openequivariance.implementations.ComputationSchedule import ComputationSchedule
-
-from openequivariance.implementations.dtype_enum import dtype_to_enum
-from openequivariance.implementations.TensorProductBase import TensorProductBase
-from openequivariance.benchmark.logging_utils import getLogger
-from openequivariance.implementations.utils import (
- filter_and_analyze_problem,
- count_cg_non_zero,
-)
-
-logger = getLogger()
-
-
-class LoopUnrollTP(TensorProductBase):
- def __init__(self, config, torch_op=True):
- super().__init__(config, torch_op=torch_op)
-
- env = get_jinja_environment()
- template = env.get_template("loop_unroll_batch.cuh")
- dp = extlib.DeviceProp(0)
-
- analysis = filter_and_analyze_problem(config)
- self.is_uvw = analysis["is_uvw"]
-
- def generate_forward_schedule(warps_per_block):
- self.forward_schedule = ComputationSchedule(
- self.config,
- smem_limit=dp.maxSharedMemPerBlock,
- warps_per_block=warps_per_block,
- warp_size=dp.warpsize,
- block_count=dp.multiprocessorCount * 4,
- direction="forward",
- irrep_dtype=config.irrep_dtype,
- weight_dtype=config.weight_dtype,
- include_scratch=self.is_uvw,
- stream_weights=self.is_uvw,
- )
-
- def generate_backward_schedule(warps_per_block):
- self.backward_schedule = ComputationSchedule(
- self.config,
- smem_limit=dp.maxSharedMemPerBlock,
- warps_per_block=warps_per_block,
- warp_size=dp.warpsize,
- block_count=dp.multiprocessorCount * 4,
- direction="backward",
- irrep_dtype=config.irrep_dtype,
- weight_dtype=config.weight_dtype,
- include_scratch=self.is_uvw,
- stream_weights=self.is_uvw,
- )
-
- def generate_double_backward_schedule(warps_per_block):
- self.double_backward_schedule = ComputationSchedule(
- self.config,
- smem_limit=dp.maxSharedMemPerBlock,
- warps_per_block=warps_per_block,
- warp_size=dp.warpsize,
- block_count=dp.multiprocessorCount,
- direction="double_backward",
- irrep_dtype=config.irrep_dtype,
- weight_dtype=config.weight_dtype,
- include_scratch=self.is_uvw,
- stream_weights=self.is_uvw,
- schedule_type=3,
- )
-
- scheduler_generators = [
- generate_forward_schedule,
- generate_backward_schedule,
- generate_double_backward_schedule,
- ]
-
- for generate_schedule in scheduler_generators:
- warp_count = 8
- while warp_count > 0:
- try:
- generate_schedule(warp_count)
- break
- except Exception:
- warp_count -= 2
- if warp_count == 0:
- raise RuntimeError(
- "Tensor product schedule generation failed, shared memory inadequate!"
- )
-
- self.jit_kernel = extlib.postprocess_kernel(
- template.render(
- forward_schedule=self.forward_schedule,
- backward_schedule=self.backward_schedule,
- double_backward_schedule=self.double_backward_schedule,
- )
- )
-
- # with open("scratch.txt", "w") as f:
- # f.write(self.jit_kernel)
-
- internal_cls = None
- if self.torch_op and extlib.TORCH_COMPILE:
- global torch
- import torch
-
- internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct
- else:
- internal_cls = extlib.JITTPImpl
-
- logger.info("Starting kernel compiler...")
- self.internal = internal_cls(
- self.jit_kernel,
- vars(self.forward_schedule.launch_config),
- vars(self.backward_schedule.launch_config),
- vars(self.double_backward_schedule.launch_config),
- {
- "L1_dim": self.L1.dim,
- "L2_dim": self.L2.dim,
- "L3_dim": self.L3.dim,
- "weight_numel": self.config.weight_numel,
- "shared_weights": int(self.config.shared_weights),
- "opt_level": 3,
- "irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
- "weight_dtype": dtype_to_enum[self.config.weight_dtype],
- },
- )
- logger.info("Kernel compiled!")
- logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")
-
- def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
- return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)
-
- def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
- return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)
-
- @classmethod
- def register_torch_fakes(cls):
- global torch
- import torch
-
- @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct")
- class TorchJITProduct:
- def __init__(
- self,
- kernel_plaintext: str,
- fwd_config: dict[str, int],
- bwd_config: dict[str, int],
- dbl_bwd_config: dict[str, int],
- kernel_dims: dict[str, int],
- ) -> None:
- (
- self.kernel_plaintext,
- self.fwd_config,
- self.bwd_config,
- self.dbl_bwd_config,
- self.kernel_dims,
- ) = (
- kernel_plaintext,
- fwd_config,
- bwd_config,
- dbl_bwd_config,
- kernel_dims,
- )
-
- @classmethod
- def __obj_unflatten__(cls, flattened_product):
- return cls(**dict(flattened_product))
-
- def __len__(self):
- return 0
-
- def __setstate__(self, state):
- self.kernel_plaintext = state["kernel_plaintext"]
- self.fwd_config = state["fwd_config"]
- self.bwd_config = state["bwd_config"]
- self.dbl_bwd_config = state["dbl_bwd_config"]
- self.kernel_dims = state["kernel_dims"]
-
- def exec_tensor_product_rawptr(*args, **kwargs):
- pass
-
- def backward_rawptr(*args, **kwargs):
- pass
-
- def L3_dim_getter(self):
- return self.kernel_dims["L3_dim"]
-
- def irrep_dtype_getter(self):
- return self.kernel_dims["irrep_dtype"]
-
- @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward")
- def fake_forward(jit, L1_in, L2_in, W):
- L3_dim = None
- if hasattr(jit, "wrapped_obj"):
- L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"]
- else:
- L3_dim = jit.L3_dim
-
- return L1_in.new_empty(L1_in.shape[0], L3_dim)
-
- @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward")
- def fake_backward(jit, L1_in, L2_in, W, L3_grad):
- return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
-
- @classmethod
- def register_autograd(cls):
- backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward
-
- def setup_context(ctx, inputs, output):
- ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs
-
- def backward(ctx, grad_output):
- L1_grad, L2_grad, W_grad = backward_op(
- ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output
- )
- return None, L1_grad, L2_grad, W_grad
-
- torch.library.register_autograd(
- "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context
- )
-
- def setup_context_double_backward(ctx, inputs, output):
- ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs
-
- def double_backward(ctx, E, F, G):
- result = torch.ops.libtorch_tp_jit.jit_tp_double_backward(
- ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G
- )
- return None, result[0], result[1], result[2], result[3]
-
- torch.library.register_autograd(
- "libtorch_tp_jit::jit_tp_backward",
- double_backward,
- setup_context=setup_context_double_backward,
- )
-
- @classmethod
- def register_autocast(cls):
- global torch
- import torch
-
- torch.library.register_autocast(
- "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32
- )
- torch.library.register_autocast(
- "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32
- )
- torch.library.register_autocast(
- "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32
- )
-
- @staticmethod
- def name():
- return "LoopUnrollTP"
-
- def calculate_flops_forward(self, batch_size: int) -> dict:
- if self.is_uvw:
- return super().calculate_flops_forward(batch_size)
- else:
- tpp = self.config
- flop_count = {
- "CG_decomposition": 0,
- "linear_combination": 0,
- "outer_products": 0,
- }
- for ins in tpp.instructions:
- l1, l2, l3 = (
- tpp.irreps_in1[ins.i_in1].ir.l,
- tpp.irreps_in2[ins.i_in2].ir.l,
- tpp.irreps_out[ins.i_out].ir.l,
- )
- flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
- ins.path_shape[0] * ins.path_shape[1]
- )
- flop_count["linear_combination"] += (
- (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0
- )
-
- flop_count["CG_decomposition"] *= 3 * batch_size
- flop_count["linear_combination"] *= (
- batch_size # Weights do not require FMA here
- )
- flop_count["total"] = sum(flop_count.values())
- return flop_count
-
- def calculate_flops_backward(self, batch_size: int) -> dict:
- if self.is_uvw:
- return super().calculate_flops_backward(batch_size)
- else:
- tpp = self.config
- flop_count = {"backward": 0}
- for ins in tpp.instructions:
- l1, l2, l3 = (
- tpp.irreps_in1[ins.i_in1].ir.l,
- tpp.irreps_in2[ins.i_in2].ir.l,
- tpp.irreps_out[ins.i_out].ir.l,
- )
- flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (
- ins.path_shape[0] * ins.path_shape[1]
- )
-
- flop_count["backward"] *= 9 * batch_size
- flop_count["total"] = sum(flop_count.values())
- return flop_count
-
-
-if extlib.TORCH_COMPILE:
- LoopUnrollTP.register_torch_fakes()
- LoopUnrollTP.register_autograd()
- LoopUnrollTP.register_autocast()
diff --git a/openequivariance/implementations/TensorProduct.py b/openequivariance/implementations/TensorProduct.py
deleted file mode 100644
index 54fc4307..00000000
--- a/openequivariance/implementations/TensorProduct.py
+++ /dev/null
@@ -1,200 +0,0 @@
-from openequivariance.implementations.LoopUnrollTP import LoopUnrollTP
-from openequivariance import TPProblem
-from openequivariance import extlib
-import torch
-import typing
-from openequivariance.implementations.utils import torch_to_oeq_dtype
-
-
-class TensorProduct(torch.nn.Module, LoopUnrollTP):
- r"""
- Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward,
- backward, and double-backward passes using JIT-compiled kernels. Initialization
- fails if:
-
- * There are no visible GPUs.
- * The provided tensor product specification is unsupported.
-
- :param problem: Specification of the tensor product.
- :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``.
- """
-
- def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False):
- torch.nn.Module.__init__(self)
- self.input_args = {
- "problem": problem,
- "torch_op": torch_op,
- "use_opaque": use_opaque,
- }
- self._init_class()
-
- def _init_class(self):
- LoopUnrollTP.__init__(
- self, self.input_args["problem"], self.input_args["torch_op"]
- )
- self.weight_numel = self.input_args["problem"].weight_numel
- self._setup_notorchbind()
- if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]:
- self.forward = self.forward_opaque
-
- def to(self, *args, **kwargs):
- r"""
- See `torch.nn.Module.to() `_.
- """
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
- *args, **kwargs
- )
-
- if dtype is not None:
- updated_problem = self.input_args["problem"].clone()
- updated_problem.irrep_dtype = torch_to_oeq_dtype(dtype)
- updated_problem.weight_dtype = torch_to_oeq_dtype(dtype)
- self.input_args["problem"] = updated_problem
- self._init_class()
-
- torch.nn.Module.to(self, *args, **kwargs)
- return self
-
- def __getstate__(self):
- return self.input_args
-
- def __setstate__(self, state):
- torch.nn.Module.__init__(self)
- self.input_args = state
- self._init_class()
-
- @staticmethod
- def name():
- return LoopUnrollTP.name()
-
- def forward(
- self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor
- ) -> torch.Tensor:
- r"""
- Computes :math:`W (x \otimes_{\textrm{CG}} y)`, identical to
- ``o3.TensorProduct.forward``.
-
- :param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype
- ``problem.irrep_dtype``.
- :param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype
- ``problem.irrep_dtype``.
- :param W: Tensor of datatype ``problem.weight_dtype`` and shape
-
- * ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False``
- * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
-
- :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
- """
- return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W)
-
- def _setup_notorchbind(self):
- """
- In case TorchBind is not available (e.g. for torch.compile below PT2.8, etc.),
- set up operations using custom ops.
- """
-
- @torch.library.custom_op(
- f"openequivariance::tp_forward{self.tp_id}",
- mutates_args=(),
- device_types="cuda",
- )
- def forward(
- L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor
- ) -> torch.Tensor:
- L1_in_c, L2_in_c, weights_c = (
- L1_in.contiguous(),
- L2_in.contiguous(),
- weights.contiguous(),
- )
- L3_out = torch.empty(
- (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device=L1_in.device
- )
- self.forward_raw(
- L1_in_c.shape[0],
- L1_in_c.data_ptr(),
- L2_in_c.data_ptr(),
- L3_out.data_ptr(),
- weights_c.data_ptr(),
- )
- return L3_out
-
- @forward.register_fake
- def _(L1_in, L2_in, weights):
- return L1_in.new_empty(L1_in.shape[0], self.L3.dim)
-
- self.forward_opaque = forward
-
- # ---------------- Backward pass -----------------
- @torch.library.custom_op(
- f"openequivariance::tp_grad_helper{self.tp_id}",
- mutates_args=(),
- device_types="cuda",
- )
- def backward_helper(
- L1_in: torch.Tensor,
- L2_in: torch.Tensor,
- weights: torch.Tensor,
- L3_grad: torch.Tensor,
- ) -> typing.List[torch.Tensor]:
- L1_grad = torch.zeros_like(L1_in)
- L2_grad = torch.zeros_like(L2_in)
- weights_grad = torch.empty_like(weights)
-
- if self.config.shared_weights:
- weights_grad[:] = 0.0
-
- self.backward_raw(
- L1_in.shape[0],
- L1_in.contiguous().data_ptr(),
- L1_grad.data_ptr(),
- L2_in.contiguous().data_ptr(),
- L2_grad.data_ptr(),
- weights.contiguous().data_ptr(),
- weights_grad.data_ptr(),
- L3_grad.contiguous().data_ptr(),
- )
-
- return [L1_grad, L2_grad, weights_grad]
-
- @backward_helper.register_fake
- def _(L1_in, L2_in, weights, L3_grad):
- return [
- L1_in.new_empty(*L1_in.shape),
- L2_in.new_empty(*L2_in.shape),
- weights.new_empty(*weights.shape),
- ]
-
- def setup_context(ctx, inputs, output):
- ctx.L1_in, ctx.L2_in, ctx.weights = inputs
-
- def backward(ctx, grad_output):
- result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output)
- return result[0], result[1], result[2]
-
- self.forward_opaque.register_autograd(backward, setup_context=setup_context)
-
- def setup_context_double_backward(ctx, inputs, output):
- ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs
-
- def double_backward(ctx, grad_output):
- A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights
- E, F, G = grad_output[0], grad_output[1], grad_output[2]
-
- op1 = backward_helper(E, F, D, C)
- op2 = backward_helper(A, B, G, C)
- op3 = forward(E, B, D)
- op4 = backward_helper(E, B, D, C)
- op5 = backward_helper(A, F, D, C)
- op6 = forward(A, F, D)
- op7 = forward(A, B, G)
-
- return (
- op1[0] + op2[0],
- op1[1] + op2[1],
- (op4[2] + op5[2]),
- (op3 + op6 + op7),
- )
-
- backward_helper.register_autograd(
- double_backward, setup_context=setup_context_double_backward
- )
diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py
deleted file mode 100644
index a5e46ce3..00000000
--- a/openequivariance/implementations/convolution/LoopUnrollConv.py
+++ /dev/null
@@ -1,453 +0,0 @@
-import numpy as np
-
-from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
-from openequivariance.implementations.ComputationSchedule import (
- ComputationSchedule,
- SMEMCapacityException,
-)
-
-from openequivariance.implementations.dtype_enum import (
- dtype_to_enum,
- enum_to_torch_dtype,
-)
-from openequivariance.templates.jinja_utils import get_jinja_environment
-from openequivariance import extlib
-from openequivariance.extlib import JITConvImpl, postprocess_kernel, DeviceProp
-
-from openequivariance.implementations.utils import filter_and_analyze_problem
-from openequivariance.benchmark.logging_utils import getLogger
-
-logger = getLogger()
-
-
-class LoopUnrollConv(ConvolutionBase):
- def __init__(
- self,
- config,
- *,
- idx_dtype: type[np.generic] = np.int64,
- torch_op: bool = False,
- deterministic: bool = False,
- kahan: bool = False,
- ):
- super().__init__(
- config, idx_dtype=idx_dtype, torch_op=torch_op, deterministic=deterministic
- )
-
- if kahan:
- assert deterministic
-
- env = get_jinja_environment()
- template = env.get_template("loop_unroll_conv_atomic.cuh")
- dp = DeviceProp(0)
-
- analysis = filter_and_analyze_problem(config)
- self.is_uvw = analysis["is_uvw"]
-
- if config.shared_weights:
- assert not deterministic, (
- "Deterministic convolution does not support shared weights"
- )
-
- forward_schedule_type = 3
- backward_schedule_type = 2
- if deterministic:
- backward_schedule_type = 3
- template = env.get_template("loop_unroll_conv_det.cuh")
-
- def generate_forward_schedule(warps_per_block):
- self.forward_schedule = ComputationSchedule(
- self.config,
- smem_limit=dp.maxSharedMemPerBlock // 4 * 3,
- warps_per_block=warps_per_block,
- block_count=dp.multiprocessorCount,
- direction="forward",
- irrep_dtype=config.irrep_dtype,
- weight_dtype=config.weight_dtype,
- schedule_type=forward_schedule_type,
- warp_size=dp.warpsize,
- include_scratch=self.is_uvw,
- stream_weights=self.is_uvw,
- kahan=kahan,
- )
-
- def generate_backward_schedule(warps_per_block):
- self.backward_schedule = ComputationSchedule(
- self.config,
- smem_limit=dp.maxSharedMemPerBlock,
- warps_per_block=warps_per_block,
- block_count=dp.multiprocessorCount * 2,
- direction="backward",
- irrep_dtype=config.irrep_dtype,
- weight_dtype=config.weight_dtype,
- schedule_type=backward_schedule_type,
- warp_size=dp.warpsize,
- include_scratch=self.is_uvw,
- stream_weights=self.is_uvw,
- kahan=kahan,
- )
-
- def generate_double_backward_schedule(warps_per_block):
- self.double_backward_schedule = ComputationSchedule(
- self.config,
- smem_limit=dp.maxSharedMemPerBlock,
- warps_per_block=warps_per_block,
- warp_size=dp.warpsize,
- block_count=dp.multiprocessorCount,
- direction="double_backward",
- irrep_dtype=config.irrep_dtype,
- weight_dtype=config.weight_dtype,
- include_scratch=self.is_uvw,
- stream_weights=self.is_uvw,
- schedule_type=3,
- kahan=kahan,
- )
-
- scheduler_generators = [
- generate_forward_schedule,
- generate_backward_schedule,
- generate_double_backward_schedule,
- ]
-
- for generate_schedule in scheduler_generators:
- warp_count = 6
- while warp_count > 0:
- try:
- generate_schedule(warp_count)
- break
- except SMEMCapacityException:
- warp_count -= 1
- if warp_count == 0:
- raise SMEMCapacityException(
- "Tensor product schedule generation failed, shared memory inadequate!"
- )
-
- if not deterministic:
- for segment in self.forward_schedule.segments:
- for key in segment.L3Map.storeback_procedure:
- segment.L3Map.storeback_procedure[key] = "atomic_accumulate"
-
- for segment in self.backward_schedule.segments:
- for key in segment.L1Map.storeback_procedure:
- segment.L1Map.storeback_procedure[key] = "atomic_accumulate"
-
- for segment in self.double_backward_schedule.segments:
- for key in segment.L1Map.storeback_procedure:
- segment.L1Map.storeback_procedure[key] = "atomic_accumulate"
-
- idx_type_map = {np.int32: "int", np.int64: "long"}
-
- self.forward_workspace_offset = None
- self.backward_workspace_offset = None
- self.double_backwardB_offset = None
-
- workspace_size = 1
- if deterministic:
- destination_index_bytes = 32 # Add extra to account for padding
- workspace_size = max(
- (
- self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize
- + destination_index_bytes
- )
- * self.forward_schedule.total_warps,
- (
- self.backward_schedule.L1.dim
- * np.dtype(config.irrep_dtype).itemsize
- + destination_index_bytes
- )
- * self.backward_schedule.total_warps,
- (
- self.double_backward_schedule.L1.dim
- * np.dtype(config.irrep_dtype).itemsize
- + destination_index_bytes
- )
- * self.double_backward_schedule.total_warps,
- )
-
- self.forward_workspace_offset = (
- self.forward_schedule.L3.dim
- * np.dtype(config.irrep_dtype).itemsize
- * self.forward_schedule.total_warps
- )
- self.backward_workspace_offset = (
- self.backward_schedule.L1.dim
- * np.dtype(config.irrep_dtype).itemsize
- * self.backward_schedule.total_warps
- )
- self.double_backwardB_offset = (
- self.double_backward_schedule.L1.dim
- * np.dtype(config.irrep_dtype).itemsize
- * self.double_backward_schedule.total_warps
- )
-
- self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8
- self.backward_workspace_offset = (
- (self.backward_workspace_offset + 7) // 8 * 8
- )
- self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8
-
- self.allocate_workspace(workspace_size)
-
- self.jit_kernel = template.render(
- forward_schedule=self.forward_schedule,
- backward_schedule=self.backward_schedule,
- double_backward_schedule=self.double_backward_schedule,
- idx_type=idx_type_map[idx_dtype],
- forward_workspace_offset=self.forward_workspace_offset,
- backward_workspace_offset=self.backward_workspace_offset,
- double_backwardB_offset=self.double_backwardB_offset,
- )
- self.jit_kernel = postprocess_kernel(self.jit_kernel)
-
- if self.torch_op and extlib.TORCH_COMPILE:
- global torch
- import torch
-
- internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv
- else:
- internal_cls = JITConvImpl
-
- logger.info("Starting kernel compiler...")
- self.internal = internal_cls(
- self.jit_kernel,
- vars(self.forward_schedule.launch_config),
- vars(self.backward_schedule.launch_config),
- vars(self.double_backward_schedule.launch_config),
- {
- "L1_dim": self.L1.dim,
- "L2_dim": self.L2.dim,
- "L3_dim": self.L3.dim,
- "weight_numel": self.config.weight_numel,
- "workspace_size": self.workspace_size,
- "opt_level": 3,
- "shared_weights": int(config.shared_weights),
- "deterministic": int(self.deterministic),
- "irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
- "weight_dtype": dtype_to_enum[self.config.weight_dtype],
- "idx_dtype": dtype_to_enum[self.idx_dtype],
- },
- )
- logger.info("Kernel compiled!")
-
- # with open("scratch.txt", "w") as f:
- # f.write(self.jit_kernel)
-
- def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
- return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)
-
- def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
- return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)
-
- @staticmethod
- def name():
- return "LoopUnrollConv"
-
- @classmethod
- def register_torch_fakes(cls):
- global torch
- import torch
-
- @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv")
- class TorchJITConv:
- def __init__(
- self,
- kernel_plaintext: str,
- fwd_config: dict[str, int],
- bwd_config: dict[str, int],
- dbl_bwd_config: dict[str, int],
- kernel_dims: dict[str, int],
- ) -> None:
- (
- self.kernel_plaintext,
- self.fwd_config,
- self.bwd_config,
- self.dbl_bwd_config,
- self.kernel_dims,
- ) = (
- kernel_plaintext,
- fwd_config,
- bwd_config,
- dbl_bwd_config,
- kernel_dims,
- )
-
- @classmethod
- def __obj_unflatten__(cls, flattened_product):
- return cls(**dict(flattened_product))
-
- def __len__(self):
- return 0
-
- def __setstate__(self, state):
- (
- self.kernel_plaintext,
- self.fwd_config,
- self.bwd_config,
- self.dbl_bwd_config,
- self.kernel_dims,
- ) = state
-
- def exec_conv_rawptrs(*args, **kwargs):
- pass
-
- def backward_rawptrs(*args, **kwargs):
- pass
-
- def double_backward_rawptrs(*args, **kwargs):
- pass
-
- def L3_dim_getter(self):
- return self.kernel_dims["L3_dim"]
-
- def irrep_dtype_getter(self):
- return self.kernel_dims["irrep_dtype"]
-
- @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward")
- def fake_forward(
- jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm
- ):
- L3_dim, irrep_dtype = None, None
- if hasattr(jit, "wrapped_obj"):
- L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"]
- irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"]
- else:
- L3_dim = jit.L3_dim
- irrep_dtype = jit.irrep_dtype
-
- return torch.empty(
- L1_in.shape[0],
- L3_dim,
- device="cuda",
- dtype=enum_to_torch_dtype[irrep_dtype],
- )
-
- @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward")
- def fake_backward(
- jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm
- ):
- return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
-
- @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward")
- def fake_double_backward(
- jit,
- L1_in,
- L2_in,
- W,
- L3_grad,
- L1_dgrad,
- L2_dgrad,
- w_dgrad,
- rows,
- cols,
- workspace_buffer,
- transpose_perm=None,
- ):
- return [
- L1_in.new_empty(*L1_in.shape),
- L2_in.new_empty(*L2_in.shape),
- W.new_empty(*W.shape),
- L3_grad.new_empty(*L3_grad.shape),
- ]
-
- @classmethod
- def register_autograd(cls):
- backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward
- double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward
-
- def setup_context(ctx, inputs, output):
- (
- ctx.jit,
- ctx.L1_in,
- ctx.L2_in,
- ctx.W,
- ctx.rows,
- ctx.cols,
- ctx.workspace_buffer,
- ctx.sender_perm,
- ) = inputs
-
- def backward(ctx, grad_output):
- L1_grad, L2_grad, W_grad = backward_op(
- ctx.jit,
- ctx.L1_in,
- ctx.L2_in,
- ctx.W,
- grad_output,
- ctx.rows,
- ctx.cols,
- ctx.workspace_buffer,
- ctx.sender_perm,
- )
- return None, L1_grad, L2_grad, W_grad, None, None, None, None
-
- torch.library.register_autograd(
- "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context
- )
-
- def setup_context_double_backward(ctx, inputs, output):
- (
- ctx.jit,
- ctx.L1_in,
- ctx.L2_in,
- ctx.W,
- ctx.grad_output,
- ctx.rows,
- ctx.cols,
- ctx.workspace_buffer,
- ctx.sender_perm,
- ) = inputs
- ctx.inputs = inputs
-
- def double_backward(ctx, E, F, G):
- result = double_backward_op(
- ctx.jit,
- ctx.L1_in,
- ctx.L2_in,
- ctx.W,
- ctx.grad_output,
- E,
- F,
- G,
- ctx.rows,
- ctx.cols,
- ctx.workspace_buffer,
- ctx.sender_perm,
- )
- return (
- None,
- result[0],
- result[1],
- result[2],
- result[3],
- None,
- None,
- None,
- None,
- )
-
- torch.library.register_autograd(
- "libtorch_tp_jit::jit_conv_backward",
- double_backward,
- setup_context=setup_context_double_backward,
- )
-
- @classmethod
- def register_autocast(cls):
- global torch
- import torch
-
- torch.library.register_autocast(
- "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32
- )
- torch.library.register_autocast(
- "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32
- )
- torch.library.register_autocast(
- "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32
- )
-
-
-if extlib.TORCH_COMPILE:
- LoopUnrollConv.register_torch_fakes()
- LoopUnrollConv.register_autograd()
- LoopUnrollConv.register_autocast()
diff --git a/openequivariance/implementations/dtype_enum.py b/openequivariance/implementations/dtype_enum.py
deleted file mode 100644
index 292b7e4f..00000000
--- a/openequivariance/implementations/dtype_enum.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from enum import IntEnum
-from types import MappingProxyType
-import numpy as np
-import torch
-
-
-class DTypeEnum(IntEnum):
- FLOAT32 = 1
- FLOAT64 = 2
- INT32 = 3
- INT64 = 4
- UINT8 = 5
-
-
-dtype_to_enum = MappingProxyType(
- {
- torch.float32: DTypeEnum.FLOAT32,
- torch.float64: DTypeEnum.FLOAT64,
- torch.int32: DTypeEnum.INT32,
- torch.int64: DTypeEnum.INT64,
- torch.uint8: DTypeEnum.UINT8,
- # torch
- np.float32: DTypeEnum.FLOAT32,
- np.float64: DTypeEnum.FLOAT64,
- np.int32: DTypeEnum.INT32,
- np.int64: DTypeEnum.INT64,
- np.uint8: DTypeEnum.UINT8,
- # numpy generic
- np.dtype(np.float32): DTypeEnum.FLOAT32,
- np.dtype(np.float64): DTypeEnum.FLOAT64,
- np.dtype(np.int32): DTypeEnum.INT32,
- np.dtype(np.int64): DTypeEnum.INT64,
- np.dtype(np.uint8): DTypeEnum.UINT8,
- # numpy dtype
- }
-)
-
-
-enum_to_torch_dtype = MappingProxyType(
- {
- DTypeEnum.FLOAT32: torch.float32,
- DTypeEnum.FLOAT64: torch.float64,
- DTypeEnum.INT32: torch.int32,
- DTypeEnum.INT64: torch.int64,
- DTypeEnum.UINT8: torch.uint8,
- }
-)
diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/implementations/symmetric_contraction/__init__.py
deleted file mode 100644
index 75ac6cc8..00000000
--- a/openequivariance/implementations/symmetric_contraction/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from openequivariance.implementations.symmetric_contraction.symmetric_contraction import (
- SymmetricContraction,
-)
-
-__all__ = ["SymmetricContraction"]
diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py
new file mode 100644
index 00000000..a842a7c9
--- /dev/null
+++ b/openequivariance/openequivariance/__init__.py
@@ -0,0 +1,105 @@
+# ruff: noqa: F401
+import sys
+import os
+import numpy as np
+
+from pathlib import Path
+from importlib.metadata import version
+
+from openequivariance.core.e3nn_lite import (
+ TPProblem,
+ Irrep,
+ Irreps,
+ _MulIr,
+ Instruction,
+)
+
+__version__ = None
+try:
+ __version__ = version("openequivariance")
+except Exception as e:
+ print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr)
+
+
+def _check_package_editable():
+ import json
+ from importlib.metadata import Distribution
+
+ direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json")
+ return json.loads(direct_url).get("dir_info", {}).get("editable", False)
+
+
+_editable_install_output_path = Path(__file__).parent.parent.parent / "outputs"
+
+if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1":
+ import torch
+
+ from openequivariance._torch.TensorProduct import TensorProduct
+ from openequivariance._torch.TensorProductConv import TensorProductConv
+
+ from openequivariance._torch.extlib import (
+ torch_ext_so_path as torch_ext_so_path_internal,
+ )
+ from openequivariance.core.utils import torch_to_oeq_dtype
+
+ torch.serialization.add_safe_globals(
+ [
+ TensorProduct,
+ TensorProductConv,
+ TPProblem,
+ Irrep,
+ Irreps,
+ _MulIr,
+ Instruction,
+ np.float32,
+ np.float64,
+ ]
+ )
+
+ from openequivariance._torch.extlib import (
+ LINKED_LIBPYTHON,
+ LINKED_LIBPYTHON_ERROR,
+ BUILT_EXTENSION,
+ BUILT_EXTENSION_ERROR,
+ TORCH_COMPILE,
+ TORCH_COMPILE_ERROR,
+ )
+
+
+def torch_ext_so_path():
+ """
+ :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
+ from the PyTorch C++ Interface.
+ """
+ try:
+ return torch_ext_so_path_internal()
+ except NameError:
+ return None
+
+
+jax = None
+try:
+ import openequivariance_extjax
+ import openequivariance.jax as jax
+except Exception as e:
+ error = e
+
+ class JAX_ERR:
+ def TensorProduct(*args, **kwargs):
+ raise error
+
+ def TensorProductConv(*args, **kwargs):
+ raise error
+
+ jax = JAX_ERR()
+
+__all__ = [
+ "TPProblem",
+ "Irreps",
+ "TensorProduct",
+ "TensorProductConv",
+ "torch_to_oeq_dtype",
+ "_check_package_editable",
+ "torch_ext_so_path",
+ "jax",
+]
diff --git a/openequivariance/implementations/convolution/CUEConv.py b/openequivariance/openequivariance/_torch/CUEConv.py
similarity index 95%
rename from openequivariance/implementations/convolution/CUEConv.py
rename to openequivariance/openequivariance/_torch/CUEConv.py
index 9287abe8..8500e39c 100644
--- a/openequivariance/implementations/convolution/CUEConv.py
+++ b/openequivariance/openequivariance/_torch/CUEConv.py
@@ -2,8 +2,8 @@
import itertools
from typing import Iterator
-from openequivariance.implementations.CUETensorProduct import CUETensorProduct
-from openequivariance.implementations.convolution.ConvolutionBase import (
+from openequivariance._torch.CUETensorProduct import CUETensorProduct
+from openequivariance.core.ConvolutionBase import (
ConvolutionBase,
scatter_add_wrapper,
)
diff --git a/openequivariance/implementations/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py
similarity index 97%
rename from openequivariance/implementations/CUETensorProduct.py
rename to openequivariance/openequivariance/_torch/CUETensorProduct.py
index a7d027f4..33b8db12 100644
--- a/openequivariance/implementations/CUETensorProduct.py
+++ b/openequivariance/openequivariance/_torch/CUETensorProduct.py
@@ -4,15 +4,15 @@
import itertools
from typing import Iterator
-from openequivariance.implementations.TensorProductBase import TensorProductBase
-from openequivariance.implementations.e3nn_lite import TPProblem
+from openequivariance.core.TensorProductBase import TensorProductBase
+from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.benchmark.tpp_creation_utils import (
ChannelwiseTPP,
FullyConnectedTPProblem,
SingleInstruction,
)
-from openequivariance.implementations.utils import count_cg_non_zero
+from openequivariance.core.utils import count_cg_non_zero
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"
diff --git a/openequivariance/implementations/convolution/E3NNConv.py b/openequivariance/openequivariance/_torch/E3NNConv.py
similarity index 86%
rename from openequivariance/implementations/convolution/E3NNConv.py
rename to openequivariance/openequivariance/_torch/E3NNConv.py
index 00b0faa8..4cc20662 100644
--- a/openequivariance/implementations/convolution/E3NNConv.py
+++ b/openequivariance/openequivariance/_torch/E3NNConv.py
@@ -1,13 +1,14 @@
import numpy as np
-from openequivariance.implementations.convolution.ConvolutionBase import (
+from openequivariance.core.ConvolutionBase import (
ConvolutionBase,
scatter_add_wrapper,
)
-from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
+from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
+from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
-class E3NNConv(ConvolutionBase):
+class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv):
def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
assert torch_op
super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op)
@@ -37,7 +38,7 @@ def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
if config.irrep_dtype == np.float64:
torch.set_default_dtype(torch.float32) # Reset to default
- def forward(self, L1_in, L2_in, weights, rows, cols):
+ def forward(self, L1_in, L2_in, weights, rows, cols, transpose_perm=None):
messages = self.reference_tp(L1_in[cols], L2_in, weights)
return scatter_add_wrapper(messages, rows, L1_in.size(0))
diff --git a/openequivariance/implementations/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py
similarity index 94%
rename from openequivariance/implementations/E3NNTensorProduct.py
rename to openequivariance/openequivariance/_torch/E3NNTensorProduct.py
index 334ba65c..067a7e6b 100644
--- a/openequivariance/implementations/E3NNTensorProduct.py
+++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py
@@ -9,16 +9,17 @@
import pathlib
import numpy as np
-from openequivariance.implementations.TensorProductBase import TensorProductBase
-from openequivariance.implementations.e3nn_lite import TPProblem
+from openequivariance.core.TensorProductBase import TensorProductBase
+from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.logging_utils import getLogger
+from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")
logger = getLogger()
-class E3NNTensorProduct(TensorProductBase):
+class E3NNTensorProduct(TensorProductBase, NumpyDoubleBackwardMixin):
def __init__(self, config: TPProblem, torch_op=True):
super().__init__(config, torch_op=torch_op)
assert self.torch_op
diff --git a/openequivariance/implementations/convolution/FlashTPConv.py b/openequivariance/openequivariance/_torch/FlashTPConv.py
similarity index 87%
rename from openequivariance/implementations/convolution/FlashTPConv.py
rename to openequivariance/openequivariance/_torch/FlashTPConv.py
index 0302ef9c..9ec5c409 100644
--- a/openequivariance/implementations/convolution/FlashTPConv.py
+++ b/openequivariance/openequivariance/_torch/FlashTPConv.py
@@ -4,8 +4,8 @@
import torch
import numpy as np
-from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
-from openequivariance.implementations.utils import oeq_to_torch_dtype
+from openequivariance.core.ConvolutionBase import ConvolutionBase
+from openequivariance.core.utils import oeq_to_torch_dtype
class FlashTPConv(ConvolutionBase):
diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py
new file mode 100644
index 00000000..caf94268
--- /dev/null
+++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py
@@ -0,0 +1,97 @@
+import torch
+
+
+class NumpyDoubleBackwardMixin:
+ """
+ Adds a Numpy double backward method to any TensorProduct
+ with the forward pass defined in PyTorch and the relevant
+ derivatives registered.
+ """
+
+ def double_backward_cpu(
+ self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad
+ ):
+ assert self.torch_op
+
+ in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
+ in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
+ weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
+ out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
+ in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
+ in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
+ weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
+ out_torch = self.forward(in1_torch, in2_torch, weights_torch)
+
+ in1_grad, in2_grad, weights_grad = torch.autograd.grad(
+ outputs=out_torch,
+ inputs=[in1_torch, in2_torch, weights_torch],
+ grad_outputs=out_grad_torch,
+ create_graph=True,
+ retain_graph=True,
+ )
+
+ a, b, c, d = torch.autograd.grad(
+ outputs=[in1_grad, in2_grad, weights_grad],
+ inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
+ grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
+ )
+
+ return (
+ a.detach().cpu().numpy(),
+ b.detach().cpu().numpy(),
+ c.detach().cpu().numpy(),
+ d.detach().cpu().numpy(),
+ )
+
+
+class NumpyDoubleBackwardMixinConv:
+ """
+ Similar, but for fused graph convolution.
+ """
+
+ def double_backward_cpu(
+ self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph
+ ):
+ assert self.torch_op
+
+ in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
+ in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
+ weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
+ out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
+ in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
+ in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
+ weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
+
+ torch_rows = torch.tensor(graph.rows, device="cuda")
+ torch_cols = torch.tensor(graph.cols, device="cuda")
+ torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda")
+
+ out_torch = self.forward(
+ in1_torch,
+ in2_torch,
+ weights_torch,
+ torch_rows,
+ torch_cols,
+ torch_transpose_perm,
+ )
+
+ in1_grad, in2_grad, weights_grad = torch.autograd.grad(
+ outputs=out_torch,
+ inputs=[in1_torch, in2_torch, weights_torch],
+ grad_outputs=out_grad_torch,
+ create_graph=True,
+ retain_graph=True,
+ )
+
+ a, b, c, d = torch.autograd.grad(
+ outputs=[in1_grad, in2_grad, weights_grad],
+ inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
+ grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
+ )
+
+ return (
+ a.detach().cpu().numpy(),
+ b.detach().cpu().numpy(),
+ c.detach().cpu().numpy(),
+ d.detach().cpu().numpy(),
+ )
diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py
new file mode 100644
index 00000000..05ea54b5
--- /dev/null
+++ b/openequivariance/openequivariance/_torch/TensorProduct.py
@@ -0,0 +1,448 @@
+from openequivariance.core.LoopUnrollTP import LoopUnrollTP
+from openequivariance import TPProblem
+from openequivariance._torch import extlib
+import torch
+import typing
+from openequivariance.core.utils import torch_to_oeq_dtype
+from openequivariance.benchmark.logging_utils import getLogger
+from openequivariance._torch.utils import reorder_torch
+from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
+
+import numpy as np
+from openequivariance._torch.extlib import DeviceBuffer
+
+logger = getLogger()
+
+
+class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin):
+ r"""
+ Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward,
+ backward, and double-backward passes using JIT-compiled kernels. Initialization
+ fails if:
+
+ * There are no visible GPUs.
+ * The provided tensor product specification is unsupported.
+
+ :param problem: Specification of the tensor product.
+ :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``.
+ """
+
+ def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False):
+ torch.nn.Module.__init__(self)
+ self.input_args = {
+ "problem": problem,
+ "torch_op": torch_op,
+ "use_opaque": use_opaque,
+ }
+ self._init_class()
+
+ def _init_class(self):
+ dp = extlib.DeviceProp(0)
+ LoopUnrollTP.__init__(
+ self,
+ self.input_args["problem"],
+ dp,
+ extlib.postprocess_kernel,
+ self.input_args["torch_op"],
+ )
+
+ internal_cls = None
+ if extlib.TORCH_COMPILE:
+ internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct
+ else:
+ internal_cls = extlib.JITTPImpl
+
+ logger.info("Starting kernel compiler...")
+ self.internal = internal_cls(
+ self.jit_kernel,
+ vars(self.forward_schedule.launch_config),
+ vars(self.backward_schedule.launch_config),
+ vars(self.double_backward_schedule.launch_config),
+ self.kernelProp,
+ )
+ logger.info("Kernel compiled!")
+ logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")
+
+ self.weight_numel = self.input_args["problem"].weight_numel
+ self._setup_notorchbind()
+ if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]:
+ self.forward = self.forward_opaque
+
+ def to(self, *args, **kwargs):
+ r"""
+ See `torch.nn.Module.to() `_.
+ """
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
+ *args, **kwargs
+ )
+
+ if dtype is not None:
+ updated_problem = self.input_args["problem"].clone()
+ updated_problem.irrep_dtype = torch_to_oeq_dtype(dtype)
+ updated_problem.weight_dtype = torch_to_oeq_dtype(dtype)
+ self.input_args["problem"] = updated_problem
+ self._init_class()
+
+ torch.nn.Module.to(self, *args, **kwargs)
+ return self
+
+ def __getstate__(self):
+ return self.input_args
+
+ def __setstate__(self, state):
+ torch.nn.Module.__init__(self)
+ self.input_args = state
+ self._init_class()
+
+ def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
+ return reorder_torch(
+ self.forward_schedule, weights, "forward", not self.config.shared_weights
+ )
+
+ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
+ return reorder_torch(
+ self.forward_schedule, weights, "backward", not self.config.shared_weights
+ )
+
+ def forward(
+ self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor
+ ) -> torch.Tensor:
+ r"""
+ Computes :math:`W (x \otimes_{\textrm{CG}} y)`, identical to
+ ``o3.TensorProduct.forward``.
+
+ :param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype
+ ``problem.irrep_dtype``.
+ :param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype
+ ``problem.irrep_dtype``.
+ :param W: Tensor of datatype ``problem.weight_dtype`` and shape
+
+ * ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False``
+ * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
+
+ :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
+ """
+ return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W)
+
+ def _setup_notorchbind(self):
+ """
+ In case TorchBind is not available (e.g. for torch.compile below PT2.8, etc.),
+ set up operations using custom ops.
+ """
+
+ @torch.library.custom_op(
+ f"openequivariance::tp_forward{self.tp_id}",
+ mutates_args=(),
+ device_types="cuda",
+ )
+ def forward(
+ L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor
+ ) -> torch.Tensor:
+ L1_in_c, L2_in_c, weights_c = (
+ L1_in.contiguous(),
+ L2_in.contiguous(),
+ weights.contiguous(),
+ )
+ L3_out = torch.empty(
+ (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device=L1_in.device
+ )
+ self.forward_raw(
+ L1_in_c.shape[0],
+ L1_in_c.data_ptr(),
+ L2_in_c.data_ptr(),
+ L3_out.data_ptr(),
+ weights_c.data_ptr(),
+ )
+ return L3_out
+
+ @forward.register_fake
+ def _(L1_in, L2_in, weights):
+ return L1_in.new_empty(L1_in.shape[0], self.L3.dim)
+
+ self.forward_opaque = forward
+
+ # ---------------- Backward pass -----------------
+ @torch.library.custom_op(
+ f"openequivariance::tp_grad_helper{self.tp_id}",
+ mutates_args=(),
+ device_types="cuda",
+ )
+ def backward_helper(
+ L1_in: torch.Tensor,
+ L2_in: torch.Tensor,
+ weights: torch.Tensor,
+ L3_grad: torch.Tensor,
+ ) -> typing.List[torch.Tensor]:
+ L1_grad = torch.zeros_like(L1_in)
+ L2_grad = torch.zeros_like(L2_in)
+ weights_grad = torch.empty_like(weights)
+
+ if self.config.shared_weights:
+ weights_grad[:] = 0.0
+
+ self.backward_raw(
+ L1_in.shape[0],
+ L1_in.contiguous().data_ptr(),
+ L1_grad.data_ptr(),
+ L2_in.contiguous().data_ptr(),
+ L2_grad.data_ptr(),
+ weights.contiguous().data_ptr(),
+ weights_grad.data_ptr(),
+ L3_grad.contiguous().data_ptr(),
+ )
+
+ return [L1_grad, L2_grad, weights_grad]
+
+ @backward_helper.register_fake
+ def _(L1_in, L2_in, weights, L3_grad):
+ return [
+ L1_in.new_empty(*L1_in.shape),
+ L2_in.new_empty(*L2_in.shape),
+ weights.new_empty(*weights.shape),
+ ]
+
+ def setup_context(ctx, inputs, output):
+ ctx.L1_in, ctx.L2_in, ctx.weights = inputs
+
+ def backward(ctx, grad_output):
+ result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output)
+ return result[0], result[1], result[2]
+
+ self.forward_opaque.register_autograd(backward, setup_context=setup_context)
+
+ def setup_context_double_backward(ctx, inputs, output):
+ ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs
+
+ def double_backward(ctx, grad_output):
+ A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights
+ E, F, G = grad_output[0], grad_output[1], grad_output[2]
+
+ op1 = backward_helper(E, F, D, C)
+ op2 = backward_helper(A, B, G, C)
+ op3 = forward(E, B, D)
+ op4 = backward_helper(E, B, D, C)
+ op5 = backward_helper(A, F, D, C)
+ op6 = forward(A, F, D)
+ op7 = forward(A, B, G)
+
+ return (
+ op1[0] + op2[0],
+ op1[1] + op2[1],
+ (op4[2] + op5[2]),
+ (op3 + op6 + op7),
+ )
+
+ backward_helper.register_autograd(
+ double_backward, setup_context=setup_context_double_backward
+ )
+
+ @classmethod
+ def register_torch_fakes(cls):
+ @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct")
+ class TorchJITProduct:
+ def __init__(
+ self,
+ kernel_plaintext: str,
+ fwd_config: dict[str, int],
+ bwd_config: dict[str, int],
+ dbl_bwd_config: dict[str, int],
+ kernel_dims: dict[str, int],
+ ) -> None:
+ (
+ self.kernel_plaintext,
+ self.fwd_config,
+ self.bwd_config,
+ self.dbl_bwd_config,
+ self.kernel_dims,
+ ) = (
+ kernel_plaintext,
+ fwd_config,
+ bwd_config,
+ dbl_bwd_config,
+ kernel_dims,
+ )
+
+ @classmethod
+ def __obj_unflatten__(cls, flattened_product):
+ return cls(**dict(flattened_product))
+
+ def __len__(self):
+ return 0
+
+ def __setstate__(self, state):
+ self.kernel_plaintext = state["kernel_plaintext"]
+ self.fwd_config = state["fwd_config"]
+ self.bwd_config = state["bwd_config"]
+ self.dbl_bwd_config = state["dbl_bwd_config"]
+ self.kernel_dims = state["kernel_dims"]
+
+ def exec_tensor_product_rawptr(*args, **kwargs):
+ pass
+
+ def backward_rawptr(*args, **kwargs):
+ pass
+
+ def L3_dim_getter(self):
+ return self.kernel_dims["L3_dim"]
+
+ def irrep_dtype_getter(self):
+ return self.kernel_dims["irrep_dtype"]
+
+ @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward")
+ def fake_forward(jit, L1_in, L2_in, W):
+ L3_dim = None
+ if hasattr(jit, "wrapped_obj"):
+ L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"]
+ else:
+ L3_dim = jit.L3_dim
+
+ return L1_in.new_empty(L1_in.shape[0], L3_dim)
+
+ @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward")
+ def fake_backward(jit, L1_in, L2_in, W, L3_grad):
+ return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
+
+ @classmethod
+ def register_autograd(cls):
+ backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward
+
+ def setup_context(ctx, inputs, output):
+ ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs
+
+ def backward(ctx, grad_output):
+ L1_grad, L2_grad, W_grad = backward_op(
+ ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output
+ )
+ return None, L1_grad, L2_grad, W_grad
+
+ torch.library.register_autograd(
+ "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context
+ )
+
+ def setup_context_double_backward(ctx, inputs, output):
+ ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs
+
+ def double_backward(ctx, E, F, G):
+ result = torch.ops.libtorch_tp_jit.jit_tp_double_backward(
+ ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G
+ )
+ return None, result[0], result[1], result[2], result[3]
+
+ torch.library.register_autograd(
+ "libtorch_tp_jit::jit_tp_backward",
+ double_backward,
+ setup_context=setup_context_double_backward,
+ )
+
+ @classmethod
+ def register_autocast(cls):
+ global torch
+ import torch
+
+ torch.library.register_autocast(
+ "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32
+ )
+ torch.library.register_autocast(
+ "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32
+ )
+ torch.library.register_autocast(
+ "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32
+ )
+
+ @staticmethod
+ def name():
+ return "LoopUnrollTP"
+
+ def forward_raw(
+ self,
+ batch: np.uint64,
+ L1_in: np.uint64,
+ L2_in: np.uint64,
+ L3_out: np.uint64,
+ weights: np.uint64,
+ ) -> None:
+ self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights)
+
+ def backward_raw(
+ self,
+ batch_size: np.uint64,
+ L1_in: np.uint64,
+ L1_grad: np.uint64,
+ L2_in: np.uint64,
+ L2_grad: np.uint64,
+ weights: np.uint64,
+ weights_grad: np.uint64,
+ L3_grad: np.uint64,
+ ):
+ self.internal.backward_rawptr(
+ batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad
+ )
+
+ def forward_cpu(
+ self,
+ L1_in: np.ndarray,
+ L2_in: np.ndarray,
+ L3_out: np.ndarray,
+ weights: np.ndarray,
+ ) -> None:
+ weights_chunked = self.reorder_weights_from_e3nn(
+ weights, not self.config.shared_weights
+ )
+
+ batch = L1_in.shape[0]
+ L1_d = DeviceBuffer(L1_in)
+ L2_d = DeviceBuffer(L2_in)
+ L3_d = DeviceBuffer(L3_out)
+ weights_d = DeviceBuffer(weights_chunked)
+ self.internal.exec_tensor_product_rawptr(
+ batch,
+ L1_d.data_ptr(),
+ L2_d.data_ptr(),
+ L3_d.data_ptr(),
+ weights_d.data_ptr(),
+ )
+ L3_d.copy_to_host()
+
+ def backward_cpu(
+ self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
+ ) -> None:
+ weights_chunked = self.reorder_weights_from_e3nn(
+ weights, not self.config.shared_weights
+ )
+
+ batch = L1_in.shape[0]
+ L1_d, L2_d, L3_d = (
+ DeviceBuffer(L1_in),
+ DeviceBuffer(L2_in),
+ DeviceBuffer(L3_grad),
+ )
+ L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad)
+ weights_d, weights_grad_d = (
+ DeviceBuffer(weights_chunked),
+ DeviceBuffer(weights_grad),
+ )
+
+ self.internal.backward_rawptr(
+ batch,
+ L1_d.data_ptr(),
+ L1_grad_d.data_ptr(),
+ L2_d.data_ptr(),
+ L2_grad_d.data_ptr(),
+ weights_d.data_ptr(),
+ weights_grad_d.data_ptr(),
+ L3_d.data_ptr(),
+ )
+
+ L1_grad_d.copy_to_host()
+ L2_grad_d.copy_to_host()
+ weights_grad_d.copy_to_host()
+
+ weights_grad[:] = self.reorder_weights_to_e3nn(
+ weights_grad, not self.config.shared_weights
+ )
+
+
+if extlib.TORCH_COMPILE:
+ TensorProduct.register_torch_fakes()
+ TensorProduct.register_autograd()
+ TensorProduct.register_autocast()
diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py
similarity index 57%
rename from openequivariance/implementations/convolution/TensorProductConv.py
rename to openequivariance/openequivariance/_torch/TensorProductConv.py
index 7e860944..f30c943c 100644
--- a/openequivariance/implementations/convolution/TensorProductConv.py
+++ b/openequivariance/openequivariance/_torch/TensorProductConv.py
@@ -3,18 +3,32 @@
import numpy as np
import torch
-from openequivariance import extlib
-from openequivariance.implementations.convolution.ConvolutionBase import (
+import openequivariance._torch.extlib as extlib
+from openequivariance._torch.extlib import (
+ JITConvImpl,
+ postprocess_kernel,
+ DeviceProp,
+)
+
+from openequivariance.core.ConvolutionBase import (
ConvolutionBase,
scatter_add_wrapper,
)
-from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv
-from openequivariance.implementations.TensorProduct import TensorProduct
+from openequivariance.core.LoopUnrollConv import LoopUnrollConv
+from openequivariance._torch.TensorProduct import TensorProduct
from openequivariance import TPProblem
-from openequivariance.implementations.utils import torch_to_oeq_dtype
+from openequivariance.core.utils import torch_to_oeq_dtype
+from openequivariance._torch.utils import enum_to_torch_dtype
+from openequivariance._torch.utils import reorder_torch
+
+from openequivariance.benchmark.logging_utils import getLogger
+from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
+from openequivariance._torch.extlib import DeviceBuffer
+
+logger = getLogger()
-class TensorProductConv(torch.nn.Module, LoopUnrollConv):
+class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixinConv):
r"""
Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
:math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes
@@ -58,14 +72,34 @@ def __init__(
self._init_class()
def _init_class(self):
+ dp = DeviceProp(0)
LoopUnrollConv.__init__(
self,
self.input_args["problem"],
+ dp,
+ postprocess_kernel,
idx_dtype=np.int64,
torch_op=self.input_args["torch_op"],
deterministic=self.input_args["deterministic"],
kahan=self.input_args["kahan"],
)
+
+ self.allocate_workspace(self.workspace_size)
+ if extlib.TORCH_COMPILE:
+ internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv
+ else:
+ internal_cls = JITConvImpl
+
+ logger.info("Starting kernel compiler...")
+ self.internal = internal_cls(
+ self.jit_kernel,
+ vars(self.forward_schedule.launch_config),
+ vars(self.backward_schedule.launch_config),
+ vars(self.double_backward_schedule.launch_config),
+ self.kernel_prop,
+ )
+ logger.info("Kernel compiled!")
+
self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda")
self.weight_numel = self.config.weight_numel
self._setup_notorchbind()
@@ -151,9 +185,16 @@ def forward(
sender_perm,
)
- @staticmethod
- def name():
- return LoopUnrollConv.name()
+ def allocate_workspace(self, size_bytes):
+ self.workspace_size = size_bytes
+ if self.torch_op:
+ self.workspace_buffer = torch.zeros(
+ size_bytes, dtype=torch.uint8, device="cuda"
+ )
+ else:
+ self.workspace_buffer = extlib.DeviceBuffer(size_bytes)
+ self.workspace_ptr = self.workspace_buffer.data_ptr()
+ logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.")
def _setup_notorchbind(self):
@torch.library.custom_op(
@@ -379,6 +420,315 @@ def double_backward(ctx, grad_output):
double_backward, setup_context=setup_context_double_backward
)
+ def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
+ return reorder_torch(
+ self.forward_schedule, weights, "forward", not self.config.shared_weights
+ )
+
+ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
+ return reorder_torch(
+ self.forward_schedule, weights, "backward", not self.config.shared_weights
+ )
+
+ @staticmethod
+ def name():
+ return "LoopUnrollConv"
+
+ @classmethod
+ def register_torch_fakes(cls):
+ global torch
+ import torch
+
+ @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv")
+ class TorchJITConv:
+ def __init__(
+ self,
+ kernel_plaintext: str,
+ fwd_config: dict[str, int],
+ bwd_config: dict[str, int],
+ dbl_bwd_config: dict[str, int],
+ kernel_dims: dict[str, int],
+ ) -> None:
+ (
+ self.kernel_plaintext,
+ self.fwd_config,
+ self.bwd_config,
+ self.dbl_bwd_config,
+ self.kernel_dims,
+ ) = (
+ kernel_plaintext,
+ fwd_config,
+ bwd_config,
+ dbl_bwd_config,
+ kernel_dims,
+ )
+
+ @classmethod
+ def __obj_unflatten__(cls, flattened_product):
+ return cls(**dict(flattened_product))
+
+ def __len__(self):
+ return 0
+
+ def __setstate__(self, state):
+ (
+ self.kernel_plaintext,
+ self.fwd_config,
+ self.bwd_config,
+ self.dbl_bwd_config,
+ self.kernel_dims,
+ ) = state
+
+ def exec_conv_rawptrs(*args, **kwargs):
+ pass
+
+ def backward_rawptrs(*args, **kwargs):
+ pass
+
+ def double_backward_rawptrs(*args, **kwargs):
+ pass
+
+ def L3_dim_getter(self):
+ return self.kernel_dims["L3_dim"]
+
+ def irrep_dtype_getter(self):
+ return self.kernel_dims["irrep_dtype"]
+
+ @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward")
+ def fake_forward(
+ jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm
+ ):
+ L3_dim, irrep_dtype = None, None
+ if hasattr(jit, "wrapped_obj"):
+ L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"]
+ irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"]
+ else:
+ L3_dim = jit.L3_dim
+ irrep_dtype = jit.irrep_dtype
+
+ return torch.empty(
+ L1_in.shape[0],
+ L3_dim,
+ device="cuda",
+ dtype=enum_to_torch_dtype[irrep_dtype],
+ )
+
+ @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward")
+ def fake_backward(
+ jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm
+ ):
+ return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
+
+ @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward")
+ def fake_double_backward(
+ jit,
+ L1_in,
+ L2_in,
+ W,
+ L3_grad,
+ L1_dgrad,
+ L2_dgrad,
+ w_dgrad,
+ rows,
+ cols,
+ workspace_buffer,
+ transpose_perm=None,
+ ):
+ return [
+ L1_in.new_empty(*L1_in.shape),
+ L2_in.new_empty(*L2_in.shape),
+ W.new_empty(*W.shape),
+ L3_grad.new_empty(*L3_grad.shape),
+ ]
+
+ @classmethod
+ def register_autograd(cls):
+ backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward
+ double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward
+
+ def setup_context(ctx, inputs, output):
+ (
+ ctx.jit,
+ ctx.L1_in,
+ ctx.L2_in,
+ ctx.W,
+ ctx.rows,
+ ctx.cols,
+ ctx.workspace_buffer,
+ ctx.sender_perm,
+ ) = inputs
+
+ def backward(ctx, grad_output):
+ L1_grad, L2_grad, W_grad = backward_op(
+ ctx.jit,
+ ctx.L1_in,
+ ctx.L2_in,
+ ctx.W,
+ grad_output,
+ ctx.rows,
+ ctx.cols,
+ ctx.workspace_buffer,
+ ctx.sender_perm,
+ )
+ return None, L1_grad, L2_grad, W_grad, None, None, None, None
+
+ torch.library.register_autograd(
+ "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context
+ )
+
+ def setup_context_double_backward(ctx, inputs, output):
+ (
+ ctx.jit,
+ ctx.L1_in,
+ ctx.L2_in,
+ ctx.W,
+ ctx.grad_output,
+ ctx.rows,
+ ctx.cols,
+ ctx.workspace_buffer,
+ ctx.sender_perm,
+ ) = inputs
+ ctx.inputs = inputs
+
+ def double_backward(ctx, E, F, G):
+ result = double_backward_op(
+ ctx.jit,
+ ctx.L1_in,
+ ctx.L2_in,
+ ctx.W,
+ ctx.grad_output,
+ E,
+ F,
+ G,
+ ctx.rows,
+ ctx.cols,
+ ctx.workspace_buffer,
+ ctx.sender_perm,
+ )
+ return (
+ None,
+ result[0],
+ result[1],
+ result[2],
+ result[3],
+ None,
+ None,
+ None,
+ None,
+ )
+
+ torch.library.register_autograd(
+ "libtorch_tp_jit::jit_conv_backward",
+ double_backward,
+ setup_context=setup_context_double_backward,
+ )
+
+ @classmethod
+ def register_autocast(cls):
+ global torch
+ import torch
+
+ torch.library.register_autocast(
+ "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32
+ )
+ torch.library.register_autocast(
+ "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32
+ )
+ torch.library.register_autocast(
+ "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32
+ )
+
+ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
+ assert graph.rows.dtype == self.idx_dtype
+ assert graph.cols.dtype == self.idx_dtype
+
+ weights_chunked = self.reorder_weights_from_e3nn(
+ weights, not self.config.shared_weights
+ )
+
+ L1_d, L2_d, weights_d = (
+ DeviceBuffer(L1_in),
+ DeviceBuffer(L2_in),
+ DeviceBuffer(weights_chunked),
+ )
+ L3_d = DeviceBuffer(L3_out)
+
+ rows_d = DeviceBuffer(graph.rows)
+ cols_d = DeviceBuffer(graph.cols)
+
+ self.internal.exec_conv_rawptrs(
+ L1_d.data_ptr(),
+ L2_d.data_ptr(),
+ weights_d.data_ptr(),
+ L3_d.data_ptr(),
+ rows_d.data_ptr(),
+ cols_d.data_ptr(),
+ graph.nnz,
+ graph.node_count,
+ self.workspace_ptr,
+ )
+
+ L3_d.copy_to_host()
+
+ def backward_cpu(
+ self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph
+ ):
+ assert graph.rows.dtype == self.idx_dtype
+ assert graph.cols.dtype == self.idx_dtype
+
+ weights_chunked = self.reorder_weights_from_e3nn(
+ weights, not self.config.shared_weights
+ )
+
+ L1_d = DeviceBuffer(L1_in)
+ L2_d = DeviceBuffer(L2_in)
+ weights_d = DeviceBuffer(weights_chunked)
+ L3_d = DeviceBuffer(L3_grad)
+ rows_d = DeviceBuffer(graph.rows)
+ cols_d = DeviceBuffer(graph.cols)
+
+ L1_grad_d = DeviceBuffer(L1_grad)
+ L2_grad_d = DeviceBuffer(L2_grad)
+ weights_grad_d = DeviceBuffer(weights_grad)
+
+ transpose_perm_d = None
+ transpose_perm_ptr = 0
+ if self.deterministic:
+ transpose_perm_d = DeviceBuffer(graph.transpose_perm)
+ transpose_perm_ptr = transpose_perm_d.data_ptr()
+
+ self.internal.backward_rawptrs(
+ L1_d.data_ptr(),
+ L1_grad_d.data_ptr(),
+ L2_d.data_ptr(),
+ L2_grad_d.data_ptr(),
+ weights_d.data_ptr(),
+ weights_grad_d.data_ptr(),
+ L3_d.data_ptr(),
+ rows_d.data_ptr(),
+ cols_d.data_ptr(),
+ graph.nnz,
+ graph.node_count,
+ self.workspace_ptr,
+ transpose_perm_ptr,
+ )
+
+ L1_grad_d.copy_to_host()
+ L2_grad_d.copy_to_host()
+ weights_grad_d.copy_to_host()
+
+ weights_grad[:] = self.reorder_weights_to_e3nn(
+ weights_grad, not self.config.shared_weights
+ )
+
+ return L1_grad, L2_grad, weights_grad
+
+
+if extlib.TORCH_COMPILE:
+ TensorProductConv.register_torch_fakes()
+ TensorProductConv.register_autograd()
+ TensorProductConv.register_autocast()
+
# ==================================================================
# Reference implementations for benchmarking
diff --git a/openequivariance/extlib/.empty b/openequivariance/openequivariance/_torch/extlib/.empty
similarity index 100%
rename from openequivariance/extlib/.empty
rename to openequivariance/openequivariance/_torch/extlib/.empty
diff --git a/openequivariance/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py
similarity index 96%
rename from openequivariance/extlib/__init__.py
rename to openequivariance/openequivariance/_torch/extlib/__init__.py
index ac6f15a3..72440872 100644
--- a/openequivariance/extlib/__init__.py
+++ b/openequivariance/openequivariance/_torch/extlib/__init__.py
@@ -9,7 +9,7 @@
from openequivariance.benchmark.logging_utils import getLogger
-oeq_root = str(Path(__file__).parent.parent)
+oeq_root = str(Path(__file__).parent.parent.parent)
BUILT_EXTENSION = False
BUILT_EXTENSION_ERROR = None
@@ -23,7 +23,6 @@
torch_module, generic_module = None, None
postprocess_kernel = lambda kernel: kernel # noqa : E731
-
try:
python_lib_dir = sysconfig.get_config_var("LIBDIR")
major, minor = sys.version_info.major, sys.version_info.minor
@@ -42,9 +41,10 @@
if BUILT_EXTENSION:
- import openequivariance.extlib.generic_module
+ import openequivariance._torch.extlib.generic_module
+
+ generic_module = openequivariance._torch.extlib.generic_module
- generic_module = openequivariance.extlib.generic_module
elif torch.version.cuda or torch.version.hip:
try:
from torch.utils.cpp_extension import library_paths, include_paths
@@ -143,6 +143,10 @@ def _raise_import_error_helper(import_target: str):
raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}")
+def torch_ext_so_path():
+ return torch_module.__file__
+
+
if BUILT_EXTENSION:
from generic_module import (
JITTPImpl,
diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py
new file mode 100644
index 00000000..00edefcb
--- /dev/null
+++ b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py
@@ -0,0 +1,5 @@
+from openequivariance._torch.symmetric_contraction.symmetric_contraction import (
+ SymmetricContraction,
+)
+
+__all__ = ["SymmetricContraction"]
diff --git a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py
similarity index 99%
rename from openequivariance/implementations/symmetric_contraction/symmetric_contraction.py
rename to openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py
index 9790c2a2..504e788e 100644
--- a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py
+++ b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py
@@ -1,7 +1,7 @@
# ruff: noqa : E402
import torch
-from openequivariance.extlib import GroupMM_F32, GroupMM_F64
+from openequivariance._torch.extlib import GroupMM_F32, GroupMM_F64
class GroupMM:
diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py
new file mode 100644
index 00000000..7538fb27
--- /dev/null
+++ b/openequivariance/openequivariance/_torch/utils.py
@@ -0,0 +1,68 @@
+import torch
+from types import MappingProxyType
+from openequivariance.core.utils import DTypeEnum
+
+
+def reorder_helper(schedule, weights_in, direction, has_batch_dim):
+ assert direction in ["forward", "backward"]
+
+ specs = schedule.weight_reordering_info(weights_in, has_batch_dim)
+ weights_out = torch.zeros_like(weights_in)
+
+ for spec in specs:
+ parent_range = spec["parent_range"]
+ parent_shape = spec["parent_shape"]
+ weights_subrange = spec["weights_subrange"]
+ child_range = spec["child_range"]
+ transpose_perm = spec["transpose_perm"]
+
+ if direction == "forward":
+ reshape_size = spec["reshape_size"]
+
+ sliced_weights = weights_in[parent_range].reshape(parent_shape)[
+ weights_subrange
+ ]
+
+ weights_out[child_range] = sliced_weights.permute(transpose_perm).reshape(
+ reshape_size
+ )
+
+ elif direction == "backward":
+ transpose_child_shape = spec["transpose_child_shape"]
+ child_shape = spec["child_shape"]
+
+ sliced_weights = (
+ weights_in[child_range]
+ .reshape(transpose_child_shape)
+ .permute(transpose_perm)
+ )
+
+ weights_out[parent_range].reshape(parent_shape)[weights_subrange] = (
+ sliced_weights.flatten().reshape(child_shape)
+ )
+
+ return weights_out
+
+
+def reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim):
+ weights_in = torch.from_numpy(weights_in.copy())
+ result = reorder_helper(schedule, weights_in, direction, has_batch_dim)
+ return result.detach().cpu().numpy().copy()
+
+
+def reorder_torch(schedule, weights_in, direction, has_batch_dim):
+ if isinstance(weights_in, torch.Tensor):
+ return reorder_helper(schedule, weights_in, direction, has_batch_dim)
+ else:
+ return reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim)
+
+
+enum_to_torch_dtype = MappingProxyType(
+ {
+ DTypeEnum.FLOAT32: torch.float32,
+ DTypeEnum.FLOAT64: torch.float64,
+ DTypeEnum.INT32: torch.int32,
+ DTypeEnum.INT64: torch.int64,
+ DTypeEnum.UINT8: torch.uint8,
+ }
+)
diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py
similarity index 98%
rename from openequivariance/benchmark/ConvBenchmarkSuite.py
rename to openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py
index a4b7c982..499a33eb 100644
--- a/openequivariance/benchmark/ConvBenchmarkSuite.py
+++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py
@@ -7,7 +7,7 @@
import openequivariance as oeq
from openequivariance.benchmark.logging_utils import getLogger
-from openequivariance.implementations.convolution.ConvolutionBase import CoordGraph
+from openequivariance.core.ConvolutionBase import CoordGraph
logger = getLogger()
diff --git a/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py
similarity index 98%
rename from openequivariance/benchmark/TestBenchmarkSuite.py
rename to openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py
index d764be77..119c866c 100644
--- a/openequivariance/benchmark/TestBenchmarkSuite.py
+++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py
@@ -7,11 +7,11 @@
from dataclasses import dataclass
import openequivariance as oeq
-from openequivariance.extlib import DeviceProp
-from openequivariance.implementations.TensorProductBase import TensorProductBase
+from openequivariance._torch.extlib import DeviceProp
+from openequivariance.core.TensorProductBase import TensorProductBase
from openequivariance.benchmark.logging_utils import getLogger, bcolors
-from openequivariance.implementations.e3nn_lite import TPProblem
+from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.correctness_utils import (
correctness_forward,
correctness_backward,
diff --git a/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py
similarity index 96%
rename from openequivariance/benchmark/benchmark_utils.py
rename to openequivariance/openequivariance/benchmark/benchmark_utils.py
index 4dfea422..377df3d6 100644
--- a/openequivariance/benchmark/benchmark_utils.py
+++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py
@@ -10,10 +10,10 @@
calculate_minimum_memory_streamed_forward,
calculate_minimum_memory_streamed_backward,
)
-from openequivariance.implementations.utils import calculate_total_nnz
-from openequivariance.implementations.TensorProductBase import TensorProductBase
-from openequivariance.implementations.e3nn_lite import TPProblem
-from openequivariance.implementations.CUETensorProduct import CUETensorProduct
+from openequivariance.core.utils import calculate_total_nnz
+from openequivariance.core.TensorProductBase import TensorProductBase
+from openequivariance.core.e3nn_lite import TPProblem
+from openequivariance._torch.CUETensorProduct import CUETensorProduct
from openequivariance.benchmark.logging_utils import getLogger, bcolors
logger = getLogger()
diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py
similarity index 75%
rename from openequivariance/benchmark/correctness_utils.py
rename to openequivariance/openequivariance/benchmark/correctness_utils.py
index e2cf414b..788d209e 100644
--- a/openequivariance/benchmark/correctness_utils.py
+++ b/openequivariance/openequivariance/benchmark/correctness_utils.py
@@ -1,12 +1,14 @@
from typing import Optional, Union
-from openequivariance.implementations.TensorProductBase import TensorProductBase
-from openequivariance.implementations.CUETensorProduct import CUETensorProduct
-from openequivariance.implementations.e3nn_lite import TPProblem
+from openequivariance.core.TensorProductBase import TensorProductBase
+from openequivariance.core.e3nn_lite import TPProblem
+from openequivariance._torch.CUETensorProduct import CUETensorProduct
from openequivariance.benchmark.random_buffer_utils import (
get_random_buffers_forward,
get_random_buffers_backward,
+ get_random_buffers_double_backward,
)
+
from openequivariance.benchmark.logging_utils import getLogger, bcolors
import numpy as np
import numpy.linalg as la
@@ -71,7 +73,7 @@ def correctness_forward(
prng_seed: int,
) -> dict:
if reference_implementation is None:
- from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
+ from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
reference_implementation = E3NNTensorProduct
@@ -115,7 +117,7 @@ def correctness_backward(
prng_seed: int,
) -> dict:
if reference_implementation is None:
- from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
+ from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
reference_implementation = E3NNTensorProduct
@@ -194,66 +196,49 @@ def correctness_double_backward(
global torch
import torch
- in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward(
- problem, batch_size, prng_seed
+ in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = (
+ get_random_buffers_double_backward(
+ problem, batch_size=batch_size, prng_seed=prng_seed
+ )
)
- rng = np.random.default_rng(seed=prng_seed * 2)
- dummy_grad = rng.standard_normal(1)[0]
if reference_implementation is None:
- from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
+ from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
reference_implementation = E3NNTensorProduct
result = {"thresh": correctness_threshold, "batch_size": batch_size}
tensors = []
- for i, impl in enumerate([test_implementation, reference_implementation]):
+ for _, impl in enumerate([test_implementation, reference_implementation]):
tp = instantiate_implementation(impl, problem)
-
- if impl == CUETensorProduct and problem.shared_weights:
- weights = weights[np.newaxis, :]
-
weights_reordered = tp.reorder_weights_from_e3nn(
- weights, not tp.config.shared_weights
- )
-
- in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
- in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
- weights_torch = torch.tensor(
- weights_reordered, device="cuda", requires_grad=True
+ weights, has_batch_dim=not problem.shared_weights
)
-
- out_torch = tp.forward(in1_torch, in2_torch, weights_torch)
- out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True)
-
- in1_grad, in2_grad, w_grad = torch.autograd.grad(
- outputs=[out_torch],
- inputs=[in1_torch, in2_torch, weights_torch],
- grad_outputs=[out_grad],
- create_graph=True,
+ weights_dgrad_reordered = tp.reorder_weights_from_e3nn(
+ weights_dgrad, has_batch_dim=not problem.shared_weights
)
- dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad)
- dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True)
-
- dummy.backward(
- dummy_grad,
- retain_graph=True,
- inputs=[out_grad, in1_torch, in2_torch, weights_torch],
- )
-
- weights_grad = weights_torch.grad.detach().cpu().numpy()
- weights_grad = tp.reorder_weights_to_e3nn(
- weights_grad, not tp.config.shared_weights
+ if impl == CUETensorProduct and problem.shared_weights:
+ weights_reordered = weights_reordered[np.newaxis, :]
+
+ in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
+ in1,
+ in2,
+ out_grad,
+ weights_reordered,
+ weights_dgrad_reordered,
+ in1_dgrad,
+ in2_dgrad,
)
-
tensors.append(
(
- out_grad.grad.detach().cpu().numpy(),
- in1_torch.grad.detach().cpu().numpy(),
- in2_torch.grad.detach().cpu().numpy(),
- weights_grad,
+ out_dgrad,
+ in1_grad,
+ in2_grad,
+ tp.reorder_weights_to_e3nn(
+ weights_grad, has_batch_dim=not problem.shared_weights
+ ),
)
)
diff --git a/openequivariance/benchmark/logging_utils.py b/openequivariance/openequivariance/benchmark/logging_utils.py
similarity index 100%
rename from openequivariance/benchmark/logging_utils.py
rename to openequivariance/openequivariance/benchmark/logging_utils.py
diff --git a/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py
similarity index 96%
rename from openequivariance/benchmark/perf_metrics_utils.py
rename to openequivariance/openequivariance/benchmark/perf_metrics_utils.py
index 88a903ab..212f05f4 100644
--- a/openequivariance/benchmark/perf_metrics_utils.py
+++ b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py
@@ -1,11 +1,11 @@
import math
-from openequivariance.implementations.utils import (
+from openequivariance.core.utils import (
count_cg_non_zero,
sparse_outer_product_work,
)
-from openequivariance.implementations.e3nn_lite import TPProblem, wigner_3j
+from openequivariance.core.e3nn_lite import TPProblem, wigner_3j
from openequivariance.benchmark.logging_utils import getLogger
import numpy as np
diff --git a/openequivariance/benchmark/plotting/__init__.py b/openequivariance/openequivariance/benchmark/plotting/__init__.py
similarity index 100%
rename from openequivariance/benchmark/plotting/__init__.py
rename to openequivariance/openequivariance/benchmark/plotting/__init__.py
diff --git a/openequivariance/benchmark/plotting/plot_convolution.py b/openequivariance/openequivariance/benchmark/plotting/plot_convolution.py
similarity index 100%
rename from openequivariance/benchmark/plotting/plot_convolution.py
rename to openequivariance/openequivariance/benchmark/plotting/plot_convolution.py
diff --git a/openequivariance/benchmark/plotting/plot_double_backward.py b/openequivariance/openequivariance/benchmark/plotting/plot_double_backward.py
similarity index 100%
rename from openequivariance/benchmark/plotting/plot_double_backward.py
rename to openequivariance/openequivariance/benchmark/plotting/plot_double_backward.py
diff --git a/openequivariance/benchmark/plotting/plot_roofline.py b/openequivariance/openequivariance/benchmark/plotting/plot_roofline.py
similarity index 100%
rename from openequivariance/benchmark/plotting/plot_roofline.py
rename to openequivariance/openequivariance/benchmark/plotting/plot_roofline.py
diff --git a/openequivariance/benchmark/plotting/plot_uvu.py b/openequivariance/openequivariance/benchmark/plotting/plot_uvu.py
similarity index 100%
rename from openequivariance/benchmark/plotting/plot_uvu.py
rename to openequivariance/openequivariance/benchmark/plotting/plot_uvu.py
diff --git a/openequivariance/benchmark/plotting/plot_uvw.py b/openequivariance/openequivariance/benchmark/plotting/plot_uvw.py
similarity index 100%
rename from openequivariance/benchmark/plotting/plot_uvw.py
rename to openequivariance/openequivariance/benchmark/plotting/plot_uvw.py
diff --git a/openequivariance/benchmark/plotting/plotting_utils.py b/openequivariance/openequivariance/benchmark/plotting/plotting_utils.py
similarity index 100%
rename from openequivariance/benchmark/plotting/plotting_utils.py
rename to openequivariance/openequivariance/benchmark/plotting/plotting_utils.py
diff --git a/openequivariance/benchmark/problems.py b/openequivariance/openequivariance/benchmark/problems.py
similarity index 100%
rename from openequivariance/benchmark/problems.py
rename to openequivariance/openequivariance/benchmark/problems.py
diff --git a/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/random_buffer_utils.py
similarity index 74%
rename from openequivariance/benchmark/random_buffer_utils.py
rename to openequivariance/openequivariance/benchmark/random_buffer_utils.py
index 41fb7cb6..c657d5bc 100644
--- a/openequivariance/benchmark/random_buffer_utils.py
+++ b/openequivariance/openequivariance/benchmark/random_buffer_utils.py
@@ -1,6 +1,6 @@
import numpy as np
-from openequivariance.implementations.e3nn_lite import TPProblem
+from openequivariance.core.e3nn_lite import TPProblem
def get_random_buffers_forward(
@@ -104,10 +104,16 @@ def get_random_buffers_double_backward(
)
weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
- weights_grad = np.zeros_like(weights)
- in1_grad = np.zeros_like(in1)
- in2_grad = np.zeros_like(in2)
- out_double_grad = np.zeros_like(out_grad)
+ weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
+ in1_grad = np.array(
+ rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
+ )
+ in2_grad = np.array(
+ rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
+ )
+ out_double_grad = np.array(
+ rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
+ )
return (
in1,
@@ -176,3 +182,46 @@ def get_random_buffers_backward_conv(
in2_grad = np.zeros_like(in2)
return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad
+
+
+def get_random_buffers_double_backward_conv(
+ tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int
+):
+ rng = np.random.default_rng(prng_seed)
+ in1 = np.array(
+ rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
+ )
+ in2 = np.array(
+ rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
+ )
+ out_grad = np.array(
+ rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
+ )
+
+ weights_size = (
+ tuple([tpp.weight_numel])
+ if tpp.shared_weights
+ else tuple([edge_count, tpp.weight_numel])
+ )
+
+ weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
+ weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
+ in1_grad = np.array(
+ rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
+ )
+ in2_grad = np.array(
+ rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
+ )
+ out_double_grad = np.array(
+ rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
+ )
+ return (
+ in1,
+ in2,
+ out_grad,
+ weights,
+ weights_grad,
+ in1_grad,
+ in2_grad,
+ out_double_grad,
+ )
diff --git a/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py
similarity index 97%
rename from openequivariance/benchmark/tpp_creation_utils.py
rename to openequivariance/openequivariance/benchmark/tpp_creation_utils.py
index 18f3a84c..7637f412 100644
--- a/openequivariance/benchmark/tpp_creation_utils.py
+++ b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py
@@ -1,14 +1,12 @@
import numpy as np
from typing import Iterator, Optional
-from openequivariance.implementations.e3nn_lite import Irrep, Irreps, TPProblem
+from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem
"""
This was taken from
-
https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py
-
-And adopted to create TPP's to avoid torch dependence
+Adapted to create TPPs to avoid torch dependence.
"""
diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py
similarity index 89%
rename from openequivariance/implementations/ComputationSchedule.py
rename to openequivariance/openequivariance/core/ComputationSchedule.py
index 6d3b7215..9c3884c9 100644
--- a/openequivariance/implementations/ComputationSchedule.py
+++ b/openequivariance/openequivariance/core/ComputationSchedule.py
@@ -1,5 +1,5 @@
import numpy as np
-from openequivariance.implementations.e3nn_lite import Irreps, TPProblem, wigner_3j
+from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j
from itertools import accumulate
from openequivariance.benchmark.logging_utils import getLogger
@@ -619,40 +619,33 @@ def calculate_backward_smem(
smem=self.memory_per_warp * warps_per_block,
)
- def reorder_weights(self, weights_in, direction, has_batch_dim):
+ def weight_reordering_info(self, weights_in, has_batch_dim):
"""
- Reorders weights from the canonical e3nn form to the
- form that LoopUnrollTP can ingest. Can also reorder the parameters
- of a dense neural network layer that produces the weight matrix.
-
- If has_batch_dim is true, the first dimension of the input weight matrix
- is treated as the batch dimension.
+ Calculates all shapes, slices, and permutation info to reorder
+ weights.
"""
- import torch # TODO-someday: no need to specialize this to PyTorch
+ batch_dim = weights_in.shape[0]
+ reorder_specs = []
- weights_out = torch.zeros_like(weights_in)
- assert direction in ["forward", "backward"]
for i, child_inst in enumerate(self.problem_splitter.new_instructions):
parent_start, parent_end = (
child_inst.parent_weights_start,
child_inst.parent_weights_end,
)
parent_shape = list(child_inst.parent_weights_shape)
+ parent_range = [slice(parent_start, parent_end)]
child_start, child_end, child_shape = (
self.updated_config.weight_range_and_shape_for_instruction(i)
)
+ child_range = [slice(child_start, child_end)]
- parent_range, child_range = (
- [slice(parent_start, parent_end)],
- [slice(child_start, child_end)],
- )
weights_subrange = child_inst.weights_subrange
- batch_dim = weights_in.shape[0]
+
reshape_size = [-1]
transpose_perm = None
-
connection_mode = self.updated_config.instructions[i].connection_mode
+
if connection_mode == "uvu":
transpose_perm = [1, 0]
elif connection_mode == "uvw":
@@ -662,50 +655,29 @@ def reorder_weights(self, weights_in, direction, has_batch_dim):
child_range = [slice(0, batch_dim)] + child_range
parent_range = [slice(0, batch_dim)] + parent_range
parent_shape = [batch_dim] + parent_shape
+
child_shape = [batch_dim] + list(child_shape)
weights_subrange = [slice(0, batch_dim)] + child_inst.weights_subrange
reshape_size = [batch_dim] + reshape_size
- transpose_perm = [0] + [i + 1 for i in transpose_perm]
-
- if direction == "forward":
- sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[
- tuple(weights_subrange)
- ]
- weights_out[tuple(child_range)] = sliced_weights.permute(
- transpose_perm
- ).reshape(reshape_size)
- elif direction == "backward":
- transpose_child_shape = [child_shape[i] for i in transpose_perm]
- sliced_weights = (
- weights_in[tuple(child_range)]
- .reshape(transpose_child_shape)
- .permute(transpose_perm)
- )
- weights_out[tuple(parent_range)].reshape(parent_shape)[
- tuple(weights_subrange)
- ] = sliced_weights.flatten().reshape(child_shape)
-
- return weights_out
-
- def reorder_weights_numpy(self, weights_in, direction, has_batch_dim):
- import torch
-
- weights_in = torch.from_numpy(weights_in.copy())
- result = self.reorder_weights(weights_in, direction, has_batch_dim)
- return result.detach().cpu().numpy().copy()
- def reorder_weights_from_e3nn(self, weights_in, has_batch_dim):
- import torch
-
- if isinstance(weights_in, np.ndarray):
- return self.reorder_weights_numpy(weights_in, "forward", has_batch_dim)
- elif isinstance(weights_in, torch.Tensor):
- return self.reorder_weights(weights_in, "forward", has_batch_dim)
-
- def reorder_weights_to_e3nn(self, weights_in, has_batch_dim):
- import torch
+ if transpose_perm is not None:
+ transpose_perm = [0] + [k + 1 for k in transpose_perm]
+
+ transpose_child_shape = None
+ if transpose_perm is not None:
+ transpose_child_shape = [child_shape[k] for k in transpose_perm]
+
+ reorder_specs.append(
+ {
+ "parent_range": tuple(parent_range),
+ "parent_shape": parent_shape,
+ "weights_subrange": tuple(weights_subrange),
+ "child_range": tuple(child_range),
+ "child_shape": child_shape,
+ "transpose_perm": transpose_perm,
+ "reshape_size": reshape_size,
+ "transpose_child_shape": transpose_child_shape,
+ }
+ )
- if isinstance(weights_in, np.ndarray):
- return self.reorder_weights_numpy(weights_in, "backward", has_batch_dim)
- elif isinstance(weights_in, torch.Tensor):
- return self.reorder_weights(weights_in, "backward", has_batch_dim)
+ return reorder_specs
diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py
similarity index 73%
rename from openequivariance/implementations/convolution/ConvolutionBase.py
rename to openequivariance/openequivariance/core/ConvolutionBase.py
index 7ed16571..a06b2c79 100644
--- a/openequivariance/implementations/convolution/ConvolutionBase.py
+++ b/openequivariance/openequivariance/core/ConvolutionBase.py
@@ -1,15 +1,15 @@
import copy
import numpy as np
-from openequivariance.extlib import DeviceBuffer
from openequivariance.benchmark.random_buffer_utils import (
get_random_buffers_forward_conv,
get_random_buffers_backward_conv,
+ get_random_buffers_double_backward_conv,
)
from openequivariance.benchmark.logging_utils import getLogger, bcolors
from openequivariance.benchmark.correctness_utils import check_similiarity
-from openequivariance.implementations.e3nn_lite import wigner_3j
-from openequivariance.implementations.utils import benchmark
+from openequivariance.core.e3nn_lite import wigner_3j
+from openequivariance.core.utils import benchmark
logger = getLogger()
@@ -130,106 +130,10 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
"""
return weights
- def allocate_workspace(self, size_bytes):
- self.workspace_size = size_bytes
- if self.torch_op:
- self.workspace_buffer = torch.zeros(
- size_bytes, dtype=torch.uint8, device="cuda"
- )
- else:
- self.workspace_buffer = DeviceBuffer(size_bytes)
- self.workspace_ptr = self.workspace_buffer.data_ptr()
- logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.")
-
@staticmethod
def name():
raise NotImplementedError()
- def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
- assert graph.rows.dtype == self.idx_dtype
- assert graph.cols.dtype == self.idx_dtype
-
- weights_chunked = self.reorder_weights_from_e3nn(
- weights, not self.config.shared_weights
- )
-
- L1_d, L2_d, weights_d = (
- DeviceBuffer(L1_in),
- DeviceBuffer(L2_in),
- DeviceBuffer(weights_chunked),
- )
- L3_d = DeviceBuffer(L3_out)
-
- rows_d = DeviceBuffer(graph.rows)
- cols_d = DeviceBuffer(graph.cols)
-
- self.internal.exec_conv_rawptrs(
- L1_d.data_ptr(),
- L2_d.data_ptr(),
- weights_d.data_ptr(),
- L3_d.data_ptr(),
- rows_d.data_ptr(),
- cols_d.data_ptr(),
- graph.nnz,
- graph.node_count,
- self.workspace_ptr,
- )
-
- L3_d.copy_to_host()
-
- def backward_cpu(
- self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph
- ):
- assert graph.rows.dtype == self.idx_dtype
- assert graph.cols.dtype == self.idx_dtype
-
- weights_chunked = self.reorder_weights_from_e3nn(
- weights, not self.config.shared_weights
- )
-
- L1_d = DeviceBuffer(L1_in)
- L2_d = DeviceBuffer(L2_in)
- weights_d = DeviceBuffer(weights_chunked)
- L3_d = DeviceBuffer(L3_grad)
- rows_d = DeviceBuffer(graph.rows)
- cols_d = DeviceBuffer(graph.cols)
-
- L1_grad_d = DeviceBuffer(L1_grad)
- L2_grad_d = DeviceBuffer(L2_grad)
- weights_grad_d = DeviceBuffer(weights_grad)
-
- transpose_perm_d = None
- transpose_perm_ptr = 0
- if self.deterministic:
- transpose_perm_d = DeviceBuffer(graph.transpose_perm)
- transpose_perm_ptr = transpose_perm_d.data_ptr()
-
- self.internal.backward_rawptrs(
- L1_d.data_ptr(),
- L1_grad_d.data_ptr(),
- L2_d.data_ptr(),
- L2_grad_d.data_ptr(),
- weights_d.data_ptr(),
- weights_grad_d.data_ptr(),
- L3_d.data_ptr(),
- rows_d.data_ptr(),
- cols_d.data_ptr(),
- graph.nnz,
- graph.node_count,
- self.workspace_ptr,
- transpose_perm_ptr,
- )
-
- L1_grad_d.copy_to_host()
- L2_grad_d.copy_to_host()
- weights_grad_d.copy_to_host()
-
- weights_grad[:] = self.reorder_weights_to_e3nn(
- weights_grad, not self.config.shared_weights
- )
-
- return L1_grad, L2_grad, weights_grad
-
def test_correctness_forward(
self,
graph,
@@ -240,7 +144,7 @@ def test_correctness_forward(
high_precision_ref=False,
):
if reference_implementation is None:
- from openequivariance.implementations.convolution.E3NNConv import E3NNConv
+ from openequivariance._torch.E3NNConv import E3NNConv
reference_implementation = E3NNConv
@@ -484,7 +388,7 @@ def test_correctness_backward(
high_precision_ref=False,
):
if reference_implementation is None:
- from openequivariance.implementations.convolution.E3NNConv import E3NNConv
+ from openequivariance._torch.E3NNConv import E3NNConv
reference_implementation = E3NNConv
@@ -560,19 +464,12 @@ def test_correctness_double_backward(
reference_implementation=None,
high_precision_ref=False,
):
- global torch
- import torch
-
- assert self.torch_op
- buffers = get_random_buffers_backward_conv(
+ buffers = get_random_buffers_double_backward_conv(
self.config, graph.node_count, graph.nnz, prng_seed
)
- rng = np.random.default_rng(seed=prng_seed * 2)
- dummy_grad_value = rng.standard_normal(1)[0]
-
if reference_implementation is None:
- from openequivariance.implementations.convolution.E3NNConv import E3NNConv
+ from openequivariance._torch.E3NNConv import E3NNConv
reference_implementation = E3NNConv
@@ -587,61 +484,41 @@ def test_correctness_double_backward(
result = {"thresh": thresh}
tensors = []
for i, tp in enumerate([self, reference_tp]):
- in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers]
+ buffers_copy = [buf.copy() for buf in buffers]
if i == 1 and high_precision_ref:
- in1, in2, out_grad, weights, _, _, _ = [
- np.array(el, dtype=np.float64) for el in buffers
- ]
-
- in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
- in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
+ buffers_copy = [np.array(el, dtype=np.float64) for el in buffers]
- weights_reordered = tp.reorder_weights_from_e3nn(
- weights, not self.config.shared_weights
- )
-
- weights_torch = torch.tensor(
- weights_reordered, device="cuda", requires_grad=True
+ in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = (
+ buffers_copy
)
- torch_rows = torch.tensor(graph.rows, device="cuda")
- torch_cols = torch.tensor(graph.cols, device="cuda")
- torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda")
-
- fwd_args = [in1_torch, in2_torch, weights_torch, torch_rows, torch_cols]
- if tp.deterministic:
- fwd_args.append(torch_transpose_perm)
-
- out_torch = tp.forward(*fwd_args)
- out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True)
-
- in1_grad, in2_grad, w_grad = torch.autograd.grad(
- outputs=[out_torch],
- inputs=[in1_torch, in2_torch, weights_torch],
- grad_outputs=[out_grad_torch],
- create_graph=True,
- )
-
- dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad)
- dummy_grad = torch.tensor(
- float(dummy_grad_value), device="cuda", requires_grad=True
+ weights_reordered = tp.reorder_weights_from_e3nn(
+ weights, not tp.config.shared_weights
)
- dummy.backward(
- dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch]
+ weights_dgrad_reordered = tp.reorder_weights_from_e3nn(
+ weights_dgrad, not tp.config.shared_weights
)
- weights_grad = weights_torch.grad.detach().cpu().numpy()
- weights_grad = tp.reorder_weights_to_e3nn(
- weights_grad, not self.config.shared_weights
+ in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
+ in1,
+ in2,
+ out_grad,
+ weights_reordered,
+ weights_dgrad_reordered,
+ in1_dgrad,
+ in2_dgrad,
+ graph,
)
tensors.append(
(
- out_grad_torch.grad.detach().cpu().numpy().copy(),
- in1_torch.grad.detach().cpu().numpy().copy(),
- in2_torch.grad.detach().cpu().numpy().copy(),
- weights_grad.copy(),
+ out_dgrad,
+ in1_grad,
+ in2_grad,
+ tp.reorder_weights_to_e3nn(
+ weights_grad, has_batch_dim=not self.config.shared_weights
+ ),
)
)
diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py
new file mode 100644
index 00000000..35a9bc3e
--- /dev/null
+++ b/openequivariance/openequivariance/core/LoopUnrollConv.py
@@ -0,0 +1,207 @@
+import numpy as np
+
+from openequivariance.core.ConvolutionBase import ConvolutionBase
+from openequivariance.core.ComputationSchedule import (
+ ComputationSchedule,
+ SMEMCapacityException,
+)
+
+from openequivariance.core.utils import dtype_to_enum
+from openequivariance.templates.jinja_utils import get_jinja_environment
+from openequivariance.core.utils import filter_and_analyze_problem
+
+
+class LoopUnrollConv(ConvolutionBase):
+ def __init__(
+ self,
+ config,
+ dp,
+ postprocess_kernel,
+ *,
+ idx_dtype: type[np.generic] = np.int64,
+ torch_op: bool = False,
+ deterministic: bool = False,
+ kahan: bool = False,
+ ):
+ super().__init__(
+ config, idx_dtype=idx_dtype, torch_op=torch_op, deterministic=deterministic
+ )
+
+ if kahan:
+ assert deterministic
+
+ env = get_jinja_environment()
+ template = env.get_template("loop_unroll_conv_atomic.cuh")
+
+ analysis = filter_and_analyze_problem(config)
+ self.is_uvw = analysis["is_uvw"]
+
+ if config.shared_weights:
+ assert not deterministic, (
+ "Deterministic convolution does not support shared weights"
+ )
+
+ forward_schedule_type = 3
+ backward_schedule_type = 2
+ if deterministic:
+ backward_schedule_type = 3
+ template = env.get_template("loop_unroll_conv_det.cuh")
+
+ def generate_forward_schedule(warps_per_block):
+ self.forward_schedule = ComputationSchedule(
+ self.config,
+ smem_limit=dp.maxSharedMemPerBlock // 4 * 3,
+ warps_per_block=warps_per_block,
+ block_count=dp.multiprocessorCount,
+ direction="forward",
+ irrep_dtype=config.irrep_dtype,
+ weight_dtype=config.weight_dtype,
+ schedule_type=forward_schedule_type,
+ warp_size=dp.warpsize,
+ include_scratch=self.is_uvw,
+ stream_weights=self.is_uvw,
+ kahan=kahan,
+ )
+
+ def generate_backward_schedule(warps_per_block):
+ self.backward_schedule = ComputationSchedule(
+ self.config,
+ smem_limit=dp.maxSharedMemPerBlock,
+ warps_per_block=warps_per_block,
+ block_count=dp.multiprocessorCount * 2,
+ direction="backward",
+ irrep_dtype=config.irrep_dtype,
+ weight_dtype=config.weight_dtype,
+ schedule_type=backward_schedule_type,
+ warp_size=dp.warpsize,
+ include_scratch=self.is_uvw,
+ stream_weights=self.is_uvw,
+ kahan=kahan,
+ )
+
+ def generate_double_backward_schedule(warps_per_block):
+ self.double_backward_schedule = ComputationSchedule(
+ self.config,
+ smem_limit=dp.maxSharedMemPerBlock,
+ warps_per_block=warps_per_block,
+ warp_size=dp.warpsize,
+ block_count=dp.multiprocessorCount,
+ direction="double_backward",
+ irrep_dtype=config.irrep_dtype,
+ weight_dtype=config.weight_dtype,
+ include_scratch=self.is_uvw,
+ stream_weights=self.is_uvw,
+ schedule_type=3,
+ kahan=kahan,
+ )
+
+ scheduler_generators = [
+ generate_forward_schedule,
+ generate_backward_schedule,
+ generate_double_backward_schedule,
+ ]
+
+ for generate_schedule in scheduler_generators:
+ warp_count = 6
+ while warp_count > 0:
+ try:
+ generate_schedule(warp_count)
+ break
+ except SMEMCapacityException:
+ warp_count -= 1
+ if warp_count == 0:
+ raise SMEMCapacityException(
+ "Tensor product schedule generation failed, shared memory inadequate!"
+ )
+
+ if not deterministic:
+ for segment in self.forward_schedule.segments:
+ for key in segment.L3Map.storeback_procedure:
+ segment.L3Map.storeback_procedure[key] = "atomic_accumulate"
+
+ for segment in self.backward_schedule.segments:
+ for key in segment.L1Map.storeback_procedure:
+ segment.L1Map.storeback_procedure[key] = "atomic_accumulate"
+
+ for segment in self.double_backward_schedule.segments:
+ for key in segment.L1Map.storeback_procedure:
+ segment.L1Map.storeback_procedure[key] = "atomic_accumulate"
+
+ idx_type_map = {np.int32: "int", np.int64: "long"}
+
+ self.forward_workspace_offset = None
+ self.backward_workspace_offset = None
+ self.double_backwardB_offset = None
+
+ self.workspace_size = 1
+ if deterministic:
+ destination_index_bytes = 32 # Add extra to account for padding
+ self.workspace_size = max(
+ (
+ self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize
+ + destination_index_bytes
+ )
+ * self.forward_schedule.total_warps,
+ (
+ self.backward_schedule.L1.dim
+ * np.dtype(config.irrep_dtype).itemsize
+ + destination_index_bytes
+ )
+ * self.backward_schedule.total_warps,
+ (
+ self.double_backward_schedule.L1.dim
+ * np.dtype(config.irrep_dtype).itemsize
+ + destination_index_bytes
+ )
+ * self.double_backward_schedule.total_warps,
+ )
+
+ self.forward_workspace_offset = (
+ self.forward_schedule.L3.dim
+ * np.dtype(config.irrep_dtype).itemsize
+ * self.forward_schedule.total_warps
+ )
+ self.backward_workspace_offset = (
+ self.backward_schedule.L1.dim
+ * np.dtype(config.irrep_dtype).itemsize
+ * self.backward_schedule.total_warps
+ )
+ self.double_backwardB_offset = (
+ self.double_backward_schedule.L1.dim
+ * np.dtype(config.irrep_dtype).itemsize
+ * self.double_backward_schedule.total_warps
+ )
+
+ self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8
+ self.backward_workspace_offset = (
+ (self.backward_workspace_offset + 7) // 8 * 8
+ )
+ self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8
+
+ self.kernel_prop = {
+ "L1_dim": self.L1.dim,
+ "L2_dim": self.L2.dim,
+ "L3_dim": self.L3.dim,
+ "weight_numel": self.config.weight_numel,
+ "workspace_size": self.workspace_size,
+ "opt_level": 3,
+ "shared_weights": int(config.shared_weights),
+ "deterministic": int(self.deterministic),
+ "irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
+ "weight_dtype": dtype_to_enum[self.config.weight_dtype],
+ "idx_dtype": dtype_to_enum[self.idx_dtype],
+ }
+
+ self.jit_kernel = template.render(
+ forward_schedule=self.forward_schedule,
+ backward_schedule=self.backward_schedule,
+ double_backward_schedule=self.double_backward_schedule,
+ idx_type=idx_type_map[idx_dtype],
+ forward_workspace_offset=self.forward_workspace_offset,
+ backward_workspace_offset=self.backward_workspace_offset,
+ double_backwardB_offset=self.double_backwardB_offset,
+ )
+ self.jit_kernel = postprocess_kernel(self.jit_kernel)
+
+ # with open("scratch.txt", "w") as f:
+ # f.write(self.jit_kernel)
diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py
new file mode 100644
index 00000000..12ad4536
--- /dev/null
+++ b/openequivariance/openequivariance/core/LoopUnrollTP.py
@@ -0,0 +1,157 @@
+import numpy as np
+
+from openequivariance.templates.jinja_utils import get_jinja_environment
+from openequivariance.core.ComputationSchedule import ComputationSchedule
+from openequivariance.core.TensorProductBase import TensorProductBase
+from openequivariance.core.utils import dtype_to_enum
+
+from openequivariance.core.utils import (
+ filter_and_analyze_problem,
+ count_cg_non_zero,
+)
+
+
+class LoopUnrollTP(TensorProductBase):
+ def __init__(self, config, dp, postprocess_kernel, torch_op):
+ super().__init__(config, torch_op=torch_op)
+
+ env = get_jinja_environment()
+ template = env.get_template("loop_unroll_batch.cuh")
+
+ analysis = filter_and_analyze_problem(config)
+ self.is_uvw = analysis["is_uvw"]
+
+ def generate_forward_schedule(warps_per_block):
+ self.forward_schedule = ComputationSchedule(
+ self.config,
+ smem_limit=dp.maxSharedMemPerBlock,
+ warps_per_block=warps_per_block,
+ warp_size=dp.warpsize,
+ block_count=dp.multiprocessorCount * 4,
+ direction="forward",
+ irrep_dtype=config.irrep_dtype,
+ weight_dtype=config.weight_dtype,
+ include_scratch=self.is_uvw,
+ stream_weights=self.is_uvw,
+ )
+
+ def generate_backward_schedule(warps_per_block):
+ self.backward_schedule = ComputationSchedule(
+ self.config,
+ smem_limit=dp.maxSharedMemPerBlock,
+ warps_per_block=warps_per_block,
+ warp_size=dp.warpsize,
+ block_count=dp.multiprocessorCount * 4,
+ direction="backward",
+ irrep_dtype=config.irrep_dtype,
+ weight_dtype=config.weight_dtype,
+ include_scratch=self.is_uvw,
+ stream_weights=self.is_uvw,
+ )
+
+ def generate_double_backward_schedule(warps_per_block):
+ self.double_backward_schedule = ComputationSchedule(
+ self.config,
+ smem_limit=dp.maxSharedMemPerBlock,
+ warps_per_block=warps_per_block,
+ warp_size=dp.warpsize,
+ block_count=dp.multiprocessorCount,
+ direction="double_backward",
+ irrep_dtype=config.irrep_dtype,
+ weight_dtype=config.weight_dtype,
+ include_scratch=self.is_uvw,
+ stream_weights=self.is_uvw,
+ schedule_type=3,
+ )
+
+ scheduler_generators = [
+ generate_forward_schedule,
+ generate_backward_schedule,
+ generate_double_backward_schedule,
+ ]
+
+ for generate_schedule in scheduler_generators:
+ warp_count = 8
+ while warp_count > 0:
+ try:
+ generate_schedule(warp_count)
+ break
+ except Exception:
+ warp_count -= 2
+ if warp_count == 0:
+ raise RuntimeError(
+ "Tensor product schedule generation failed, shared memory inadequate!"
+ )
+
+ self.jit_kernel = postprocess_kernel(
+ template.render(
+ forward_schedule=self.forward_schedule,
+ backward_schedule=self.backward_schedule,
+ double_backward_schedule=self.double_backward_schedule,
+ )
+ )
+
+ self.kernelProp = {
+ "L1_dim": self.L1.dim,
+ "L2_dim": self.L2.dim,
+ "L3_dim": self.L3.dim,
+ "weight_numel": self.config.weight_numel,
+ "shared_weights": int(self.config.shared_weights),
+ "opt_level": 3,
+ "irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
+ "weight_dtype": dtype_to_enum[self.config.weight_dtype],
+ # Not relevant, included for compatibility with convolution
+ "workspace_size": 0,
+ "deterministic": 1,
+ "idx_dtype": 0,
+ }
+
+ def calculate_flops_forward(self, batch_size: int) -> dict:
+ if self.is_uvw:
+ return super().calculate_flops_forward(batch_size)
+ else:
+ tpp = self.config
+ flop_count = {
+ "CG_decomposition": 0,
+ "linear_combination": 0,
+ "outer_products": 0,
+ }
+ for ins in tpp.instructions:
+ l1, l2, l3 = (
+ tpp.irreps_in1[ins.i_in1].ir.l,
+ tpp.irreps_in2[ins.i_in2].ir.l,
+ tpp.irreps_out[ins.i_out].ir.l,
+ )
+ flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
+ ins.path_shape[0] * ins.path_shape[1]
+ )
+ flop_count["linear_combination"] += (
+ (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0
+ )
+
+ flop_count["CG_decomposition"] *= 3 * batch_size
+ flop_count["linear_combination"] *= (
+ batch_size # Weights do not require FMA here
+ )
+ flop_count["total"] = sum(flop_count.values())
+ return flop_count
+
+ def calculate_flops_backward(self, batch_size: int) -> dict:
+ if self.is_uvw:
+ return super().calculate_flops_backward(batch_size)
+ else:
+ tpp = self.config
+ flop_count = {"backward": 0}
+ for ins in tpp.instructions:
+ l1, l2, l3 = (
+ tpp.irreps_in1[ins.i_in1].ir.l,
+ tpp.irreps_in2[ins.i_in2].ir.l,
+ tpp.irreps_out[ins.i_out].ir.l,
+ )
+ flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (
+ ins.path_shape[0] * ins.path_shape[1]
+ )
+
+ flop_count["backward"] *= 9 * batch_size
+ flop_count["total"] = sum(flop_count.values())
+ return flop_count
diff --git a/openequivariance/implementations/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py
similarity index 71%
rename from openequivariance/implementations/TensorProductBase.py
rename to openequivariance/openequivariance/core/TensorProductBase.py
index 043c5b77..b5d3831f 100644
--- a/openequivariance/implementations/TensorProductBase.py
+++ b/openequivariance/openequivariance/core/TensorProductBase.py
@@ -1,9 +1,8 @@
import numpy as np
-from openequivariance.implementations.e3nn_lite import TPProblem
+from openequivariance.core.e3nn_lite import TPProblem
from openequivariance.benchmark.logging_utils import getLogger
-from openequivariance.implementations.utils import benchmark
-from openequivariance.extlib import DeviceBuffer
+from openequivariance.core.utils import benchmark
logger = getLogger()
@@ -44,7 +43,7 @@ def reorder_weights_from_e3nn(self, weights, has_batch_dim: bool = True):
Reorders weights from ``e3nn`` canonical order to the order used by ``oeq``.
:param weights: Weights in ``e3nn`` canonical order, either an
- np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
+ np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]``
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.
:param has_batch_dim: If ``True``, treats the first dimension of weights as a batch dimension. Default: ``True``.
@@ -57,8 +56,8 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True):
r"""
Reorders weights from ``oeq`` canonical order to the order used by ``e3nn``.
- :param weights: Weights in ``oeq`` canonical order, either an
- np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
+ :param weights: Weights in ``oeq`` canonical order, either a
+ np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]``
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.
:param has_batch_dim: If ``True``, treats the first dimension of wieghts as a batch dimension. Default: ``True``.
@@ -67,94 +66,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True):
"""
return weights
- def forward_raw(
- self,
- batch: np.uint64,
- L1_in: np.uint64,
- L2_in: np.uint64,
- L3_out: np.uint64,
- weights: np.uint64,
- ) -> None:
- self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights)
-
- def backward_raw(
- self,
- batch_size: np.uint64,
- L1_in: np.uint64,
- L1_grad: np.uint64,
- L2_in: np.uint64,
- L2_grad: np.uint64,
- weights: np.uint64,
- weights_grad: np.uint64,
- L3_grad: np.uint64,
- ):
- self.internal.backward_rawptr(
- batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad
- )
-
- def forward_cpu(
- self,
- L1_in: np.ndarray,
- L2_in: np.ndarray,
- L3_out: np.ndarray,
- weights: np.ndarray,
- ) -> None:
- weights_chunked = self.reorder_weights_from_e3nn(
- weights, not self.config.shared_weights
- )
-
- batch = L1_in.shape[0]
- L1_d = DeviceBuffer(L1_in)
- L2_d = DeviceBuffer(L2_in)
- L3_d = DeviceBuffer(L3_out)
- weights_d = DeviceBuffer(weights_chunked)
- self.internal.exec_tensor_product_rawptr(
- batch,
- L1_d.data_ptr(),
- L2_d.data_ptr(),
- L3_d.data_ptr(),
- weights_d.data_ptr(),
- )
- L3_d.copy_to_host()
-
- def backward_cpu(
- self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
- ) -> None:
- weights_chunked = self.reorder_weights_from_e3nn(
- weights, not self.config.shared_weights
- )
-
- batch = L1_in.shape[0]
- L1_d, L2_d, L3_d = (
- DeviceBuffer(L1_in),
- DeviceBuffer(L2_in),
- DeviceBuffer(L3_grad),
- )
- L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad)
- weights_d, weights_grad_d = (
- DeviceBuffer(weights_chunked),
- DeviceBuffer(weights_grad),
- )
-
- self.internal.backward_rawptr(
- batch,
- L1_d.data_ptr(),
- L1_grad_d.data_ptr(),
- L2_d.data_ptr(),
- L2_grad_d.data_ptr(),
- weights_d.data_ptr(),
- weights_grad_d.data_ptr(),
- L3_d.data_ptr(),
- )
-
- L1_grad_d.copy_to_host()
- L2_grad_d.copy_to_host()
- weights_grad_d.copy_to_host()
-
- weights_grad[:] = self.reorder_weights_to_e3nn(
- weights_grad, not self.config.shared_weights
- )
-
def benchmark_forward(
self,
num_warmup: int,
diff --git a/openequivariance/implementations/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py
similarity index 100%
rename from openequivariance/implementations/e3nn_lite.py
rename to openequivariance/openequivariance/core/e3nn_lite.py
diff --git a/openequivariance/implementations/utils.py b/openequivariance/openequivariance/core/utils.py
similarity index 84%
rename from openequivariance/implementations/utils.py
rename to openequivariance/openequivariance/core/utils.py
index b90993c1..5fd8f81d 100644
--- a/openequivariance/implementations/utils.py
+++ b/openequivariance/openequivariance/core/utils.py
@@ -3,11 +3,39 @@
import numpy as np
-from openequivariance.implementations.e3nn_lite import Instruction, TPProblem, wigner_3j
+from openequivariance.core.e3nn_lite import Instruction, TPProblem, wigner_3j
import json
import tempfile
-from openequivariance.extlib import GPUTimer
+import hashlib
+
+from enum import IntEnum
+
+
+class DTypeEnum(IntEnum):
+ """
+ The C++ layer storess a copy of this map.
+ """
+
+ FLOAT32 = 1
+ FLOAT64 = 2
+ INT32 = 3
+ INT64 = 4
+ UINT8 = 5
+
+
+dtype_to_enum = {
+ np.float32: DTypeEnum.FLOAT32,
+ np.float64: DTypeEnum.FLOAT64,
+ np.int32: DTypeEnum.INT32,
+ np.int64: DTypeEnum.INT64,
+ np.uint8: DTypeEnum.UINT8,
+ np.dtype(np.float32): DTypeEnum.FLOAT32,
+ np.dtype(np.float64): DTypeEnum.FLOAT64,
+ np.dtype(np.int32): DTypeEnum.INT32,
+ np.dtype(np.int64): DTypeEnum.INT64,
+ np.dtype(np.uint8): DTypeEnum.UINT8,
+}
def sparse_outer_product_work(cg: np.ndarray) -> int:
@@ -123,6 +151,8 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
mode=gpu_time may include PyTorch overhead
mode=kernel_time measures runtime for only the specified kernels
"""
+ from openequivariance._torch.extlib import GPUTimer
+
assert mode in ["gpu_time", "torch_kernel_time"]
time_millis = np.zeros(num_iter, dtype=np.float32)
timer = GPUTimer()
@@ -170,3 +200,13 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
time_millis[i] = kernel_time
return time_millis
+
+
+def hash_attributes(attrs):
+ m = hashlib.sha256()
+
+ for key in sorted(attrs.keys()):
+ m.update(attrs[key].__repr__().encode("utf-8"))
+
+ hash = int(m.hexdigest()[:16], 16) >> 1
+ attrs["hash"] = hash
diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp
similarity index 99%
rename from openequivariance/extension/convolution.hpp
rename to openequivariance/openequivariance/extension/convolution.hpp
index 75b4f879..3b2ce1e6 100644
--- a/openequivariance/extension/convolution.hpp
+++ b/openequivariance/openequivariance/extension/convolution.hpp
@@ -2,8 +2,6 @@
#include
#include
-#include
-#include
#include
struct ConvData {
diff --git a/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp
similarity index 100%
rename from openequivariance/extension/generic_module.cpp
rename to openequivariance/openequivariance/extension/generic_module.cpp
diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/openequivariance/extension/group_mm_cuda.hpp
similarity index 100%
rename from openequivariance/extension/group_mm_cuda.hpp
rename to openequivariance/openequivariance/extension/group_mm_cuda.hpp
diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/openequivariance/extension/group_mm_hip.hpp
similarity index 100%
rename from openequivariance/extension/group_mm_hip.hpp
rename to openequivariance/openequivariance/extension/group_mm_hip.hpp
diff --git a/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp
similarity index 100%
rename from openequivariance/extension/libtorch_tp_jit.cpp
rename to openequivariance/openequivariance/extension/libtorch_tp_jit.cpp
diff --git a/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp
similarity index 100%
rename from openequivariance/extension/tensorproducts.hpp
rename to openequivariance/openequivariance/extension/tensorproducts.hpp
diff --git a/openequivariance/extension/test/CMakeLists.txt b/openequivariance/openequivariance/extension/test/CMakeLists.txt
similarity index 100%
rename from openequivariance/extension/test/CMakeLists.txt
rename to openequivariance/openequivariance/extension/test/CMakeLists.txt
diff --git a/openequivariance/extension/test/load_jitscript.cpp b/openequivariance/openequivariance/extension/test/load_jitscript.cpp
similarity index 100%
rename from openequivariance/extension/test/load_jitscript.cpp
rename to openequivariance/openequivariance/extension/test/load_jitscript.cpp
diff --git a/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/util/backend_cuda.hpp
similarity index 98%
rename from openequivariance/extension/util/backend_cuda.hpp
rename to openequivariance/openequivariance/extension/util/backend_cuda.hpp
index 364186fc..4c79faed 100644
--- a/openequivariance/extension/util/backend_cuda.hpp
+++ b/openequivariance/openequivariance/extension/util/backend_cuda.hpp
@@ -349,7 +349,11 @@ class __attribute__((visibility("default"))) CUJITKernel {
~CUJITKernel() {
if(compiled) {
- CUDA_SAFE_CALL(cuLibraryUnload(library));
+ auto result = cuLibraryUnload(library);
+ if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
+ std::cout << "Failed to unload CUDA library, error code: " << ((int) result) << std::endl;
+ }
+
delete[] code;
}
NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
diff --git a/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/util/backend_hip.hpp
similarity index 100%
rename from openequivariance/extension/util/backend_hip.hpp
rename to openequivariance/openequivariance/extension/util/backend_hip.hpp
diff --git a/openequivariance/extension/util/buffer.hpp b/openequivariance/openequivariance/extension/util/buffer.hpp
similarity index 100%
rename from openequivariance/extension/util/buffer.hpp
rename to openequivariance/openequivariance/extension/util/buffer.hpp
diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py
new file mode 100644
index 00000000..452e7bb7
--- /dev/null
+++ b/openequivariance/openequivariance/jax/TensorProduct.py
@@ -0,0 +1,159 @@
+import jax
+import numpy as np
+from functools import partial
+from openequivariance.jax import extlib
+from openequivariance.core.e3nn_lite import TPProblem
+from openequivariance.core.LoopUnrollTP import LoopUnrollTP
+from openequivariance.core.utils import hash_attributes
+from openequivariance.jax.utils import reorder_jax
+
+
+@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
+def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
+ forward_call = jax.ffi.ffi_call(
+ "tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)
+ )
+ return forward_call(X, Y, W, **attrs)
+
+
+def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
+ return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
+
+
+@partial(jax.custom_vjp, nondiff_argnums=(4, 5))
+def backward(X, Y, W, dZ, irrep_dtype, attrs):
+ backward_call = jax.ffi.ffi_call(
+ "tp_backward",
+ (
+ jax.ShapeDtypeStruct(X.shape, irrep_dtype),
+ jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
+ jax.ShapeDtypeStruct(W.shape, irrep_dtype),
+ ),
+ )
+
+ return backward_call(X, Y, W, dZ, **attrs)
+
+
+def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs):
+ return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ)
+
+
+def double_backward(irrep_dtype, attrs, inputs, derivatives):
+ double_backward_call = jax.ffi.ffi_call(
+ "tp_double_backward",
+ (
+ jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
+ jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
+ jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
+ jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
+ ),
+ )
+ return double_backward_call(*inputs, *derivatives, **attrs)
+
+
+def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
+ return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
+
+
+forward.defvjp(forward_with_inputs, backward_autograd)
+backward.defvjp(backward_with_inputs, double_backward)
+
+
+class TensorProduct(LoopUnrollTP):
+ r"""
+ Identical to ``oeq.torch.TensorProduct`` with functionality in JAX.
+
+ :param problem: Specification of the tensor product.
+ """
+
+ def __init__(self, problem: TPProblem):
+ dp = extlib.DeviceProp(0)
+ super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False)
+
+ self.attrs = {
+ "kernel": self.jit_kernel,
+ "forward_config": vars(self.forward_schedule.launch_config),
+ "backward_config": vars(self.backward_schedule.launch_config),
+ "double_backward_config": vars(self.double_backward_schedule.launch_config),
+ "kernel_prop": self.kernelProp,
+ }
+ hash_attributes(self.attrs)
+
+ self.weight_numel = problem.weight_numel
+ self.L3_dim = self.config.irreps_out.dim
+
+ def forward(
+ self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray
+ ) -> jax.numpy.ndarray:
+ return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
+
+ def __call__(
+ self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray
+ ) -> jax.numpy.ndarray:
+ return self.forward(X, Y, W)
+
+ def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
+ return reorder_jax(
+ self.forward_schedule, weights, "forward", not self.config.shared_weights
+ )
+
+ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
+ return reorder_jax(
+ self.forward_schedule, weights, "backward", not self.config.shared_weights
+ )
+
+ def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None:
+ weights = self.reorder_weights_from_e3nn(
+ weights, has_batch_dim=not self.config.shared_weights
+ )
+ result = self.forward(
+ jax.numpy.asarray(L1_in),
+ jax.numpy.asarray(L2_in),
+ jax.numpy.asarray(weights),
+ )
+ L3_out[:] = np.asarray(result)
+
+ def backward_cpu(
+ self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
+ ) -> None:
+ weights = self.reorder_weights_from_e3nn(
+ weights, has_batch_dim=not self.config.shared_weights
+ )
+ backward_fn = jax.vjp(
+ lambda X, Y, W: self.forward(X, Y, W),
+ jax.numpy.asarray(L1_in),
+ jax.numpy.asarray(L2_in),
+ jax.numpy.asarray(weights),
+ )[1]
+ L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn(
+ jax.numpy.asarray(L3_grad)
+ )
+ L1_grad[:] = np.asarray(L1_grad_jax)
+ L2_grad[:] = np.asarray(L2_grad_jax)
+ weights_grad[:] = np.asarray(weights_grad_jax)
+ weights_grad[:] = self.reorder_weights_to_e3nn(
+ weights_grad, has_batch_dim=not self.config.shared_weights
+ )
+
+ def double_backward_cpu(
+ self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad
+ ):
+ in1_jax = jax.numpy.asarray(in1)
+ in2_jax = jax.numpy.asarray(in2)
+ weights_jax = jax.numpy.asarray(weights)
+ out_grad_jax = jax.numpy.asarray(out_grad)
+ in1_dgrad_jax = jax.numpy.asarray(in1_dgrad)
+ in2_dgrad_jax = jax.numpy.asarray(in2_dgrad)
+ weights_dgrad_jax = jax.numpy.asarray(weights_dgrad)
+
+ in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp(
+ lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[
+ 1
+ ](o),
+ in1_jax,
+ in2_jax,
+ weights_jax,
+ out_grad_jax,
+ )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
+
+ return in1_grad, in2_grad, weights_grad, out_dgrad
diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py
new file mode 100644
index 00000000..3aaee28a
--- /dev/null
+++ b/openequivariance/openequivariance/jax/TensorProductConv.py
@@ -0,0 +1,298 @@
+import numpy as np
+from functools import partial
+from typing import Optional
+from openequivariance.jax import extlib
+
+from openequivariance.core.e3nn_lite import TPProblem
+from openequivariance.core.LoopUnrollConv import LoopUnrollConv
+from openequivariance.core.utils import hash_attributes
+from openequivariance.jax.utils import reorder_jax
+
+import jax
+import jax.numpy as jnp
+
+from openequivariance.benchmark.logging_utils import getLogger
+
+logger = getLogger()
+
+
+@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
+def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
+ forward_call = jax.ffi.ffi_call(
+ "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)
+ )
+ return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs)
+
+
+def forward_with_inputs(
+ X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs
+):
+ return forward(
+ X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs
+ ), (X, Y, W, rows, cols, sender_perm, workspace)
+
+
+@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9))
+def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
+ backward_call = jax.ffi.ffi_call(
+ "conv_backward",
+ (
+ jax.ShapeDtypeStruct(X.shape, irrep_dtype),
+ jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
+ jax.ShapeDtypeStruct(W.shape, irrep_dtype),
+ ),
+ )
+ return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs)
+
+
+def backward_with_inputs(
+ X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
+):
+ return backward(
+ X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
+ ), (X, Y, W, dZ) # rows, cols, sender_perm, workspace)
+
+
+def double_backward(
+ rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives
+):
+ double_backward_call = jax.ffi.ffi_call(
+ "conv_double_backward",
+ (
+ jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
+ jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
+ jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
+ jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
+ ),
+ )
+ return double_backward_call(
+ *inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs
+ )
+
+
+def backward_autograd(
+ rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ
+):
+ return backward(
+ inputs[0],
+ inputs[1],
+ inputs[2],
+ dZ,
+ rows,
+ cols,
+ workspace,
+ sender_perm,
+ irrep_dtype,
+ attrs,
+ )
+
+
+forward.defvjp(forward_with_inputs, backward_autograd)
+backward.defvjp(backward_with_inputs, double_backward)
+
+
+class TensorProductConv(LoopUnrollConv):
+ r"""
+ Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one
+ key difference: integer arrays passed to this function must have dtype
+ ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version).
+
+ :param problem: Specification of the tensor product.
+ :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
+ fixup-based algorithm. `Default`: ``False``.
+ :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
+ the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
+ """
+
+ def __init__(
+ self, config: TPProblem, deterministic: bool = False, kahan: bool = False
+ ):
+ dp = extlib.DeviceProp(0)
+ super().__init__(
+ config,
+ dp,
+ extlib.postprocess_kernel,
+ idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version
+ torch_op=False,
+ deterministic=deterministic,
+ kahan=kahan,
+ )
+
+ self.attrs = {
+ "kernel": self.jit_kernel,
+ "forward_config": vars(self.forward_schedule.launch_config),
+ "backward_config": vars(self.backward_schedule.launch_config),
+ "double_backward_config": vars(self.double_backward_schedule.launch_config),
+ "kernel_prop": self.kernel_prop,
+ }
+ hash_attributes(self.attrs)
+
+ self.weight_numel = config.weight_numel
+ self.L3_dim = self.config.irreps_out.dim
+
+ self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8)
+ logger.info(
+ f"Convolution requires {self.workspace_size // (2**20)}MB of workspace."
+ )
+ self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32)
+
+ def forward(
+ self,
+ X: jax.numpy.ndarray,
+ Y: jax.numpy.ndarray,
+ W: jax.numpy.ndarray,
+ rows: jax.numpy.ndarray,
+ cols: jax.numpy.ndarray,
+ sender_perm: Optional[jax.numpy.ndarray] = None,
+ ) -> jax.numpy.ndarray:
+ r"""
+ Computes the fused CG tensor product + convolution.
+
+ :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
+ :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
+ :param W: Tensor of datatype ``problem.weight_dtype`` and shape
+
+ * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
+ * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
+
+ :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
+ datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
+ :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
+ datatype ``np.int32``.
+ :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a
+ permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
+ Must be provided when ``deterministic=True``.
+
+ :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
+ """
+ if not self.deterministic:
+ sender_perm = self.dummy_transpose_perm
+ else:
+ assert sender_perm is not None, (
+ "Must provide sender_perm for deterministic convolutions."
+ )
+
+ return forward(
+ X,
+ Y,
+ W,
+ rows,
+ cols,
+ self.workspace,
+ sender_perm,
+ self.L3_dim,
+ self.config.irrep_dtype,
+ self.attrs,
+ )
+
+ def __call__(
+ self,
+ X: jax.numpy.ndarray,
+ Y: jax.numpy.ndarray,
+ W: jax.numpy.ndarray,
+ rows: jax.numpy.ndarray,
+ cols: jax.numpy.ndarray,
+ sender_perm: Optional[jax.numpy.ndarray] = None,
+ ) -> jax.numpy.ndarray:
+ return self.forward(X, Y, W, rows, cols, sender_perm)
+
+ def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
+ return reorder_jax(self.forward_schedule, weights, "forward", has_batch_dim)
+
+ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
+ return reorder_jax(self.forward_schedule, weights, "backward", has_batch_dim)
+
+ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
+ rows = graph.rows.astype(np.int32)
+ cols = graph.cols.astype(np.int32)
+ sender_perm = graph.transpose_perm.astype(np.int32)
+ weights = self.reorder_weights_from_e3nn(
+ weights, has_batch_dim=not self.config.shared_weights
+ )
+ result = self.forward(
+ jax.numpy.asarray(L1_in),
+ jax.numpy.asarray(L2_in),
+ jax.numpy.asarray(weights),
+ jax.numpy.asarray(rows),
+ jax.numpy.asarray(cols),
+ jax.numpy.asarray(sender_perm),
+ )
+ L3_out[:] = np.asarray(result)
+
+ def backward_cpu(
+ self,
+ L1_in,
+ L1_grad,
+ L2_in,
+ L2_grad,
+ L3_grad,
+ weights,
+ weights_grad,
+ graph,
+ ):
+ rows = graph.rows.astype(np.int32)
+ cols = graph.cols.astype(np.int32)
+ sender_perm = graph.transpose_perm.astype(np.int32)
+ weights = self.reorder_weights_from_e3nn(
+ weights, has_batch_dim=not self.config.shared_weights
+ )
+
+ backward_fn = jax.vjp(
+ lambda X, Y, W: self.forward(
+ X,
+ Y,
+ W,
+ jax.numpy.asarray(rows),
+ jax.numpy.asarray(cols),
+ jax.numpy.asarray(sender_perm),
+ ),
+ jax.numpy.asarray(L1_in),
+ jax.numpy.asarray(L2_in),
+ jax.numpy.asarray(weights),
+ )[1]
+ L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn(
+ jax.numpy.asarray(L3_grad)
+ )
+ L1_grad[:] = np.asarray(L1_grad_jax)
+ L2_grad[:] = np.asarray(L2_grad_jax)
+ weights_grad[:] = np.asarray(weights_grad_jax)
+ weights_grad[:] = self.reorder_weights_to_e3nn(
+ weights_grad, has_batch_dim=not self.config.shared_weights
+ )
+
+ def double_backward_cpu(
+ self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph
+ ):
+ in1_jax = jax.numpy.asarray(in1)
+ in2_jax = jax.numpy.asarray(in2)
+ weights_jax = jax.numpy.asarray(weights)
+ out_grad_jax = jax.numpy.asarray(out_grad)
+ in1_dgrad_jax = jax.numpy.asarray(in1_dgrad)
+ in2_dgrad_jax = jax.numpy.asarray(in2_dgrad)
+ weights_dgrad_jax = jax.numpy.asarray(weights_dgrad)
+
+ rows_jax = jax.numpy.asarray(graph.rows.astype(self.idx_dtype))
+ cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype))
+ sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype))
+
+ in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp(
+ lambda x, y, w, o: jax.vjp(
+ lambda a, b, c: self.forward(
+ a, b, c, rows_jax, cols_jax, sender_perm_jax
+ ),
+ x,
+ y,
+ w,
+ )[1](o),
+ in1_jax,
+ in2_jax,
+ weights_jax,
+ out_grad_jax,
+ )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
+
+ return (
+ np.asarray(in1_grad),
+ np.asarray(in2_grad),
+ np.asarray(weights_grad),
+ np.asarray(out_dgrad),
+ )
diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py
new file mode 100644
index 00000000..410e5dbf
--- /dev/null
+++ b/openequivariance/openequivariance/jax/__init__.py
@@ -0,0 +1,6 @@
+from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct
+from openequivariance.jax.TensorProductConv import (
+ TensorProductConv as TensorProductConv,
+)
+
+__all__ = ["TensorProduct", "TensorProductConv"]
diff --git a/openequivariance/openequivariance/jax/extlib/__init__.py b/openequivariance/openequivariance/jax/extlib/__init__.py
new file mode 100644
index 00000000..d0965502
--- /dev/null
+++ b/openequivariance/openequivariance/jax/extlib/__init__.py
@@ -0,0 +1,28 @@
+import jax
+import openequivariance_extjax as oeq_extjax
+
+
+def postprocess_kernel(kernel):
+ if oeq_extjax.is_hip():
+ kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
+ kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
+ kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
+ return kernel
+ else:
+ return kernel
+
+
+platform = "CUDA"
+if oeq_extjax.is_hip():
+ platform = "ROCM"
+
+for name, target in oeq_extjax.registrations().items():
+ jax.ffi.register_ffi_target(name, target, platform=platform)
+
+GPUTimer = oeq_extjax.GPUTimer
+DeviceProp = oeq_extjax.DeviceProp
+
+__all__ = [
+ "GPUTimer",
+ "DeviceProp",
+]
diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py
new file mode 100644
index 00000000..ae15d1a6
--- /dev/null
+++ b/openequivariance/openequivariance/jax/utils.py
@@ -0,0 +1,63 @@
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+
+def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim):
+ assert direction in ["forward", "backward"]
+
+ specs = schedule.weight_reordering_info(weights_in, has_batch_dim)
+ weights_out = jnp.zeros_like(weights_in)
+
+ for spec in specs:
+ parent_range = spec["parent_range"]
+ parent_shape = spec["parent_shape"]
+ weights_subrange = spec["weights_subrange"]
+ child_range = spec["child_range"]
+ transpose_perm = spec["transpose_perm"]
+
+ if direction == "forward":
+ reshape_size = spec["reshape_size"]
+
+ sliced_weights = weights_in[parent_range].reshape(parent_shape)[
+ weights_subrange
+ ]
+
+ value_to_assign = sliced_weights.transpose(transpose_perm).reshape(
+ reshape_size
+ )
+ weights_out = weights_out.at[child_range].set(value_to_assign)
+
+ elif direction == "backward":
+ transpose_child_shape = spec["transpose_child_shape"]
+ child_shape = spec["child_shape"]
+
+ sliced_weights = (
+ weights_in[child_range]
+ .reshape(transpose_child_shape)
+ .transpose(transpose_perm)
+ )
+
+ value_to_insert = sliced_weights.flatten().reshape(child_shape)
+
+ slab = weights_out[parent_range]
+ slab_reshaped = slab.reshape(parent_shape)
+ slab_reshaped = slab_reshaped.at[weights_subrange].set(value_to_insert)
+ weights_out = weights_out.at[parent_range].set(
+ slab_reshaped.reshape(slab.shape)
+ )
+
+ return weights_out
+
+
+def reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim):
+ weights_in_jax = jnp.array(weights_in)
+ result = reorder_jax_helper(schedule, weights_in_jax, direction, has_batch_dim)
+ return np.array(result)
+
+
+def reorder_jax(schedule, weights_in, direction, has_batch_dim):
+ if isinstance(weights_in, (jnp.ndarray, jax.Array)):
+ return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim)
+ else:
+ return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim)
diff --git a/openequivariance/templates/common.cuh b/openequivariance/openequivariance/templates/common.cuh
similarity index 100%
rename from openequivariance/templates/common.cuh
rename to openequivariance/openequivariance/templates/common.cuh
diff --git a/openequivariance/templates/jinja_utils.py b/openequivariance/openequivariance/templates/jinja_utils.py
similarity index 100%
rename from openequivariance/templates/jinja_utils.py
rename to openequivariance/openequivariance/templates/jinja_utils.py
diff --git a/openequivariance/templates/loop_unroll_batch.cuh b/openequivariance/openequivariance/templates/loop_unroll_batch.cuh
similarity index 100%
rename from openequivariance/templates/loop_unroll_batch.cuh
rename to openequivariance/openequivariance/templates/loop_unroll_batch.cuh
diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/openequivariance/templates/loop_unroll_conv_atomic.cuh
similarity index 100%
rename from openequivariance/templates/loop_unroll_conv_atomic.cuh
rename to openequivariance/openequivariance/templates/loop_unroll_conv_atomic.cuh
diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/openequivariance/templates/loop_unroll_conv_det.cuh
similarity index 100%
rename from openequivariance/templates/loop_unroll_conv_det.cuh
rename to openequivariance/openequivariance/templates/loop_unroll_conv_det.cuh
diff --git a/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh
similarity index 100%
rename from openequivariance/templates/loop_unroll_tp.cuh
rename to openequivariance/openequivariance/templates/loop_unroll_tp.cuh
diff --git a/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja
similarity index 100%
rename from openequivariance/templates/macros.jinja
rename to openequivariance/openequivariance/templates/macros.jinja
diff --git a/openequivariance/templates/wmm.cuh b/openequivariance/openequivariance/templates/wmm.cuh
similarity index 100%
rename from openequivariance/templates/wmm.cuh
rename to openequivariance/openequivariance/templates/wmm.cuh
diff --git a/pyproject.toml b/openequivariance/pyproject.toml
similarity index 91%
rename from pyproject.toml
rename to openequivariance/pyproject.toml
index 6f7b8771..a0ddd618 100644
--- a/pyproject.toml
+++ b/openequivariance/pyproject.toml
@@ -17,8 +17,7 @@ dependencies = [
"setuptools",
"ninja",
"jinja2",
- "numpy",
- "torch >= 2.4",
+ "numpy"
]
readme = "README.md"
@@ -57,17 +56,20 @@ dev = [
"pytest-check",
"pytest-subtests",
"torch_geometric",
- "cmake",
- "furo",
- "sphinx",
- "sphinx-autobuild"
+ "cmake"
+]
+
+jax = [
+ "nanobind",
+ "scikit-build-core",
+ "setuptools-scm"
]
[tool.setuptools.packages.find]
include = ["openequivariance*"]
[tool.setuptools_scm]
-# Presence of this section necessary, even if empty
+root = ".."
[tool.pytest.ini_options]
addopts = [
diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt
new file mode 100644
index 00000000..25fec285
--- /dev/null
+++ b/openequivariance_extjax/CMakeLists.txt
@@ -0,0 +1,97 @@
+cmake_minimum_required(VERSION 3.15...3.30)
+project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)
+
+find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
+
+# --- XLA CONFIGURATION ---
+if(XLA_DIRECT_DOWNLOAD)
+ message(STATUS "XLA_DIRECT_DOWNLOAD is ON. Fetching XLA source...")
+ include(ExternalProject)
+ ExternalProject_Add(
+ xla
+ PREFIX ${CMAKE_BINARY_DIR}/xla
+ GIT_REPOSITORY https://github.com/openxla/xla.git
+ GIT_TAG main
+ GIT_SHALLOW TRUE
+ GIT_PROGRESS TRUE
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ LOG_DOWNLOAD ON
+ )
+ ExternalProject_Get_Property(xla source_dir)
+ set(XLA_DIR ${source_dir})
+else()
+ message(STATUS "XLA_DIRECT_DOWNLOAD is OFF. Locating XLA via installed JAX...")
+ execute_process(
+ COMMAND "${Python_EXECUTABLE}" "-c"
+ "from jax import ffi; print(ffi.include_dir())"
+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR
+ )
+endif()
+
+message(STATUS "XLA include directory: ${XLA_DIR}")
+# -------------------------
+
+execute_process(
+ COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT
+)
+message(STATUS "nanobind cmake directory: ${nanobind_ROOT}")
+
+find_package(nanobind CONFIG REQUIRED)
+
+if(JAX_HIP)
+ message(STATUS "JAX_HIP is set. Building with HIP backend.")
+ find_package(hip REQUIRED)
+else()
+ message(STATUS "JAX_HIP is not set (or zero). Building with CUDA backend.")
+ find_package(CUDAToolkit REQUIRED)
+endif()
+
+set(HEADER_DIR "src/extension")
+set(OEQ_JAX_SOURCES
+ src/libjax_tp_jit.cpp
+)
+
+set(OEQ_JAX_HEADERS
+ ${HEADER_DIR}/convolution.hpp
+ ${HEADER_DIR}/tensorproducts.hpp
+ ${HEADER_DIR}/util/backend_cuda.hpp
+ ${HEADER_DIR}/util/backend_hip.hpp
+ ${HEADER_DIR}/util/buffer.hpp
+)
+
+nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS})
+target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DIR})
+set_target_properties(openequivariance_extjax PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type)
+
+# Ensure the module waits for XLA download if we are in direct download mode
+if(XLA_DIRECT_DOWNLOAD)
+ add_dependencies(openequivariance_extjax xla)
+endif()
+
+if(JAX_HIP)
+ target_link_libraries(openequivariance_extjax PRIVATE hiprtc)
+ target_compile_definitions(openequivariance_extjax PRIVATE HIP_BACKEND=1)
+
+else()
+ set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17)
+
+ get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION)
+ get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY)
+
+ set_target_properties(openequivariance_extjax PROPERTIES
+ BUILD_RPATH "${CUDA_LIB_DIR}"
+ INSTALL_RPATH "${CUDA_LIB_DIR}"
+ )
+
+ target_link_libraries(openequivariance_extjax PRIVATE
+ CUDA::cudart
+ CUDA::cuda_driver
+ CUDA::nvrtc)
+ target_compile_definitions(openequivariance_extjax PRIVATE CUDA_BACKEND=1)
+endif()
+
+install(TARGETS openequivariance_extjax LIBRARY DESTINATION .)
\ No newline at end of file
diff --git a/openequivariance_extjax/LICENSE b/openequivariance_extjax/LICENSE
new file mode 120000
index 00000000..ea5b6064
--- /dev/null
+++ b/openequivariance_extjax/LICENSE
@@ -0,0 +1 @@
+../LICENSE
\ No newline at end of file
diff --git a/openequivariance_extjax/MANIFEST.in b/openequivariance_extjax/MANIFEST.in
new file mode 100644
index 00000000..fcd76e59
--- /dev/null
+++ b/openequivariance_extjax/MANIFEST.in
@@ -0,0 +1,2 @@
+include CMakeLists.txt
+include src/*
\ No newline at end of file
diff --git a/openequivariance_extjax/README.md b/openequivariance_extjax/README.md
new file mode 100644
index 00000000..ad7455ef
--- /dev/null
+++ b/openequivariance_extjax/README.md
@@ -0,0 +1,3 @@
+# OpenEquivariance JAX Extension
+
+The JAX extension module for OpenEquivariance.
\ No newline at end of file
diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml
new file mode 100644
index 00000000..47a21179
--- /dev/null
+++ b/openequivariance_extjax/pyproject.toml
@@ -0,0 +1,53 @@
+[build-system]
+requires = [
+ "setuptools-scm",
+ "scikit-build-core",
+ "nanobind"
+]
+build-backend = "scikit_build_core.build"
+
+[project]
+name = "openequivariance_extjax"
+dynamic = ["version"]
+authors = [
+ { name="Austin Glover" },
+ { name="Vivek Bharadwaj" },
+ { name="Aydin Buluc" },
+ { name="James Demmel" }
+]
+description = "JAX C++ Extension for OpenEquivariance"
+requires-python = ">=3.10"
+
+dependencies = []
+readme = "README.md"
+
+license = "BSD-3-Clause"
+license-files = ["LICENSE"]
+
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+]
+
+[project.urls]
+homepage = "https://passionlab.github.io/OpenEquivariance/"
+source = "https://github.com/PASSIONLab/OpenEquivariance"
+issues = "https://github.com/PASSIONLab/OpenEquivariance/issues"
+
+[tool.scikit-build.cmake.define]
+JAX_HIP = {env="JAX_HIP", default="0"}
+XLA_DIRECT_DOWNLOAD = {env="XLA_DIRECT_DOWNLOAD", default="0"}
+
+[tool.setuptools_scm]
+root = ".."
+
+[tool.pytest.ini_options]
+addopts = [
+ "--import-mode=importlib",
+]
+
+[tool.ruff]
+lint.ignore = ["E741"]
\ No newline at end of file
diff --git a/openequivariance_extjax/src/extension b/openequivariance_extjax/src/extension
new file mode 120000
index 00000000..8370a418
--- /dev/null
+++ b/openequivariance_extjax/src/extension
@@ -0,0 +1 @@
+../../openequivariance/openequivariance/extension
\ No newline at end of file
diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp
new file mode 100644
index 00000000..ae2035e8
--- /dev/null
+++ b/openequivariance_extjax/src/libjax_tp_jit.cpp
@@ -0,0 +1,773 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "nanobind/nanobind.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace nb = nanobind;
+namespace ffi = xla::ffi;
+
+#ifdef CUDA_BACKEND
+ #include
+ #include
+
+ #include "util/backend_cuda.hpp"
+ #include "group_mm_cuda.hpp"
+ using JITKernel = CUJITKernel;
+ using GPU_Allocator = CUDA_Allocator;
+
+ template
+ using GroupMM = GroupMMCUDA;
+ using stream_t = cudaStream_t;
+#endif
+
+#ifdef HIP_BACKEND
+ #include "util/backend_hip.hpp"
+ #include "group_mm_hip.hpp"
+ using JITKernel = HIPJITKernel;
+ using GPU_Allocator = HIP_Allocator;
+
+ template
+ using GroupMM = GroupMMHIP;
+ using stream_t = hipStream_t;
+#endif
+
+#include "tensorproducts.hpp"
+#include "convolution.hpp"
+
+xla::ffi::DataType enum_to_xla_dtype(int64_t i){
+ switch(i) {
+ case 1:
+ return xla::ffi::DataType::F32;
+ case 2:
+ return xla::ffi::DataType::F64;
+ case 3:
+ return xla::ffi::DataType::S32;
+ case 4:
+ return xla::ffi::DataType::S64;
+ case 5:
+ return xla::ffi::DataType::U8;
+ }
+ throw logic_error("Unsupported tensor datatype!");
+}
+
+std::string xla_dtype_to_string(xla::ffi::DataType dtype) {
+ const std::unordered_map map = {
+ {xla::ffi::DataType::INVALID, "INVALID"},
+ {xla::ffi::DataType::PRED, "PRED"},
+ {xla::ffi::DataType::S8, "S8"},
+ {xla::ffi::DataType::S16, "S16"},
+ {xla::ffi::DataType::S32, "S32"},
+ {xla::ffi::DataType::S64, "S64"},
+ {xla::ffi::DataType::U8, "U8"},
+ {xla::ffi::DataType::U16, "U16"},
+ {xla::ffi::DataType::U32, "U32"},
+ {xla::ffi::DataType::U64, "U64"},
+ {xla::ffi::DataType::F16, "F16"},
+ {xla::ffi::DataType::F32, "F32"},
+ {xla::ffi::DataType::F64, "F64"},
+ {xla::ffi::DataType::BF16, "BF16"},
+ {xla::ffi::DataType::C64, "C64"},
+ {xla::ffi::DataType::C128, "C128"},
+ {xla::ffi::DataType::TOKEN, "TOKEN"},
+ {xla::ffi::DataType::F8E5M2, "F8E5M2"},
+ {xla::ffi::DataType::F8E4M3, "F8E4M3"},
+ {xla::ffi::DataType::F8E4M3FN, "F8E4M3FN"},
+ {xla::ffi::DataType::F8E4M3B11FNUZ, "F8E4M3B11FNUZ"},
+ {xla::ffi::DataType::F8E5M2FNUZ, "F8E5M2FNUZ"},
+ {xla::ffi::DataType::F8E4M3FNUZ, "F8E4M3FNUZ"},
+ {xla::ffi::DataType::F8E3M4, "F8E3M4"},
+ {xla::ffi::DataType::F4E2M1FN, "F4E2M1FN"},
+ {xla::ffi::DataType::F8E8M0FNU, "F8E8M0FNU"},
+ };
+ return map.at(dtype);
+}
+
+inline void* data_ptr(ffi::AnyBuffer &buffer) {
+ return buffer.untyped_data();
+}
+
+inline void* data_ptr(ffi::Result &buffer) {
+ return data_ptr(*buffer);
+}
+
+inline int byte_count(ffi::AnyBuffer &buffer) {
+ switch (buffer.element_type()) {
+ case xla::ffi::DataType::U32:
+ case xla::ffi::DataType::S32:
+ case xla::ffi::DataType::F32:
+ return 4;
+ case xla::ffi::DataType::F64:
+ case xla::ffi::DataType::S64:
+ return 8;
+ case xla::ffi::DataType::U8:
+ return 1;
+ default:
+ throw logic_error("Unsupported tensor datatype!");
+ }
+}
+
+#ifdef CUDA_BACKEND
+void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) {
+ cudaMemsetAsync(
+ data_ptr(buffer),
+ 0,
+ buffer.element_count() * byte_count(buffer),
+ stream);
+}
+#endif
+#ifdef HIP_BACKEND
+void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) {
+ std::ignore = hipMemsetAsync(
+ data_ptr(buffer),
+ 0,
+ buffer.element_count() * byte_count(buffer),
+ stream);
+}
+#endif
+
+struct KernelProp {
+ int64_t L1_dim, L2_dim, L3_dim, weight_numel;
+ bool shared_weights;
+ xla::ffi::DataType irrep_dtype;
+ xla::ffi::DataType weight_dtype;
+
+ int64_t workspace_size; // Convolution only
+ bool deterministic;
+ xla::ffi::DataType idx_dtype;
+ xla::ffi::DataType workspace_dtype;
+
+ KernelProp() {}
+
+ KernelProp(
+ std::unordered_map &kernel_dims, bool is_convolution):
+ L1_dim(kernel_dims.at("L1_dim")),
+ L2_dim(kernel_dims.at("L2_dim")),
+ L3_dim(kernel_dims.at("L3_dim")),
+ weight_numel(kernel_dims.at("weight_numel")),
+ shared_weights(kernel_dims.at("shared_weights")),
+ irrep_dtype(enum_to_xla_dtype(kernel_dims.at("irrep_dtype"))),
+ weight_dtype(enum_to_xla_dtype(kernel_dims.at("weight_dtype"))),
+ workspace_dtype(xla::ffi::DataType::U8) {
+ if(is_convolution) {
+ workspace_size = kernel_dims.at("workspace_size");
+ deterministic = kernel_dims.at("deterministic");
+ idx_dtype = enum_to_xla_dtype(kernel_dims.at("idx_dtype"));
+ }
+ }
+};
+
+std::unordered_map>,
+ KernelProp
+ >> tp_cache;
+
+std::unordered_map>,
+ KernelProp
+ >> conv_cache;
+std::mutex mut;
+
+std::vector launch_config_keys = {
+ "num_blocks",
+ "num_threads",
+ "smem"};
+std::vector kernel_prop_keys = {
+ "L1_dim",
+ "L2_dim",
+ "L3_dim",
+ "weight_numel",
+ "shared_weights",
+ "opt_level",
+ "irrep_dtype",
+ "weight_dtype",
+
+ // Convolution only
+ "workspace_size",
+ "deterministic",
+ "idx_dtype"};
+
+std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) {
+ std::unordered_map result;
+ for (const auto &key : keys) {
+ result[key] = dict.get(key).value();
+ }
+ return result;
+}
+
+std::pair*, KernelProp>
+ compile_tp_with_caching(std::string_view kernel,
+ ffi::Dictionary forward_config,
+ ffi::Dictionary backward_config,
+ ffi::Dictionary double_backward_config,
+ ffi::Dictionary kernel_prop,
+ int64_t hash,
+ bool is_convolution) {
+
+ {
+ const std::lock_guard lock(mut);
+ auto it = tp_cache.find(hash);
+ if (it == tp_cache.end()) {
+ auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys);
+ auto jit_tp_impl = std::make_unique>(
+ std::string(kernel),
+ parse_ffi_dict(forward_config, launch_config_keys),
+ parse_ffi_dict(backward_config, launch_config_keys),
+ parse_ffi_dict(double_backward_config, launch_config_keys),
+ kernel_prop_map);
+ tp_cache.insert({hash,
+ std::make_pair(std::move(jit_tp_impl),
+ KernelProp(kernel_prop_map, is_convolution))});
+ it = tp_cache.find(hash);
+ }
+ return {it->second.first.get(), it->second.second};
+ }
+}
+
+std::pair*, KernelProp>
+ compile_conv_with_caching(std::string_view kernel,
+ ffi::Dictionary forward_config,
+ ffi::Dictionary backward_config,
+ ffi::Dictionary double_backward_config,
+ ffi::Dictionary kernel_prop,
+ int64_t hash,
+ bool is_convolution) {
+
+ {
+ const std::lock_guard lock(mut);
+ auto it = conv_cache.find(hash);
+ if (it == conv_cache.end()) {
+ auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys);
+ auto jit_conv_impl = std::make_unique>(
+ std::string(kernel),
+ parse_ffi_dict(forward_config, launch_config_keys),
+ parse_ffi_dict(backward_config, launch_config_keys),
+ parse_ffi_dict(double_backward_config, launch_config_keys),
+ kernel_prop_map);
+ conv_cache.insert({hash,
+ std::make_pair(std::move(jit_conv_impl),
+ KernelProp(kernel_prop_map, is_convolution))});
+ it = conv_cache.find(hash);
+ }
+ return {it->second.first.get(), it->second.second};
+ }
+}
+
+inline void check_tensor(const ffi::AnyBuffer &buffer,
+ std::initializer_list expected_shape,
+ xla::ffi::DataType expected_dtype,
+ std::string tensor_name) {
+ const ffi::AnyBuffer::Dimensions dims = buffer.dimensions();
+ if (dims.size() != expected_shape.size()) {
+ throw std::logic_error("Rank mismatch for tensor '"
+ + tensor_name
+ + "'. Expected rank "
+ + std::to_string(expected_shape.size())
+ + ", got rank "
+ + std::to_string(dims.size()));
+ }
+
+ for (size_t i = 0; i < dims.size(); i++) {
+ if (dims[i] != expected_shape.begin()[i]) {
+ throw std::logic_error("Shape mismatch for tensor '"
+ + tensor_name
+ + "'. Expected dimension "
+ + std::to_string(expected_shape.begin()[i])
+ + " at index "
+ + std::to_string(i)
+ + ", got "
+ + std::to_string(dims[i]));
+ }
+ }
+
+ if (buffer.element_type() != expected_dtype) {
+ throw std::logic_error("Datatype mismatch for tensor " + tensor_name +
+ ". Expected datatype " + xla_dtype_to_string(expected_dtype) +
+ ", got " + xla_dtype_to_string(buffer.element_type()));
+ }
+}
+
+// --------------------- Tensor Products --------------------------
+ffi::Error tp_forward_impl(
+ ffi::AnyBuffer L1_in,
+ ffi::AnyBuffer L2_in,
+ ffi::AnyBuffer W,
+ ffi::Result L3_out,
+ stream_t stream,
+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
+ int64_t hash) {
+
+ auto [jit_kernel, k] = compile_tp_with_caching(
+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
+ const int64_t num_batch = L1_in.dimensions()[0];
+
+ check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
+ check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in");
+
+ if (k.shared_weights)
+ check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
+ else
+ check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W");
+
+ jit_kernel->exec_tensor_product(
+ num_batch,
+ data_ptr(L1_in),
+ data_ptr(L2_in),
+ data_ptr(L3_out),
+ data_ptr(W),
+ stream);
+
+ return ffi::Error::Success();
+}
+
+ffi::Error tp_backward_impl(
+ ffi::AnyBuffer L1_in,
+ ffi::AnyBuffer L2_in,
+ ffi::AnyBuffer W,
+ ffi::AnyBuffer L3_grad,
+ ffi::Result L1_grad,
+ ffi::Result L2_grad,
+ ffi::Result W_grad,
+ stream_t stream,
+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
+ int64_t hash) {
+
+ auto [jit_kernel, k] = compile_tp_with_caching(
+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
+ const int64_t num_batch = L1_in.dimensions()[0];
+ check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
+ check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in");
+ check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad");
+
+ if (k.shared_weights) {
+ check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad");
+ }
+ else {
+ check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(*W_grad, {num_batch, k.weight_numel}, k.weight_dtype, "W_grad");
+ }
+
+ if (k.shared_weights) {
+ zero_buffer(*W_grad, stream);
+ }
+
+ jit_kernel->backward(
+ num_batch,
+ data_ptr(L1_in),
+ data_ptr(L1_grad),
+ data_ptr(L2_in),
+ data_ptr(L2_grad),
+ data_ptr(W),
+ data_ptr(W_grad),
+ data_ptr(L3_grad),
+ stream);
+ return ffi::Error::Success();
+}
+
+
+ffi::Error tp_double_backward_impl(
+ ffi::AnyBuffer L1_in,
+ ffi::AnyBuffer L2_in,
+ ffi::AnyBuffer W,
+ ffi::AnyBuffer L3_grad,
+ ffi::AnyBuffer L1_dgrad,
+ ffi::AnyBuffer L2_dgrad,
+ ffi::AnyBuffer W_dgrad,
+ ffi::Result L1_grad,
+ ffi::Result L2_grad,
+ ffi::Result W_grad,
+ ffi::Result L3_dgrad,
+ stream_t stream,
+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
+ int64_t hash) {
+
+ auto [jit_kernel, k] = compile_tp_with_caching(
+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
+ const int64_t num_batch = L1_in.dimensions()[0];
+ check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
+ check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in");
+ check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad");
+ check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad");
+ check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad");
+
+ if (k.shared_weights){
+ check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad");
+ } else {
+ check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad");
+ }
+
+ if (k.shared_weights) {
+ zero_buffer(*W_grad, stream);
+ }
+
+ jit_kernel->double_backward(
+ num_batch,
+ data_ptr(L1_in),
+ data_ptr(L2_in),
+ data_ptr(W),
+ data_ptr(L3_grad),
+ data_ptr(L1_dgrad),
+ data_ptr(L2_dgrad),
+ data_ptr(W_dgrad),
+ data_ptr(L1_grad),
+ data_ptr(L2_grad),
+ data_ptr(W_grad),
+ data_ptr(L3_dgrad),
+ stream);
+ return ffi::Error::Success();
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ tp_forward, tp_forward_impl,
+ ffi::Ffi::Bind()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ret()
+ .Ctx>()
+ .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop")
+ .Attr("hash"),
+ {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ tp_backward, tp_backward_impl,
+ ffi::Ffi::Bind()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ret()
+ .Ret()
+ .Ret()
+ .Ctx>()
+ .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop")
+ .Attr("hash"),
+ {xla::ffi::Traits::kCmdBufferCompatible});
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ tp_double_backward, tp_double_backward_impl,
+ ffi::Ffi::Bind()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ret()
+ .Ret()
+ .Ret()
+ .Ret()
+ .Ctx>()
+ .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop")
+ .Attr("hash"),
+ {xla::ffi::Traits::kCmdBufferCompatible});
+
+// --------------------- Convolution --------------------------
+ffi::Error conv_forward_impl(
+ ffi::AnyBuffer L1_in,
+ ffi::AnyBuffer L2_in,
+ ffi::AnyBuffer W,
+ ffi::AnyBuffer rows,
+ ffi::AnyBuffer cols,
+ ffi::AnyBuffer workspace,
+ ffi::AnyBuffer transpose_perm,
+ ffi::Result L3_out,
+ stream_t stream,
+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
+ int64_t hash) {
+
+ auto [jit_kernel, k] = compile_conv_with_caching(
+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true);
+ const int64_t nnz = rows.dimensions()[0];
+ const int64_t node_count = L1_in.dimensions()[0];
+ void* workspace_ptr = data_ptr(workspace);
+
+ check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in");
+ check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in");
+ check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace");
+ check_tensor(rows, {nnz}, k.idx_dtype, "rows");
+ check_tensor(cols, {nnz}, k.idx_dtype, "cols");
+
+ if (k.deterministic){
+ check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
+ }
+ else {
+ workspace_ptr = nullptr;
+ }
+ zero_buffer(*L3_out, stream);
+
+ if (k.shared_weights)
+ check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
+ else
+ check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W");
+
+ jit_kernel->exec_conv(
+ data_ptr(L1_in),
+ data_ptr(L2_in),
+ data_ptr(W),
+ data_ptr(L3_out),
+ data_ptr(rows),
+ data_ptr(cols),
+ nnz, node_count,
+ workspace_ptr,
+ stream);
+
+ return ffi::Error::Success();
+}
+
+ffi::Error conv_backward_impl(
+ ffi::AnyBuffer L1_in,
+ ffi::AnyBuffer L2_in,
+ ffi::AnyBuffer W,
+ ffi::AnyBuffer L3_grad,
+ ffi::Result L1_grad,
+ ffi::Result L2_grad,
+ ffi::Result W_grad,
+ ffi::AnyBuffer rows,
+ ffi::AnyBuffer cols,
+ ffi::AnyBuffer workspace,
+ ffi::AnyBuffer transpose_perm,
+ stream_t stream,
+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
+ int64_t hash) {
+
+ auto [jit_kernel, k] = compile_conv_with_caching(
+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true);
+ const int64_t nnz = rows.dimensions()[0];
+ const int64_t node_count = L1_in.dimensions()[0];
+ void* workspace_ptr = data_ptr(workspace);
+
+ check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in");
+ check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in");
+ check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad");
+ check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace");
+ check_tensor(rows, {nnz}, k.idx_dtype, "rows");
+ check_tensor(cols, {nnz}, k.idx_dtype, "cols");
+
+ if (k.deterministic) {
+ check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
+ }
+ else {
+ workspace_ptr = nullptr;
+ }
+ zero_buffer(*L1_grad, stream);
+
+ if (k.shared_weights) {
+ check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad");
+ }
+ else {
+ check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad");
+ }
+ if(k.shared_weights)
+ zero_buffer(*W_grad, stream);
+
+ jit_kernel->backward(
+ data_ptr(L1_in),
+ data_ptr(L1_grad),
+ data_ptr(L2_in),
+ data_ptr(L2_grad),
+ data_ptr(W),
+ data_ptr(W_grad),
+ data_ptr(L3_grad),
+ data_ptr(rows),
+ data_ptr(cols),
+ nnz, node_count,
+ workspace_ptr,
+ data_ptr(transpose_perm),
+ stream);
+ return ffi::Error::Success();
+}
+
+ffi::Error conv_double_backward_impl(
+ ffi::AnyBuffer L1_in,
+ ffi::AnyBuffer L2_in,
+ ffi::AnyBuffer W,
+ ffi::AnyBuffer L3_grad,
+ ffi::AnyBuffer L1_dgrad,
+ ffi::AnyBuffer L2_dgrad,
+ ffi::AnyBuffer W_dgrad,
+ ffi::Result L1_grad,
+ ffi::Result L2_grad,
+ ffi::Result W_grad,
+ ffi::Result L3_dgrad,
+ ffi::AnyBuffer rows,
+ ffi::AnyBuffer cols,
+ ffi::AnyBuffer workspace,
+ ffi::AnyBuffer transpose_perm,
+ stream_t stream,
+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
+ int64_t hash) {
+
+ auto [jit_kernel, k] = compile_conv_with_caching(
+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true);
+ const int64_t nnz = rows.dimensions()[0];
+ const int64_t node_count = L1_in.dimensions()[0];
+ void* workspace_ptr = data_ptr(workspace);
+
+ check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in");
+ check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in");
+ check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad");
+ check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad");
+ check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad");
+ check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace");
+ check_tensor(rows, {nnz}, k.idx_dtype, "rows");
+ check_tensor(cols, {nnz}, k.idx_dtype, "cols");
+
+ if (k.deterministic) {
+ check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
+ }
+ else {
+ workspace_ptr = nullptr;
+ }
+ zero_buffer(*L1_grad, stream);
+ zero_buffer(*L3_dgrad, stream);
+
+
+ if (k.shared_weights) {
+ check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad");
+ } else {
+ check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W");
+ check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad");
+ }
+ if(k.shared_weights)
+ zero_buffer(*W_grad, stream);
+
+ jit_kernel->double_backward(
+ data_ptr(L1_in),
+ data_ptr(L2_in),
+ data_ptr(W),
+ data_ptr(L3_grad),
+ data_ptr(L1_dgrad),
+ data_ptr(L2_dgrad),
+ data_ptr(W_dgrad),
+ data_ptr(L1_grad),
+ data_ptr(L2_grad),
+ data_ptr(W_grad),
+ data_ptr(L3_dgrad),
+ data_ptr(rows),
+ data_ptr(cols),
+ nnz, node_count,
+ workspace_ptr,
+ data_ptr(transpose_perm),
+ stream);
+ return ffi::Error::Success();
+}
+
+bool is_hip() {
+#ifdef HIP_BACKEND
+ return true;
+#else
+ return false;
+#endif
+}
+
+// --------------------- FFI Bindings --------------------------
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ conv_forward, conv_forward_impl,
+ ffi::Ffi::Bind()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ret()
+ .Ctx>()
+ .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop")
+ .Attr("hash"),
+ {xla::ffi::Traits::kCmdBufferCompatible});
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ conv_backward, conv_backward_impl,
+ ffi::Ffi::Bind()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ret()
+ .Ret()
+ .Ret()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ctx>()
+ .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop")
+ .Attr("hash"),
+ {xla::ffi::Traits::kCmdBufferCompatible});
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ conv_double_backward, conv_double_backward_impl,
+ ffi::Ffi::Bind()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ret()
+ .Ret()
+ .Ret()
+ .Ret()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Arg()
+ .Ctx>()
+ .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop")
+ .Attr("hash"),
+ {xla::ffi::Traits::kCmdBufferCompatible});
+
+// --------------------- NB Module --------------------------
+NB_MODULE(openequivariance_extjax, m) {
+ m.def("registrations", []() {
+ nb::dict registrations;
+ registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward));
+ registrations["tp_backward"] = nb::capsule(reinterpret_cast(tp_backward));
+ registrations["tp_double_backward"] = nb::capsule(reinterpret_cast(tp_double_backward));
+
+ registrations["conv_forward"] = nb::capsule(reinterpret_cast(conv_forward));
+ registrations["conv_backward"] = nb::capsule(reinterpret_cast(conv_backward));
+ registrations["conv_double_backward"] = nb::capsule(reinterpret_cast(conv_double_backward));
+ return registrations;
+ });
+ m.def("is_hip", &is_hip);
+
+ nb::class_(m, "DeviceProp")
+ .def(nb::init())
+ .def_ro("name", &DeviceProp::name)
+ .def_ro("warpsize", &DeviceProp::warpsize)
+ .def_ro("major", &DeviceProp::major)
+ .def_ro("minor", &DeviceProp::minor)
+ .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount)
+ .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock);
+
+ nb::class_(m, "GPUTimer")
+ .def(nb::init<>())
+ .def("start", &GPUTimer::start)
+ .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed)
+ .def("clear_L2_cache", &GPUTimer::clear_L2_cache);
+
+ /*nb::class_>(m, "DeviceBuffer")
+ .def(nb::init())
+ .def(nb::init())
+ .def("copy_to_host", &PyDeviceBuffer::copy_to_host)
+ .def("data_ptr", &PyDeviceBuffer::data_ptr);*/
+}
diff --git a/tests/batch_test.py b/tests/batch_test.py
index 3c7cdf27..f32f7b51 100644
--- a/tests/batch_test.py
+++ b/tests/batch_test.py
@@ -3,7 +3,6 @@
import numpy as np
import openequivariance as oeq
-from openequivariance.implementations.TensorProduct import TensorProduct
from openequivariance.benchmark.correctness_utils import (
correctness_forward,
correctness_backward,
@@ -41,8 +40,17 @@ def extra_tp_constructor_args(self):
return {}
@pytest.fixture(scope="class")
- def tp_and_problem(self, problem, extra_tp_constructor_args):
- tp = TensorProduct(problem, **extra_tp_constructor_args)
+ def with_jax(self, request):
+ return request.config.getoption("--jax")
+
+ @pytest.fixture(scope="class")
+ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
+ cls = oeq.TensorProduct
+ if with_jax:
+ import openequivariance.jax.TensorProduct as jax_tp
+
+ cls = jax_tp
+ tp = cls(problem, **extra_tp_constructor_args)
return tp, problem
def test_tp_fwd(self, tp_and_problem):
@@ -247,7 +255,9 @@ def problem(self, request, dtype):
class TestTorchbindDisable(TestProductionModels):
@pytest.fixture(scope="class")
- def extra_tp_constructor_args(self):
+ def extra_tp_constructor_args(self, with_jax):
+ if with_jax:
+ pytest.skip("N/A for JAX")
return {"use_opaque": True}
@@ -261,11 +271,14 @@ def problem(self, request, dtype):
return problem
@pytest.fixture(scope="class")
- def tp_and_problem(self, problem, extra_tp_constructor_args):
- tp = TensorProduct(problem, **extra_tp_constructor_args)
- switch_map = {
- np.float32: torch.float64,
- np.float64: torch.float32,
- }
- tp.to(switch_map[problem.irrep_dtype])
- return tp, tp.config
+ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
+ if with_jax:
+ pytest.skip("N/A for JAX")
+ else:
+ tp = oeq.TensorProduct(problem, **extra_tp_constructor_args)
+ switch_map = {
+ np.float32: torch.float64,
+ np.float64: torch.float32,
+ }
+ tp.to(switch_map[problem.irrep_dtype])
+ return tp, tp.config
diff --git a/tests/benchmark.py b/tests/benchmark.py
index ab005cdf..829cc46c 100644
--- a/tests/benchmark.py
+++ b/tests/benchmark.py
@@ -11,14 +11,14 @@
import numpy as np
from openequivariance.benchmark.logging_utils import getLogger
-from openequivariance.extlib import DeviceProp
-from openequivariance.implementations.E3NNTensorProduct import (
+from openequivariance._torch.extlib import DeviceProp
+from openequivariance._torch.E3NNTensorProduct import (
E3NNTensorProduct,
E3NNTensorProductCompiledCUDAGraphs,
E3NNTensorProductCompiledMaxAutotuneCUDAGraphs,
)
-from openequivariance.implementations.TensorProduct import TensorProduct
-from openequivariance.implementations.CUETensorProduct import CUETensorProduct
+from openequivariance._torch.TensorProduct import TensorProduct
+from openequivariance._torch.CUETensorProduct import CUETensorProduct
from openequivariance.benchmark.TestBenchmarkSuite import (
TestBenchmarkSuite,
TestDefinition,
@@ -30,15 +30,15 @@
SingleInstruction,
)
-from openequivariance.implementations.convolution.TensorProductConv import (
+from openequivariance._torch.TensorProductConv import (
TensorProductConvAtomic,
TensorProductConvDeterministic,
TensorProductConvKahan,
TensorProductConvScatterSum,
)
-from openequivariance.implementations.convolution.CUEConv import CUEConv, CUEConvFused
-from openequivariance.implementations.convolution.FlashTPConv import FlashTPConv
+from openequivariance._torch.CUEConv import CUEConv, CUEConvFused
+from openequivariance._torch.FlashTPConv import FlashTPConv
from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph
from openequivariance.benchmark.problems import (
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..0e7098e0
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,13 @@
+import os
+
+os.environ["JAX_ENABLE_X64"] = "True"
+os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--jax",
+ action="store_true",
+ default=False,
+ help="Test the JAX frontend instead of PyTorch",
+ )
diff --git a/tests/conv_test.py b/tests/conv_test.py
index 14c8c3d2..9c6bb4c8 100644
--- a/tests/conv_test.py
+++ b/tests/conv_test.py
@@ -51,23 +51,29 @@ def graph(self, request):
def extra_conv_constructor_args(self):
return {}
+ @pytest.fixture(scope="class")
+ def with_jax(self, request):
+ return request.config.getoption("--jax")
+
@pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class")
- def conv_object(self, request, problem, extra_conv_constructor_args):
+ def conv_object(self, request, problem, extra_conv_constructor_args, with_jax):
+ cls = oeq.TensorProductConv
+ if with_jax:
+ from openequivariance.jax import TensorProductConv as jax_conv
+
+ cls = jax_conv
+
if request.param == "atomic":
- return oeq.TensorProductConv(
- problem, deterministic=False, **extra_conv_constructor_args
- )
+ return cls(problem, deterministic=False, **extra_conv_constructor_args)
elif request.param == "deterministic":
if not problem.shared_weights:
- return oeq.TensorProductConv(
- problem, deterministic=True, **extra_conv_constructor_args
- )
+ return cls(problem, deterministic=True, **extra_conv_constructor_args)
else:
pytest.skip("Shared weights not supported with deterministic")
elif request.param == "kahan":
if problem.irrep_dtype == np.float32:
if not problem.shared_weights:
- return oeq.TensorProductConv(
+ return cls(
problem,
deterministic=True,
kahan=True,
@@ -228,7 +234,7 @@ def thresh(self, direction):
return {
"fwd": 1e-5,
"bwd": 7.5e-2, # Expect higher errors for shared weights
- "double_bwd": 5e-2,
+ "double_bwd": 5e-1,
}[direction]
@pytest.fixture(params=problems, ids=lambda x: x.label, scope="class")
@@ -245,7 +251,9 @@ def conv_object(self, request, problem):
class TestTorchbindDisable(TestProductionModels):
@pytest.fixture(scope="class")
- def extra_conv_constructor_args(self):
+ def extra_conv_constructor_args(self, with_jax):
+ if with_jax:
+ pytest.skip("N/A for JAX")
return {"use_opaque": True}
@@ -253,7 +261,10 @@ class TestTorchTo(ConvCorrectness):
problems = [mace_problems()[0]]
@pytest.fixture(params=problems, ids=lambda x: x.label, scope="class")
- def problem(self, request, dtype):
+ def problem(self, request, dtype, with_jax):
+ if with_jax:
+ pytest.skip("N/A for JAX")
+
problem = request.param
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
return problem
diff --git a/tests/example_test.py b/tests/example_test.py
new file mode 100644
index 00000000..ae19f77e
--- /dev/null
+++ b/tests/example_test.py
@@ -0,0 +1,163 @@
+import pytest
+import os
+
+
+@pytest.fixture
+def with_jax(request):
+ return request.config.getoption("--jax")
+
+
+def test_tutorial_torch(with_jax):
+ if with_jax:
+ pytest.skip("Skipping PyTorch tutorial when testing JAX")
+
+ import torch
+ import e3nn.o3 as o3
+
+ gen = torch.Generator(device="cuda")
+
+ batch_size = 1000
+ X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
+ X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
+ Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
+
+ instructions = [(0, 0, 0, "uvu", True)]
+
+ tp_e3nn = o3.TensorProduct(
+ X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
+ ).to("cuda")
+ W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)
+
+ Z = tp_e3nn(X, Y, W)
+ print(torch.norm(Z))
+ # ===============================
+
+ # ===============================
+ import openequivariance as oeq
+
+ problem = oeq.TPProblem(
+ X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
+ )
+ tp_fast = oeq.TensorProduct(problem)
+
+ Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
+ print(torch.norm(Z))
+ # ===============================
+
+ # Graph Convolution
+ # ===============================
+ from torch_geometric import EdgeIndex
+
+ node_ct, nonzero_ct = 3, 4
+
+ # Receiver, sender indices for message passing GNN
+ edge_index = EdgeIndex(
+ [
+ [0, 1, 1, 2], # Receiver
+ [1, 0, 2, 1],
+ ], # Sender
+ device="cuda",
+ dtype=torch.long,
+ )
+
+ X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
+ Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
+ W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
+
+ tp_conv = oeq.TensorProductConv(
+ problem, deterministic=False
+ ) # Reuse problem from earlier
+ Z = tp_conv.forward(
+ X, Y, W, edge_index[0], edge_index[1]
+ ) # Z has shape [node_ct, z_ir.dim]
+ print(torch.norm(Z))
+ # ===============================
+
+ # ===============================
+ _, sender_perm = edge_index.sort_by("col") # Sort by sender index
+ edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
+
+ # Now we can use the faster deterministic algorithm
+ tp_conv = oeq.TensorProductConv(problem, deterministic=True)
+ Z = tp_conv.forward(
+ X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
+ )
+ print(torch.norm(Z))
+ # ===============================
+ assert True
+
+
+def test_tutorial_jax(with_jax):
+ if not with_jax:
+ pytest.skip("Skipping JAX tutorial when testing PyTorch")
+
+ os.environ["OEQ_NOTORCH"] = "1"
+ import openequivariance as oeq
+ import jax
+
+ seed = 42
+ key = jax.random.PRNGKey(seed)
+
+ batch_size = 1000
+ X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e")
+ instructions = [(0, 0, 0, "uvu", True)]
+
+ problem = oeq.TPProblem(
+ X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
+ )
+ tp_fast = oeq.jax.TensorProduct(problem)
+
+ X = jax.random.uniform(
+ key,
+ shape=(batch_size, X_ir.dim),
+ minval=0.0,
+ maxval=1.0,
+ dtype=jax.numpy.float32,
+ )
+ Y = jax.random.uniform(
+ key,
+ shape=(batch_size, Y_ir.dim),
+ minval=0.0,
+ maxval=1.0,
+ dtype=jax.numpy.float32,
+ )
+ W = jax.random.uniform(
+ key,
+ shape=(batch_size, tp_fast.weight_numel),
+ minval=0.0,
+ maxval=1.0,
+ dtype=jax.numpy.float32,
+ )
+
+ Z = tp_fast(X, Y, W)
+ print(jax.numpy.linalg.norm(Z))
+
+ edge_index = jax.numpy.array(
+ [
+ [0, 1, 1, 2],
+ [1, 0, 2, 1],
+ ],
+ dtype=jax.numpy.int32, # NOTE: This int32, not int64
+ )
+
+ node_ct, nonzero_ct = 3, 4
+ X = jax.random.uniform(
+ key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32
+ )
+ Y = jax.random.uniform(
+ key,
+ shape=(nonzero_ct, Y_ir.dim),
+ minval=0.0,
+ maxval=1.0,
+ dtype=jax.numpy.float32,
+ )
+ W = jax.random.uniform(
+ key,
+ shape=(nonzero_ct, problem.weight_numel),
+ minval=0.0,
+ maxval=1.0,
+ dtype=jax.numpy.float32,
+ )
+ tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
+ Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1])
+ print(jax.numpy.linalg.norm(Z))
diff --git a/tests/examples_test.py b/tests/examples_test.py
deleted file mode 100644
index 3beaabb1..00000000
--- a/tests/examples_test.py
+++ /dev/null
@@ -1,75 +0,0 @@
-def test_tutorial():
- import torch
- import e3nn.o3 as o3
-
- gen = torch.Generator(device="cuda")
-
- batch_size = 1000
- X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
- X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
- Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
-
- instructions = [(0, 0, 0, "uvu", True)]
-
- tp_e3nn = o3.TensorProduct(
- X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
- ).to("cuda")
- W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)
-
- Z = tp_e3nn(X, Y, W)
- print(torch.norm(Z))
- # ===============================
-
- # ===============================
- import openequivariance as oeq
-
- problem = oeq.TPProblem(
- X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
- )
- tp_fast = oeq.TensorProduct(problem, torch_op=True)
-
- Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
- print(torch.norm(Z))
- # ===============================
-
- # Graph Convolution
- # ===============================
- from torch_geometric import EdgeIndex
-
- node_ct, nonzero_ct = 3, 4
-
- # Receiver, sender indices for message passing GNN
- edge_index = EdgeIndex(
- [
- [0, 1, 1, 2], # Receiver
- [1, 0, 2, 1],
- ], # Sender
- device="cuda",
- dtype=torch.long,
- )
-
- X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
- Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
- W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
-
- tp_conv = oeq.TensorProductConv(
- problem, torch_op=True, deterministic=False
- ) # Reuse problem from earlier
- Z = tp_conv.forward(
- X, Y, W, edge_index[0], edge_index[1]
- ) # Z has shape [node_ct, z_ir.dim]
- print(torch.norm(Z))
- # ===============================
-
- # ===============================
- _, sender_perm = edge_index.sort_by("col") # Sort by sender index
- edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
-
- # Now we can use the faster deterministic algorithm
- tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
- Z = tp_conv.forward(
- X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
- )
- print(torch.norm(Z))
- # ===============================
- assert True
diff --git a/tests/export_test.py b/tests/export_test.py
index e18b38b1..0fd23b2b 100644
--- a/tests/export_test.py
+++ b/tests/export_test.py
@@ -11,7 +11,7 @@
from torch_geometric import EdgeIndex
import importlib.resources
-from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
+from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
@pytest.fixture(scope="session")