Skip to content

Commit dd7c6c9

Browse files
author
Austin Glover
committed
remove changes to benchmark.py
1 parent d7af1b0 commit dd7c6c9

File tree

1 file changed

+17
-26
lines changed

1 file changed

+17
-26
lines changed

tests/benchmark.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646
diffdock_problems,
4747
mace_problems,
4848
nequip_problems,
49-
nequix_problems,
50-
seven_net_problems,
5149
)
5250

5351
from torch._functorch import config
@@ -93,19 +91,11 @@
9391

9492

9593
def benchmark_uvu(params):
96-
def get_problems():
97-
return (
98-
mace_problems()
99-
+ nequip_problems()
100-
+ nequix_problems()
101-
+ seven_net_problems()
102-
)
103-
104-
float64_problems = get_problems()
94+
float64_problems = mace_problems() + nequip_problems()
10595
for problem in float64_problems:
10696
problem.irrep_dtype = np.float64
10797
problem.weight_dtype = np.float64
108-
problems = get_problems() + float64_problems
98+
problems = mace_problems() + nequip_problems() + float64_problems
10999

110100
implementations = [implementation_map_tp[impl] for impl in params.implementations]
111101
directions = params.directions
@@ -290,20 +280,21 @@ def benchmark_convolution(params):
290280
graphs = download_graphs(params, filenames)
291281

292282
if not params.disable_bench:
293-
294-
def get_problems():
295-
return (
296-
mace_problems()
297-
+ nequip_problems()
298-
+ nequix_problems()
299-
+ seven_net_problems()
300-
)
301-
302-
float64_problems = get_problems()
303-
for problem in float64_problems:
304-
problem.irrep_dtype = np.float64
305-
problem.weight_dtype = np.float64
306-
configs = get_problems() + float64_problems
283+
configs = [
284+
ChannelwiseTPP(
285+
"128x0e+128x1o+128x2e",
286+
"1x0e+1x1o+1x2e+1x3o",
287+
"128x0e+128x1o+128x2e+128x3o",
288+
),
289+
ChannelwiseTPP(
290+
"128x0e+128x1o+128x2e",
291+
"1x0e+1x1o+1x2e+1x3o",
292+
"128x0e+128x1o+128x2e+128x3o",
293+
),
294+
] # MACE-large
295+
296+
configs[1].irrep_dtype = np.float64
297+
configs[1].weight_dtype = np.float64
307298

308299
bench = ConvBenchmarkSuite(configs, test_name="convolution")
309300

0 commit comments

Comments
 (0)