Skip to content

Commit 0f45a24

Browse files
committed
Precommit.
1 parent 2aa70a7 commit 0f45a24

File tree

4 files changed

+25
-20
lines changed

4 files changed

+25
-20
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ Z = tp_conv.forward(
185185
print(jax.numpy.linalg.norm(Z))
186186

187187
# Test JAX JIT
188-
func = lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)
189-
jitted = jax.jit(func)
188+
jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2))
190189
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))
191190
```
192191

openequivariance/openequivariance/jax/TensorProduct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from openequivariance.core.utils import hash_attributes
99
from openequivariance.jax.utils import reorder_jax
1010

11+
1112
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
1213
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
1314
forward_call = jax.ffi.ffi_call(
@@ -209,4 +210,4 @@ def double_backward_cpu(
209210
out_grad_jax,
210211
)[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
211212

212-
return in1_grad, in2_grad, weights_grad, out_dgrad
213+
return in1_grad, in2_grad, weights_grad, out_dgrad

openequivariance/openequivariance/jax/TensorProductConv.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
logger = getLogger()
1616

17+
1718
def zeros_like(x):
1819
return jnp.zeros_like(x)
1920

21+
2022
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
2123
def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
2224
forward_call = jax.ffi.ffi_call(
@@ -34,9 +36,7 @@ def forward_fwd(
3436
return out, (X, Y, W, rows, cols)
3537

3638

37-
def forward_bwd(
38-
workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ
39-
):
39+
def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ):
4040
X, Y, W, rows, cols = res
4141
dX, dY, dW = backward(
4242
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
@@ -60,23 +60,29 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
6060
return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs)
6161

6262

63-
def backward_fwd(
64-
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
65-
):
66-
out = backward(
67-
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
68-
)
63+
def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
64+
out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs)
6965
return out, (X, Y, W, dZ, rows, cols)
7066

7167

72-
def backward_bwd(
73-
workspace, sender_perm, irrep_dtype, attrs, res, derivatives
74-
):
68+
def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives):
7569
X, Y, W, dZ, rows, cols = res
7670
ddX, ddY, ddW = derivatives
7771

7872
gX, gY, gW, gdZ = double_backward(
79-
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs
73+
X,
74+
Y,
75+
W,
76+
dZ,
77+
ddX,
78+
ddY,
79+
ddW,
80+
rows,
81+
cols,
82+
workspace,
83+
sender_perm,
84+
irrep_dtype,
85+
attrs,
8086
)
8187

8288
return gX, gY, gW, gdZ, None, None
@@ -340,4 +346,4 @@ def double_backward_cpu(
340346
np.asarray(in2_grad),
341347
np.asarray(weights_grad),
342348
np.asarray(out_dgrad),
343-
)
349+
)

tests/example_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,5 @@ def test_tutorial_jax(with_jax):
162162
Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1])
163163
print(jax.numpy.linalg.norm(Z))
164164

165-
func = lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)
166-
jitted = jax.jit(func)
167-
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))
165+
jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2))
166+
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))

0 commit comments

Comments
 (0)