forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
143 lines (127 loc) · 4.66 KB
/
utils.py
File metadata and controls
143 lines (127 loc) · 4.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import functools
from collections import namedtuple
from loguru import logger
__all__ = [
"get_layer_mappings_from_architecture",
"MAPPINGS_REGISTRY",
"DEFAULT_SMOOTHQUANT_MAPPINGS",
]
LayerMapType = tuple[list[str], str]
LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"])
DEFAULT_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*gate_proj", "re:.*up_proj"],
smooth_layers="re:.*post_attention_layernorm",
),
]
MIXTRAL_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
),
]
BLOOM_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=["re:.*query_key_value"],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*dense_h_to_4h"],
smooth_layers="re:.*post_attention_layernorm",
),
]
PHI3_VISION_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=["re:.*qkv_proj"],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*gate_up_proj"],
smooth_layers="re:.*post_attention_layernorm",
),
]
WHISPER_V2_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=["re:.*k_proj", "re:.*v_proj", "re:.*q_proj"],
smooth_layers="re:.*self_attn_layer_norm",
),
LayerMap(
balance_layers=["re:.*fc1"],
smooth_layers="re:.*final_layer_norm",
),
]
DEEPSEEK_V2_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=["re:.*q(_a)?_proj$", "re:.*kv_a_proj_with_mqa"],
smooth_layers="re:.*input_layernorm",
),
]
AFMOE_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=[
"re:.*self_attn\\.q_proj",
"re:.*self_attn\\.k_proj",
"re:.*self_attn\\.v_proj",
"re:.*self_attn\\.gate_proj",
],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*mlp.*gate_proj", "re:.*mlp.*up_proj"],
smooth_layers="re:.*pre_mlp_layernorm",
),
]
# Registry of layer mappings for different architectures
# Add more mappings here
MAPPINGS_REGISTRY: dict[str, list[LayerMap]] = {
"BloomForCausalLM": BLOOM_SMOOTHQUANT_MAPPINGS,
"ChatGLMForConditionalGeneration": BLOOM_SMOOTHQUANT_MAPPINGS,
"DeepseekV2ForCausalLM": DEEPSEEK_V2_SMOOTHQUANT_MAPPINGS,
"Gemma2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Gemma3ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Gemma3ForConditionalGeneration": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Llama4ForConditionalGeneration": DEFAULT_SMOOTHQUANT_MAPPINGS,
"LlamaForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Mistral3ForConditionalGeneration": DEFAULT_SMOOTHQUANT_MAPPINGS,
"MistralForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"MixtralForCausalLM": MIXTRAL_SMOOTHQUANT_MAPPINGS,
"Phi3VForCausalLM": PHI3_VISION_SMOOTHQUANT_MAPPINGS,
"Qwen2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Qwen3ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"WhisperForConditionalGeneration": WHISPER_V2_SMOOTHQUANT_MAPPINGS,
"AfmoeForCausalLM": AFMOE_SMOOTHQUANT_MAPPINGS,
}
def get_layer_mappings_from_architecture(architecture: str) -> list[LayerMap]:
"""
:param architecture: str: The architecture of the model
:return: list: The layer mappings for the given architecture
"""
if architecture not in MAPPINGS_REGISTRY:
logger.info(
f"Architecture {architecture} not found in mappings. "
f"Using default mappings: {DEFAULT_SMOOTHQUANT_MAPPINGS}"
)
return MAPPINGS_REGISTRY.get(architecture, DEFAULT_SMOOTHQUANT_MAPPINGS)
def handle_mapping_resolution_errors(func):
"""
Decorator to catch any errors that occur when resolving mappings and provide a
helpful error message to the user pointing them to the README
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as original_exception:
readme_location = (
"https://github.com/vllm-project/llm-compressor/tree/main/"
"src/llmcompressor/modifiers/transform/smoothquant"
)
raise RuntimeError(
f"Error resolving mappings for given architecture."
f"Please refer to the README at {readme_location} for more information."
) from original_exception
return wrapper