4949_IS_SPARSE = flags .DEFINE_bool (
5050 "is_sparse" , True , "if model should create sparse mass matrices"
5151)
52- _NEFC_TOTAL = flags .DEFINE_integer (
53- "nefc_total" , 0 , "total number of efc for batch of worlds"
52+ _NCONMAX = flags .DEFINE_integer (
53+ "nconmax" , - 1 , "Maximum number of contacts in a batch physics step."
54+ )
55+ _NJMAX = flags .DEFINE_integer (
56+ "njmax" , - 1 , "Maximum number of constraints in a batch physics step."
5457)
5558_OUTPUT = flags .DEFINE_enum (
5659 "output" , "text" , ["text" , "tsv" ], "format to print results"
5760)
61+ _CLEAR_KERNEL_CACHE = flags .DEFINE_bool (
62+ "clear_kernel_cache" , False , "Clear kernel cache (to calculate full JIT time)"
63+ )
5864
5965
6066def _main (argv : Sequence [str ]):
@@ -74,20 +80,32 @@ def _main(argv: Sequence[str]):
7480 else :
7581 m .opt .jacobian = mujoco .mjtJacobian .mjJAC_DENSE
7682
83+ d = mujoco .MjData (m )
84+ if m .nkey > 0 :
85+ mujoco .mj_resetDataKeyframe (m , d , 0 )
86+ # populate some constraints
87+ mujoco .mj_forward (m , d )
88+
89+ if _CLEAR_KERNEL_CACHE .value :
90+ wp .clear_kernel_cache ()
91+
7792 print (
7893 f"Model nbody: { m .nbody } nv: { m .nv } ngeom: { m .ngeom } is_sparse: { _IS_SPARSE .value } "
7994 )
95+ print (f"Data ncon: { d .ncon } nefc: { d .nefc } " )
8096 print (f"Rolling out { _NSTEP .value } steps at dt = { m .opt .timestep :.3f} ..." )
8197 jit_time , run_time , steps = mjx .benchmark (
8298 mjx .__dict__ [_FUNCTION .value ],
8399 m ,
100+ d ,
84101 _NSTEP .value ,
85102 _BATCH_SIZE .value ,
86103 _UNROLL .value ,
87104 _SOLVER .value ,
88105 _ITERATIONS .value ,
89106 _LS_ITERATIONS .value ,
90- _NEFC_TOTAL .value ,
107+ _NCONMAX .value ,
108+ _NJMAX .value ,
91109 )
92110
93111 name = argv [0 ]
@@ -99,7 +117,7 @@ def _main(argv: Sequence[str]):
99117 Total simulation time: { run_time :.2f} s
100118 Total steps per second: { steps / run_time :,.0f}
101119 Total realtime factor: { steps * m .opt .timestep / run_time :,.2f} x
102- Total time per step: { 1e6 * run_time / steps :.2f} µs """ )
120+ Total time per step: { 1e9 * run_time / steps :.2f} ns """ )
103121 elif _OUTPUT .value == "tsv" :
104122 name = name .split ("/" )[- 1 ].replace ("testspeed_" , "" )
105123 print (f"{ name } \t jit: { jit_time :.2f} s\t steps/second: { steps / run_time :.0f} " )
0 commit comments