File tree Expand file tree Collapse file tree 2 files changed +4
-10
lines changed Expand file tree Collapse file tree 2 files changed +4
-10
lines changed Original file line number Diff line number Diff line change @@ -146,14 +146,6 @@ class DPOModelArgument:
146
146
default = None ,
147
147
metadata = {"help" : "whether to fuse first up and gate proj in mlp block" },
148
148
)
149
- use_sparse_head_and_loss_fn : bool = field (
150
- default = True ,
151
- metadata = {"help" : "Whether to use sparse indexing for loss calculation." },
152
- )
153
- use_fused_head_and_loss_fn : bool = field (
154
- default = True ,
155
- metadata = {"help" : "Whether to use fused kernel to calculate lm head and loss." },
156
- )
157
149
use_attn_mask_startend_row_indices : bool = field (
158
150
default = True ,
159
151
metadata = {"help" : "Sparse attention mode." },
Original file line number Diff line number Diff line change @@ -152,6 +152,8 @@ def main():
152
152
model_args .model_name_or_path ,
153
153
dtype = dtype ,
154
154
)
155
+ ref_model_config ._attn_implementation = model_args .attn_impl
156
+
155
157
LlmMetaConfig .set_llm_config (ref_model_config , training_args )
156
158
157
159
if training_args .pipeline_parallel_degree > 1 :
@@ -309,8 +311,8 @@ def main():
309
311
collate_fn ,
310
312
tokenizer = tokenizer ,
311
313
max_seq_len = max_seq_len ,
312
- use_sparse_head_and_loss_fn = model_args .use_sparse_head_and_loss_fn ,
313
- use_fused_head_and_loss_fn = model_args .use_fused_head_and_loss_fn ,
314
+ use_sparse_head_and_loss_fn = model_config .use_sparse_head_and_loss_fn ,
315
+ use_fused_head_and_loss_fn = model_config .use_fused_head_and_loss_fn ,
314
316
),
315
317
ignore_eos_token = True ,
316
318
)
You can’t perform that action at this time.
0 commit comments