Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ def main():
dtype=dtype,
)
model_config._attn_implementation = model_args.attn_impl
if training_args.use_expert_parallel and training_args.expert_parallel_degree >= 1:
model_config.n_group = training_args.expert_parallel_degree
model_config.use_fused_head_and_loss_fn = model_args.use_fused_head_and_loss_fn
model_config.use_filtered_label_loss = training_args.use_filtered_label_loss
model_config.loss_subbatch_sequence_length = training_args.loss_subbatch_sequence_length

LlmMetaConfig.set_llm_config(model_config, training_args)

Expand Down
3 changes: 3 additions & 0 deletions examples/config/dpo_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ recompute: true
bf16: true
fp16_opt_level: O2
unified_checkpoint: true
use_fused_head_and_loss_fn: false
use_filtered_label_loss: false
loss_subbatch_sequence_length: -1
3 changes: 3 additions & 0 deletions examples/config/dpo_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ recompute: true
bf16: true
fp16_opt_level: O2
unified_checkpoint: true
use_fused_head_and_loss_fn: false
use_filtered_label_loss: false
loss_subbatch_sequence_length: -1
3 changes: 3 additions & 0 deletions examples/config/sft_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ recompute: true
bf16: true
fp16_opt_level: O2
unified_checkpoint: true
use_fused_head_and_loss_fn: false
use_filtered_label_loss: false
loss_subbatch_sequence_length: -1
3 changes: 3 additions & 0 deletions examples/config/sft_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ recompute: true
bf16: true
fp16_opt_level: O2
unified_checkpoint: true
use_fused_head_and_loss_fn: false
use_filtered_label_loss: false
loss_subbatch_sequence_length: -1
5 changes: 5 additions & 0 deletions examples/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def main():
model_config.max_sequence_length = training_args.max_seq_len
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
model_config._attn_implementation = model_args.attn_impl
if training_args.use_expert_parallel and training_args.expert_parallel_degree >= 1:
model_config.n_group = training_args.expert_parallel_degree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_config.n_group 不是每个模型都叫n_group,看看怎么处理通用化一些

model_config.use_fused_head_and_loss_fn = model_args.use_fused_head_and_loss_fn
model_config.use_filtered_label_loss = training_args.use_filtered_label_loss
model_config.loss_subbatch_sequence_length = training_args.loss_subbatch_sequence_length
logger.info(f"Final model config: {model_config}")
logger.info("Creating model")

Expand Down
2 changes: 1 addition & 1 deletion paddleformers/nn/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, hidden_states, tensor_parallel_output=None):
hidden_states,
self.weight,
self.bias,
self.config.tie_word_embeddings,
True,
)

return calc_lm_head_logits(
Expand Down
23 changes: 7 additions & 16 deletions paddleformers/transformers/glm4_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ...nn.norm import Norm as GeneralNorm
from ...nn.pp_model import GeneralModelForCausalLMPipe
from ...utils.log import logger
from ..llama.modeling import get_use_casual_mask
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ..model_utils import PretrainedModel, register_base_model
from ..moe_gate import PretrainedMoEGate
Expand Down Expand Up @@ -1013,27 +1012,18 @@ def forward(
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
inputs_embeds = ScatterOp.apply(inputs_embeds)

if attn_mask_startend_row_indices is not None or get_use_casual_mask():
attention_mask = None
else:
# [bs, seq_len]
attention_mask = (
paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if attention_mask is None
else attention_mask
)
hidden_states = inputs_embeds

if attention_mask is not None:
causal_mask = self._prepare_decoder_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=cache_length,
dtype=inputs_embeds.dtype,
) # [bs, 1, seq_len, seq_len]
attention_mask, hidden_states.shape[:2], cache_length, hidden_states.dtype
)
else:
causal_mask = None

if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
Expand Down Expand Up @@ -1205,6 +1195,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

hidden_states = outputs[0] # [bs, seq_len, dim]
Expand Down
3 changes: 3 additions & 0 deletions paddleformers/trl/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ class ModelConfig:
pp_seg_method: Optional[str] = field(
default="layer:DecoderLayer|EmptyLayer", metadata={"help": "PP Segmentation Method"}
)
use_fused_head_and_loss_fn: bool = field(
default=False, metadata={"help": "Whether to use fused head and loss function"}
)
2 changes: 1 addition & 1 deletion tests/nn/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_forward_fused_head_loss(self):
self.assertEqual(output[0].shape, test_input.shape)
self.assertEqual(output[1].shape, lm_head.weight.shape)
self.assertIs(output[2], lm_head.bias)
self.assertEqual(output[3], config.tie_word_embeddings)
self.assertEqual(output[3], True)


if __name__ == "__main__":
Expand Down