Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
b6a5be0
Bare bones of a JAX extension.
vbharadwaj-bk Nov 23, 2025
8efa3a3
Continued progress.
vbharadwaj-bk Nov 23, 2025
4be44d7
More changes.
vbharadwaj-bk Nov 23, 2025
b758e4b
Added all relevant parameters.
vbharadwaj-bk Nov 23, 2025
b0fb34c
Should be able to compile a kernel.
vbharadwaj-bk Nov 23, 2025
d932d21
Cleaned up code a bit.
vbharadwaj-bk Nov 23, 2025
1069ae3
Reorganized the repo.
vbharadwaj-bk Nov 24, 2025
0525b34
Refactored imports.
vbharadwaj-bk Nov 24, 2025
33f8232
Tests are passing after the first refactor.
vbharadwaj-bk Nov 24, 2025
121cb42
Nested directory structure by one level.
vbharadwaj-bk Nov 25, 2025
45348a2
Temp commit.
vbharadwaj-bk Nov 25, 2025
583c249
Got the editable install working again.
vbharadwaj-bk Nov 25, 2025
e1aa951
Extension module in progress.
vbharadwaj-bk Nov 25, 2025
d7aa1ea
More things working.
vbharadwaj-bk Nov 25, 2025
d2cec50
Began putting together a test rig.
vbharadwaj-bk Nov 25, 2025
6a75711
Things starting to work.
vbharadwaj-bk Nov 25, 2025
f07592f
Made LoopUnrollTP generic.
vbharadwaj-bk Nov 25, 2025
a450604
More things are working.
vbharadwaj-bk Nov 25, 2025
59bc57c
More plumbing.
vbharadwaj-bk Nov 25, 2025
8f7c675
More progress.
vbharadwaj-bk Nov 27, 2025
41e7cd7
Dispatch complete.
vbharadwaj-bk Nov 27, 2025
5c7a828
Forward call is working.
vbharadwaj-bk Nov 27, 2025
6790bd0
Added the backward pass.
vbharadwaj-bk Nov 27, 2025
c15f4f7
Encapsulated the forward call.
vbharadwaj-bk Nov 27, 2025
bfa52a5
Skeleton of rule implemented.
vbharadwaj-bk Nov 27, 2025
d1131fa
Backward call is working.
vbharadwaj-bk Dec 1, 2025
136c9f6
Zero'd buffer.
vbharadwaj-bk Dec 1, 2025
2dadb0f
Wrapped the double-backward pass.
vbharadwaj-bk Dec 1, 2025
e78f705
Added the forward convolution implementation.
vbharadwaj-bk Dec 1, 2025
8784dd4
Backward convolution implemented.
vbharadwaj-bk Dec 1, 2025
ac9b3db
Convolution double backward registered.
vbharadwaj-bk Dec 1, 2025
63ed1c0
Finished the double backward VJP registration.
vbharadwaj-bk Dec 1, 2025
673b5ee
Double backward pass seems to work.
vbharadwaj-bk Dec 1, 2025
7b1ce90
Did some extra testing.
vbharadwaj-bk Dec 1, 2025
38017b0
Reorg of LoopUnrollConv.py
vbharadwaj-bk Dec 1, 2025
865ca13
Convolution changed.
vbharadwaj-bk Dec 1, 2025
b9c9135
Finished prototype of TensorProductConv.
vbharadwaj-bk Dec 2, 2025
4ea49dc
Added some type annotations.
vbharadwaj-bk Dec 2, 2025
745e4e0
Finished the forward call.
vbharadwaj-bk Dec 2, 2025
19f284b
Ready to start JAX support.
vbharadwaj-bk Dec 2, 2025
0d07cd9
More plumbing.
vbharadwaj-bk Dec 2, 2025
ce68f69
Forward call is working.
vbharadwaj-bk Dec 2, 2025
d94db28
Registered the VJP rules for backward and double-backward.
vbharadwaj-bk Dec 2, 2025
2524f2a
Added __call__ functions.
vbharadwaj-bk Dec 2, 2025
a1b6248
Prepping to add tests.
vbharadwaj-bk Dec 2, 2025
427fdcb
Ran ruff.
vbharadwaj-bk Dec 2, 2025
f16c622
Moved tests back.
vbharadwaj-bk Dec 3, 2025
fa42654
1/3 tests is passing.
vbharadwaj-bk Dec 3, 2025
7f4ac06
Backward test is passing.
vbharadwaj-bk Dec 3, 2025
617d996
Backward convolution is failing, need to figure out why.
vbharadwaj-bk Dec 3, 2025
1b0deb0
Zerod gradient buffer.
vbharadwaj-bk Dec 4, 2025
d924503
Abstracted away reordering.
vbharadwaj-bk Dec 7, 2025
6452140
Added JAX reordering function.
vbharadwaj-bk Dec 7, 2025
64c5c56
Reordering starting to work...
vbharadwaj-bk Dec 7, 2025
d815424
Forward and backward are working.
vbharadwaj-bk Dec 7, 2025
c3f83ea
Batch test is working.
vbharadwaj-bk Dec 7, 2025
58b7957
Ready to modify the double backward correctness function.
vbharadwaj-bk Dec 8, 2025
4dc31dc
Correctness double backward works for existing code, need to extend t…
vbharadwaj-bk Dec 8, 2025
79522a1
Wrote double backward function for JAX.
vbharadwaj-bk Dec 8, 2025
71ca862
All double backward tests passing.
vbharadwaj-bk Dec 8, 2025
e140c07
Added the mixins.
vbharadwaj-bk Dec 8, 2025
8a0094a
Added double backward CPU function to jax TP conv.
vbharadwaj-bk Dec 8, 2025
61e0566
Almost there, need to get TensorProductConv working.
vbharadwaj-bk Dec 8, 2025
ab83aef
Double backward tests are passing.
vbharadwaj-bk Dec 25, 2025
50e0fcc
Updated documentation.
vbharadwaj-bk Dec 25, 2025
44af17b
Modified documentation.
vbharadwaj-bk Dec 25, 2025
8caa93e
Updated documentation.
vbharadwaj-bk Dec 25, 2025
9d6e30e
More documentation progress.
vbharadwaj-bk Dec 25, 2025
b7af425
Renamed.
vbharadwaj-bk Dec 25, 2025
1bcea33
Renaming + added JAX example.
vbharadwaj-bk Dec 25, 2025
259ea20
JAX example.
vbharadwaj-bk Dec 25, 2025
ab87185
Added examples.
vbharadwaj-bk Dec 25, 2025
9d8f5d8
Updated README.
vbharadwaj-bk Dec 25, 2025
7acad2e
Updated release file.
vbharadwaj-bk Dec 25, 2025
f659115
Linted.
vbharadwaj-bk Dec 25, 2025
dc45a8f
Updated the build verification.
vbharadwaj-bk Dec 25, 2025
140b39d
Merged.
vbharadwaj-bk Dec 25, 2025
1817a2a
Merge complete.
vbharadwaj-bk Dec 25, 2025
a030fb5
Updated README.
vbharadwaj-bk Dec 25, 2025
6edad89
Fixed some minor issues.
vbharadwaj-bk Dec 25, 2025
895ad78
Added symlinks.
vbharadwaj-bk Dec 25, 2025
2b2c156
Cleaning up the core.
vbharadwaj-bk Dec 25, 2025
c28f2b9
More core cleanup.
vbharadwaj-bk Dec 25, 2025
52cf2ce
Rename.
vbharadwaj-bk Dec 25, 2025
b2145cd
Example test is working.
vbharadwaj-bk Dec 26, 2025
3f59532
Sanded away some more issues.
vbharadwaj-bk Dec 26, 2025
e49ad88
Updated changelog.
vbharadwaj-bk Dec 26, 2025
d5650b4
Pre-commit.
vbharadwaj-bk Dec 26, 2025
14f642d
Download XLA directly.
vbharadwaj-bk Dec 26, 2025
0e71c54
Removed need for build isolation.
vbharadwaj-bk Dec 26, 2025
85c2e2f
Removed need for build isolation.
vbharadwaj-bk Dec 26, 2025
0efef5b
Updated README.
vbharadwaj-bk Dec 26, 2025
37f4891
Updated documentation slightly.
vbharadwaj-bk Dec 26, 2025
5d8e42f
Don't need extension source path anymore.
vbharadwaj-bk Dec 26, 2025
01e124d
Removed a spurious import.
vbharadwaj-bk Dec 27, 2025
d21b248
Update Python version and XLA Git tag in CMakeLists
vbharadwaj-bk Jan 4, 2026
c774f0e
Update XLA dir.
vbharadwaj-bk Jan 4, 2026
aa040f1
Removed dependency.
vbharadwaj-bk Jan 4, 2026
fa5236f
Went back to version that disables build isolation.
vbharadwaj-bk Jan 4, 2026
94490dd
Updated README.
vbharadwaj-bk Jan 4, 2026
aea1351
Updated error handling
vbharadwaj-bk Jan 4, 2026
903d597
Last bit of cleanup.
vbharadwaj-bk Jan 4, 2026
1ef97da
Merged with test.
vbharadwaj-bk Jan 4, 2026
6568876
Ruff.
vbharadwaj-bk Jan 4, 2026
24dbb11
Things working for HIP, just need to branch.
vbharadwaj-bk Jan 5, 2026
8c5eadf
Updated CMakeLists.
vbharadwaj-bk Jan 5, 2026
39ac3be
Added pyproject.toml define.
vbharadwaj-bk Jan 5, 2026
c3d4666
Plumbed logic.
vbharadwaj-bk Jan 5, 2026
9e6c56f
Made things compile with HIP.
vbharadwaj-bk Jan 5, 2026
bf4bb89
Updated READMEs.
vbharadwaj-bk Jan 5, 2026
abc409f
Highlight AMD support in changelog.
vbharadwaj-bk Jan 5, 2026
77ef6a9
Ruff.
vbharadwaj-bk Jan 5, 2026
0eeec58
Updated documentation.
vbharadwaj-bk Jan 6, 2026
90b3405
Updated installation instructions.
vbharadwaj-bk Jan 6, 2026
4acd545
More ruff.
vbharadwaj-bk Jan 6, 2026
6f61ee9
Added option for CI.
vbharadwaj-bk Jan 16, 2026
cfe9b67
Ready to go.
vbharadwaj-bk Jan 16, 2026
09d62c5
Enabled direct download XLA.
vbharadwaj-bk Jan 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install sphinx furo
pip install -r docs/requirements.txt
- name: Build website
run: |
sphinx-build -M dirhtml docs docs/_build
Expand Down
54 changes: 47 additions & 7 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
# ref: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/

jobs:
build:
build-oeq:
name: Build distribution
runs-on: ubuntu-latest
steps:
Expand All @@ -16,21 +16,20 @@ jobs:
with:
python-version: '3.10'
- name: install dependencies, then build source tarball
run: |
run: |
cd openequivariance
python3 -m pip install build --user
python3 -m build --sdist
- name: store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
path: openequivariance/dist/

pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
# build task to be completed first
needs: build
# Specifying a GitHub environment is optional, but strongly encouraged
needs: build-oeq
environment:
name: pypi
url: https://pypi.org/p/openequivariance
Expand All @@ -42,6 +41,47 @@ jobs:
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
path: openequivariance/dist/
- name: publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

# ------------------------------------

build-oeq-extjax:
name: Build distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: install dependencies, then build source tarball
run: |
cd openequivariance_extjax
python3 -m pip install build --user
python3 -m build --sdist
- name: store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: openequivariance_extjax/dist/

pypi-publish-extjax:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: build-oeq-extjax
environment:
name: pypi
url: https://pypi.org/p/openequivariance_extjax
permissions:
# IMPORTANT: this permission is mandatory for Trusted Publishing
id-token: write
steps:
- name: download the distributions
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: openequivariance_extjax/dist/
- name: publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
4 changes: 3 additions & 1 deletion .github/workflows/requirements_cuda_ci.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
numpy==2.2.5
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
pytest==8.3.5
ninja==1.11.1.4
ninja==1.11.1.4
nanobind==2.10.2
scikit-build-core==0.11.6
12 changes: 8 additions & 4 deletions .github/workflows/verify_extension_build.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: OEQ CUDA C++ Extension Build Verification
name: OEQ C++ Extension Build Verification

on:
push:
Expand Down Expand Up @@ -29,10 +29,14 @@ jobs:
sudo apt-get update
sudo apt install nvidia-cuda-toolkit
pip install -r .github/workflows/requirements_cuda_ci.txt
pip install -e .
pip install -e "./openequivariance"
- name: Test extension build via import
- name: Test CUDA extension build via import
run: |
pytest \
tests/import_test.py::test_extension_built \
tests/import_test.py::test_torch_extension_built
tests/import_test.py::test_torch_extension_built
- name: Test JAX extension build
run: |
XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ triton_autotuning
paper_benchmarks
paper_benchmarks_v2
paper_benchmarks_v3
openequivariance/extlib/*.so

get_node.sh
*.egg-info
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
## Latest Changes

### v0.5.0 (2025-12-25)
JAX support is now available in
OpenEquivariance for BOTH NVIDIA and
AMD GPUs! See the
[documentation](https://passionlab.github.io/OpenEquivariance/)
and README.md for instructions on installation
and usage.

Minor changes:
- Defer error reporting when CUDA is not available
to the first library usage in code, not library load.

### v0.4.1 (2025-09-04)
Minor update, fixes a bug loading JIT-compiled modules
with PyTorch 2.9.
Expand Down
10 changes: 0 additions & 10 deletions MANIFEST.in

This file was deleted.

74 changes: 66 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# OpenEquivariance
[![OEQ CUDA C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
[![OEQ C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
[![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)

[[Examples]](#show-me-some-examples)
[[PyTorch Examples]](#pytorch-examples)
[[JAX Examples]](#jax-examples)
[[Citation and Acknowledgements]](#citation-and-acknowledgements)

OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
Expand All @@ -12,8 +13,8 @@ that [e3nn](https://e3nn.org/) supports
commonly found in graph neural networks
(e.g. [Nequip](https://github.com/mir-group/nequip) or
[MACE](https://github.com/ACEsuit/mace)). To get
started, ensure that you have GCC 9+ on your system
and install our package via
started with PyTorch, ensure that you have PyTorch
and GCC 9+ available before installing our package via

```bash
pip install openequivariance
Expand All @@ -29,11 +30,26 @@ computation and memory consumption significantly.
For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
check out the [documentation](https://passionlab.github.io/OpenEquivariance).

📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025).
⭐️ **JAX**: Our latest update brings
support for JAX. For NVIDIA GPUs,
install it (after installing JAX)
with the following two commands strictly in order:

## Show me some examples
``` bash
pip install openequivariance[jax]
pip install openequivariance_extjax --no-build-isolation
```

For AMD GPUs:
``` bash
pip install openequivariance[jax]
JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation
```

See the section below for example usage and
our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details.

## PyTorch Examples
Here's a CG tensor product implemented by e3nn:

```python
Expand Down Expand Up @@ -127,6 +143,48 @@ print(torch.norm(Z))
`deterministic=False`, the `sender` and `receiver` indices can have
arbitrary order.

## JAX Examples
After installation, use the library
as follows. Set `OEQ_NOTORCH=1`
in your environment to avoid the PyTorch import in
the regular `openequivariance` package.
```python
import jax
import os

os.environ["OEQ_NOTORCH"] = "1"
import openequivariance as oeq

seed = 42
key = jax.random.PRNGKey(seed)

batch_size = 1000
X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e")
problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False)


node_ct, nonzero_ct = 3, 4
edge_index = jax.numpy.array(
[
[0, 1, 1, 2],
[1, 0, 2, 1],
],
dtype=jax.numpy.int32, # NOTE: This int32, not int64
)

X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim),
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel),
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)

tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
Z = tp_conv.forward(
X, Y, W, edge_index[0], edge_index[1]
)
print(jax.numpy.linalg.norm(Z))
```

## Citation and Acknowledgements
If you find this code useful, please cite our paper:

Expand Down
38 changes: 33 additions & 5 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ OpenEquivariance API
OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces
``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses
the CG tensor product with a subsequent graph convolution. Initializing either class triggers
JIT compilation of a custom kernel, which can take a few seconds.
JIT compilation of a custom kernel, which can take a few seconds.

Both classes require a configuration object specified
by :py:class:`openequivariance.TPProblem`, which has a constructor
Expand All @@ -17,6 +17,9 @@ We recommend reading the `e3nn documentation <https://docs.e3nn.org/en/latest/>`
trying our code. OpenEquivariance cannot accelerate all tensor products; see
:doc:`this page </supported_ops>` for a list of supported configurations.

PyTorch API
------------------------

.. autoclass:: openequivariance.TensorProduct
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
:undoc-members:
Expand All @@ -27,14 +30,39 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
:undoc-members:
:exclude-members: name

.. autoclass:: openequivariance.TPProblem
:members:
:undoc-members:

.. autofunction:: openequivariance.torch_to_oeq_dtype

.. autofunction:: openequivariance.torch_ext_so_path

JAX API
------------------------
The JAX API consists of ``TensorProduct`` and ``TensorProductConv``
classes that behave identically to their PyTorch counterparts. These classes
do not conform exactly to the e3nn-jax API, but perform the same computation.

If you plan to use ``oeq.jax`` without PyTorch installed,
you need to set ``OEQ_NOTORCH=1`` in your local environment (within Python,
``os.environ["OEQ_NOTORCH"] = 1``). For the moment, we require this to avoid
breaking the PyTorch version of OpenEquivariance.


.. autoclass:: openequivariance.jax.TensorProduct
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
:undoc-members:
:exclude-members:

.. autoclass:: openequivariance.jax.TensorProductConv
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
:undoc-members:
:exclude-members:

Common API
---------------------

.. autoclass:: openequivariance.TPProblem
:members:
:undoc-members:

API Identical to e3nn
---------------------

Expand Down
18 changes: 12 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@
html_theme = "furo"
# html_static_path = ["_static"]

extensions = [
"sphinx.ext.autodoc",
extensions = ["sphinx.ext.autodoc", "sphinx_inline_tabs"]

sys.path.insert(0, str(Path("../openequivariance").resolve()))

autodoc_mock_imports = [
"torch",
"jax",
"openequivariance._torch.extlib",
"openequivariance.jax.extlib",
"openequivariance_extjax",
"jinja2",
"numpy",
]

sys.path.insert(0, str(Path("..").resolve()))

autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"]
autodoc_typehints = "description"
Loading