@@ -86,13 +86,11 @@ def create_nacl_supercell(supercell_size=10):
8686]
8787
8888
89- def benchmark_forward (
90- contracter , input1 , input2 , scatter_idxs , num_atoms , warmup = 3 , n_iter = 10
91- ):
89+ def benchmark_forward (contracter , input1 , input2 , scatter_idxs , warmup = 3 , n_iter = 10 ):
9290 """Benchmark forward pass."""
9391 # warmup
9492 for _ in range (warmup ):
95- _ = contracter (input1 , input2 , scatter_idxs , num_atoms )
93+ _ = contracter (input1 , input2 , scatter_idxs )
9694 torch .cuda .synchronize ()
9795
9896 # benchmark
@@ -101,7 +99,7 @@ def benchmark_forward(
10199
102100 start_event .record ()
103101 for _ in range (n_iter ):
104- _ = contracter (input1 , input2 , scatter_idxs , num_atoms )
102+ _ = contracter (input1 , input2 , scatter_idxs )
105103 end_event .record ()
106104
107105 torch .cuda .synchronize ()
@@ -110,9 +108,7 @@ def benchmark_forward(
110108 return total_time_ms / n_iter
111109
112110
113- def benchmark_backward (
114- contracter , input1 , input2 , scatter_idxs , num_atoms , warmup = 3 , n_iter = 10
115- ):
111+ def benchmark_backward (contracter , input1 , input2 , scatter_idxs , warmup = 3 , n_iter = 10 ):
116112 """Benchmark full forward+backward pass.
117113
118114 Returns:
@@ -122,7 +118,7 @@ def benchmark_backward(
122118 for _ in range (warmup ):
123119 input1 .grad = None
124120 input2 .grad = None
125- out = contracter (input1 , input2 , scatter_idxs , num_atoms )
121+ out = contracter (input1 , input2 , scatter_idxs )
126122 grad_out = torch .randn_like (out )
127123 out .backward (grad_out )
128124 torch .cuda .synchronize ()
@@ -135,7 +131,7 @@ def benchmark_backward(
135131 for _ in range (n_iter ):
136132 input1 .grad = None
137133 input2 .grad = None
138- out = contracter (input1 , input2 , scatter_idxs , num_atoms )
134+ out = contracter (input1 , input2 , scatter_idxs )
139135 grad_out = torch .randn_like (out )
140136 out .backward (grad_out )
141137 end_event .record ()
@@ -188,7 +184,6 @@ def autotune(
188184 num_nodes = AtomicDataDict .num_nodes (data )
189185 num_edges = AtomicDataDict .num_edges (data )
190186 scatter_idxs = data [AtomicDataDict .EDGE_INDEX_KEY ][1 ]
191- num_atoms_tensor = torch .tensor ([num_nodes ], dtype = torch .int64 , device = device )
192187
193188 print (f" num_nodes: { num_nodes } " )
194189 print (f" num_edges: { num_edges } " )
@@ -206,9 +201,9 @@ def autotune(
206201 irreps_in2 = model_config ["irreps_in2" ]
207202 mul = model_config ["mul" ]
208203
209- # both inputs are edge-indexed, enable grad for backward
204+ # input1 is edge-indexed, input2 is node-indexed
210205 input1 = irreps_in1 .randn (num_edges , mul , - 1 , dtype = dtype , device = device )
211- input2 = irreps_in2 .randn (num_edges , mul , - 1 , dtype = dtype , device = device )
206+ input2 = irreps_in2 .randn (num_nodes , mul , - 1 , dtype = dtype , device = device )
212207 input1 .requires_grad_ (True )
213208 input2 .requires_grad_ (True )
214209
@@ -237,7 +232,6 @@ def autotune(
237232 input1 ,
238233 input2 ,
239234 scatter_idxs ,
240- num_atoms_tensor ,
241235 warmup = 5 ,
242236 n_iter = 20 ,
243237 )
@@ -247,7 +241,6 @@ def autotune(
247241 input1 ,
248242 input2 ,
249243 scatter_idxs ,
250- num_atoms_tensor ,
251244 warmup = 5 ,
252245 n_iter = 20 ,
253246 )
@@ -285,7 +278,6 @@ def autotune(
285278 input1 ,
286279 input2 ,
287280 scatter_idxs ,
288- num_atoms_tensor ,
289281 warmup = 5 ,
290282 n_iter = 20 ,
291283 )
@@ -295,7 +287,6 @@ def autotune(
295287 input1 ,
296288 input2 ,
297289 scatter_idxs ,
298- num_atoms_tensor ,
299290 warmup = 5 ,
300291 n_iter = 20 ,
301292 )
0 commit comments