Skip to content

Commit 9021845

Browse files
committed
export working, cleanup needed
Signed-off-by: Suguna Velury <[email protected]>
1 parent d48aaf2 commit 9021845

File tree

6 files changed

+435
-107
lines changed

6 files changed

+435
-107
lines changed

modelopt/torch/export/layer_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,9 @@ def is_moe(module: nn.Module) -> bool:
345345

346346
def is_quantlinear(module: nn.Module) -> bool:
347347
"""Returns whether the module is a quantized linear layer."""
348-
return "QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower()
348+
return (
349+
"QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower()
350+
) or ("Quant" in type(module).__name__ and "Linear" in type(module).__name__)
349351

350352

351353
def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:

modelopt/torch/export/unified_export_hf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,11 @@ def _output_hook(module, input, output):
150150
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
151151
model(fake_input, decoder_input_ids=decoder_fake_input)
152152
else:
153-
print("DEBUG LOG: Calling model(fake_input)")
154153
model(fake_input)
155154

156155
for handle in handles:
157156
handle.remove()
158157

159-
print(f"DEBUG LOG: input_to_linear: {input_to_linear}")
160-
161158
for tensor, modules in input_to_linear.items():
162159
quantization_format = get_quantization_format(modules[0])
163160
if len(modules) > 1 and quantization_format not in [
@@ -177,7 +174,8 @@ def _output_hook(module, input, output):
177174
and tensor in output_to_layernorm
178175
):
179176
# Pre quant scale of modules is already updated to avg_pre_quant_scale
180-
fuse_prequant_layernorm(output_to_layernorm[tensor], modules)
177+
with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]):
178+
fuse_prequant_layernorm(output_to_layernorm[tensor], modules)
181179

182180
# The dummy forward may not be able to activate all the experts.
183181
# Process experts by naming rules like experts.0, experts.1, etc.
@@ -470,7 +468,8 @@ def _export_hf_checkpoint(
470468
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
471469
has_quantized_layers = True
472470
if is_quantlinear(sub_module):
473-
_export_quantized_weight(sub_module, dtype)
471+
with fsdp2_aware_weight_update(model, sub_module):
472+
_export_quantized_weight(sub_module, dtype)
474473
elif (
475474
"Llama4TextExperts" in type(sub_module).__name__
476475
or "GptOssExperts" in type(sub_module).__name__
@@ -488,7 +487,8 @@ def _export_hf_checkpoint(
488487
)
489488
# Export the quantized weights
490489
for weight_name in ["gate_up_proj", "down_proj"]:
491-
_export_quantized_weight(sub_module, dtype, weight_name)
490+
with fsdp2_aware_weight_update(model, sub_module):
491+
_export_quantized_weight(sub_module, dtype, weight_name)
492492

493493
quantized_state_dict = model.state_dict()
494494

modelopt/torch/quantization/qtensor/base_qtensor.py

Lines changed: 74 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -274,76 +274,88 @@ def fsdp2_aware_weight_update(root_model, modules_to_update):
274274

275275
from modelopt.torch.quantization.utils import _get_enclosing_fsdp_module, _get_module_name
276276

277-
breakpoint()
278-
# Get FSDP root module, if none is returned, then the update is not made to a submodule of an FSDPModule
279-
if not isinstance(modules_to_update, list):
280-
modules_to_update = [modules_to_update]
281-
282-
root_modules = set()
283-
for module in modules_to_update:
284-
root_module = _get_enclosing_fsdp_module(module, root_model)
285-
root_modules.add(root_module)
286-
287-
# Ensure all modules in root_modules are the same
288-
assert len(root_modules) == 1, "All modules must be in the same root FSDPModule"
289-
root_module = next(iter(root_modules))
290-
291-
# Check if root module state is sharded and unshard if needed
292-
if fully_shard.state(root_module)._fsdp_param_group.is_sharded:
293-
with enable_fake_quant(root_module):
294-
root_module.unshard()
295-
296-
# Get FSDPParam list
297-
fsdp_param_group = fully_shard.state(root_module)._fsdp_param_group
298-
fsdp_param_mapping = _create_fsdp_param_mapping(fsdp_param_group.fsdp_params, root_module)
299-
300-
# Assert that all the modules in the module list are present in this fsdp_param_group
301-
for module in modules_to_update:
302-
name = _get_module_name(module, root_module)
303-
assert name in fsdp_param_mapping, f"Module {module} not found in fsdp_param_mapping"
277+
if isinstance(root_model, FSDPModule):
278+
# Get FSDP root module, if none is returned, then the update is not made to a submodule of an FSDPModule
279+
if not isinstance(modules_to_update, list):
280+
modules_to_update = [modules_to_update]
281+
282+
root_modules = set()
283+
for module in modules_to_update:
284+
root_module = _get_enclosing_fsdp_module(module, root_model)
285+
root_modules.add(root_module)
286+
287+
# Ensure all modules in root_modules are the same
288+
assert len(root_modules) == 1, "All modules must be in the same root FSDPModule"
289+
root_module = next(iter(root_modules))
290+
291+
# Check if root module state is sharded and unshard if needed
292+
if fully_shard.state(root_module)._fsdp_param_group.is_sharded:
293+
with enable_fake_quant(root_module):
294+
root_module.unshard()
295+
296+
# Get FSDPParam list
297+
fsdp_param_group = fully_shard.state(root_module)._fsdp_param_group
298+
fsdp_param_mapping = _create_fsdp_param_mapping(
299+
fsdp_param_group.fsdp_params, root_model
300+
)
304301

302+
# Assert that all the modules in the module list are present in this fsdp_param_group
303+
for module in modules_to_update:
304+
name = _get_module_name(module, root_model)
305+
assert name in fsdp_param_mapping, (
306+
f"Module {module} not found in fsdp_param_mapping"
307+
)
305308
# Yields for necessary weight updates/processing
306309
yield
307310
finally:
308-
# Update FSDPParam list
309-
for module in modules_to_update:
310-
name = _get_module_name(module, root_module)
311-
old_fsdp_param = fsdp_param_mapping[name]
312-
313-
# Update mp policy to reflect the new dtype
314-
new_mp_policy = MixedPrecisionPolicy(
315-
param_dtype=module.weight.dtype,
316-
reduce_dtype=None,
317-
output_dtype=None,
318-
cast_forward_inputs=False,
319-
)
311+
from torch.distributed.fsdp import fully_shard
320312

321-
with no_requires_grad():
322-
# Create a new QFSDPParam or FSDPParam based on weight type
323-
param_class = QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam
324-
new_param = param_class(
325-
module.weight,
326-
old_fsdp_param._module_info,
327-
old_fsdp_param.mesh_info,
328-
old_fsdp_param.post_forward_mesh_info,
329-
old_fsdp_param.device,
330-
None,
331-
new_mp_policy,
332-
None,
313+
from modelopt.torch.quantization.utils import _get_enclosing_fsdp_module, _get_module_name
314+
315+
if isinstance(root_model, FSDPModule):
316+
# Update FSDPParam list
317+
for module in modules_to_update:
318+
name = _get_module_name(module, root_model)
319+
old_fsdp_param = fsdp_param_mapping[name]
320+
321+
# Update mp policy to reflect the new dtype
322+
new_mp_policy = MixedPrecisionPolicy(
323+
param_dtype=module.weight.dtype,
324+
reduce_dtype=None,
325+
output_dtype=None,
326+
cast_forward_inputs=False,
333327
)
334328

335-
# Update the FSDPParam mapping to keep track of the new FSDPParam
336-
fsdp_param_mapping[name] = new_param
329+
with no_requires_grad():
330+
# Create a new QFSDPParam or FSDPParam based on weight type
331+
param_class = (
332+
QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam
333+
)
334+
new_param = param_class(
335+
module.weight,
336+
old_fsdp_param._module_info,
337+
old_fsdp_param.mesh_info,
338+
old_fsdp_param.post_forward_mesh_info,
339+
old_fsdp_param.device,
340+
None,
341+
new_mp_policy,
342+
None,
343+
)
344+
if not isinstance(new_param, QFSDPParam):
345+
new_param.init_dtype_attrs(new_mp_policy)
346+
347+
# Update the FSDPParam mapping to keep track of the new FSDPParam
348+
fsdp_param_mapping[name] = new_param
337349

338-
# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
339-
old_fsdp_param._post_load_hook_handle.remove()
350+
# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
351+
old_fsdp_param._post_load_hook_handle.remove()
340352

341-
# Update FSDPParam list with new compressed weights
342-
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())
353+
# Update FSDPParam list with new compressed weights
354+
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())
343355

344-
# Reshard FSDP root module
345-
# TODO: Check if reshard is needed or not
346-
root_module.reshard()
356+
# Reshard FSDP root module
357+
# TODO: Check if reshard is needed or not
358+
root_module.reshard()
347359

348360

349361
def pack_real_quantize_weight(module, force_quantize: bool = False):
@@ -422,39 +434,8 @@ def _compress_fsdp_module(fsdp_module):
422434
if name not in fsdp_param_mapping:
423435
continue
424436

425-
if _compress_and_update_module_weight(submodule):
426-
old_fsdp_param = fsdp_param_mapping[name]
427-
428-
# Update mp policy to reflect the new dtype
429-
new_mp_policy = MixedPrecisionPolicy(
430-
param_dtype=submodule.weight.dtype,
431-
reduce_dtype=None,
432-
output_dtype=None,
433-
cast_forward_inputs=False,
434-
)
435-
with no_requires_grad():
436-
# Create a new QFSDPParam parameter
437-
new_param = QFSDPParam(
438-
submodule.weight,
439-
old_fsdp_param._module_info,
440-
old_fsdp_param.mesh_info,
441-
old_fsdp_param.post_forward_mesh_info,
442-
old_fsdp_param.device,
443-
None,
444-
new_mp_policy,
445-
None,
446-
)
447-
448-
# Update the FSDPParam mapping to keep track of the new FSDPParam
449-
fsdp_param_mapping[name] = new_param
450-
# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
451-
old_fsdp_param._post_load_hook_handle.remove()
452-
453-
# Update FSDPParam list with new compressed weights
454-
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())
455-
456-
# Reshard FSDP root module
457-
fsdp_module.reshard()
437+
with fsdp2_aware_weight_update(fsdp_module, submodule):
438+
_compress_and_update_module_weight(submodule)
458439

459440
with SequentialQuantizer.convert_to_single_quantizer(module), torch.no_grad():
460441
for _, m in module.named_modules():

tests/_test_utils/torch_export/export_utils.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,22 @@
1818

1919
# Models
2020
class ToyModel(torch.nn.Module):
21-
def __init__(self, dims=[10, 10, 10, 10]):
21+
def __init__(self, dims=[10, 10, 10, 10], bias=True):
2222
super().__init__()
2323
assert len(dims) >= 2
2424
if len(dims) == 2:
25-
self.linears = torch.nn.Linear(dims[0], dims[1])
25+
self.linears = torch.nn.Linear(dims[0], dims[1], bias=bias)
2626
else:
27-
linears = [torch.nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)]
27+
linears = [
28+
torch.nn.Linear(dims[i], dims[i + 1], bias=bias) for i in range(len(dims) - 1)
29+
]
2830
self.linears = torch.nn.Sequential(*linears)
2931

3032
def forward(self, x):
3133
return self.linears(x)
3234

3335

34-
class SmallQKVModel(torch.nn.Module):
36+
class SmallLinearModelwithCustomWeight(torch.nn.Module):
3537
def __init__(self, weights):
3638
super().__init__()
3739
self.q_proj = torch.nn.Linear(weights[0].shape[1], weights[0].shape[0], bias=False)
@@ -52,6 +54,35 @@ def forward(self, x):
5254
return x
5355

5456

57+
class SmallQKVModel(torch.nn.Module):
58+
def __init__(self, dim=4, device="cuda", apply_embed=False):
59+
super().__init__()
60+
self.embedding = torch.nn.Embedding(2, dim)
61+
self.q_proj = torch.nn.Linear(dim, dim, bias=False)
62+
self.k_proj = torch.nn.Linear(dim, dim, bias=False)
63+
self.v_proj = torch.nn.Linear(dim, dim, bias=False)
64+
self.o_proj = torch.nn.Linear(dim, dim, bias=False)
65+
self.device = device
66+
self.config = None
67+
self.apply_embed = apply_embed
68+
# TODO: Debug why fsdp2 modifies bias of layernorm for awq
69+
self.input_layernorm = torch.nn.LayerNorm(dim, bias=False)
70+
71+
def forward(self, x):
72+
if self.apply_embed:
73+
x = self.embedding(x)
74+
75+
x = self.input_layernorm(x)
76+
q_proj = self.q_proj(x)
77+
k_proj = self.k_proj(x)
78+
v_proj = self.v_proj(x)
79+
scores = torch.matmul(q_proj, k_proj.transpose(-2, -1))
80+
attn = torch.nn.functional.softmax(scores, dim=-1)
81+
x = torch.matmul(attn, v_proj)
82+
o_proj = self.o_proj(x)
83+
return o_proj
84+
85+
5586
# Quantization configs
5687
partial_fp8_config = {
5788
"quant_cfg": {

tests/gpu/torch/export/test_export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
import torch
1818
from _test_utils.torch_export.export_utils import (
19-
SmallQKVModel,
19+
SmallLinearModelwithCustomWeight,
2020
ToyModel,
2121
only_input_quantizer_fp8_config,
2222
only_output_quantizer_fp8_config,
@@ -306,7 +306,7 @@ def test_adjust_attn_amax_values(
306306
q_weight, k_weight, v_weight, o_weight, expected_qkv_amax, expected_o_amax, config
307307
):
308308
# Initialize model and quantize to insert quantizers
309-
model = SmallQKVModel([q_weight, k_weight, v_weight, o_weight]).to("cuda")
309+
model = SmallLinearModelwithCustomWeight([q_weight, k_weight, v_weight, o_weight]).to("cuda")
310310
mtq.quantize(model, config, lambda x: x(torch.randn(1, 4, q_weight.shape[1], device="cuda")))
311311
adjust_attn_amax_values(model)
312312
# Weight quantizer amax must remain unchanged for non qkv layers
@@ -375,11 +375,12 @@ def test_get_scaling_factor(
375375
q_weight, k_weight, v_weight, o_weight, config, expected_amax, maxbound
376376
):
377377
# Initialize model and quantize to insert quantizers
378-
model = SmallQKVModel([q_weight, k_weight, v_weight, o_weight]).to("cuda")
378+
model = SmallLinearModelwithCustomWeight([q_weight, k_weight, v_weight, o_weight]).to("cuda")
379379
mtq.quantize(model, config, lambda x: x(torch.ones(1, 2, q_weight.shape[1], device="cuda")))
380380
for name, module in model.named_modules():
381381
if isinstance(module, TensorQuantizer) and module.is_enabled:
382382
scale = get_scaling_factor(module)
383+
print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")
383384
assert torch.allclose(
384385
scale,
385386
torch.tensor((expected_amax[0] / maxbound), dtype=scale.dtype),

0 commit comments

Comments
 (0)