Skip to content

Commit adee4b7

Browse files
committed
Modify bench
1 parent baf3588 commit adee4b7

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

benchmark/bench_phyjax2d.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,16 @@ def ball_fall_phyjax2d(
9191
visualizer.close()
9292
return datetime.now() - start
9393
else:
94-
nstep(100, 0.6, space, sd, vs)
94+
@jax.jit
95+
def step(sd, vs):
96+
sd, vs, _ = nstep(5, 0.6, space, sd, vs)
97+
return sd, vs.replace(pn=vs.pn * 0.6)
98+
99+
step(sd, vs)
100+
95101
start = datetime.now()
96-
for _ in range(n_iter // 100):
97-
sd, _, _ = nstep(100, 0.6, space, sd, vs)
102+
for _ in range(n_iter // 5):
103+
sd, vs = step(sd, vs)
98104
return datetime.now() - start
99105

100106

0 commit comments

Comments
 (0)