Skip to content

Commit 48e9ab5

Browse files
committed
Update comments for test cases
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 8df12bc commit 48e9ab5

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,6 @@
5858
},
5959
}
6060

61-
LARGE_SCALE_LORA_CFG = {
62-
"adapter_type": "lora",
63-
"adapter_name": "large_scale",
64-
"adapter_cfg": {
65-
"*": {
66-
"rank": 16,
67-
"scale": 10.0,
68-
"lora_a_init": "kaiming_init",
69-
"lora_b_init": "zero_init",
70-
"enable": True,
71-
},
72-
},
73-
}
74-
7561
SELECTIVE_LAYER_LORA_CFG = {
7662
"adapter_type": "lora",
7763
"adapter_name": "selective",
@@ -130,16 +116,25 @@ def _test_forward_with_one_lora(lora_config, rank, size):
130116
lora_output = megatron_prefill(model, prompt_tokens)
131117
assert lora_output.shape == original_output.shape
132118
if lora_config == DEFAULT_LORA_CFG_RANDOM_INIT_TEST:
119+
# Task: To verify that the LoRA output differs from the original
120+
# output since two LoRA layers are initialized randomly.
133121
assert not torch.allclose(lora_output, original_output, rtol=1e-5)
134122
else:
123+
# Task: The LoRA output should match the original output if two
124+
# LoRA layers are initialized in the standard way
125+
# (one with random values and one with zeros).
135126
assert torch.allclose(lora_output, original_output, rtol=1e-5), (
136127
f"{lora_output}, {original_output}"
137128
)
138129
mtpf.disable_adapters(model)
139130
lora_disabled_output = megatron_prefill(model, prompt_tokens)
131+
# Task: Since all LoRA layers are disabled, the output should
132+
# be identical to the original output.
140133
assert torch.allclose(lora_disabled_output, original_output, rtol=1e-5)
141134
mtpf.enable_adapters(model)
142135
lora_reenabled_output = megatron_prefill(model, prompt_tokens)
136+
# Task: To verify that toggling LoRA layers from disabled
137+
# to enabled does not alter the output, the output should remain unchanged.
143138
assert torch.allclose(lora_reenabled_output, lora_output, rtol=1e-5)
144139
lora_module_count = 0
145140
lora_with_adapter_count = 0
@@ -149,19 +144,19 @@ def _test_forward_with_one_lora(lora_config, rank, size):
149144

150145
if lora_config == SELECTIVE_LAYER_LORA_CFG:
151146
if "self_attention" in name:
152-
# Only self_attention modules should have the adapter
147+
# Task: Only self_attention modules should have the adapter
153148
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
154149
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
155150
assert lora_config["adapter_name"] in module._lora_adapters
156151
assert module._lora_adapters[lora_config["adapter_name"]]["enable"]
157152
lora_with_adapter_count += 1
158153
else:
159-
# Other modules should NOT have the adapter at all
154+
# Task: Other modules should NOT have the adapter at all
160155
assert not hasattr(module, f"lora_a_{lora_config['adapter_name']}")
161156
assert not hasattr(module, f"lora_b_{lora_config['adapter_name']}")
162157
assert lora_config["adapter_name"] not in module._lora_adapters
163158
else:
164-
# For non-selective configs, all LoRA modules should have the adapter
159+
# Task: For non-selective configs, all LoRA modules should have the adapter
165160
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
166161
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
167162
lora_with_adapter_count += 1

0 commit comments

Comments
 (0)