Skip to content

Commit 0c2a68f

Browse files
authored
[BUG FIX] Fix link velocity computation for kinematic entity. (#2508)
* Fix FEM get pos. * Fix link velocity computation for kinematic entity. * Fix kinematic solver step not properly tracking dirty state. * Remove unsupported link reference arg for kinematic entity.
1 parent af8cefd commit 0c2a68f

File tree

7 files changed

+180
-109
lines changed

7 files changed

+180
-109
lines changed

genesis/engine/entities/fem_entity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,10 +971,10 @@ def remove_vertex_constraints(self, verts_idx_local=None, envs_idx=None):
971971
@qd.kernel
972972
def _kernel_get_verts_pos(self, f: qd.i32, pos: qd.types.ndarray(), verts_idx: qd.types.ndarray()):
973973
# get current position of vertices
974-
for i_v, i_b in qd.ndrange(verts_idx.shape[0], verts_idx.shape[1]):
975-
i_global = verts_idx[i_v, i_b] + self.v_start
974+
for i_b, i_v_ in qd.ndrange(verts_idx.shape[0], verts_idx.shape[1]):
975+
i_v = verts_idx[i_b, i_v_] + self.v_start
976976
for j in qd.static(range(3)):
977-
pos[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].pos[j]
977+
pos[i_b, i_v_, j] = self._solver.elements_v[f, i_v, i_b].pos[j]
978978

979979
def get_el2v(self):
980980
"""

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,14 +1230,7 @@ def get_ang(self, envs_idx=None):
12301230
return self._solver.get_links_ang(self.base_link_idx, envs_idx)[..., 0, :]
12311231

12321232
@gs.assert_built
1233-
def get_links_pos(
1234-
self,
1235-
links_idx_local=None,
1236-
envs_idx=None,
1237-
*,
1238-
ref: Literal["link_origin", "link_com", "root_com"] = "link_origin",
1239-
unsafe=False,
1240-
):
1233+
def get_links_pos(self, links_idx_local=None, envs_idx=None):
12411234
"""
12421235
Returns the position of a given reference point for all the entity's links.
12431236
@@ -1247,19 +1240,14 @@ def get_links_pos(
12471240
The indices of the links. Defaults to None.
12481241
envs_idx : None | array_like, optional
12491242
The indices of the environments. If None, all environments will be considered. Defaults to None.
1250-
ref: "link_origin" | "link_com" | "root_com"
1251-
The reference point being used to express the position of each link.
1252-
* "root_com": center of mass of the sub-entities to which the link belongs. As a reminder, a single
1253-
kinematic tree (aka. 'RigidEntity') may compromise multiple "physical" entities, i.e. a kinematic tree
1254-
that may have at most one free joint, at its root.
12551243
12561244
Returns
12571245
-------
12581246
pos : torch.Tensor, shape (n_links, 3) or (n_envs, n_links, 3)
12591247
The position of all the entity's links.
12601248
"""
12611249
links_idx = self._get_global_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
1262-
return self._solver.get_links_pos(links_idx, envs_idx, ref=ref)
1250+
return self._solver.get_links_pos(links_idx, envs_idx)
12631251

12641252
@gs.assert_built
12651253
def get_links_quat(self, links_idx_local=None, envs_idx=None):
@@ -1318,14 +1306,7 @@ def get_vAABB(self, envs_idx=None):
13181306
return torch.stack((aabbs[..., 0, :].min(dim=-2).values, aabbs[..., 1, :].max(dim=-2).values), dim=-2)
13191307

13201308
@gs.assert_built
1321-
def get_links_vel(
1322-
self,
1323-
links_idx_local=None,
1324-
envs_idx=None,
1325-
*,
1326-
ref: Literal["link_origin", "link_com"] = "link_origin",
1327-
unsafe=False,
1328-
):
1309+
def get_links_vel(self, links_idx_local=None, envs_idx=None):
13291310
"""
13301311
Returns linear velocity of all the entity's links expressed at a given reference position in world coordinates.
13311312
@@ -1335,16 +1316,14 @@ def get_links_vel(
13351316
The indices of the links. Defaults to None.
13361317
envs_idx : None | array_like, optional
13371318
The indices of the environments. If None, all environments will be considered. Defaults to None.
1338-
ref: "link_origin" | "link_com"
1339-
The reference point being used to expressed the velocity of each link.
13401319
13411320
Returns
13421321
-------
13431322
vel : torch.Tensor, shape (n_links, 3) or (n_envs, n_links, 3)
13441323
The linear velocity of all the entity's links.
13451324
"""
13461325
links_idx = self._get_global_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
1347-
return self._solver.get_links_vel(links_idx, envs_idx, ref=ref)
1326+
return self._solver.get_links_vel(links_idx, envs_idx)
13481327

13491328
@gs.assert_built
13501329
def get_links_ang(self, links_idx_local=None, envs_idx=None):
@@ -1800,6 +1779,7 @@ class RigidEntity(KinematicEntity):
18001779

18011780
if TYPE_CHECKING:
18021781
material: gs.materials.Rigid
1782+
_solver: "RigidSolver"
18031783

18041784
def __init__(
18051785
self,
@@ -3176,6 +3156,61 @@ def get_AABB(self, envs_idx=None, *, allow_fast_approx: bool = False):
31763156
def get_aabb(self):
31773157
raise DeprecationError("This method has been removed. Please use 'get_AABB()' instead.")
31783158

3159+
@gs.assert_built
3160+
def get_links_pos(
3161+
self,
3162+
links_idx_local=None,
3163+
envs_idx=None,
3164+
*,
3165+
ref: Literal["link_origin", "link_com", "root_com"] = "link_origin",
3166+
):
3167+
"""
3168+
Returns the position of a given reference point for all the entity's links.
3169+
3170+
Parameters
3171+
----------
3172+
links_idx_local : None | array_like
3173+
The indices of the links. Defaults to None.
3174+
envs_idx : None | array_like, optional
3175+
The indices of the environments. If None, all environments will be considered. Defaults to None.
3176+
ref: "link_origin" | "link_com" | "root_com"
3177+
The reference point being used to express the position of each link.
3178+
* "root_com": center of mass of the sub-entities to which the link belongs. As a reminder, a single
3179+
kinematic tree (aka. 'RigidEntity') may compromise multiple "physical" entities, i.e. a kinematic tree
3180+
that may have at most one free joint, at its root.
3181+
3182+
Returns
3183+
-------
3184+
pos : torch.Tensor, shape (n_links, 3) or (n_envs, n_links, 3)
3185+
The position of all the entity's links.
3186+
"""
3187+
links_idx = self._get_global_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
3188+
return self._solver.get_links_pos(links_idx, envs_idx, ref=ref)
3189+
3190+
@gs.assert_built
3191+
def get_links_vel(
3192+
self, links_idx_local=None, envs_idx=None, *, ref: Literal["link_origin", "link_com"] = "link_origin"
3193+
):
3194+
"""
3195+
Returns linear velocity of all the entity's links expressed at a given reference position in world coordinates.
3196+
3197+
Parameters
3198+
----------
3199+
links_idx_local : None | array_like
3200+
The indices of the links. Defaults to None.
3201+
envs_idx : None | array_like, optional
3202+
The indices of the environments. If None, all environments will be considered. Defaults to None.
3203+
ref: "link_origin" | "link_com"
3204+
The reference point being used to expressed the velocity of each link.
3205+
3206+
Returns
3207+
-------
3208+
vel : torch.Tensor, shape (n_links, 3) or (n_envs, n_links, 3)
3209+
The linear velocity of all the entity's links.
3210+
"""
3211+
links_idx = self._get_global_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
3212+
return self._solver.get_links_vel(links_idx, envs_idx, ref=ref)
3213+
31793214
@gs.assert_built
31803215
def get_links_acc(self, links_idx_local=None, envs_idx=None):
31813216
links_idx = self._get_global_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)

genesis/engine/solvers/kinematic_solver.py

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,8 @@ def substep_post_coupling(self, f):
485485
rigid_global_info=self._rigid_global_info,
486486
static_rigid_sim_config=self._static_rigid_sim_config,
487487
)
488+
self._is_forward_pos_updated = True
489+
self._is_forward_vel_updated = True
488490

489491
def substep_post_coupling_grad(self, f):
490492
pass
@@ -942,72 +944,33 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None):
942944
self._is_forward_pos_updated = True
943945
self._is_forward_vel_updated = True
944946

945-
@staticmethod
946-
def _convert_ref_to_idx(ref: Literal["link_origin", "link_com", "root_com"]):
947-
if ref == "root_com":
948-
return 0
949-
elif ref == "link_com":
950-
return 1
951-
elif ref == "link_origin":
952-
return 2
953-
else:
954-
gs.raise_exception("'ref' must be either 'link_origin', 'link_com', or 'root_com'.")
955-
956-
def get_links_pos(
957-
self,
958-
links_idx=None,
959-
envs_idx=None,
960-
*,
961-
ref: Literal["link_origin", "link_com", "root_com"] = "link_origin",
962-
):
947+
def get_links_pos(self, links_idx=None, envs_idx=None):
963948
if not gs.use_zerocopy:
964949
_, links_idx, envs_idx = self._sanitize_io_variables(
965950
None, links_idx, self.n_links, "links_idx", envs_idx, (3,), skip_allocation=True
966951
)
967-
968-
ref = self._convert_ref_to_idx(ref)
969-
if ref == 0:
970-
tensor = qd_to_torch(self.links_state.root_COM, envs_idx, links_idx, transpose=True, copy=True)
971-
elif ref == 1:
972-
i_pos = qd_to_torch(self.links_state.i_pos, envs_idx, links_idx, transpose=True)
973-
root_COM = qd_to_torch(self.links_state.root_COM, envs_idx, links_idx, transpose=True)
974-
tensor = i_pos + root_COM
975-
elif ref == 2:
976-
tensor = qd_to_torch(self.links_state.pos, envs_idx, links_idx, transpose=True, copy=True)
977-
else:
978-
gs.raise_exception("'ref' must be either 'link_origin', 'link_com', or 'root_com'.")
979-
952+
tensor = qd_to_torch(self.links_state.pos, envs_idx, links_idx, transpose=True, copy=True)
980953
return tensor[0] if self.n_envs == 0 else tensor
981954

982955
def get_links_quat(self, links_idx=None, envs_idx=None):
983956
tensor = qd_to_torch(self.links_state.quat, envs_idx, links_idx, transpose=True, copy=True)
984957
return tensor[0] if self.n_envs == 0 else tensor
985958

986-
def get_links_vel(
987-
self, links_idx=None, envs_idx=None, *, ref: Literal["link_origin", "link_com", "root_com"] = "link_origin"
988-
):
959+
def get_links_vel(self, links_idx=None, envs_idx=None):
989960
if gs.use_zerocopy:
990961
mask = (0, *indices_to_mask(links_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, links_idx)
991962
cd_vel = qd_to_torch(self.links_state.cd_vel, transpose=True)
992-
if ref == "root_com":
993-
return cd_vel[mask]
994963
cd_ang = qd_to_torch(self.links_state.cd_ang, transpose=True)
995-
if ref == "link_com":
996-
i_pos = qd_to_torch(self.links_state.i_pos, transpose=True)
997-
delta = i_pos[mask]
998-
else:
999-
pos = qd_to_torch(self.links_state.pos, transpose=True)
1000-
root_COM = qd_to_torch(self.links_state.root_COM, transpose=True)
1001-
delta = pos[mask] - root_COM[mask]
1002-
return cd_vel[mask] + cd_ang[mask].cross(delta, dim=-1)
964+
pos = qd_to_torch(self.links_state.pos, transpose=True)
965+
root_COM = qd_to_torch(self.links_state.root_COM, transpose=True)
966+
return cd_vel[mask] + cd_ang[mask].cross(pos[mask] - root_COM[mask], dim=-1)
1003967

1004968
_tensor, links_idx, envs_idx = self._sanitize_io_variables(
1005969
None, links_idx, self.n_links, "links_idx", envs_idx, (3,)
1006970
)
1007971
assert _tensor is not None
1008972
tensor = _tensor[None] if self.n_envs == 0 else _tensor
1009-
ref = self._convert_ref_to_idx(ref)
1010-
kernel_get_links_vel(tensor, links_idx, envs_idx, ref, self.links_state, self._static_rigid_sim_config)
973+
kernel_get_links_vel(tensor, links_idx, envs_idx, 2, self.links_state, self._static_rigid_sim_config)
1011974
return _tensor
1012975

1013976
def get_links_ang(self, links_idx=None, envs_idx=None):

genesis/engine/solvers/rigid/abd/forward_kinematics.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@ def kernel_forward_kinematics(
144144
static_rigid_sim_config=static_rigid_sim_config,
145145
is_backward=False,
146146
)
147+
func_COM_links(
148+
i_b=i_b,
149+
links_state=links_state,
150+
links_info=links_info,
151+
joints_state=joints_state,
152+
joints_info=joints_info,
153+
dofs_state=dofs_state,
154+
dofs_info=dofs_info,
155+
entities_info=entities_info,
156+
rigid_global_info=rigid_global_info,
157+
static_rigid_sim_config=static_rigid_sim_config,
158+
is_backward=False,
159+
)
147160
func_forward_velocity_batch(
148161
i_b=i_b,
149162
entities_info=entities_info,
@@ -185,6 +198,19 @@ def kernel_masked_forward_kinematics(
185198
static_rigid_sim_config=static_rigid_sim_config,
186199
is_backward=False,
187200
)
201+
func_COM_links(
202+
i_b=i_b,
203+
links_state=links_state,
204+
links_info=links_info,
205+
joints_state=joints_state,
206+
joints_info=joints_info,
207+
dofs_state=dofs_state,
208+
dofs_info=dofs_info,
209+
entities_info=entities_info,
210+
rigid_global_info=rigid_global_info,
211+
static_rigid_sim_config=static_rigid_sim_config,
212+
is_backward=False,
213+
)
188214
func_forward_velocity_batch(
189215
i_b=i_b,
190216
entities_info=entities_info,
@@ -321,6 +347,7 @@ def func_COM_links_entity(
321347
static_rigid_sim_config: qd.template(),
322348
is_backward: qd.template(),
323349
):
350+
EPS = rigid_global_info.EPS[None]
324351
BW = qd.static(is_backward)
325352

326353
# Becomes static loop in backward pass, because we assume this loop is an inner loop
@@ -372,7 +399,11 @@ def func_COM_links_entity(
372399

373400
i_r = links_info.root_idx[I_l]
374401
if i_l == i_r:
375-
links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b]
402+
mass_sum = links_state.mass_sum[i_l, i_b]
403+
if mass_sum > EPS:
404+
links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b]
405+
else:
406+
links_state.root_COM[i_l, i_b] = links_state.i_pos_bw[i_r, i_b]
376407

377408
for i_l_ in (
378409
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
@@ -500,8 +531,6 @@ def func_COM_links_entity(
500531
i_j = i_j_ if qd.static(not BW) else (i_j_ + links_info.joint_start[I_l])
501532

502533
if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW):
503-
EPS = rigid_global_info.EPS[None]
504-
505534
offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b]
506535
I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j
507536
joint_type = joints_info.type[I_j]
@@ -571,7 +600,7 @@ def func_forward_kinematics_entity(
571600
if qd.static(not BW)
572601
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
573602
):
574-
i_l = i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e])
603+
i_l = gs.qd_int(i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e]))
575604

576605
if func_check_index_range(i_l, entities_info.link_start[i_e], entities_info.link_end[i_e], BW):
577606
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
@@ -977,7 +1006,7 @@ def func_forward_velocity_entity(
9771006
if qd.static(not BW)
9781007
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
9791008
):
980-
i_l = i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e])
1009+
i_l = gs.qd_int(i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e]))
9811010

9821011
if func_check_index_range(i_l, entities_info.link_start[i_e], entities_info.link_end[i_e], BW):
9831012
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l

0 commit comments

Comments
 (0)