Skip to content

Commit 5af0ffe

Browse files
authored
Fix Articulation GPU issue (#44)
1 parent f23e65f commit 5af0ffe

File tree

5 files changed

+38
-15
lines changed

5 files changed

+38
-15
lines changed

embodichain/lab/sim/cfg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def attr(self) -> PhysicalAttr:
165165
attr.mass = self.mass
166166
attr.contact_offset = self.contact_offset
167167
attr.rest_offset = self.rest_offset
168-
attr.enable_collision = self.enable_collision
169168
attr.dynamic_friction = self.dynamic_friction
170169
attr.static_friction = self.static_friction
171170
attr.angular_damping = self.angular_damping

embodichain/lab/sim/objects/articulation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,7 @@ def __init__(
511511
self.device = device
512512

513513
# Store all indices for batch operations
514-
self._all_indices = torch.arange(
515-
len(entities), dtype=torch.int32, device=device
516-
)
514+
self._all_indices = torch.arange(len(entities), dtype=torch.int32).tolist()
517515

518516
if device.type == "cuda":
519517
self._world.update(0.001)
@@ -799,6 +797,7 @@ def set_local_pose(
799797
# we should keep `pose_` life cycle to the end of the function.
800798
pose_ = torch.cat((quat, xyz), dim=-1)
801799
indices = self.body_data.gpu_indices[local_env_ids]
800+
torch.cuda.synchronize(self.device)
802801
self._ps.gpu_apply_root_data(
803802
data=pose_,
804803
gpu_indices=indices,
@@ -978,6 +977,7 @@ def set_qpos(
978977
indices = self.body_data.gpu_indices[local_env_ids]
979978
qpos_set = self.body_data._qpos[local_env_ids]
980979
qpos_set[:, local_joint_ids] = qpos
980+
torch.cuda.synchronize(self.device)
981981
self._ps.gpu_apply_joint_data(
982982
data=qpos_set,
983983
gpu_indices=indices,
@@ -1041,6 +1041,7 @@ def set_qvel(
10411041
self.body_data.qvel
10421042
qvel_set = self.body_data._qvel[local_env_ids]
10431043
qvel_set[:, joint_ids] = qvel
1044+
torch.cuda.synchronize(self.device)
10441045
self._ps.gpu_apply_joint_data(
10451046
data=qvel_set,
10461047
gpu_indices=indices,
@@ -1081,6 +1082,7 @@ def set_qf(
10811082
self.body_data.qf
10821083
qf_set = self.body_data._qf[local_env_ids]
10831084
qf_set[:, joint_ids] = qf
1085+
torch.cuda.synchronize(self.device)
10841086
self._ps.gpu_apply_joint_data(
10851087
data=qf_set,
10861088
gpu_indices=indices,
@@ -1161,11 +1163,13 @@ def clear_dynamics(self, env_ids: Sequence[int] | None = None) -> None:
11611163
(len(local_env_ids), self.dof), dtype=torch.float32, device=self.device
11621164
)
11631165
indices = self.body_data.gpu_indices[local_env_ids]
1166+
torch.cuda.synchronize(self.device)
11641167
self._ps.gpu_apply_joint_data(
11651168
data=zeros,
11661169
gpu_indices=indices,
11671170
data_type=ArticulationGPUAPIWriteType.JOINT_VELOCITY,
11681171
)
1172+
torch.cuda.synchronize(self.device)
11691173
self._ps.gpu_apply_joint_data(
11701174
data=zeros,
11711175
gpu_indices=indices,

embodichain/lab/sim/objects/rigid_object.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def __init__(
176176
self._world = dexsim.default_world()
177177
self._ps = self._world.get_physics_scene()
178178

179-
self._all_indices = torch.arange(
180-
len(entities), dtype=torch.int32, device=device
181-
)
179+
self._all_indices = torch.arange(len(entities), dtype=torch.int32).tolist()
182180

183181
# data for managing body data (only for dynamic and kinematic bodies) on GPU.
184182
self._data: RigidBodyData | None = None
@@ -200,6 +198,12 @@ def __init__(
200198
# set default collision filter
201199
self._set_default_collision_filter()
202200

201+
# TODO: Must be called after setting all attributes.
202+
# May be improved in the future.
203+
if cfg.attrs.enable_collision is False:
204+
flag = torch.zeros(len(entities), dtype=torch.bool)
205+
self.enable_collision(flag)
206+
203207
# reserve flag for collision visible node existence
204208
self._has_collision_visible_node = False
205209

@@ -614,6 +618,26 @@ def get_user_ids(self) -> torch.Tensor:
614618
device=self.device,
615619
)
616620

621+
def enable_collision(
622+
self, enable: torch.Tensor, env_ids: Sequence[int] | None = None
623+
) -> None:
624+
"""Enable or disable collision for the rigid bodies.
625+
626+
Args:
627+
enable (torch.Tensor): A tensor of shape (N,) representing whether to enable collision for each rigid body.
628+
env_ids (Sequence[int] | None): Environment indices. If None, then all indices are used.
629+
"""
630+
local_env_ids = self._all_indices if env_ids is None else env_ids
631+
632+
if len(local_env_ids) != len(enable):
633+
logger.log_error(
634+
f"Length of env_ids {len(local_env_ids)} does not match enable length {len(enable)}."
635+
)
636+
637+
enable_list = enable.tolist()
638+
for i, env_idx in enumerate(local_env_ids):
639+
self._entities[env_idx].enable_collision(bool(enable_list[i]))
640+
617641
def clear_dynamics(self, env_ids: Sequence[int] | None = None) -> None:
618642
"""Clear the dynamics of the rigid bodies by resetting velocities and applying zero forces and torques.
619643

embodichain/lab/sim/objects/rigid_object_group.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,10 @@ def __init__(
187187
self._world = dexsim.default_world()
188188
self._ps = self._world.get_physics_scene()
189189

190-
self._all_indices = torch.arange(
191-
len(entities), dtype=torch.int32, device=device
192-
)
190+
self._all_indices = torch.arange(len(entities), dtype=torch.int32).tolist()
193191
self._all_obj_indices = torch.arange(
194-
len(entities[0]), dtype=torch.int32, device=device
195-
)
192+
len(entities[0]), dtype=torch.int32
193+
).tolist()
196194

197195
# data for managing body data (only for dynamic and kinematic bodies) on GPU.
198196
self._data = RigidBodyGroupData(entities=entities, ps=self._ps, device=device)

embodichain/lab/sim/objects/soft_object.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,7 @@ def __init__(
165165
) -> None:
166166
self._world = dexsim.default_world()
167167
self._ps = self._world.get_physics_scene()
168-
self._all_indices = torch.arange(
169-
len(entities), dtype=torch.int32, device=device
170-
)
168+
self._all_indices = torch.arange(len(entities), dtype=torch.int32).tolist()
171169

172170
self._data = SoftBodyData(entities=entities, ps=self._ps, device=device)
173171

0 commit comments

Comments
 (0)