Skip to content
22 changes: 17 additions & 5 deletions src/mjlab/actuator/actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(
self._global_ctrl_ids: torch.Tensor | None = None
self._mjs_actuators: list[mujoco.MjsActuator] = []
self._site_zeros: torch.Tensor | None = None
# Expanded indices for ball joint support.
self._q_indices: torch.Tensor | None = None
self._v_indices: torch.Tensor | None = None

@property
def target_ids(self) -> torch.Tensor:
Expand Down Expand Up @@ -189,6 +192,15 @@ def initialize(
device=device,
)

# Compute expanded indices for ball joint support.
if self.transmission_type == TransmissionType.JOINT:
indexing = self.entity.indexing
q_indices = indexing.expand_to_q_indices(self._target_ids_list)
v_indices = indexing.expand_to_v_indices(self._target_ids_list)
assert isinstance(q_indices, torch.Tensor) and isinstance(v_indices, torch.Tensor)
self._q_indices = q_indices
self._v_indices = v_indices

# Pre-allocate zeros for SITE transmission type to avoid repeated allocations.
if self.transmission_type == TransmissionType.SITE:
nenvs = data.nworld
Expand All @@ -206,11 +218,11 @@ def get_command(self, data: EntityData) -> ActuatorCmd:
"""
if self.transmission_type == TransmissionType.JOINT:
return ActuatorCmd(
position_target=data.joint_pos_target[:, self.target_ids],
velocity_target=data.joint_vel_target[:, self.target_ids],
effort_target=data.joint_effort_target[:, self.target_ids],
pos=data.joint_pos[:, self.target_ids],
vel=data.joint_vel[:, self.target_ids],
position_target=data.joint_pos_target[:, self._q_indices],
velocity_target=data.joint_vel_target[:, self._v_indices],
effort_target=data.joint_effort_target[:, self._v_indices],
pos=data.joint_pos[:, self._q_indices],
vel=data.joint_vel[:, self._v_indices],
)
elif self.transmission_type == TransmissionType.TENDON:
return ActuatorCmd(
Expand Down
46 changes: 35 additions & 11 deletions src/mjlab/actuator/builtin_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
(BuiltinMuscleActuator, TransmissionType.TENDON): "tendon_effort_target",
}

# Indicates whether the actuator type uses qpos indexing (vs qvel indexing).
# Position actuators read from qpos-indexed tensors (nq dimension).
# Velocity/motor actuators read from qvel-indexed tensors (nv dimension).
_USES_QPOS_INDEXING: dict[type[BuiltinActuatorType], bool] = {
BuiltinPositionActuator: True,
BuiltinVelocityActuator: False,
BuiltinMotorActuator: False,
BuiltinMuscleActuator: False,
}


@dataclass(frozen=True)
class BuiltinActuatorGroup:
Expand All @@ -47,7 +57,9 @@ class BuiltinActuatorGroup:
enables direct writes without per-actuator overhead.
"""

# Map from (BuiltinActuator type, transmission_type) to (target_ids, ctrl_ids).
# Map from (BuiltinActuator type, transmission_type) to (expanded_indices, ctrl_ids).
# For JOINT transmission, expanded_indices are qpos or qvel indices depending
# on the actuator type. For other transmissions, they match target_ids.
_index_groups: dict[
tuple[type[BuiltinActuatorType], TransmissionType],
tuple[torch.Tensor, torch.Tensor],
Expand Down Expand Up @@ -84,17 +96,29 @@ def process(
else:
custom_actuators.append(act)

# Return stacked indices for each (actuator_type, transmission_type) group.
# Build stacked indices for each (actuator_type, transmission_type) group.
index_groups: dict[
tuple[type[BuiltinActuatorType], TransmissionType],
tuple[torch.Tensor, torch.Tensor],
] = {
key: (
torch.cat([act.target_ids for act in acts], dim=0),
torch.cat([act.ctrl_ids for act in acts], dim=0),
)
for key, acts in builtin_groups.items()
}
] = {}

for key, acts in builtin_groups.items():
actuator_type, transmission_type = key
ctrl_ids = torch.cat([act.ctrl_ids for act in acts], dim=0)

if transmission_type == TransmissionType.JOINT:
# Use expanded qpos/qvel indices for joint transmission.
uses_qpos = _USES_QPOS_INDEXING[actuator_type]
attr = "_q_indices" if uses_qpos else "_v_indices"
indices_list = [getattr(act, attr) for act in acts]
assert all(idx is not None for idx in indices_list)
expanded_indices = torch.cat(indices_list, dim=0) # type: ignore[arg-type]
index_groups[key] = (expanded_indices, ctrl_ids)
else:
# For tendon/site, use flat target_ids.
target_ids = torch.cat([act.target_ids for act in acts], dim=0)
index_groups[key] = (target_ids, ctrl_ids)

return BuiltinActuatorGroup(index_groups), tuple(custom_actuators)

def apply_controls(self, data: EntityData) -> None:
Expand All @@ -104,9 +128,9 @@ def apply_controls(self, data: EntityData) -> None:
data: Entity data containing targets and control arrays.
"""
for (actuator_type, transmission_type), (
target_ids,
expanded_indices,
ctrl_ids,
) in self._index_groups.items():
attr_name = _TARGET_TENSOR_MAP[(actuator_type, transmission_type)]
target_tensor = getattr(data, attr_name)
data.write_ctrl(target_tensor[:, target_ids], ctrl_ids)
data.write_ctrl(target_tensor[:, expanded_indices], ctrl_ids)
3 changes: 3 additions & 0 deletions src/mjlab/actuator/delayed_actuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def initialize(
self._target_ids = self._base_actuator._target_ids
self._ctrl_ids = self._base_actuator._ctrl_ids
self._global_ctrl_ids = self._base_actuator._global_ctrl_ids
# Copy expanded indices for ball joint support.
self._q_indices = self._base_actuator._q_indices
self._v_indices = self._base_actuator._v_indices

targets = (
(self.cfg.delay_target,)
Expand Down
17 changes: 11 additions & 6 deletions src/mjlab/entity/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class EntityData:
root_link_pose_w) require sim.forward() to be current. If you write then read,
call sim.forward() in between. Event order matters when mixing reads and writes.
All inputs/outputs use world frame.

Ball Joint Support:
Position tensors use nq dimensions (4 per ball, 1 per hinge/slide).
Velocity/effort tensors use nv dimensions (3 per ball, 1 per hinge/slide).
Joint limits remain per-joint: (nworld, num_joints, 2).
"""

indexing: EntityIndexing
Expand Down Expand Up @@ -155,9 +160,9 @@ def write_joint_position(
raise ValueError("Cannot write joint position for non-articulated entity.")

env_ids = self._resolve_env_ids(env_ids)
joint_ids = joint_ids if joint_ids is not None else slice(None)
q_slice = self.indexing.joint_q_adr[joint_ids]
self.data.qpos[env_ids, q_slice] = position
q_indices = self.indexing.expand_to_q_indices(joint_ids)
q_adr = self.indexing.joint_q_adr[q_indices]
self.data.qpos[env_ids, q_adr] = position

def write_joint_velocity(
self,
Expand All @@ -169,9 +174,9 @@ def write_joint_velocity(
raise ValueError("Cannot write joint velocity for non-articulated entity.")

env_ids = self._resolve_env_ids(env_ids)
joint_ids = joint_ids if joint_ids is not None else slice(None)
v_slice = self.indexing.joint_v_adr[joint_ids]
self.data.qvel[env_ids, v_slice] = velocity
v_indices = self.indexing.expand_to_v_indices(joint_ids)
v_adr = self.indexing.joint_v_adr[v_indices]
self.data.qvel[env_ids, v_adr] = velocity

def write_external_wrench(
self,
Expand Down
Loading