File tree Expand file tree Collapse file tree 2 files changed +10
-10
lines changed
Expand file tree Collapse file tree 2 files changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -187,15 +187,17 @@ def test_train(self, model_id):
187187 # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
188188 # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
189189 # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
190- is_nemotron = "NemotronH" in model_id
190+ kwargs = {}
191+ if "NemotronH" in model_id :
192+ kwargs ["gradient_checkpointing" ] = False
193+ kwargs ["use_cpu" ] = True
191194
192195 # Initialize the trainer
193196 training_args = DPOConfig (
194197 output_dir = self .tmp_dir ,
195198 learning_rate = 0.1 , # use higher lr because gradients are tiny and default lr can stall updates
196199 report_to = "none" ,
197- gradient_checkpointing = not is_nemotron ,
198- use_cpu = is_nemotron ,
200+ ** kwargs ,
199201 )
200202 trainer = DPOTrainer (model = model_id , args = training_args , train_dataset = dataset )
201203
Original file line number Diff line number Diff line change @@ -300,15 +300,13 @@ def test_train(self, model_id):
300300 # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
301301 # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
302302 # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
303- is_nemotron = "NemotronH" in model_id
303+ kwargs = {}
304+ if "NemotronH" in model_id :
305+ kwargs ["gradient_checkpointing" ] = False
306+ kwargs ["use_cpu" ] = True
304307
305308 # Initialize the trainer
306- training_args = SFTConfig (
307- output_dir = self .tmp_dir ,
308- report_to = "none" ,
309- gradient_checkpointing = not is_nemotron ,
310- use_cpu = is_nemotron ,
311- )
309+ training_args = SFTConfig (output_dir = self .tmp_dir , report_to = "none" , ** kwargs )
312310 trainer = SFTTrainer (model = model_id , args = training_args , train_dataset = dataset )
313311
314312 # Save the initial parameters to compare them later
You can’t perform that action at this time.
0 commit comments