Skip to content

Commit c555040

Browse files
committed
Update
1 parent 8122d6d commit c555040

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

tests/test_dpo_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,18 @@ def test_train(self, model_id):
184184
# Get the dataset
185185
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
186186

187-
# NemotronH does not support gradient checkpointing
188-
gradient_checkpointing = "NemotronH" not in model_id
187+
# NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
188+
# kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
189+
# Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
190+
is_nemotron = "NemotronH" in model_id
189191

190192
# Initialize the trainer
191193
training_args = DPOConfig(
192194
output_dir=self.tmp_dir,
193195
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
194196
report_to="none",
195-
gradient_checkpointing=gradient_checkpointing,
197+
gradient_checkpointing=not is_nemotron,
198+
use_cpu=is_nemotron,
196199
)
197200
trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset)
198201

tests/test_sft_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,17 @@ def test_train(self, model_id):
297297
# Get the dataset
298298
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
299299

300-
# NemotronH does not support gradient checkpointing
301-
gradient_checkpointing = "NemotronH" not in model_id
300+
# NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
301+
# kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
302+
# Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
303+
is_nemotron = "NemotronH" in model_id
302304

303305
# Initialize the trainer
304306
training_args = SFTConfig(
305-
output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=gradient_checkpointing
307+
output_dir=self.tmp_dir,
308+
report_to="none",
309+
gradient_checkpointing=not is_nemotron,
310+
use_cpu=is_nemotron,
306311
)
307312
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)
308313

0 commit comments

Comments
 (0)