87
87
)
88
88
from axolotl .utils .collators .mm_chat import MultiModalChatDataCollator
89
89
from axolotl .utils .models import ensure_dtype
90
- from axolotl .utils .schemas .enums import CustomSupportedOptimizers
90
+ from axolotl .utils .schemas .enums import CustomSupportedOptimizers , RLType
91
91
92
92
try :
93
93
import torch ._dynamo # pylint: disable=ungrouped-imports
@@ -353,7 +353,7 @@ def build(self, total_num_steps):
353
353
training_arguments_kwargs ["warmup_steps" ] = warmup_steps
354
354
training_arguments_kwargs ["logging_steps" ] = logging_steps
355
355
356
- if self .cfg .seed :
356
+ if self .cfg .seed is not None :
357
357
training_arguments_kwargs ["seed" ] = self .cfg .seed
358
358
359
359
if self .cfg .gradient_checkpointing :
@@ -547,8 +547,6 @@ def build(self, total_num_steps):
547
547
report_to = []
548
548
if self .cfg .use_wandb :
549
549
report_to .append ("wandb" )
550
- if self .cfg .wandb_name :
551
- training_arguments_kwargs ["run_name" ] = self .cfg .wandb_name
552
550
if self .cfg .use_mlflow :
553
551
report_to .append ("mlflow" )
554
552
if self .cfg .use_tensorboard :
@@ -821,14 +819,15 @@ def build(self, total_num_steps):
821
819
data_collator_kwargs = {
822
820
"padding" : True , # True/"longest" is the default
823
821
}
822
+ multiple = 64
824
823
if self .cfg .pad_to_sequence_len :
825
- data_collator_kwargs ["pad_to_multiple_of" ] = 64 * math .ceil (
826
- self .cfg .sequence_len / 64
824
+ data_collator_kwargs ["pad_to_multiple_of" ] = multiple * math .ceil (
825
+ self .cfg .sequence_len / multiple
827
826
)
828
827
else :
829
828
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
830
829
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
831
- data_collator_kwargs ["pad_to_multiple_of" ] = 64
830
+ data_collator_kwargs ["pad_to_multiple_of" ] = multiple
832
831
833
832
if self .cfg .reward_model :
834
833
data_collator_kwargs ["max_length" ] = self .cfg .sequence_len
@@ -1034,6 +1033,10 @@ def build_training_arguments(self, total_num_steps):
1034
1033
training_args_kwargs ["dataloader_prefetch_factor" ] = (
1035
1034
self .cfg .dataloader_prefetch_factor
1036
1035
)
1036
+
1037
+ if self .cfg .seed is not None :
1038
+ training_args_kwargs ["seed" ] = self .cfg .seed
1039
+
1037
1040
if self .cfg .gradient_checkpointing :
1038
1041
training_args_kwargs ["gradient_checkpointing" ] = (
1039
1042
self .cfg .gradient_checkpointing
@@ -1076,23 +1079,27 @@ def build_training_arguments(self, total_num_steps):
1076
1079
if self .cfg .use_wandb :
1077
1080
training_args_kwargs ["run_name" ] = self .cfg .wandb_name
1078
1081
1082
+ training_args_kwargs ["sequence_parallel_degree" ] = (
1083
+ self .cfg .sequence_parallel_degree
1084
+ )
1085
+
1079
1086
training_args_cls = None
1080
1087
blocklist_args_kwargs = []
1081
- if self .cfg .rl == "simpo" :
1088
+ if self .cfg .rl is RLType . SIMPO :
1082
1089
training_args_cls = AxolotlCPOConfig
1083
1090
training_args_kwargs ["loss_type" ] = "simpo"
1084
1091
training_args_kwargs ["max_length" ] = self .cfg .sequence_len
1085
1092
training_args_kwargs ["simpo_gamma" ] = self .cfg .simpo_gamma
1086
1093
if self .cfg .cpo_alpha is not None :
1087
1094
training_args_kwargs ["cpo_alpha" ] = self .cfg .cpo_alpha
1088
1095
1089
- elif self .cfg .rl == "orpo" :
1096
+ elif self .cfg .rl is RLType . ORPO :
1090
1097
training_args_cls = AxolotlORPOConfig
1091
1098
training_args_kwargs ["max_length" ] = self .cfg .sequence_len
1092
1099
if self .cfg .max_prompt_len :
1093
1100
training_args_kwargs ["max_prompt_length" ] = self .cfg .max_prompt_len
1094
1101
1095
- elif self .cfg .rl == "kto" :
1102
+ elif self .cfg .rl is RLType . KTO :
1096
1103
training_args_cls = AxolotlKTOConfig
1097
1104
1098
1105
training_args_kwargs ["desirable_weight" ] = (
@@ -1106,14 +1113,14 @@ def build_training_arguments(self, total_num_steps):
1106
1113
if self .cfg .max_prompt_len :
1107
1114
training_args_kwargs ["max_prompt_length" ] = self .cfg .max_prompt_len
1108
1115
1109
- elif self .cfg .rl == "grpo" :
1116
+ elif self .cfg .rl is RLType . GRPO :
1110
1117
training_args_cls = GRPOStrategy .get_training_args_class ()
1111
1118
training_args_kwargs .update (GRPOStrategy .set_training_args_kwargs (self .cfg ))
1112
1119
blocklist_args_kwargs = GRPOStrategy .get_blocklist_args_kwargs ()
1113
1120
1114
1121
else :
1115
1122
training_args_cls = AxolotlDPOConfig
1116
- if self .cfg .rl == "ipo" :
1123
+ if self .cfg .rl is RLType . IPO :
1117
1124
training_args_kwargs ["loss_type" ] = "ipo"
1118
1125
training_args_kwargs ["max_length" ] = self .cfg .sequence_len
1119
1126
training_args_kwargs ["max_completion_length" ] = None
@@ -1156,67 +1163,69 @@ def build_training_arguments(self, total_num_steps):
1156
1163
1157
1164
def build (self , total_num_steps ):
1158
1165
training_args = self .build_training_arguments (total_num_steps )
1159
- dpo_trainer_kwargs = {}
1160
- if self .cfg .rl == "ipo" :
1166
+ trainer_kwargs = {}
1167
+ if self .cfg .rl is RLType . IPO :
1161
1168
if self .cfg .dpo_label_smoothing :
1162
- dpo_trainer_kwargs ["label_smoothing" ] = self .cfg .dpo_label_smoothing
1169
+ trainer_kwargs ["label_smoothing" ] = self .cfg .dpo_label_smoothing
1163
1170
if self .eval_dataset :
1164
- dpo_trainer_kwargs ["eval_dataset" ] = self .eval_dataset
1171
+ trainer_kwargs ["eval_dataset" ] = self .eval_dataset
1165
1172
if self .cfg .adapter and self .peft_config :
1166
- dpo_trainer_kwargs ["peft_config" ] = self .peft_config
1173
+ trainer_kwargs ["peft_config" ] = self .peft_config
1167
1174
if self .cfg .precompute_ref_log_probs is not None :
1168
- dpo_trainer_kwargs ["precompute_ref_log_probs" ] = (
1175
+ trainer_kwargs ["precompute_ref_log_probs" ] = (
1169
1176
self .cfg .precompute_ref_log_probs
1170
1177
)
1171
- if self .cfg .rl == "grpo" :
1172
- trainer_cls = GRPOStrategy .get_trainer_class ()
1178
+ if self .cfg .rl is RLType .GRPO :
1179
+ trainer_cls = GRPOStrategy .get_trainer_class (
1180
+ sequence_parallel = self .cfg .sequence_parallel_degree > 1
1181
+ )
1173
1182
trainer_cls_args = [self .model ]
1174
1183
trainer_cls_args .extend (GRPOStrategy .set_trainer_args (self .cfg ))
1175
- dpo_trainer_kwargs .update (GRPOStrategy .set_trainer_kwargs (self .cfg ))
1176
- elif self .cfg .rl in ["dpo" , "ipo" ]:
1184
+ trainer_kwargs .update (GRPOStrategy .set_trainer_kwargs (self .cfg ))
1185
+ elif self .cfg .rl in [RLType . DPO , RLType . IPO ]:
1177
1186
trainer_cls = DPOStrategy .get_trainer_class ()
1178
1187
trainer_cls_args = [self .model , self .model_ref ]
1179
- elif self .cfg .rl == "orpo" :
1188
+ elif self .cfg .rl is RLType . ORPO :
1180
1189
trainer_cls = AxolotlORPOTrainer
1181
1190
trainer_cls_args = [self .model ]
1182
- elif self .cfg .rl in [ "kto" ] :
1191
+ elif self .cfg .rl is RLType . KTO :
1183
1192
trainer_cls = AxolotlKTOTrainer
1184
1193
trainer_cls_args = [self .model ]
1185
- elif self .cfg .rl in [ "simpo" ] :
1194
+ elif self .cfg .rl is RLType . SIMPO :
1186
1195
trainer_cls = AxolotlCPOTrainer
1187
1196
trainer_cls_args = [self .model ]
1188
1197
else :
1189
1198
raise ValueError (f"Unsupported RL: { self .cfg .rl } " )
1190
1199
1191
1200
sig = inspect .signature (trainer_cls )
1192
1201
if "tokenizer" in sig .parameters .keys ():
1193
- dpo_trainer_kwargs ["tokenizer" ] = self .tokenizer
1202
+ trainer_kwargs ["tokenizer" ] = self .tokenizer
1194
1203
else :
1195
- dpo_trainer_kwargs ["processing_class" ] = self .tokenizer
1204
+ trainer_kwargs ["processing_class" ] = self .tokenizer
1196
1205
1197
1206
if self .cfg .datasets is not None and (
1198
1207
trainer_cls is DPOStrategy .get_trainer_class ()
1199
1208
):
1200
- dpo_trainer_kwargs ["dataset_tags" ] = [
1209
+ trainer_kwargs ["dataset_tags" ] = [
1201
1210
d ["path" ] for d in self .cfg .datasets if not Path (d ["path" ]).is_dir ()
1202
1211
]
1203
- dpo_trainer = trainer_cls (
1212
+ trainer = trainer_cls (
1204
1213
* trainer_cls_args ,
1205
1214
args = training_args ,
1206
1215
train_dataset = self .train_dataset ,
1207
1216
callbacks = self .get_callbacks (),
1208
- ** dpo_trainer_kwargs ,
1217
+ ** trainer_kwargs ,
1209
1218
)
1210
1219
if self .cfg .fsdp :
1211
- ensure_dtype (dpo_trainer .model , dtype = self .cfg .torch_dtype )
1212
- if self .cfg .rl in ["dpo" , "ipo" ] and dpo_trainer .ref_model :
1213
- ensure_dtype (dpo_trainer .ref_model , dtype = self .cfg .torch_dtype )
1220
+ ensure_dtype (trainer .model , dtype = self .cfg .torch_dtype )
1221
+ if self .cfg .rl in [RLType . DPO , RLType . IPO ] and trainer .ref_model :
1222
+ ensure_dtype (trainer .ref_model , dtype = self .cfg .torch_dtype )
1214
1223
1215
- dpo_trainer = self .hook_post_create_trainer (dpo_trainer )
1216
- for callback in self .get_post_trainer_create_callbacks (dpo_trainer ):
1217
- dpo_trainer .add_callback (callback )
1224
+ trainer = self .hook_post_create_trainer (trainer )
1225
+ for callback in self .get_post_trainer_create_callbacks (trainer ):
1226
+ trainer .add_callback (callback )
1218
1227
1219
- return dpo_trainer
1228
+ return trainer
1220
1229
1221
1230
1222
1231
class HFPPOTrainerBuilder (TrainerBuilderBase ):
0 commit comments