58
58
},
59
59
}
60
60
61
- LARGE_SCALE_LORA_CFG = {
62
- "adapter_type" : "lora" ,
63
- "adapter_name" : "large_scale" ,
64
- "adapter_cfg" : {
65
- "*" : {
66
- "rank" : 16 ,
67
- "scale" : 10.0 ,
68
- "lora_a_init" : "kaiming_init" ,
69
- "lora_b_init" : "zero_init" ,
70
- "enable" : True ,
71
- },
72
- },
73
- }
74
-
75
61
SELECTIVE_LAYER_LORA_CFG = {
76
62
"adapter_type" : "lora" ,
77
63
"adapter_name" : "selective" ,
@@ -130,16 +116,25 @@ def _test_forward_with_one_lora(lora_config, rank, size):
130
116
lora_output = megatron_prefill (model , prompt_tokens )
131
117
assert lora_output .shape == original_output .shape
132
118
if lora_config == DEFAULT_LORA_CFG_RANDOM_INIT_TEST :
119
+ # Task: To verify that the LoRA output differs from the original
120
+ # output since two LoRA layers are initialized randomly.
133
121
assert not torch .allclose (lora_output , original_output , rtol = 1e-5 )
134
122
else :
123
+ # Task: The LoRA output should match the original output if two
124
+ # LoRA layers are initialized in the standard way
125
+ # (one with random values and one with zeros).
135
126
assert torch .allclose (lora_output , original_output , rtol = 1e-5 ), (
136
127
f"{ lora_output } , { original_output } "
137
128
)
138
129
mtpf .disable_adapters (model )
139
130
lora_disabled_output = megatron_prefill (model , prompt_tokens )
131
+ # Task: Since all LoRA layers are disabled, the output should
132
+ # be identical to the original output.
140
133
assert torch .allclose (lora_disabled_output , original_output , rtol = 1e-5 )
141
134
mtpf .enable_adapters (model )
142
135
lora_reenabled_output = megatron_prefill (model , prompt_tokens )
136
+ # Task: To verify that toggling LoRA layers from disabled
137
+ # to enabled does not alter the output, the output should remain unchanged.
143
138
assert torch .allclose (lora_reenabled_output , lora_output , rtol = 1e-5 )
144
139
lora_module_count = 0
145
140
lora_with_adapter_count = 0
@@ -149,19 +144,19 @@ def _test_forward_with_one_lora(lora_config, rank, size):
149
144
150
145
if lora_config == SELECTIVE_LAYER_LORA_CFG :
151
146
if "self_attention" in name :
152
- # Only self_attention modules should have the adapter
147
+ # Task: Only self_attention modules should have the adapter
153
148
assert hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
154
149
assert hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
155
150
assert lora_config ["adapter_name" ] in module ._lora_adapters
156
151
assert module ._lora_adapters [lora_config ["adapter_name" ]]["enable" ]
157
152
lora_with_adapter_count += 1
158
153
else :
159
- # Other modules should NOT have the adapter at all
154
+ # Task: Other modules should NOT have the adapter at all
160
155
assert not hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
161
156
assert not hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
162
157
assert lora_config ["adapter_name" ] not in module ._lora_adapters
163
158
else :
164
- # For non-selective configs, all LoRA modules should have the adapter
159
+ # Task: For non-selective configs, all LoRA modules should have the adapter
165
160
assert hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
166
161
assert hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
167
162
lora_with_adapter_count += 1
0 commit comments