diff --git a/src/mjlab/actuator/actuator.py b/src/mjlab/actuator/actuator.py index 1f1c4340e..8196a9ea8 100644 --- a/src/mjlab/actuator/actuator.py +++ b/src/mjlab/actuator/actuator.py @@ -114,6 +114,9 @@ def __init__( self._global_ctrl_ids: torch.Tensor | None = None self._mjs_actuators: list[mujoco.MjsActuator] = [] self._site_zeros: torch.Tensor | None = None + # Expanded indices for ball joint support. + self._q_indices: torch.Tensor | None = None + self._v_indices: torch.Tensor | None = None @property def target_ids(self) -> torch.Tensor: @@ -189,6 +192,15 @@ def initialize( device=device, ) + # Compute expanded indices for ball joint support. + if self.transmission_type == TransmissionType.JOINT: + indexing = self.entity.indexing + q_indices = indexing.expand_to_q_indices(self._target_ids_list) + v_indices = indexing.expand_to_v_indices(self._target_ids_list) + assert isinstance(q_indices, torch.Tensor) and isinstance(v_indices, torch.Tensor) + self._q_indices = q_indices + self._v_indices = v_indices + # Pre-allocate zeros for SITE transmission type to avoid repeated allocations. if self.transmission_type == TransmissionType.SITE: nenvs = data.nworld @@ -206,11 +218,11 @@ def get_command(self, data: EntityData) -> ActuatorCmd: """ if self.transmission_type == TransmissionType.JOINT: return ActuatorCmd( - position_target=data.joint_pos_target[:, self.target_ids], - velocity_target=data.joint_vel_target[:, self.target_ids], - effort_target=data.joint_effort_target[:, self.target_ids], - pos=data.joint_pos[:, self.target_ids], - vel=data.joint_vel[:, self.target_ids], + position_target=data.joint_pos_target[:, self._q_indices], + velocity_target=data.joint_vel_target[:, self._v_indices], + effort_target=data.joint_effort_target[:, self._v_indices], + pos=data.joint_pos[:, self._q_indices], + vel=data.joint_vel[:, self._v_indices], ) elif self.transmission_type == TransmissionType.TENDON: return ActuatorCmd( diff --git a/src/mjlab/actuator/builtin_group.py b/src/mjlab/actuator/builtin_group.py index ad14705e7..ebcc3be95 100644 --- a/src/mjlab/actuator/builtin_group.py +++ b/src/mjlab/actuator/builtin_group.py @@ -37,6 +37,16 @@ (BuiltinMuscleActuator, TransmissionType.TENDON): "tendon_effort_target", } +# Indicates whether the actuator type uses qpos indexing (vs qvel indexing). +# Position actuators read from qpos-indexed tensors (nq dimension). +# Velocity/motor actuators read from qvel-indexed tensors (nv dimension). +_USES_QPOS_INDEXING: dict[type[BuiltinActuatorType], bool] = { + BuiltinPositionActuator: True, + BuiltinVelocityActuator: False, + BuiltinMotorActuator: False, + BuiltinMuscleActuator: False, +} + @dataclass(frozen=True) class BuiltinActuatorGroup: @@ -47,7 +57,9 @@ class BuiltinActuatorGroup: enables direct writes without per-actuator overhead. """ - # Map from (BuiltinActuator type, transmission_type) to (target_ids, ctrl_ids). + # Map from (BuiltinActuator type, transmission_type) to (expanded_indices, ctrl_ids). + # For JOINT transmission, expanded_indices are qpos or qvel indices depending + # on the actuator type. For other transmissions, they match target_ids. _index_groups: dict[ tuple[type[BuiltinActuatorType], TransmissionType], tuple[torch.Tensor, torch.Tensor], @@ -84,17 +96,29 @@ def process( else: custom_actuators.append(act) - # Return stacked indices for each (actuator_type, transmission_type) group. + # Build stacked indices for each (actuator_type, transmission_type) group. index_groups: dict[ tuple[type[BuiltinActuatorType], TransmissionType], tuple[torch.Tensor, torch.Tensor], - ] = { - key: ( - torch.cat([act.target_ids for act in acts], dim=0), - torch.cat([act.ctrl_ids for act in acts], dim=0), - ) - for key, acts in builtin_groups.items() - } + ] = {} + + for key, acts in builtin_groups.items(): + actuator_type, transmission_type = key + ctrl_ids = torch.cat([act.ctrl_ids for act in acts], dim=0) + + if transmission_type == TransmissionType.JOINT: + # Use expanded qpos/qvel indices for joint transmission. + uses_qpos = _USES_QPOS_INDEXING[actuator_type] + attr = "_q_indices" if uses_qpos else "_v_indices" + indices_list = [getattr(act, attr) for act in acts] + assert all(idx is not None for idx in indices_list) + expanded_indices = torch.cat(indices_list, dim=0) # type: ignore[arg-type] + index_groups[key] = (expanded_indices, ctrl_ids) + else: + # For tendon/site, use flat target_ids. + target_ids = torch.cat([act.target_ids for act in acts], dim=0) + index_groups[key] = (target_ids, ctrl_ids) + return BuiltinActuatorGroup(index_groups), tuple(custom_actuators) def apply_controls(self, data: EntityData) -> None: @@ -104,9 +128,9 @@ def apply_controls(self, data: EntityData) -> None: data: Entity data containing targets and control arrays. """ for (actuator_type, transmission_type), ( - target_ids, + expanded_indices, ctrl_ids, ) in self._index_groups.items(): attr_name = _TARGET_TENSOR_MAP[(actuator_type, transmission_type)] target_tensor = getattr(data, attr_name) - data.write_ctrl(target_tensor[:, target_ids], ctrl_ids) + data.write_ctrl(target_tensor[:, expanded_indices], ctrl_ids) diff --git a/src/mjlab/actuator/delayed_actuator.py b/src/mjlab/actuator/delayed_actuator.py index fa1ac5361..736c6a48e 100644 --- a/src/mjlab/actuator/delayed_actuator.py +++ b/src/mjlab/actuator/delayed_actuator.py @@ -105,6 +105,9 @@ def initialize( self._target_ids = self._base_actuator._target_ids self._ctrl_ids = self._base_actuator._ctrl_ids self._global_ctrl_ids = self._base_actuator._global_ctrl_ids + # Copy expanded indices for ball joint support. + self._q_indices = self._base_actuator._q_indices + self._v_indices = self._base_actuator._v_indices targets = ( (self.cfg.delay_target,) diff --git a/src/mjlab/entity/data.py b/src/mjlab/entity/data.py index 80bef750b..d521407a6 100644 --- a/src/mjlab/entity/data.py +++ b/src/mjlab/entity/data.py @@ -39,6 +39,11 @@ class EntityData: root_link_pose_w) require sim.forward() to be current. If you write then read, call sim.forward() in between. Event order matters when mixing reads and writes. All inputs/outputs use world frame. + + Ball Joint Support: + Position tensors use nq dimensions (4 per ball, 1 per hinge/slide). + Velocity/effort tensors use nv dimensions (3 per ball, 1 per hinge/slide). + Joint limits remain per-joint: (nworld, num_joints, 2). """ indexing: EntityIndexing @@ -155,9 +160,9 @@ def write_joint_position( raise ValueError("Cannot write joint position for non-articulated entity.") env_ids = self._resolve_env_ids(env_ids) - joint_ids = joint_ids if joint_ids is not None else slice(None) - q_slice = self.indexing.joint_q_adr[joint_ids] - self.data.qpos[env_ids, q_slice] = position + q_indices = self.indexing.expand_to_q_indices(joint_ids) + q_adr = self.indexing.joint_q_adr[q_indices] + self.data.qpos[env_ids, q_adr] = position def write_joint_velocity( self, @@ -169,9 +174,9 @@ def write_joint_velocity( raise ValueError("Cannot write joint velocity for non-articulated entity.") env_ids = self._resolve_env_ids(env_ids) - joint_ids = joint_ids if joint_ids is not None else slice(None) - v_slice = self.indexing.joint_v_adr[joint_ids] - self.data.qvel[env_ids, v_slice] = velocity + v_indices = self.indexing.expand_to_v_indices(joint_ids) + v_adr = self.indexing.joint_v_adr[v_indices] + self.data.qvel[env_ids, v_adr] = velocity def write_external_wrench( self, diff --git a/src/mjlab/entity/entity.py b/src/mjlab/entity/entity.py index 88d4d6fb5..092bd7bbb 100644 --- a/src/mjlab/entity/entity.py +++ b/src/mjlab/entity/entity.py @@ -17,7 +17,7 @@ from mjlab.utils.lab_api.string import resolve_matching_names from mjlab.utils.mujoco import dof_width, qpos_width from mjlab.utils.spec import auto_wrap_fixed_base_mocap -from mjlab.utils.string import resolve_expr +from mjlab.utils.string import resolve_expr_with_widths @dataclass(frozen=False) @@ -53,10 +53,63 @@ class EntityIndexing: free_joint_q_adr: torch.Tensor free_joint_v_adr: torch.Tensor + # Ball joint support: per-joint dimensions and cumulative offsets. + joint_qpos_widths: torch.Tensor # (num_joints,) - 4 for ball, 1 for hinge/slide + joint_dof_widths: torch.Tensor # (num_joints,) - 3 for ball, 1 for hinge/slide + joint_types: torch.Tensor # (num_joints,) - mjtJoint enum values + q_offsets: torch.Tensor # (num_joints,) - cumulative qpos offsets per joint + v_offsets: torch.Tensor # (num_joints,) - cumulative qvel offsets per joint + nq: int # Total qpos dimension + nv: int # Total dof dimension + @property def root_body_id(self) -> int: return self.bodies[0].id + def expand_to_q_indices( + self, joint_ids: torch.Tensor | Sequence[int] | slice | None + ) -> torch.Tensor | slice: + """Expand joint IDs to qpos indices for ball joint support.""" + return self._expand_indices(joint_ids, self.joint_qpos_widths, self.q_offsets) + + def expand_to_v_indices( + self, joint_ids: torch.Tensor | Sequence[int] | slice | None + ) -> torch.Tensor | slice: + """Expand joint IDs to qvel indices for ball joint support.""" + return self._expand_indices(joint_ids, self.joint_dof_widths, self.v_offsets) + + def _expand_indices( + self, + joint_ids: torch.Tensor | Sequence[int] | slice | None, + widths: torch.Tensor, + offsets: torch.Tensor, + ) -> torch.Tensor | slice: + """Expand joint IDs to DOF indices using tensor operations.""" + if joint_ids is None: + return slice(None) + + device = widths.device + if isinstance(joint_ids, slice): + start = joint_ids.start or 0 + stop = joint_ids.stop or len(widths) + joint_ids = torch.arange(start, stop, device=device) + elif not isinstance(joint_ids, torch.Tensor): + joint_ids = torch.tensor(joint_ids, dtype=torch.long, device=device) + + selected_widths = widths[joint_ids] + selected_offsets = offsets[joint_ids] + + # Expand using repeat_interleave: starts + local_offsets + repeated_offsets = selected_offsets.repeat_interleave(selected_widths) + total = int(selected_widths.sum().item()) + cumwidths = torch.zeros_like(selected_widths) + cumwidths[1:] = selected_widths[:-1].cumsum(0) + local_offsets = torch.arange(total, device=device) - cumwidths.repeat_interleave( + selected_widths + ) + + return repeated_offsets + local_offsets + @dataclass class EntityCfg: @@ -70,7 +123,10 @@ class InitialStateCfg: ang_vel: tuple[float, float, float] = (0.0, 0.0, 0.0) # Articulation (only for articulated entities). # Set to None to use the model's existing keyframe (errors if none exists). - joint_pos: dict[str, float] | None = field(default_factory=lambda: {".*": 0.0}) + # Ball joints use tuple of 4 floats for quaternion (w, x, y, z). + joint_pos: dict[str, float | tuple[float, ...]] | None = field( + default_factory=lambda: {".*": 0.0} + ) joint_vel: dict[str, float] = field(default_factory=lambda: {".*": 0.0}) init_state: InitialStateCfg = field(default_factory=InitialStateCfg) @@ -204,21 +260,37 @@ def _add_initial_state_keyframe(self) -> None: self.root_body.quat[:] = self.cfg.init_state.rot return - qpos_components = [] + qpos_components: list[tuple[float, ...]] = [] if self._free_joint is not None: qpos_components.extend([self.cfg.init_state.pos, self.cfg.init_state.rot]) joint_pos = None if self._non_free_joints: - joint_pos = resolve_expr(self.cfg.init_state.joint_pos, self.joint_names, 0.0) + # Per-joint widths and defaults (ball: 4 qpos, hinge/slide: 1). + widths = tuple(qpos_width(j.type) for j in self._non_free_joints) + defaults = tuple( + (1.0, 0.0, 0.0, 0.0) if j.type == mujoco.mjtJoint.mjJNT_BALL else (0.0,) + for j in self._non_free_joints + ) + joint_pos = resolve_expr_with_widths( + self.cfg.init_state.joint_pos, self.joint_names, widths, defaults + ) qpos_components.append(joint_pos) key_qpos = np.hstack(qpos_components) if qpos_components else np.array([]) key = self._spec.add_key(name="init_state", qpos=key_qpos.tolist()) if self.is_actuated and joint_pos is not None: - name_to_pos = {name: joint_pos[i] for i, name in enumerate(self.joint_names)} + # Map joint names to position values for ctrl (skip ball joints). + joint_q_offset = 0 + name_to_pos: dict[str, float] = {} + for j, name in zip(self._non_free_joints, self.joint_names, strict=True): + width = qpos_width(j.type) + if width == 1: + name_to_pos[name] = joint_pos[joint_q_offset] + joint_q_offset += width + ctrl = [] for act in self._spec.actuators: joint_name = act.target @@ -315,6 +387,20 @@ def actuator_names(self) -> tuple[str, ...]: def num_joints(self) -> int: return len(self.joint_names) + @property + def nq(self) -> int: + """Total qpos dimension (excludes free joint). Accounts for ball joints (4 qpos).""" + if hasattr(self, "indexing"): + return self.indexing.nq + return sum(qpos_width(j.type) for j in self._non_free_joints) + + @property + def nv(self) -> int: + """Total dof dimension (excludes free joint). Accounts for ball joints (3 dof).""" + if hasattr(self, "indexing"): + return self.indexing.nv + return sum(dof_width(j.type) for j in self._non_free_joints) + @property def num_bodies(self) -> int: return len(self.body_names) @@ -521,6 +607,13 @@ def initialize( # Joint state. if self.is_articulated: + # Use indexing tensors for widths; build defaults based on joint types. + qpos_widths = tuple(indexing.joint_qpos_widths.tolist()) + dof_widths_tuple = tuple(indexing.joint_dof_widths.tolist()) + is_ball = (indexing.joint_types == int(mujoco.mjtJoint.mjJNT_BALL)).tolist() + qpos_defaults = tuple((1.0, 0.0, 0.0, 0.0) if b else (0.0,) for b in is_ball) + vel_defaults = tuple((0.0, 0.0, 0.0) if b else (0.0,) for b in is_ball) + if self.cfg.init_state.joint_pos is None: # Use keyframe joint positions. key_qpos = mj_model.key("init_state").qpos @@ -530,19 +623,40 @@ def initialize( ].repeat(nworld, 1) else: default_joint_pos = torch.tensor( - resolve_expr(self.cfg.init_state.joint_pos, self.joint_names, 0.0), + resolve_expr_with_widths( + self.cfg.init_state.joint_pos, + self.joint_names, + qpos_widths, + qpos_defaults, + ), device=device, )[None].repeat(nworld, 1) default_joint_vel = torch.tensor( - resolve_expr(self.cfg.init_state.joint_vel, self.joint_names, 0.0), + resolve_expr_with_widths( + self.cfg.init_state.joint_vel, + self.joint_names, + dof_widths_tuple, + vel_defaults, + ), device=device, )[None].repeat(nworld, 1) - # Joint limits. + # Joint limits: hinge/slide have real limits, ball joints don't. + # MuJoCo joint types: free=0, ball=1, slide=2, hinge=3. joint_ids_global = torch.tensor( [j.id for j in self._non_free_joints], device=device ) - dof_limits = model.jnt_range[:, joint_ids_global] + dof_limits = model.jnt_range[:, joint_ids_global] # (1, num_joints, 2) + has_real_limits = indexing.joint_types >= 2 # slide=2, hinge=3 + ball_mask = ~has_real_limits + if ball_mask.any(): + # Ball joints: substitute with large range since they have no position limits. + large_range = 1e6 + dof_limits = dof_limits.clone() + dof_limits[:, ball_mask, :] = torch.tensor( + [[-large_range, large_range]], device=device + ) + default_joint_pos_limits = dof_limits.clone() joint_pos_limits = default_joint_pos_limits.clone() joint_pos_mean = (joint_pos_limits[..., 0] + joint_pos_limits[..., 1]) / 2 @@ -574,14 +688,15 @@ def initialize( ) if self.is_actuated: + # Use nq for position targets, nv for velocity/effort targets. joint_pos_target = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nq), dtype=torch.float, device=device ) joint_vel_target = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nv), dtype=torch.float, device=device ) joint_effort_target = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nv), dtype=torch.float, device=device ) else: joint_pos_target = torch.empty(nworld, 0, dtype=torch.float, device=device) @@ -615,10 +730,10 @@ def initialize( site_effort_target = torch.empty(nworld, 0, dtype=torch.float, device=device) # Encoder bias for simulating encoder calibration errors. - # Shape: (num_envs, num_joints). Defaults to zero (no bias). + # Shape: (num_envs, nq). Defaults to zero (no bias). if self.is_articulated: encoder_bias = torch.zeros( - (nworld, self.num_joints), dtype=torch.float, device=device + (nworld, indexing.nq), dtype=torch.float, device=device ) else: encoder_bias = torch.empty(nworld, 0, dtype=torch.float, device=device) @@ -812,15 +927,15 @@ def set_joint_position_target( """Set joint position targets. Args: - position: Target joint poisitions with shape (N, num_joints). + position: Target joint positions with shape (N, nq) or (N, selected_nq). + For ball joints, this should include all 4 quaternion values per joint. joint_ids: Optional joint indices to set. If None, set all joints. env_ids: Optional environment indices. If None, set all environments. """ if env_ids is None: env_ids = slice(None) - if joint_ids is None: - joint_ids = slice(None) - self._data.joint_pos_target[env_ids, joint_ids] = position + q_indices = self.indexing.expand_to_q_indices(joint_ids) + self._data.joint_pos_target[env_ids, q_indices] = position def set_joint_velocity_target( self, @@ -831,15 +946,15 @@ def set_joint_velocity_target( """Set joint velocity targets. Args: - velocity: Target joint velocities with shape (N, num_joints). + velocity: Target joint velocities with shape (N, nv) or (N, selected_nv). + For ball joints, this should include all 3 angular velocity values per joint. joint_ids: Optional joint indices to set. If None, set all joints. env_ids: Optional environment indices. If None, set all environments. """ if env_ids is None: env_ids = slice(None) - if joint_ids is None: - joint_ids = slice(None) - self._data.joint_vel_target[env_ids, joint_ids] = velocity + v_indices = self.indexing.expand_to_v_indices(joint_ids) + self._data.joint_vel_target[env_ids, v_indices] = velocity def set_joint_effort_target( self, @@ -850,15 +965,15 @@ def set_joint_effort_target( """Set joint effort targets. Args: - effort: Target joint efforts with shape (N, num_joints). + effort: Target joint efforts with shape (N, nv) or (N, selected_nv). + For ball joints, this should include all 3 torque values per joint. joint_ids: Optional joint indices to set. If None, set all joints. env_ids: Optional environment indices. If None, set all environments. """ if env_ids is None: env_ids = slice(None) - if joint_ids is None: - joint_ids = slice(None) - self._data.joint_effort_target[env_ids, joint_ids] = effort + v_indices = self.indexing.expand_to_v_indices(joint_ids) + self._data.joint_effort_target[env_ids, v_indices] = effort def set_tendon_len_target( self, @@ -1010,6 +1125,11 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin joint_v_adr = [] free_joint_q_adr = [] free_joint_v_adr = [] + + # Per-joint dimension tracking for ball joint support. + joint_qpos_widths_list: list[int] = [] + joint_dof_widths_list: list[int] = [] + joint_types_list: list[int] = [] for joint in self.spec.joints: jnt = model.joint(joint.name) jnt_type = jnt.type[0] @@ -1019,13 +1139,39 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin free_joint_v_adr.extend(range(vadr, vadr + 6)) free_joint_q_adr.extend(range(qadr, qadr + 7)) else: - joint_v_adr.extend(range(vadr, vadr + dof_width(jnt_type))) - joint_q_adr.extend(range(qadr, qadr + qpos_width(jnt_type))) + qw = qpos_width(jnt_type) + dw = dof_width(jnt_type) + joint_v_adr.extend(range(vadr, vadr + dw)) + joint_q_adr.extend(range(qadr, qadr + qw)) + + # Track per-joint dimensions. + joint_types_list.append(jnt_type) + joint_qpos_widths_list.append(qw) + joint_dof_widths_list.append(dw) + joint_q_adr = torch.tensor(joint_q_adr, dtype=torch.int, device=device) joint_v_adr = torch.tensor(joint_v_adr, dtype=torch.int, device=device) free_joint_v_adr = torch.tensor(free_joint_v_adr, dtype=torch.int, device=device) free_joint_q_adr = torch.tensor(free_joint_q_adr, dtype=torch.int, device=device) + # Convert per-joint tracking lists to tensors. + joint_qpos_widths = torch.tensor( + joint_qpos_widths_list, dtype=torch.int, device=device + ) + joint_dof_widths = torch.tensor( + joint_dof_widths_list, dtype=torch.int, device=device + ) + joint_types = torch.tensor(joint_types_list, dtype=torch.int, device=device) + + # Compute cumulative offsets for tensor-based index expansion. + q_offsets = torch.zeros_like(joint_qpos_widths) + v_offsets = torch.zeros_like(joint_dof_widths) + if len(joint_qpos_widths) > 0: + q_offsets[1:] = joint_qpos_widths[:-1].cumsum(0) + v_offsets[1:] = joint_dof_widths[:-1].cumsum(0) + nq = int(joint_qpos_widths.sum().item()) if len(joint_qpos_widths) > 0 else 0 + nv = int(joint_dof_widths.sum().item()) if len(joint_dof_widths) > 0 else 0 + if self.is_fixed_base and self.is_mocap: mocap_id = int(model.body_mocapid[self.root_body.id]) else: @@ -1055,6 +1201,13 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin joint_v_adr=joint_v_adr, free_joint_q_adr=free_joint_q_adr, free_joint_v_adr=free_joint_v_adr, + joint_qpos_widths=joint_qpos_widths, + joint_dof_widths=joint_dof_widths, + joint_types=joint_types, + q_offsets=q_offsets, + v_offsets=v_offsets, + nq=nq, + nv=nv, ) def _apply_actuator_controls(self) -> None: diff --git a/src/mjlab/envs/mdp/events.py b/src/mjlab/envs/mdp/events.py index f8b9e5974..7db3391b7 100644 --- a/src/mjlab/envs/mdp/events.py +++ b/src/mjlab/envs/mdp/events.py @@ -9,6 +9,7 @@ from mjlab.entity import Entity from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.utils.lab_api.math import ( + quat_from_angle_axis, quat_from_euler_xyz, quat_mul, sample_uniform, @@ -284,6 +285,11 @@ def reset_joints_by_offset( velocity_range: tuple[float, float], asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, ) -> None: + """Reset joint positions with random offsets from default. + + For hinge joints: adds uniform noise and clamps to limits. + For ball joints: applies random rotation via quaternion multiplication. + """ if env_ids is None: env_ids = torch.arange(env.num_envs, device=env.device, dtype=torch.int) @@ -295,17 +301,102 @@ def reset_joints_by_offset( soft_joint_pos_limits = asset.data.soft_joint_pos_limits assert soft_joint_pos_limits is not None - joint_pos = default_joint_pos[env_ids][:, asset_cfg.joint_ids].clone() - joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device) - joint_pos_limits = soft_joint_pos_limits[env_ids][:, asset_cfg.joint_ids] - joint_pos = joint_pos.clamp_(joint_pos_limits[..., 0], joint_pos_limits[..., 1]) + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + + joint_ids = asset_cfg.joint_ids + if isinstance(joint_ids, slice): + joint_ids_tensor = torch.arange(asset.num_joints, device=env.device) + elif isinstance(joint_ids, list): + joint_ids_tensor = torch.tensor(joint_ids, device=env.device) + else: + joint_ids_tensor = joint_ids + + # Hinge-only entities: use simple add + clamp. + if asset.nq == asset.nv: + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + joint_pos = default_joint_pos[env_ids][:, q_indices].clone() + joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device) + + joint_limits = soft_joint_pos_limits[env_ids][:, joint_ids_tensor] + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + expanded_limits = joint_limits.repeat_interleave(qpos_widths, dim=1) + joint_pos = joint_pos.clamp_(expanded_limits[..., 0], expanded_limits[..., 1]) + else: + # Mixed ball/hinge joints: handle separately. + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + dof_widths = asset.indexing.joint_dof_widths[joint_ids_tensor] + q_offsets = asset.indexing.q_offsets[joint_ids_tensor] + + num_envs = len(env_ids) + device = env.device + total_qpos = int(qpos_widths.sum().item()) + joint_pos = torch.zeros((num_envs, total_qpos), device=device) + + is_ball_joint = (qpos_widths == 4) & (dof_widths == 3) + qpos_cumsum = torch.cat( + [ + torch.zeros(1, device=device, dtype=qpos_widths.dtype), + qpos_widths.cumsum(0)[:-1], + ] + ) + + # Ball joints: sample rotation angle, random axis, apply via quat_mul. + if is_ball_joint.any(): + ball_q_offsets = q_offsets[is_ball_joint] + ball_out_offsets = qpos_cumsum[is_ball_joint] + num_ball = int(is_ball_joint.sum().item()) + + # Get default quaternions. + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + default_quats = default_joint_pos[env_ids][:, ball_q_indices.flatten()].view( + num_envs, num_ball, 4 + ) + + # Sample rotation angles from position_range. + angles = sample_uniform( + position_range[0], position_range[1], (num_envs * num_ball,), device + ) + + # Sample random rotation axes (unit vectors). + axes = torch.randn((num_envs * num_ball, 3), device=device) + axes = axes / axes.norm(dim=-1, keepdim=True) + + # Create delta quaternions. + delta_quats = quat_from_angle_axis(angles, axes) + + # Apply rotation: q_new = q_default * q_delta. + new_quats = quat_mul(default_quats.reshape(-1, 4), delta_quats).view( + num_envs, num_ball, 4 + ) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(4, device=device) + joint_pos[:, ball_out_indices.flatten()] = new_quats.reshape(num_envs, -1) - joint_vel = default_joint_vel[env_ids][:, asset_cfg.joint_ids].clone() + # Hinge joints: add noise + clamp. + is_hinge_joint = ~is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = q_offsets[is_hinge_joint] + hinge_out_offsets = qpos_cumsum[is_hinge_joint] + hinge_joint_ids = joint_ids_tensor[is_hinge_joint] + num_hinge = int(is_hinge_joint.sum().item()) + + default_hinge = default_joint_pos[env_ids][:, hinge_q_offsets] + noise = sample_uniform(*position_range, (num_envs, num_hinge), device) + hinge_pos = default_hinge + noise + + # Clamp to limits. + hinge_limits = soft_joint_pos_limits[env_ids][:, hinge_joint_ids] + hinge_pos = hinge_pos.clamp_(hinge_limits[..., 0], hinge_limits[..., 1]) + + joint_pos[:, hinge_out_offsets] = hinge_pos + + joint_vel = default_joint_vel[env_ids][:, v_indices].clone() joint_vel += sample_uniform(*velocity_range, joint_vel.shape, env.device) - joint_ids = asset_cfg.joint_ids - if isinstance(joint_ids, list): - joint_ids = torch.tensor(joint_ids, device=env.device) + if isinstance(asset_cfg.joint_ids, list): + joint_ids = torch.tensor(asset_cfg.joint_ids, device=env.device) + else: + joint_ids = asset_cfg.joint_ids asset.write_joint_state_to_sim( joint_pos.view(len(env_ids), -1), diff --git a/src/mjlab/envs/mdp/observations.py b/src/mjlab/envs/mdp/observations.py index 1b5b326f8..6c7fc00af 100644 --- a/src/mjlab/envs/mdp/observations.py +++ b/src/mjlab/envs/mdp/observations.py @@ -9,6 +9,7 @@ from mjlab.entity import Entity from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sensor import BuiltinSensor, RayCastSensor +from mjlab.utils.lab_api.math import quat_box_minus if TYPE_CHECKING: from mjlab.envs import ManagerBasedRlEnv @@ -53,12 +54,74 @@ def joint_pos_rel( biased: bool = False, asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG, ) -> torch.Tensor: + """Compute relative joint positions (current - default) in DOF space. + + For hinge/slide joints: returns scalar difference. + For ball joints: returns 3D axis-angle difference via quat_box_minus. + + Returns: + Tensor of shape (num_envs, total_dof) for selected joints. + """ asset: Entity = env.scene[asset_cfg.name] default_joint_pos = asset.data.default_joint_pos assert default_joint_pos is not None - jnt_ids = asset_cfg.joint_ids joint_pos = asset.data.joint_pos_biased if biased else asset.data.joint_pos - return joint_pos[:, jnt_ids] - default_joint_pos[:, jnt_ids] + + # Hinge-only entities (nq == nv means no ball joints). + if asset.nq == asset.nv: + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + return joint_pos[:, q_indices] - default_joint_pos[:, q_indices] + + # Ball joints. + joint_ids = asset_cfg.joint_ids + if isinstance(joint_ids, slice): + joint_ids_tensor = torch.arange(asset.num_joints, device=joint_pos.device) + elif isinstance(joint_ids, list): + joint_ids_tensor = torch.tensor(joint_ids, device=joint_pos.device) + else: + joint_ids_tensor = joint_ids + + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + dof_widths = asset.indexing.joint_dof_widths[joint_ids_tensor] + q_offsets = asset.indexing.q_offsets[joint_ids_tensor] + + num_envs = joint_pos.shape[0] + device = joint_pos.device + + is_ball_joint = (qpos_widths == 4) & (dof_widths == 3) + dof_cumsum = torch.cat( + [torch.zeros(1, device=device, dtype=dof_widths.dtype), dof_widths.cumsum(0)[:-1]] + ) + total_dof = int(dof_widths.sum().item()) + result = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype) + + # Ball joints: quaternion difference via quat_box_minus. + if is_ball_joint.any(): + ball_q_offsets = q_offsets[is_ball_joint] + ball_out_offsets = dof_cumsum[is_ball_joint] + + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + default_quats = default_joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + + num_ball = current_quats.shape[1] + diff_flat = quat_box_minus( + current_quats.reshape(-1, 4), default_quats.reshape(-1, 4) + ) + diff = diff_flat.view(num_envs, num_ball, 3) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device) + result[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1) + + # Hinge joints: scalar subtraction. + is_hinge_joint = ~is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = q_offsets[is_hinge_joint] + hinge_out_offsets = dof_cumsum[is_hinge_joint] + diff_hinge = joint_pos[:, hinge_q_offsets] - default_joint_pos[:, hinge_q_offsets] + result[:, hinge_out_offsets] = diff_hinge + + return result def joint_vel_rel( @@ -68,8 +131,8 @@ def joint_vel_rel( asset: Entity = env.scene[asset_cfg.name] default_joint_vel = asset.data.default_joint_vel assert default_joint_vel is not None - jnt_ids = asset_cfg.joint_ids - return asset.data.joint_vel[:, jnt_ids] - default_joint_vel[:, jnt_ids] + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + return asset.data.joint_vel[:, v_indices] - default_joint_vel[:, v_indices] ## diff --git a/src/mjlab/envs/mdp/rewards.py b/src/mjlab/envs/mdp/rewards.py index 300433da0..59c8eb807 100644 --- a/src/mjlab/envs/mdp/rewards.py +++ b/src/mjlab/envs/mdp/rewards.py @@ -9,6 +9,7 @@ from mjlab.entity import Entity from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg +from mjlab.utils.lab_api.math import quat_box_minus from mjlab.utils.lab_api.string import ( resolve_matching_names_values, ) @@ -44,7 +45,8 @@ def joint_vel_l2( ) -> torch.Tensor: """Penalize joint velocities on the articulation using L2 squared kernel.""" asset: Entity = env.scene[asset_cfg.name] - return torch.sum(torch.square(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1) + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + return torch.sum(torch.square(asset.data.joint_vel[:, v_indices]), dim=1) def joint_acc_l2( @@ -52,7 +54,8 @@ def joint_acc_l2( ) -> torch.Tensor: """Penalize joint accelerations on the articulation using L2 squared kernel.""" asset: Entity = env.scene[asset_cfg.name] - return torch.sum(torch.square(asset.data.joint_acc[:, asset_cfg.joint_ids]), dim=1) + v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids) + return torch.sum(torch.square(asset.data.joint_acc[:, v_indices]), dim=1) def action_rate_l2(env: ManagerBasedRlEnv) -> torch.Tensor: @@ -79,14 +82,25 @@ def joint_pos_limits( asset: Entity = env.scene[asset_cfg.name] soft_joint_pos_limits = asset.data.soft_joint_pos_limits assert soft_joint_pos_limits is not None - out_of_limits = -( - asset.data.joint_pos[:, asset_cfg.joint_ids] - - soft_joint_pos_limits[:, asset_cfg.joint_ids, 0] - ).clip(max=0.0) - out_of_limits += ( - asset.data.joint_pos[:, asset_cfg.joint_ids] - - soft_joint_pos_limits[:, asset_cfg.joint_ids, 1] - ).clip(min=0.0) + + q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids) + joint_pos = asset.data.joint_pos[:, q_indices] + + joint_ids = asset_cfg.joint_ids + if isinstance(joint_ids, slice): + joint_ids_tensor = torch.arange(asset.num_joints, device=env.device) + elif isinstance(joint_ids, list): + joint_ids_tensor = torch.tensor(joint_ids, device=env.device) + else: + joint_ids_tensor = joint_ids + + # Expand limits to qpos dimensions for ball joints. + limits = soft_joint_pos_limits[:, joint_ids_tensor] + qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor] + expanded_limits = limits.repeat_interleave(qpos_widths, dim=1) + + out_of_limits = -(joint_pos - expanded_limits[..., 0]).clip(max=0.0) + out_of_limits += (joint_pos - expanded_limits[..., 1]).clip(min=0.0) return torch.sum(out_of_limits, dim=1) @@ -103,9 +117,8 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): assert default_joint_pos is not None self.default_joint_pos = default_joint_pos - _, joint_names = asset.find_joints( - cfg.params["asset_cfg"].joint_names, - ) + joint_ids, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names) + self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) _, _, std = resolve_matching_names_values( data=cfg.params["std"], @@ -113,15 +126,66 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): ) self.std = torch.tensor(std, device=env.device, dtype=torch.float32) + self._joint_dof_widths = asset.indexing.joint_dof_widths[self._joint_ids] + self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids] + self._q_offsets = asset.indexing.q_offsets[self._joint_ids] + self._is_ball_joint = (self._joint_qpos_widths == 4) & (self._joint_dof_widths == 3) + + # Precompute cumulative DOF offsets for output indexing. + self._dof_cumsum = torch.cat( + [ + torch.zeros(1, device=env.device, dtype=self._joint_dof_widths.dtype), + self._joint_dof_widths.cumsum(0)[:-1], + ] + ) + self._std_expanded = self.std.repeat_interleave(self._joint_dof_widths) + def __call__( self, env: ManagerBasedRlEnv, std, asset_cfg: SceneEntityCfg ) -> torch.Tensor: del std # Unused. + asset: Entity = env.scene[asset_cfg.name] - current_joint_pos = asset.data.joint_pos[:, asset_cfg.joint_ids] - desired_joint_pos = self.default_joint_pos[:, asset_cfg.joint_ids] - error_squared = torch.square(current_joint_pos - desired_joint_pos) - return torch.exp(-torch.mean(error_squared / (self.std**2), dim=1)) + joint_pos = asset.data.joint_pos + num_envs = joint_pos.shape[0] + device = joint_pos.device + + total_dof = int(self._joint_dof_widths.sum().item()) + error = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype) + + # Ball joints: quaternion difference. + if self._is_ball_joint.any(): + ball_q_offsets = self._q_offsets[self._is_ball_joint] + ball_out_offsets = self._dof_cumsum[self._is_ball_joint] + + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + default_quats = self.default_joint_pos[:, ball_q_indices.flatten()].view( + num_envs, -1, 4 + ) + + num_ball = current_quats.shape[1] + diff_flat = quat_box_minus( + current_quats.reshape(-1, 4), default_quats.reshape(-1, 4) + ) + diff = diff_flat.view(num_envs, num_ball, 3) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device) + error[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1) + + # Hinge joints: scalar subtraction. + is_hinge_joint = ~self._is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = self._q_offsets[is_hinge_joint] + hinge_out_offsets = self._dof_cumsum[is_hinge_joint] + + diff_hinge = ( + joint_pos[:, hinge_q_offsets] - self.default_joint_pos[:, hinge_q_offsets] + ) + error[:, hinge_out_offsets] = diff_hinge + + error_squared = torch.square(error) + return torch.exp(-torch.mean(error_squared / (self._std_expanded**2), dim=1)) class electrical_power_cost: @@ -136,13 +200,14 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): actuator_ids, _ = asset.find_actuators( cfg.params["asset_cfg"].joint_names, ) - self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) + joint_ids_tensor = torch.tensor(joint_ids, device=env.device, dtype=torch.long) + self._v_indices = asset.indexing.expand_to_v_indices(joint_ids_tensor) self._actuator_ids = torch.tensor(actuator_ids, device=env.device, dtype=torch.long) def __call__(self, env: ManagerBasedRlEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: asset: Entity = env.scene[asset_cfg.name] tau = asset.data.actuator_force[:, self._actuator_ids] - qd = asset.data.joint_vel[:, self._joint_ids] + qd = asset.data.joint_vel[:, self._v_indices] mech = tau * qd mech_pos = torch.clamp(mech, min=0.0) # Don't penalize regen. return torch.sum(mech_pos, dim=1) diff --git a/src/mjlab/tasks/manipulation/mdp/rewards.py b/src/mjlab/tasks/manipulation/mdp/rewards.py index b5792cf94..1430d605f 100644 --- a/src/mjlab/tasks/manipulation/mdp/rewards.py +++ b/src/mjlab/tasks/manipulation/mdp/rewards.py @@ -64,6 +64,7 @@ def joint_velocity_hinge_penalty( penalty, shaped as the negative squared L2 norm of the excess velocities. """ robot: Entity = env.scene[asset_cfg.name] - joint_vel = robot.data.joint_vel[:, asset_cfg.joint_ids] + v_indices = robot.indexing.expand_to_v_indices(asset_cfg.joint_ids) + joint_vel = robot.data.joint_vel[:, v_indices] excess = (joint_vel.abs() - max_vel).clamp_min(0.0) return (excess**2).sum(dim=-1) diff --git a/src/mjlab/tasks/velocity/mdp/rewards.py b/src/mjlab/tasks/velocity/mdp/rewards.py index 772dad74c..1a11887fc 100644 --- a/src/mjlab/tasks/velocity/mdp/rewards.py +++ b/src/mjlab/tasks/velocity/mdp/rewards.py @@ -8,7 +8,7 @@ from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sensor import BuiltinSensor, ContactSensor -from mjlab.utils.lab_api.math import quat_apply_inverse +from mjlab.utils.lab_api.math import quat_apply_inverse, quat_box_minus from mjlab.utils.lab_api.string import ( resolve_matching_names_values, ) @@ -322,27 +322,36 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): assert default_joint_pos is not None self.default_joint_pos = default_joint_pos - _, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names) + joint_ids, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names) + self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long) _, _, std_standing = resolve_matching_names_values( - data=cfg.params["std_standing"], - list_of_strings=joint_names, + data=cfg.params["std_standing"], list_of_strings=joint_names + ) + _, _, std_walking = resolve_matching_names_values( + data=cfg.params["std_walking"], list_of_strings=joint_names + ) + _, _, std_running = resolve_matching_names_values( + data=cfg.params["std_running"], list_of_strings=joint_names ) self.std_standing = torch.tensor( std_standing, device=env.device, dtype=torch.float32 ) - - _, _, std_walking = resolve_matching_names_values( - data=cfg.params["std_walking"], - list_of_strings=joint_names, - ) self.std_walking = torch.tensor(std_walking, device=env.device, dtype=torch.float32) + self.std_running = torch.tensor(std_running, device=env.device, dtype=torch.float32) - _, _, std_running = resolve_matching_names_values( - data=cfg.params["std_running"], - list_of_strings=joint_names, + self._joint_dof_widths = asset.indexing.joint_dof_widths[self._joint_ids] + self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids] + self._q_offsets = asset.indexing.q_offsets[self._joint_ids] + self._is_ball_joint = (self._joint_qpos_widths == 4) & (self._joint_dof_widths == 3) + + # Precompute cumulative DOF offsets for output indexing. + self._dof_cumsum = torch.cat( + [ + torch.zeros(1, device=env.device, dtype=self._joint_dof_widths.dtype), + self._joint_dof_widths.cumsum(0)[:-1], + ] ) - self.std_running = torch.tensor(std_running, device=env.device, dtype=torch.float32) def __call__( self, @@ -361,9 +370,7 @@ def __call__( command = env.command_manager.get_command(command_name) assert command is not None - linear_speed = torch.norm(command[:, :2], dim=1) - angular_speed = torch.abs(command[:, 2]) - total_speed = linear_speed + angular_speed + total_speed = torch.norm(command[:, :2], dim=1) + torch.abs(command[:, 2]) standing_mask = (total_speed < walking_threshold).float() walking_mask = ( @@ -371,14 +378,54 @@ def __call__( ).float() running_mask = (total_speed >= running_threshold).float() + # Expand std to qpos dimensions for ball joints. + std_standing_exp = self.std_standing.repeat_interleave(self._joint_dof_widths) + std_walking_exp = self.std_walking.repeat_interleave(self._joint_dof_widths) + std_running_exp = self.std_running.repeat_interleave(self._joint_dof_widths) + std = ( - self.std_standing * standing_mask.unsqueeze(1) - + self.std_walking * walking_mask.unsqueeze(1) - + self.std_running * running_mask.unsqueeze(1) + std_standing_exp * standing_mask.unsqueeze(1) + + std_walking_exp * walking_mask.unsqueeze(1) + + std_running_exp * running_mask.unsqueeze(1) ) - current_joint_pos = asset.data.joint_pos[:, asset_cfg.joint_ids] - desired_joint_pos = self.default_joint_pos[:, asset_cfg.joint_ids] - error_squared = torch.square(current_joint_pos - desired_joint_pos) - + joint_pos = asset.data.joint_pos + num_envs = joint_pos.shape[0] + device = joint_pos.device + + total_dof = int(self._joint_dof_widths.sum().item()) + error = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype) + + # Ball joints: quaternion difference. + if self._is_ball_joint.any(): + ball_q_offsets = self._q_offsets[self._is_ball_joint] + ball_out_offsets = self._dof_cumsum[self._is_ball_joint] + + ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device) + current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4) + default_quats = self.default_joint_pos[:, ball_q_indices.flatten()].view( + num_envs, -1, 4 + ) + + num_ball = current_quats.shape[1] + diff_flat = quat_box_minus( + current_quats.reshape(-1, 4), default_quats.reshape(-1, 4) + ) + diff = diff_flat.view(num_envs, num_ball, 3) + + ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device) + error[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1) + + # Hinge joints: scalar subtraction. + is_hinge_joint = ~self._is_ball_joint + if is_hinge_joint.any(): + hinge_q_offsets = self._q_offsets[is_hinge_joint] + hinge_out_offsets = self._dof_cumsum[is_hinge_joint] + + diff_hinge = ( + joint_pos[:, hinge_q_offsets] - self.default_joint_pos[:, hinge_q_offsets] + ) + error[:, hinge_out_offsets] = diff_hinge + + error_squared = torch.square(error) return torch.exp(-torch.mean(error_squared / (std**2), dim=1)) diff --git a/src/mjlab/utils/string.py b/src/mjlab/utils/string.py index e57d55c03..7f12df9cb 100644 --- a/src/mjlab/utils/string.py +++ b/src/mjlab/utils/string.py @@ -21,6 +21,46 @@ def resolve_expr( return tuple(result) +def resolve_expr_with_widths( + pattern_map: dict[str, Any], + names: tuple[str, ...], + widths: tuple[int, ...], + default_vals: tuple[tuple[float, ...], ...], +) -> tuple[float, ...]: + """Resolve field values accounting for per-joint widths (ball=4, hinge/slide=1). + + Scalars for ball joints use the default (identity quaternion) for backward + compatibility with {".*": 0.0} patterns. + """ + patterns = [(re.compile(pat), val) for pat, val in pattern_map.items()] + + result: list[float] = [] + for name, width, default in zip(names, widths, default_vals, strict=True): + matched = False + for pat, val in patterns: + if pat.match(name): + if isinstance(val, (tuple, list)): + if len(val) != width: + raise ValueError(f"Joint '{name}' expects {width} values, got {len(val)}") + result.extend(val) + matched = True + else: + # Scalar value. + if width == 1: + result.append(val) + matched = True + else: + # Scalar for ball joint - use default (identity quaternion). + # This handles backward compatibility with {".*": 0.0} patterns. + result.extend(default) + matched = True + break + if not matched: + result.extend(default) + + return tuple(result) + + def filter_exp( exprs: list[str] | tuple[str, ...], names: tuple[str, ...] ) -> tuple[str, ...]: diff --git a/tests/test_ball_joint.py b/tests/test_ball_joint.py new file mode 100644 index 000000000..82e0d6943 --- /dev/null +++ b/tests/test_ball_joint.py @@ -0,0 +1,440 @@ +"""Tests for ball joint support in the Entity system.""" + +import mujoco +import pytest +import torch +from conftest import get_test_device, initialize_entity + +from mjlab.entity import Entity, EntityCfg + +# XML with ball joints for testing. +BALL_JOINT_XML = """ + + + + + + + + + + + + + + + + + + + + +""" + +# XML with only hinge joints for comparison. +HINGE_ONLY_XML = """ + + + + + + + + + + + + + + + + +""" + + +@pytest.fixture(scope="module") +def device(): + """Test device fixture.""" + return get_test_device() + + +def create_ball_joint_entity(): + """Create an entity with ball joints.""" + cfg = EntityCfg(spec_fn=lambda: mujoco.MjSpec.from_string(BALL_JOINT_XML)) + return Entity(cfg) + + +def create_hinge_only_entity(): + """Create an entity with only hinge joints.""" + cfg = EntityCfg(spec_fn=lambda: mujoco.MjSpec.from_string(HINGE_ONLY_XML)) + return Entity(cfg) + + +def test_nq_nv_with_ball_joints(device): + """Test that nq and nv are correctly computed for ball joints. + + Ball joints have 4 qpos (quaternion) and 3 qvel (angular velocity). + Hinge joints have 1 qpos and 1 qvel. + """ + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + # Entity has: ball1 (4 qpos, 3 dof), hinge1 (1 qpos, 1 dof), ball2 (4 qpos, 3 dof) + # Total: 4 + 1 + 4 = 9 qpos, 3 + 1 + 3 = 7 dof + assert entity.num_joints == 3 + assert entity.nq == 9, f"Expected nq=9, got {entity.nq}" + assert entity.nv == 7, f"Expected nv=7, got {entity.nv}" + + # Check indexing fields match. + assert entity.indexing.nq == 9 + assert entity.indexing.nv == 7 + + +def test_nq_nv_hinge_only(device): + """Test that nq and nv equal num_joints for hinge-only entities.""" + entity = create_hinge_only_entity() + entity, _ = initialize_entity(entity, device) + + # Entity has: hinge1 (1 qpos, 1 dof), hinge2 (1 qpos, 1 dof) + assert entity.num_joints == 2 + assert entity.nq == 2 + assert entity.nv == 2 + + +def test_joint_offset_tensors(device): + """Test that q_offsets and v_offsets are correctly computed.""" + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + indexing = entity.indexing + + # Expected cumulative offsets: ball1(4,3), hinge1(1,1), ball2(4,3) + expected_q_offsets = torch.tensor([0, 4, 5], dtype=torch.int, device=device) + expected_v_offsets = torch.tensor([0, 3, 4], dtype=torch.int, device=device) + + assert torch.equal(indexing.q_offsets, expected_q_offsets) + assert torch.equal(indexing.v_offsets, expected_v_offsets) + + # Test expand_to_q_indices for single joint (ball1). + q_indices = indexing.expand_to_q_indices(torch.tensor([0], device=device)) + assert isinstance(q_indices, torch.Tensor) + assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3], device=device)) + + # Test expand_to_v_indices for single joint (ball1). + v_indices = indexing.expand_to_v_indices(torch.tensor([0], device=device)) + assert isinstance(v_indices, torch.Tensor) + assert torch.equal(v_indices, torch.tensor([0, 1, 2], device=device)) + + # Test expand for hinge joint (joint 1). + q_indices = indexing.expand_to_q_indices(torch.tensor([1], device=device)) + assert isinstance(q_indices, torch.Tensor) + assert torch.equal(q_indices, torch.tensor([4], device=device)) + + # Test expand for multiple joints (ball1, ball2). + q_indices = indexing.expand_to_q_indices(torch.tensor([0, 2], device=device)) + assert isinstance(q_indices, torch.Tensor) + assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3, 5, 6, 7, 8], device=device)) + + +def test_joint_qpos_widths(device): + """Test that joint_qpos_widths and joint_dof_widths are correct.""" + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + indexing = entity.indexing + + # ball1, hinge1, ball2 + expected_qpos_widths = torch.tensor([4, 1, 4], dtype=torch.int, device=device) + expected_dof_widths = torch.tensor([3, 1, 3], dtype=torch.int, device=device) + + assert torch.equal(indexing.joint_qpos_widths, expected_qpos_widths) + assert torch.equal(indexing.joint_dof_widths, expected_dof_widths) + + +def test_ball_joint_initial_state(device): + """Test that ball joints default to identity quaternion (1, 0, 0, 0).""" + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + # Check default joint positions have correct shape. + default_joint_pos = entity.data.default_joint_pos + assert default_joint_pos.shape == (1, 9), ( + f"Expected shape (1, 9), got {default_joint_pos.shape}" + ) + + # Ball joint 1 (indices 0:4) should be identity quaternion. + ball1_quat = default_joint_pos[0, 0:4] + expected_identity = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device) + assert torch.allclose(ball1_quat, expected_identity, atol=1e-6) + + # Hinge joint (index 4) should be 0. + hinge_pos = default_joint_pos[0, 4] + assert abs(hinge_pos.item()) < 1e-6 + + # Ball joint 2 (indices 5:9) should be identity quaternion. + ball2_quat = default_joint_pos[0, 5:9] + assert torch.allclose(ball2_quat, expected_identity, atol=1e-6) + + +def test_default_joint_vel_shape(device): + """Test that default_joint_vel has correct shape (nv, not num_joints).""" + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + default_joint_vel = entity.data.default_joint_vel + assert default_joint_vel.shape == (1, 7), ( + f"Expected shape (1, 7), got {default_joint_vel.shape}" + ) + + +def test_joint_pos_target_shape(device): + """Test that joint_pos_target has correct shape (nq) for non-actuated entity.""" + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + # Non-actuated entities have empty target tensors. + joint_pos_target = entity.data.joint_pos_target + assert joint_pos_target.shape == (1, 0), ( + f"Expected shape (1, 0) for non-actuated entity, got {joint_pos_target.shape}" + ) + + +def test_joint_vel_target_shape(device): + """Test that joint_vel_target has correct shape (nv) for non-actuated entity.""" + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + # Non-actuated entities have empty target tensors. + joint_vel_target = entity.data.joint_vel_target + assert joint_vel_target.shape == (1, 0), ( + f"Expected shape (1, 0) for non-actuated entity, got {joint_vel_target.shape}" + ) + + +def test_write_joint_position_with_ball_joints(device): + """Test writing joint positions with ball joints.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Write all joint positions. + # ball1 (4) + hinge1 (1) + ball2 (4) = 9 values + new_pos = torch.tensor( + [ + [ + 0.7071, + 0.7071, + 0.0, + 0.0, # ball1: 90 degree rotation around x + 0.5, # hinge1 + 1.0, + 0.0, + 0.0, + 0.0, # ball2: identity + ] + ], + device=device, + ) + + entity.data.write_joint_position(new_pos) + sim.forward() + + # Read back and verify. + joint_pos = entity.data.joint_pos + assert torch.allclose(joint_pos, new_pos, atol=1e-4) + + +def test_write_joint_velocity_with_ball_joints(device): + """Test writing joint velocities with ball joints.""" + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Write all joint velocities. + # ball1 (3) + hinge1 (1) + ball2 (3) = 7 values + new_vel = torch.tensor( + [ + [ + 0.1, + 0.2, + 0.3, # ball1 angular velocity + 0.5, # hinge1 velocity + 0.0, + 0.0, + 0.0, # ball2 angular velocity + ] + ], + device=device, + ) + + entity.data.write_joint_velocity(new_vel) + sim.forward() + + # Read back and verify. + joint_vel = entity.data.joint_vel + assert torch.allclose(joint_vel, new_vel, atol=1e-4) + + +## +# MDP function tests with ball joints. +## + + +def test_mdp_joint_pos_rel_with_ball_joints(device): + """Test joint_pos_rel observation with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import observations + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Get relative joint positions. + result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg) + + # Should have shape (1, nv=7) in DOF space, not (1, nq=9) in qpos space. + # Ball joints contribute 3 DOF (axis-angle) instead of 4 qpos (quaternion). + assert result.shape == (1, 7), f"Expected (1, 7), got {result.shape}" + + # At default position, all relative positions should be zero. + assert torch.allclose(result, torch.zeros_like(result), atol=1e-6) + + +def test_mdp_joint_vel_rel_with_ball_joints(device): + """Test joint_vel_rel observation with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import observations + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Get relative joint velocities. + result = observations.joint_vel_rel(env, asset_cfg=asset_cfg) + + # Should have shape (1, nv=7), not (1, num_joints=3). + assert result.shape == (1, 7), f"Expected (1, 7), got {result.shape}" + + +def test_mdp_joint_vel_l2_with_ball_joints(device): + """Test joint_vel_l2 reward with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import rewards + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Compute joint velocity L2 penalty. + result = rewards.joint_vel_l2(env, asset_cfg=asset_cfg) + + # Should return scalar per env. + assert result.shape == (1,), f"Expected (1,), got {result.shape}" + + +def test_mdp_joint_acc_l2_with_ball_joints(device): + """Test joint_acc_l2 reward with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import rewards + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Compute joint acceleration L2 penalty. + result = rewards.joint_acc_l2(env, asset_cfg=asset_cfg) + + # Should return scalar per env. + assert result.shape == (1,), f"Expected (1,), got {result.shape}" + + +def test_mdp_joint_pos_limits_with_ball_joints(device): + """Test joint_pos_limits reward with ball joints.""" + from unittest.mock import Mock + + from mjlab.envs.mdp import rewards + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, _ = initialize_entity(entity, device) + + env = Mock() + env.device = device + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + + # Compute joint position limits penalty. + result = rewards.joint_pos_limits(env, asset_cfg=asset_cfg) + + # Should return scalar per env. + assert result.shape == (1,), f"Expected (1,), got {result.shape}" + + +def test_quaternion_difference_correctness(device): + """Test that joint_pos_rel computes correct quaternion difference. + + For a 90-degree rotation around z-axis, the axis-angle representation + should be [0, 0, pi/2] (rotation axis z, angle 90 degrees). + """ + import math + from unittest.mock import Mock + + from mjlab.envs.mdp import observations + from mjlab.managers.scene_entity_config import SceneEntityCfg + + entity = create_ball_joint_entity() + entity, sim = initialize_entity(entity, device) + + # Apply a 90-degree rotation around z-axis to ball joint 1. + # Quaternion for 90 deg around z: (cos(45°), 0, 0, sin(45°)) = (0.7071, 0, 0, 0.7071) + angle = math.pi / 2 + half_angle = angle / 2 + new_pos = entity.data.joint_pos.clone() + new_pos[0, 0:4] = torch.tensor( + [math.cos(half_angle), 0.0, 0.0, math.sin(half_angle)], device=device + ) + entity.data.write_joint_position(new_pos) + sim.forward() + + env = Mock() + env.scene = {"robot": entity} + + asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None)) + result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg) + + # Ball joint 1 (indices 0:3 in DOF space) should have axis-angle [0, 0, pi/2]. + ball1_diff = result[0, 0:3] + expected_axis_angle = torch.tensor([0.0, 0.0, math.pi / 2], device=device) + assert torch.allclose(ball1_diff, expected_axis_angle, atol=1e-4), ( + f"Expected axis-angle {expected_axis_angle}, got {ball1_diff}" + ) + + # Hinge joint (index 3 in DOF space) should be unchanged (0). + assert abs(result[0, 3].item()) < 1e-6 + + # Ball joint 2 (indices 4:7 in DOF space) should be unchanged (identity -> 0). + ball2_diff = result[0, 4:7] + assert torch.allclose(ball2_diff, torch.zeros(3, device=device), atol=1e-6) diff --git a/tests/test_events.py b/tests/test_events.py index 3f66c9f13..82eba6a9b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -428,6 +428,9 @@ def test_reset_joints_by_offset(device): env.device = device mock_entity = Mock() + mock_entity.num_joints = 3 + mock_entity.nq = 3 + mock_entity.nv = 3 mock_entity.data.default_joint_pos = torch.zeros((2, 3), device=device) mock_entity.data.default_joint_vel = torch.zeros((2, 3), device=device) mock_entity.data.soft_joint_pos_limits = torch.tensor( @@ -438,6 +441,15 @@ def test_reset_joints_by_offset(device): device=device, ) mock_entity.write_joint_state_to_sim = Mock() + + # Mock indexing for ball joint support (all hinge joints: 1 qpos, 1 dof each). + mock_entity.indexing.expand_to_q_indices = ( + lambda x: slice(None) if isinstance(x, slice) else x + ) + mock_entity.indexing.expand_to_v_indices = ( + lambda x: slice(None) if isinstance(x, slice) else x + ) + mock_entity.indexing.joint_qpos_widths = torch.ones(3, dtype=torch.int, device=device) env.scene = {"robot": mock_entity} # Normal offset. diff --git a/tests/test_rewards.py b/tests/test_rewards.py index bd3708bc7..571022674 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -199,7 +199,8 @@ def test_electrical_power_cost_partially_actuated(device): reward = electrical_power_cost(reward_cfg, env) - assert len(reward._joint_ids) == 2 + assert isinstance(reward._v_indices, torch.Tensor) + assert len(reward._v_indices) == 2 assert len(reward._actuator_ids) == 2 # Test case 1: All forces and velocities aligned (all positive work).