3
3
import pytest
4
4
import torch
5
5
from _test_utils .import_helper import skip_if_no_megatron
6
- from _test_utils .torch_dist .dist_utils import spawn_multiprocess_job
6
+ from _test_utils .torch_dist .dist_utils import get_device_counts , spawn_multiprocess_job
7
7
from _test_utils .torch_dist .plugins .megatron_common import (
8
8
get_mcore_gpt_model ,
9
9
initialize_for_megatron ,
10
10
)
11
+ from megatron .core import dist_checkpointing
12
+
13
+ from modelopt .torch .opt .plugins .mcore_dist_checkpointing import (
14
+ restore_sharded_modelopt_state ,
15
+ save_sharded_modelopt_state ,
16
+ )
11
17
12
18
skip_if_no_megatron ()
13
19
27
33
"lora_b_init" : "zero_init" ,
28
34
"enable" : True ,
29
35
},
36
+ "*output_layer*" : {"enable" : False },
37
+ },
38
+ }
39
+
40
+ LARGE_LORA_CFG_TEST = {
41
+ "adapter_type" : "lora" ,
42
+ "adapter_name" : "default" ,
43
+ "adapter_cfg" : {
44
+ "*" : {
45
+ "rank" : 128 ,
46
+ "scale" : 1 ,
47
+ "lora_a_init" : "kaiming_init" ,
48
+ "lora_b_init" : "zero_init" ,
49
+ "enable" : True ,
50
+ },
51
+ "*output_layer*" : {"enable" : False },
30
52
},
31
53
}
32
54
41
63
"lora_b_init" : "kaiming_init" ,
42
64
"enable" : True ,
43
65
},
66
+ "*output_layer*" : {"enable" : False },
67
+ },
68
+ }
69
+
70
+ LARGE_LORA_CFG_RANDOM_INIT_TEST = {
71
+ "adapter_type" : "lora" ,
72
+ "adapter_name" : "random" ,
73
+ "adapter_cfg" : {
74
+ "*" : {
75
+ "rank" : 128 ,
76
+ "scale" : 1 ,
77
+ "lora_a_init" : "kaiming_init" ,
78
+ "lora_b_init" : "kaiming_init" ,
79
+ "enable" : True ,
80
+ },
81
+ "*output_layer*" : {"enable" : False },
44
82
},
45
83
}
46
84
55
93
"lora_b_init" : "kaiming_init" ,
56
94
"enable" : True ,
57
95
},
96
+ "*output_layer*" : {"enable" : False },
58
97
},
59
98
}
60
99
70
109
"lora_b_init" : "zero_init" ,
71
110
"enable" : True ,
72
111
},
112
+ "*output_layer*" : {"enable" : False },
73
113
},
74
114
}
75
115
76
116
117
+ def save_distributed_checkpoint (checkpoint_path , gpt_model ):
118
+ sharded_state_dict = gpt_model .sharded_state_dict (prefix = "" )
119
+ dist_checkpointing .save (sharded_state_dict = sharded_state_dict , checkpoint_dir = checkpoint_path )
120
+
121
+
122
+ def load_distributed_checkpoint (checkpoint_path , gpt_model ):
123
+ sharded_state_dict = gpt_model .sharded_state_dict (prefix = "" )
124
+ checkpoint = dist_checkpointing .load (
125
+ sharded_state_dict = sharded_state_dict , checkpoint_dir = checkpoint_path
126
+ )
127
+ gpt_model .load_state_dict (checkpoint )
128
+ return gpt_model
129
+
130
+
77
131
def _gpt_model_provider (tp_size : int , hidden_size = 256 , vocab_size = 64 , meta_device = False ):
78
132
"""Build the model."""
79
133
@@ -157,8 +211,9 @@ def _test_forward_with_one_lora(lora_config, rank, size):
157
211
assert lora_config ["adapter_name" ] not in module ._lora_adapters
158
212
else :
159
213
# Task: For non-selective configs, all LoRA modules should have the adapter
160
- assert hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
161
- assert hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
214
+ for adapter_name in module ._lora_adapters :
215
+ assert hasattr (module , f"lora_a_{ adapter_name } " )
216
+ assert hasattr (module , f"lora_b_{ adapter_name } " )
162
217
lora_with_adapter_count += 1
163
218
164
219
assert lora_module_count > 0
@@ -216,11 +271,9 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2, rank, size):
216
271
217
272
for _ , module in model .named_modules ():
218
273
if isinstance (module , LoRAModule ):
219
- assert hasattr (module , f"lora_a_{ lora_config_1 ['adapter_name' ]} " )
220
- assert hasattr (module , f"lora_b_{ lora_config_1 ['adapter_name' ]} " )
221
- assert hasattr (module , f"lora_a_{ lora_config_2 ['adapter_name' ]} " )
222
- assert hasattr (module , f"lora_b_{ lora_config_2 ['adapter_name' ]} " )
223
- assert len (module ._lora_adapters ) == 2
274
+ for adapter_name in module ._lora_adapters :
275
+ assert hasattr (module , f"lora_a_{ adapter_name } " )
276
+ assert hasattr (module , f"lora_b_{ adapter_name } " )
224
277
225
278
226
279
@pytest .mark .parametrize (
@@ -237,7 +290,91 @@ def test_forward_with_two_loras(lora_config_1, lora_config_2):
237
290
)
238
291
239
292
240
- # TODO: Save and restore with 1 or 2 GPUs
293
+ # TODO: Rank check
294
+ def _test_attr_changes_with_one_lora (lora_config , rank , size ):
295
+ """Test forward pass with a single LoRA adapter with various configurations."""
296
+ hidden_size = 320
297
+ initialize_for_megatron (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 )
298
+ model = _gpt_model_provider (tp_size = 1 , hidden_size = hidden_size )
299
+ prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
300
+
301
+ mtpf .update_model (model , lora_config )
302
+ lora_1_output = megatron_prefill (model , prompt_tokens )
303
+
304
+ for _ , module in model .named_modules ():
305
+ if isinstance (module , LoRAModule ):
306
+ for adapter_name in module ._lora_adapters :
307
+ adapter = module ._lora_adapters [adapter_name ]
308
+ adapter ["scale" ] = 10.0
309
+
310
+ lora_2_output = megatron_prefill (model , prompt_tokens )
311
+ assert not torch .allclose (lora_1_output , lora_2_output )
312
+
313
+ for _ , module in model .named_modules ():
314
+ if isinstance (module , LoRAModule ):
315
+ for adapter_name in module ._lora_adapters :
316
+ adapter = module ._lora_adapters [adapter_name ]
317
+ adapter ["scale" ] = 1.0
318
+ lora_back_output = megatron_prefill (model , prompt_tokens )
319
+
320
+ assert torch .allclose (lora_1_output , lora_back_output )
321
+
322
+
323
+ @pytest .mark .parametrize (
324
+ "lora_config" ,
325
+ [
326
+ DEFAULT_LORA_CFG_RANDOM_INIT_TEST ,
327
+ ],
328
+ )
329
+ def test_attr_changes_with_one_lora (lora_config ):
330
+ spawn_multiprocess_job (
331
+ size = 1 , job = partial (_test_attr_changes_with_one_lora , lora_config ), backend = "nccl"
332
+ )
333
+
334
+
335
+ def _test_mcore_save_restore (lora_config , tmp_path , rank , size ):
336
+ hidden_size = 1280
337
+ initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 )
338
+ model_ref = _gpt_model_provider (tp_size = size , hidden_size = hidden_size )
339
+ model_test = _gpt_model_provider (tp_size = size , hidden_size = hidden_size )
340
+ prompt_tokens = torch .randint (
341
+ 0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )
342
+ ).cuda ()
343
+ original_output_test = megatron_prefill (model_test , prompt_tokens )
344
+
345
+ mtpf .update_model (model_ref , lora_config )
346
+
347
+ lora_output_ref = megatron_prefill (model_ref , prompt_tokens )
348
+
349
+ save_distributed_checkpoint (tmp_path , model_ref )
350
+ save_sharded_modelopt_state ([model_ref ], tmp_path )
351
+
352
+ restore_sharded_modelopt_state ([model_test ], tmp_path )
353
+ model_test = load_distributed_checkpoint (tmp_path , model_test )
354
+
355
+ lora_output_test = megatron_prefill (model_test , prompt_tokens )
356
+
357
+ # Task: If the save and restore functions work correctly, they should produce the same output.
358
+ assert torch .allclose (lora_output_test , lora_output_ref )
359
+
360
+ assert not torch .allclose (original_output_test , lora_output_test )
361
+
362
+
363
+ @pytest .mark .parametrize ("device_count" , get_device_counts ())
364
+ @pytest .mark .parametrize (
365
+ "lora_config" ,
366
+ [
367
+ DEFAULT_LORA_CFG_RANDOM_INIT_TEST ,
368
+ ],
369
+ )
370
+ def test_mcore_save_restore (device_count , lora_config , tmp_path ):
371
+ spawn_multiprocess_job (
372
+ size = device_count ,
373
+ job = partial (_test_mcore_save_restore , lora_config , str (tmp_path )),
374
+ backend = "nccl" ,
375
+ )
376
+
377
+
241
378
# TODO: Grad check
242
379
243
380
# def test_edge_cases_and_error_handling():
0 commit comments