Skip to content

Commit 6fe3c7a

Browse files
resolved a 0-dim tensor slicing bug from _get_state_without_padding (#7659)
fixes #7650 adding a `value.dim()>0` check to prevent slicing of 0-dim tensors cc @sfc-gh-truwase Signed-off-by: Naveenraj Kamalakannan <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 76a4075 commit 6fe3c7a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2297,7 +2297,7 @@ def _get_groups_without_padding(self, groups_with_padding):
22972297
def _get_state_without_padding(self, state_with_padding, padding):
22982298
lean_state = {}
22992299
for key, value in state_with_padding.items():
2300-
if torch.is_tensor(value):
2300+
if torch.is_tensor(value) and value.dim() > 0:
23012301
lean_length = value.numel() - padding
23022302
lean_state[key] = value[:lean_length]
23032303
else:

0 commit comments

Comments
 (0)