Skip to content

Commit 83f4fed

Browse files
authored
[BUG FIX] Prevent nan to propagate in position and raise exception. (#2033)
* Clean up torch version error message. * Change default logging level to INFO systematically in unit tests. * Detect nan in constraint solver, stop state integration and raise exception.
1 parent d74f02c commit 83f4fed

File tree

5 files changed

+142
-123
lines changed

5 files changed

+142
-123
lines changed

genesis/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
try:
1919
import torch
20-
21-
if tuple(map(int, torch.__version__.split(".")[:2])) < (2, 8):
22-
raise ImportError(
23-
"'torch<2.8.0' is not supported. Please update pytorch manually: https://pytorch.org/get-started/locally/"
24-
)
2520
except ImportError as e:
2621
raise ImportError(
2722
"'torch' module not available. Please install pytorch manually: https://pytorch.org/get-started/locally/"
2823
) from e
24+
if tuple(map(int, torch.__version__.split(".")[:2])) < (2, 8):
25+
raise ImportError(
26+
"'torch<2.8.0' is not supported. Please update pytorch manually: https://pytorch.org/get-started/locally/"
27+
)
28+
2929
import numpy as np
3030

3131
from .constants import GS_ARCH, TI_ARCH

genesis/engine/solvers/mpm_solver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ def g2p(
455455
@ti.kernel
456456
def _is_state_valid(self, f: ti.i32) -> ti.i32:
457457
is_success = True
458-
for i_p, i_b in ti.ndrange(self._n_particles, self._B):
459-
if ti.math.isnan(self.particles[f, i_p, i_b].pos).any():
458+
for i_p, i_b, i_3 in ti.ndrange(self._n_particles, self._B, 3):
459+
if ti.math.isnan(self.particles[f, i_p, i_b].pos[i_3]):
460460
is_success = False
461461
return is_success
462462

@@ -506,6 +506,7 @@ def substep_post_coupling(self, f):
506506
self.sim.coupler.rigid_solver.links_state,
507507
self.sim.coupler.rigid_solver._rigid_global_info,
508508
)
509+
# FIXME: Use existing errno mechanism for this.
509510
if not self._is_state_valid(f):
510511
gs.raise_exception(
511512
"NaN detected in MPM states. Try reducing the time step size or adjusting simulation parameters."

genesis/engine/solvers/rigid/constraint_solver_decomp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def resolve(self):
196196
dofs_state=solver.dofs_state,
197197
constraint_state=self.constraint_state,
198198
static_rigid_sim_config=solver._static_rigid_sim_config,
199+
errno=solver._errno,
199200
)
200201

201202
if solver._options.noslip_iterations > 0:
@@ -1403,6 +1404,7 @@ def func_update_qacc(
14031404
dofs_state: array_class.DofsState,
14041405
constraint_state: array_class.ConstraintState,
14051406
static_rigid_sim_config: ti.template(),
1407+
errno: array_class.V_ANNOTATION,
14061408
):
14071409
n_dofs = dofs_state.acc.shape[0]
14081410
_B = dofs_state.acc.shape[1]
@@ -1412,6 +1414,8 @@ def func_update_qacc(
14121414
dofs_state.qf_constraint[i_d, i_b] = constraint_state.qfrc_constraint[i_d, i_b]
14131415
dofs_state.force[i_d, i_b] = dofs_state.qf_smooth[i_d, i_b] + constraint_state.qfrc_constraint[i_d, i_b]
14141416
constraint_state.qacc_ws[i_d, i_b] = constraint_state.qacc[i_d, i_b]
1417+
if ti.math.isnan(constraint_state.qacc[i_d, i_b]):
1418+
errno[None] = 3
14151419

14161420

14171421
@ti.kernel(fastcache=gs.use_fastcache)

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 128 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,8 @@ def check_errno(self):
875875
f"Exceeding max number of contact pairs ({max_contact_pairs}). Please increase the value of "
876876
"RigidSolver's option 'max_collision_pairs'."
877877
)
878+
case 3:
879+
gs.raise_exception("Invalid accelerations causing 'nan'. Please decrease Rigid simulation timestep.")
878880

879881
def _kernel_detect_collision(self):
880882
self.collider.reset(cache_only=True)
@@ -5534,50 +5536,135 @@ def func_integrate(
55345536
_B = dofs_state.ctrl_mode.shape[1]
55355537
n_dofs = dofs_state.ctrl_mode.shape[0]
55365538
n_links = links_info.root_idx.shape[0]
5539+
55375540
if ti.static(static_rigid_sim_config.use_hibernation):
55385541
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
55395542
for i_b in range(_B):
5543+
is_valid = True
55405544
for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]):
55415545
i_d = rigid_global_info.awake_dofs[i_d_, i_b]
5542-
dofs_state.vel[i_d, i_b] = (
5543-
dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None]
5544-
)
5546+
if ti.math.isnan(dofs_state.acc[i_d, i_b]):
5547+
is_valid = False
5548+
5549+
if is_valid:
5550+
for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]):
5551+
i_d = rigid_global_info.awake_dofs[i_d_, i_b]
5552+
dofs_state.vel[i_d, i_b] = (
5553+
dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None]
5554+
)
5555+
5556+
for i_l_ in range(rigid_global_info.n_awake_links[i_b]):
5557+
i_l = rigid_global_info.awake_links[i_l_, i_b]
5558+
I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
55455559

5560+
for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]):
5561+
I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
5562+
dof_start = joints_info.dof_start[I_j]
5563+
q_start = joints_info.q_start[I_j]
5564+
q_end = joints_info.q_end[I_j]
5565+
5566+
joint_type = joints_info.type[I_j]
5567+
5568+
if joint_type == gs.JOINT_TYPE.FREE:
5569+
rot = ti.Vector(
5570+
[
5571+
rigid_global_info.qpos[q_start + 3, i_b],
5572+
rigid_global_info.qpos[q_start + 4, i_b],
5573+
rigid_global_info.qpos[q_start + 5, i_b],
5574+
rigid_global_info.qpos[q_start + 6, i_b],
5575+
]
5576+
)
5577+
ang = (
5578+
ti.Vector(
5579+
[
5580+
dofs_state.vel[dof_start + 3, i_b],
5581+
dofs_state.vel[dof_start + 4, i_b],
5582+
dofs_state.vel[dof_start + 5, i_b],
5583+
]
5584+
)
5585+
* rigid_global_info.substep_dt[None]
5586+
)
5587+
qrot = gu.ti_rotvec_to_quat(ang, EPS)
5588+
rot = gu.ti_transform_quat_by_quat(qrot, rot)
5589+
pos = ti.Vector(
5590+
[
5591+
rigid_global_info.qpos[q_start, i_b],
5592+
rigid_global_info.qpos[q_start + 1, i_b],
5593+
rigid_global_info.qpos[q_start + 2, i_b],
5594+
]
5595+
)
5596+
vel = ti.Vector(
5597+
[
5598+
dofs_state.vel[dof_start, i_b],
5599+
dofs_state.vel[dof_start + 1, i_b],
5600+
dofs_state.vel[dof_start + 2, i_b],
5601+
]
5602+
)
5603+
pos = pos + vel * rigid_global_info.substep_dt[None]
5604+
for j in ti.static(range(3)):
5605+
rigid_global_info.qpos[q_start + j, i_b] = pos[j]
5606+
for j in ti.static(range(4)):
5607+
rigid_global_info.qpos[q_start + j + 3, i_b] = rot[j]
5608+
elif joint_type == gs.JOINT_TYPE.FIXED:
5609+
pass
5610+
elif joint_type == gs.JOINT_TYPE.SPHERICAL:
5611+
rot = ti.Vector(
5612+
[
5613+
rigid_global_info.qpos[q_start + 0, i_b],
5614+
rigid_global_info.qpos[q_start + 1, i_b],
5615+
rigid_global_info.qpos[q_start + 2, i_b],
5616+
rigid_global_info.qpos[q_start + 3, i_b],
5617+
]
5618+
)
5619+
ang = (
5620+
ti.Vector(
5621+
[
5622+
dofs_state.vel[dof_start + 3, i_b],
5623+
dofs_state.vel[dof_start + 4, i_b],
5624+
dofs_state.vel[dof_start + 5, i_b],
5625+
]
5626+
)
5627+
* rigid_global_info.substep_dt[None]
5628+
)
5629+
qrot = gu.ti_rotvec_to_quat(ang, EPS)
5630+
rot = gu.ti_transform_quat_by_quat(qrot, rot)
5631+
for j in ti.static(range(4)):
5632+
rigid_global_info.qpos[q_start + j, i_b] = rot[j]
5633+
5634+
else:
5635+
for j in range(q_end - q_start):
5636+
rigid_global_info.qpos[q_start + j, i_b] = (
5637+
rigid_global_info.qpos[q_start + j, i_b]
5638+
+ dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None]
5639+
)
5640+
else:
55465641
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
55475642
for i_b in range(_B):
5548-
for i_l_ in range(rigid_global_info.n_awake_links[i_b]):
5549-
i_l = rigid_global_info.awake_links[i_l_, i_b]
5550-
I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
5643+
is_valid = True
5644+
for i_d in range(n_dofs):
5645+
if ti.math.isnan(dofs_state.acc[i_d, i_b]):
5646+
is_valid = False
5647+
5648+
if is_valid:
5649+
for i_d in range(n_dofs):
5650+
dofs_state.vel[i_d, i_b] = (
5651+
dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None]
5652+
)
55515653

5552-
for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]):
5553-
I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
5554-
dof_start = joints_info.dof_start[I_j]
5555-
q_start = joints_info.q_start[I_j]
5556-
q_end = joints_info.q_end[I_j]
5654+
for i_l in range(n_links):
5655+
I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
5656+
if links_info.n_dofs[I_l] == 0:
5657+
continue
55575658

5659+
dof_start = links_info.dof_start[I_l]
5660+
q_start = links_info.q_start[I_l]
5661+
q_end = links_info.q_end[I_l]
5662+
5663+
i_j = links_info.joint_start[I_l]
5664+
I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
55585665
joint_type = joints_info.type[I_j]
55595666

55605667
if joint_type == gs.JOINT_TYPE.FREE:
5561-
rot = ti.Vector(
5562-
[
5563-
rigid_global_info.qpos[q_start + 3, i_b],
5564-
rigid_global_info.qpos[q_start + 4, i_b],
5565-
rigid_global_info.qpos[q_start + 5, i_b],
5566-
rigid_global_info.qpos[q_start + 6, i_b],
5567-
]
5568-
)
5569-
ang = (
5570-
ti.Vector(
5571-
[
5572-
dofs_state.vel[dof_start + 3, i_b],
5573-
dofs_state.vel[dof_start + 4, i_b],
5574-
dofs_state.vel[dof_start + 5, i_b],
5575-
]
5576-
)
5577-
* rigid_global_info.substep_dt[None]
5578-
)
5579-
qrot = gu.ti_rotvec_to_quat(ang, EPS)
5580-
rot = gu.ti_transform_quat_by_quat(qrot, rot)
55815668
pos = ti.Vector(
55825669
[
55835670
rigid_global_info.qpos[q_start, i_b],
@@ -5595,111 +5682,37 @@ def func_integrate(
55955682
pos = pos + vel * rigid_global_info.substep_dt[None]
55965683
for j in ti.static(range(3)):
55975684
rigid_global_info.qpos[q_start + j, i_b] = pos[j]
5598-
for j in ti.static(range(4)):
5599-
rigid_global_info.qpos[q_start + j + 3, i_b] = rot[j]
5600-
elif joint_type == gs.JOINT_TYPE.FIXED:
5601-
pass
5602-
elif joint_type == gs.JOINT_TYPE.SPHERICAL:
5685+
if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE:
5686+
rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0
56035687
rot = ti.Vector(
56045688
[
5605-
rigid_global_info.qpos[q_start + 0, i_b],
5606-
rigid_global_info.qpos[q_start + 1, i_b],
5607-
rigid_global_info.qpos[q_start + 2, i_b],
5608-
rigid_global_info.qpos[q_start + 3, i_b],
5689+
rigid_global_info.qpos[q_start + rot_offset + 0, i_b],
5690+
rigid_global_info.qpos[q_start + rot_offset + 1, i_b],
5691+
rigid_global_info.qpos[q_start + rot_offset + 2, i_b],
5692+
rigid_global_info.qpos[q_start + rot_offset + 3, i_b],
56095693
]
56105694
)
56115695
ang = (
56125696
ti.Vector(
56135697
[
5614-
dofs_state.vel[dof_start + 3, i_b],
5615-
dofs_state.vel[dof_start + 4, i_b],
5616-
dofs_state.vel[dof_start + 5, i_b],
5698+
dofs_state.vel[dof_start + rot_offset + 0, i_b],
5699+
dofs_state.vel[dof_start + rot_offset + 1, i_b],
5700+
dofs_state.vel[dof_start + rot_offset + 2, i_b],
56175701
]
56185702
)
56195703
* rigid_global_info.substep_dt[None]
56205704
)
56215705
qrot = gu.ti_rotvec_to_quat(ang, EPS)
56225706
rot = gu.ti_transform_quat_by_quat(qrot, rot)
56235707
for j in ti.static(range(4)):
5624-
rigid_global_info.qpos[q_start + j, i_b] = rot[j]
5625-
5708+
rigid_global_info.qpos[q_start + j + rot_offset, i_b] = rot[j]
56265709
else:
56275710
for j in range(q_end - q_start):
56285711
rigid_global_info.qpos[q_start + j, i_b] = (
56295712
rigid_global_info.qpos[q_start + j, i_b]
56305713
+ dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None]
56315714
)
56325715

5633-
else:
5634-
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
5635-
for i_d, i_b in ti.ndrange(n_dofs, _B):
5636-
dofs_state.vel[i_d, i_b] = (
5637-
dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None]
5638-
)
5639-
5640-
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
5641-
for i_l, i_b in ti.ndrange(n_links, _B):
5642-
I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
5643-
if links_info.n_dofs[I_l] == 0:
5644-
continue
5645-
5646-
dof_start = links_info.dof_start[I_l]
5647-
q_start = links_info.q_start[I_l]
5648-
q_end = links_info.q_end[I_l]
5649-
5650-
i_j = links_info.joint_start[I_l]
5651-
I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
5652-
joint_type = joints_info.type[I_j]
5653-
5654-
if joint_type == gs.JOINT_TYPE.FREE:
5655-
pos = ti.Vector(
5656-
[
5657-
rigid_global_info.qpos[q_start, i_b],
5658-
rigid_global_info.qpos[q_start + 1, i_b],
5659-
rigid_global_info.qpos[q_start + 2, i_b],
5660-
]
5661-
)
5662-
vel = ti.Vector(
5663-
[
5664-
dofs_state.vel[dof_start, i_b],
5665-
dofs_state.vel[dof_start + 1, i_b],
5666-
dofs_state.vel[dof_start + 2, i_b],
5667-
]
5668-
)
5669-
pos = pos + vel * rigid_global_info.substep_dt[None]
5670-
for j in ti.static(range(3)):
5671-
rigid_global_info.qpos[q_start + j, i_b] = pos[j]
5672-
if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE:
5673-
rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0
5674-
rot = ti.Vector(
5675-
[
5676-
rigid_global_info.qpos[q_start + rot_offset + 0, i_b],
5677-
rigid_global_info.qpos[q_start + rot_offset + 1, i_b],
5678-
rigid_global_info.qpos[q_start + rot_offset + 2, i_b],
5679-
rigid_global_info.qpos[q_start + rot_offset + 3, i_b],
5680-
]
5681-
)
5682-
ang = (
5683-
ti.Vector(
5684-
[
5685-
dofs_state.vel[dof_start + rot_offset + 0, i_b],
5686-
dofs_state.vel[dof_start + rot_offset + 1, i_b],
5687-
dofs_state.vel[dof_start + rot_offset + 2, i_b],
5688-
]
5689-
)
5690-
* rigid_global_info.substep_dt[None]
5691-
)
5692-
qrot = gu.ti_rotvec_to_quat(ang, EPS)
5693-
rot = gu.ti_transform_quat_by_quat(qrot, rot)
5694-
for j in ti.static(range(4)):
5695-
rigid_global_info.qpos[q_start + j + rot_offset, i_b] = rot[j]
5696-
else:
5697-
for j in range(q_end - q_start):
5698-
rigid_global_info.qpos[q_start + j, i_b] = (
5699-
rigid_global_info.qpos[q_start + j, i_b]
5700-
+ dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None]
5701-
)
5702-
57035716

57045717
@ti.func
57055718
def func_integrate_dq_entity(

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import ctypes
33
import gc
4+
import logging
45
import os
56
import re
67
import subprocess
@@ -495,7 +496,7 @@ def initialize_genesis(request, monkeypatch, tmp_path, backend, precision, perfo
495496
yield
496497
return
497498

498-
logging_level = request.config.getoption("--log-cli-level")
499+
logging_level = request.config.getoption("--log-cli-level", logging.INFO)
499500
debug = request.config.getoption("--dev")
500501

501502
if not taichi_offline_cache:

0 commit comments

Comments
 (0)