@@ -57,31 +57,23 @@ def test_contract_jit(
5757 c_opt_mod .weights .copy_ (c_base .weights )
5858 c_opt_mod = torch .jit .script (c_opt_mod )
5959
60- def c_opt (x , y , idx , dim , w = None ):
61- args = (x , y , idx , dim , w )
62- if w is None :
63- args = args [:- 1 ]
64- return c_opt_mod (* args )
60+ def c_opt (x , y , idx ):
61+ return c_opt_mod (x , y , idx )
6562
66- batchdim = 17
67- scatter_dim = torch . tensor ([ batchdim ], dtype = torch . long , device = device )
68- scatter_idxs = torch .arange ( batchdim , device = device )
63+ num_edges = 17
64+ num_nodes = 11
65+ scatter_idxs = torch .randint ( 0 , num_nodes , ( num_edges ,) , device = device )
6966 args_in = (
70- irreps_in1 .randn (batchdim , mul , - 1 , device = device ),
71- irreps_in2 .randn (batchdim , mul , - 1 , device = device ),
67+ irreps_in1 .randn (num_edges , mul , - 1 , device = device ),
68+ irreps_in2 .randn (num_nodes , mul , - 1 , device = device ),
7269 scatter_idxs ,
73- scatter_dim ,
74- torch .randn (
75- tuple (batchdim if e == - 1 else e for e in c_base .weights .shape )
76- ),
7770 )
78- args_in = args_in [:- 1 ]
7971
8072 for c in (c_base , c_opt ):
8173 assert_equivariant (
8274 c ,
8375 args_in = args_in ,
84- irreps_in = [irreps_in1 , irreps_in2 , None , None ],
76+ irreps_in = [irreps_in1 , irreps_in2 , None ],
8577 irreps_out = irreps_out ,
8678 # e3nn uses 1e-3, 1e-9
8779 tolerance = {torch .float32 : 1e-3 , torch .float64 : 1e-8 }[
@@ -164,7 +156,6 @@ def test_like_tp(
164156 # make input data
165157 batchdim = 1
166158 scatter_idxs = torch .arange (batchdim , device = device )
167- scatter_dim = torch .tensor ([batchdim ], dtype = torch .long , device = device )
168159 tensor1 = torch .randn (batchdim , mul , c .base_dim1 , device = device )
169160 tensor2 = torch .randn (batchdim , mul , c .base_dim2 , device = device )
170161
@@ -193,9 +184,7 @@ def test_like_tp(
193184 weights_tp = weights_tp .T
194185 # else weights are just (u,)
195186 weights_tp = weights_tp .reshape (- 1 )
196- c_out = _strided_to_cat (
197- irreps_out , mul , c (tensor1 , tensor2 , scatter_idxs , scatter_dim )
198- )
187+ c_out = _strided_to_cat (irreps_out , mul , c (tensor1 , tensor2 , scatter_idxs ))
199188 tp_out = tp (
200189 _strided_to_cat (irreps_in1 , mul , tensor1 ),
201190 _strided_to_cat (irreps_in2 , mul , tensor2 ),
0 commit comments