Skip to content

Commit 2b09c92

Browse files
committed
Update test cases
Signed-off-by: Jingyu Xin <[email protected]>
1 parent bf33ae4 commit 2b09c92

File tree

1 file changed

+78
-3
lines changed

1 file changed

+78
-3
lines changed

tests/gpu/torch/peft/test_forward_megatron.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,18 +464,91 @@ def forward_func(mod):
464464
print(
465465
f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}"
466466
)
467+
print(model)
467468
print("Test passed!")
468469
finally:
469470
# Clean up model parallel groups
470471
parallel_state.destroy_model_parallel()
471472

472473

473474
def _test_quantize_then_lora_save_restore():
474-
pass
475+
initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234)
476+
477+
try:
478+
model_ref = _gpt_model_provider(tp_size=1)
479+
model_test = _gpt_model_provider(tp_size=1)
480+
prompt_tokens = torch.randint(
481+
0, model_test.vocab_size, (2, model_test.max_sequence_length)
482+
).cuda()
483+
484+
def forward_func(mod):
485+
output = megatron_prefill(model_ref, prompt_tokens)
486+
487+
mtq.quantize(model_ref, mtq.FP8_DEFAULT_CFG, forward_func)
488+
lora_config = {
489+
"adapter_type": "lora",
490+
"adapter_name": "default",
491+
"adapter_cfg": {
492+
"*attention*": {"rank": 32, "scale": 1},
493+
"*mlp*": {"rank": 64, "scale": 1},
494+
},
495+
}
496+
model_ref = mtp.update_model(model_ref, lora_config)
497+
tmp_path = "./model_ref"
498+
save_distributed_checkpoint(tmp_path, model_ref)
499+
save_sharded_modelopt_state([model_ref], tmp_path)
500+
restore_sharded_modelopt_state([model_test], tmp_path)
501+
model_test = load_distributed_checkpoint(tmp_path, model_test)
502+
# Run forward pass
503+
output_test = megatron_prefill(model_test, prompt_tokens)
504+
output_ref = megatron_prefill(model_ref, prompt_tokens)
505+
print(
506+
f"Forward pass successful! Output shape: {output_test.shape if hasattr(output_test, 'shape') else 'N/A'}"
507+
)
508+
print(model_ref)
509+
print(f"output_test: {output_test}")
510+
print(f"output_ref: {output_ref}")
511+
512+
finally:
513+
# Clean up model parallel groups
514+
parallel_state.destroy_model_parallel()
475515

476516

477517
def _test_lora_then_quantize():
478-
pass
518+
initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234)
519+
520+
try:
521+
model = _gpt_model_provider(tp_size=1)
522+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
523+
lora_config = {
524+
"adapter_type": "lora",
525+
"adapter_name": "default",
526+
"adapter_cfg": {
527+
"*attention*": {"rank": 32, "scale": 1},
528+
"*mlp*": {"rank": 64, "scale": 1},
529+
},
530+
}
531+
532+
def forward_func(mod):
533+
output = megatron_prefill(model, prompt_tokens)
534+
535+
model = mtp.update_model(model, lora_config)
536+
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_func)
537+
lora_count = 0
538+
for name, module in model.named_modules():
539+
if hasattr(module, "_lora_adapters"):
540+
lora_count += 1
541+
print(f"LoRA module found: {name}")
542+
print(f"\nTotal LoRA modules: {lora_count}")
543+
output_lora_quant = megatron_prefill(model, prompt_tokens)
544+
print(
545+
f"LoRA forward pass successful! Output shape: {output_lora_quant.shape if hasattr(output_lora_quant, 'shape') else 'N/A'}"
546+
)
547+
print("Test passed!")
548+
print(model)
549+
finally:
550+
# Clean up model parallel groups
551+
parallel_state.destroy_model_parallel()
479552

480553

481554
def _test_lora_then_quantize_save_restore():
@@ -520,7 +593,9 @@ def main():
520593
# _test_lora_save_and_restore()
521594
# _test_lora_add_2nd_lora()
522595
# _test_lora_save_and_restore_with2loras()
523-
_test_quantize_then_lora()
596+
# _test_quantize_then_lora()
597+
# _test_quantize_then_lora_save_restore()
598+
_test_lora_then_quantize()
524599
finally:
525600
if torch.distributed.is_initialized():
526601
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)