Skip to content

Commit fc0d6e8

Browse files
committed
updated unit tests
Signed-off-by: Suguna Velury <[email protected]>
1 parent 11d4cf4 commit fc0d6e8

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

modelopt/torch/quantization/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,15 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
712712
# Assert that all the modules in the module list are present in this fsdp_param_group
713713
if len(modules_to_update) > 1:
714714
for module in modules_to_update:
715-
name = _get_module_name(module, root_model)
716-
assert name in fsdp_param_mapping, (
717-
f"Module {module} not found in fsdp_param_mapping"
715+
module_name = _get_module_name(module, root_model)
716+
# Check if any parameter from this module is in the mapping
717+
module_params_in_mapping = any(
718+
f"{module_name}.{n}" in fsdp_param_mapping
719+
for n, _ in module.named_parameters()
720+
)
721+
assert module_params_in_mapping, (
722+
f"Module {module} with name '{module_name}' not found in fsdp_param_mapping. "
723+
f"Available keys: {list(fsdp_param_mapping.keys())}"
718724
)
719725
# Yields for necessary weight updates/processing
720726
yield

tests/_test_utils/torch/export/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,18 @@ def forward(self, x):
5555

5656

5757
class SmallQKVModel(torch.nn.Module):
58-
def __init__(self, dim=4, device="cuda", apply_embed=False):
58+
def __init__(self, dim=4, device="cuda", apply_embed=False, bias=False):
5959
super().__init__()
6060
self.embedding = torch.nn.Embedding(2, dim)
61-
self.q_proj = torch.nn.Linear(dim, dim, bias=False)
62-
self.k_proj = torch.nn.Linear(dim, dim, bias=False)
63-
self.v_proj = torch.nn.Linear(dim, dim, bias=False)
64-
self.o_proj = torch.nn.Linear(dim, dim, bias=False)
61+
self.q_proj = torch.nn.Linear(dim, dim, bias=bias)
62+
self.k_proj = torch.nn.Linear(dim, dim, bias=bias)
63+
self.v_proj = torch.nn.Linear(dim, dim, bias=bias)
64+
self.o_proj = torch.nn.Linear(dim, dim, bias=bias)
6565
self.device = device
6666
self.config = None
6767
self.apply_embed = apply_embed
6868
# TODO: Debug why fsdp2 modifies bias of layernorm for awq
69-
self.input_layernorm = torch.nn.LayerNorm(dim, bias=False)
69+
self.input_layernorm = torch.nn.LayerNorm(dim, bias=bias)
7070

7171
def forward(self, x):
7272
if self.apply_embed:

tests/gpu/torch/export/test_fsdp2_export.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def _compare_parameters_and_buffers(model1, model2):
118118
)
119119

120120

121-
def _fuse_layers(rank, size, quant_config):
121+
def _fuse_layers(rank, size, quant_config, bias):
122122
with patch_fsdp_mp_dtypes():
123123
# Initialize model
124-
model = SmallQKVModel(dim=32).to("cuda")
125-
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
124+
model = SmallQKVModel(dim=32, bias=bias).to("cuda")
125+
non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda")
126126
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
127127
model.eval()
128128
non_fsdp_model.eval()
@@ -159,15 +159,15 @@ def calib_fn(x):
159159
_compare_parameters_and_buffers(model, non_fsdp_model)
160160

161161

162-
def _export_quantized_weight_test(rank, size, quant_config):
162+
def _export_quantized_weight_test(rank, size, quant_config, bias):
163163
import copy
164164

165165
from torch.distributed._composable.fsdp import fully_shard
166166

167167
with patch_fsdp_mp_dtypes():
168168
# Initialize model
169-
model = SmallQKVModel(dim=32).to("cuda")
170-
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
169+
model = SmallQKVModel(dim=32, bias=bias).to("cuda")
170+
non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda")
171171
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
172172
model.eval()
173173
non_fsdp_model.eval()
@@ -247,10 +247,11 @@ def test_fsdp2_weight_update_context_for_export(device_count):
247247
],
248248
)
249249
@pytest.mark.parametrize("device_count", get_device_counts())
250-
def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config):
250+
@pytest.mark.parametrize("bias", [True, False])
251+
def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config, bias):
251252
spawn_multiprocess_job(
252253
size=device_count,
253-
job=partial(_fuse_layers, quant_config=quant_config),
254+
job=partial(_fuse_layers, quant_config=quant_config, bias=bias),
254255
backend="nccl",
255256
)
256257

@@ -270,9 +271,10 @@ def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config)
270271
],
271272
)
272273
@pytest.mark.parametrize("device_count", get_device_counts())
273-
def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config):
274+
@pytest.mark.parametrize("bias", [True, False])
275+
def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config, bias):
274276
spawn_multiprocess_job(
275277
size=device_count,
276-
job=partial(_export_quantized_weight_test, quant_config=quant_config),
278+
job=partial(_export_quantized_weight_test, quant_config=quant_config, bias=bias),
277279
backend="nccl",
278280
)

0 commit comments

Comments
 (0)