diff --git a/src/mjlab/actuator/actuator.py b/src/mjlab/actuator/actuator.py
index 1f1c4340e..8196a9ea8 100644
--- a/src/mjlab/actuator/actuator.py
+++ b/src/mjlab/actuator/actuator.py
@@ -114,6 +114,9 @@ def __init__(
self._global_ctrl_ids: torch.Tensor | None = None
self._mjs_actuators: list[mujoco.MjsActuator] = []
self._site_zeros: torch.Tensor | None = None
+ # Expanded indices for ball joint support.
+ self._q_indices: torch.Tensor | None = None
+ self._v_indices: torch.Tensor | None = None
@property
def target_ids(self) -> torch.Tensor:
@@ -189,6 +192,15 @@ def initialize(
device=device,
)
+ # Compute expanded indices for ball joint support.
+ if self.transmission_type == TransmissionType.JOINT:
+ indexing = self.entity.indexing
+ q_indices = indexing.expand_to_q_indices(self._target_ids_list)
+ v_indices = indexing.expand_to_v_indices(self._target_ids_list)
+ assert isinstance(q_indices, torch.Tensor) and isinstance(v_indices, torch.Tensor)
+ self._q_indices = q_indices
+ self._v_indices = v_indices
+
# Pre-allocate zeros for SITE transmission type to avoid repeated allocations.
if self.transmission_type == TransmissionType.SITE:
nenvs = data.nworld
@@ -206,11 +218,11 @@ def get_command(self, data: EntityData) -> ActuatorCmd:
"""
if self.transmission_type == TransmissionType.JOINT:
return ActuatorCmd(
- position_target=data.joint_pos_target[:, self.target_ids],
- velocity_target=data.joint_vel_target[:, self.target_ids],
- effort_target=data.joint_effort_target[:, self.target_ids],
- pos=data.joint_pos[:, self.target_ids],
- vel=data.joint_vel[:, self.target_ids],
+ position_target=data.joint_pos_target[:, self._q_indices],
+ velocity_target=data.joint_vel_target[:, self._v_indices],
+ effort_target=data.joint_effort_target[:, self._v_indices],
+ pos=data.joint_pos[:, self._q_indices],
+ vel=data.joint_vel[:, self._v_indices],
)
elif self.transmission_type == TransmissionType.TENDON:
return ActuatorCmd(
diff --git a/src/mjlab/actuator/builtin_group.py b/src/mjlab/actuator/builtin_group.py
index ad14705e7..ebcc3be95 100644
--- a/src/mjlab/actuator/builtin_group.py
+++ b/src/mjlab/actuator/builtin_group.py
@@ -37,6 +37,16 @@
(BuiltinMuscleActuator, TransmissionType.TENDON): "tendon_effort_target",
}
+# Indicates whether the actuator type uses qpos indexing (vs qvel indexing).
+# Position actuators read from qpos-indexed tensors (nq dimension).
+# Velocity/motor actuators read from qvel-indexed tensors (nv dimension).
+_USES_QPOS_INDEXING: dict[type[BuiltinActuatorType], bool] = {
+ BuiltinPositionActuator: True,
+ BuiltinVelocityActuator: False,
+ BuiltinMotorActuator: False,
+ BuiltinMuscleActuator: False,
+}
+
@dataclass(frozen=True)
class BuiltinActuatorGroup:
@@ -47,7 +57,9 @@ class BuiltinActuatorGroup:
enables direct writes without per-actuator overhead.
"""
- # Map from (BuiltinActuator type, transmission_type) to (target_ids, ctrl_ids).
+ # Map from (BuiltinActuator type, transmission_type) to (expanded_indices, ctrl_ids).
+ # For JOINT transmission, expanded_indices are qpos or qvel indices depending
+ # on the actuator type. For other transmissions, they match target_ids.
_index_groups: dict[
tuple[type[BuiltinActuatorType], TransmissionType],
tuple[torch.Tensor, torch.Tensor],
@@ -84,17 +96,29 @@ def process(
else:
custom_actuators.append(act)
- # Return stacked indices for each (actuator_type, transmission_type) group.
+ # Build stacked indices for each (actuator_type, transmission_type) group.
index_groups: dict[
tuple[type[BuiltinActuatorType], TransmissionType],
tuple[torch.Tensor, torch.Tensor],
- ] = {
- key: (
- torch.cat([act.target_ids for act in acts], dim=0),
- torch.cat([act.ctrl_ids for act in acts], dim=0),
- )
- for key, acts in builtin_groups.items()
- }
+ ] = {}
+
+ for key, acts in builtin_groups.items():
+ actuator_type, transmission_type = key
+ ctrl_ids = torch.cat([act.ctrl_ids for act in acts], dim=0)
+
+ if transmission_type == TransmissionType.JOINT:
+ # Use expanded qpos/qvel indices for joint transmission.
+ uses_qpos = _USES_QPOS_INDEXING[actuator_type]
+ attr = "_q_indices" if uses_qpos else "_v_indices"
+ indices_list = [getattr(act, attr) for act in acts]
+ assert all(idx is not None for idx in indices_list)
+ expanded_indices = torch.cat(indices_list, dim=0) # type: ignore[arg-type]
+ index_groups[key] = (expanded_indices, ctrl_ids)
+ else:
+ # For tendon/site, use flat target_ids.
+ target_ids = torch.cat([act.target_ids for act in acts], dim=0)
+ index_groups[key] = (target_ids, ctrl_ids)
+
return BuiltinActuatorGroup(index_groups), tuple(custom_actuators)
def apply_controls(self, data: EntityData) -> None:
@@ -104,9 +128,9 @@ def apply_controls(self, data: EntityData) -> None:
data: Entity data containing targets and control arrays.
"""
for (actuator_type, transmission_type), (
- target_ids,
+ expanded_indices,
ctrl_ids,
) in self._index_groups.items():
attr_name = _TARGET_TENSOR_MAP[(actuator_type, transmission_type)]
target_tensor = getattr(data, attr_name)
- data.write_ctrl(target_tensor[:, target_ids], ctrl_ids)
+ data.write_ctrl(target_tensor[:, expanded_indices], ctrl_ids)
diff --git a/src/mjlab/actuator/delayed_actuator.py b/src/mjlab/actuator/delayed_actuator.py
index fa1ac5361..736c6a48e 100644
--- a/src/mjlab/actuator/delayed_actuator.py
+++ b/src/mjlab/actuator/delayed_actuator.py
@@ -105,6 +105,9 @@ def initialize(
self._target_ids = self._base_actuator._target_ids
self._ctrl_ids = self._base_actuator._ctrl_ids
self._global_ctrl_ids = self._base_actuator._global_ctrl_ids
+ # Copy expanded indices for ball joint support.
+ self._q_indices = self._base_actuator._q_indices
+ self._v_indices = self._base_actuator._v_indices
targets = (
(self.cfg.delay_target,)
diff --git a/src/mjlab/entity/data.py b/src/mjlab/entity/data.py
index 80bef750b..d521407a6 100644
--- a/src/mjlab/entity/data.py
+++ b/src/mjlab/entity/data.py
@@ -39,6 +39,11 @@ class EntityData:
root_link_pose_w) require sim.forward() to be current. If you write then read,
call sim.forward() in between. Event order matters when mixing reads and writes.
All inputs/outputs use world frame.
+
+ Ball Joint Support:
+ Position tensors use nq dimensions (4 per ball, 1 per hinge/slide).
+ Velocity/effort tensors use nv dimensions (3 per ball, 1 per hinge/slide).
+ Joint limits remain per-joint: (nworld, num_joints, 2).
"""
indexing: EntityIndexing
@@ -155,9 +160,9 @@ def write_joint_position(
raise ValueError("Cannot write joint position for non-articulated entity.")
env_ids = self._resolve_env_ids(env_ids)
- joint_ids = joint_ids if joint_ids is not None else slice(None)
- q_slice = self.indexing.joint_q_adr[joint_ids]
- self.data.qpos[env_ids, q_slice] = position
+ q_indices = self.indexing.expand_to_q_indices(joint_ids)
+ q_adr = self.indexing.joint_q_adr[q_indices]
+ self.data.qpos[env_ids, q_adr] = position
def write_joint_velocity(
self,
@@ -169,9 +174,9 @@ def write_joint_velocity(
raise ValueError("Cannot write joint velocity for non-articulated entity.")
env_ids = self._resolve_env_ids(env_ids)
- joint_ids = joint_ids if joint_ids is not None else slice(None)
- v_slice = self.indexing.joint_v_adr[joint_ids]
- self.data.qvel[env_ids, v_slice] = velocity
+ v_indices = self.indexing.expand_to_v_indices(joint_ids)
+ v_adr = self.indexing.joint_v_adr[v_indices]
+ self.data.qvel[env_ids, v_adr] = velocity
def write_external_wrench(
self,
diff --git a/src/mjlab/entity/entity.py b/src/mjlab/entity/entity.py
index 88d4d6fb5..092bd7bbb 100644
--- a/src/mjlab/entity/entity.py
+++ b/src/mjlab/entity/entity.py
@@ -17,7 +17,7 @@
from mjlab.utils.lab_api.string import resolve_matching_names
from mjlab.utils.mujoco import dof_width, qpos_width
from mjlab.utils.spec import auto_wrap_fixed_base_mocap
-from mjlab.utils.string import resolve_expr
+from mjlab.utils.string import resolve_expr_with_widths
@dataclass(frozen=False)
@@ -53,10 +53,63 @@ class EntityIndexing:
free_joint_q_adr: torch.Tensor
free_joint_v_adr: torch.Tensor
+ # Ball joint support: per-joint dimensions and cumulative offsets.
+ joint_qpos_widths: torch.Tensor # (num_joints,) - 4 for ball, 1 for hinge/slide
+ joint_dof_widths: torch.Tensor # (num_joints,) - 3 for ball, 1 for hinge/slide
+ joint_types: torch.Tensor # (num_joints,) - mjtJoint enum values
+ q_offsets: torch.Tensor # (num_joints,) - cumulative qpos offsets per joint
+ v_offsets: torch.Tensor # (num_joints,) - cumulative qvel offsets per joint
+ nq: int # Total qpos dimension
+ nv: int # Total dof dimension
+
@property
def root_body_id(self) -> int:
return self.bodies[0].id
+ def expand_to_q_indices(
+ self, joint_ids: torch.Tensor | Sequence[int] | slice | None
+ ) -> torch.Tensor | slice:
+ """Expand joint IDs to qpos indices for ball joint support."""
+ return self._expand_indices(joint_ids, self.joint_qpos_widths, self.q_offsets)
+
+ def expand_to_v_indices(
+ self, joint_ids: torch.Tensor | Sequence[int] | slice | None
+ ) -> torch.Tensor | slice:
+ """Expand joint IDs to qvel indices for ball joint support."""
+ return self._expand_indices(joint_ids, self.joint_dof_widths, self.v_offsets)
+
+ def _expand_indices(
+ self,
+ joint_ids: torch.Tensor | Sequence[int] | slice | None,
+ widths: torch.Tensor,
+ offsets: torch.Tensor,
+ ) -> torch.Tensor | slice:
+ """Expand joint IDs to DOF indices using tensor operations."""
+ if joint_ids is None:
+ return slice(None)
+
+ device = widths.device
+ if isinstance(joint_ids, slice):
+ start = joint_ids.start or 0
+ stop = joint_ids.stop or len(widths)
+ joint_ids = torch.arange(start, stop, device=device)
+ elif not isinstance(joint_ids, torch.Tensor):
+ joint_ids = torch.tensor(joint_ids, dtype=torch.long, device=device)
+
+ selected_widths = widths[joint_ids]
+ selected_offsets = offsets[joint_ids]
+
+ # Expand using repeat_interleave: starts + local_offsets
+ repeated_offsets = selected_offsets.repeat_interleave(selected_widths)
+ total = int(selected_widths.sum().item())
+ cumwidths = torch.zeros_like(selected_widths)
+ cumwidths[1:] = selected_widths[:-1].cumsum(0)
+ local_offsets = torch.arange(total, device=device) - cumwidths.repeat_interleave(
+ selected_widths
+ )
+
+ return repeated_offsets + local_offsets
+
@dataclass
class EntityCfg:
@@ -70,7 +123,10 @@ class InitialStateCfg:
ang_vel: tuple[float, float, float] = (0.0, 0.0, 0.0)
# Articulation (only for articulated entities).
# Set to None to use the model's existing keyframe (errors if none exists).
- joint_pos: dict[str, float] | None = field(default_factory=lambda: {".*": 0.0})
+ # Ball joints use tuple of 4 floats for quaternion (w, x, y, z).
+ joint_pos: dict[str, float | tuple[float, ...]] | None = field(
+ default_factory=lambda: {".*": 0.0}
+ )
joint_vel: dict[str, float] = field(default_factory=lambda: {".*": 0.0})
init_state: InitialStateCfg = field(default_factory=InitialStateCfg)
@@ -204,21 +260,37 @@ def _add_initial_state_keyframe(self) -> None:
self.root_body.quat[:] = self.cfg.init_state.rot
return
- qpos_components = []
+ qpos_components: list[tuple[float, ...]] = []
if self._free_joint is not None:
qpos_components.extend([self.cfg.init_state.pos, self.cfg.init_state.rot])
joint_pos = None
if self._non_free_joints:
- joint_pos = resolve_expr(self.cfg.init_state.joint_pos, self.joint_names, 0.0)
+ # Per-joint widths and defaults (ball: 4 qpos, hinge/slide: 1).
+ widths = tuple(qpos_width(j.type) for j in self._non_free_joints)
+ defaults = tuple(
+ (1.0, 0.0, 0.0, 0.0) if j.type == mujoco.mjtJoint.mjJNT_BALL else (0.0,)
+ for j in self._non_free_joints
+ )
+ joint_pos = resolve_expr_with_widths(
+ self.cfg.init_state.joint_pos, self.joint_names, widths, defaults
+ )
qpos_components.append(joint_pos)
key_qpos = np.hstack(qpos_components) if qpos_components else np.array([])
key = self._spec.add_key(name="init_state", qpos=key_qpos.tolist())
if self.is_actuated and joint_pos is not None:
- name_to_pos = {name: joint_pos[i] for i, name in enumerate(self.joint_names)}
+ # Map joint names to position values for ctrl (skip ball joints).
+ joint_q_offset = 0
+ name_to_pos: dict[str, float] = {}
+ for j, name in zip(self._non_free_joints, self.joint_names, strict=True):
+ width = qpos_width(j.type)
+ if width == 1:
+ name_to_pos[name] = joint_pos[joint_q_offset]
+ joint_q_offset += width
+
ctrl = []
for act in self._spec.actuators:
joint_name = act.target
@@ -315,6 +387,20 @@ def actuator_names(self) -> tuple[str, ...]:
def num_joints(self) -> int:
return len(self.joint_names)
+ @property
+ def nq(self) -> int:
+ """Total qpos dimension (excludes free joint). Accounts for ball joints (4 qpos)."""
+ if hasattr(self, "indexing"):
+ return self.indexing.nq
+ return sum(qpos_width(j.type) for j in self._non_free_joints)
+
+ @property
+ def nv(self) -> int:
+ """Total dof dimension (excludes free joint). Accounts for ball joints (3 dof)."""
+ if hasattr(self, "indexing"):
+ return self.indexing.nv
+ return sum(dof_width(j.type) for j in self._non_free_joints)
+
@property
def num_bodies(self) -> int:
return len(self.body_names)
@@ -521,6 +607,13 @@ def initialize(
# Joint state.
if self.is_articulated:
+ # Use indexing tensors for widths; build defaults based on joint types.
+ qpos_widths = tuple(indexing.joint_qpos_widths.tolist())
+ dof_widths_tuple = tuple(indexing.joint_dof_widths.tolist())
+ is_ball = (indexing.joint_types == int(mujoco.mjtJoint.mjJNT_BALL)).tolist()
+ qpos_defaults = tuple((1.0, 0.0, 0.0, 0.0) if b else (0.0,) for b in is_ball)
+ vel_defaults = tuple((0.0, 0.0, 0.0) if b else (0.0,) for b in is_ball)
+
if self.cfg.init_state.joint_pos is None:
# Use keyframe joint positions.
key_qpos = mj_model.key("init_state").qpos
@@ -530,19 +623,40 @@ def initialize(
].repeat(nworld, 1)
else:
default_joint_pos = torch.tensor(
- resolve_expr(self.cfg.init_state.joint_pos, self.joint_names, 0.0),
+ resolve_expr_with_widths(
+ self.cfg.init_state.joint_pos,
+ self.joint_names,
+ qpos_widths,
+ qpos_defaults,
+ ),
device=device,
)[None].repeat(nworld, 1)
default_joint_vel = torch.tensor(
- resolve_expr(self.cfg.init_state.joint_vel, self.joint_names, 0.0),
+ resolve_expr_with_widths(
+ self.cfg.init_state.joint_vel,
+ self.joint_names,
+ dof_widths_tuple,
+ vel_defaults,
+ ),
device=device,
)[None].repeat(nworld, 1)
- # Joint limits.
+ # Joint limits: hinge/slide have real limits, ball joints don't.
+ # MuJoCo joint types: free=0, ball=1, slide=2, hinge=3.
joint_ids_global = torch.tensor(
[j.id for j in self._non_free_joints], device=device
)
- dof_limits = model.jnt_range[:, joint_ids_global]
+ dof_limits = model.jnt_range[:, joint_ids_global] # (1, num_joints, 2)
+ has_real_limits = indexing.joint_types >= 2 # slide=2, hinge=3
+ ball_mask = ~has_real_limits
+ if ball_mask.any():
+ # Ball joints: substitute with large range since they have no position limits.
+ large_range = 1e6
+ dof_limits = dof_limits.clone()
+ dof_limits[:, ball_mask, :] = torch.tensor(
+ [[-large_range, large_range]], device=device
+ )
+
default_joint_pos_limits = dof_limits.clone()
joint_pos_limits = default_joint_pos_limits.clone()
joint_pos_mean = (joint_pos_limits[..., 0] + joint_pos_limits[..., 1]) / 2
@@ -574,14 +688,15 @@ def initialize(
)
if self.is_actuated:
+ # Use nq for position targets, nv for velocity/effort targets.
joint_pos_target = torch.zeros(
- (nworld, self.num_joints), dtype=torch.float, device=device
+ (nworld, indexing.nq), dtype=torch.float, device=device
)
joint_vel_target = torch.zeros(
- (nworld, self.num_joints), dtype=torch.float, device=device
+ (nworld, indexing.nv), dtype=torch.float, device=device
)
joint_effort_target = torch.zeros(
- (nworld, self.num_joints), dtype=torch.float, device=device
+ (nworld, indexing.nv), dtype=torch.float, device=device
)
else:
joint_pos_target = torch.empty(nworld, 0, dtype=torch.float, device=device)
@@ -615,10 +730,10 @@ def initialize(
site_effort_target = torch.empty(nworld, 0, dtype=torch.float, device=device)
# Encoder bias for simulating encoder calibration errors.
- # Shape: (num_envs, num_joints). Defaults to zero (no bias).
+ # Shape: (num_envs, nq). Defaults to zero (no bias).
if self.is_articulated:
encoder_bias = torch.zeros(
- (nworld, self.num_joints), dtype=torch.float, device=device
+ (nworld, indexing.nq), dtype=torch.float, device=device
)
else:
encoder_bias = torch.empty(nworld, 0, dtype=torch.float, device=device)
@@ -812,15 +927,15 @@ def set_joint_position_target(
"""Set joint position targets.
Args:
- position: Target joint poisitions with shape (N, num_joints).
+ position: Target joint positions with shape (N, nq) or (N, selected_nq).
+ For ball joints, this should include all 4 quaternion values per joint.
joint_ids: Optional joint indices to set. If None, set all joints.
env_ids: Optional environment indices. If None, set all environments.
"""
if env_ids is None:
env_ids = slice(None)
- if joint_ids is None:
- joint_ids = slice(None)
- self._data.joint_pos_target[env_ids, joint_ids] = position
+ q_indices = self.indexing.expand_to_q_indices(joint_ids)
+ self._data.joint_pos_target[env_ids, q_indices] = position
def set_joint_velocity_target(
self,
@@ -831,15 +946,15 @@ def set_joint_velocity_target(
"""Set joint velocity targets.
Args:
- velocity: Target joint velocities with shape (N, num_joints).
+ velocity: Target joint velocities with shape (N, nv) or (N, selected_nv).
+ For ball joints, this should include all 3 angular velocity values per joint.
joint_ids: Optional joint indices to set. If None, set all joints.
env_ids: Optional environment indices. If None, set all environments.
"""
if env_ids is None:
env_ids = slice(None)
- if joint_ids is None:
- joint_ids = slice(None)
- self._data.joint_vel_target[env_ids, joint_ids] = velocity
+ v_indices = self.indexing.expand_to_v_indices(joint_ids)
+ self._data.joint_vel_target[env_ids, v_indices] = velocity
def set_joint_effort_target(
self,
@@ -850,15 +965,15 @@ def set_joint_effort_target(
"""Set joint effort targets.
Args:
- effort: Target joint efforts with shape (N, num_joints).
+ effort: Target joint efforts with shape (N, nv) or (N, selected_nv).
+ For ball joints, this should include all 3 torque values per joint.
joint_ids: Optional joint indices to set. If None, set all joints.
env_ids: Optional environment indices. If None, set all environments.
"""
if env_ids is None:
env_ids = slice(None)
- if joint_ids is None:
- joint_ids = slice(None)
- self._data.joint_effort_target[env_ids, joint_ids] = effort
+ v_indices = self.indexing.expand_to_v_indices(joint_ids)
+ self._data.joint_effort_target[env_ids, v_indices] = effort
def set_tendon_len_target(
self,
@@ -1010,6 +1125,11 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin
joint_v_adr = []
free_joint_q_adr = []
free_joint_v_adr = []
+
+ # Per-joint dimension tracking for ball joint support.
+ joint_qpos_widths_list: list[int] = []
+ joint_dof_widths_list: list[int] = []
+ joint_types_list: list[int] = []
for joint in self.spec.joints:
jnt = model.joint(joint.name)
jnt_type = jnt.type[0]
@@ -1019,13 +1139,39 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin
free_joint_v_adr.extend(range(vadr, vadr + 6))
free_joint_q_adr.extend(range(qadr, qadr + 7))
else:
- joint_v_adr.extend(range(vadr, vadr + dof_width(jnt_type)))
- joint_q_adr.extend(range(qadr, qadr + qpos_width(jnt_type)))
+ qw = qpos_width(jnt_type)
+ dw = dof_width(jnt_type)
+ joint_v_adr.extend(range(vadr, vadr + dw))
+ joint_q_adr.extend(range(qadr, qadr + qw))
+
+ # Track per-joint dimensions.
+ joint_types_list.append(jnt_type)
+ joint_qpos_widths_list.append(qw)
+ joint_dof_widths_list.append(dw)
+
joint_q_adr = torch.tensor(joint_q_adr, dtype=torch.int, device=device)
joint_v_adr = torch.tensor(joint_v_adr, dtype=torch.int, device=device)
free_joint_v_adr = torch.tensor(free_joint_v_adr, dtype=torch.int, device=device)
free_joint_q_adr = torch.tensor(free_joint_q_adr, dtype=torch.int, device=device)
+ # Convert per-joint tracking lists to tensors.
+ joint_qpos_widths = torch.tensor(
+ joint_qpos_widths_list, dtype=torch.int, device=device
+ )
+ joint_dof_widths = torch.tensor(
+ joint_dof_widths_list, dtype=torch.int, device=device
+ )
+ joint_types = torch.tensor(joint_types_list, dtype=torch.int, device=device)
+
+ # Compute cumulative offsets for tensor-based index expansion.
+ q_offsets = torch.zeros_like(joint_qpos_widths)
+ v_offsets = torch.zeros_like(joint_dof_widths)
+ if len(joint_qpos_widths) > 0:
+ q_offsets[1:] = joint_qpos_widths[:-1].cumsum(0)
+ v_offsets[1:] = joint_dof_widths[:-1].cumsum(0)
+ nq = int(joint_qpos_widths.sum().item()) if len(joint_qpos_widths) > 0 else 0
+ nv = int(joint_dof_widths.sum().item()) if len(joint_dof_widths) > 0 else 0
+
if self.is_fixed_base and self.is_mocap:
mocap_id = int(model.body_mocapid[self.root_body.id])
else:
@@ -1055,6 +1201,13 @@ def _compute_indexing(self, model: mujoco.MjModel, device: str) -> EntityIndexin
joint_v_adr=joint_v_adr,
free_joint_q_adr=free_joint_q_adr,
free_joint_v_adr=free_joint_v_adr,
+ joint_qpos_widths=joint_qpos_widths,
+ joint_dof_widths=joint_dof_widths,
+ joint_types=joint_types,
+ q_offsets=q_offsets,
+ v_offsets=v_offsets,
+ nq=nq,
+ nv=nv,
)
def _apply_actuator_controls(self) -> None:
diff --git a/src/mjlab/envs/mdp/events.py b/src/mjlab/envs/mdp/events.py
index f8b9e5974..7db3391b7 100644
--- a/src/mjlab/envs/mdp/events.py
+++ b/src/mjlab/envs/mdp/events.py
@@ -9,6 +9,7 @@
from mjlab.entity import Entity
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.utils.lab_api.math import (
+ quat_from_angle_axis,
quat_from_euler_xyz,
quat_mul,
sample_uniform,
@@ -284,6 +285,11 @@ def reset_joints_by_offset(
velocity_range: tuple[float, float],
asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG,
) -> None:
+ """Reset joint positions with random offsets from default.
+
+ For hinge joints: adds uniform noise and clamps to limits.
+ For ball joints: applies random rotation via quaternion multiplication.
+ """
if env_ids is None:
env_ids = torch.arange(env.num_envs, device=env.device, dtype=torch.int)
@@ -295,17 +301,102 @@ def reset_joints_by_offset(
soft_joint_pos_limits = asset.data.soft_joint_pos_limits
assert soft_joint_pos_limits is not None
- joint_pos = default_joint_pos[env_ids][:, asset_cfg.joint_ids].clone()
- joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device)
- joint_pos_limits = soft_joint_pos_limits[env_ids][:, asset_cfg.joint_ids]
- joint_pos = joint_pos.clamp_(joint_pos_limits[..., 0], joint_pos_limits[..., 1])
+ v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids)
+
+ joint_ids = asset_cfg.joint_ids
+ if isinstance(joint_ids, slice):
+ joint_ids_tensor = torch.arange(asset.num_joints, device=env.device)
+ elif isinstance(joint_ids, list):
+ joint_ids_tensor = torch.tensor(joint_ids, device=env.device)
+ else:
+ joint_ids_tensor = joint_ids
+
+ # Hinge-only entities: use simple add + clamp.
+ if asset.nq == asset.nv:
+ q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids)
+ joint_pos = default_joint_pos[env_ids][:, q_indices].clone()
+ joint_pos += sample_uniform(*position_range, joint_pos.shape, env.device)
+
+ joint_limits = soft_joint_pos_limits[env_ids][:, joint_ids_tensor]
+ qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor]
+ expanded_limits = joint_limits.repeat_interleave(qpos_widths, dim=1)
+ joint_pos = joint_pos.clamp_(expanded_limits[..., 0], expanded_limits[..., 1])
+ else:
+ # Mixed ball/hinge joints: handle separately.
+ qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor]
+ dof_widths = asset.indexing.joint_dof_widths[joint_ids_tensor]
+ q_offsets = asset.indexing.q_offsets[joint_ids_tensor]
+
+ num_envs = len(env_ids)
+ device = env.device
+ total_qpos = int(qpos_widths.sum().item())
+ joint_pos = torch.zeros((num_envs, total_qpos), device=device)
+
+ is_ball_joint = (qpos_widths == 4) & (dof_widths == 3)
+ qpos_cumsum = torch.cat(
+ [
+ torch.zeros(1, device=device, dtype=qpos_widths.dtype),
+ qpos_widths.cumsum(0)[:-1],
+ ]
+ )
+
+ # Ball joints: sample rotation angle, random axis, apply via quat_mul.
+ if is_ball_joint.any():
+ ball_q_offsets = q_offsets[is_ball_joint]
+ ball_out_offsets = qpos_cumsum[is_ball_joint]
+ num_ball = int(is_ball_joint.sum().item())
+
+ # Get default quaternions.
+ ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device)
+ default_quats = default_joint_pos[env_ids][:, ball_q_indices.flatten()].view(
+ num_envs, num_ball, 4
+ )
+
+ # Sample rotation angles from position_range.
+ angles = sample_uniform(
+ position_range[0], position_range[1], (num_envs * num_ball,), device
+ )
+
+ # Sample random rotation axes (unit vectors).
+ axes = torch.randn((num_envs * num_ball, 3), device=device)
+ axes = axes / axes.norm(dim=-1, keepdim=True)
+
+ # Create delta quaternions.
+ delta_quats = quat_from_angle_axis(angles, axes)
+
+ # Apply rotation: q_new = q_default * q_delta.
+ new_quats = quat_mul(default_quats.reshape(-1, 4), delta_quats).view(
+ num_envs, num_ball, 4
+ )
+
+ ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(4, device=device)
+ joint_pos[:, ball_out_indices.flatten()] = new_quats.reshape(num_envs, -1)
- joint_vel = default_joint_vel[env_ids][:, asset_cfg.joint_ids].clone()
+ # Hinge joints: add noise + clamp.
+ is_hinge_joint = ~is_ball_joint
+ if is_hinge_joint.any():
+ hinge_q_offsets = q_offsets[is_hinge_joint]
+ hinge_out_offsets = qpos_cumsum[is_hinge_joint]
+ hinge_joint_ids = joint_ids_tensor[is_hinge_joint]
+ num_hinge = int(is_hinge_joint.sum().item())
+
+ default_hinge = default_joint_pos[env_ids][:, hinge_q_offsets]
+ noise = sample_uniform(*position_range, (num_envs, num_hinge), device)
+ hinge_pos = default_hinge + noise
+
+ # Clamp to limits.
+ hinge_limits = soft_joint_pos_limits[env_ids][:, hinge_joint_ids]
+ hinge_pos = hinge_pos.clamp_(hinge_limits[..., 0], hinge_limits[..., 1])
+
+ joint_pos[:, hinge_out_offsets] = hinge_pos
+
+ joint_vel = default_joint_vel[env_ids][:, v_indices].clone()
joint_vel += sample_uniform(*velocity_range, joint_vel.shape, env.device)
- joint_ids = asset_cfg.joint_ids
- if isinstance(joint_ids, list):
- joint_ids = torch.tensor(joint_ids, device=env.device)
+ if isinstance(asset_cfg.joint_ids, list):
+ joint_ids = torch.tensor(asset_cfg.joint_ids, device=env.device)
+ else:
+ joint_ids = asset_cfg.joint_ids
asset.write_joint_state_to_sim(
joint_pos.view(len(env_ids), -1),
diff --git a/src/mjlab/envs/mdp/observations.py b/src/mjlab/envs/mdp/observations.py
index 1b5b326f8..6c7fc00af 100644
--- a/src/mjlab/envs/mdp/observations.py
+++ b/src/mjlab/envs/mdp/observations.py
@@ -9,6 +9,7 @@
from mjlab.entity import Entity
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.sensor import BuiltinSensor, RayCastSensor
+from mjlab.utils.lab_api.math import quat_box_minus
if TYPE_CHECKING:
from mjlab.envs import ManagerBasedRlEnv
@@ -53,12 +54,74 @@ def joint_pos_rel(
biased: bool = False,
asset_cfg: SceneEntityCfg = _DEFAULT_ASSET_CFG,
) -> torch.Tensor:
+ """Compute relative joint positions (current - default) in DOF space.
+
+ For hinge/slide joints: returns scalar difference.
+ For ball joints: returns 3D axis-angle difference via quat_box_minus.
+
+ Returns:
+ Tensor of shape (num_envs, total_dof) for selected joints.
+ """
asset: Entity = env.scene[asset_cfg.name]
default_joint_pos = asset.data.default_joint_pos
assert default_joint_pos is not None
- jnt_ids = asset_cfg.joint_ids
joint_pos = asset.data.joint_pos_biased if biased else asset.data.joint_pos
- return joint_pos[:, jnt_ids] - default_joint_pos[:, jnt_ids]
+
+ # Hinge-only entities (nq == nv means no ball joints).
+ if asset.nq == asset.nv:
+ q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids)
+ return joint_pos[:, q_indices] - default_joint_pos[:, q_indices]
+
+ # Ball joints.
+ joint_ids = asset_cfg.joint_ids
+ if isinstance(joint_ids, slice):
+ joint_ids_tensor = torch.arange(asset.num_joints, device=joint_pos.device)
+ elif isinstance(joint_ids, list):
+ joint_ids_tensor = torch.tensor(joint_ids, device=joint_pos.device)
+ else:
+ joint_ids_tensor = joint_ids
+
+ qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor]
+ dof_widths = asset.indexing.joint_dof_widths[joint_ids_tensor]
+ q_offsets = asset.indexing.q_offsets[joint_ids_tensor]
+
+ num_envs = joint_pos.shape[0]
+ device = joint_pos.device
+
+ is_ball_joint = (qpos_widths == 4) & (dof_widths == 3)
+ dof_cumsum = torch.cat(
+ [torch.zeros(1, device=device, dtype=dof_widths.dtype), dof_widths.cumsum(0)[:-1]]
+ )
+ total_dof = int(dof_widths.sum().item())
+ result = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype)
+
+ # Ball joints: quaternion difference via quat_box_minus.
+ if is_ball_joint.any():
+ ball_q_offsets = q_offsets[is_ball_joint]
+ ball_out_offsets = dof_cumsum[is_ball_joint]
+
+ ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device)
+ current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4)
+ default_quats = default_joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4)
+
+ num_ball = current_quats.shape[1]
+ diff_flat = quat_box_minus(
+ current_quats.reshape(-1, 4), default_quats.reshape(-1, 4)
+ )
+ diff = diff_flat.view(num_envs, num_ball, 3)
+
+ ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device)
+ result[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1)
+
+ # Hinge joints: scalar subtraction.
+ is_hinge_joint = ~is_ball_joint
+ if is_hinge_joint.any():
+ hinge_q_offsets = q_offsets[is_hinge_joint]
+ hinge_out_offsets = dof_cumsum[is_hinge_joint]
+ diff_hinge = joint_pos[:, hinge_q_offsets] - default_joint_pos[:, hinge_q_offsets]
+ result[:, hinge_out_offsets] = diff_hinge
+
+ return result
def joint_vel_rel(
@@ -68,8 +131,8 @@ def joint_vel_rel(
asset: Entity = env.scene[asset_cfg.name]
default_joint_vel = asset.data.default_joint_vel
assert default_joint_vel is not None
- jnt_ids = asset_cfg.joint_ids
- return asset.data.joint_vel[:, jnt_ids] - default_joint_vel[:, jnt_ids]
+ v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids)
+ return asset.data.joint_vel[:, v_indices] - default_joint_vel[:, v_indices]
##
diff --git a/src/mjlab/envs/mdp/rewards.py b/src/mjlab/envs/mdp/rewards.py
index 300433da0..59c8eb807 100644
--- a/src/mjlab/envs/mdp/rewards.py
+++ b/src/mjlab/envs/mdp/rewards.py
@@ -9,6 +9,7 @@
from mjlab.entity import Entity
from mjlab.managers.reward_manager import RewardTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg
+from mjlab.utils.lab_api.math import quat_box_minus
from mjlab.utils.lab_api.string import (
resolve_matching_names_values,
)
@@ -44,7 +45,8 @@ def joint_vel_l2(
) -> torch.Tensor:
"""Penalize joint velocities on the articulation using L2 squared kernel."""
asset: Entity = env.scene[asset_cfg.name]
- return torch.sum(torch.square(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1)
+ v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids)
+ return torch.sum(torch.square(asset.data.joint_vel[:, v_indices]), dim=1)
def joint_acc_l2(
@@ -52,7 +54,8 @@ def joint_acc_l2(
) -> torch.Tensor:
"""Penalize joint accelerations on the articulation using L2 squared kernel."""
asset: Entity = env.scene[asset_cfg.name]
- return torch.sum(torch.square(asset.data.joint_acc[:, asset_cfg.joint_ids]), dim=1)
+ v_indices = asset.indexing.expand_to_v_indices(asset_cfg.joint_ids)
+ return torch.sum(torch.square(asset.data.joint_acc[:, v_indices]), dim=1)
def action_rate_l2(env: ManagerBasedRlEnv) -> torch.Tensor:
@@ -79,14 +82,25 @@ def joint_pos_limits(
asset: Entity = env.scene[asset_cfg.name]
soft_joint_pos_limits = asset.data.soft_joint_pos_limits
assert soft_joint_pos_limits is not None
- out_of_limits = -(
- asset.data.joint_pos[:, asset_cfg.joint_ids]
- - soft_joint_pos_limits[:, asset_cfg.joint_ids, 0]
- ).clip(max=0.0)
- out_of_limits += (
- asset.data.joint_pos[:, asset_cfg.joint_ids]
- - soft_joint_pos_limits[:, asset_cfg.joint_ids, 1]
- ).clip(min=0.0)
+
+ q_indices = asset.indexing.expand_to_q_indices(asset_cfg.joint_ids)
+ joint_pos = asset.data.joint_pos[:, q_indices]
+
+ joint_ids = asset_cfg.joint_ids
+ if isinstance(joint_ids, slice):
+ joint_ids_tensor = torch.arange(asset.num_joints, device=env.device)
+ elif isinstance(joint_ids, list):
+ joint_ids_tensor = torch.tensor(joint_ids, device=env.device)
+ else:
+ joint_ids_tensor = joint_ids
+
+ # Expand limits to qpos dimensions for ball joints.
+ limits = soft_joint_pos_limits[:, joint_ids_tensor]
+ qpos_widths = asset.indexing.joint_qpos_widths[joint_ids_tensor]
+ expanded_limits = limits.repeat_interleave(qpos_widths, dim=1)
+
+ out_of_limits = -(joint_pos - expanded_limits[..., 0]).clip(max=0.0)
+ out_of_limits += (joint_pos - expanded_limits[..., 1]).clip(min=0.0)
return torch.sum(out_of_limits, dim=1)
@@ -103,9 +117,8 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv):
assert default_joint_pos is not None
self.default_joint_pos = default_joint_pos
- _, joint_names = asset.find_joints(
- cfg.params["asset_cfg"].joint_names,
- )
+ joint_ids, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names)
+ self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long)
_, _, std = resolve_matching_names_values(
data=cfg.params["std"],
@@ -113,15 +126,66 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv):
)
self.std = torch.tensor(std, device=env.device, dtype=torch.float32)
+ self._joint_dof_widths = asset.indexing.joint_dof_widths[self._joint_ids]
+ self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids]
+ self._q_offsets = asset.indexing.q_offsets[self._joint_ids]
+ self._is_ball_joint = (self._joint_qpos_widths == 4) & (self._joint_dof_widths == 3)
+
+ # Precompute cumulative DOF offsets for output indexing.
+ self._dof_cumsum = torch.cat(
+ [
+ torch.zeros(1, device=env.device, dtype=self._joint_dof_widths.dtype),
+ self._joint_dof_widths.cumsum(0)[:-1],
+ ]
+ )
+ self._std_expanded = self.std.repeat_interleave(self._joint_dof_widths)
+
def __call__(
self, env: ManagerBasedRlEnv, std, asset_cfg: SceneEntityCfg
) -> torch.Tensor:
del std # Unused.
+
asset: Entity = env.scene[asset_cfg.name]
- current_joint_pos = asset.data.joint_pos[:, asset_cfg.joint_ids]
- desired_joint_pos = self.default_joint_pos[:, asset_cfg.joint_ids]
- error_squared = torch.square(current_joint_pos - desired_joint_pos)
- return torch.exp(-torch.mean(error_squared / (self.std**2), dim=1))
+ joint_pos = asset.data.joint_pos
+ num_envs = joint_pos.shape[0]
+ device = joint_pos.device
+
+ total_dof = int(self._joint_dof_widths.sum().item())
+ error = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype)
+
+ # Ball joints: quaternion difference.
+ if self._is_ball_joint.any():
+ ball_q_offsets = self._q_offsets[self._is_ball_joint]
+ ball_out_offsets = self._dof_cumsum[self._is_ball_joint]
+
+ ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device)
+ current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4)
+ default_quats = self.default_joint_pos[:, ball_q_indices.flatten()].view(
+ num_envs, -1, 4
+ )
+
+ num_ball = current_quats.shape[1]
+ diff_flat = quat_box_minus(
+ current_quats.reshape(-1, 4), default_quats.reshape(-1, 4)
+ )
+ diff = diff_flat.view(num_envs, num_ball, 3)
+
+ ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device)
+ error[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1)
+
+ # Hinge joints: scalar subtraction.
+ is_hinge_joint = ~self._is_ball_joint
+ if is_hinge_joint.any():
+ hinge_q_offsets = self._q_offsets[is_hinge_joint]
+ hinge_out_offsets = self._dof_cumsum[is_hinge_joint]
+
+ diff_hinge = (
+ joint_pos[:, hinge_q_offsets] - self.default_joint_pos[:, hinge_q_offsets]
+ )
+ error[:, hinge_out_offsets] = diff_hinge
+
+ error_squared = torch.square(error)
+ return torch.exp(-torch.mean(error_squared / (self._std_expanded**2), dim=1))
class electrical_power_cost:
@@ -136,13 +200,14 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv):
actuator_ids, _ = asset.find_actuators(
cfg.params["asset_cfg"].joint_names,
)
- self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long)
+ joint_ids_tensor = torch.tensor(joint_ids, device=env.device, dtype=torch.long)
+ self._v_indices = asset.indexing.expand_to_v_indices(joint_ids_tensor)
self._actuator_ids = torch.tensor(actuator_ids, device=env.device, dtype=torch.long)
def __call__(self, env: ManagerBasedRlEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
asset: Entity = env.scene[asset_cfg.name]
tau = asset.data.actuator_force[:, self._actuator_ids]
- qd = asset.data.joint_vel[:, self._joint_ids]
+ qd = asset.data.joint_vel[:, self._v_indices]
mech = tau * qd
mech_pos = torch.clamp(mech, min=0.0) # Don't penalize regen.
return torch.sum(mech_pos, dim=1)
diff --git a/src/mjlab/tasks/manipulation/mdp/rewards.py b/src/mjlab/tasks/manipulation/mdp/rewards.py
index b5792cf94..1430d605f 100644
--- a/src/mjlab/tasks/manipulation/mdp/rewards.py
+++ b/src/mjlab/tasks/manipulation/mdp/rewards.py
@@ -64,6 +64,7 @@ def joint_velocity_hinge_penalty(
penalty, shaped as the negative squared L2 norm of the excess velocities.
"""
robot: Entity = env.scene[asset_cfg.name]
- joint_vel = robot.data.joint_vel[:, asset_cfg.joint_ids]
+ v_indices = robot.indexing.expand_to_v_indices(asset_cfg.joint_ids)
+ joint_vel = robot.data.joint_vel[:, v_indices]
excess = (joint_vel.abs() - max_vel).clamp_min(0.0)
return (excess**2).sum(dim=-1)
diff --git a/src/mjlab/tasks/velocity/mdp/rewards.py b/src/mjlab/tasks/velocity/mdp/rewards.py
index 772dad74c..1a11887fc 100644
--- a/src/mjlab/tasks/velocity/mdp/rewards.py
+++ b/src/mjlab/tasks/velocity/mdp/rewards.py
@@ -8,7 +8,7 @@
from mjlab.managers.reward_manager import RewardTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.sensor import BuiltinSensor, ContactSensor
-from mjlab.utils.lab_api.math import quat_apply_inverse
+from mjlab.utils.lab_api.math import quat_apply_inverse, quat_box_minus
from mjlab.utils.lab_api.string import (
resolve_matching_names_values,
)
@@ -322,27 +322,36 @@ def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv):
assert default_joint_pos is not None
self.default_joint_pos = default_joint_pos
- _, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names)
+ joint_ids, joint_names = asset.find_joints(cfg.params["asset_cfg"].joint_names)
+ self._joint_ids = torch.tensor(joint_ids, device=env.device, dtype=torch.long)
_, _, std_standing = resolve_matching_names_values(
- data=cfg.params["std_standing"],
- list_of_strings=joint_names,
+ data=cfg.params["std_standing"], list_of_strings=joint_names
+ )
+ _, _, std_walking = resolve_matching_names_values(
+ data=cfg.params["std_walking"], list_of_strings=joint_names
+ )
+ _, _, std_running = resolve_matching_names_values(
+ data=cfg.params["std_running"], list_of_strings=joint_names
)
self.std_standing = torch.tensor(
std_standing, device=env.device, dtype=torch.float32
)
-
- _, _, std_walking = resolve_matching_names_values(
- data=cfg.params["std_walking"],
- list_of_strings=joint_names,
- )
self.std_walking = torch.tensor(std_walking, device=env.device, dtype=torch.float32)
+ self.std_running = torch.tensor(std_running, device=env.device, dtype=torch.float32)
- _, _, std_running = resolve_matching_names_values(
- data=cfg.params["std_running"],
- list_of_strings=joint_names,
+ self._joint_dof_widths = asset.indexing.joint_dof_widths[self._joint_ids]
+ self._joint_qpos_widths = asset.indexing.joint_qpos_widths[self._joint_ids]
+ self._q_offsets = asset.indexing.q_offsets[self._joint_ids]
+ self._is_ball_joint = (self._joint_qpos_widths == 4) & (self._joint_dof_widths == 3)
+
+ # Precompute cumulative DOF offsets for output indexing.
+ self._dof_cumsum = torch.cat(
+ [
+ torch.zeros(1, device=env.device, dtype=self._joint_dof_widths.dtype),
+ self._joint_dof_widths.cumsum(0)[:-1],
+ ]
)
- self.std_running = torch.tensor(std_running, device=env.device, dtype=torch.float32)
def __call__(
self,
@@ -361,9 +370,7 @@ def __call__(
command = env.command_manager.get_command(command_name)
assert command is not None
- linear_speed = torch.norm(command[:, :2], dim=1)
- angular_speed = torch.abs(command[:, 2])
- total_speed = linear_speed + angular_speed
+ total_speed = torch.norm(command[:, :2], dim=1) + torch.abs(command[:, 2])
standing_mask = (total_speed < walking_threshold).float()
walking_mask = (
@@ -371,14 +378,54 @@ def __call__(
).float()
running_mask = (total_speed >= running_threshold).float()
+ # Expand std to qpos dimensions for ball joints.
+ std_standing_exp = self.std_standing.repeat_interleave(self._joint_dof_widths)
+ std_walking_exp = self.std_walking.repeat_interleave(self._joint_dof_widths)
+ std_running_exp = self.std_running.repeat_interleave(self._joint_dof_widths)
+
std = (
- self.std_standing * standing_mask.unsqueeze(1)
- + self.std_walking * walking_mask.unsqueeze(1)
- + self.std_running * running_mask.unsqueeze(1)
+ std_standing_exp * standing_mask.unsqueeze(1)
+ + std_walking_exp * walking_mask.unsqueeze(1)
+ + std_running_exp * running_mask.unsqueeze(1)
)
- current_joint_pos = asset.data.joint_pos[:, asset_cfg.joint_ids]
- desired_joint_pos = self.default_joint_pos[:, asset_cfg.joint_ids]
- error_squared = torch.square(current_joint_pos - desired_joint_pos)
-
+ joint_pos = asset.data.joint_pos
+ num_envs = joint_pos.shape[0]
+ device = joint_pos.device
+
+ total_dof = int(self._joint_dof_widths.sum().item())
+ error = torch.zeros((num_envs, total_dof), device=device, dtype=joint_pos.dtype)
+
+ # Ball joints: quaternion difference.
+ if self._is_ball_joint.any():
+ ball_q_offsets = self._q_offsets[self._is_ball_joint]
+ ball_out_offsets = self._dof_cumsum[self._is_ball_joint]
+
+ ball_q_indices = ball_q_offsets.unsqueeze(1) + torch.arange(4, device=device)
+ current_quats = joint_pos[:, ball_q_indices.flatten()].view(num_envs, -1, 4)
+ default_quats = self.default_joint_pos[:, ball_q_indices.flatten()].view(
+ num_envs, -1, 4
+ )
+
+ num_ball = current_quats.shape[1]
+ diff_flat = quat_box_minus(
+ current_quats.reshape(-1, 4), default_quats.reshape(-1, 4)
+ )
+ diff = diff_flat.view(num_envs, num_ball, 3)
+
+ ball_out_indices = ball_out_offsets.unsqueeze(1) + torch.arange(3, device=device)
+ error[:, ball_out_indices.flatten()] = diff.reshape(num_envs, -1)
+
+ # Hinge joints: scalar subtraction.
+ is_hinge_joint = ~self._is_ball_joint
+ if is_hinge_joint.any():
+ hinge_q_offsets = self._q_offsets[is_hinge_joint]
+ hinge_out_offsets = self._dof_cumsum[is_hinge_joint]
+
+ diff_hinge = (
+ joint_pos[:, hinge_q_offsets] - self.default_joint_pos[:, hinge_q_offsets]
+ )
+ error[:, hinge_out_offsets] = diff_hinge
+
+ error_squared = torch.square(error)
return torch.exp(-torch.mean(error_squared / (std**2), dim=1))
diff --git a/src/mjlab/utils/string.py b/src/mjlab/utils/string.py
index e57d55c03..7f12df9cb 100644
--- a/src/mjlab/utils/string.py
+++ b/src/mjlab/utils/string.py
@@ -21,6 +21,46 @@ def resolve_expr(
return tuple(result)
+def resolve_expr_with_widths(
+ pattern_map: dict[str, Any],
+ names: tuple[str, ...],
+ widths: tuple[int, ...],
+ default_vals: tuple[tuple[float, ...], ...],
+) -> tuple[float, ...]:
+ """Resolve field values accounting for per-joint widths (ball=4, hinge/slide=1).
+
+ Scalars for ball joints use the default (identity quaternion) for backward
+ compatibility with {".*": 0.0} patterns.
+ """
+ patterns = [(re.compile(pat), val) for pat, val in pattern_map.items()]
+
+ result: list[float] = []
+ for name, width, default in zip(names, widths, default_vals, strict=True):
+ matched = False
+ for pat, val in patterns:
+ if pat.match(name):
+ if isinstance(val, (tuple, list)):
+ if len(val) != width:
+ raise ValueError(f"Joint '{name}' expects {width} values, got {len(val)}")
+ result.extend(val)
+ matched = True
+ else:
+ # Scalar value.
+ if width == 1:
+ result.append(val)
+ matched = True
+ else:
+ # Scalar for ball joint - use default (identity quaternion).
+ # This handles backward compatibility with {".*": 0.0} patterns.
+ result.extend(default)
+ matched = True
+ break
+ if not matched:
+ result.extend(default)
+
+ return tuple(result)
+
+
def filter_exp(
exprs: list[str] | tuple[str, ...], names: tuple[str, ...]
) -> tuple[str, ...]:
diff --git a/tests/test_ball_joint.py b/tests/test_ball_joint.py
new file mode 100644
index 000000000..82e0d6943
--- /dev/null
+++ b/tests/test_ball_joint.py
@@ -0,0 +1,440 @@
+"""Tests for ball joint support in the Entity system."""
+
+import mujoco
+import pytest
+import torch
+from conftest import get_test_device, initialize_entity
+
+from mjlab.entity import Entity, EntityCfg
+
+# XML with ball joints for testing.
+BALL_JOINT_XML = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+# XML with only hinge joints for comparison.
+HINGE_ONLY_XML = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+@pytest.fixture(scope="module")
+def device():
+ """Test device fixture."""
+ return get_test_device()
+
+
+def create_ball_joint_entity():
+ """Create an entity with ball joints."""
+ cfg = EntityCfg(spec_fn=lambda: mujoco.MjSpec.from_string(BALL_JOINT_XML))
+ return Entity(cfg)
+
+
+def create_hinge_only_entity():
+ """Create an entity with only hinge joints."""
+ cfg = EntityCfg(spec_fn=lambda: mujoco.MjSpec.from_string(HINGE_ONLY_XML))
+ return Entity(cfg)
+
+
+def test_nq_nv_with_ball_joints(device):
+ """Test that nq and nv are correctly computed for ball joints.
+
+ Ball joints have 4 qpos (quaternion) and 3 qvel (angular velocity).
+ Hinge joints have 1 qpos and 1 qvel.
+ """
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ # Entity has: ball1 (4 qpos, 3 dof), hinge1 (1 qpos, 1 dof), ball2 (4 qpos, 3 dof)
+ # Total: 4 + 1 + 4 = 9 qpos, 3 + 1 + 3 = 7 dof
+ assert entity.num_joints == 3
+ assert entity.nq == 9, f"Expected nq=9, got {entity.nq}"
+ assert entity.nv == 7, f"Expected nv=7, got {entity.nv}"
+
+ # Check indexing fields match.
+ assert entity.indexing.nq == 9
+ assert entity.indexing.nv == 7
+
+
+def test_nq_nv_hinge_only(device):
+ """Test that nq and nv equal num_joints for hinge-only entities."""
+ entity = create_hinge_only_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ # Entity has: hinge1 (1 qpos, 1 dof), hinge2 (1 qpos, 1 dof)
+ assert entity.num_joints == 2
+ assert entity.nq == 2
+ assert entity.nv == 2
+
+
+def test_joint_offset_tensors(device):
+ """Test that q_offsets and v_offsets are correctly computed."""
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ indexing = entity.indexing
+
+ # Expected cumulative offsets: ball1(4,3), hinge1(1,1), ball2(4,3)
+ expected_q_offsets = torch.tensor([0, 4, 5], dtype=torch.int, device=device)
+ expected_v_offsets = torch.tensor([0, 3, 4], dtype=torch.int, device=device)
+
+ assert torch.equal(indexing.q_offsets, expected_q_offsets)
+ assert torch.equal(indexing.v_offsets, expected_v_offsets)
+
+ # Test expand_to_q_indices for single joint (ball1).
+ q_indices = indexing.expand_to_q_indices(torch.tensor([0], device=device))
+ assert isinstance(q_indices, torch.Tensor)
+ assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3], device=device))
+
+ # Test expand_to_v_indices for single joint (ball1).
+ v_indices = indexing.expand_to_v_indices(torch.tensor([0], device=device))
+ assert isinstance(v_indices, torch.Tensor)
+ assert torch.equal(v_indices, torch.tensor([0, 1, 2], device=device))
+
+ # Test expand for hinge joint (joint 1).
+ q_indices = indexing.expand_to_q_indices(torch.tensor([1], device=device))
+ assert isinstance(q_indices, torch.Tensor)
+ assert torch.equal(q_indices, torch.tensor([4], device=device))
+
+ # Test expand for multiple joints (ball1, ball2).
+ q_indices = indexing.expand_to_q_indices(torch.tensor([0, 2], device=device))
+ assert isinstance(q_indices, torch.Tensor)
+ assert torch.equal(q_indices, torch.tensor([0, 1, 2, 3, 5, 6, 7, 8], device=device))
+
+
+def test_joint_qpos_widths(device):
+ """Test that joint_qpos_widths and joint_dof_widths are correct."""
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ indexing = entity.indexing
+
+ # ball1, hinge1, ball2
+ expected_qpos_widths = torch.tensor([4, 1, 4], dtype=torch.int, device=device)
+ expected_dof_widths = torch.tensor([3, 1, 3], dtype=torch.int, device=device)
+
+ assert torch.equal(indexing.joint_qpos_widths, expected_qpos_widths)
+ assert torch.equal(indexing.joint_dof_widths, expected_dof_widths)
+
+
+def test_ball_joint_initial_state(device):
+ """Test that ball joints default to identity quaternion (1, 0, 0, 0)."""
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ # Check default joint positions have correct shape.
+ default_joint_pos = entity.data.default_joint_pos
+ assert default_joint_pos.shape == (1, 9), (
+ f"Expected shape (1, 9), got {default_joint_pos.shape}"
+ )
+
+ # Ball joint 1 (indices 0:4) should be identity quaternion.
+ ball1_quat = default_joint_pos[0, 0:4]
+ expected_identity = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)
+ assert torch.allclose(ball1_quat, expected_identity, atol=1e-6)
+
+ # Hinge joint (index 4) should be 0.
+ hinge_pos = default_joint_pos[0, 4]
+ assert abs(hinge_pos.item()) < 1e-6
+
+ # Ball joint 2 (indices 5:9) should be identity quaternion.
+ ball2_quat = default_joint_pos[0, 5:9]
+ assert torch.allclose(ball2_quat, expected_identity, atol=1e-6)
+
+
+def test_default_joint_vel_shape(device):
+ """Test that default_joint_vel has correct shape (nv, not num_joints)."""
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ default_joint_vel = entity.data.default_joint_vel
+ assert default_joint_vel.shape == (1, 7), (
+ f"Expected shape (1, 7), got {default_joint_vel.shape}"
+ )
+
+
+def test_joint_pos_target_shape(device):
+ """Test that joint_pos_target has correct shape (nq) for non-actuated entity."""
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ # Non-actuated entities have empty target tensors.
+ joint_pos_target = entity.data.joint_pos_target
+ assert joint_pos_target.shape == (1, 0), (
+ f"Expected shape (1, 0) for non-actuated entity, got {joint_pos_target.shape}"
+ )
+
+
+def test_joint_vel_target_shape(device):
+ """Test that joint_vel_target has correct shape (nv) for non-actuated entity."""
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ # Non-actuated entities have empty target tensors.
+ joint_vel_target = entity.data.joint_vel_target
+ assert joint_vel_target.shape == (1, 0), (
+ f"Expected shape (1, 0) for non-actuated entity, got {joint_vel_target.shape}"
+ )
+
+
+def test_write_joint_position_with_ball_joints(device):
+ """Test writing joint positions with ball joints."""
+ entity = create_ball_joint_entity()
+ entity, sim = initialize_entity(entity, device)
+
+ # Write all joint positions.
+ # ball1 (4) + hinge1 (1) + ball2 (4) = 9 values
+ new_pos = torch.tensor(
+ [
+ [
+ 0.7071,
+ 0.7071,
+ 0.0,
+ 0.0, # ball1: 90 degree rotation around x
+ 0.5, # hinge1
+ 1.0,
+ 0.0,
+ 0.0,
+ 0.0, # ball2: identity
+ ]
+ ],
+ device=device,
+ )
+
+ entity.data.write_joint_position(new_pos)
+ sim.forward()
+
+ # Read back and verify.
+ joint_pos = entity.data.joint_pos
+ assert torch.allclose(joint_pos, new_pos, atol=1e-4)
+
+
+def test_write_joint_velocity_with_ball_joints(device):
+ """Test writing joint velocities with ball joints."""
+ entity = create_ball_joint_entity()
+ entity, sim = initialize_entity(entity, device)
+
+ # Write all joint velocities.
+ # ball1 (3) + hinge1 (1) + ball2 (3) = 7 values
+ new_vel = torch.tensor(
+ [
+ [
+ 0.1,
+ 0.2,
+ 0.3, # ball1 angular velocity
+ 0.5, # hinge1 velocity
+ 0.0,
+ 0.0,
+ 0.0, # ball2 angular velocity
+ ]
+ ],
+ device=device,
+ )
+
+ entity.data.write_joint_velocity(new_vel)
+ sim.forward()
+
+ # Read back and verify.
+ joint_vel = entity.data.joint_vel
+ assert torch.allclose(joint_vel, new_vel, atol=1e-4)
+
+
+##
+# MDP function tests with ball joints.
+##
+
+
+def test_mdp_joint_pos_rel_with_ball_joints(device):
+ """Test joint_pos_rel observation with ball joints."""
+ from unittest.mock import Mock
+
+ from mjlab.envs.mdp import observations
+ from mjlab.managers.scene_entity_config import SceneEntityCfg
+
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ env = Mock()
+ env.scene = {"robot": entity}
+
+ asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None))
+
+ # Get relative joint positions.
+ result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg)
+
+ # Should have shape (1, nv=7) in DOF space, not (1, nq=9) in qpos space.
+ # Ball joints contribute 3 DOF (axis-angle) instead of 4 qpos (quaternion).
+ assert result.shape == (1, 7), f"Expected (1, 7), got {result.shape}"
+
+ # At default position, all relative positions should be zero.
+ assert torch.allclose(result, torch.zeros_like(result), atol=1e-6)
+
+
+def test_mdp_joint_vel_rel_with_ball_joints(device):
+ """Test joint_vel_rel observation with ball joints."""
+ from unittest.mock import Mock
+
+ from mjlab.envs.mdp import observations
+ from mjlab.managers.scene_entity_config import SceneEntityCfg
+
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ env = Mock()
+ env.scene = {"robot": entity}
+
+ asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None))
+
+ # Get relative joint velocities.
+ result = observations.joint_vel_rel(env, asset_cfg=asset_cfg)
+
+ # Should have shape (1, nv=7), not (1, num_joints=3).
+ assert result.shape == (1, 7), f"Expected (1, 7), got {result.shape}"
+
+
+def test_mdp_joint_vel_l2_with_ball_joints(device):
+ """Test joint_vel_l2 reward with ball joints."""
+ from unittest.mock import Mock
+
+ from mjlab.envs.mdp import rewards
+ from mjlab.managers.scene_entity_config import SceneEntityCfg
+
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ env = Mock()
+ env.scene = {"robot": entity}
+
+ asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None))
+
+ # Compute joint velocity L2 penalty.
+ result = rewards.joint_vel_l2(env, asset_cfg=asset_cfg)
+
+ # Should return scalar per env.
+ assert result.shape == (1,), f"Expected (1,), got {result.shape}"
+
+
+def test_mdp_joint_acc_l2_with_ball_joints(device):
+ """Test joint_acc_l2 reward with ball joints."""
+ from unittest.mock import Mock
+
+ from mjlab.envs.mdp import rewards
+ from mjlab.managers.scene_entity_config import SceneEntityCfg
+
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ env = Mock()
+ env.scene = {"robot": entity}
+
+ asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None))
+
+ # Compute joint acceleration L2 penalty.
+ result = rewards.joint_acc_l2(env, asset_cfg=asset_cfg)
+
+ # Should return scalar per env.
+ assert result.shape == (1,), f"Expected (1,), got {result.shape}"
+
+
+def test_mdp_joint_pos_limits_with_ball_joints(device):
+ """Test joint_pos_limits reward with ball joints."""
+ from unittest.mock import Mock
+
+ from mjlab.envs.mdp import rewards
+ from mjlab.managers.scene_entity_config import SceneEntityCfg
+
+ entity = create_ball_joint_entity()
+ entity, _ = initialize_entity(entity, device)
+
+ env = Mock()
+ env.device = device
+ env.scene = {"robot": entity}
+
+ asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None))
+
+ # Compute joint position limits penalty.
+ result = rewards.joint_pos_limits(env, asset_cfg=asset_cfg)
+
+ # Should return scalar per env.
+ assert result.shape == (1,), f"Expected (1,), got {result.shape}"
+
+
+def test_quaternion_difference_correctness(device):
+ """Test that joint_pos_rel computes correct quaternion difference.
+
+ For a 90-degree rotation around z-axis, the axis-angle representation
+ should be [0, 0, pi/2] (rotation axis z, angle 90 degrees).
+ """
+ import math
+ from unittest.mock import Mock
+
+ from mjlab.envs.mdp import observations
+ from mjlab.managers.scene_entity_config import SceneEntityCfg
+
+ entity = create_ball_joint_entity()
+ entity, sim = initialize_entity(entity, device)
+
+ # Apply a 90-degree rotation around z-axis to ball joint 1.
+ # Quaternion for 90 deg around z: (cos(45°), 0, 0, sin(45°)) = (0.7071, 0, 0, 0.7071)
+ angle = math.pi / 2
+ half_angle = angle / 2
+ new_pos = entity.data.joint_pos.clone()
+ new_pos[0, 0:4] = torch.tensor(
+ [math.cos(half_angle), 0.0, 0.0, math.sin(half_angle)], device=device
+ )
+ entity.data.write_joint_position(new_pos)
+ sim.forward()
+
+ env = Mock()
+ env.scene = {"robot": entity}
+
+ asset_cfg = SceneEntityCfg("robot", joint_ids=slice(None))
+ result = observations.joint_pos_rel(env, biased=False, asset_cfg=asset_cfg)
+
+ # Ball joint 1 (indices 0:3 in DOF space) should have axis-angle [0, 0, pi/2].
+ ball1_diff = result[0, 0:3]
+ expected_axis_angle = torch.tensor([0.0, 0.0, math.pi / 2], device=device)
+ assert torch.allclose(ball1_diff, expected_axis_angle, atol=1e-4), (
+ f"Expected axis-angle {expected_axis_angle}, got {ball1_diff}"
+ )
+
+ # Hinge joint (index 3 in DOF space) should be unchanged (0).
+ assert abs(result[0, 3].item()) < 1e-6
+
+ # Ball joint 2 (indices 4:7 in DOF space) should be unchanged (identity -> 0).
+ ball2_diff = result[0, 4:7]
+ assert torch.allclose(ball2_diff, torch.zeros(3, device=device), atol=1e-6)
diff --git a/tests/test_events.py b/tests/test_events.py
index 3f66c9f13..82eba6a9b 100644
--- a/tests/test_events.py
+++ b/tests/test_events.py
@@ -428,6 +428,9 @@ def test_reset_joints_by_offset(device):
env.device = device
mock_entity = Mock()
+ mock_entity.num_joints = 3
+ mock_entity.nq = 3
+ mock_entity.nv = 3
mock_entity.data.default_joint_pos = torch.zeros((2, 3), device=device)
mock_entity.data.default_joint_vel = torch.zeros((2, 3), device=device)
mock_entity.data.soft_joint_pos_limits = torch.tensor(
@@ -438,6 +441,15 @@ def test_reset_joints_by_offset(device):
device=device,
)
mock_entity.write_joint_state_to_sim = Mock()
+
+ # Mock indexing for ball joint support (all hinge joints: 1 qpos, 1 dof each).
+ mock_entity.indexing.expand_to_q_indices = (
+ lambda x: slice(None) if isinstance(x, slice) else x
+ )
+ mock_entity.indexing.expand_to_v_indices = (
+ lambda x: slice(None) if isinstance(x, slice) else x
+ )
+ mock_entity.indexing.joint_qpos_widths = torch.ones(3, dtype=torch.int, device=device)
env.scene = {"robot": mock_entity}
# Normal offset.
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index bd3708bc7..571022674 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -199,7 +199,8 @@ def test_electrical_power_cost_partially_actuated(device):
reward = electrical_power_cost(reward_cfg, env)
- assert len(reward._joint_ids) == 2
+ assert isinstance(reward._v_indices, torch.Tensor)
+ assert len(reward._v_indices) == 2
assert len(reward._actuator_ids) == 2
# Test case 1: All forces and velocities aligned (all positive work).