Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ Z = tp_conv.forward(
X, Y, W, edge_index[0], edge_index[1]
)
print(jax.numpy.linalg.norm(Z))

# Test JAX JIT
jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2))
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))
```

## Citation and Acknowledgements
Expand Down
80 changes: 67 additions & 13 deletions openequivariance/openequivariance/jax/TensorProduct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from openequivariance.jax import extlib
Expand All @@ -16,10 +17,18 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
return forward_call(X, Y, W, **attrs)


def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs):
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)


def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ):
X, Y, W = inputs
return backward(X, Y, W, dZ, irrep_dtype, attrs)


forward.defvjp(forward_fwd, forward_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(4, 5))
def backward(X, Y, W, dZ, irrep_dtype, attrs):
backward_call = jax.ffi.ffi_call(
Expand All @@ -30,33 +39,78 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs):
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
),
)

return backward_call(X, Y, W, dZ, **attrs)


def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs):
def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs):
return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ)


def double_backward(irrep_dtype, attrs, inputs, derivatives):
def backward_bwd(irrep_dtype, attrs, inputs, derivs):
X, Y, W, dZ = inputs
ddX, ddY, ddW = derivs
return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs)


backward.defvjp(backward_fwd, backward_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(7, 8))
def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs):
double_backward_call = jax.ffi.ffi_call(
"tp_double_backward",
(
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
jax.ShapeDtypeStruct(X.shape, irrep_dtype),
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
jax.ShapeDtypeStruct(dZ.shape, irrep_dtype),
),
)
return double_backward_call(*inputs, *derivatives, **attrs)
return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs)


def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs):
out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs)
return out, (X, Y, W, dZ, ddX, ddY, ddW)


def zeros_like(x):
return jnp.zeros_like(x)


def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs):
X, Y, W, dZ, ddX, ddY, ddW = residuals
t_dX, t_dY, t_dW, t_ddZ = tangent_outputs

op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W))
g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs)

op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW))
g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs)

op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW)
g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs)

op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW)
g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs)

g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs)
g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs)
g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs)

grad_X = g2_X + g4_X + g6_X + g7_X
grad_Y = g2_Y + g3_Y + g5_Y + g7_Y
grad_W = g1_W + g3_W + g4_W + g5_W + g6_W
grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ

grad_ddX = g1_ddX + g3_ddX + g5_ddX
grad_ddY = g1_ddY + g4_ddY + g6_ddY
grad_ddW = g2_ddW + g7_ddW

def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW


forward.defvjp(forward_with_inputs, backward_autograd)
backward.defvjp(backward_with_inputs, double_backward)
double_backward.defvjp(double_backward_fwd, triple_backward)


class TensorProduct(LoopUnrollTP):
Expand Down
173 changes: 112 additions & 61 deletions openequivariance/openequivariance/jax/TensorProductConv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Optional
Expand All @@ -8,31 +10,44 @@
from openequivariance.core.utils import hash_attributes
from openequivariance.jax.utils import reorder_jax

import jax
import jax.numpy as jnp

from openequivariance.benchmark.logging_utils import getLogger

logger = getLogger()


@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
def zeros_like(x):
return jnp.zeros_like(x)


@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
forward_call = jax.ffi.ffi_call(
"conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)
)
return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs)


def forward_with_inputs(
def forward_fwd(
X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs
):
return forward(
out = forward(
X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs
), (X, Y, W, rows, cols, sender_perm, workspace)
)
return out, (X, Y, W, rows, cols)


def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ):
X, Y, W, rows, cols = res
dX, dY, dW = backward(
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
)
return dX, dY, dW, None, None


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9))
forward.defvjp(forward_fwd, forward_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9))
def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
backward_call = jax.ffi.ffi_call(
"conv_backward",
Expand All @@ -45,65 +60,121 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs)


def backward_with_inputs(
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
):
return backward(
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
), (X, Y, W, dZ) # rows, cols, sender_perm, workspace)
def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs)
return out, (X, Y, W, dZ, rows, cols)


def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives):
X, Y, W, dZ, rows, cols = res
ddX, ddY, ddW = derivatives

gX, gY, gW, gdZ = double_backward(
X,
Y,
W,
dZ,
ddX,
ddY,
ddW,
rows,
cols,
workspace,
sender_perm,
irrep_dtype,
attrs,
)

return gX, gY, gW, gdZ, None, None


backward.defvjp(backward_fwd, backward_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12))
def double_backward(
rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs
):
double_backward_call = jax.ffi.ffi_call(
"conv_double_backward",
(
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
jax.ShapeDtypeStruct(X.shape, irrep_dtype),
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
jax.ShapeDtypeStruct(dZ.shape, irrep_dtype),
),
)
return double_backward_call(
*inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, **attrs
)


def backward_autograd(
rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ
def double_backward_fwd(
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs
):
return backward(
inputs[0],
inputs[1],
inputs[2],
out = double_backward(
X,
Y,
W,
dZ,
ddX,
ddY,
ddW,
rows,
cols,
workspace,
sender_perm,
irrep_dtype,
attrs,
)
return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols)


forward.defvjp(forward_with_inputs, backward_autograd)
backward.defvjp(backward_with_inputs, double_backward)
def triple_backward(
workspace,
sender_perm,
irrep_dtype,
attrs,
residuals,
tangent_outputs,
):
X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals
t_dX, t_dY, t_dW, t_ddZ = tangent_outputs

common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs)

class TensorProductConv(LoopUnrollConv):
r"""
Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one
key difference: integer arrays passed to this function must have dtype
``np.int32`` (as opposed to ``np.int64`` in the PyTorch version).

:param problem: Specification of the tensor product.
:param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
fixup-based algorithm. `Default`: ``False``.
:param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
"""
op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W))
g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args)

op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW))
g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args)

op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW)
g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args)

op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW)
g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args)

g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args)
g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args)
g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args)

grad_X = g2_X + g4_X + g6_X + g7_X
grad_Y = g2_Y + g3_Y + g5_Y + g7_Y
grad_W = g1_W + g3_W + g4_W + g5_W + g6_W
grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ

grad_ddX = g1_ddX + g3_ddX + g5_ddX
grad_ddY = g1_ddY + g4_ddY + g6_ddY
grad_ddW = g2_ddW + g7_ddW

return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None


double_backward.defvjp(double_backward_fwd, triple_backward)


class TensorProductConv(LoopUnrollConv):
def __init__(
self, config: TPProblem, deterministic: bool = False, kahan: bool = False
):
Expand All @@ -112,7 +183,7 @@ def __init__(
config,
dp,
extlib.postprocess_kernel,
idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version
idx_dtype=np.int32,
torch_op=False,
deterministic=deterministic,
kahan=kahan,
Expand Down Expand Up @@ -145,26 +216,6 @@ def forward(
cols: jax.numpy.ndarray,
sender_perm: Optional[jax.numpy.ndarray] = None,
) -> jax.numpy.ndarray:
r"""
Computes the fused CG tensor product + convolution.

:param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
:param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
:param W: Tensor of datatype ``problem.weight_dtype`` and shape

* ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
* ``[problem.weight_numel]`` if ``problem.shared_weights=True``

:param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
:param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
datatype ``np.int32``.
:param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a
permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
Must be provided when ``deterministic=True``.

:return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
"""
if not self.deterministic:
sender_perm = self.dummy_transpose_perm
else:
Expand Down
3 changes: 3 additions & 0 deletions tests/example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,6 @@ def test_tutorial_jax(with_jax):
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1])
print(jax.numpy.linalg.norm(Z))

jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2))
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))