Skip to content

Commit 038772d

Browse files
authored
Merge pull request #44 from erikfrey/testspeed_fixes
Testspeed fixes
2 parents 746025e + 55b0bb4 commit 038772d

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

mujoco/mjx/_src/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,11 @@ def put_data(
353353
# TODO(team): move to Model?
354354
if nconmax == -1:
355355
# TODO(team): heuristic for nconmax
356-
nconmax = 512
356+
nconmax = max(512, mjd.ncon * nworld)
357357
d.nconmax = nconmax
358358
if njmax == -1:
359359
# TODO(team): heuristic for njmax
360-
njmax = 512
360+
njmax = max(512, mjd.nefc * nworld)
361361
d.njmax = njmax
362362

363363
if nworld * mjd.nefc > njmax:

mujoco/mjx/_src/test_util.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,41 +46,41 @@ def fixture(fname: str, keyframe: int = -1, sparse: bool = True):
4646

4747
def benchmark(
4848
fn: Callable[[types.Model, types.Data], None],
49-
m: mujoco.MjModel,
49+
mjm: mujoco.MjModel,
50+
mjd: mujoco.MjData,
5051
nstep: int = 1000,
5152
batch_size: int = 1024,
5253
unroll_steps: int = 1,
5354
solver: str = "cg",
5455
iterations: int = 1,
5556
ls_iterations: int = 4,
56-
nefc_total: int = 0,
57+
nconmax: int = -1,
58+
njmax: int = -1,
5759
) -> Tuple[float, float, int]:
5860
"""Benchmark a model."""
5961

6062
if solver == "cg":
61-
m.opt.solver = mujoco.mjtSolver.mjSOL_CG
63+
mjm.opt.solver = mujoco.mjtSolver.mjSOL_CG
6264
elif solver == "newton":
63-
m.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
65+
mjm.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
6466

65-
m.opt.iterations = iterations
66-
m.opt.ls_iterations = ls_iterations
67+
mjm.opt.iterations = iterations
68+
mjm.opt.ls_iterations = ls_iterations
6769

68-
mx = io.put_model(m)
69-
dx = io.make_data(m, nworld=batch_size, njmax=nefc_total)
70-
dx.nefc_total = wp.array([nefc_total], dtype=wp.int32, ndim=1)
70+
m = io.put_model(mjm)
71+
d = io.put_data(mjm, mjd, nworld=batch_size, nconmax=nconmax, njmax=njmax)
7172

72-
wp.clear_kernel_cache()
7373
jit_beg = time.perf_counter()
74-
fn(mx, dx)
74+
fn(m, d)
7575
# double warmup to work around issues with compilation during graph capture:
76-
fn(mx, dx)
76+
fn(m, d)
7777
jit_end = time.perf_counter()
7878
jit_duration = jit_end - jit_beg
7979
wp.synchronize()
8080

8181
# capture the whole smooth.kinematic() function as a CUDA graph
8282
with wp.ScopedCapture() as capture:
83-
fn(mx, dx)
83+
fn(m, d)
8484
graph = capture.graph
8585

8686
run_beg = time.perf_counter()

mujoco/mjx/testspeed.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,18 @@
4949
_IS_SPARSE = flags.DEFINE_bool(
5050
"is_sparse", True, "if model should create sparse mass matrices"
5151
)
52-
_NEFC_TOTAL = flags.DEFINE_integer(
53-
"nefc_total", 0, "total number of efc for batch of worlds"
52+
_NCONMAX = flags.DEFINE_integer(
53+
"nconmax", -1, "Maximum number of contacts in a batch physics step."
54+
)
55+
_NJMAX = flags.DEFINE_integer(
56+
"njmax", -1, "Maximum number of constraints in a batch physics step."
5457
)
5558
_OUTPUT = flags.DEFINE_enum(
5659
"output", "text", ["text", "tsv"], "format to print results"
5760
)
61+
_CLEAR_KERNEL_CACHE = flags.DEFINE_bool(
62+
"clear_kernel_cache", False, "Clear kernel cache (to calculate full JIT time)"
63+
)
5864

5965

6066
def _main(argv: Sequence[str]):
@@ -74,20 +80,32 @@ def _main(argv: Sequence[str]):
7480
else:
7581
m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
7682

83+
d = mujoco.MjData(m)
84+
if m.nkey > 0:
85+
mujoco.mj_resetDataKeyframe(m, d, 0)
86+
# populate some constraints
87+
mujoco.mj_forward(m, d)
88+
89+
if _CLEAR_KERNEL_CACHE.value:
90+
wp.clear_kernel_cache()
91+
7792
print(
7893
f"Model nbody: {m.nbody} nv: {m.nv} ngeom: {m.ngeom} is_sparse: {_IS_SPARSE.value}"
7994
)
95+
print(f"Data ncon: {d.ncon} nefc: {d.nefc}")
8096
print(f"Rolling out {_NSTEP.value} steps at dt = {m.opt.timestep:.3f}...")
8197
jit_time, run_time, steps = mjx.benchmark(
8298
mjx.__dict__[_FUNCTION.value],
8399
m,
100+
d,
84101
_NSTEP.value,
85102
_BATCH_SIZE.value,
86103
_UNROLL.value,
87104
_SOLVER.value,
88105
_ITERATIONS.value,
89106
_LS_ITERATIONS.value,
90-
_NEFC_TOTAL.value,
107+
_NCONMAX.value,
108+
_NJMAX.value,
91109
)
92110

93111
name = argv[0]
@@ -99,7 +117,7 @@ def _main(argv: Sequence[str]):
99117
Total simulation time: {run_time:.2f} s
100118
Total steps per second: {steps / run_time:,.0f}
101119
Total realtime factor: {steps * m.opt.timestep / run_time:,.2f} x
102-
Total time per step: {1e6 * run_time / steps:.2f} µs""")
120+
Total time per step: {1e9 * run_time / steps:.2f} ns""")
103121
elif _OUTPUT.value == "tsv":
104122
name = name.split("/")[-1].replace("testspeed_", "")
105123
print(f"{name}\tjit: {jit_time:.2f}s\tsteps/second: {steps / run_time:.0f}")

0 commit comments

Comments
 (0)