diff --git a/README.md b/README.md index c68e4fa9..3efe59a7 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,10 @@ Z = tp_conv.forward( X, Y, W, edge_index[0], edge_index[1] ) print(jax.numpy.linalg.norm(Z)) + +# Test JAX JIT +jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)) +print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1]))) ``` ## Citation and Acknowledgements diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 452e7bb7..05e4b097 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -1,4 +1,5 @@ import jax +import jax.numpy as jnp import numpy as np from functools import partial from openequivariance.jax import extlib @@ -16,10 +17,18 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs): return forward_call(X, Y, W, **attrs) -def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): +def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs): return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) +def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ): + X, Y, W = inputs + return backward(X, Y, W, dZ, irrep_dtype, attrs) + + +forward.defvjp(forward_fwd, forward_bwd) + + @partial(jax.custom_vjp, nondiff_argnums=(4, 5)) def backward(X, Y, W, dZ, irrep_dtype, attrs): backward_call = jax.ffi.ffi_call( @@ -30,33 +39,78 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs): jax.ShapeDtypeStruct(W.shape, irrep_dtype), ), ) - return backward_call(X, Y, W, dZ, **attrs) -def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): +def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs): return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) -def double_backward(irrep_dtype, attrs, inputs, derivatives): +def backward_bwd(irrep_dtype, attrs, inputs, derivs): + X, Y, W, dZ = inputs + ddX, ddY, ddW = derivs + return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) + + +backward.defvjp(backward_fwd, backward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(7, 8)) +def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): double_backward_call = jax.ffi.ffi_call( "tp_double_backward", ( - jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), ), ) - return double_backward_call(*inputs, *derivatives, **attrs) + return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs) + + +def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): + out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) + return out, (X, Y, W, dZ, ddX, ddY, ddW) + + +def zeros_like(x): + return jnp.zeros_like(x) + + +def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs): + X, Y, W, dZ, ddX, ddY, ddW = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs) + + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs) + + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs) + + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs) + + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW -def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): - return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW -forward.defvjp(forward_with_inputs, backward_autograd) -backward.defvjp(backward_with_inputs, double_backward) +double_backward.defvjp(double_backward_fwd, triple_backward) class TensorProduct(LoopUnrollTP): diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 3aaee28a..7439cd4e 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,3 +1,5 @@ +import jax +import jax.numpy as jnp import numpy as np from functools import partial from typing import Optional @@ -8,15 +10,16 @@ from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax -import jax -import jax.numpy as jnp - from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) +def zeros_like(x): + return jnp.zeros_like(x) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call( "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) @@ -24,15 +27,27 @@ def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, at return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) -def forward_with_inputs( +def forward_fwd( X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs ): - return forward( + out = forward( X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs - ), (X, Y, W, rows, cols, sender_perm, workspace) + ) + return out, (X, Y, W, rows, cols) + + +def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ): + X, Y, W, rows, cols = res + dX, dY, dW = backward( + X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs + ) + return dX, dY, dW, None, None -@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9)) +forward.defvjp(forward_fwd, forward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): backward_call = jax.ffi.ffi_call( "conv_backward", @@ -45,39 +60,66 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) -def backward_with_inputs( - X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs -): - return backward( - X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs - ), (X, Y, W, dZ) # rows, cols, sender_perm, workspace) +def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): + out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs) + return out, (X, Y, W, dZ, rows, cols) + + +def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives): + X, Y, W, dZ, rows, cols = res + ddX, ddY, ddW = derivatives + + gX, gY, gW, gdZ = double_backward( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + irrep_dtype, + attrs, + ) + + return gX, gY, gW, gdZ, None, None + +backward.defvjp(backward_fwd, backward_bwd) + +@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) def double_backward( - rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs ): double_backward_call = jax.ffi.ffi_call( "conv_double_backward", ( - jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), - jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), ), ) return double_backward_call( - *inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, **attrs ) -def backward_autograd( - rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ +def double_backward_fwd( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs ): - return backward( - inputs[0], - inputs[1], - inputs[2], + out = double_backward( + X, + Y, + W, dZ, + ddX, + ddY, + ddW, rows, cols, workspace, @@ -85,25 +127,54 @@ def backward_autograd( irrep_dtype, attrs, ) + return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) -forward.defvjp(forward_with_inputs, backward_autograd) -backward.defvjp(backward_with_inputs, double_backward) +def triple_backward( + workspace, + sender_perm, + irrep_dtype, + attrs, + residuals, + tangent_outputs, +): + X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs) -class TensorProductConv(LoopUnrollConv): - r""" - Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one - key difference: integer arrays passed to this function must have dtype - ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). - - :param problem: Specification of the tensor product. - :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic - fixup-based algorithm. `Default`: ``False``. - :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, - the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. - """ + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args) + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args) + + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args) + + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args) + + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ + + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW + + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None + + +double_backward.defvjp(double_backward_fwd, triple_backward) + + +class TensorProductConv(LoopUnrollConv): def __init__( self, config: TPProblem, deterministic: bool = False, kahan: bool = False ): @@ -112,7 +183,7 @@ def __init__( config, dp, extlib.postprocess_kernel, - idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version + idx_dtype=np.int32, torch_op=False, deterministic=deterministic, kahan=kahan, @@ -145,26 +216,6 @@ def forward( cols: jax.numpy.ndarray, sender_perm: Optional[jax.numpy.ndarray] = None, ) -> jax.numpy.ndarray: - r""" - Computes the fused CG tensor product + convolution. - - :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. - :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. - :param W: Tensor of datatype ``problem.weight_dtype`` and shape - - * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False`` - * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` - - :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix, - datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``. - :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix, - datatype ``np.int32``. - :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a - permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. - Must be provided when ``deterministic=True``. - - :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. - """ if not self.deterministic: sender_perm = self.dummy_transpose_perm else: diff --git a/tests/example_test.py b/tests/example_test.py index ae19f77e..e8d23cb7 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -161,3 +161,6 @@ def test_tutorial_jax(with_jax): tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) print(jax.numpy.linalg.norm(Z)) + + jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)) + print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))