Skip to content

Commit d978082

Browse files
yonekenduburcqa
andauthored
[BUG FIX] Fix boolean mask inversion for PyTorch 2.x (#2056)
Co-authored-by: Alexis DUBURCQ <[email protected]>
1 parent 5672000 commit d978082

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,9 +1305,9 @@ def inverse_kinematics_multilink(
13051305
# pos and rot mask
13061306
pos_mask = broadcast_tensor(pos_mask, gs.tc_bool, (3,)).contiguous()
13071307
rot_mask = broadcast_tensor(rot_mask, gs.tc_bool, (3,)).contiguous()
1308-
if sum(rot_mask) == 1:
1309-
rot_mask = 1 - rot_mask
1310-
elif sum(rot_mask) == 2:
1308+
if (num_axis := rot_mask.sum()) == 1:
1309+
rot_mask = ~rot_mask if gs.tc_bool == torch.bool else 1 - rot_mask
1310+
elif num_axis == 2:
13111311
gs.raise_exception("You can only align 0, 1 axis or all 3 axes.")
13121312

13131313
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs)

0 commit comments

Comments
 (0)