@@ -905,6 +905,7 @@ def substep(self, f):
905905 rigid_global_info = self ._rigid_global_info ,
906906 static_rigid_sim_config = self ._static_rigid_sim_config ,
907907 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
908+ errno = self ._errno ,
908909 )
909910 if self ._requires_grad :
910911 kernel_save_adjoint_cache (
@@ -916,27 +917,27 @@ def substep(self, f):
916917 )
917918
918919 def check_errno (self ):
919- # Note that errno must be evaluated BEFORE match because otherwise it will be evaluated for each case...
920- # See official documentation: https://docs.python.org/3.10/reference/compound_stmts.html#overview
921920 if gs .use_zerocopy :
922921 errno = ti_to_torch (self ._errno , copy = None ).item ()
923922 else :
924923 errno = kernel_get_errno (self ._errno )
925- match errno :
926- case 1 :
927- max_collision_pairs_broad = self .collider ._collider_info .max_collision_pairs_broad [None ]
928- gs .raise_exception (
929- f"Exceeding max number of broad phase candidate contact pairs ({ max_collision_pairs_broad } ). "
930- f"Please increase the value of RigidSolver's option 'multiplier_collision_broad_phase'."
931- )
932- case 2 :
933- max_contact_pairs = self .collider ._collider_info .max_contact_pairs [None ]
934- gs .raise_exception (
935- f"Exceeding max number of contact pairs ({ max_contact_pairs } ). Please increase the value of "
936- "RigidSolver's option 'max_collision_pairs'."
937- )
938- case 3 :
939- gs .raise_exception ("Invalid accelerations causing 'nan'. Please decrease Rigid simulation timestep." )
924+
925+ if errno & 0b00000000000000000000000000000001 :
926+ max_collision_pairs_broad = self .collider ._collider_info .max_collision_pairs_broad [None ]
927+ gs .raise_exception (
928+ f"Exceeding max number of broad phase candidate contact pairs ({ max_collision_pairs_broad } ). "
929+ f"Please increase the value of RigidSolver's option 'multiplier_collision_broad_phase'."
930+ )
931+ if errno & 0b00000000000000000000000000000010 :
932+ max_contact_pairs = self .collider ._collider_info .max_contact_pairs [None ]
933+ gs .raise_exception (
934+ f"Exceeding max number of contact pairs ({ max_contact_pairs } ). Please increase the value of "
935+ "RigidSolver's option 'max_collision_pairs'."
936+ )
937+ if errno & 0b00000000000000000000000000000100 :
938+ gs .raise_exception ("Invalid constraint forces causing 'nan'. Please decrease Rigid simulation timestep." )
939+ if errno & 0b00000000000000000000000000001000 :
940+ gs .raise_exception ("Invalid accelerations causing 'nan'. Please decrease Rigid simulation timestep." )
940941
941942 def _kernel_detect_collision (self ):
942943 self .collider .reset (cache_only = True )
@@ -1241,6 +1242,7 @@ def substep_pre_coupling_grad(self, f):
12411242 rigid_global_info = self ._rigid_global_info ,
12421243 static_rigid_sim_config = self ._static_rigid_sim_config ,
12431244 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
1245+ errno = self ._errno ,
12441246 )
12451247
12461248 # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel,
@@ -1336,6 +1338,7 @@ def substep_post_coupling(self, f):
13361338 rigid_global_info = self ._rigid_global_info ,
13371339 static_rigid_sim_config = self ._static_rigid_sim_config ,
13381340 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
1341+ errno = self ._errno ,
13391342 )
13401343 elif isinstance (self .sim .coupler , IPCCoupler ):
13411344 # For IPCCoupler, perform full rigid body computation in post-coupling phase
@@ -4399,6 +4402,7 @@ def kernel_step_2(
43994402 rigid_global_info : array_class .RigidGlobalInfo ,
44004403 static_rigid_sim_config : ti .template (),
44014404 contact_island_state : array_class .ContactIslandState ,
4405+ errno : array_class .V_ANNOTATION ,
44024406):
44034407 # Position, Velocity and Acceleration data must be consistent when computing links acceleration, otherwise it
44044408 # would not corresponds to anyting physical. There is no other way than doing this right before integration,
@@ -4457,6 +4461,7 @@ def kernel_step_2(
44574461 dofs_state = dofs_state ,
44584462 rigid_global_info = rigid_global_info ,
44594463 static_rigid_sim_config = static_rigid_sim_config ,
4464+ errno = errno ,
44604465 )
44614466
44624467 if ti .static (not static_rigid_sim_config .enable_mujoco_compatibility ):
@@ -5511,11 +5516,9 @@ def kernel_update_vgeoms(
55115516 _B = links_state .pos .shape [1 ]
55125517 ti .loop_config (serialize = ti .static (static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL ))
55135518 for i_g , i_b in ti .ndrange (n_vgeoms , _B ):
5519+ i_l = vgeoms_info .link_idx [i_g ]
55145520 vgeoms_state .pos [i_g , i_b ], vgeoms_state .quat [i_g , i_b ] = gu .ti_transform_pos_quat_by_trans_quat (
5515- vgeoms_info .pos [i_g ],
5516- vgeoms_info .quat [i_g ],
5517- links_state .pos [vgeoms_info .link_idx [i_g ], i_b ],
5518- links_state .quat [vgeoms_info .link_idx [i_g ], i_b ],
5521+ vgeoms_info .pos [i_g ], vgeoms_info .quat [i_g ], links_state .pos [i_l , i_b ], links_state .quat [i_l , i_b ]
55195522 )
55205523
55215524
@@ -6406,14 +6409,9 @@ def func_integrate(
64066409 else i_0
64076410 )
64086411
6409- # Prevent nan propagation
6410- is_valid = True
6411- if ti .static (not BW ):
6412- is_valid = ~ ti .math .isnan (dofs_state .acc [i_d , i_b ])
6413- if is_valid :
6414- dofs_state .vel_next [i_d , i_b ] = (
6415- dofs_state .vel [i_d , i_b ] + dofs_state .acc [i_d , i_b ] * rigid_global_info .substep_dt [None ]
6416- )
6412+ dofs_state .vel_next [i_d , i_b ] = (
6413+ dofs_state .vel [i_d , i_b ] + dofs_state .acc [i_d , i_b ] * rigid_global_info .substep_dt [None ]
6414+ )
64176415
64186416 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
64196417 for i_0 , i_b in (
@@ -6520,14 +6518,31 @@ def func_copy_next_to_curr(
65206518 dofs_state : array_class .DofsState ,
65216519 rigid_global_info : array_class .RigidGlobalInfo ,
65226520 static_rigid_sim_config : ti .template (),
6521+ errno : array_class .V_ANNOTATION ,
65236522):
6524- ti . loop_config ( serialize = static_rigid_sim_config . para_level < gs . PARA_LEVEL . ALL )
6525- for I in ti . grouped ( ti . ndrange ( * dofs_state .vel .shape )):
6526- dofs_state .vel [ I ] = dofs_state . vel_next [ I ]
6523+ n_qs = rigid_global_info . qpos . shape [ 0 ]
6524+ n_dofs = dofs_state .vel .shape [ 0 ]
6525+ _B = dofs_state .vel . shape [ 1 ]
65276526
65286527 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
6529- for I in ti .grouped (ti .ndrange (* rigid_global_info .qpos .shape )):
6530- rigid_global_info .qpos [I ] = rigid_global_info .qpos_next [I ]
6528+ for i_b in range (_B ):
6529+ # Prevent nan propagation
6530+ is_valid = True
6531+ for i_d in range (n_dofs ):
6532+ e = dofs_state .vel_next [i_d , i_b ]
6533+ is_valid &= not ti .math .isnan (e )
6534+ for i_q in range (n_qs ):
6535+ e = rigid_global_info .qpos_next [i_q , i_b ]
6536+ is_valid &= not ti .math .isnan (e )
6537+
6538+ if is_valid :
6539+ for i_d in range (n_dofs ):
6540+ dofs_state .vel [i_d , i_b ] = dofs_state .vel_next [i_d , i_b ]
6541+
6542+ for i_q in range (n_qs ):
6543+ rigid_global_info .qpos [i_q , i_b ] = rigid_global_info .qpos_next [i_q , i_b ]
6544+ else :
6545+ errno [None ] = errno [None ] | 0b00000000000000000000000000001000
65316546
65326547
65336548@ti .func
@@ -6923,8 +6938,9 @@ def kernel_update_geoms_render_T(
69236938 geom_T = gu .ti_trans_quat_to_T (
69246939 geoms_state .pos [i_g , i_b ] + rigid_global_info .envs_offset [i_b ], geoms_state .quat [i_g , i_b ], EPS
69256940 )
6926- for J in ti .static (ti .grouped (ti .ndrange (4 , 4 ))):
6927- geoms_render_T [(i_g , i_b , * J )] = ti .cast (geom_T [J ], ti .float32 )
6941+ if (ti .abs (geom_T ) < 1e20 ).all ():
6942+ for J in ti .static (ti .grouped (ti .ndrange (4 , 4 ))):
6943+ geoms_render_T [(i_g , i_b , * J )] = ti .cast (geom_T [J ], ti .float32 )
69286944
69296945
69306946@ti .kernel (fastcache = gs .use_fastcache )
@@ -6945,8 +6961,9 @@ def kernel_update_vgeoms_render_T(
69456961 geom_T = gu .ti_trans_quat_to_T (
69466962 vgeoms_state .pos [i_g , i_b ] + rigid_global_info .envs_offset [i_b ], vgeoms_state .quat [i_g , i_b ], EPS
69476963 )
6948- for J in ti .static (ti .grouped (ti .ndrange (4 , 4 ))):
6949- vgeoms_render_T [(i_g , i_b , * J )] = ti .cast (geom_T [J ], ti .float32 )
6964+ if (ti .abs (geom_T ) < 1e20 ).all ():
6965+ for J in ti .static (ti .grouped (ti .ndrange (4 , 4 ))):
6966+ vgeoms_render_T [(i_g , i_b , * J )] = ti .cast (geom_T [J ], ti .float32 )
69506967
69516968
69526969@ti .kernel (fastcache = gs .use_fastcache )
0 commit comments