Skip to content

Commit 1b867f0

Browse files
xin3heHolyFalafel
andauthored
cherry-pick two internal PRs for vllm-gaudi (#2392)
* [SW-228723] Added flag if op supports dynamic quant to ModuleInfo (#314) * [GAUDISW-244137] Set FusedMoE and VllmMixtureOfExperts to support dynamic quantization (#328) * Set FusedMoE to support dynamic quantization * added vllm moe --------- Co-authored-by: Danny Semiat <[email protected]>
1 parent 4063437 commit 1b867f0

File tree

3 files changed

+33
-25
lines changed

3 files changed

+33
-25
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,43 +67,43 @@ def create_mod_info_recursion(parent):
6767

6868
_mod_default_dict = {
6969
"Matmul": ModuleInfo("matmul", PatchedMatmul),
70-
"Linear": ModuleInfo("linear", PatchedLinear),
71-
"ParallelLMHead": ModuleInfo("linear", PatchedParallelLMHead),
72-
"RowParallelLinear": ModuleInfo("row_parallel_linear", PatchedRowParallelLinear),
73-
"ColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear),
74-
"MergedColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear),
75-
"QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear),
76-
"FalconLinear": ModuleInfo("linear", PatchedLinear),
70+
"Linear": ModuleInfo("linear", PatchedLinear, supports_dynamic_quantization=True),
71+
"ParallelLMHead": ModuleInfo("linear", PatchedParallelLMHead, supports_dynamic_quantization=True),
72+
"RowParallelLinear": ModuleInfo("row_parallel_linear", PatchedRowParallelLinear, supports_dynamic_quantization=True),
73+
"ColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear, supports_dynamic_quantization=True),
74+
"MergedColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear, supports_dynamic_quantization=True),
75+
"QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear, supports_dynamic_quantization=True),
76+
"FalconLinear": ModuleInfo("linear", PatchedLinear, supports_dynamic_quantization=True),
7777
"KVCache": ModuleInfo("kv_cache", PatchedKVCache),
7878
"VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache),
7979
"Conv2d": ModuleInfo("linear", PatchedConv2d),
80-
"LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear),
80+
"LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear, supports_dynamic_quantization=True),
8181
"LoRACompatibleConv": ModuleInfo("linear", PatchedLoRACompatibleConv),
8282
"Softmax": ModuleInfo("softmax", PatchedSoftmax),
8383
"BlockSoftmaxConstMax": ModuleInfo("softmax", PatchedBlockSoftmaxConstMax),
8484
"ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA),
85-
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul),
86-
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul),
87-
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear),
85+
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul, supports_dynamic_quantization=True),
86+
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul, supports_dynamic_quantization=True),
87+
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear, supports_dynamic_quantization=True),
8888
# Note: `no_quantize_op` indicates that this module is patched but does not require measurement or quantization.
89-
"FusedMoE": ModuleInfo("no_quantize_op", PatchedMixtralMoE, False),
90-
"SharedFusedMoE": ModuleInfo("no_quantize_op", PatchedMixtralMoE, False),
89+
"FusedMoE": ModuleInfo("no_quantize_op", PatchedMixtralMoE, False, supports_dynamic_quantization=True),
90+
"SharedFusedMoE": ModuleInfo("no_quantize_op", PatchedMixtralMoE, False, supports_dynamic_quantization=True),
9191
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
9292
"GaudiDeepseekV3MoE": ModuleInfo("dynamic_moe", PatchedGaudiDeepseekV3MoE),
9393
"GaudiFP8Linear": ModuleInfo("linear", PatchedMoeFP8Matmul),
94-
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp),
95-
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),
96-
"VllmMixtureOfExpertsOpFP8PerChannel": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),
94+
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp, supports_dynamic_quantization=True),
95+
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8, supports_dynamic_quantization=True),
96+
"VllmMixtureOfExpertsOpFP8PerChannel": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8, supports_dynamic_quantization=True),
9797
}
9898

9999

100100
if deepspeed_exists:
101101
_mod_default_dict.update(
102102
{
103-
"LinearLayer": ModuleInfo("linear", PatchedLinear),
104-
"LinearAllreduce": ModuleInfo("linear", PatchedLinearAllReduce),
105-
"ScopedLinearAllReduce": ModuleInfo("linear", PatchedLinearAllReduce),
106-
"LmHeadLinearAllreduce": ModuleInfo("linear", PatchedLmHeadLinearAllreduce),
103+
"LinearLayer": ModuleInfo("linear", PatchedLinear, supports_dynamic_quantization=True),
104+
"LinearAllreduce": ModuleInfo("linear", PatchedLinearAllReduce, supports_dynamic_quantization=True),
105+
"ScopedLinearAllReduce": ModuleInfo("linear", PatchedLinearAllReduce, supports_dynamic_quantization=True),
106+
"LmHeadLinearAllreduce": ModuleInfo("linear", PatchedLmHeadLinearAllreduce, supports_dynamic_quantization=True),
107107
}
108108
)
109109

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,16 @@ class DeviceForScalesType(Enum):
9898
]
9999

100100
# TODO [SW-217813]: support dynamic quantization in all ops and remove
101-
# TODO [SW-228723]: get a better way to list all linear ops, like set in ModuleInfo if supports dynamic
102-
supported_dynamic_ops = ["Linear", "RowParallelLinear", "ColumnParallelLinear", "MergedColumnParallelLinear", "QKVParallelLinear", "FalconLinear", "LoRACompatibleLinear", "ReplicatedLinear", "LinearLayer", "LinearAllreduce", "ScopedLinearAllReduce", "LmHeadLinearAllreduce", "FusedMoE", "GaudiMixtralSparseMoeBlock", "VllmMixtureOfExpertsOp", "VllmMixtureOfExpertsOpFP8", "GaudiDeepseekV3MoE", "GaudiFP8Linear"]
101+
from neural_compressor.torch.algorithms.fp8_quant.model_configs import get_patched_module_table, ModuleInfo
102+
103103
def is_supported_dynamic_op(op_str):
104-
ret = op_str in supported_dynamic_ops
104+
"""
105+
Dynamically checks if the given op supports dynamic quantization
106+
by looking up its ModuleInfo and checking for a 'supports_dynamic_quantization' attribute.
107+
"""
108+
patched_table = get_patched_module_table()
109+
info = patched_table.get(op_str)
110+
ret = getattr(info, "supports_dynamic_quantization", False) if info is not None else False
105111
logger.trace("Checking if %s is supported for dynamic quantization: %s", op_str, ret)
106112
return ret
107113

neural_compressor/torch/algorithms/fp8_quant/model_configs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,18 @@ class ModuleInfo:
4040
Configures a relation between a ModuleType key (from `_mod_types` dict in `common.py`)
4141
to a PatchedModule class.
4242
"""
43-
def __init__(self, type, patched_module, should_measure_and_quant=True):
43+
def __init__(self, type, patched_module, should_measure_and_quant=True, *, supports_dynamic_quantization=False):
4444
self.type = type
4545
self.patched_module = patched_module
4646
self.should_measure_and_quant = should_measure_and_quant
47+
self.supports_dynamic_quantization = supports_dynamic_quantization
4748

4849
def __repr__(self):
4950
return (
5051
f"ModuleInfo(type={self.type}, "
5152
f"patched_module={self.patched_module.__name__}), "
52-
f"should_measure_and_quant={self.should_measure_and_quant}"
53+
f"should_measure_and_quant={self.should_measure_and_quant}, "
54+
f"supports_dynamic_quantization={self.supports_dynamic_quantization}"
5355
)
5456

5557

0 commit comments

Comments
 (0)