Skip to content

Commit 3c8f9d4

Browse files
committed
Cursor review
1 parent c555040 commit 3c8f9d4

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tests/test_dpo_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,17 @@ def test_train(self, model_id):
187187
# NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
188188
# kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
189189
# Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
190-
is_nemotron = "NemotronH" in model_id
190+
kwargs = {}
191+
if "NemotronH" in model_id:
192+
kwargs["gradient_checkpointing"] = False
193+
kwargs["use_cpu"] = True
191194

192195
# Initialize the trainer
193196
training_args = DPOConfig(
194197
output_dir=self.tmp_dir,
195198
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
196199
report_to="none",
197-
gradient_checkpointing=not is_nemotron,
198-
use_cpu=is_nemotron,
200+
**kwargs,
199201
)
200202
trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset)
201203

tests/test_sft_trainer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,13 @@ def test_train(self, model_id):
300300
# NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
301301
# kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
302302
# Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
303-
is_nemotron = "NemotronH" in model_id
303+
kwargs = {}
304+
if "NemotronH" in model_id:
305+
kwargs["gradient_checkpointing"] = False
306+
kwargs["use_cpu"] = True
304307

305308
# Initialize the trainer
306-
training_args = SFTConfig(
307-
output_dir=self.tmp_dir,
308-
report_to="none",
309-
gradient_checkpointing=not is_nemotron,
310-
use_cpu=is_nemotron,
311-
)
309+
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", **kwargs)
312310
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)
313311

314312
# Save the initial parameters to compare them later

0 commit comments

Comments
 (0)