Skip to content

Commit 8c31821

Browse files
committed
Update test case
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 49a1e65 commit 8c31821

File tree

1 file changed

+146
-9
lines changed

1 file changed

+146
-9
lines changed

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 146 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
import pytest
44
import torch
55
from _test_utils.import_helper import skip_if_no_megatron
6-
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
6+
from _test_utils.torch_dist.dist_utils import get_device_counts, spawn_multiprocess_job
77
from _test_utils.torch_dist.plugins.megatron_common import (
88
get_mcore_gpt_model,
99
initialize_for_megatron,
1010
)
11+
from megatron.core import dist_checkpointing
12+
13+
from modelopt.torch.opt.plugins.mcore_dist_checkpointing import (
14+
restore_sharded_modelopt_state,
15+
save_sharded_modelopt_state,
16+
)
1117

1218
skip_if_no_megatron()
1319

@@ -27,6 +33,22 @@
2733
"lora_b_init": "zero_init",
2834
"enable": True,
2935
},
36+
"*output_layer*": {"enable": False},
37+
},
38+
}
39+
40+
LARGE_LORA_CFG_TEST = {
41+
"adapter_type": "lora",
42+
"adapter_name": "default",
43+
"adapter_cfg": {
44+
"*": {
45+
"rank": 128,
46+
"scale": 1,
47+
"lora_a_init": "kaiming_init",
48+
"lora_b_init": "zero_init",
49+
"enable": True,
50+
},
51+
"*output_layer*": {"enable": False},
3052
},
3153
}
3254

@@ -41,6 +63,22 @@
4163
"lora_b_init": "kaiming_init",
4264
"enable": True,
4365
},
66+
"*output_layer*": {"enable": False},
67+
},
68+
}
69+
70+
LARGE_LORA_CFG_RANDOM_INIT_TEST = {
71+
"adapter_type": "lora",
72+
"adapter_name": "random",
73+
"adapter_cfg": {
74+
"*": {
75+
"rank": 128,
76+
"scale": 1,
77+
"lora_a_init": "kaiming_init",
78+
"lora_b_init": "kaiming_init",
79+
"enable": True,
80+
},
81+
"*output_layer*": {"enable": False},
4482
},
4583
}
4684

@@ -55,6 +93,7 @@
5593
"lora_b_init": "kaiming_init",
5694
"enable": True,
5795
},
96+
"*output_layer*": {"enable": False},
5897
},
5998
}
6099

@@ -70,10 +109,25 @@
70109
"lora_b_init": "zero_init",
71110
"enable": True,
72111
},
112+
"*output_layer*": {"enable": False},
73113
},
74114
}
75115

76116

117+
def save_distributed_checkpoint(checkpoint_path, gpt_model):
118+
sharded_state_dict = gpt_model.sharded_state_dict(prefix="")
119+
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path)
120+
121+
122+
def load_distributed_checkpoint(checkpoint_path, gpt_model):
123+
sharded_state_dict = gpt_model.sharded_state_dict(prefix="")
124+
checkpoint = dist_checkpointing.load(
125+
sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path
126+
)
127+
gpt_model.load_state_dict(checkpoint)
128+
return gpt_model
129+
130+
77131
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
78132
"""Build the model."""
79133

@@ -157,8 +211,9 @@ def _test_forward_with_one_lora(lora_config, rank, size):
157211
assert lora_config["adapter_name"] not in module._lora_adapters
158212
else:
159213
# Task: For non-selective configs, all LoRA modules should have the adapter
160-
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
161-
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
214+
for adapter_name in module._lora_adapters:
215+
assert hasattr(module, f"lora_a_{adapter_name}")
216+
assert hasattr(module, f"lora_b_{adapter_name}")
162217
lora_with_adapter_count += 1
163218

164219
assert lora_module_count > 0
@@ -216,11 +271,9 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2, rank, size):
216271

217272
for _, module in model.named_modules():
218273
if isinstance(module, LoRAModule):
219-
assert hasattr(module, f"lora_a_{lora_config_1['adapter_name']}")
220-
assert hasattr(module, f"lora_b_{lora_config_1['adapter_name']}")
221-
assert hasattr(module, f"lora_a_{lora_config_2['adapter_name']}")
222-
assert hasattr(module, f"lora_b_{lora_config_2['adapter_name']}")
223-
assert len(module._lora_adapters) == 2
274+
for adapter_name in module._lora_adapters:
275+
assert hasattr(module, f"lora_a_{adapter_name}")
276+
assert hasattr(module, f"lora_b_{adapter_name}")
224277

225278

226279
@pytest.mark.parametrize(
@@ -237,7 +290,91 @@ def test_forward_with_two_loras(lora_config_1, lora_config_2):
237290
)
238291

239292

240-
# TODO: Save and restore with 1 or 2 GPUs
293+
# TODO: Rank check
294+
def _test_attr_changes_with_one_lora(lora_config, rank, size):
295+
"""Test forward pass with a single LoRA adapter with various configurations."""
296+
hidden_size = 320
297+
initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1)
298+
model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size)
299+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
300+
301+
mtpf.update_model(model, lora_config)
302+
lora_1_output = megatron_prefill(model, prompt_tokens)
303+
304+
for _, module in model.named_modules():
305+
if isinstance(module, LoRAModule):
306+
for adapter_name in module._lora_adapters:
307+
adapter = module._lora_adapters[adapter_name]
308+
adapter["scale"] = 10.0
309+
310+
lora_2_output = megatron_prefill(model, prompt_tokens)
311+
assert not torch.allclose(lora_1_output, lora_2_output)
312+
313+
for _, module in model.named_modules():
314+
if isinstance(module, LoRAModule):
315+
for adapter_name in module._lora_adapters:
316+
adapter = module._lora_adapters[adapter_name]
317+
adapter["scale"] = 1.0
318+
lora_back_output = megatron_prefill(model, prompt_tokens)
319+
320+
assert torch.allclose(lora_1_output, lora_back_output)
321+
322+
323+
@pytest.mark.parametrize(
324+
"lora_config",
325+
[
326+
DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
327+
],
328+
)
329+
def test_attr_changes_with_one_lora(lora_config):
330+
spawn_multiprocess_job(
331+
size=1, job=partial(_test_attr_changes_with_one_lora, lora_config), backend="nccl"
332+
)
333+
334+
335+
def _test_mcore_save_restore(lora_config, tmp_path, rank, size):
336+
hidden_size = 1280
337+
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)
338+
model_ref = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
339+
model_test = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
340+
prompt_tokens = torch.randint(
341+
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
342+
).cuda()
343+
original_output_test = megatron_prefill(model_test, prompt_tokens)
344+
345+
mtpf.update_model(model_ref, lora_config)
346+
347+
lora_output_ref = megatron_prefill(model_ref, prompt_tokens)
348+
349+
save_distributed_checkpoint(tmp_path, model_ref)
350+
save_sharded_modelopt_state([model_ref], tmp_path)
351+
352+
restore_sharded_modelopt_state([model_test], tmp_path)
353+
model_test = load_distributed_checkpoint(tmp_path, model_test)
354+
355+
lora_output_test = megatron_prefill(model_test, prompt_tokens)
356+
357+
# Task: If the save and restore functions work correctly, they should produce the same output.
358+
assert torch.allclose(lora_output_test, lora_output_ref)
359+
360+
assert not torch.allclose(original_output_test, lora_output_test)
361+
362+
363+
@pytest.mark.parametrize("device_count", get_device_counts())
364+
@pytest.mark.parametrize(
365+
"lora_config",
366+
[
367+
DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
368+
],
369+
)
370+
def test_mcore_save_restore(device_count, lora_config, tmp_path):
371+
spawn_multiprocess_job(
372+
size=device_count,
373+
job=partial(_test_mcore_save_restore, lora_config, str(tmp_path)),
374+
backend="nccl",
375+
)
376+
377+
241378
# TODO: Grad check
242379

243380
# def test_edge_cases_and_error_handling():

0 commit comments

Comments
 (0)