Skip to content

Commit 23cb839

Browse files
authored
[bugfix] fix megatron-swift seq_cls (#6115)
1 parent 2871087 commit 23cb839

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

swift/megatron/model/gpt_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def forward(
250250
logits, _ = self.output_layer(
251251
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
252252
else:
253-
logits = self.output_layer(hidden_states)[0]
254253
if args.sequence_parallel and args.tensor_model_parallel_size > 1:
255-
logits = gather_from_sequence_parallel_region(logits)
254+
hidden_states = gather_from_sequence_parallel_region(hidden_states)
255+
logits = self.output_layer(hidden_states)[0]
256256
if has_config_logger_enabled(self.config):
257257
payload = OrderedDict({
258258
'input_ids': input_ids,

0 commit comments

Comments
 (0)