File tree Expand file tree Collapse file tree 2 files changed +14
-6
lines changed
Expand file tree Collapse file tree 2 files changed +14
-6
lines changed Original file line number Diff line number Diff line change @@ -184,15 +184,18 @@ def test_train(self, model_id):
184184 # Get the dataset
185185 dataset = load_dataset ("trl-internal-testing/zen" , "standard_preference" , split = "train" )
186186
187- # NemotronH does not support gradient checkpointing
188- gradient_checkpointing = "NemotronH" not in model_id
187+ # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
188+ # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
189+ # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
190+ is_nemotron = "NemotronH" in model_id
189191
190192 # Initialize the trainer
191193 training_args = DPOConfig (
192194 output_dir = self .tmp_dir ,
193195 learning_rate = 0.1 , # use higher lr because gradients are tiny and default lr can stall updates
194196 report_to = "none" ,
195- gradient_checkpointing = gradient_checkpointing ,
197+ gradient_checkpointing = not is_nemotron ,
198+ use_cpu = is_nemotron ,
196199 )
197200 trainer = DPOTrainer (model = model_id , args = training_args , train_dataset = dataset )
198201
Original file line number Diff line number Diff line change @@ -297,12 +297,17 @@ def test_train(self, model_id):
297297 # Get the dataset
298298 dataset = load_dataset ("trl-internal-testing/zen" , "standard_language_modeling" , split = "train" )
299299
300- # NemotronH does not support gradient checkpointing
301- gradient_checkpointing = "NemotronH" not in model_id
300+ # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA
301+ # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions.
302+ # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels).
303+ is_nemotron = "NemotronH" in model_id
302304
303305 # Initialize the trainer
304306 training_args = SFTConfig (
305- output_dir = self .tmp_dir , report_to = "none" , gradient_checkpointing = gradient_checkpointing
307+ output_dir = self .tmp_dir ,
308+ report_to = "none" ,
309+ gradient_checkpointing = not is_nemotron ,
310+ use_cpu = is_nemotron ,
306311 )
307312 trainer = SFTTrainer (model = model_id , args = training_args , train_dataset = dataset )
308313
You can’t perform that action at this time.
0 commit comments