Skip to content

Commit 071981f

Browse files
committed
Make the benchmark wider
1 parent 6434ad8 commit 071981f

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

benchmark/bench_phyjax2d.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def ball_fall_phyjax2d(
6464
# 2. Initialize State
6565
rng = np.random.default_rng()
6666
x_coords = rng.uniform(100, 800, n_balls)
67-
y_coords = rng.uniform(50, 400, n_balls)
67+
y_coords = rng.uniform(150, 500, n_balls)
6868
pos_array = jnp.stack([jnp.array(x_coords), jnp.array(y_coords)], axis=-1)
6969

7070
sd = space.zeros_state().nested_replace("circle.p.xy", pos_array)
@@ -75,12 +75,12 @@ def ball_fall_phyjax2d(
7575
if debug_vis:
7676
# We define the range based on the window size/container
7777
visualizer = MglVisualizer(
78-
x_range=600.0,
79-
y_range=1000.0,
78+
x_range=900.0,
79+
y_range=600.0,
8080
space=space,
8181
stated=sd,
8282
title=f"Phyjax2D Debug: {n_balls} balls",
83-
figsize=(600, 900),
83+
figsize=(900, 600),
8484
)
8585
jit_step = jax.jit(step, static_argnums=(0,))
8686
start = datetime.now()
@@ -91,16 +91,17 @@ def ball_fall_phyjax2d(
9191
visualizer.close()
9292
return datetime.now() - start
9393
else:
94+
9495
@jax.jit
95-
def step(sd, vs):
96+
def n_step_fixed(sd, vs):
9697
sd, vs, _ = nstep(5, 0.6, space, sd, vs)
9798
return sd, vs.replace(pn=vs.pn * 0.6)
9899

99-
step(sd, vs)
100+
n_step_fixed(sd, vs)
100101

101102
start = datetime.now()
102103
for _ in range(n_iter // 5):
103-
sd, vs = step(sd, vs)
104+
sd, vs = n_step_fixed(sd, vs)
104105
return datetime.now() - start
105106

106107

0 commit comments

Comments
 (0)