Skip to content

Commit bde750b

Browse files
committed
Update test cases
Signed-off-by: Jingyu Xin <[email protected]>
1 parent f1f94af commit bde750b

File tree

2 files changed

+81
-37
lines changed

2 files changed

+81
-37
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@
2626

2727
# TODO: Add test cases to cover these functions
2828
__all__ = [
29-
"freeze_base_weights",
3029
"freeze_lora_weights",
3130
"replace_lora_module",
32-
"unfreeze_base_weights",
33-
"unfreeze_lora_weights",
3431
]
3532

3633

@@ -184,40 +181,6 @@ def _set_lora_requires_grad(
184181
param.requires_grad = requires_grad
185182

186183

187-
def freeze_base_weights(model, *, layer_patterns=None):
188-
"""Freeze base model weights to prevent gradient updates during training.
189-
190-
This function sets requires_grad=False for all base model parameters (including
191-
linear weights, embeddings, layer norms, etc.) across the entire model,
192-
while keeping LoRA adapter parameters trainable. Useful for LoRA fine-tuning where
193-
only adapter weights should be updated.
194-
195-
Args:
196-
model: Model containing LoRA modules whose base weights should be frozen
197-
layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
198-
layer names. If provided, only layers matching these patterns will be affected.
199-
Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
200-
"""
201-
_set_base_requires_grad(model, requires_grad=False, layer_patterns=layer_patterns)
202-
203-
204-
def unfreeze_base_weights(model, *, layer_patterns=None):
205-
"""Unfreeze base model weights to allow gradient updates during training.
206-
207-
This function sets requires_grad=True for all base model parameters (including
208-
linear weights, embeddings, layer norms, etc.) across the entire model,
209-
while keeping LoRA adapter parameters unchanged. Useful when you want to fine-tune
210-
both base model and LoRA adapter weights together.
211-
212-
Args:
213-
model: Model containing LoRA modules whose base weights should be unfrozen
214-
layer_patterns: Optional patterns (str, bytes, or Iterable) to match specific
215-
layer names. If provided, only layers matching these patterns will be affected.
216-
Supports Unix-style wildcards (e.g., "*.linear", "transformer.*")
217-
"""
218-
_set_base_requires_grad(model, requires_grad=True, layer_patterns=layer_patterns)
219-
220-
221184
def freeze_lora_weights(model, *, layer_patterns=None, adapter_patterns=None):
222185
"""Freeze LoRA adapter weights to prevent gradient updates during training.
223186

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,87 @@ def test_adapter_gradient_flow(lora_config, tmp_path):
543543
)
544544

545545

546+
def _test_adapter_gradient_flow_freeze_lora_with_api(lora_config, tmp_path, rank, size):
547+
hidden_size = 256
548+
549+
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)
550+
model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
551+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
552+
lora_config["freeze_lora_weights"] = False
553+
lora_config["freeze_base_model"] = False
554+
555+
mtpeft.update_model(model, lora_config)
556+
# Freeze the self_attention layers only
557+
mtpeft.freeze_lora_weights(model, layer_patterns="*self_attention*")
558+
model.train()
559+
560+
# Use a simple forward pass instead for grad check
561+
batch_size = prompt_tokens.shape[0]
562+
seq_len = prompt_tokens.shape[-1]
563+
device = prompt_tokens.device
564+
565+
attention_mask = (
566+
torch.triu(torch.ones((batch_size, seq_len, seq_len), device=device), diagonal=1)
567+
.bool()
568+
.view(batch_size, 1, seq_len, seq_len)
569+
)
570+
571+
output = model(prompt_tokens, position_ids=None, attention_mask=attention_mask)
572+
573+
loss = output.sum()
574+
loss.backward()
575+
576+
for name, param in model.named_parameters():
577+
if "lora" in name and "self_attention" in name:
578+
assert param.grad is None
579+
else:
580+
assert param.grad is not None
581+
assert torch.any(param.grad != 0), "weight gradient is all zeros"
582+
583+
for p in model.parameters():
584+
p.grad = None
585+
586+
mtpeft.freeze_lora_weights(model)
587+
model.train()
588+
589+
# Use a simple forward pass instead for grad check
590+
batch_size = prompt_tokens.shape[0]
591+
seq_len = prompt_tokens.shape[-1]
592+
device = prompt_tokens.device
593+
594+
attention_mask = (
595+
torch.triu(torch.ones((batch_size, seq_len, seq_len), device=device), diagonal=1)
596+
.bool()
597+
.view(batch_size, 1, seq_len, seq_len)
598+
)
599+
600+
output = model(prompt_tokens, position_ids=None, attention_mask=attention_mask)
601+
602+
loss = output.sum()
603+
loss.backward()
604+
605+
for name, param in model.named_parameters():
606+
if "lora" in name:
607+
assert param.grad is None
608+
else:
609+
assert param.grad is not None
610+
assert torch.any(param.grad != 0), "weight gradient is all zeros"
611+
612+
613+
@pytest.mark.parametrize(
614+
"lora_config",
615+
[
616+
LARGE_LORA_CFG_RANDOM_INIT_TEST,
617+
],
618+
)
619+
def test_adapter_gradient_flow_freeze_lora_with_api(lora_config, tmp_path):
620+
spawn_multiprocess_job(
621+
size=torch.cuda.device_count(),
622+
job=partial(_test_adapter_gradient_flow_freeze_lora_with_api, lora_config, str(tmp_path)),
623+
backend="nccl",
624+
)
625+
626+
546627
def _test_quantize_then_lora(lora_config, tmp_path, rank, size):
547628
hidden_size = 512
548629
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)

0 commit comments

Comments
 (0)