Skip to content

Commit 5672000

Browse files
authored
[FEATURE] Differentiable forward dynamics for rigid body sim. (#1808)
1 parent 8ccfbbe commit 5672000

File tree

13 files changed

+3297
-1381
lines changed

13 files changed

+3297
-1381
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from copy import copy
22
from itertools import chain
33
from typing import TYPE_CHECKING, Literal
4+
from functools import wraps
5+
import inspect
46

57
import gstaichi as ti
68
import numpy as np
@@ -19,6 +21,7 @@
1921
from genesis.utils import terrain as tu
2022
from genesis.utils import urdf as uu
2123
from genesis.utils.misc import DeprecationError, broadcast_tensor, sanitize_index, ti_to_torch
24+
from genesis.engine.states.entities import RigidEntityState
2225

2326
from ..base_entity import Entity
2427
from .rigid_equality import RigidEquality
@@ -31,6 +34,22 @@
3134
from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver
3235

3336

37+
# Wrapper to track the arguments of a function and save them in the target buffer
38+
def tracked(fun):
39+
sig = inspect.signature(fun)
40+
41+
@wraps(fun)
42+
def wrapper(self, *args, **kwargs):
43+
if self._update_tgt_while_set:
44+
bound = sig.bind(self, *args, **kwargs)
45+
bound.apply_defaults()
46+
args_dict = dict(tuple(bound.arguments.items())[1:])
47+
self._update_tgt(fun.__name__, args_dict)
48+
return fun(self, *args, **kwargs)
49+
50+
return wrapper
51+
52+
3453
@ti.data_oriented
3554
class RigidEntity(Entity):
3655
"""
@@ -97,6 +116,23 @@ def __init__(
97116

98117
self._load_model()
99118

119+
# Initialize target variables and checkpoint
120+
self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity")
121+
self._tgt = dict()
122+
self._tgt_buffer = list()
123+
self._ckpt = dict()
124+
self._update_tgt_while_set = self._solver._requires_grad
125+
126+
def _update_tgt(self, key, value):
127+
# Set [self._tgt] value while keeping the insertion order between keys. When a new key is inserted or an existing
128+
# key is updated, the new element should be inserted at the end of the dict. This is because we need to keep
129+
# the insertion order to correctly pass the gradients in the backward pass.
130+
self._tgt.pop(key, None)
131+
self._tgt[key] = value
132+
133+
def init_ckpt(self):
134+
pass
135+
100136
def _load_model(self):
101137
self._links = gs.List()
102138
self._joints = gs.List()
@@ -1598,6 +1634,92 @@ def plan_path(
15981634
# ------------------------------------------------------------------------------------
15991635
# ---------------------------------- control & io ------------------------------------
16001636
# ------------------------------------------------------------------------------------
1637+
def process_input(self, in_backward=False):
1638+
if in_backward:
1639+
# use negative index because buffer length might not be full
1640+
index = self._sim.cur_step_local - self._sim._steps_local
1641+
self._tgt = self._tgt_buffer[index].copy()
1642+
else:
1643+
self._tgt_buffer.append(self._tgt.copy())
1644+
1645+
update_tgt_while_set = self._update_tgt_while_set
1646+
# Apply targets in the order of insertion
1647+
for key in self._tgt.keys():
1648+
data_kwargs = self._tgt[key]
1649+
1650+
# We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would
1651+
# be in [tgt]
1652+
if "zero_velocity" in data_kwargs:
1653+
data_kwargs["zero_velocity"] = False
1654+
# Do not update [tgt], as input information is finalized at this point
1655+
self._update_tgt_while_set = False
1656+
1657+
match key:
1658+
case "set_pos":
1659+
self.set_pos(**data_kwargs)
1660+
case "set_quat":
1661+
self.set_quat(**data_kwargs)
1662+
case "set_dofs_velocity":
1663+
self.set_dofs_velocity(**data_kwargs)
1664+
case _:
1665+
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")
1666+
1667+
self._tgt = dict()
1668+
self._update_tgt_while_set = update_tgt_while_set
1669+
1670+
def process_input_grad(self):
1671+
index = self._sim.cur_step_local - self._sim._steps_local
1672+
for key in reversed(self._tgt_buffer[index].keys()):
1673+
data_kwargs = self._tgt_buffer[index][key]
1674+
1675+
match key:
1676+
# We need to unpack the data_kwargs because [_backward_from_ti] only supports positional arguments
1677+
case "set_pos":
1678+
pos = data_kwargs.pop("pos")
1679+
if pos.requires_grad:
1680+
pos._backward_from_ti(self.set_pos_grad, data_kwargs["envs_idx"], data_kwargs["relative"])
1681+
1682+
case "set_quat":
1683+
quat = data_kwargs.pop("quat")
1684+
if quat.requires_grad:
1685+
quat._backward_from_ti(self.set_quat_grad, data_kwargs["envs_idx"], data_kwargs["relative"])
1686+
1687+
case "set_dofs_velocity":
1688+
velocity = data_kwargs.pop("velocity")
1689+
# [velocity] could be None when we want to zero the velocity (see set_dofs_velocity of RigidSolver)
1690+
if velocity is not None and velocity.requires_grad:
1691+
velocity._backward_from_ti(
1692+
self.set_dofs_velocity_grad,
1693+
data_kwargs["dofs_idx_local"],
1694+
data_kwargs["envs_idx"],
1695+
)
1696+
case _:
1697+
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")
1698+
1699+
def save_ckpt(self, ckpt_name):
1700+
if ckpt_name not in self._ckpt:
1701+
self._ckpt[ckpt_name] = {}
1702+
self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy()
1703+
self._tgt_buffer.clear()
1704+
1705+
def load_ckpt(self, ckpt_name):
1706+
self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy()
1707+
1708+
def reset_grad(self):
1709+
self._tgt_buffer.clear()
1710+
1711+
@gs.assert_built
1712+
def get_state(self):
1713+
state = RigidEntityState(self, self._sim.cur_step_global)
1714+
1715+
solver_state = self._solver.get_state()
1716+
pos = solver_state.links_pos[:, self.base_link_idx]
1717+
quat = solver_state.links_quat[:, self.base_link_idx]
1718+
1719+
state._pos = pos
1720+
state._quat = quat
1721+
1722+
return state
16011723

16021724
def _get_global_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False):
16031725
# Handling default argument and special cases
@@ -1967,6 +2089,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None):
19672089
return self._solver.get_links_invweight(links_idx, envs_idx)
19682090

19692091
@gs.assert_built
2092+
@tracked
19702093
def set_pos(self, pos, envs_idx=None, *, zero_velocity=True, relative=False):
19712094
"""
19722095
Set position of the entity's base link.
@@ -1989,6 +2112,16 @@ def set_pos(self, pos, envs_idx=None, *, zero_velocity=True, relative=False):
19892112
self._solver.set_base_links_pos(pos, self._base_links_idx_, envs_idx, relative=relative)
19902113

19912114
@gs.assert_built
2115+
def set_pos_grad(self, envs_idx, relative, pos_grad):
2116+
self._solver.set_base_links_pos_grad(
2117+
self._base_links_idx_,
2118+
envs_idx,
2119+
relative,
2120+
pos_grad.data,
2121+
)
2122+
2123+
@gs.assert_built
2124+
@tracked
19922125
def set_quat(self, quat, envs_idx=None, *, zero_velocity=True, relative=False):
19932126
"""
19942127
Set quaternion of the entity's base link.
@@ -2010,6 +2143,15 @@ def set_quat(self, quat, envs_idx=None, *, zero_velocity=True, relative=False):
20102143
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True)
20112144
self._solver.set_base_links_quat(quat, self._base_links_idx_, envs_idx, relative=relative)
20122145

2146+
@gs.assert_built
2147+
def set_quat_grad(self, envs_idx, relative, quat_grad):
2148+
self._solver.set_base_links_quat_grad(
2149+
self._base_links_idx_,
2150+
envs_idx,
2151+
relative,
2152+
quat_grad.data,
2153+
)
2154+
20132155
@gs.assert_built
20142156
def get_verts(self):
20152157
"""
@@ -2169,6 +2311,7 @@ def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None
21692311
self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx)
21702312

21712313
@gs.assert_built
2314+
@tracked
21722315
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, skip_forward=False):
21732316
"""
21742317
Set the entity's dofs' velocity.
@@ -2185,6 +2328,11 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
21852328
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
21862329
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=skip_forward)
21872330

2331+
@gs.assert_built
2332+
def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad):
2333+
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
2334+
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data)
2335+
21882336
@gs.assert_built
21892337
def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zero_velocity=True):
21902338
"""

genesis/engine/simulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,9 @@ def f_global_to_s_global(self, f_global):
266266
# ------------------------------------------------------------------------------------
267267

268268
def step(self, in_backward=False):
269-
if self._rigid_only: # "Only Advance!" --Thomas Wade :P
269+
if self._rigid_only and not self._requires_grad: # "Only Advance!" --Thomas Wade :P
270270
for _ in range(self._substeps):
271-
self.rigid_solver.substep()
271+
self.rigid_solver.substep(self.cur_substep_local)
272272
self._cur_substep_global += 1
273273
else:
274274
self.process_input(in_backward=in_backward)

genesis/engine/solvers/rigid/constraint_noslip.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def kernel_build_efc_AR_b(
3535
rigid_solver.func_solve_mass_batched(
3636
constraint_state.Mgrad,
3737
constraint_state.Mgrad,
38+
array_class.PLACEHOLDER,
3839
i_b,
3940
entities_info=entities_info,
4041
rigid_global_info=rigid_global_info,
@@ -191,6 +192,7 @@ def kernel_dual_finish(
191192
rigid_solver.func_solve_mass_batched(
192193
vec=constraint_state.qfrc_constraint,
193194
out=constraint_state.qacc,
195+
out_bw=array_class.PLACEHOLDER,
194196
i_b=i_b,
195197
entities_info=entities_info,
196198
rigid_global_info=rigid_global_info,
@@ -283,6 +285,7 @@ def compute_A_diag(
283285
rigid_solver.func_solve_mass_batched(
284286
constraint_state.Mgrad,
285287
constraint_state.Mgrad,
288+
array_class.PLACEHOLDER,
286289
i_b,
287290
entities_info=entities_info,
288291
rigid_global_info=rigid_global_info,

genesis/engine/solvers/rigid/constraint_solver_decomp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,7 @@ def func_update_gradient(
20122012
rigid_solver.func_solve_mass_batched(
20132013
constraint_state.grad,
20142014
constraint_state.Mgrad,
2015+
array_class.PLACEHOLDER,
20152016
i_b,
20162017
entities_info=entities_info,
20172018
rigid_global_info=rigid_global_info,

genesis/engine/solvers/rigid/constraint_solver_decomp_island.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ def _func_update_gradient(self, island, i_b):
990990
i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity
991991
i_e = self.contact_island.entity_id[i_e_, i_b]
992992
self._solver.mass_mat_mask[i_e_, i_b] = True
993-
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b)
993+
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, array_class.PLACEHOLDER, i_b)
994994
for i_e in range(self._solver.n_entities):
995995
self._solver.mass_mat_mask[i_e, i_b] = True
996996
elif ti.static(self._solver_type == gs.constraint_solver.Newton):

0 commit comments

Comments
 (0)