From 9a2791d5b8ef973c0cc79cb91794e7898d00c368 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 18 Jan 2026 17:01:16 -0800 Subject: [PATCH 1/9] Ball joints have 4 qpos (quaternion) and 3 dof values. Added: - expand_to_q_indices/expand_to_v_indices for joint-to-qpos/dof mapping - nq, nv properties for total qpos/dof dimensions - joint_qpos_widths, joint_dof_widths tensors --- src/mjlab/entity/data.py | 17 +-- src/mjlab/entity/entity.py | 209 ++++++++++++++++++++++++++++++++----- 2 files changed, 192 insertions(+), 34 deletions(-) diff --git a/src/mjlab/entity/data.py b/src/mjlab/entity/data.py index 0333241de..400a639df 100644 --- a/src/mjlab/entity/data.py +++ b/src/mjlab/entity/data.py @@ -39,6 +39,11 @@ class EntityData: root_link_pose_w) require sim.forward() to be current. If you write then read, call sim.forward() in between. Event order matters when mixing reads and writes. All inputs/outputs use world frame. + + Ball Joint Support: + Position tensors use nq dimensions (4 per ball, 1 per hinge/slide). + Velocity/effort tensors use nv dimensions (3 per ball, 1 per hinge/slide). + Joint limits remain per-joint: (nworld, num_joints, 2). """ indexing: EntityIndexing @@ -155,9 +160,9 @@ def write_joint_position( raise ValueError("Cannot write joint position for non-articulated entity.") env_ids = self._resolve_env_ids(env_ids) - joint_ids = joint_ids if joint_ids is not None else slice(None) - q_slice = self.indexing.joint_q_adr[joint_ids] - self.data.qpos[env_ids, q_slice] = position + q_indices = self.indexing.expand_to_q_indices(joint_ids) + q_adr = self.indexing.joint_q_adr[q_indices] + self.data.qpos[env_ids, q_adr] = position def write_joint_velocity( self, @@ -169,9 +174,9 @@ def write_joint_velocity( raise ValueError("Cannot write joint velocity for non-articulated entity.") env_ids = self._resolve_env_ids(env_ids) - joint_ids = joint_ids if joint_ids is not None else slice(None) - v_slice = self.indexing.joint_v_adr[joint_ids] - self.data.qvel[env_ids, v_slice] = velocity + v_indices = self.indexing.expand_to_v_indices(joint_ids) + v_adr = self.indexing.joint_v_adr[v_indices] + self.data.qvel[env_ids, v_adr] = velocity def write_external_wrench( self, diff --git a/src/mjlab/entity/entity.py b/src/mjlab/entity/entity.py index 360528e6d..584e0a261 100644 --- a/src/mjlab/entity/entity.py +++ b/src/mjlab/entity/entity.py @@ -17,7 +17,7 @@ from mjlab.utils.lab_api.string import resolve_matching_names from mjlab.utils.mujoco import dof_width, qpos_width from mjlab.utils.spec import auto_wrap_fixed_base_mocap -from mjlab.utils.string import resolve_expr +from mjlab.utils.string import resolve_expr_with_widths @dataclass(frozen=False) @@ -47,10 +47,63 @@ class EntityIndexing: free_joint_q_adr: torch.Tensor free_joint_v_adr: torch.Tensor + # Ball joint support: per-joint dimensions and cumulative offsets. + joint_qpos_widths: torch.Tensor # (num_joints,) - 4 for ball, 1 for hinge/slide + joint_dof_widths: torch.Tensor # (num_joints,) - 3 for ball, 1 for hinge/slide + joint_types: torch.Tensor # (num_joints,) - mjtJoint enum values + q_offsets: torch.Tensor # (num_joints,) - cumulative qpos offsets per joint + v_offsets: torch.Tensor # (num_joints,) - cumulative qvel offsets per joint + nq: int # Total qpos dimension + nv: int # Total dof dimension + @property def root_body_id(self) -> int: return self.bodies[0].id + def expand_to_q_indices( + self, joint_ids: torch.Tensor | Sequence[int] | slice | None + ) -> torch.Tensor | slice: + """Expand joint IDs to qpos indices for ball joint support.""" + return self._expand_indices(joint_ids, self.joint_qpos_widths, self.q_offsets) + + def expand_to_v_indices( + self, joint_ids: torch.Tensor | Sequence[int] | slice | None + ) -> torch.Tensor | slice: + """Expand joint IDs to qvel indices for ball joint support.""" + return self._expand_indices(joint_ids, self.joint_dof_widths, self.v_offsets) + + def _expand_indices( + self, + joint_ids: torch.Tensor | Sequence[int] | slice | None, + widths: torch.Tensor, + offsets: torch.Tensor, + ) -> torch.Tensor | slice: + """Expand joint IDs to DOF indices using tensor operations.""" + if joint_ids is None: + return slice(None) + + device = widths.device + if isinstance(joint_ids, slice): + start = joint_ids.start or 0 + stop = joint_ids.stop or len(widths) + joint_ids = torch.arange(start, stop, device=device) + elif not isinstance(joint_ids, torch.Tensor): + joint_ids = torch.tensor(joint_ids, dtype=torch.long, device=device) + + selected_widths = widths[joint_ids] + selected_offsets = offsets[joint_ids] + + # Expand using repeat_interleave: starts + local_offsets + repeated_offsets = selected_offsets.repeat_interleave(selected_widths) + total = int(selected_widths.sum().item()) + cumwidths = torch.zeros_like(selected_widths) + cumwidths[1:] = selected_widths[:-1].cumsum(0) + local_offsets = torch.arange(total, device=device) - cumwidths.repeat_interleave( + selected_widths + ) + + return repeated_offsets + local_offsets + @dataclass class EntityCfg: @@ -64,7 +117,10 @@ class InitialStateCfg: ang_vel: tuple[float, float, float] = (0.0, 0.0, 0.0) # Articulation (only for articulated entities). # Set to None to use the model's existing keyframe (errors if none exists). - joint_pos: dict[str, float] | None = field(default_factory=lambda: {".*": 0.0}) + # Ball joints use tuple of 4 floats for quaternion (w, x, y, z). + joint_pos: dict[str, float | tuple[float, ...]] | None = field( + default_factory=lambda: {".*": 0.0} + ) joint_vel: dict[str, float] = field(default_factory=lambda: {".*": 0.0}) init_state: InitialStateCfg = field(default_factory=InitialStateCfg) @@ -193,21 +249,37 @@ def _add_initial_state_keyframe(self) -> None: self.root_body.quat[:] = self.cfg.init_state.rot return - qpos_components = [] + qpos_components: list[tuple[float, ...]] = [] if self._free_joint is not None: qpos_components.extend([self.cfg.init_state.pos, self.cfg.init_state.rot]) joint_pos = None if self._non_free_joints: - joint_pos = resolve_expr(self.cfg.init_state.joint_pos, self.joint_names, 0.0) + # Per-joint widths and defaults (ball: 4 qpos, hinge/slide: 1). + widths = tuple(qpos_width(j.type) for j in self._non_free_joints) + defaults = tuple( + (1.0, 0.0, 0.0, 0.0) if j.type == mujoco.mjtJoint.mjJNT_BALL else (0.0,) + for j in self._non_free_joints + ) + joint_pos = resolve_expr_with_widths( + self.cfg.init_state.joint_pos, self.joint_names, widths, defaults + ) qpos_components.append(joint_pos) key_qpos = np.hstack(qpos_components) if qpos_components else np.array([]) key = self._spec.add_key(name="init_state", qpos=key_qpos) if self.is_actuated and joint_pos is not None: - name_to_pos = {name: joint_pos[i] for i, name in enumerate(self.joint_names)} + # Map joint names to position values for ctrl (skip ball joints). + joint_q_offset = 0 + name_to_pos: dict[str, float] = {} + for j, name in zip(self._non_free_joints, self.joint_names, strict=True): + width = qpos_width(j.type) + if width == 1: + name_to_pos[name] = joint_pos[joint_q_offset] + joint_q_offset += width + ctrl = [] for act in self._spec.actuators: joint_name = act.target @@ -304,6 +376,20 @@ def actuator_names(self) -> tuple[str, ...]: def num_joints(self) -> int: return len(self.joint_names) + @property + def nq(self) -> int: + """Total qpos dimension (excludes free joint). Accounts for ball joints (4 qpos).""" + if hasattr(self, "indexing"): + return self.indexing.nq + return sum(qpos_width(j.type) for j in self._non_free_joints) + + @property + def nv(self) -> int: + """Total dof dimension (excludes free joint). Accounts for ball joints (3 dof).""" + if hasattr(self, "indexing"): + return self.indexing.nv + return sum(dof_width(j.type) for j in self._non_free_joints) + @property def num_bodies(self) -> int: return len(self.body_names) @@ -452,6 +538,13 @@ def initialize( # Joint state. if self.is_articulated: + # Use indexing tensors for widths; build defaults based on joint types. + qpos_widths = tuple(indexing.joint_qpos_widths.tolist()) + dof_widths_tuple = tuple(indexing.joint_dof_widths.tolist()) + is_ball = (indexing.joint_types == int(mujoco.mjtJoint.mjJNT_BALL)).tolist() + qpos_defaults = tuple((1.0, 0.0, 0.0, 0.0) if b else (0.0,) for b in is_ball) + vel_defaults = tuple((0.0, 0.0, 0.0) if b else (0.0,) for b in is_ball) + if self.cfg.init_state.joint_pos is None: # Use keyframe joint positions. key_qpos = mj_model.key("init_state").qpos @@ -461,19 +554,40 @@ def initialize( ].repeat(nworld, 1) else: default_joint_pos = torch.tensor( - resolve_expr(self.cfg.init_state.joint_pos, self.joint_names, 0.0), + resolve_expr_with_widths( + self.cfg.init_state.joint_pos, + self.joint_names, + qpos_widths, + qpos_defaults, + ), device=device, )[None].repeat(nworld, 1) default_joint_vel = torch.tensor( - resolve_expr(self.cfg.init_state.joint_vel, self.joint_names, 0.0), + resolve_expr_with_widths( + self.cfg.init_state.joint_vel, + self.joint_names, + dof_widths_tuple, + vel_defaults, + ), device=device, )[None].repeat(nworld, 1) - # Joint limits. + # Joint limits: hinge/slide have real limits, ball joints don't. + # MuJoCo joint types: free=0, ball=1, slide=2, hinge=3. joint_ids_global = torch.tensor( [j.id for j in self._non_free_joints], device=device ) - dof_limits = model.jnt_range[:, joint_ids_global] + dof_limits = model.jnt_range[:, joint_ids_global] # (1, num_joints, 2) + has_real_limits = indexing.joint_types >= 2 # slide=2, hinge=3 + ball_mask = ~has_real_limits + if ball_mask.any(): + # Ball joints: substitute with large range since they have no position limits. + large_range = 1e6 + dof_limits = dof_limits.clone() + dof_limits[:, ball_mask, :] = torch.tensor( + [[-large_range, large_range]], device=device + ) + default_joint_pos_limits = dof_limits.clone() joint_pos_limits = default_joint_pos_limits.clone() joint_pos_mean = (joint_pos_limits[..., 0] + joint_pos_limits[..., 1]) / 2 @@ -505,14 +619,15 @@ def initialize( ) if self.is_actuated: + # Use nq for position targets, nv for velocity/effort targets. joint_pos_target = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nq), dtype=torch.float, device=device ) joint_vel_target = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nv), dtype=torch.float, device=device ) joint_effort_target = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nv), dtype=torch.float, device=device ) else: joint_pos_target = torch.empty(nworld, 0, dtype=torch.float, device=device) @@ -546,10 +661,10 @@ def initialize( site_effort_target = torch.empty(nworld, 0, dtype=torch.float, device=device) # Encoder bias for simulating encoder calibration errors. - # Shape: (num_envs, num_joints). Defaults to zero (no bias). + # Shape: (num_envs, nq). Defaults to zero (no bias). if self.is_articulated: encoder_bias = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nq), dtype=torch.float, device=device ) else: encoder_bias = torch.empty(nworld, 0, dtype=torch.float, device=device) @@ -738,15 +853,15 @@ def set_joint_position_target( """Set joint position targets. Args: - position: Target joint poisitions with shape (N, num_joints). + position: Target joint positions with shape (N, nq) or (N, selected_nq). + For ball joints, this should include all 4 quaternion values per joint. joint_ids: Optional joint indices to set. If None, set all joints. env_ids: Optional environment indices. If None, set all environments. """ if env_ids is None: env_ids = slice(None) - if joint_ids is None: - joint_ids = slice(None) - self._data.joint_pos_target[env_ids, joint_ids] = position + q_indices = self.indexing.expand_to_q_indices(joint_ids) + self._data.joint_pos_target[env_ids, q_indices] = position def set_joint_velocity_target( self, @@ -757,15 +872,15 @@ def set_joint_velocity_target( """Set joint velocity targets. Args: - velocity: Target joint velocities with shape (N, num_joints). + velocity: Target joint velocities with shape (N, nv) or (N, selected_nv). + For ball joints, this should include all 3 angular velocity values per joint. joint_ids: Optional joint indices to set. If None, set all joints. env_ids: Optional environment indices. If None, set all environments. """ if env_ids is None: env_ids = slice(None) - if joint_ids is None: - joint_ids = slice(None) - self._data.joint_vel_target[env_ids, joint_ids] = velocity + v_indices = self.indexing.expand_to_v_indices(joint_ids) + self._data.joint_vel_target[env_ids, v_indices] = velocity def set_joint_effort_target( self, @@ -776,15 +891,15 @@ def set_joint_effort_target( """Set joint effort targets. Args: - effort: Target joint efforts with shape (N, num_joints). + effort: Target joint efforts with shape (N, nv) or (N, selected_nv). + For ball joints, this should include all 3 torque values per joint. joint_ids: Optional joint indices to set. If None, set all joints. env_ids: Optional environment indices. If None, set all environments. """ if env_ids is None: env_ids = slice(None) - if joint_ids is None: - joint_ids = slice(None) - self._data.joint_effort_target[env_ids, joint_ids] = effort + v_indices = self.indexing.expand_to_v_indices(joint_ids) + self._data.joint_effort_target[env_ids, v_indices] = effort def set_tendon_len_target( self, @@ -930,6 +1045,11 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin joint_v_adr = [] free_joint_q_adr = [] free_joint_v_adr = [] + + # Per-joint dimension tracking for ball joint support. + joint_qpos_widths_list: list[int] = [] + joint_dof_widths_list: list[int] = [] + joint_types_list: list[int] = [] for joint in self.spec.joints: jnt = model.joint(joint.name) jnt_type = jnt.type[0] @@ -939,13 +1059,39 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin free_joint_v_adr.extend(range(vadr, vadr + 6)) free_joint_q_adr.extend(range(qadr, qadr + 7)) else: - joint_v_adr.extend(range(vadr, vadr + dof_width(jnt_type))) - joint_q_adr.extend(range(qadr, qadr + qpos_width(jnt_type))) + qw = qpos_width(jnt_type) + dw = dof_width(jnt_type) + joint_v_adr.extend(range(vadr, vadr + dw)) + joint_q_adr.extend(range(qadr, qadr + qw)) + + # Track per-joint dimensions. + joint_types_list.append(jnt_type) + joint_qpos_widths_list.append(qw) + joint_dof_widths_list.append(dw) + joint_q_adr = torch.tensor(joint_q_adr, dtype=torch.int, device=device) joint_v_adr = torch.tensor(joint_v_adr, dtype=torch.int, device=device) free_joint_v_adr = torch.tensor(free_joint_v_adr, dtype=torch.int, device=device) free_joint_q_adr = torch.tensor(free_joint_q_adr, dtype=torch.int, device=device) + # Convert per-joint tracking lists to tensors. + joint_qpos_widths = torch.tensor( + joint_qpos_widths_list, dtype=torch.int, device=device + ) + joint_dof_widths = torch.tensor( + joint_dof_widths_list, dtype=torch.int, device=device + ) + joint_types = torch.tensor(joint_types_list, dtype=torch.int, device=device) + + # Compute cumulative offsets for tensor-based index expansion. + q_offsets = torch.zeros_like(joint_qpos_widths) + v_offsets = torch.zeros_like(joint_dof_widths) + if len(joint_qpos_widths) > 0: + q_offsets[1:] = joint_qpos_widths[:-1].cumsum(0) + v_offsets[1:] = joint_dof_widths[:-1].cumsum(0) + nq = int(joint_qpos_widths.sum().item()) if len(joint_qpos_widths) > 0 else 0 + nv = int(joint_dof_widths.sum().item()) if len(joint_dof_widths) > 0 else 0 + if self.is_fixed_base and self.is_mocap: mocap_id = int(model.body_mocapid[self.root_body.id]) else: @@ -969,6 +1115,13 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin joint_v_adr=joint_v_adr, free_joint_q_adr=free_joint_q_adr, free_joint_v_adr=free_joint_v_adr, + joint_qpos_widths=joint_qpos_widths, + joint_dof_widths=joint_dof_widths, + joint_types=joint_types, + q_offsets=q_offsets, + v_offsets=v_offsets, + nq=nq, + nv=nv, ) def _apply_actuator_controls(self) -> None: From 5c5c4130a895b5b93ec0325d6a9ced00bd4c456a Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 18 Jan 2026 17:03:11 -0800 Subject: [PATCH 2/9] Update string utils for ball joint support --- src/mjlab/utils/string.py | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/mjlab/utils/string.py b/src/mjlab/utils/string.py index e57d55c03..7f12df9cb 100644 --- a/src/mjlab/utils/string.py +++ b/src/mjlab/utils/string.py @@ -21,6 +21,46 @@ def resolve_expr( return tuple(result) +def resolve_expr_with_widths( + pattern_map: dict[str, Any], + names: tuple[str, ...], + widths: tuple[int, ...], + default_vals: tuple[tuple[float, ...], ...], +) -> tuple[float, ...]: + """Resolve field values accounting for per-joint widths (ball=4, hinge/slide=1). + + Scalars for ball joints use the default (identity quaternion) for backward + compatibility with {".*": 0.0} patterns. + """ + patterns = [(re.compile(pat), val) for pat, val in pattern_map.items()] + + result: list[float] = [] + for name, width, default in zip(names, widths, default_vals, strict=True): + matched = False + for pat, val in patterns: + if pat.match(name): + if isinstance(val, (tuple, list)): + if len(val) != width: + raise ValueError(f"Joint '{name}' expects {width} values, got {len(val)}") + result.extend(val) + matched = True + else: + # Scalar value. + if width == 1: + result.append(val) + matched = True + else: + # Scalar for ball joint - use default (identity quaternion). + # This handles backward compatibility with {".*": 0.0} patterns. + result.extend(default) + matched = True + break + if not matched: + result.extend(default) + + return tuple(result) + + def filter_exp( exprs: list[str] | tuple[str, ...], names: tuple[str, ...] ) -> tuple[str, ...]: From 67c2ebe7ac6f0a4dd7ad50018036e979b55d62c9 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 18 Jan 2026 17:14:16 -0800 Subject: [PATCH 3/9] Update actuators for ball joint qpos/dof dimensions --- src/mjlab/actuator/actuator.py | 22 +++++++++--- src/mjlab/actuator/builtin_group.py | 46 ++++++++++++++++++++------ src/mjlab/actuator/delayed_actuator.py | 3 ++ 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/src/mjlab/actuator/actuator.py b/src/mjlab/actuator/actuator.py index d99156e0b..c297a42e4 100644 --- a/src/mjlab/actuator/actuator.py +++ b/src/mjlab/actuator/actuator.py @@ -113,6 +113,9 @@ def __init__( self._ctrl_ids: torch.Tensor | None = None self._mjs_actuators: list[mujoco.MjsActuator] = [] self._site_zeros: torch.Tensor | None = None + # Expanded indices for ball joint support. + self._q_indices: torch.Tensor | None = None + self._v_indices: torch.Tensor | None = None @property def target_ids(self) -> torch.Tensor: @@ -173,6 +176,15 @@ def initialize( ctrl_ids_list = [act.id for act in self._mjs_actuators] self._ctrl_ids = torch.tensor(ctrl_ids_list, dtype=torch.long, device=device) + # Compute expanded indices for ball joint support. + if self.transmission_type == TransmissionType.JOINT: + indexing = self.entity.indexing + q_indices = indexing.expand_to_q_indices(self._target_ids_list) + v_indices = indexing.expand_to_v_indices(self._target_ids_list) + assert isinstance(q_indices, torch.Tensor) and isinstance(v_indices, torch.Tensor) + self._q_indices = q_indices + self._v_indices = v_indices + # Pre-allocate zeros for SITE transmission type to avoid repeated allocations. if self.transmission_type == TransmissionType.SITE: nenvs = data.nworld @@ -190,11 +202,11 @@ def get_command(self, data: EntityData) -> ActuatorCmd: """ if self.transmission_type == TransmissionType.JOINT: return ActuatorCmd( - position_target=data.joint_pos_target[:, self.target_ids], - velocity_target=data.joint_vel_target[:, self.target_ids], - effort_target=data.joint_effort_target[:, self.target_ids], - pos=data.joint_pos[:, self.target_ids], - vel=data.joint_vel[:, self.target_ids], + position_target=data.joint_pos_target[:, self._q_indices], + velocity_target=data.joint_vel_target[:, self._v_indices], + effort_target=data.joint_effort_target[:, self._v_indices], + pos=data.joint_pos[:, self._q_indices], + vel=data.joint_vel[:, self._v_indices], ) elif self.transmission_type == TransmissionType.TENDON: return ActuatorCmd( diff --git a/src/mjlab/actuator/builtin_group.py b/src/mjlab/actuator/builtin_group.py index ad14705e7..ebcc3be95 100644 --- a/src/mjlab/actuator/builtin_group.py +++ b/src/mjlab/actuator/builtin_group.py @@ -37,6 +37,16 @@ (BuiltinMuscleActuator, TransmissionType.TENDON): "tendon_effort_target", } +# Indicates whether the actuator type uses qpos indexing (vs qvel indexing). +# Position actuators read from qpos-indexed tensors (nq dimension). +# Velocity/motor actuators read from qvel-indexed tensors (nv dimension). +_USES_QPOS_INDEXING: dict[type[BuiltinActuatorType], bool] = { + BuiltinPositionActuator: True, + BuiltinVelocityActuator: False, + BuiltinMotorActuator: False, + BuiltinMuscleActuator: False, +} + @dataclass(frozen=True) class BuiltinActuatorGroup: @@ -47,7 +57,9 @@ class BuiltinActuatorGroup: enables direct writes without per-actuator overhead. """ - # Map from (BuiltinActuator type, transmission_type) to (target_ids, ctrl_ids). + # Map from (BuiltinActuator type, transmission_type) to (expanded_indices, ctrl_ids). + # For JOINT transmission, expanded_indices are qpos or qvel indices depending + # on the actuator type. For other transmissions, they match target_ids. _index_groups: dict[ tuple[type[BuiltinActuatorType], TransmissionType], tuple[torch.Tensor, torch.Tensor], @@ -84,17 +96,29 @@ def process( else: custom_actuators.append(act) - # Return stacked indices for each (actuator_type, transmission_type) group. + # Build stacked indices for each (actuator_type, transmission_type) group. index_groups: dict[ tuple[type[BuiltinActuatorType], TransmissionType], tuple[torch.Tensor, torch.Tensor], - ] = { - key: ( - torch.cat([act.target_ids for act in acts], dim=0), - torch.cat([act.ctrl_ids for act in acts], dim=0), - ) - for key, acts in builtin_groups.items() - } + ] = {} + + for key, acts in builtin_groups.items(): + actuator_type, transmission_type = key + ctrl_ids = torch.cat([act.ctrl_ids for act in acts], dim=0) + + if transmission_type == TransmissionType.JOINT: + # Use expanded qpos/qvel indices for joint transmission. + uses_qpos = _USES_QPOS_INDEXING[actuator_type] + attr = "_q_indices" if uses_qpos else "_v_indices" + indices_list = [getattr(act, attr) for act in acts] + assert all(idx is not None for idx in indices_list) + expanded_indices = torch.cat(indices_list, dim=0) # type: ignore[arg-type] + index_groups[key] = (expanded_indices, ctrl_ids) + else: + # For tendon/site, use flat target_ids. + target_ids = torch.cat([act.target_ids for act in acts], dim=0) + index_groups[key] = (target_ids, ctrl_ids) + return BuiltinActuatorGroup(index_groups), tuple(custom_actuators) def apply_controls(self, data: EntityData) -> None: @@ -104,9 +128,9 @@ def apply_controls(self, data: EntityData) -> None: data: Entity data containing targets and control arrays. """ for (actuator_type, transmission_type), ( - target_ids, + expanded_indices, ctrl_ids, ) in self._index_groups.items(): attr_name = _TARGET_TENSOR_MAP[(actuator_type, transmission_type)] target_tensor = getattr(data, attr_name) - data.write_ctrl(target_tensor[:, target_ids], ctrl_ids) + data.write_ctrl(target_tensor[:, expanded_indices], ctrl_ids) diff --git a/src/mjlab/actuator/delayed_actuator.py b/src/mjlab/actuator/delayed_actuator.py index 52135db66..e98a226fb 100644 --- a/src/mjlab/actuator/delayed_actuator.py +++ b/src/mjlab/actuator/delayed_actuator.py @@ -104,6 +104,9 @@ def initialize( self._target_ids = self._base_actuator._target_ids self._ctrl_ids = self._base_actuator._ctrl_ids + # Copy expanded indices for ball joint support. + self._q_indices = self._base_actuator._q_indices + self._v_indices = self._base_actuator._v_indices targets = ( (self.cfg.delay_target,) From 9fef621d14ced80632ce92abbadf79213467ade9 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 18 Jan 2026 17:24:08 -0800 Subject: [PATCH 4/9] Add ball joint tests Co-Authored-By: Claude Opus 4.5 --- tests/test_ball_joint.py | 269 +++++++++++++++++++++++++++++++++++++++ tests/test_events.py | 10 ++ 2 files changed, 279 insertions(+) create mode 100644 tests/test_ball_joint.py diff --git a/tests/test_ball_joint.py b/tests/test_ball_joint.py new file mode 100644 index 000000000..fd452be33 --- /dev/null +++ b/tests/test_ball_joint.py @@ -0,0 +1,269 @@ +"""Tests for ball joint support in the Entity system.""" + +import mujoco +import pytest +import torch +from conftest import get_test_device, initialize_entity + +from mjlab.entity import Entity, EntityCfg + +# XML with ball joints for testing. +BALL_JOINT_XML = """ + + + + + + + + + + + + + + + + + + + + +""" + +# XML with only hinge joints for comparison. +HINGE_ONLY_XML = """ + + + + + + + + + + + + + + + + +""" + + +@pytest.fixture(scope="module") +def device(): + """Test device fixture.""" + return get_test_device() + + +def create_ball_joint_entity(): + """Create an entity with ball joints.""" + cfg = EntityCfg(spec_fn=lambda: mujoco.MjSpec.from_string(BALL_JOINT_XML)) + return Entity(cfg) + + +def create_hinge_only_entity(): + """Create an entity with only hinge joints.""" + cfg = EntityCfg(spec_fn=lambda: mujoco.MjSpec.from_string(HINGE_ONLY_XML)) + return Entity(cfg) + + +def test_nq_nv_with_ball_joints(device): + """Test that nq and nv are correctly computed for ball joints. + + Ball joints have 4 qpos (quaternion) and 3 qvel (angular velocity). + Hinge joints have 1 qpos and 1 qvel. + """ + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Entity has: ball1 (4 qpos, 3 dof), hinge1 (1 qpos, 1 dof), ball2 (4 qpos, 3 dof) + # Total: 4 + 1 + 4 = 9 qpos, 3 + 1 + 3 = 7 dof + assert entity.num_joints == 3 + assert entity.nq == 9, f"Expected nq=9, got {entity.nq}" + assert entity.nv == 7, f"Expected nv=7, got {entity.nv}" + + # Check indexing fields match. + assert entity.indexing.nq == 9 + assert entity.indexing.nv == 7 + + +def test_nq_nv_hinge_only(device): + """Test that nq and nv equal num_joints for hinge-only entities.""" + entity = create_hinge_only_entity() + entity, sim = initialize_entity(entity, device) + + # Entity has: hinge1 (1 qpos, 1 dof), hinge2 (1 qpos, 1 dof) + assert entity.num_joints == 2 + assert entity.nq == 2 + assert entity.nv == 2 + + +def test_joint_offset_tensors(device): + """Test that q_offsets and v_offsets are correctly computed.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + indexing = entity.indexing + + # Expected cumulative offsets: ball1(4,3), hinge1(1,1), ball2(4,3) + expected_q_offsets = torch.tensor([0, 4, 5], dtype=torch.int, device=device) + expected_v_offsets = torch.tensor([0, 3, 4], dtype=torch.int, device=device) + + assert torch.equal(indexing.q_offsets, expected_q_offsets) + assert torch.equal(indexing.v_offsets, expected_v_offsets) + + # Test expand_to_q_indices for single joint (ball1). + q_indices = indexing.expand_to_q_indices(torch.tensor([0], device=device)) + assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3], device=device)) + + # Test expand_to_v_indices for single joint (ball1). + v_indices = indexing.expand_to_v_indices(torch.tensor([0], device=device)) + assert torch.equal(v_indices, torch.tensor([0, 1, 2], device=device)) + + # Test expand for hinge joint (joint 1). + q_indices = indexing.expand_to_q_indices(torch.tensor([1], device=device)) + assert torch.equal(q_indices, torch.tensor([4], device=device)) + + # Test expand for multiple joints (ball1, ball2). + q_indices = indexing.expand_to_q_indices(torch.tensor([0, 2], device=device)) + assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3, 5, 6, 7, 8], device=device)) + + +def test_joint_qpos_widths(device): + """Test that joint_qpos_widths and joint_dof_widths are correct.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + indexing = entity.indexing + + # ball1, hinge1, ball2 + expected_qpos_widths = torch.tensor([4, 1, 4], dtype=torch.int, device=device) + expected_dof_widths = torch.tensor([3, 1, 3], dtype=torch.int, device=device) + + assert torch.equal(indexing.joint_qpos_widths, expected_qpos_widths) + assert torch.equal(indexing.joint_dof_widths, expected_dof_widths) + + +def test_ball_joint_initial_state(device): + """Test that ball joints default to identity quaternion (1, 0, 0, 0).""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Check default joint positions have correct shape. + default_joint_pos = entity.data.default_joint_pos + assert default_joint_pos.shape == (1, 9), ( + f"Expected shape (1, 9), got {default_joint_pos.shape}" + ) + + # Ball joint 1 (indices 0:4) should be identity quaternion. + ball1_quat = default_joint_pos[0, 0:4] + expected_identity = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device) + assert torch.allclose(ball1_quat, expected_identity, atol=1e-6) + + # Hinge joint (index 4) should be 0. + hinge_pos = default_joint_pos[0, 4] + assert abs(hinge_pos.item()) < 1e-6 + + # Ball joint 2 (indices 5:9) should be identity quaternion. + ball2_quat = default_joint_pos[0, 5:9] + assert torch.allclose(ball2_quat, expected_identity, atol=1e-6) + + +def test_default_joint_vel_shape(device): + """Test that default_joint_vel has correct shape (nv, not num_joints).""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + default_joint_vel = entity.data.default_joint_vel + assert default_joint_vel.shape == (1, 7), ( + f"Expected shape (1, 7), got {default_joint_vel.shape}" + ) + + +def test_joint_pos_target_shape(device): + """Test that joint_pos_target has correct shape (nq) for non-actuated entity.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Non-actuated entities have empty target tensors. + joint_pos_target = entity.data.joint_pos_target + assert joint_pos_target.shape == (1, 0), ( + f"Expected shape (1, 0) for non-actuated entity, got {joint_pos_target.shape}" + ) + + +def test_joint_vel_target_shape(device): + """Test that joint_vel_target has correct shape (nv) for non-actuated entity.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Non-actuated entities have empty target tensors. + joint_vel_target = entity.data.joint_vel_target + assert joint_vel_target.shape == (1, 0), ( + f"Expected shape (1, 0) for non-actuated entity, got {joint_vel_target.shape}" + ) + + +def test_write_joint_position_with_ball_joints(device): + """Test writing joint positions with ball joints.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Write all joint positions. + # ball1 (4) + hinge1 (1) + ball2 (4) = 9 values + new_pos = torch.tensor( + [ + [ + 0.7071, + 0.7071, + 0.0, + 0.0, # ball1: 90 degree rotation around x + 0.5, # hinge1 + 1.0, + 0.0, + 0.0, + 0.0, # ball2: identity + ] + ], + device=device, + ) + + entity.data.write_joint_position(new_pos) + sim.forward() + + # Read back and verify. + joint_pos = entity.data.joint_pos + assert torch.allclose(joint_pos, new_pos, atol=1e-4) + + +def test_write_joint_velocity_with_ball_joints(device): + """Test writing joint velocities with ball joints.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Write all joint velocities. + # ball1 (3) + hinge1 (1) + ball2 (3) = 7 values + new_vel = torch.tensor( + [ + [ + 0.1, + 0.2, + 0.3, # ball1 angular velocity + 0.5, # hinge1 velocity + 0.0, + 0.0, + 0.0, # ball2 angular velocity + ] + ], + device=device, + ) + + entity.data.write_joint_velocity(new_vel) + sim.forward() + + # Read back and verify. + joint_vel = entity.data.joint_vel + assert torch.allclose(joint_vel, new_vel, atol=1e-4) diff --git a/tests/test_events.py b/tests/test_events.py index f2dc7ee47..b89af39ce 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,6 +25,7 @@ def test_reset_joints_by_offset(device): env.device = device mock_entity = Mock() + mock_entity.num_joints = 3 mock_entity.data.default_joint_pos = torch.zeros((2, 3), device=device) mock_entity.data.default_joint_vel = torch.zeros((2, 3), device=device) mock_entity.data.soft_joint_pos_limits = torch.tensor( @@ -36,6 +37,15 @@ def test_reset_joints_by_offset(device): ) mock_entity.write_joint_state_to_sim = Mock() + # Mock indexing for ball joint support (all hinge joints: 1 qpos, 1 dof each). + mock_entity.indexing.expand_to_q_indices = ( + lambda x: slice(None) if isinstance(x, slice) else x + ) + mock_entity.indexing.expand_to_v_indices = ( + lambda x: slice(None) if isinstance(x, slice) else x + ) + mock_entity.indexing.joint_qpos_widths = torch.ones(3, dtype=torch.int, device=device) + env.scene = {"robot": mock_entity} # Test normal offset application. From 47d2b8ca2fd13a1a8e6ca8514ecc2bc8c8d33f22 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 18 Jan 2026 17:49:08 -0800 Subject: [PATCH 5/9] Add MDP unit tests for ball joint support Tests verify observations and rewards correctly use expand_to_q_indices and expand_to_v_indices instead of direct joint_ids indexing. Co-Authored-By: Claude Opus 4.5 --- tests/test_ball_joint.py | 132 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 124 insertions(+), 8 deletions(-) diff --git a/tests/test_ball_joint.py b/tests/test_ball_joint.py index fd452be33..7180a351a 100644 --- a/tests/test_ball_joint.py +++ b/tests/test_ball_joint.py @@ -77,7 +77,7 @@ def test_nq_nv_with_ball_joints(device): Hinge joints have 1 qpos and 1 qvel. """ entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) # Entity has: ball1 (4 qpos, 3 dof), hinge1 (1 qpos, 1 dof), ball2 (4 qpos, 3 dof) # Total: 4 + 1 + 4 = 9 qpos, 3 + 1 + 3 = 7 dof @@ -93,7 +93,7 @@ def test_nq_nv_with_ball_joints(device): def test_nq_nv_hinge_only(device): """Test that nq and nv equal num_joints for hinge-only entities.""" entity = create_hinge_only_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) # Entity has: hinge1 (1 qpos, 1 dof), hinge2 (1 qpos, 1 dof) assert entity.num_joints == 2 @@ -104,7 +104,7 @@ def test_nq_nv_hinge_only(device): def test_joint_offset_tensors(device): """Test that q_offsets and v_offsets are correctly computed.""" entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) indexing = entity.indexing @@ -135,7 +135,7 @@ def test_joint_offset_tensors(device): def test_joint_qpos_widths(device): """Test that joint_qpos_widths and joint_dof_widths are correct.""" entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) indexing = entity.indexing @@ -150,7 +150,7 @@ def test_joint_qpos_widths(device): def test_ball_joint_initial_state(device): """Test that ball joints default to identity quaternion (1, 0, 0, 0).""" entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) # Check default joint positions have correct shape. default_joint_pos = entity.data.default_joint_pos @@ -175,7 +175,7 @@ def test_ball_joint_initial_state(device): def test_default_joint_vel_shape(device): """Test that default_joint_vel has correct shape (nv, not num_joints).""" entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) default_joint_vel = entity.data.default_joint_vel assert default_joint_vel.shape == (1, 7), ( @@ -186,7 +186,7 @@ def test_default_joint_vel_shape(device): def test_joint_pos_target_shape(device): """Test that joint_pos_target has correct shape (nq) for non-actuated entity.""" entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) # Non-actuated entities have empty target tensors. joint_pos_target = entity.data.joint_pos_target @@ -198,7 +198,7 @@ def test_joint_pos_target_shape(device): def test_joint_vel_target_shape(device): """Test that joint_vel_target has correct shape (nv) for non-actuated entity.""" entity = create_ball_joint_entity() - entity, sim = initialize_entity(entity, device) + entity, _ = initialize_entity(entity, device) # Non-actuated entities have empty target tensors. joint_vel_target = entity.data.joint_vel_target @@ -267,3 +267,119 @@ def test_write_joint_velocity_with_ball_joints(device): # Read back and verify. joint_vel = entity.data.joint_vel assert torch.allclose(joint_vel, new_vel, atol=1e-4) + + +## +# MDP function tests with ball joints. +## + + +def test_mdp_joint_pos_rel_with_ball_joints(device): + """Test joint_pos_rel observation with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import observations + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Get relative joint positions. + result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg) + + # Should have shape (1, nq=9), not (1, num_joints=3). + assert result.shape == (1, 9), f"Expected (1, 9), got {result.shape}" + + +def test_mdp_joint_vel_rel_with_ball_joints(device): + """Test joint_vel_rel observation with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import observations + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Get relative joint velocities. + result = observations.joint_vel_rel(env, asset_cfg=asset_cfg) + + # Should have shape (1, nv=7), not (1, num_joints=3). + assert result.shape == (1, 7), f"Expected (1, 7), got {result.shape}" + + +def test_mdp_joint_vel_l2_with_ball_joints(device): + """Test joint_vel_l2 reward with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import rewards + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Compute joint velocity L2 penalty. + result = rewards.joint_vel_l2(env, asset_cfg=asset_cfg) + + # Should return scalar per env. + assert result.shape == (1,), f"Expected (1,), got {result.shape}" + + +def test_mdp_joint_acc_l2_with_ball_joints(device): + """Test joint_acc_l2 reward with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import rewards + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Compute joint acceleration L2 penalty. + result = rewards.joint_acc_l2(env, asset_cfg=asset_cfg) + + # Should return scalar per env. + assert result.shape == (1,), f"Expected (1,), got {result.shape}" + + +def test_mdp_joint_pos_limits_with_ball_joints(device): + """Test joint_pos_limits reward with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import rewards + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.device = device + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Compute joint position limits penalty. + result = rewards.joint_pos_limits(env, asset_cfg=asset_cfg) + + # Should return scalar per env. + assert result.shape == (1,), f"Expected (1,), got {result.shape}" From 0da422262040bbd20cdf93cf32e8de246d68c018 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 18 Jan 2026 18:54:10 -0800 Subject: [PATCH 6/9] Fix MDP functions for ball joint support - observations: joint_pos_rel, joint_vel_rel use expanded indices - rewards: joint_vel_l2, joint_acc_l2, joint_pos_limits, posture, electrical_power_cost - events: reset_joints_by_offset, randomize_encoder_bias, _get_entity_indices --- src/mjlab/envs/mdp/events.py | 52 ++++++++++++++------- src/mjlab/envs/mdp/observations.py | 8 ++-- src/mjlab/envs/mdp/rewards.py | 50 +++++++++++++------- src/mjlab/tasks/manipulation/mdp/rewards.py | 3 +- src/mjlab/tasks/velocity/mdp/rewards.py | 21 ++++++--- 5 files changed, 90 insertions(+), 44 deletions(-) diff --git a/src/mjlab/envs/mdp/events.py b/src/mjlab/envs/mdp/events.py index a297424f7..b24ac0fe1 100644 --- a/src/mjlab/envs/mdp/events.py +++ b/src/mjlab/envs/mdp/events.py @@ -186,17 +186,33 @@ def reset_joints_by_offset( soft_joint_pos_limits = asset.data.soft_joint_pos_limits assert soft_joint_pos_limits is not None - joint_pos = default_joint_pos[env_ids][:, asset_cfg.joint_ids].clone() + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + + joint_pos = default_joint_pos[env_ids][:, q_indices].clone() joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device) - joint_pos_limits = soft_joint_pos_limits[env_ids][:, asset_cfg.joint_ids] - joint_pos = joint_pos.clamp_(joint_pos_limits[..., 0], joint_pos_limits[..., 1]) - joint_vel = default_joint_vel[env_ids][:, asset_cfg.joint_ids].clone() + joint_ids = asset_cfg.joint_ids + if isinstance(joint_ids, slice): + joint_ids_tensor = torch.arange(asset.num_joints, device=env.device) + elif isinstance(joint_ids, list): + joint_ids_tensor = torch.tensor(joint_ids, device=env.device) + else: + joint_ids_tensor = joint_ids + + # Expand limits to qpos dimensions for ball joints. + joint_limits = soft_joint_pos_limits[env_ids][:, joint_ids_tensor] + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + expanded_limits = joint_limits.repeat_interleave(qpos_widths, dim=1) + joint_pos = joint_pos.clamp_(expanded_limits[..., 0], expanded_limits[..., 1]) + + joint_vel = default_joint_vel[env_ids][:, v_indices].clone() joint_vel += sample_uniform(*velocity_range, joint_vel.shape, env.device) - joint_ids = asset_cfg.joint_ids - if isinstance(joint_ids, list): - joint_ids = torch.tensor(joint_ids, device=env.device) + if isinstance(asset_cfg.joint_ids, list): + joint_ids = torch.tensor(asset_cfg.joint_ids, device=env.device) + else: + joint_ids = asset_cfg.joint_ids asset.write_joint_state_to_sim( joint_pos.view(len(env_ids), -1), @@ -376,9 +392,11 @@ def _get_entity_indices( ) -> torch.Tensor: match spec.entity_type: case "dof": - return indexing.joint_v_adr[asset_cfg.joint_ids] + v_indices = indexing.expand_to_v_indices(asset_cfg.joint_ids) + return indexing.joint_v_adr[v_indices] case "joint" if spec.use_address: - return indexing.joint_q_adr[asset_cfg.joint_ids] + q_indices = indexing.expand_to_q_indices(asset_cfg.joint_ids) + return indexing.joint_q_adr[q_indices] case "joint": return indexing.joint_ids[asset_cfg.joint_ids] case "body": @@ -741,24 +759,24 @@ def randomize_encoder_bias( env_ids = env_ids.to(env.device, dtype=torch.int) joint_ids = asset_cfg.joint_ids - if isinstance(joint_ids, slice): - num_joints = asset.num_joints - joint_ids_tensor = torch.arange(num_joints, device=env.device) + q_indices = asset.indexing.expand_to_q_indices(joint_ids) + + if isinstance(q_indices, slice): + nq = asset.nq else: - joint_ids_tensor = torch.tensor(joint_ids, device=env.device) + nq = len(q_indices) - num_joints = len(joint_ids_tensor) bias_samples = sample_uniform( torch.tensor(bias_range[0], device=env.device), torch.tensor(bias_range[1], device=env.device), - (len(env_ids), num_joints), + (len(env_ids), nq), env.device, ) - if isinstance(joint_ids, slice): + if isinstance(q_indices, slice): asset.data.encoder_bias[env_ids] = bias_samples else: - asset.data.encoder_bias[env_ids[:, None], joint_ids_tensor] = bias_samples + asset.data.encoder_bias[env_ids[:, None], q_indices] = bias_samples def sync_actuator_delays( diff --git a/src/mjlab/envs/mdp/observations.py b/src/mjlab/envs/mdp/observations.py index 8ed528eb1..7f5de6677 100644 --- a/src/mjlab/envs/mdp/observations.py +++ b/src/mjlab/envs/mdp/observations.py @@ -56,9 +56,9 @@ def joint_pos_rel( asset: Entity = env.scene[asset_cfg.name] default_joint_pos = asset.data.default_joint_pos assert default_joint_pos is not None - jnt_ids = asset_cfg.joint_ids + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) joint_pos = asset.data.joint_pos_biased if biased else asset.data.joint_pos - return joint_pos[:, jnt_ids] - default_joint_pos[:, jnt_ids] + return joint_pos[:, q_indices] - default_joint_pos[:, q_indices] def joint_vel_rel( @@ -68,8 +68,8 @@ def joint_vel_rel( asset: Entity = env.scene[asset_cfg.name] default_joint_vel = asset.data.default_joint_vel assert default_joint_vel is not None - jnt_ids = asset_cfg.joint_ids - return asset.data.joint_vel[:, jnt_ids] - default_joint_vel[:, jnt_ids] + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + return asset.data.joint_vel[:, v_indices] - default_joint_vel[:, v_indices] ## diff --git a/src/mjlab/envs/mdp/rewards.py b/src/mjlab/envs/mdp/rewards.py index edfe5a4f1..816ef1114 100644 --- a/src/mjlab/envs/mdp/rewards.py +++ b/src/mjlab/envs/mdp/rewards.py @@ -42,7 +42,8 @@ def joint_vel_l2( ) -> torch.Tensor: """Penalize joint velocities on the articulation using L2 squared kernel.""" asset: Entity = env.scene[asset_cfg.name] - return torch.sum(torch.square(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1) + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + return torch.sum(torch.square(asset.data.joint_vel[:, v_indices]), dim=1) def joint_acc_l2( @@ -50,7 +51,8 @@ def joint_acc_l2( ) -> torch.Tensor: """Penalize joint accelerations on the articulation using L2 squared kernel.""" asset: Entity = env.scene[asset_cfg.name] - return torch.sum(torch.square(asset.data.joint_acc[:, asset_cfg.joint_ids]), dim=1) + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + return torch.sum(torch.square(asset.data.joint_acc[:, v_indices]), dim=1) def action_rate_l2(env: ManagerBasedRlEnv) -> torch.Tensor: @@ -77,14 +79,25 @@ def joint_pos_limits( asset: Entity = env.scene[asset_cfg.name] soft_joint_pos_limits = asset.data.soft_joint_pos_limits assert soft_joint_pos_limits is not None - out_of_limits = -( - asset.data.joint_pos[:, asset_cfg.joint_ids] - - soft_joint_pos_limits[:, asset_cfg.joint_ids, 0] - ).clip(max=0.0) - out_of_limits += ( - asset.data.joint_pos[:, asset_cfg.joint_ids] - - soft_joint_pos_limits[:, asset_cfg.joint_ids, 1] - ).clip(min=0.0) + + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + joint_pos = asset.data.joint_pos[:, q_indices] + + joint_ids = asset_cfg.joint_ids + if isinstance(joint_ids, slice): + joint_ids_tensor = torch.arange(asset.num_joints, device=env.device) + elif isinstance(joint_ids, list): + joint_ids_tensor = torch.tensor(joint_ids, device=env.device) + else: + joint_ids_tensor = joint_ids + + # Expand limits to qpos dimensions for ball joints. + limits = soft_joint_pos_limits[:, joint_ids_tensor] + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + expanded_limits = limits.repeat_interleave(qpos_widths, dim=1) + + out_of_limits = -(joint_pos - expanded_limits[..., 0]).clip(max=0.0) + out_of_limits += (joint_pos - expanded_limits[..., 1]).clip(min=0.0) return torch.sum(out_of_limits, dim=1) @@ -101,25 +114,29 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): assert default_joint_pos is not None self.default_joint_pos = default_joint_pos - _, joint_names = asset.find_joints( + joint_ids, joint_names = asset.find_joints( cfg.params["asset_cfg"].joint_names, ) + self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) _, _, std = resolve_matching_names_values( data=cfg.params["std"], list_of_strings=joint_names, ) self.std = torch.tensor(std, device=env.device, dtype=torch.float32) + self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids] def __call__( self, env: ManagerBasedRlEnv, std, asset_cfg: SceneEntityCfg ) -> torch.Tensor: del std # Unused. asset: Entity = env.scene[asset_cfg.name] - current_joint_pos = asset.data.joint_pos[:, asset_cfg.joint_ids] - desired_joint_pos = self.default_joint_pos[:, asset_cfg.joint_ids] + q_indices = asset.indexing.expand_to_q_indices(self._joint_ids) + current_joint_pos = asset.data.joint_pos[:, q_indices] + desired_joint_pos = self.default_joint_pos[:, q_indices] error_squared = torch.square(current_joint_pos - desired_joint_pos) - return torch.exp(-torch.mean(error_squared / (self.std**2), dim=1)) + std_expanded = self.std.repeat_interleave(self._joint_qpos_widths) + return torch.exp(-torch.mean(error_squared / (std_expanded**2), dim=1)) class electrical_power_cost: @@ -134,13 +151,14 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): actuator_ids, _ = asset.find_actuators( cfg.params["asset_cfg"].joint_names, ) - self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) + joint_ids_tensor = torch.tensor(joint_ids, device=env.device, dtype=torch.long) + self._v_indices = asset.indexing.expand_to_v_indices(joint_ids_tensor) self._actuator_ids = torch.tensor(actuator_ids, device=env.device, dtype=torch.long) def __call__(self, env: ManagerBasedRlEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: asset: Entity = env.scene[asset_cfg.name] tau = asset.data.actuator_force[:, self._actuator_ids] - qd = asset.data.joint_vel[:, self._joint_ids] + qd = asset.data.joint_vel[:, self._v_indices] mech = tau * qd mech_pos = torch.clamp(mech, min=0.0) # Don't penalize regen. return torch.sum(mech_pos, dim=1) diff --git a/src/mjlab/tasks/manipulation/mdp/rewards.py b/src/mjlab/tasks/manipulation/mdp/rewards.py index b5792cf94..1430d605f 100644 --- a/src/mjlab/tasks/manipulation/mdp/rewards.py +++ b/src/mjlab/tasks/manipulation/mdp/rewards.py @@ -64,6 +64,7 @@ def joint_velocity_hinge_penalty( penalty, shaped as the negative squared L2 norm of the excess velocities. """ robot: Entity = env.scene[asset_cfg.name] - joint_vel = robot.data.joint_vel[:, asset_cfg.joint_ids] + v_indices = robot.indexing.expand_to_v_indices(asset_cfg.joint_ids) + joint_vel = robot.data.joint_vel[:, v_indices] excess = (joint_vel.abs() - max_vel).clamp_min(0.0) return (excess**2).sum(dim=-1) diff --git a/src/mjlab/tasks/velocity/mdp/rewards.py b/src/mjlab/tasks/velocity/mdp/rewards.py index ff0199e53..3ea15396e 100644 --- a/src/mjlab/tasks/velocity/mdp/rewards.py +++ b/src/mjlab/tasks/velocity/mdp/rewards.py @@ -310,7 +310,8 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): assert default_joint_pos is not None self.default_joint_pos = default_joint_pos - _, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names) + joint_ids, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names) + self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) _, _, std_standing = resolve_matching_names_values( data=cfg.params["std_standing"], @@ -332,6 +333,8 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): ) self.std_running = torch.tensor(std_running, device=env.device, dtype=torch.float32) + self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids] + def __call__( self, env: ManagerBasedRlEnv, @@ -359,14 +362,20 @@ def __call__( ).float() running_mask = (total_speed >= running_threshold).float() + # Expand std to qpos dimensions for ball joints. + std_standing_exp = self.std_standing.repeat_interleave(self._joint_qpos_widths) + std_walking_exp = self.std_walking.repeat_interleave(self._joint_qpos_widths) + std_running_exp = self.std_running.repeat_interleave(self._joint_qpos_widths) + std = ( - self.std_standing * standing_mask.unsqueeze(1) - + self.std_walking * walking_mask.unsqueeze(1) - + self.std_running * running_mask.unsqueeze(1) + std_standing_exp * standing_mask.unsqueeze(1) + + std_walking_exp * walking_mask.unsqueeze(1) + + std_running_exp * running_mask.unsqueeze(1) ) - current_joint_pos = asset.data.joint_pos[:, asset_cfg.joint_ids] - desired_joint_pos = self.default_joint_pos[:, asset_cfg.joint_ids] + q_indices = asset.indexing.expand_to_q_indices(self._joint_ids) + current_joint_pos = asset.data.joint_pos[:, q_indices] + desired_joint_pos = self.default_joint_pos[:, q_indices] error_squared = torch.square(current_joint_pos - desired_joint_pos) return torch.exp(-torch.mean(error_squared / (std**2), dim=1)) From bc24381f6adc5bc3a1abb35e55099cfc4c292b35 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Mon, 19 Jan 2026 03:02:44 -0800 Subject: [PATCH 7/9] Fix quaternion substraction issue in events, obs, and rewards. Trakcing still needs to be fixed --- src/mjlab/envs/mdp/events.py | 93 ++++++++++++++++++++++--- src/mjlab/envs/mdp/observations.py | 67 +++++++++++++++++- src/mjlab/envs/mdp/rewards.py | 65 ++++++++++++++--- src/mjlab/tasks/velocity/mdp/rewards.py | 86 ++++++++++++++++------- 4 files changed, 267 insertions(+), 44 deletions(-) diff --git a/src/mjlab/envs/mdp/events.py b/src/mjlab/envs/mdp/events.py index b24ac0fe1..bca88d1e0 100644 --- a/src/mjlab/envs/mdp/events.py +++ b/src/mjlab/envs/mdp/events.py @@ -10,6 +10,7 @@ from mjlab.entity import Entity, EntityIndexing from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.utils.lab_api.math import ( + quat_from_angle_axis, quat_from_euler_xyz, quat_mul, sample_gaussian, @@ -175,6 +176,11 @@ def reset_joints_by_offset( velocity_range: tuple[float, float], asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, ) -> None: + """Reset joint positions with random offsets from default. + + For hinge joints: adds uniform noise and clamps to limits. + For ball joints: applies random rotation via quaternion multiplication. + """ if env_ids is None: env_ids = torch.arange(env.num_envs, device=env.device, dtype=torch.int) @@ -186,12 +192,8 @@ def reset_joints_by_offset( soft_joint_pos_limits = asset.data.soft_joint_pos_limits assert soft_joint_pos_limits is not None - q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) - joint_pos = default_joint_pos[env_ids][:, q_indices].clone() - joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device) - joint_ids = asset_cfg.joint_ids if isinstance(joint_ids, slice): joint_ids_tensor = torch.arange(asset.num_joints, device=env.device) @@ -200,11 +202,84 @@ def reset_joints_by_offset( else: joint_ids_tensor = joint_ids - # Expand limits to qpos dimensions for ball joints. - joint_limits = soft_joint_pos_limits[env_ids][:, joint_ids_tensor] - qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] - expanded_limits = joint_limits.repeat_interleave(qpos_widths, dim=1) - joint_pos = joint_pos.clamp_(expanded_limits[..., 0], expanded_limits[..., 1]) + # Hinge-only entities: use simple add + clamp. + if asset.nq == asset.nv: + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + joint_pos = default_joint_pos[env_ids][:, q_indices].clone() + joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device) + + joint_limits = soft_joint_pos_limits[env_ids][:, joint_ids_tensor] + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + expanded_limits = joint_limits.repeat_interleave(qpos_widths, dim=1) + joint_pos = joint_pos.clamp_(expanded_limits[..., 0], expanded_limits[..., 1]) + else: + # Mixed ball/hinge joints: handle separately. + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + dof_widths = asset.indexing.joint_dof_widths[joint_ids_tensor] + q_offsets = asset.indexing.q_offsets[joint_ids_tensor] + + num_envs = len(env_ids) + device = env.device + total_qpos = int(qpos_widths.sum().item()) + joint_pos = torch.zeros((num_envs, total_qpos), device=device) + + is_ball_joint = (qpos_widths == 4) & (dof_widths == 3) + qpos_cumsum = torch.cat( + [ + torch.zeros(1, device=device, dtype=qpos_widths.dtype), + qpos_widths.cumsum(0)[:-1], + ] + ) + + # Ball joints: sample rotation angle, random axis, apply via quat_mul. + if is_ball_joint.any(): + ball_q_offsets = q_offsets[is_ball_joint] + ball_out_offsets = qpos_cumsum[is_ball_joint] + num_ball = int(is_ball_joint.sum().item()) + + # Get default quaternions. + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + default_quats = default_joint_pos[env_ids][:, ball_q_indices.flatten()].view( + num_envs, num_ball, 4 + ) + + # Sample rotation angles from position_range. + angles = sample_uniform( + position_range[0], position_range[1], (num_envs * num_ball,), device + ) + + # Sample random rotation axes (unit vectors). + axes = torch.randn((num_envs * num_ball, 3), device=device) + axes = axes / axes.norm(dim=-1, keepdim=True) + + # Create delta quaternions. + delta_quats = quat_from_angle_axis(angles, axes) + + # Apply rotation: q_new = q_default * q_delta. + new_quats = quat_mul(default_quats.reshape(-1, 4), delta_quats).view( + num_envs, num_ball, 4 + ) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(4, device=device) + joint_pos[:, ball_out_indices.flatten()] = new_quats.reshape(num_envs, -1) + + # Hinge joints: add noise + clamp. + is_hinge_joint = ~is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = q_offsets[is_hinge_joint] + hinge_out_offsets = qpos_cumsum[is_hinge_joint] + hinge_joint_ids = joint_ids_tensor[is_hinge_joint] + num_hinge = int(is_hinge_joint.sum().item()) + + default_hinge = default_joint_pos[env_ids][:, hinge_q_offsets] + noise = sample_uniform(*position_range, (num_envs, num_hinge), device) + hinge_pos = default_hinge + noise + + # Clamp to limits. + hinge_limits = soft_joint_pos_limits[env_ids][:, hinge_joint_ids] + hinge_pos = hinge_pos.clamp_(hinge_limits[..., 0], hinge_limits[..., 1]) + + joint_pos[:, hinge_out_offsets] = hinge_pos joint_vel = default_joint_vel[env_ids][:, v_indices].clone() joint_vel += sample_uniform(*velocity_range, joint_vel.shape, env.device) diff --git a/src/mjlab/envs/mdp/observations.py b/src/mjlab/envs/mdp/observations.py index 7f5de6677..f5de17cd1 100644 --- a/src/mjlab/envs/mdp/observations.py +++ b/src/mjlab/envs/mdp/observations.py @@ -9,6 +9,7 @@ from mjlab.entity import Entity from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sensor import BuiltinSensor +from mjlab.utils.lab_api.math import quat_box_minus if TYPE_CHECKING: from mjlab.envs import ManagerBasedRlEnv @@ -53,12 +54,74 @@ def joint_pos_rel( biased: bool = False, asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, ) -> torch.Tensor: + """Compute relative joint positions (current - default) in DOF space. + + For hinge/slide joints: returns scalar difference. + For ball joints: returns 3D axis-angle difference via quat_box_minus. + + Returns: + Tensor of shape (num_envs, total_dof) for selected joints. + """ asset: Entity = env.scene[asset_cfg.name] default_joint_pos = asset.data.default_joint_pos assert default_joint_pos is not None - q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) joint_pos = asset.data.joint_pos_biased if biased else asset.data.joint_pos - return joint_pos[:, q_indices] - default_joint_pos[:, q_indices] + + # Hinge-only entities (nq == nv means no ball joints). + if asset.nq == asset.nv: + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + return joint_pos[:, q_indices] - default_joint_pos[:, q_indices] + + # Ball joints. + joint_ids = asset_cfg.joint_ids + if isinstance(joint_ids, slice): + joint_ids_tensor = torch.arange(asset.num_joints, device=joint_pos.device) + elif isinstance(joint_ids, list): + joint_ids_tensor = torch.tensor(joint_ids, device=joint_pos.device) + else: + joint_ids_tensor = joint_ids + + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + dof_widths = asset.indexing.joint_dof_widths[joint_ids_tensor] + q_offsets = asset.indexing.q_offsets[joint_ids_tensor] + + num_envs = joint_pos.shape[0] + device = joint_pos.device + + is_ball_joint = (qpos_widths == 4) & (dof_widths == 3) + dof_cumsum = torch.cat( + [torch.zeros(1, device=device, dtype=dof_widths.dtype), dof_widths.cumsum(0)[:-1]] + ) + total_dof = int(dof_widths.sum().item()) + result = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype) + + # Ball joints: quaternion difference via quat_box_minus. + if is_ball_joint.any(): + ball_q_offsets = q_offsets[is_ball_joint] + ball_out_offsets = dof_cumsum[is_ball_joint] + + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + default_quats = default_joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + + num_ball = current_quats.shape[1] + diff_flat = quat_box_minus( + current_quats.reshape(-1, 4), default_quats.reshape(-1, 4) + ) + diff = diff_flat.view(num_envs, num_ball, 3) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device) + result[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1) + + # Hinge joints: scalar subtraction. + is_hinge_joint = ~is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = q_offsets[is_hinge_joint] + hinge_out_offsets = dof_cumsum[is_hinge_joint] + diff_hinge = joint_pos[:, hinge_q_offsets] - default_joint_pos[:, hinge_q_offsets] + result[:, hinge_out_offsets] = diff_hinge + + return result def joint_vel_rel( diff --git a/src/mjlab/envs/mdp/rewards.py b/src/mjlab/envs/mdp/rewards.py index 816ef1114..50ee4cfa0 100644 --- a/src/mjlab/envs/mdp/rewards.py +++ b/src/mjlab/envs/mdp/rewards.py @@ -9,6 +9,7 @@ from mjlab.entity import Entity from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg +from mjlab.utils.lab_api.math import quat_box_minus from mjlab.utils.lab_api.string import ( resolve_matching_names_values, ) @@ -114,9 +115,7 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): assert default_joint_pos is not None self.default_joint_pos = default_joint_pos - joint_ids, joint_names = asset.find_joints( - cfg.params["asset_cfg"].joint_names, - ) + joint_ids, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names) self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) _, _, std = resolve_matching_names_values( @@ -124,19 +123,67 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): list_of_strings=joint_names, ) self.std = torch.tensor(std, device=env.device, dtype=torch.float32) + + self._joint_dof_widths = asset.indexing.joint_dof_widths[self._joint_ids] self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids] + self._q_offsets = asset.indexing.q_offsets[self._joint_ids] + self._is_ball_joint = (self._joint_qpos_widths == 4) & (self._joint_dof_widths == 3) + + # Precompute cumulative DOF offsets for output indexing. + self._dof_cumsum = torch.cat( + [ + torch.zeros(1, device=env.device, dtype=self._joint_dof_widths.dtype), + self._joint_dof_widths.cumsum(0)[:-1], + ] + ) + self._std_expanded = self.std.repeat_interleave(self._joint_dof_widths) def __call__( self, env: ManagerBasedRlEnv, std, asset_cfg: SceneEntityCfg ) -> torch.Tensor: del std # Unused. + asset: Entity = env.scene[asset_cfg.name] - q_indices = asset.indexing.expand_to_q_indices(self._joint_ids) - current_joint_pos = asset.data.joint_pos[:, q_indices] - desired_joint_pos = self.default_joint_pos[:, q_indices] - error_squared = torch.square(current_joint_pos - desired_joint_pos) - std_expanded = self.std.repeat_interleave(self._joint_qpos_widths) - return torch.exp(-torch.mean(error_squared / (std_expanded**2), dim=1)) + joint_pos = asset.data.joint_pos + num_envs = joint_pos.shape[0] + device = joint_pos.device + + total_dof = int(self._joint_dof_widths.sum().item()) + error = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype) + + # Ball joints: quaternion difference. + if self._is_ball_joint.any(): + ball_q_offsets = self._q_offsets[self._is_ball_joint] + ball_out_offsets = self._dof_cumsum[self._is_ball_joint] + + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + default_quats = self.default_joint_pos[:, ball_q_indices.flatten()].view( + num_envs, -1, 4 + ) + + num_ball = current_quats.shape[1] + diff_flat = quat_box_minus( + current_quats.reshape(-1, 4), default_quats.reshape(-1, 4) + ) + diff = diff_flat.view(num_envs, num_ball, 3) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device) + error[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1) + + # Hinge joints: scalar subtraction. + is_hinge_joint = ~self._is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = self._q_offsets[is_hinge_joint] + hinge_out_offsets = self._dof_cumsum[is_hinge_joint] + + diff_hinge = ( + joint_pos[:, hinge_q_offsets] - self.default_joint_pos[:, hinge_q_offsets] + ) + error[:, hinge_out_offsets] = diff_hinge + + error_squared = torch.square(error) + return torch.exp(-torch.mean(error_squared / (self._std_expanded**2), dim=1)) class electrical_power_cost: diff --git a/src/mjlab/tasks/velocity/mdp/rewards.py b/src/mjlab/tasks/velocity/mdp/rewards.py index 3ea15396e..672ff32cb 100644 --- a/src/mjlab/tasks/velocity/mdp/rewards.py +++ b/src/mjlab/tasks/velocity/mdp/rewards.py @@ -8,7 +8,7 @@ from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sensor import BuiltinSensor, ContactSensor -from mjlab.utils.lab_api.math import quat_apply_inverse +from mjlab.utils.lab_api.math import quat_apply_inverse, quat_box_minus from mjlab.utils.lab_api.string import ( resolve_matching_names_values, ) @@ -314,26 +314,32 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) _, _, std_standing = resolve_matching_names_values( - data=cfg.params["std_standing"], - list_of_strings=joint_names, + data=cfg.params["std_standing"], list_of_strings=joint_names ) - self.std_standing = torch.tensor( - std_standing, device=env.device, dtype=torch.float32 - ) - _, _, std_walking = resolve_matching_names_values( - data=cfg.params["std_walking"], - list_of_strings=joint_names, + data=cfg.params["std_walking"], list_of_strings=joint_names ) - self.std_walking = torch.tensor(std_walking, device=env.device, dtype=torch.float32) - _, _, std_running = resolve_matching_names_values( - data=cfg.params["std_running"], - list_of_strings=joint_names, + data=cfg.params["std_running"], list_of_strings=joint_names ) + self.std_standing = torch.tensor( + std_standing, device=env.device, dtype=torch.float32 + ) + self.std_walking = torch.tensor(std_walking, device=env.device, dtype=torch.float32) self.std_running = torch.tensor(std_running, device=env.device, dtype=torch.float32) + self._joint_dof_widths = asset.indexing.joint_dof_widths[self._joint_ids] self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids] + self._q_offsets = asset.indexing.q_offsets[self._joint_ids] + self._is_ball_joint = (self._joint_qpos_widths == 4) & (self._joint_dof_widths == 3) + + # Precompute cumulative DOF offsets for output indexing. + self._dof_cumsum = torch.cat( + [ + torch.zeros(1, device=env.device, dtype=self._joint_dof_widths.dtype), + self._joint_dof_widths.cumsum(0)[:-1], + ] + ) def __call__( self, @@ -352,9 +358,7 @@ def __call__( command = env.command_manager.get_command(command_name) assert command is not None - linear_speed = torch.norm(command[:, :2], dim=1) - angular_speed = torch.abs(command[:, 2]) - total_speed = linear_speed + angular_speed + total_speed = torch.norm(command[:, :2], dim=1) + torch.abs(command[:, 2]) standing_mask = (total_speed < walking_threshold).float() walking_mask = ( @@ -363,9 +367,9 @@ def __call__( running_mask = (total_speed >= running_threshold).float() # Expand std to qpos dimensions for ball joints. - std_standing_exp = self.std_standing.repeat_interleave(self._joint_qpos_widths) - std_walking_exp = self.std_walking.repeat_interleave(self._joint_qpos_widths) - std_running_exp = self.std_running.repeat_interleave(self._joint_qpos_widths) + std_standing_exp = self.std_standing.repeat_interleave(self._joint_dof_widths) + std_walking_exp = self.std_walking.repeat_interleave(self._joint_dof_widths) + std_running_exp = self.std_running.repeat_interleave(self._joint_dof_widths) std = ( std_standing_exp * standing_mask.unsqueeze(1) @@ -373,9 +377,43 @@ def __call__( + std_running_exp * running_mask.unsqueeze(1) ) - q_indices = asset.indexing.expand_to_q_indices(self._joint_ids) - current_joint_pos = asset.data.joint_pos[:, q_indices] - desired_joint_pos = self.default_joint_pos[:, q_indices] - error_squared = torch.square(current_joint_pos - desired_joint_pos) - + joint_pos = asset.data.joint_pos + num_envs = joint_pos.shape[0] + device = joint_pos.device + + total_dof = int(self._joint_dof_widths.sum().item()) + error = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype) + + # Ball joints: quaternion difference. + if self._is_ball_joint.any(): + ball_q_offsets = self._q_offsets[self._is_ball_joint] + ball_out_offsets = self._dof_cumsum[self._is_ball_joint] + + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + default_quats = self.default_joint_pos[:, ball_q_indices.flatten()].view( + num_envs, -1, 4 + ) + + num_ball = current_quats.shape[1] + diff_flat = quat_box_minus( + current_quats.reshape(-1, 4), default_quats.reshape(-1, 4) + ) + diff = diff_flat.view(num_envs, num_ball, 3) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device) + error[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1) + + # Hinge joints: scalar subtraction. + is_hinge_joint = ~self._is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = self._q_offsets[is_hinge_joint] + hinge_out_offsets = self._dof_cumsum[is_hinge_joint] + + diff_hinge = ( + joint_pos[:, hinge_q_offsets] - self.default_joint_pos[:, hinge_q_offsets] + ) + error[:, hinge_out_offsets] = diff_hinge + + error_squared = torch.square(error) return torch.exp(-torch.mean(error_squared / (std**2), dim=1)) From 3613ccc8af0825da634f2eef1c0c9d7b0a911865 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Mon, 19 Jan 2026 03:05:27 -0800 Subject: [PATCH 8/9] Modify ball joint tests Co-Authored-By: Claude Opus 4.5 --- tests/test_ball_joint.py | 59 ++++++++++++++++++++++++++++++++++++++-- tests/test_rewards.py | 3 +- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/tests/test_ball_joint.py b/tests/test_ball_joint.py index 7180a351a..82e0d6943 100644 --- a/tests/test_ball_joint.py +++ b/tests/test_ball_joint.py @@ -117,18 +117,22 @@ def test_joint_offset_tensors(device): # Test expand_to_q_indices for single joint (ball1). q_indices = indexing.expand_to_q_indices(torch.tensor([0], device=device)) + assert isinstance(q_indices, torch.Tensor) assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3], device=device)) # Test expand_to_v_indices for single joint (ball1). v_indices = indexing.expand_to_v_indices(torch.tensor([0], device=device)) + assert isinstance(v_indices, torch.Tensor) assert torch.equal(v_indices, torch.tensor([0, 1, 2], device=device)) # Test expand for hinge joint (joint 1). q_indices = indexing.expand_to_q_indices(torch.tensor([1], device=device)) + assert isinstance(q_indices, torch.Tensor) assert torch.equal(q_indices, torch.tensor([4], device=device)) # Test expand for multiple joints (ball1, ball2). q_indices = indexing.expand_to_q_indices(torch.tensor([0, 2], device=device)) + assert isinstance(q_indices, torch.Tensor) assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3, 5, 6, 7, 8], device=device)) @@ -292,8 +296,12 @@ def test_mdp_joint_pos_rel_with_ball_joints(device): # Get relative joint positions. result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg) - # Should have shape (1, nq=9), not (1, num_joints=3). - assert result.shape == (1, 9), f"Expected (1, 9), got {result.shape}" + # Should have shape (1, nv=7) in DOF space, not (1, nq=9) in qpos space. + # Ball joints contribute 3 DOF (axis-angle) instead of 4 qpos (quaternion). + assert result.shape == (1, 7), f"Expected (1, 7), got {result.shape}" + + # At default position, all relative positions should be zero. + assert torch.allclose(result, torch.zeros_like(result), atol=1e-6) def test_mdp_joint_vel_rel_with_ball_joints(device): @@ -383,3 +391,50 @@ def test_mdp_joint_pos_limits_with_ball_joints(device): # Should return scalar per env. assert result.shape == (1,), f"Expected (1,), got {result.shape}" + + +def test_quaternion_difference_correctness(device): + """Test that joint_pos_rel computes correct quaternion difference. + + For a 90-degree rotation around z-axis, the axis-angle representation + should be [0, 0, pi/2] (rotation axis z, angle 90 degrees). + """ + import math + from unittest.mock import Mock + + from mjlab.envs.mdp import observations + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Apply a 90-degree rotation around z-axis to ball joint 1. + # Quaternion for 90 deg around z: (cos(45°), 0, 0, sin(45°)) = (0.7071, 0, 0, 0.7071) + angle = math.pi / 2 + half_angle = angle / 2 + new_pos = entity.data.joint_pos.clone() + new_pos[0, 0:4] = torch.tensor( + [math.cos(half_angle), 0.0, 0.0, math.sin(half_angle)], device=device + ) + entity.data.write_joint_position(new_pos) + sim.forward() + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg) + + # Ball joint 1 (indices 0:3 in DOF space) should have axis-angle [0, 0, pi/2]. + ball1_diff = result[0, 0:3] + expected_axis_angle = torch.tensor([0.0, 0.0, math.pi / 2], device=device) + assert torch.allclose(ball1_diff, expected_axis_angle, atol=1e-4), ( + f"Expected axis-angle {expected_axis_angle}, got {ball1_diff}" + ) + + # Hinge joint (index 3 in DOF space) should be unchanged (0). + assert abs(result[0, 3].item()) < 1e-6 + + # Ball joint 2 (indices 4:7 in DOF space) should be unchanged (identity -> 0). + ball2_diff = result[0, 4:7] + assert torch.allclose(ball2_diff, torch.zeros(3, device=device), atol=1e-6) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index f2ad35782..2d47e361b 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -199,7 +199,8 @@ def test_electrical_power_cost_partially_actuated(device): reward = electrical_power_cost(reward_cfg, env) - assert len(reward._joint_ids) == 2 + assert isinstance(reward._v_indices, torch.Tensor) + assert len(reward._v_indices) == 2 assert len(reward._actuator_ids) == 2 # Test case 1: All forces and velocities aligned (all positive work). From 7d4da982191f3607eea48671b28f76929b848fa5 Mon Sep 17 00:00:00 2001 From: mohitgadde Date: Sun, 1 Mar 2026 15:16:02 -0800 Subject: [PATCH 9/9] Fix test mock missing nq/nv attributes for ball joint guard --- tests/test_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 970ce86ce..82eba6a9b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -429,6 +429,8 @@ def test_reset_joints_by_offset(device): mock_entity = Mock() mock_entity.num_joints = 3 + mock_entity.nq = 3 + mock_entity.nv = 3 mock_entity.data.default_joint_pos = torch.zeros((2, 3), device=device) mock_entity.data.default_joint_vel = torch.zeros((2, 3), device=device) mock_entity.data.soft_joint_pos_limits = torch.tensor(