Skip to content

Commit 5030b43

Browse files
committed
Update the grad and some test cases
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 9b96fea commit 5030b43

File tree

2 files changed

+58
-80
lines changed

2 files changed

+58
-80
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,17 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
4040
# initialize the true module if necessary
4141
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model
4242

43+
# Freeze all base model weights before replacing modules if freeze_base_model is True
44+
if config.freeze_base_model:
45+
for param in model.parameters():
46+
param.requires_grad = False
47+
4348
replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)
4449

4550
metadata = {}
4651
add_adapter(model, config)
47-
update_grads(model, config)
52+
# Update gradient settings for LoRA parameters only
53+
_update_lora_grads(model, config)
4854

4955
return model, metadata
5056

@@ -169,17 +175,25 @@ def _iter_lora_modules(model, layer_patterns=None):
169175

170176

171177
def _set_base_requires_grad(model, *, requires_grad: bool, layer_patterns=None):
172-
for _, module in _iter_lora_modules(model, layer_patterns):
173-
lora_param_ids = {
174-
id(param)
175-
for adapter in module._lora_adapters.values()
176-
for submodule in ("lora_a", "lora_b")
177-
for _, param in adapter[submodule].named_parameters()
178-
}
179-
for _, param in module.named_parameters():
180-
if id(param) in lora_param_ids:
178+
# Collect all LoRA parameter IDs across the entire model
179+
lora_param_ids = set()
180+
for _, module in _iter_lora_modules(model, layer_patterns=None):
181+
for adapter in module._lora_adapters.values():
182+
for submodule in ("lora_a", "lora_b"):
183+
for _, param in adapter[submodule].named_parameters():
184+
lora_param_ids.add(id(param))
185+
186+
# Set requires_grad for all parameters in the model (excluding LoRA parameters)
187+
for name, param in model.named_parameters():
188+
# Skip LoRA parameters
189+
if id(param) in lora_param_ids:
190+
continue
191+
# If layer_patterns is specified, only affect matching layers
192+
if layer_patterns is not None:
193+
module_name = ".".join(name.split(".")[:-1]) # Get module name without param name
194+
if not _matches(module_name, layer_patterns):
181195
continue
182-
param.requires_grad = requires_grad
196+
param.requires_grad = requires_grad
183197

184198

185199
def _iter_adapter_names(module, adapter_patterns=None):
@@ -202,7 +216,8 @@ def _set_lora_requires_grad(
202216
def freeze_base_weights(model, *, layer_patterns=None):
203217
"""Freeze base model weights to prevent gradient updates during training.
204218
205-
This function sets requires_grad=False for all base model parameters in LoRA modules,
219+
This function sets requires_grad=False for all base model parameters (including
220+
linear weights, embeddings, layer norms, etc.) across the entire model,
206221
while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
207222
only adapter weights should be updated.
208223
@@ -218,8 +233,10 @@ def freeze_base_weights(model, *, layer_patterns=None):
218233
def unfreeze_base_weights(model, *, layer_patterns=None):
219234
"""Unfreeze base model weights to allow gradient updates during training.
220235
221-
This function sets requires_grad=True for all base model parameters in LoRA modules.
222-
Useful when you want to fine-tune both base model and LoRA adapter weights together.
236+
This function sets requires_grad=True for all base model parameters (including
237+
linear weights, embeddings, layer norms, etc.) across the entire model,
238+
while keeping LoRA adapter parameters unchanged. Useful when you want to fine-tune
239+
both base model and LoRA adapter weights together.
223240
224241
Args:
225242
model: Model containing LoRA modules whose base weights should be unfrozen
@@ -277,18 +294,20 @@ def unfreeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
277294
)
278295

279296

280-
def update_grads(model, config: PEFTConfig):
281-
"""Update gradient computation settings based on PEFTConfig.
297+
def _update_lora_grads(model, config: PEFTConfig):
298+
"""Update gradient computation settings for LoRA parameters only (internal function).
299+
300+
This internal function configures which LoRA adapter parameters should have gradients
301+
computed based on the freeze_lora_weights setting in the PEFTConfig. It's typically
302+
called during model initialization after LoRA adapters have been added.
282303
283-
This function configures which model parameters should have gradients computed
284-
based on the freeze settings in the PEFTConfig. It's typically called during
285-
model initialization or when switching training configurations.
304+
Note: This function only affects LoRA parameters. Base model parameter gradients
305+
should be set separately (e.g., in convert_to_peft_model before LoRA module replacement).
286306
287307
Args:
288308
model: Model containing LoRA modules to configure
289-
config: PEFTConfig instance with freeze_base_model and freeze_lora_weights settings
290-
- If config.freeze_base_model is True, base weights will have requires_grad=False
309+
config: PEFTConfig instance with freeze_lora_weights setting
291310
- If config.freeze_lora_weights is True, LoRA weights will have requires_grad=False
311+
- If config.freeze_lora_weights is False, LoRA weights will have requires_grad=True
292312
"""
293-
_set_base_requires_grad(model, requires_grad=not config.freeze_base_model)
294313
_set_lora_requires_grad(model, requires_grad=not config.freeze_lora_weights)

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ def test_forward_with_two_loras(lora_config_1, lora_config_2):
313313

314314

315315
# TODO: Rank check
316+
317+
316318
def _test_attr_changes_with_one_lora(lora_config, rank, size):
317319
"""Test forward pass with a single LoRA adapter with various configurations."""
318320
hidden_size = 320
@@ -422,26 +424,12 @@ def _test_adapter_gradient_flow_freeze_base_model(lora_config, tmp_path, rank, s
422424
loss = output.sum()
423425
loss.backward()
424426

425-
for name, module in model.named_modules():
426-
if isinstance(module, LoRAModule):
427-
if len(module._lora_adapters) == 0:
428-
continue
429-
for adapter_name in module._lora_adapters:
430-
lora_a_module = module._lora_adapters[adapter_name]["lora_a"]
431-
lora_b_module = module._lora_adapters[adapter_name]["lora_b"]
432-
433-
for param_name, param in lora_a_module.named_parameters():
434-
assert param.grad is not None, f"lora_a.{param_name} in {name} has no gradient"
435-
assert torch.any(param.grad != 0), (
436-
f"lora_a.{param_name} gradient is all zeros in {name}"
437-
)
438-
439-
for param_name, param in lora_b_module.named_parameters():
440-
assert param.grad is not None, f"lora_b.{param_name} in {name} has no gradient"
441-
assert torch.any(param.grad != 0), (
442-
f"lora_b.{param_name} gradient is all zeros in {name}"
443-
)
444-
assert module.weight.grad is None
427+
for name, param in model.named_parameters():
428+
if "lora" in name:
429+
assert param.grad is not None
430+
assert torch.any(param.grad != 0)
431+
else:
432+
assert param.grad is None
445433

446434

447435
@pytest.mark.parametrize("device_count", get_device_counts())
@@ -487,22 +475,12 @@ def _test_adapter_gradient_flow_freeze_lora_model(lora_config, tmp_path, rank, s
487475
loss = output.sum()
488476
loss.backward()
489477

490-
for name, module in model.named_modules():
491-
if isinstance(module, LoRAModule):
492-
if len(module._lora_adapters) == 0:
493-
continue
494-
for adapter_name in module._lora_adapters:
495-
lora_a_module = module._lora_adapters[adapter_name]["lora_a"]
496-
lora_b_module = module._lora_adapters[adapter_name]["lora_b"]
497-
498-
for param_name, param in lora_a_module.named_parameters():
499-
assert param.grad is None, f"lora_a.{param_name} in {name} has gradient"
500-
501-
for param_name, param in lora_b_module.named_parameters():
502-
assert param.grad is None, f"lora_b.{param_name} in {name} has gradient"
503-
504-
assert module.weight.grad is not None
505-
assert torch.any(module.weight.grad != 0), "weight gradient is all zeros"
478+
for name, param in model.named_parameters():
479+
if "lora" in name:
480+
assert param.grad is None
481+
else:
482+
assert param.grad is not None
483+
assert torch.any(param.grad != 0)
506484

507485

508486
@pytest.mark.parametrize("device_count", get_device_counts())
@@ -548,28 +526,9 @@ def _test_adapter_gradient_flow(lora_config, tmp_path, rank, size):
548526
loss = output.sum()
549527
loss.backward()
550528

551-
for name, module in model.named_modules():
552-
if isinstance(module, LoRAModule):
553-
if len(module._lora_adapters) == 0:
554-
continue
555-
for adapter_name in module._lora_adapters:
556-
lora_a_module = module._lora_adapters[adapter_name]["lora_a"]
557-
lora_b_module = module._lora_adapters[adapter_name]["lora_b"]
558-
559-
for param_name, param in lora_a_module.named_parameters():
560-
assert param.grad is not None, f"lora_a.{param_name} in {name} has gradient"
561-
assert torch.any(param.grad != 0), (
562-
f"lora_a.{param_name} gradient is all zeros in {name}"
563-
)
564-
565-
for param_name, param in lora_b_module.named_parameters():
566-
assert param.grad is not None, f"lora_b.{param_name} in {name} has gradient"
567-
assert torch.any(param.grad != 0), (
568-
f"lora_b.{param_name} gradient is all zeros in {name}"
569-
)
570-
571-
assert module.weight.grad is not None
572-
assert torch.any(module.weight.grad != 0), "weight gradient is all zeros"
529+
for name, param in model.named_parameters():
530+
assert param.grad is not None
531+
assert torch.any(param.grad != 0), "weight gradient is all zeros"
573532

574533

575534
@pytest.mark.parametrize("device_count", get_device_counts())

0 commit comments

Comments
 (0)