Skip to content

Commit 49350a3

Browse files
authored
Add cpu num to SimulationCfg & Add some interfaces to robot (#51)
1 parent f9f0103 commit 49350a3

File tree

4 files changed

+311
-10
lines changed

4 files changed

+311
-10
lines changed

embodichain/lab/sim/objects/articulation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,11 @@ def get_link_pose(
891891
return link_pose
892892

893893
def get_qpos(self) -> torch.Tensor:
894-
"""Get the current positions (qpos) of the articulation."""
894+
"""Get the current positions (qpos) of the articulation.
895+
896+
Returns:
897+
torch.Tensor: Joint positions with shape (N, dof), where N is the number of environments.
898+
"""
895899
return self.body_data.qpos
896900

897901
def set_qpos(

embodichain/lab/sim/objects/robot.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,107 @@ def get_joint_ids(
123123
else [i for i in self._joint_ids[name] if i not in self.mimic_ids]
124124
)
125125

126+
def get_link_names(self, name: str | None = None) -> Union[List[str], None]:
127+
"""Get the link names of the robot for a specific control part.
128+
129+
If no control part is specified, return all link names.
130+
131+
Args:
132+
name (str, optional): The name of the control part to get the link names for. If None, the default part is used.
133+
134+
Returns:
135+
List[str]: The link names of the robot for the specified control part.
136+
"""
137+
if not self.control_parts or name is None:
138+
return self.link_names
139+
140+
if name not in self.control_parts:
141+
logger.log_error(
142+
f"The control part '{name}' does not exist in the robot's control parts {self.control_parts}."
143+
)
144+
return self._control_groups[name].link_names
145+
146+
def get_qpos_limits(
147+
self, name: str | None = None, env_ids: Sequence[int] | None = None
148+
) -> torch.Tensor:
149+
"""Get the joint position limits (qpos) of the robot for a specific control part.
150+
151+
It returns all joint position limits if no control part is specified.
152+
153+
Args:
154+
name (str | None): The name of the control part to get the qpos limits for.
155+
env_ids (Sequence[int] | None): The environment ids to get the qpos limits for. If None, all environments are used.
156+
157+
Returns:
158+
torch.Tensor: Joint position limits with shape (N, dof, 2), where N is the number of environments.
159+
"""
160+
local_env_ids = self._all_indices if env_ids is None else env_ids
161+
162+
qpos_limits = self.body_data.qpos_limits
163+
if name is None:
164+
return qpos_limits[local_env_ids, :]
165+
else:
166+
if not self.control_parts or name not in self.control_parts:
167+
logger.log_error(
168+
f"The control part '{name}' does not exist in the robot's control parts."
169+
)
170+
part_joint_ids = self.get_joint_ids(name=name)
171+
return qpos_limits[local_env_ids][:, part_joint_ids, :]
172+
173+
def get_qvel_limits(
174+
self, name: str | None = None, env_ids: Sequence[int] | None = None
175+
) -> torch.Tensor:
176+
"""Get the joint velocity limits (qvel) of the robot for a specific control part.
177+
178+
It returns all joint velocity limits if no control part is specified.
179+
180+
Args:
181+
name (str | None): The name of the control part to get the qvel limits for.
182+
env_ids (Sequence[int] | None): The environment ids to get the qvel limits for. If None, all environments are used.
183+
184+
Returns:
185+
torch.Tensor: Joint velocity limits with shape (N, dof, 2), where N is the number of environments.
186+
"""
187+
local_env_ids = self._all_indices if env_ids is None else env_ids
188+
189+
qvel_limits = self.body_data.qvel_limits
190+
if name is None:
191+
return qvel_limits[local_env_ids, :]
192+
else:
193+
if not self.control_parts or name not in self.control_parts:
194+
logger.log_error(
195+
f"The control part '{name}' does not exist in the robot's control parts."
196+
)
197+
part_joint_ids = self.get_joint_ids(name=name)
198+
return qvel_limits[local_env_ids][:, part_joint_ids, :]
199+
200+
def get_qf_limits(
201+
self, name: str | None = None, env_ids: Sequence[int] | None = None
202+
) -> torch.Tensor:
203+
"""Get the joint effort limits (qf) of the robot for a specific control part.
204+
205+
It returns all joint effort limits if no control part is specified.
206+
207+
Args:
208+
name (str | None): The name of the control part to get the qf limits for.
209+
env_ids (Sequence[int] | None): The environment ids to get the qf limits for. If None, all environments are used.
210+
211+
Returns:
212+
torch.Tensor: Joint effort limits with shape (N, dof, 2), where N is the number of environments.
213+
"""
214+
local_env_ids = self._all_indices if env_ids is None else env_ids
215+
216+
qf_limits = self.body_data.qf_limits
217+
if name is None:
218+
return qf_limits[local_env_ids, :]
219+
else:
220+
if not self.control_parts or name not in self.control_parts:
221+
logger.log_error(
222+
f"The control part '{name}' does not exist in the robot's control parts."
223+
)
224+
part_joint_ids = self.get_joint_ids(name=name)
225+
return qf_limits[local_env_ids][:, part_joint_ids, :]
226+
126227
def get_proprioception(self) -> Dict[str, torch.Tensor]:
127228
"""Gets robot proprioception information, primarily for agent state representation in robot learning scenarios.
128229
@@ -139,6 +240,191 @@ def get_proprioception(self) -> Dict[str, torch.Tensor]:
139240
qpos=self.body_data.qpos, qvel=self.body_data.qvel, qf=self.body_data.qf
140241
)
141242

243+
def set_qpos(
244+
self,
245+
qpos: torch.Tensor,
246+
joint_ids: Sequence[int] | None = None,
247+
env_ids: Sequence[int] | None = None,
248+
target: bool = True,
249+
name: str | None = None,
250+
) -> None:
251+
"""Set the joint positions (qpos) or target positions for the articulation.
252+
253+
Args:
254+
qpos (torch.Tensor): Joint positions with shape (N, dof), where N is the number of environments.
255+
joint_ids (Sequence[int] | None, optional): Joint indices to apply the positions. If None, applies to all joints.
256+
env_ids (Sequence[int] | None): Environment indices to apply the positions. Defaults to all environments.
257+
target (bool): If True, sets target positions for simulation. If False, updates current positions directly.
258+
name (str | None): The name of the control part to set the qpos for. If None, the default part is used.
259+
260+
Raises:
261+
ValueError: If the length of `env_ids` does not match the length of `qpos`.
262+
"""
263+
if name is None:
264+
super().set_qpos(
265+
qpos=qpos,
266+
joint_ids=joint_ids,
267+
env_ids=env_ids,
268+
target=target,
269+
)
270+
else:
271+
if not self.control_parts or name not in self.control_parts:
272+
logger.log_error(
273+
f"The control part '{name}' does not exist in the robot's control parts."
274+
)
275+
part_joint_ids = self.get_joint_ids(name=name)
276+
if joint_ids is not None:
277+
logger.log_warning(f"`joint_ids` is ignored when `name` is specified.")
278+
279+
super().set_qpos(
280+
qpos=qpos,
281+
joint_ids=part_joint_ids,
282+
env_ids=env_ids,
283+
target=target,
284+
)
285+
286+
def get_qpos(self, name: str | None = None) -> torch.Tensor:
287+
"""Get the joint positions (qpos) of the robot.
288+
289+
Args:
290+
name (str | None): The name of the control part to get the qpos for. If None, the default part is used.
291+
Returns:
292+
torch.Tensor: Joint positions with shape (N, dof), where N is the number of environments.
293+
"""
294+
295+
qpos = super().get_qpos()
296+
if name is None:
297+
return qpos
298+
else:
299+
if not self.control_parts or name not in self.control_parts:
300+
logger.log_error(
301+
f"The control part '{name}' does not exist in the robot's control parts."
302+
)
303+
part_joint_ids = self.get_joint_ids(name=name)
304+
return qpos[:, part_joint_ids]
305+
306+
def set_qvel(
307+
self,
308+
qvel: torch.Tensor,
309+
joint_ids: Sequence[int] | None = None,
310+
env_ids: Sequence[int] | None = None,
311+
target: bool = True,
312+
name: str | None = None,
313+
) -> None:
314+
"""Set the joint velocities (qvel) or target velocities for the articulation.
315+
316+
Args:
317+
qvel (torch.Tensor): Joint velocities with shape (N, dof), where N is the number of environments.
318+
joint_ids (Sequence[int] | None, optional): Joint indices to apply the velocities. If None, applies to all joints.
319+
env_ids (Sequence[int] | None): Environment indices to apply the velocities. Defaults to all environments.
320+
target (bool): If True, sets target velocities for simulation. If False, updates current velocities directly.
321+
name (str | None): The name of the control part to set the qvel for. If None, the default part is used.
322+
323+
Raises:
324+
ValueError: If the length of `env_ids` does not match the length of `qvel`.
325+
"""
326+
if name is None:
327+
super().set_qvel(
328+
qvel=qvel,
329+
joint_ids=joint_ids,
330+
env_ids=env_ids,
331+
target=target,
332+
)
333+
else:
334+
if not self.control_parts or name not in self.control_parts:
335+
logger.log_error(
336+
f"The control part '{name}' does not exist in the robot's control parts."
337+
)
338+
part_joint_ids = self.get_joint_ids(name=name)
339+
if joint_ids is not None:
340+
logger.log_warning(f"`joint_ids` is ignored when `name` is specified.")
341+
342+
super().set_qvel(
343+
qvel=qvel,
344+
joint_ids=part_joint_ids,
345+
env_ids=env_ids,
346+
target=target,
347+
)
348+
349+
def get_qvel(self, name: str | None = None) -> torch.Tensor:
350+
"""Get the joint velocities (qvel) of the robot.
351+
352+
Args:
353+
name (str | None): The name of the control part to get the qvel for. If None, the default part is used.
354+
Returns:
355+
torch.Tensor: Joint velocities with shape (N, dof), where N is the number of environments.
356+
"""
357+
358+
qvel = super().get_qvel()
359+
if name is None:
360+
return qvel
361+
else:
362+
if not self.control_parts or name not in self.control_parts:
363+
logger.log_error(
364+
f"The control part '{name}' does not exist in the robot's control parts."
365+
)
366+
part_joint_ids = self.get_joint_ids(name=name)
367+
return qvel[:, part_joint_ids]
368+
369+
def set_qf(
370+
self,
371+
qf: torch.Tensor,
372+
joint_ids: Sequence[int] | None = None,
373+
env_ids: Sequence[int] | None = None,
374+
name: str | None = None,
375+
) -> None:
376+
"""Set the joint efforts (qf) for the articulation.
377+
378+
Args:
379+
qf (torch.Tensor): Joint efforts with shape (N, dof), where N is the number of environments.
380+
joint_ids (Sequence[int] | None, optional): Joint indices to apply the efforts. If None, applies to all joints.
381+
env_ids (Sequence[int] | None): Environment indices to apply the efforts. Defaults to all environments.
382+
name (str | None): The name of the control part to set the qf for. If None, the default part is used.
383+
384+
Raises:
385+
ValueError: If the length of `env_ids` does not match the length of `qf`.
386+
"""
387+
if name is None:
388+
super().set_qf(
389+
qf=qf,
390+
joint_ids=joint_ids,
391+
env_ids=env_ids,
392+
)
393+
else:
394+
if not self.control_parts or name not in self.control_parts:
395+
logger.log_error(
396+
f"The control part '{name}' does not exist in the robot's control parts."
397+
)
398+
part_joint_ids = self.get_joint_ids(name=name)
399+
if joint_ids is not None:
400+
logger.log_warning(f"`joint_ids` is ignored when `name` is specified.")
401+
402+
super().set_qf(
403+
qf=qf,
404+
joint_ids=part_joint_ids,
405+
env_ids=env_ids,
406+
)
407+
408+
def get_qf(self, name: str | None = None) -> torch.Tensor:
409+
"""Get the joint efforts (qf) of the robot.
410+
411+
Args:
412+
name (str | None): The name of the control part to get the qf for. If None, the default part is used.
413+
Returns:
414+
torch.Tensor: Joint efforts with shape (N, dof), where N is the number of environments.
415+
"""
416+
417+
qf = super().get_qf()
418+
if name is None:
419+
return qf
420+
else:
421+
if not self.control_parts or name not in self.control_parts:
422+
logger.log_error(
423+
f"The control part '{name}' does not exist in the robot's control parts."
424+
)
425+
part_joint_ids = self.get_joint_ids(name=name)
426+
return qf[:, part_joint_ids]
427+
142428
def compute_fk(
143429
self,
144430
qpos: torch.Tensor | np.ndarray | None,

embodichain/lab/sim/sim_manager.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
RobotCfg,
7676
)
7777
from embodichain.lab.sim import VisualMaterial, VisualMaterialCfg
78-
from embodichain.data.assets import SimResources
7978
from embodichain.utils import configclass, logger
8079

8180
__all__ = [
@@ -123,6 +122,9 @@ class SimulationManagerCfg:
123122
- RENDER_SCENE_SHARE_ENGINE: The rendering thread and scene update thread share the same thread with the simulation engine.
124123
"""
125124

125+
cpu_num: int = 1
126+
"""The number of CPU threads to use for the simulation engine."""
127+
126128
arena_space: float = 5.0
127129
"""The distance between each arena when building multiple arenas."""
128130

@@ -291,6 +293,7 @@ def _convert_sim_config(
291293
win_config = dexsim.WindowsConfig()
292294
win_config.width = sim_config.width
293295
win_config.height = sim_config.height
296+
world_config.cpu_num = sim_config.cpu_num
294297
world_config.win_config = win_config
295298
world_config.open_windows = not sim_config.headless
296299
self.is_window_opened = not sim_config.headless
@@ -327,16 +330,10 @@ def _convert_sim_config(
327330

328331
return world_config
329332

330-
def get_default_resources(self) -> SimResources:
331-
"""Get the default resources instance.
332-
333-
Returns:
334-
SimResources: The default resources path.
335-
"""
336-
return self._default_resources
337-
338333
def _init_sim_resources(self) -> None:
339334
"""Initialize the default simulation resources."""
335+
from embodichain.data.assets import SimResources
336+
340337
self._default_resources = SimResources()
341338

342339
def enable_physics(self, enable: bool) -> None:

tests/sim/objects/test_robot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,20 @@ def test_mimic(self):
241241
len(right_eef_ids_without_mimic) == 6
242242
), f"Expected 6 right eef joint IDs without mimic, got {len(right_eef_ids_without_mimic)}"
243243

244+
def test_setter_and_getter_with_control_part(self):
245+
left_arm_qpos = self.robot.get_qpos(name="left_arm")
246+
assert left_arm_qpos.shape == (10, 7)
247+
248+
left_qpos_limits = self.robot.get_qpos_limits(name="left_arm")
249+
assert left_qpos_limits.shape == (10, 7, 2)
250+
251+
dummy_qpos = torch.randn(10, 7, device=self.sim.device)
252+
# Clamp to limits
253+
dummy_qpos = torch.max(
254+
torch.min(dummy_qpos, left_qpos_limits[:, :, 1]), left_qpos_limits[:, :, 0]
255+
)
256+
self.robot.set_qpos(qpos=dummy_qpos, name="left_arm")
257+
244258
def teardown_method(self):
245259
"""Clean up resources after each test method."""
246260
self.sim.destroy()

0 commit comments

Comments
 (0)