@@ -64,7 +64,7 @@ def ball_fall_phyjax2d(
6464 # 2. Initialize State
6565 rng = np .random .default_rng ()
6666 x_coords = rng .uniform (100 , 800 , n_balls )
67- y_coords = rng .uniform (50 , 400 , n_balls )
67+ y_coords = rng .uniform (150 , 500 , n_balls )
6868 pos_array = jnp .stack ([jnp .array (x_coords ), jnp .array (y_coords )], axis = - 1 )
6969
7070 sd = space .zeros_state ().nested_replace ("circle.p.xy" , pos_array )
@@ -75,12 +75,12 @@ def ball_fall_phyjax2d(
7575 if debug_vis :
7676 # We define the range based on the window size/container
7777 visualizer = MglVisualizer (
78- x_range = 600 .0 ,
79- y_range = 1000 .0 ,
78+ x_range = 900 .0 ,
79+ y_range = 600 .0 ,
8080 space = space ,
8181 stated = sd ,
8282 title = f"Phyjax2D Debug: { n_balls } balls" ,
83- figsize = (600 , 900 ),
83+ figsize = (900 , 600 ),
8484 )
8585 jit_step = jax .jit (step , static_argnums = (0 ,))
8686 start = datetime .now ()
@@ -91,16 +91,17 @@ def ball_fall_phyjax2d(
9191 visualizer .close ()
9292 return datetime .now () - start
9393 else :
94+
9495 @jax .jit
95- def step (sd , vs ):
96+ def n_step_fixed (sd , vs ):
9697 sd , vs , _ = nstep (5 , 0.6 , space , sd , vs )
9798 return sd , vs .replace (pn = vs .pn * 0.6 )
9899
99- step (sd , vs )
100+ n_step_fixed (sd , vs )
100101
101102 start = datetime .now ()
102103 for _ in range (n_iter // 5 ):
103- sd , vs = step (sd , vs )
104+ sd , vs = n_step_fixed (sd , vs )
104105 return datetime .now () - start
105106
106107
0 commit comments