Skip to content

Commit 2fb498e

Browse files
committed
XPBD (WIP)
1 parent f6e160e commit 2fb498e

File tree

6 files changed

+281
-43
lines changed

6 files changed

+281
-43
lines changed

benchmark/bench_phyjax2d.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import csv
2+
import functools
23
from datetime import datetime, timedelta
34
from pathlib import Path
45

@@ -7,7 +8,7 @@
78
import numpy as np
89
import typer
910

10-
from phyjax2d import SpaceBuilder, Vec2d, step
11+
from phyjax2d import SpaceBuilder, Vec2d, nstep, step
1112
from phyjax2d.moderngl_vis import MglVisualizer
1213

1314

@@ -23,11 +24,12 @@ def ball_fall_phyjax2d(
2324
builder = SpaceBuilder(
2425
gravity=(0.0, -900.0),
2526
dt=0.002,
26-
jacobi_damping=0.5,
27+
viscous_damping=0.6,
2728
n_velocity_iter=10,
28-
n_position_iter=2,
29-
bias_factor=0.02,
29+
n_position_iter=1,
30+
bias_factor=0.1,
3031
bounce_threshold=4,
32+
allowed_penetration=0.01,
3133
)
3234

3335
for _ in range(n_balls):
@@ -81,27 +83,20 @@ def ball_fall_phyjax2d(
8183
title=f"Phyjax2D Debug: {n_balls} balls",
8284
figsize=(600, 1000),
8385
)
84-
85-
# 4. Simulation Loop
86-
jit_step = jax.jit(step, static_argnums=(0,))
87-
88-
# Warm-up (Optional: only if you want to exclude first-run JIT from benchmark)
89-
sd, _, _ = jit_step(space, sd, vs)
90-
91-
start = datetime.now()
92-
for _ in range(n_iter):
93-
sd, _, _ = jit_step(space, sd, vs)
94-
95-
if visualizer is not None:
86+
jit_step = jax.jit(step, static_argnums=(0,))
87+
start = datetime.now()
88+
for _ in range(n_iter):
89+
sd, _, _ = jit_step(space, sd, vs)
9690
visualizer.render(state=sd)
9791
visualizer.show()
98-
99-
duration = datetime.now() - start
100-
101-
if visualizer is not None:
10292
visualizer.close()
103-
104-
return duration
93+
return datetime.now() - start
94+
else:
95+
nstep(100, 0.6, space, sd, vs)
96+
start = datetime.now()
97+
for _ in range(n_iter // 100):
98+
sd, _, _ = nstep(100, 0.6, space, sd, vs)
99+
return datetime.now() - start
105100

106101

107102
DEFAULT_COUNTS = [1000]

benchmark/bench_pymunk.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def ball_fall(n_balls: int, debug_vis: bool, n_iter: int = 1000) -> timedelta:
1313
space = pymunk.Space()
1414
# 1. Flip Gravity: Positive Y pulls "down" in PyGame coordinates
1515
space.gravity = (0, 900)
16+
space.iterations = 10
1617

1718
static_body = space.static_body
1819
# 2. Invert Container: Floor is now at Y=800, walls go up toward Y=50

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,5 @@ dev-dependencies = [
5858
"ipython >= 8.0",
5959
"pytest >= 8.3.3",
6060
"pymunk >= 6.0",
61-
"typer >= 0.12",
61+
"typer >= 0.21.1",
6262
]

src/phyjax2d/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
StateDict,
1212
Velocity,
1313
VelocitySolver,
14+
XpbdSolver,
1415
empty,
1516
get_relative_angle,
17+
nstep,
1618
step,
19+
step_xpbd,
1720
)
1821
from phyjax2d.raycast import (
1922
Raycast,

0 commit comments

Comments
 (0)