|
46 | 46 | diffdock_problems, |
47 | 47 | mace_problems, |
48 | 48 | nequip_problems, |
49 | | - nequix_problems, |
50 | | - seven_net_problems, |
51 | 49 | ) |
52 | 50 |
|
53 | 51 | from torch._functorch import config |
|
93 | 91 |
|
94 | 92 |
|
95 | 93 | def benchmark_uvu(params): |
96 | | - def get_problems(): |
97 | | - return ( |
98 | | - mace_problems() |
99 | | - + nequip_problems() |
100 | | - + nequix_problems() |
101 | | - + seven_net_problems() |
102 | | - ) |
103 | | - |
104 | | - float64_problems = get_problems() |
| 94 | + float64_problems = mace_problems() + nequip_problems() |
105 | 95 | for problem in float64_problems: |
106 | 96 | problem.irrep_dtype = np.float64 |
107 | 97 | problem.weight_dtype = np.float64 |
108 | | - problems = get_problems() + float64_problems |
| 98 | + problems = mace_problems() + nequip_problems() + float64_problems |
109 | 99 |
|
110 | 100 | implementations = [implementation_map_tp[impl] for impl in params.implementations] |
111 | 101 | directions = params.directions |
@@ -290,20 +280,21 @@ def benchmark_convolution(params): |
290 | 280 | graphs = download_graphs(params, filenames) |
291 | 281 |
|
292 | 282 | if not params.disable_bench: |
293 | | - |
294 | | - def get_problems(): |
295 | | - return ( |
296 | | - mace_problems() |
297 | | - + nequip_problems() |
298 | | - + nequix_problems() |
299 | | - + seven_net_problems() |
300 | | - ) |
301 | | - |
302 | | - float64_problems = get_problems() |
303 | | - for problem in float64_problems: |
304 | | - problem.irrep_dtype = np.float64 |
305 | | - problem.weight_dtype = np.float64 |
306 | | - configs = get_problems() + float64_problems |
| 283 | + configs = [ |
| 284 | + ChannelwiseTPP( |
| 285 | + "128x0e+128x1o+128x2e", |
| 286 | + "1x0e+1x1o+1x2e+1x3o", |
| 287 | + "128x0e+128x1o+128x2e+128x3o", |
| 288 | + ), |
| 289 | + ChannelwiseTPP( |
| 290 | + "128x0e+128x1o+128x2e", |
| 291 | + "1x0e+1x1o+1x2e+1x3o", |
| 292 | + "128x0e+128x1o+128x2e+128x3o", |
| 293 | + ), |
| 294 | + ] # MACE-large |
| 295 | + |
| 296 | + configs[1].irrep_dtype = np.float64 |
| 297 | + configs[1].weight_dtype = np.float64 |
307 | 298 |
|
308 | 299 | bench = ConvBenchmarkSuite(configs, test_name="convolution") |
309 | 300 |
|
|
0 commit comments