We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2871087 commit 23cb839Copy full SHA for 23cb839
swift/megatron/model/gpt_model.py
@@ -250,9 +250,9 @@ def forward(
250
logits, _ = self.output_layer(
251
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
252
else:
253
- logits = self.output_layer(hidden_states)[0]
254
if args.sequence_parallel and args.tensor_model_parallel_size > 1:
255
- logits = gather_from_sequence_parallel_region(logits)
+ hidden_states = gather_from_sequence_parallel_region(hidden_states)
+ logits = self.output_layer(hidden_states)[0]
256
if has_config_logger_enabled(self.config):
257
payload = OrderedDict({
258
'input_ids': input_ids,
0 commit comments