Skip to content

Commit 922c245

Browse files
authored
Merge pull request #111 from erikfrey/measure_alloc2
Add support to testspeed to measure ncon, nefc during unroll.
2 parents 51727df + 639902e commit 922c245

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

mujoco_warp/_src/test_util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def benchmark(
6666
nconmax: int = -1,
6767
njmax: int = -1,
6868
event_trace: bool = False,
69-
) -> Tuple[float, float, dict, int]:
69+
measure_alloc: bool = False,
70+
) -> Tuple[float, float, dict, int, list, list]:
7071
"""Benchmark a model."""
7172

7273
if solver == "cg":
@@ -90,9 +91,11 @@ def benchmark(
9091
jit_duration = jit_end - jit_beg
9192
wp.synchronize()
9293
trace = {}
94+
ncon = []
95+
nefc = []
9396

9497
with warp_util.EventTracer(enabled=event_trace) as tracer:
95-
# capture the whole smooth.kinematic() function as a CUDA graph
98+
# capture the whole function as a CUDA graph
9699
with wp.ScopedCapture() as capture:
97100
fn(m, d)
98101
graph = capture.graph
@@ -104,8 +107,12 @@ def benchmark(
104107
trace = _sum(trace, tracer.trace())
105108
else:
106109
trace = tracer.trace()
110+
if measure_alloc:
111+
wp.synchronize()
112+
ncon.append(d.ncon.numpy()[0])
113+
nefc.append(d.nefc.numpy()[0])
107114
wp.synchronize()
108115
run_end = time.perf_counter()
109116
run_duration = run_end - run_beg
110117

111-
return jit_duration, run_duration, trace, batch_size * nstep
118+
return jit_duration, run_duration, trace, batch_size * nstep, ncon, nefc

mujoco_warp/testspeed.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Sequence
2020

2121
import mujoco
22+
import numpy as np
2223
import warp as wp
2324
from absl import app
2425
from absl import flags
@@ -62,6 +63,9 @@
6263
"clear_kernel_cache", False, "Clear kernel cache (to calculate full JIT time)"
6364
)
6465
_EVENT_TRACE = flags.DEFINE_bool("event_trace", False, "Provide a full event trace")
66+
_MEASURE_ALLOC = flags.DEFINE_bool(
67+
"measure_alloc", False, "Measure how much of nconmax, njmax is used."
68+
)
6569

6670

6771
def _main(argv: Sequence[str]):
@@ -93,9 +97,10 @@ def _main(argv: Sequence[str]):
9397
print(
9498
f"Model nbody: {m.nbody} nv: {m.nv} ngeom: {m.ngeom} is_sparse: {_IS_SPARSE.value} solver: {_SOLVER.value}"
9599
)
100+
print(f"Params nconmax: {_NCONMAX.value} njmax: {_NJMAX.value}")
96101
print(f"Data ncon: {d.ncon} nefc: {d.nefc} keyframe: {_KEYFRAME.value}")
97102
print(f"Rolling out {_NSTEP.value} steps at dt = {m.opt.timestep:.3f}...")
98-
jit_time, run_time, trace, steps = mjwarp.benchmark(
103+
jit_time, run_time, trace, steps, ncon, nefc = mjwarp.benchmark(
99104
mjwarp.__dict__[_FUNCTION.value],
100105
m,
101106
d,
@@ -107,6 +112,7 @@ def _main(argv: Sequence[str]):
107112
_NCONMAX.value,
108113
_NJMAX.value,
109114
_EVENT_TRACE.value,
115+
_MEASURE_ALLOC.value,
110116
)
111117

112118
name = argv[0]
@@ -136,6 +142,38 @@ def _print_trace(trace, indent):
136142
_print_trace(sub_trace, indent + 1)
137143

138144
_print_trace(trace, 0)
145+
if ncon and nefc:
146+
num_buckets = 10
147+
idx = 0
148+
ncon_matrix, nefc_matrix = [], []
149+
for i in range(num_buckets):
150+
size = _NSTEP.value // num_buckets + (i < (_NSTEP.value % num_buckets))
151+
ncon_arr = np.array(ncon[idx : idx + size])
152+
nefc_arr = np.array(nefc[idx : idx + size])
153+
ncon_matrix.append(
154+
[np.mean(ncon_arr), np.std(ncon_arr), np.min(ncon_arr), np.max(ncon_arr)]
155+
)
156+
nefc_matrix.append(
157+
[np.mean(nefc_arr), np.std(nefc_arr), np.min(nefc_arr), np.max(nefc_arr)]
158+
)
159+
idx += size
160+
161+
def _print_table(matrix, headers):
162+
num_cols = len(headers)
163+
col_widths = [
164+
max(len(f"{row[i]:g}") for row in matrix) for i in range(num_cols)
165+
]
166+
col_widths = [max(col_widths[i], len(headers[i])) for i in range(num_cols)]
167+
168+
print(" ".join(f"{headers[i]:<{col_widths[i]}}" for i in range(num_cols)))
169+
print("-" * sum(col_widths) + "--" * 3) # Separator line
170+
for row in matrix:
171+
print(" ".join(f"{row[i]:{col_widths[i]}g}" for i in range(num_cols)))
172+
173+
print("\nncon alloc:\n")
174+
_print_table(ncon_matrix, ("mean", "std", "min", "max"))
175+
print("\nnefc alloc:\n")
176+
_print_table(nefc_matrix, ("mean", "std", "min", "max"))
139177

140178
elif _OUTPUT.value == "tsv":
141179
name = name.split("/")[-1].replace("testspeed_", "")

0 commit comments

Comments
 (0)