1414
1515logger = getLogger ()
1616
17+
1718def zeros_like (x ):
1819 return jnp .zeros_like (x )
1920
21+
2022@partial (jax .custom_vjp , nondiff_argnums = (5 , 6 , 7 , 8 , 9 ))
2123def forward (X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ):
2224 forward_call = jax .ffi .ffi_call (
@@ -34,9 +36,7 @@ def forward_fwd(
3436 return out , (X , Y , W , rows , cols )
3537
3638
37- def forward_bwd (
38- workspace , sender_perm , L3_dim , irrep_dtype , attrs , res , dZ
39- ):
39+ def forward_bwd (workspace , sender_perm , L3_dim , irrep_dtype , attrs , res , dZ ):
4040 X , Y , W , rows , cols = res
4141 dX , dY , dW = backward (
4242 X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs
@@ -60,23 +60,29 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
6060 return backward_call (X , Y , W , dZ , rows , cols , workspace , sender_perm , ** attrs )
6161
6262
63- def backward_fwd (
64- X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs
65- ):
66- out = backward (
67- X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs
68- )
63+ def backward_fwd (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs ):
64+ out = backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs )
6965 return out , (X , Y , W , dZ , rows , cols )
7066
7167
72- def backward_bwd (
73- workspace , sender_perm , irrep_dtype , attrs , res , derivatives
74- ):
68+ def backward_bwd (workspace , sender_perm , irrep_dtype , attrs , res , derivatives ):
7569 X , Y , W , dZ , rows , cols = res
7670 ddX , ddY , ddW = derivatives
7771
7872 gX , gY , gW , gdZ = double_backward (
79- X , Y , W , dZ , ddX , ddY , ddW , rows , cols , workspace , sender_perm , irrep_dtype , attrs
73+ X ,
74+ Y ,
75+ W ,
76+ dZ ,
77+ ddX ,
78+ ddY ,
79+ ddW ,
80+ rows ,
81+ cols ,
82+ workspace ,
83+ sender_perm ,
84+ irrep_dtype ,
85+ attrs ,
8086 )
8187
8288 return gX , gY , gW , gdZ , None , None
@@ -340,4 +346,4 @@ def double_backward_cpu(
340346 np .asarray (in2_grad ),
341347 np .asarray (weights_grad ),
342348 np .asarray (out_dgrad ),
343- )
349+ )
0 commit comments