Skip to content

Commit 03b533c

Browse files
authored
Fix conflict bug in DPO (#2705)
1 parent ee9ad93 commit 03b533c

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

examples/alignment/dpo/dpo_argument.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,6 @@ class DPOModelArgument:
146146
default=None,
147147
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
148148
)
149-
use_sparse_head_and_loss_fn: bool = field(
150-
default=True,
151-
metadata={"help": "Whether to use sparse indexing for loss calculation."},
152-
)
153-
use_fused_head_and_loss_fn: bool = field(
154-
default=True,
155-
metadata={"help": "Whether to use fused kernel to calculate lm head and loss."},
156-
)
157149
use_attn_mask_startend_row_indices: bool = field(
158150
default=True,
159151
metadata={"help": "Sparse attention mode."},

examples/alignment/dpo/run_dpo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def main():
152152
model_args.model_name_or_path,
153153
dtype=dtype,
154154
)
155+
ref_model_config._attn_implementation = model_args.attn_impl
156+
155157
LlmMetaConfig.set_llm_config(ref_model_config, training_args)
156158

157159
if training_args.pipeline_parallel_degree > 1:
@@ -309,8 +311,8 @@ def main():
309311
collate_fn,
310312
tokenizer=tokenizer,
311313
max_seq_len=max_seq_len,
312-
use_sparse_head_and_loss_fn=model_args.use_sparse_head_and_loss_fn,
313-
use_fused_head_and_loss_fn=model_args.use_fused_head_and_loss_fn,
314+
use_sparse_head_and_loss_fn=model_config.use_sparse_head_and_loss_fn,
315+
use_fused_head_and_loss_fn=model_config.use_fused_head_and_loss_fn,
314316
),
315317
ignore_eos_token=True,
316318
)

0 commit comments

Comments
 (0)