Skip to content

Commit 4a6832a

Browse files
authored
Fix Rigid Object/Group CUDA memory issue (#46)
1 parent 5af0ffe commit 4a6832a

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

embodichain/lab/sim/objects/rigid_object.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def set_local_pose(
350350
# we should keep `pose_` life cycle to the end of the function.
351351
pose = torch.cat((quat, xyz), dim=-1)
352352
indices = self.body_data.gpu_indices[local_env_ids]
353+
torch.cuda.synchronize(self.device)
353354
self._ps.gpu_apply_rigid_body_data(
354355
data=pose.clone(),
355356
gpu_indices=indices,
@@ -458,6 +459,7 @@ def add_force_torque(
458459

459460
else:
460461
indices = self.body_data.gpu_indices[local_env_ids]
462+
torch.cuda.synchronize(self.device)
461463
if force is not None:
462464
self._ps.gpu_apply_rigid_body_data(
463465
data=force,
@@ -658,6 +660,7 @@ def clear_dynamics(self, env_ids: Sequence[int] | None = None) -> None:
658660
(len(local_env_ids), 3), dtype=torch.float32, device=self.device
659661
)
660662
indices = self.body_data.gpu_indices[local_env_ids]
663+
torch.cuda.synchronize(self.device)
661664
self._ps.gpu_apply_rigid_body_data(
662665
data=zeros,
663666
gpu_indices=indices,

embodichain/lab/sim/objects/rigid_object_group.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def set_local_pose(
362362
indices = self.body_data.gpu_indices[local_env_ids][
363363
:, local_obj_ids
364364
].flatten()
365+
torch.cuda.synchronize(self.device)
365366
self._ps.gpu_apply_rigid_body_data(
366367
data=pose.clone(),
367368
gpu_indices=indices,
@@ -433,6 +434,7 @@ def clear_dynamics(self, env_ids: Sequence[int] | None = None) -> None:
433434
device=self.device,
434435
)
435436
indices = self.body_data.gpu_indices[local_env_ids].flatten()
437+
torch.cuda.synchronize(self.device)
436438
self._ps.gpu_apply_rigid_body_data(
437439
data=zeros,
438440
gpu_indices=indices,

0 commit comments

Comments
 (0)