Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-torch-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
pytorch-version: ["2.3", "2.4", "2.5", "2.6", "2.7", "2.8", "2.9", "2.10", "2.11"]
pytorch-version: ["2.4", "2.5", "2.6", "2.7", "2.8", "2.9", "2.10", "2.11"]
os: [ubuntu-22.04, macos-14, ubuntu-22.04-arm]
include:
- os: ubuntu-22.04
Expand Down
2 changes: 1 addition & 1 deletion sphericart-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ else()
endif()
endif()

find_package(Torch 2.1 REQUIRED)
find_package(Torch 2.4 REQUIRED)

file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'")

Expand Down
11 changes: 2 additions & 9 deletions sphericart-torch/include/sphericart/torch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@

#include <torch/torch.h>

#include <mutex>

#include "sphericart.hpp"
#include "sphericart_cuda.hpp"

namespace sphericart_torch {

class SphericartAutograd;
class SphericartAutogradBackward;

class SphericalHarmonics : public torch::CustomClassHolder {
class SphericalHarmonics {
public:
SphericalHarmonics(int64_t l_max, bool backward_second_derivatives = false);

Expand All @@ -22,8 +19,6 @@ class SphericalHarmonics : public torch::CustomClassHolder {
std::vector<torch::Tensor> compute_with_gradients(torch::Tensor xyz);
std::vector<torch::Tensor> compute_with_hessians(torch::Tensor xyz);

int64_t get_l_max() const { return this->l_max_; }
bool get_backward_second_derivative_flag() const { return this->backward_second_derivatives_; }
int64_t get_omp_num_threads() const { return this->omp_num_threads_; }

private:
Expand All @@ -49,7 +44,7 @@ class SphericalHarmonics : public torch::CustomClassHolder {
std::unique_ptr<sphericart::cuda::SphericalHarmonics<float>> calculator_cuda_float_ptr;
};

class SolidHarmonics : public torch::CustomClassHolder {
class SolidHarmonics {
public:
SolidHarmonics(int64_t l_max, bool backward_second_derivatives = false);

Expand All @@ -58,8 +53,6 @@ class SolidHarmonics : public torch::CustomClassHolder {
std::vector<torch::Tensor> compute_with_gradients(torch::Tensor xyz);
std::vector<torch::Tensor> compute_with_hessians(torch::Tensor xyz);

int64_t get_l_max() const { return this->l_max_; }
bool get_backward_second_derivative_flag() const { return this->backward_second_derivatives_; }
int64_t get_omp_num_threads() const { return this->omp_num_threads_; }

private:
Expand Down
111 changes: 109 additions & 2 deletions sphericart-torch/python/sphericart/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,112 @@ def _lib_path():
)


# load the C++ operators and custom classes
torch.classes.load_library(_lib_path())
# load the C++ operators
torch.ops.load_library(_lib_path())


def _sph_size(xyz, l_max):
return (xyz.shape[0], (l_max + 1) ** 2)


def _xyz_grad_from_dsph(grad_sph, dsph):
return torch.sum(grad_sph.unsqueeze(1) * dsph, dim=2)


def _xyz_grad_from_ddsph(grad_dsph, ddsph):
return torch.sum(grad_dsph.unsqueeze(2) * ddsph, dim=(1, 3))


def _fake_compute(xyz, l_max):
return xyz.new_empty(_sph_size(xyz, l_max))


def _fake_compute_with_gradients(xyz, l_max):
sph = xyz.new_empty(_sph_size(xyz, l_max))
dsph = xyz.new_empty((xyz.shape[0], 3, sph.shape[1]))
return sph, dsph


def _fake_compute_with_hessians(xyz, l_max):
sph, dsph = _fake_compute_with_gradients(xyz, l_max)
ddsph = xyz.new_empty((xyz.shape[0], 3, 3, sph.shape[1]))
return sph, dsph, ddsph


def _setup_fake_impls(prefix):
@torch.library.register_fake(f"sphericart_torch::{prefix}")
def _fake_compute_op(xyz, l_max, backward_second_derivatives=False):
return _fake_compute(xyz, l_max)

@torch.library.register_fake(f"sphericart_torch::{prefix}_with_gradients")
def _fake_compute_with_gradients_op(xyz, l_max):
return _fake_compute_with_gradients(xyz, l_max)

@torch.library.register_fake(f"sphericart_torch::{prefix}_with_hessians")
def _fake_compute_with_hessians_op(xyz, l_max):
return _fake_compute_with_hessians(xyz, l_max)


def _setup_compute_context(ctx, inputs, output):
xyz, l_max, backward_second_derivatives = inputs
ctx.save_for_backward(xyz)
ctx.l_max = l_max
ctx.backward_second_derivatives = backward_second_derivatives


def _setup_gradient_context(ctx, inputs, output):
xyz, l_max = inputs
_, dsph = output
ctx.save_for_backward(xyz, dsph)
ctx.l_max = l_max


def _setup_autograd(prefix):
gradients_op = getattr(torch.ops.sphericart_torch, f"{prefix}_with_gradients")
hessians_op = getattr(torch.ops.sphericart_torch, f"{prefix}_with_hessians")

def _compute_backward(ctx, grad_sph):
(xyz,) = ctx.saved_tensors
if not ctx.needs_input_grad[0]:
return None, None, None

grad_context = (
torch.enable_grad() if ctx.backward_second_derivatives else torch.no_grad()
)
with grad_context:
_, dsph = gradients_op(xyz, ctx.l_max)

return _xyz_grad_from_dsph(grad_sph, dsph), None, None

def _gradients_backward(ctx, grad_sph, grad_dsph):
xyz, dsph = ctx.saved_tensors
if not ctx.needs_input_grad[0]:
return None, None

xyz_grad = None
if grad_sph is not None:
xyz_grad = _xyz_grad_from_dsph(grad_sph, dsph)

if grad_dsph is not None:
with torch.no_grad():
_, _, ddsph = hessians_op(xyz, ctx.l_max)
ddsph_grad = _xyz_grad_from_ddsph(grad_dsph, ddsph)
xyz_grad = ddsph_grad if xyz_grad is None else xyz_grad + ddsph_grad

return xyz_grad, None

torch.library.register_autograd(
f"sphericart_torch::{prefix}",
_compute_backward,
setup_context=_setup_compute_context,
)
torch.library.register_autograd(
f"sphericart_torch::{prefix}_with_gradients",
_gradients_backward,
setup_context=_setup_gradient_context,
)


for _prefix in ("spherical_harmonics", "solid_harmonics"):
_setup_fake_impls(_prefix)
_setup_autograd(_prefix)
48 changes: 30 additions & 18 deletions sphericart-torch/python/sphericart/torch/spherical_hamonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def __init__(
backward_second_derivatives: bool = False,
):
super().__init__()
self.calculator = torch.classes.sphericart_torch.SphericalHarmonics(
l_max, backward_second_derivatives
)
self._l_max = l_max
self._backward_second_derivatives = backward_second_derivatives

def forward(self, xyz: Tensor) -> Tensor:
"""
Expand All @@ -103,11 +102,13 @@ def forward(self, xyz: Tensor) -> Tensor:
spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1,
1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order.
"""
return self.calculator.compute(xyz)
return torch.ops.sphericart_torch.spherical_harmonics(
xyz, self._l_max, self._backward_second_derivatives
)

def compute(self, xyz: Tensor) -> Tensor:
"""Equivalent to ``forward``"""
return self.calculator.compute(xyz)
return self.forward(xyz)

def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]:
"""
Expand Down Expand Up @@ -139,7 +140,9 @@ def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]:
derivatives in the the x, y, and z directions, respectively.

"""
return self.calculator.compute_with_gradients(xyz)
return torch.ops.sphericart_torch.spherical_harmonics_with_gradients(
xyz, self._l_max
)

def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Expand Down Expand Up @@ -176,15 +179,19 @@ def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
hessian dimensions.

"""
return self.calculator.compute_with_hessians(xyz)
return torch.ops.sphericart_torch.spherical_harmonics_with_hessians(
xyz, self._l_max
)

def omp_num_threads(self):
"""Returns the number of threads available for calculations on the CPU."""
return self.calculator.omp_num_threads()
return torch.ops.sphericart_torch.spherical_harmonics_omp_num_threads(
self._l_max
)

def l_max(self):
"""Returns the maximum angular momentum setting for this calculator."""
return self.calculator.l_max()
return self._l_max


class SolidHarmonics(torch.nn.Module):
Expand Down Expand Up @@ -216,30 +223,35 @@ def __init__(
backward_second_derivatives: bool = False,
):
super().__init__()
self.calculator = torch.classes.sphericart_torch.SolidHarmonics(
l_max, backward_second_derivatives
)
self._l_max = l_max
self._backward_second_derivatives = backward_second_derivatives

def forward(self, xyz: Tensor) -> Tensor:
"""See :py:meth:`SphericalHarmonics.forward`"""
return self.calculator.compute(xyz)
return torch.ops.sphericart_torch.solid_harmonics(
xyz, self._l_max, self._backward_second_derivatives
)

def compute(self, xyz: Tensor) -> Tensor:
"""Equivalent to ``forward``"""
return self.calculator.compute(xyz)
return self.forward(xyz)

def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]:
"""See :py:meth:`SphericalHarmonics.compute_with_gradients`"""
return self.calculator.compute_with_gradients(xyz)
return torch.ops.sphericart_torch.solid_harmonics_with_gradients(
xyz, self._l_max
)

def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""See :py:meth:`SphericalHarmonics.compute_with_hessians`"""
return self.calculator.compute_with_hessians(xyz)
return torch.ops.sphericart_torch.solid_harmonics_with_hessians(
xyz, self._l_max
)

def omp_num_threads(self):
"""Returns the number of threads available for calculations on the CPU."""
return self.calculator.omp_num_threads()
return torch.ops.sphericart_torch.solid_harmonics_omp_num_threads(self._l_max)

def l_max(self):
"""Returns the maximum angular momentum setting for this calculator."""
return self.calculator.l_max()
return self._l_max
29 changes: 29 additions & 0 deletions sphericart-torch/python/tests/test_script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
import torch._dynamo as dynamo

import sphericart.torch

Expand Down Expand Up @@ -39,3 +40,31 @@ def test_script(xyz, normalized):
script = torch.jit.script(module)
sh_script = script.forward(xyz_jit)
sh_script.sum().backward()


@pytest.mark.parametrize("normalized", [True, False])
def test_compile(xyz, normalized):
xyz_eager = xyz.detach().clone().requires_grad_()
xyz_compiled = xyz.detach().clone().requires_grad_()
module = SHModule(l_max=10, normalized=normalized)

sh_eager = module(xyz_eager)
sh_eager.sum().backward()
eager_grad = xyz_eager.grad.detach().clone()

compiled = torch.compile(module, fullgraph=True)
sh_compiled = compiled(xyz_compiled)
sh_compiled.sum().backward()

assert torch.allclose(sh_compiled, sh_eager)
assert torch.allclose(xyz_compiled.grad, eager_grad)


@pytest.mark.parametrize("normalized", [True, False])
def test_compile_has_no_graph_breaks(xyz, normalized):
module = SHModule(l_max=10, normalized=normalized)
explanation = dynamo.explain(module)(xyz.detach())

assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
assert explanation.break_reasons == []
2 changes: 1 addition & 1 deletion sphericart-torch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def run(self):
torch_version = f"== {torch_v_major}.{torch_v_minor}.*"
except ImportError:
# otherwise we are building a sdist
torch_version = ">= 2.1"
torch_version = ">= 2.4"

install_requires = [f"torch {torch_version}"]

Expand Down
2 changes: 0 additions & 2 deletions sphericart-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ class CUDAStream {
return instance;
}

bool loaded() { return handle != nullptr; }

using get_stream_t = void* (*)(uint8_t);
get_stream_t get_stream = nullptr;

Expand Down
Loading
Loading