From e96ca85bc09d79add1be6ad16520cdfafbbdd3f3 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 28 Mar 2026 07:37:47 +0100 Subject: [PATCH 1/3] Add support for `torch.compile` --- sphericart-torch/include/sphericart/torch.hpp | 11 +- .../python/sphericart/torch/__init__.py | 111 +++++++++++++- .../sphericart/torch/spherical_hamonics.py | 48 +++--- sphericart-torch/python/tests/test_script.py | 29 ++++ sphericart-torch/src/autograd.cpp | 2 - sphericart-torch/src/torch.cpp | 145 ++++++++++-------- 6 files changed, 255 insertions(+), 91 deletions(-) diff --git a/sphericart-torch/include/sphericart/torch.hpp b/sphericart-torch/include/sphericart/torch.hpp index 0e5e04183..9db5d51d3 100644 --- a/sphericart-torch/include/sphericart/torch.hpp +++ b/sphericart-torch/include/sphericart/torch.hpp @@ -3,17 +3,14 @@ #include -#include - #include "sphericart.hpp" #include "sphericart_cuda.hpp" namespace sphericart_torch { class SphericartAutograd; -class SphericartAutogradBackward; -class SphericalHarmonics : public torch::CustomClassHolder { +class SphericalHarmonics { public: SphericalHarmonics(int64_t l_max, bool backward_second_derivatives = false); @@ -22,8 +19,6 @@ class SphericalHarmonics : public torch::CustomClassHolder { std::vector compute_with_gradients(torch::Tensor xyz); std::vector compute_with_hessians(torch::Tensor xyz); - int64_t get_l_max() const { return this->l_max_; } - bool get_backward_second_derivative_flag() const { return this->backward_second_derivatives_; } int64_t get_omp_num_threads() const { return this->omp_num_threads_; } private: @@ -49,7 +44,7 @@ class SphericalHarmonics : public torch::CustomClassHolder { std::unique_ptr> calculator_cuda_float_ptr; }; -class SolidHarmonics : public torch::CustomClassHolder { +class SolidHarmonics { public: SolidHarmonics(int64_t l_max, bool backward_second_derivatives = false); @@ -58,8 +53,6 @@ class SolidHarmonics : public torch::CustomClassHolder { std::vector compute_with_gradients(torch::Tensor xyz); std::vector compute_with_hessians(torch::Tensor xyz); - int64_t get_l_max() const { return this->l_max_; } - bool get_backward_second_derivative_flag() const { return this->backward_second_derivatives_; } int64_t get_omp_num_threads() const { return this->omp_num_threads_; } private: diff --git a/sphericart-torch/python/sphericart/torch/__init__.py b/sphericart-torch/python/sphericart/torch/__init__.py index a045d7353..ad42b9b66 100644 --- a/sphericart-torch/python/sphericart/torch/__init__.py +++ b/sphericart-torch/python/sphericart/torch/__init__.py @@ -68,5 +68,112 @@ def _lib_path(): ) -# load the C++ operators and custom classes -torch.classes.load_library(_lib_path()) +# load the C++ operators +torch.ops.load_library(_lib_path()) + + +def _sph_size(xyz, l_max): + return (xyz.shape[0], (l_max + 1) ** 2) + + +def _xyz_grad_from_dsph(grad_sph, dsph): + return torch.sum(grad_sph.unsqueeze(1) * dsph, dim=2) + + +def _xyz_grad_from_ddsph(grad_dsph, ddsph): + return torch.sum(grad_dsph.unsqueeze(2) * ddsph, dim=(1, 3)) + + +def _fake_compute(xyz, l_max): + return xyz.new_empty(_sph_size(xyz, l_max)) + + +def _fake_compute_with_gradients(xyz, l_max): + sph = xyz.new_empty(_sph_size(xyz, l_max)) + dsph = xyz.new_empty((xyz.shape[0], 3, sph.shape[1])) + return sph, dsph + + +def _fake_compute_with_hessians(xyz, l_max): + sph, dsph = _fake_compute_with_gradients(xyz, l_max) + ddsph = xyz.new_empty((xyz.shape[0], 3, 3, sph.shape[1])) + return sph, dsph, ddsph + + +def _setup_fake_impls(prefix): + @torch.library.register_fake(f"sphericart_torch::{prefix}") + def _fake_compute_op(xyz, l_max, backward_second_derivatives=False): + return _fake_compute(xyz, l_max) + + @torch.library.register_fake(f"sphericart_torch::{prefix}_with_gradients") + def _fake_compute_with_gradients_op(xyz, l_max): + return _fake_compute_with_gradients(xyz, l_max) + + @torch.library.register_fake(f"sphericart_torch::{prefix}_with_hessians") + def _fake_compute_with_hessians_op(xyz, l_max): + return _fake_compute_with_hessians(xyz, l_max) + + +def _setup_compute_context(ctx, inputs, output): + xyz, l_max, backward_second_derivatives = inputs + ctx.save_for_backward(xyz) + ctx.l_max = l_max + ctx.backward_second_derivatives = backward_second_derivatives + + +def _setup_gradient_context(ctx, inputs, output): + xyz, l_max = inputs + _, dsph = output + ctx.save_for_backward(xyz, dsph) + ctx.l_max = l_max + + +def _setup_autograd(prefix): + gradients_op = getattr(torch.ops.sphericart_torch, f"{prefix}_with_gradients") + hessians_op = getattr(torch.ops.sphericart_torch, f"{prefix}_with_hessians") + + def _compute_backward(ctx, grad_sph): + (xyz,) = ctx.saved_tensors + if not ctx.needs_input_grad[0]: + return None, None, None + + grad_context = ( + torch.enable_grad() if ctx.backward_second_derivatives else torch.no_grad() + ) + with grad_context: + _, dsph = gradients_op(xyz, ctx.l_max) + + return _xyz_grad_from_dsph(grad_sph, dsph), None, None + + def _gradients_backward(ctx, grad_sph, grad_dsph): + xyz, dsph = ctx.saved_tensors + if not ctx.needs_input_grad[0]: + return None, None + + xyz_grad = None + if grad_sph is not None: + xyz_grad = _xyz_grad_from_dsph(grad_sph, dsph) + + if grad_dsph is not None: + with torch.no_grad(): + _, _, ddsph = hessians_op(xyz, ctx.l_max) + ddsph_grad = _xyz_grad_from_ddsph(grad_dsph, ddsph) + xyz_grad = ddsph_grad if xyz_grad is None else xyz_grad + ddsph_grad + + return xyz_grad, None + + torch.library.register_autograd( + f"sphericart_torch::{prefix}", + _compute_backward, + setup_context=_setup_compute_context, + ) + torch.library.register_autograd( + f"sphericart_torch::{prefix}_with_gradients", + _gradients_backward, + setup_context=_setup_gradient_context, + ) + + +for _prefix in ("spherical_harmonics", "solid_harmonics"): + _setup_fake_impls(_prefix) + _setup_autograd(_prefix) diff --git a/sphericart-torch/python/sphericart/torch/spherical_hamonics.py b/sphericart-torch/python/sphericart/torch/spherical_hamonics.py index 9ca57bba6..9ee79d4ca 100644 --- a/sphericart-torch/python/sphericart/torch/spherical_hamonics.py +++ b/sphericart-torch/python/sphericart/torch/spherical_hamonics.py @@ -74,9 +74,8 @@ def __init__( backward_second_derivatives: bool = False, ): super().__init__() - self.calculator = torch.classes.sphericart_torch.SphericalHarmonics( - l_max, backward_second_derivatives - ) + self._l_max = l_max + self._backward_second_derivatives = backward_second_derivatives def forward(self, xyz: Tensor) -> Tensor: """ @@ -103,11 +102,13 @@ def forward(self, xyz: Tensor) -> Tensor: spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. """ - return self.calculator.compute(xyz) + return torch.ops.sphericart_torch.spherical_harmonics( + xyz, self._l_max, self._backward_second_derivatives + ) def compute(self, xyz: Tensor) -> Tensor: """Equivalent to ``forward``""" - return self.calculator.compute(xyz) + return self.forward(xyz) def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: """ @@ -139,7 +140,9 @@ def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: derivatives in the the x, y, and z directions, respectively. """ - return self.calculator.compute_with_gradients(xyz) + return torch.ops.sphericart_torch.spherical_harmonics_with_gradients( + xyz, self._l_max + ) def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ @@ -176,15 +179,19 @@ def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: hessian dimensions. """ - return self.calculator.compute_with_hessians(xyz) + return torch.ops.sphericart_torch.spherical_harmonics_with_hessians( + xyz, self._l_max + ) def omp_num_threads(self): """Returns the number of threads available for calculations on the CPU.""" - return self.calculator.omp_num_threads() + return torch.ops.sphericart_torch.spherical_harmonics_omp_num_threads( + self._l_max + ) def l_max(self): """Returns the maximum angular momentum setting for this calculator.""" - return self.calculator.l_max() + return self._l_max class SolidHarmonics(torch.nn.Module): @@ -216,30 +223,35 @@ def __init__( backward_second_derivatives: bool = False, ): super().__init__() - self.calculator = torch.classes.sphericart_torch.SolidHarmonics( - l_max, backward_second_derivatives - ) + self._l_max = l_max + self._backward_second_derivatives = backward_second_derivatives def forward(self, xyz: Tensor) -> Tensor: """See :py:meth:`SphericalHarmonics.forward`""" - return self.calculator.compute(xyz) + return torch.ops.sphericart_torch.solid_harmonics( + xyz, self._l_max, self._backward_second_derivatives + ) def compute(self, xyz: Tensor) -> Tensor: """Equivalent to ``forward``""" - return self.calculator.compute(xyz) + return self.forward(xyz) def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: """See :py:meth:`SphericalHarmonics.compute_with_gradients`""" - return self.calculator.compute_with_gradients(xyz) + return torch.ops.sphericart_torch.solid_harmonics_with_gradients( + xyz, self._l_max + ) def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """See :py:meth:`SphericalHarmonics.compute_with_hessians`""" - return self.calculator.compute_with_hessians(xyz) + return torch.ops.sphericart_torch.solid_harmonics_with_hessians( + xyz, self._l_max + ) def omp_num_threads(self): """Returns the number of threads available for calculations on the CPU.""" - return self.calculator.omp_num_threads() + return torch.ops.sphericart_torch.solid_harmonics_omp_num_threads(self._l_max) def l_max(self): """Returns the maximum angular momentum setting for this calculator.""" - return self.calculator.l_max() + return self._l_max diff --git a/sphericart-torch/python/tests/test_script.py b/sphericart-torch/python/tests/test_script.py index c4b7b9c73..69afa01c7 100644 --- a/sphericart-torch/python/tests/test_script.py +++ b/sphericart-torch/python/tests/test_script.py @@ -1,5 +1,6 @@ import pytest import torch +import torch._dynamo as dynamo import sphericart.torch @@ -39,3 +40,31 @@ def test_script(xyz, normalized): script = torch.jit.script(module) sh_script = script.forward(xyz_jit) sh_script.sum().backward() + + +@pytest.mark.parametrize("normalized", [True, False]) +def test_compile(xyz, normalized): + xyz_eager = xyz.detach().clone().requires_grad_() + xyz_compiled = xyz.detach().clone().requires_grad_() + module = SHModule(l_max=10, normalized=normalized) + + sh_eager = module(xyz_eager) + sh_eager.sum().backward() + eager_grad = xyz_eager.grad.detach().clone() + + compiled = torch.compile(module, fullgraph=True) + sh_compiled = compiled(xyz_compiled) + sh_compiled.sum().backward() + + assert torch.allclose(sh_compiled, sh_eager) + assert torch.allclose(xyz_compiled.grad, eager_grad) + + +@pytest.mark.parametrize("normalized", [True, False]) +def test_compile_has_no_graph_breaks(xyz, normalized): + module = SHModule(l_max=10, normalized=normalized) + explanation = dynamo.explain(module)(xyz.detach()) + + assert explanation.graph_count == 1 + assert explanation.graph_break_count == 0 + assert explanation.break_reasons == [] diff --git a/sphericart-torch/src/autograd.cpp b/sphericart-torch/src/autograd.cpp index c3632fa75..1b856990c 100644 --- a/sphericart-torch/src/autograd.cpp +++ b/sphericart-torch/src/autograd.cpp @@ -20,8 +20,6 @@ class CUDAStream { return instance; } - bool loaded() { return handle != nullptr; } - using get_stream_t = void* (*)(uint8_t); get_stream_t get_stream = nullptr; diff --git a/sphericart-torch/src/torch.cpp b/sphericart-torch/src/torch.cpp index 3314441a3..a6893908f 100644 --- a/sphericart-torch/src/torch.cpp +++ b/sphericart-torch/src/torch.cpp @@ -3,9 +3,62 @@ #include "sphericart/torch.hpp" #include "sphericart/autograd.hpp" -using namespace torch; +#include +#include +#include +#include + using namespace sphericart_torch; +namespace { + +template +using CacheMap = std::map, std::unique_ptr>; + +template +Calculator& get_or_create_calculator(int64_t l_max, bool backward_second_derivatives) { + static CacheMap cache; + static std::mutex cache_mutex; + + std::lock_guard lock(cache_mutex); + auto key = std::make_tuple(l_max, backward_second_derivatives); + auto it = cache.find(key); + if (it == cache.end()) { + it = + cache.insert({key, std::make_unique(l_max, backward_second_derivatives)}).first; + } + return *(it->second); +} + +template +torch::Tensor compute(torch::Tensor xyz, int64_t l_max, bool backward_second_derivatives) { + auto& calculator = get_or_create_calculator(l_max, backward_second_derivatives); + return calculator.compute(xyz); +} + +template +std::tuple compute_with_gradients(torch::Tensor xyz, int64_t l_max) { + auto& calculator = get_or_create_calculator(l_max, false); + auto result = calculator.compute_with_gradients(xyz); + return {result[0], result[1]}; +} + +template +std::tuple compute_with_hessians( + torch::Tensor xyz, int64_t l_max +) { + auto& calculator = get_or_create_calculator(l_max, false); + auto result = calculator.compute_with_hessians(xyz); + return {result[0], result[1], result[2]}; +} + +template int64_t omp_num_threads(int64_t l_max) { + auto& calculator = get_or_create_calculator(l_max, false); + return calculator.get_omp_num_threads(); +} + +} // namespace + SphericalHarmonics::SphericalHarmonics(int64_t l_max, bool backward_second_derivatives) : l_max_(l_max), backward_second_derivatives_(backward_second_derivatives), calculator_double_(l_max_), calculator_float_(l_max_) { @@ -59,63 +112,35 @@ std::vector SolidHarmonics::compute_with_hessians(torch::Tensor x } TORCH_LIBRARY(sphericart_torch, m) { - m.class_("SphericalHarmonics") - .def( - torch::init(), - "", - {torch::arg("l_max"), torch::arg("backward_second_derivatives") = false} - ) - .def("compute", &SphericalHarmonics::compute, "", {torch::arg("xyz")}) - .def( - "compute_with_gradients", - &SphericalHarmonics::compute_with_gradients, - "", - {torch::arg("xyz")} - ) - .def( - "compute_with_hessians", - &SphericalHarmonics::compute_with_hessians, - "", - {torch::arg("xyz")} - ) - .def("omp_num_threads", &SphericalHarmonics::get_omp_num_threads) - .def("l_max", &SphericalHarmonics::get_l_max) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& self) -> std::tuple { - return {self->get_l_max(), self->get_backward_second_derivative_flag()}; - }, - // __setstate__ - [](std::tuple state) -> c10::intrusive_ptr { - const auto l_max = std::get<0>(state); - const auto backward_second_derivatives = std::get<1>(state); - return c10::make_intrusive(l_max, backward_second_derivatives); - } - ); - - m.class_("SolidHarmonics") - .def( - torch::init(), - "", - {torch::arg("l_max"), torch::arg("backward_second_derivatives") = false} - ) - .def("compute", &SolidHarmonics::compute, "", {torch::arg("xyz")}) - .def( - "compute_with_gradients", &SolidHarmonics::compute_with_gradients, "", {torch::arg("xyz")} - ) - .def("compute_with_hessians", &SolidHarmonics::compute_with_hessians, "", {torch::arg("xyz")}) - .def("omp_num_threads", &SolidHarmonics::get_omp_num_threads) - .def("l_max", &SolidHarmonics::get_l_max) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& self) -> std::tuple { - return {self->get_l_max(), self->get_backward_second_derivative_flag()}; - }, - // __setstate__ - [](std::tuple state) -> c10::intrusive_ptr { - const auto l_max = std::get<0>(state); - const auto backward_second_derivatives = std::get<1>(state); - return c10::make_intrusive(l_max, backward_second_derivatives); - } - ); + m.def("spherical_harmonics(Tensor xyz, int l_max, bool backward_second_derivatives=False) -> Tensor"); + m.def("spherical_harmonics_with_gradients(Tensor xyz, int l_max) -> (Tensor, Tensor)"); + m.def("spherical_harmonics_with_hessians(Tensor xyz, int l_max) -> (Tensor, Tensor, Tensor)"); + m.def("solid_harmonics(Tensor xyz, int l_max, bool backward_second_derivatives=False) -> Tensor"); + m.def("solid_harmonics_with_gradients(Tensor xyz, int l_max) -> (Tensor, Tensor)"); + m.def("solid_harmonics_with_hessians(Tensor xyz, int l_max) -> (Tensor, Tensor, Tensor)"); + m.def("spherical_harmonics_omp_num_threads(int l_max) -> int"); + m.def("solid_harmonics_omp_num_threads(int l_max) -> int"); +} + +TORCH_LIBRARY_IMPL(sphericart_torch, CPU, m) { + m.impl("spherical_harmonics", &compute); + m.impl("spherical_harmonics_with_gradients", &compute_with_gradients); + m.impl("spherical_harmonics_with_hessians", &compute_with_hessians); + m.impl("solid_harmonics", &compute); + m.impl("solid_harmonics_with_gradients", &compute_with_gradients); + m.impl("solid_harmonics_with_hessians", &compute_with_hessians); +} + +TORCH_LIBRARY_IMPL(sphericart_torch, CUDA, m) { + m.impl("spherical_harmonics", &compute); + m.impl("spherical_harmonics_with_gradients", &compute_with_gradients); + m.impl("spherical_harmonics_with_hessians", &compute_with_hessians); + m.impl("solid_harmonics", &compute); + m.impl("solid_harmonics_with_gradients", &compute_with_gradients); + m.impl("solid_harmonics_with_hessians", &compute_with_hessians); +} + +TORCH_LIBRARY_IMPL(sphericart_torch, CompositeExplicitAutograd, m) { + m.impl("spherical_harmonics_omp_num_threads", &omp_num_threads); + m.impl("solid_harmonics_omp_num_threads", &omp_num_threads); } From fc83b2e816338048bb4c78945c1e3ead1a7a0f07 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 28 Mar 2026 07:42:43 +0100 Subject: [PATCH 2/3] Raise torch minimum to 2.4 --- sphericart-torch/CMakeLists.txt | 2 +- sphericart-torch/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sphericart-torch/CMakeLists.txt b/sphericart-torch/CMakeLists.txt index edf702439..37683cb66 100644 --- a/sphericart-torch/CMakeLists.txt +++ b/sphericart-torch/CMakeLists.txt @@ -70,7 +70,7 @@ else() endif() endif() -find_package(Torch 2.1 REQUIRED) +find_package(Torch 2.4 REQUIRED) file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'") diff --git a/sphericart-torch/setup.py b/sphericart-torch/setup.py index 3ecc6c2ce..9df4c078c 100644 --- a/sphericart-torch/setup.py +++ b/sphericart-torch/setup.py @@ -113,7 +113,7 @@ def run(self): torch_version = f"== {torch_v_major}.{torch_v_minor}.*" except ImportError: # otherwise we are building a sdist - torch_version = ">= 2.1" + torch_version = ">= 2.4" install_requires = [f"torch {torch_version}"] From fbe5adfab61f30c5f14413184f60b51c04884c56 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 28 Mar 2026 07:46:23 +0100 Subject: [PATCH 3/3] Fix minimum version in tox --- .github/workflows/build-torch-wheels.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-torch-wheels.yml b/.github/workflows/build-torch-wheels.yml index 1cdc4d1cf..763b29e12 100644 --- a/.github/workflows/build-torch-wheels.yml +++ b/.github/workflows/build-torch-wheels.yml @@ -51,7 +51,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - pytorch-version: ["2.3", "2.4", "2.5", "2.6", "2.7", "2.8", "2.9", "2.10", "2.11"] + pytorch-version: ["2.4", "2.5", "2.6", "2.7", "2.8", "2.9", "2.10", "2.11"] os: [ubuntu-22.04, macos-14, ubuntu-22.04-arm] include: - os: ubuntu-22.04 diff --git a/tox.ini b/tox.ini index 2157a23b8..c9e66418d 100644 --- a/tox.ini +++ b/tox.ini @@ -64,7 +64,7 @@ deps = {[testenv]packaging_deps} numpy - torch==2.3.0 + torch==2.4.0 pytest e3nn metatensor-torch