Skip to content
3 changes: 2 additions & 1 deletion paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,8 @@ def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=Fals
def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None):
# add this for fused_head_and_loss_fn
if self.config.use_fused_head_and_loss_fn:
return hidden_states, self.weight, None, self.transpose_y
# return hidden_states, self.weight, None, self.transpose_y
return hidden_states, self.weight, None, None

if self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
Expand Down
Loading