Skip to content

Commit c8b56a2

Browse files
alexis779duburcqa
andauthored
[BUG FIX] Fix 'RigidJoint.(get_anchor_pos | get_anchor_axis)' getters. (#2012)
Co-authored-by: Alexis DUBURCQ <[email protected]>
1 parent 44e3760 commit c8b56a2

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

genesis/engine/entities/rigid_entity/rigid_joint.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import genesis as gs
55
import genesis.utils.geom as gu
6+
from genesis.utils import array_class
67
from genesis.utils.misc import DeprecationError
78
from genesis.repr_base import RBC
89

910

10-
@ti.data_oriented
1111
class RigidJoint(RBC):
1212
"""
1313
Joint class for rigid body entities. Each RigidLink is connected to its parent link via a RigidJoint.
@@ -102,18 +102,11 @@ def get_anchor_pos(self):
102102
the anchor point is the "output" of the joint transmission, on which the child body is welded.
103103
"""
104104
tensor = torch.empty((self._solver._B, 3), dtype=gs.tc_float, device=gs.device)
105-
self._kernel_get_anchor_pos(tensor)
105+
_kernel_get_anchor_pos(self._idx, tensor, self._solver.joints_state)
106106
if self._solver.n_envs == 0:
107107
tensor = tensor.squeeze(0)
108108
return tensor
109109

110-
@ti.kernel
111-
def _kernel_get_anchor_pos(self, tensor: ti.types.ndarray()):
112-
for i_b in range(self._solver._B):
113-
xpos = self._solver.joints_state.xanchor[self._idx, i_b]
114-
for i in ti.static(range(3)):
115-
tensor[i_b, i] = xpos[i]
116-
117110
@gs.assert_built
118111
def get_anchor_axis(self):
119112
"""
@@ -122,18 +115,11 @@ def get_anchor_axis(self):
122115
See `RigidJoint.get_anchor_pos` documentation for details about the notion on anchor point.
123116
"""
124117
tensor = torch.empty((self._solver._B, 3), dtype=gs.tc_float, device=gs.device)
125-
self._kernel_get_anchor_axis(tensor)
118+
_kernel_get_anchor_axis(self._idx, tensor, self._solver.joints_state)
126119
if self._solver.n_envs == 0:
127120
tensor = tensor.squeeze(0)
128121
return tensor
129122

130-
@ti.kernel
131-
def _kernel_get_anchor_axis(self, tensor: ti.types.ndarray()):
132-
for i_b in range(self._solver._B):
133-
xaxis = self._solver.joints_state.xaxis[self._idx, i_b]
134-
for i in ti.static(range(3)):
135-
tensor[i_b, i] = xaxis[i]
136-
137123
def set_sol_params(self, sol_params):
138124
"""
139125
Set the solver parameters of this joint.
@@ -454,3 +440,21 @@ def is_built(self):
454440

455441
def _repr_brief(self):
456442
return f"{(self._repr_type())}: {self._uid}, name: '{self._name}', idx: {self._idx}, type: {self._type}"
443+
444+
445+
@ti.kernel
446+
def _kernel_get_anchor_pos(joint_idx: ti.i32, tensor: ti.types.ndarray(), joints_state: array_class.JointsState):
447+
_B = joints_state.xanchor.shape[1]
448+
for i_b in range(_B):
449+
xpos = joints_state.xanchor[joint_idx, i_b]
450+
for i in ti.static(range(3)):
451+
tensor[i_b, i] = xpos[i]
452+
453+
454+
@ti.kernel
455+
def _kernel_get_anchor_axis(joint_idx: ti.i32, tensor: ti.types.ndarray(), joints_state: array_class.JointsState):
456+
_B = joints_state.xaxis.shape[1]
457+
for i_b in range(_B):
458+
xaxis = joints_state.xaxis[joint_idx, i_b]
459+
for i in ti.static(range(3)):
460+
tensor[i_b, i] = xaxis[i]

tests/test_rigid_physics.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,3 +3392,30 @@ def test_reset_control(robot_path, tol):
33923392
new_control_force = robot.get_dofs_control_force()
33933393
assert old_control_force.abs().max() > gs.EPS
33943394
assert_allclose(new_control_force, 0, tol=gs.EPS)
3395+
3396+
3397+
@pytest.mark.required
3398+
@pytest.mark.parametrize("n_envs", [0, 2])
3399+
def test_joint_get_anchor_pos_and_axis(n_envs):
3400+
scene = gs.Scene(
3401+
show_viewer=False,
3402+
show_FPS=False,
3403+
)
3404+
robot = scene.add_entity(
3405+
gs.morphs.MJCF(
3406+
file="xml/franka_emika_panda/panda.xml",
3407+
),
3408+
)
3409+
scene.build(n_envs=n_envs)
3410+
batch_shape = (n_envs,) if n_envs > 0 else ()
3411+
3412+
joint = robot.joints[1]
3413+
anchor_pos = joint.get_anchor_pos()
3414+
assert anchor_pos.shape == (*batch_shape, 3)
3415+
expected_pos = scene.rigid_solver.joints_state.xanchor.to_numpy()
3416+
assert_allclose(anchor_pos, expected_pos[joint.idx], tol=gs.EPS)
3417+
3418+
anchor_axis = joint.get_anchor_axis()
3419+
assert anchor_axis.shape == (*batch_shape, 3)
3420+
expected_axis = scene.rigid_solver.joints_state.xaxis.to_numpy()
3421+
assert_allclose(anchor_axis, expected_axis[joint.idx], tol=gs.EPS)

0 commit comments

Comments
 (0)