|
11 | 11 | from vllm.model_executor.layers.quantization.base_config import (
|
12 | 12 | QuantizationConfig)
|
13 | 13 | from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
14 |
| -from vllm.model_executor.utils import set_weight_attrs |
| 14 | +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, |
| 15 | + GroupQuantScaleParameter, |
| 16 | + PackedColumnParameter, |
| 17 | + PackedvLLMParameter, |
| 18 | + RowvLLMParameter) |
15 | 19 |
|
16 | 20 |
|
17 | 21 | class GPTQConfig(QuantizationConfig):
|
@@ -108,6 +112,7 @@ def create_weights(
|
108 | 112 | **extra_weight_attrs,
|
109 | 113 | ):
|
110 | 114 | del output_size # Unused.
|
| 115 | + weight_loader = extra_weight_attrs.get("weight_loader") |
111 | 116 | if input_size_per_partition % self.quant_config.group_size != 0:
|
112 | 117 | raise ValueError(
|
113 | 118 | "The input size is not aligned with the quantized "
|
@@ -138,73 +143,81 @@ def create_weights(
|
138 | 143 | scale_and_zero_size = input_size_per_partition // group_size
|
139 | 144 | scale_and_zero_input_dim = 0
|
140 | 145 |
|
141 |
| - qweight = Parameter( |
142 |
| - torch.empty( |
| 146 | + qweight = PackedvLLMParameter( |
| 147 | + data=torch.empty( |
143 | 148 | input_size_per_partition // self.quant_config.pack_factor,
|
144 | 149 | output_size_per_partition,
|
145 | 150 | dtype=torch.int32,
|
146 | 151 | ),
|
147 |
| - requires_grad=False, |
148 |
| - ) |
149 |
| - set_weight_attrs( |
150 |
| - qweight, { |
151 |
| - "input_dim": 0, |
152 |
| - "output_dim": 1, |
153 |
| - "packed_dim": 0, |
154 |
| - "pack_factor": self.quant_config.pack_factor, |
155 |
| - }) |
156 |
| - g_idx = Parameter( |
157 |
| - torch.tensor( |
158 |
| - [ |
159 |
| - i // self.quant_config.group_size |
160 |
| - for i in range(input_size_per_partition) |
161 |
| - ], |
162 |
| - dtype=torch.int32, |
163 |
| - ), |
164 |
| - requires_grad=False, |
165 |
| - ) |
166 |
| - # Ignore warning from fused linear layers such as QKVParallelLinear. |
167 |
| - set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) |
168 |
| - qzeros = Parameter( |
| 152 | + input_dim=0, |
| 153 | + output_dim=1, |
| 154 | + packed_dim=0, |
| 155 | + packed_factor=self.quant_config.pack_factor, |
| 156 | + weight_loader=weight_loader) |
| 157 | + |
| 158 | + g_idx = RowvLLMParameter(data=torch.tensor( |
| 159 | + [ |
| 160 | + i // self.quant_config.group_size |
| 161 | + for i in range(input_size_per_partition) |
| 162 | + ], |
| 163 | + dtype=torch.int32, |
| 164 | + ), |
| 165 | + input_dim=0, |
| 166 | + weight_loader=weight_loader) |
| 167 | + qzeros_args = { |
| 168 | + "data": |
169 | 169 | torch.empty(
|
170 | 170 | scale_and_zero_size,
|
171 | 171 | output_size_per_partition // self.quant_config.pack_factor,
|
172 | 172 | dtype=torch.int32,
|
173 | 173 | ),
|
174 |
| - requires_grad=False, |
175 |
| - ) |
176 |
| - set_weight_attrs( |
177 |
| - qzeros, { |
178 |
| - "input_dim": scale_and_zero_input_dim, |
179 |
| - "output_dim": 1, |
180 |
| - "packed_dim": 1, |
181 |
| - "pack_factor": self.quant_config.pack_factor, |
182 |
| - }) |
183 |
| - scales = Parameter( |
| 174 | + "weight_loader": |
| 175 | + weight_loader |
| 176 | + } |
| 177 | + weight_scale_args = { |
| 178 | + "data": |
184 | 179 | torch.empty(
|
185 | 180 | scale_and_zero_size,
|
186 | 181 | output_size_per_partition,
|
187 | 182 | dtype=params_dtype,
|
188 | 183 | ),
|
189 |
| - requires_grad=False, |
190 |
| - ) |
191 |
| - set_weight_attrs(scales, { |
192 |
| - "input_dim": scale_and_zero_input_dim, |
193 |
| - "output_dim": 1, |
194 |
| - }) |
| 184 | + "weight_loader": |
| 185 | + weight_loader |
| 186 | + } |
| 187 | + if scale_and_zero_input_dim is None: |
| 188 | + scales = ChannelQuantScaleParameter(output_dim=1, |
| 189 | + **weight_scale_args) |
| 190 | + qzeros = PackedColumnParameter( |
| 191 | + output_dim=1, |
| 192 | + packed_dim=1, |
| 193 | + packed_factor=self.quant_config.pack_factor, |
| 194 | + **qzeros_args) |
| 195 | + |
| 196 | + else: |
| 197 | + scales = GroupQuantScaleParameter(output_dim=1, |
| 198 | + input_dim=0, |
| 199 | + **weight_scale_args) |
| 200 | + qzeros = PackedvLLMParameter( |
| 201 | + input_dim=0, |
| 202 | + output_dim=1, |
| 203 | + packed_dim=1, |
| 204 | + packed_factor=self.quant_config.pack_factor, |
| 205 | + **qzeros_args) |
195 | 206 |
|
196 | 207 | layer.register_parameter("qweight", qweight)
|
197 |
| - set_weight_attrs(qweight, extra_weight_attrs) |
198 | 208 | layer.register_parameter("g_idx", g_idx)
|
199 |
| - set_weight_attrs(g_idx, extra_weight_attrs) |
200 | 209 | layer.register_parameter("qzeros", qzeros)
|
201 |
| - set_weight_attrs(qzeros, extra_weight_attrs) |
202 | 210 | layer.register_parameter("scales", scales)
|
203 |
| - set_weight_attrs(scales, extra_weight_attrs) |
204 | 211 |
|
205 | 212 | layer.exllama_state = exllama_state
|
206 | 213 |
|
207 | 214 | def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
| 215 | + # for torch.compile |
| 216 | + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) |
| 217 | + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) |
| 218 | + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) |
| 219 | + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) |
| 220 | + |
208 | 221 | # exllama needs to shuffle the weight after the weight is loaded
|
209 | 222 | # here we do the shuffle on first forward pass
|
210 | 223 | if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
|
0 commit comments