@@ -233,8 +233,6 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
233233 'collision_pair' : d ._impl .collision_pair .shape ,
234234 'collision_pairid' : d ._impl .collision_pairid .shape ,
235235 'collision_worldid' : d ._impl .collision_worldid .shape ,
236- 'geom_xmat' : d .geom_xmat .shape ,
237- 'geom_xpos' : d .geom_xpos .shape ,
238236 'nacon' : d ._impl .nacon .shape ,
239237 'ncollision' : d ._impl .ncollision .shape ,
240238 'contact__dim' : d ._impl .contact__dim .shape ,
@@ -253,15 +251,13 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
253251 }
254252 jf = ffi .jax_callable_variadic_tuple (
255253 _collision_shim ,
256- num_outputs = 20 ,
254+ num_outputs = 18 ,
257255 output_dims = output_dims ,
258256 vmap_method = None ,
259257 in_out_argnames = {
260258 'collision_pair' ,
261259 'collision_pairid' ,
262260 'collision_worldid' ,
263- 'geom_xmat' ,
264- 'geom_xpos' ,
265261 'nacon' ,
266262 'ncollision' ,
267263 'contact__dim' ,
@@ -278,6 +274,27 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
278274 'contact__type' ,
279275 'contact__worldid' ,
280276 },
277+ stage_in_argnames = {
278+ 'geom_aabb' ,
279+ 'geom_friction' ,
280+ 'geom_gap' ,
281+ 'geom_margin' ,
282+ 'geom_rbound' ,
283+ 'geom_size' ,
284+ 'geom_solimp' ,
285+ 'geom_solmix' ,
286+ 'geom_solref' ,
287+ 'geom_xmat' ,
288+ 'geom_xpos' ,
289+ 'hfield_data' ,
290+ 'pair_friction' ,
291+ 'pair_gap' ,
292+ 'pair_margin' ,
293+ 'pair_solimp' ,
294+ 'pair_solref' ,
295+ 'pair_solreffriction' ,
296+ },
297+ stage_out_argnames = {},
281298 graph_mode = m .opt ._impl .graph_mode ,
282299 )
283300 out = jf (
@@ -373,23 +390,21 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
373390 '_impl.collision_pair' : out [0 ],
374391 '_impl.collision_pairid' : out [1 ],
375392 '_impl.collision_worldid' : out [2 ],
376- 'geom_xmat' : out [3 ],
377- 'geom_xpos' : out [4 ],
378- '_impl.nacon' : out [5 ],
379- '_impl.ncollision' : out [6 ],
380- '_impl.contact__dim' : out [7 ],
381- '_impl.contact__dist' : out [8 ],
382- '_impl.contact__frame' : out [9 ],
383- '_impl.contact__friction' : out [10 ],
384- '_impl.contact__geom' : out [11 ],
385- '_impl.contact__geomcollisionid' : out [12 ],
386- '_impl.contact__includemargin' : out [13 ],
387- '_impl.contact__pos' : out [14 ],
388- '_impl.contact__solimp' : out [15 ],
389- '_impl.contact__solref' : out [16 ],
390- '_impl.contact__solreffriction' : out [17 ],
391- '_impl.contact__type' : out [18 ],
392- '_impl.contact__worldid' : out [19 ],
393+ '_impl.nacon' : out [3 ],
394+ '_impl.ncollision' : out [4 ],
395+ '_impl.contact__dim' : out [5 ],
396+ '_impl.contact__dist' : out [6 ],
397+ '_impl.contact__frame' : out [7 ],
398+ '_impl.contact__friction' : out [8 ],
399+ '_impl.contact__geom' : out [9 ],
400+ '_impl.contact__geomcollisionid' : out [10 ],
401+ '_impl.contact__includemargin' : out [11 ],
402+ '_impl.contact__pos' : out [12 ],
403+ '_impl.contact__solimp' : out [13 ],
404+ '_impl.contact__solref' : out [14 ],
405+ '_impl.contact__solreffriction' : out [15 ],
406+ '_impl.contact__type' : out [16 ],
407+ '_impl.contact__worldid' : out [17 ],
393408 })
394409 return d
395410
0 commit comments