@@ -403,74 +403,68 @@ def factor_m(m: Model, d: Data):
403403def rne (m : Model , d : Data ):
404404 """Computes inverse dynamics using Newton-Euler algorithm."""
405405
406- cacc = wp .zeros (shape = (d .nworld , m .nbody ), dtype = wp .spatial_vector )
407- cfrc = wp .zeros (shape = (d .nworld , m .nbody ), dtype = wp .spatial_vector )
408-
409406 @kernel
410- def cacc_gravity (m : Model , cacc : wp . array ( dtype = wp . spatial_vector , ndim = 2 ) ):
407+ def cacc_gravity (m : Model , d : Data ):
411408 worldid = wp .tid ()
412- cacc [worldid , 0 ] = wp .spatial_vector (wp .vec3 (0.0 ), - m .opt .gravity )
409+ d . rne_cacc [worldid , 0 ] = wp .spatial_vector (wp .vec3 (0.0 ), - m .opt .gravity )
413410
414411 @kernel
415412 def cacc_level (
416413 m : Model ,
417414 d : Data ,
418- cacc : wp .array (dtype = wp .spatial_vector , ndim = 2 ),
419415 leveladr : int ,
420416 ):
421417 worldid , nodeid = wp .tid ()
422418 bodyid = m .body_tree [leveladr + nodeid ]
423419 dofnum = m .body_dofnum [bodyid ]
424420 pid = m .body_parentid [bodyid ]
425421 dofadr = m .body_dofadr [bodyid ]
426- local_cacc = cacc [worldid , pid ]
422+ local_cacc = d . rne_cacc [worldid , pid ]
427423 for i in range (dofnum ):
428424 local_cacc += d .cdof_dot [worldid , dofadr + i ] * d .qvel [worldid , dofadr + i ]
429- cacc [worldid , bodyid ] = local_cacc
425+ d . rne_cacc [worldid , bodyid ] = local_cacc
430426
431427 @kernel
432- def frc_fn (
433- d : Data ,
434- cfrc : wp .array (dtype = wp .spatial_vector , ndim = 2 ),
435- cacc : wp .array (dtype = wp .spatial_vector , ndim = 2 ),
436- ):
428+ def frc_fn (d : Data ):
437429 worldid , bodyid = wp .tid ()
438- frc = math .inert_vec (d .cinert [worldid , bodyid ], cacc [worldid , bodyid ])
430+ frc = math .inert_vec (d .cinert [worldid , bodyid ], d . rne_cacc [worldid , bodyid ])
439431 frc += math .motion_cross_force (
440432 d .cvel [worldid , bodyid ],
441433 math .inert_vec (d .cinert [worldid , bodyid ], d .cvel [worldid , bodyid ]),
442434 )
443- cfrc [worldid , bodyid ] + = frc
435+ d . rne_cfrc [worldid , bodyid ] = frc
444436
445437 @kernel
446- def cfrc_fn (m : Model , cfrc : wp . array ( dtype = wp . spatial_vector , ndim = 2 ) , leveladr : int ):
438+ def cfrc_fn (m : Model , d : Data , leveladr : int ):
447439 worldid , nodeid = wp .tid ()
448440 bodyid = m .body_tree [leveladr + nodeid ]
449441 pid = m .body_parentid [bodyid ]
450- wp .atomic_add (cfrc [worldid ], pid , cfrc [worldid , bodyid ])
442+ wp .atomic_add (d . rne_cfrc [worldid ], pid , d . rne_cfrc [worldid , bodyid ])
451443
452444 @kernel
453- def qfrc_bias (m : Model , d : Data , cfrc : wp . array ( dtype = wp . spatial_vector , ndim = 2 ) ):
445+ def qfrc_bias (m : Model , d : Data ):
454446 worldid , dofid = wp .tid ()
455447 bodyid = m .dof_bodyid [dofid ]
456- d .qfrc_bias [worldid , dofid ] = wp .dot (d .cdof [worldid , dofid ], cfrc [worldid , bodyid ])
448+ d .qfrc_bias [worldid , dofid ] = wp .dot (
449+ d .cdof [worldid , dofid ], d .rne_cfrc [worldid , bodyid ]
450+ )
457451
458- wp .launch (cacc_gravity , dim = [d .nworld ], inputs = [m , cacc ])
452+ wp .launch (cacc_gravity , dim = [d .nworld ], inputs = [m , d ])
459453
460454 body_treeadr = m .body_treeadr .numpy ()
461455 for i in range (len (body_treeadr )):
462456 beg = body_treeadr [i ]
463457 end = m .nbody if i == len (body_treeadr ) - 1 else body_treeadr [i + 1 ]
464- wp .launch (cacc_level , dim = (d .nworld , end - beg ), inputs = [m , d , cacc , beg ])
458+ wp .launch (cacc_level , dim = (d .nworld , end - beg ), inputs = [m , d , beg ])
465459
466- wp .launch (frc_fn , dim = [d .nworld , m .nbody ], inputs = [d , cfrc , cacc ])
460+ wp .launch (frc_fn , dim = [d .nworld , m .nbody ], inputs = [d ])
467461
468462 for i in reversed (range (len (body_treeadr ))):
469463 beg = body_treeadr [i ]
470464 end = m .nbody if i == len (body_treeadr ) - 1 else body_treeadr [i + 1 ]
471- wp .launch (cfrc_fn , dim = [d .nworld , end - beg ], inputs = [m , cfrc , beg ])
465+ wp .launch (cfrc_fn , dim = [d .nworld , end - beg ], inputs = [m , d , beg ])
472466
473- wp .launch (qfrc_bias , dim = [d .nworld , m .nv ], inputs = [m , d , cfrc ])
467+ wp .launch (qfrc_bias , dim = [d .nworld , m .nv ], inputs = [m , d ])
474468
475469
476470@event_scope
0 commit comments