1919from typing import Sequence
2020
2121import mujoco
22+ import numpy as np
2223import warp as wp
2324from absl import app
2425from absl import flags
6263 "clear_kernel_cache" , False , "Clear kernel cache (to calculate full JIT time)"
6364)
6465_EVENT_TRACE = flags .DEFINE_bool ("event_trace" , False , "Provide a full event trace" )
66+ _MEASURE_ALLOC = flags .DEFINE_bool (
67+ "measure_alloc" , False , "Measure how much of nconmax, njmax is used."
68+ )
6569
6670
6771def _main (argv : Sequence [str ]):
@@ -93,9 +97,10 @@ def _main(argv: Sequence[str]):
9397 print (
9498 f"Model nbody: { m .nbody } nv: { m .nv } ngeom: { m .ngeom } is_sparse: { _IS_SPARSE .value } solver: { _SOLVER .value } "
9599 )
100+ print (f"Params nconmax: { _NCONMAX .value } njmax: { _NJMAX .value } " )
96101 print (f"Data ncon: { d .ncon } nefc: { d .nefc } keyframe: { _KEYFRAME .value } " )
97102 print (f"Rolling out { _NSTEP .value } steps at dt = { m .opt .timestep :.3f} ..." )
98- jit_time , run_time , trace , steps = mjwarp .benchmark (
103+ jit_time , run_time , trace , steps , ncon , nefc = mjwarp .benchmark (
99104 mjwarp .__dict__ [_FUNCTION .value ],
100105 m ,
101106 d ,
@@ -107,6 +112,7 @@ def _main(argv: Sequence[str]):
107112 _NCONMAX .value ,
108113 _NJMAX .value ,
109114 _EVENT_TRACE .value ,
115+ _MEASURE_ALLOC .value ,
110116 )
111117
112118 name = argv [0 ]
@@ -136,6 +142,38 @@ def _print_trace(trace, indent):
136142 _print_trace (sub_trace , indent + 1 )
137143
138144 _print_trace (trace , 0 )
145+ if ncon and nefc :
146+ num_buckets = 10
147+ idx = 0
148+ ncon_matrix , nefc_matrix = [], []
149+ for i in range (num_buckets ):
150+ size = _NSTEP .value // num_buckets + (i < (_NSTEP .value % num_buckets ))
151+ ncon_arr = np .array (ncon [idx : idx + size ])
152+ nefc_arr = np .array (nefc [idx : idx + size ])
153+ ncon_matrix .append (
154+ [np .mean (ncon_arr ), np .std (ncon_arr ), np .min (ncon_arr ), np .max (ncon_arr )]
155+ )
156+ nefc_matrix .append (
157+ [np .mean (nefc_arr ), np .std (nefc_arr ), np .min (nefc_arr ), np .max (nefc_arr )]
158+ )
159+ idx += size
160+
161+ def _print_table (matrix , headers ):
162+ num_cols = len (headers )
163+ col_widths = [
164+ max (len (f"{ row [i ]:g} " ) for row in matrix ) for i in range (num_cols )
165+ ]
166+ col_widths = [max (col_widths [i ], len (headers [i ])) for i in range (num_cols )]
167+
168+ print (" " .join (f"{ headers [i ]:<{col_widths [i ]}} " for i in range (num_cols )))
169+ print ("-" * sum (col_widths ) + "--" * 3 ) # Separator line
170+ for row in matrix :
171+ print (" " .join (f"{ row [i ]:{col_widths [i ]}g} " for i in range (num_cols )))
172+
173+ print ("\n ncon alloc:\n " )
174+ _print_table (ncon_matrix , ("mean" , "std" , "min" , "max" ))
175+ print ("\n nefc alloc:\n " )
176+ _print_table (nefc_matrix , ("mean" , "std" , "min" , "max" ))
139177
140178 elif _OUTPUT .value == "tsv" :
141179 name = name .split ("/" )[- 1 ].replace ("testspeed_" , "" )
0 commit comments