Skip to content

Commit 8d3077b

Browse files
tastelikefeettastelikefeet
authored andcommitted
fix loss_scale sp (#4880)
--------- Co-authored-by: tastelikefeet <[email protected]>
1 parent 64af52f commit 8d3077b

File tree

1 file changed

+3
-1
lines changed
  • swift/trainers/sequence_parallel

1 file changed

+3
-1
lines changed

swift/trainers/sequence_parallel/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None
111111
device = logits.device
112112

113113
if labels.shape[1] > logits.shape[1]:
114-
_, _, labels, _, _, loss_scale = sp_instance.pad_and_split_inputs(None, None, labels, None, None, loss_scale)
114+
_, _, labels, _, _, _ = sp_instance.pad_and_split_inputs(None, None, labels, None, None, None)
115+
if loss_scale.shape[1] > logits.shape[1]:
116+
_, _, _, _, _, loss_scale = sp_instance.pad_and_split_inputs(None, None, None, None, None, loss_scale)
115117
logits = logits.view(-1, logits.shape[-1])
116118

117119
labels = labels.flatten().to(device)

0 commit comments

Comments
 (0)