Skip to content

Commit 4f12d6b

Browse files
committed
Benchmark
1 parent 35bf575 commit 4f12d6b

File tree

5 files changed

+242
-4
lines changed

5 files changed

+242
-4
lines changed

benchmark/bench_phyjax2d.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import csv
2+
from datetime import datetime, timedelta
3+
from pathlib import Path
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
import typer
9+
10+
from phyjax2d import SpaceBuilder, Vec2d, step
11+
from phyjax2d.moderngl_vis import MglVisualizer
12+
13+
14+
def ball_fall_phyjax2d(
15+
n_balls: int,
16+
debug_vis: bool,
17+
n_iter: int = 1000,
18+
) -> timedelta:
19+
"""
20+
Simulates n_balls falling using phyjax2d.
21+
If debug_vis is True, uses MglVisualizer for rendering.
22+
"""
23+
builder = SpaceBuilder(
24+
gravity=(0.0, -900.0),
25+
dt=0.002,
26+
jacobi_damping=0.5,
27+
n_velocity_iter=10,
28+
n_position_iter=2,
29+
bias_factor=0.02,
30+
bounce_threshold=4,
31+
)
32+
33+
for _ in range(n_balls):
34+
builder.add_circle(
35+
radius=4.0,
36+
density=1.0 / (16 * np.pi),
37+
elasticity=0.5,
38+
friction=0.5,
39+
)
40+
41+
# Container setup
42+
builder.add_segment(
43+
p1=Vec2d(50.0, 50.0),
44+
p2=Vec2d(550.0, 50.0),
45+
elasticity=0.4,
46+
friction=0.5,
47+
)
48+
builder.add_segment(
49+
p1=Vec2d(50.0, 50.0),
50+
p2=Vec2d(50.0, 800.0),
51+
elasticity=0.4,
52+
friction=0.5,
53+
)
54+
builder.add_segment(
55+
p1=Vec2d(550.0, 50.0),
56+
p2=Vec2d(550.0, 800.0),
57+
elasticity=0.4,
58+
friction=0.5,
59+
)
60+
61+
space = builder.build()
62+
63+
# 2. Initialize State
64+
rng = np.random.default_rng()
65+
x_coords = rng.uniform(70, 530, n_balls)
66+
y_coords = rng.uniform(400, 1000, n_balls)
67+
pos_array = jnp.stack([jnp.array(x_coords), jnp.array(y_coords)], axis=-1)
68+
69+
sd = space.zeros_state().nested_replace("circle.p.xy", pos_array)
70+
vs = space.init_solver()
71+
72+
# 3. Initialize Visualizer
73+
visualizer = None
74+
if debug_vis:
75+
# We define the range based on the window size/container
76+
visualizer = MglVisualizer(
77+
x_range=600.0,
78+
y_range=1000.0,
79+
space=space,
80+
stated=sd,
81+
title=f"Phyjax2D Debug: {n_balls} balls",
82+
figsize=(600, 1000),
83+
)
84+
85+
# 4. Simulation Loop
86+
jit_step = jax.jit(step, static_argnums=(0,))
87+
88+
# Warm-up (Optional: only if you want to exclude first-run JIT from benchmark)
89+
sd, _, _ = jit_step(space, sd, vs)
90+
91+
start = datetime.now()
92+
for _ in range(n_iter):
93+
sd, _, _ = jit_step(space, sd, vs)
94+
95+
if visualizer is not None:
96+
visualizer.render(state=sd)
97+
visualizer.show()
98+
99+
duration = datetime.now() - start
100+
101+
if visualizer is not None:
102+
visualizer.close()
103+
104+
return duration
105+
106+
107+
def main(
108+
count_min: int = 1000,
109+
count_max: int = 100000,
110+
count_interval: int = 1000,
111+
debug_vis: bool = False,
112+
filename: Path = Path("bench.csv"),
113+
) -> None:
114+
results = []
115+
116+
for count in [x for x in range(count_min, count_max + 1, count_interval)]:
117+
duration = ball_fall_phyjax2d(count, debug_vis)
118+
# Convert timedelta to total seconds as a float for the CSV
119+
seconds = duration.total_seconds()
120+
results.append((count, seconds))
121+
122+
if not debug_vis:
123+
with open(filename, mode="w", newline="") as file:
124+
writer = csv.writer(file)
125+
writer.writerow(["n_balls", "duration_seconds"]) # Header
126+
writer.writerows(results)
127+
128+
129+
if __name__ == "__main__":
130+
typer.run(main)

benchmark/bench_pymunk.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import csv
2+
from datetime import datetime, timedelta
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import pygame
7+
import pymunk
8+
import pymunk.pygame_util
9+
import typer
10+
11+
12+
def ball_fall(n_balls: int, debug_vis: bool, n_iter: int = 1000) -> timedelta:
13+
space = pymunk.Space()
14+
# 1. Flip Gravity: Positive Y pulls "down" in PyGame coordinates
15+
space.gravity = (0, 900)
16+
17+
static_body = space.static_body
18+
# 2. Invert Container: Floor is now at Y=800, walls go up toward Y=50
19+
segments = [
20+
pymunk.Segment(static_body, (50, 800), (550, 800), 5), # Floor (Bottom)
21+
pymunk.Segment(static_body, (50, 800), (50, 50), 5), # Left Wall
22+
pymunk.Segment(static_body, (550, 800), (550, 50), 5), # Right Wall
23+
]
24+
for seg in segments:
25+
seg.elasticity = 0.4
26+
seg.friction = 0.5
27+
space.add(*segments)
28+
29+
radius = 4
30+
mass = 1
31+
moment = pymunk.moment_for_circle(mass, 0, radius)
32+
33+
rng = np.random.default_rng()
34+
# 3. Invert Spawn: Balls start near the top (Y=100 to 400)
35+
x_coords = rng.uniform(70, 530, n_balls)
36+
y_coords = rng.uniform(50, 400, n_balls)
37+
38+
for i in range(n_balls):
39+
body = pymunk.Body(mass, moment)
40+
body.position = (float(x_coords[i]), float(y_coords[i]))
41+
shape = pymunk.Circle(body, radius)
42+
shape.elasticity = 0.5
43+
shape.friction = 0.5
44+
space.add(body, shape)
45+
46+
if debug_vis:
47+
pygame.init()
48+
screen = pygame.display.set_mode((600, 850))
49+
pygame.display.set_caption(f"Inverted Gravity Benchmark: {n_balls} balls")
50+
draw_options = pymunk.pygame_util.DrawOptions(screen)
51+
else:
52+
screen = None
53+
draw_options = None
54+
55+
start = datetime.now()
56+
for _ in range(n_iter):
57+
if screen is not None and draw_options is not None:
58+
for event in pygame.event.get():
59+
if event.type == pygame.QUIT:
60+
pygame.quit()
61+
return datetime.now() - start
62+
63+
screen.fill((255, 255, 255))
64+
space.debug_draw(draw_options)
65+
pygame.display.flip()
66+
67+
space.step(0.002)
68+
69+
if debug_vis:
70+
pygame.quit()
71+
72+
return datetime.now() - start
73+
74+
75+
def main(
76+
count_min: int = 1000,
77+
count_max: int = 100000,
78+
count_interval: int = 1000,
79+
debug_vis: bool = False,
80+
filename: Path = Path("bench.csv"),
81+
) -> None:
82+
results = []
83+
84+
for count in [x for x in range(count_min, count_max + 1, count_interval)]:
85+
duration = ball_fall(count, debug_vis)
86+
# Convert timedelta to total seconds as a float for the CSV
87+
seconds = duration.total_seconds()
88+
results.append((count, seconds))
89+
90+
if not debug_vis:
91+
with open(filename, mode="w", newline="") as file:
92+
writer = csv.writer(file)
93+
writer.writerow(["n_balls", "duration_seconds"]) # Header
94+
writer.writerows(results)
95+
96+
97+
if __name__ == "__main__":
98+
typer.run(main)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ requires-python = ">= 3.10"
2121
dependencies = [
2222
"chex >= 0.1.86",
2323
"jax >= 0.4.26",
24+
"pygame>=2.6.1",
2425
]
2526
dynamic = ["version"]
2627

src/phyjax2d/impl.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ class Space:
861861
dt: float = 0.1
862862
linear_damping: float = 0.95
863863
angular_damping: float = 0.95
864+
jacobi_damping: float = 1.0
864865
bias_factor: float = 0.2
865866
n_velocity_iter: int = 6
866867
n_position_iter: int = 2
@@ -1014,7 +1015,7 @@ def __str__(self) -> str:
10141015
f" Timestep (dt): {self.dt}",
10151016
f" Damping: Linear={self.linear_damping}, Angular={self.angular_damping}",
10161017
f" Solver Iterations: Velocity={self.n_velocity_iter}, Position={self.n_position_iter}",
1017-
f" Safety: Linear Slop={self.linear_slop}, Max Velocity={self.max_velocity}"
1018+
f" Safety: Linear Slop={self.linear_slop}, Max Velocity={self.max_velocity}",
10181019
]
10191020
return "\n".join(lines)
10201021

@@ -1282,8 +1283,13 @@ def solve_constraints(
12821283
"""Resolve collisions by Sequential Impulse method"""
12831284
idx1, idx2 = space._ci_total.index1, space._ci_total.index2
12841285

1285-
def gather(a: jax.Array, b: jax.Array, orig: jax.Array) -> jax.Array:
1286-
return orig.at[idx1].add(a).at[idx2].add(b)
1286+
def gather(
1287+
a: jax.Array,
1288+
b: jax.Array,
1289+
orig: jax.Array,
1290+
damping: float = 1.0,
1291+
) -> jax.Array:
1292+
return orig.at[idx1].add(a * damping).at[idx2].add(b * damping)
12871293

12881294
p1, p2 = p.get_slice(idx1), p.get_slice(idx2)
12891295
v1, v2 = v.get_slice(idx1), v.get_slice(idx2)
@@ -1317,7 +1323,7 @@ def vstep(
13171323
(v.into_axy(), solver),
13181324
)
13191325
bv1, bv2 = apply_bounce(contact, helper, solver)
1320-
v_axy = gather(bv1, bv2, v_axy)
1326+
v_axy = gather(bv1, bv2, v_axy, damping=space.jacobi_damping)
13211327

13221328
def pstep(
13231329
_: int,

src/phyjax2d/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class SpaceBuilder:
195195
dt: float = 0.1
196196
linear_damping: float = 0.9
197197
angular_damping: float = 0.9
198+
jacobi_damping: float = 1.0
198199
bias_factor: float = 0.2
199200
n_velocity_iter: int = 6
200201
n_position_iter: int = 2
@@ -525,10 +526,12 @@ def build(self) -> Space:
525526
jnp.inf if self.max_angular_velocity is None else self.max_angular_velocity
526527
)
527528
return Space(
529+
dt=self.dt,
528530
gravity=jnp.array(self.gravity),
529531
shaped=shaped,
530532
linear_damping=linear_damping,
531533
angular_damping=angular_damping,
534+
jacobi_damping=self.jacobi_damping,
532535
bias_factor=self.bias_factor,
533536
n_velocity_iter=self.n_velocity_iter,
534537
n_position_iter=self.n_position_iter,

0 commit comments

Comments
 (0)