|
13 | 13 | import torch |
14 | 14 |
|
15 | 15 | from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer |
| 16 | +from executorch.extension.llm.export.builder import DType |
16 | 17 |
|
17 | 18 | try: |
18 | 19 | from .fairseq2 import convert_to_llama_checkpoint |
@@ -191,73 +192,31 @@ def __init__(self, **kwargs): |
191 | 192 | ) |
192 | 193 | elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: |
193 | 194 | print("Using SPIN quantization.") |
194 | | - assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" |
195 | | - assert self.args.preq_mode in [ |
196 | | - "8da4w", |
197 | | - "8da4w_output_8da8w", |
198 | | - ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." |
199 | | - assert hasattr( |
200 | | - self.args, "preq_group_size" |
201 | | - ), "preq_group_size must be specified" |
202 | | - assert hasattr( |
203 | | - self.args, "dtype_override" |
204 | | - ), "dtype_override must be specified" |
| 195 | + self._transform_for_pre_quantization(checkpoint) |
| 196 | + |
205 | 197 | from .source_transformation.pre_quantization import ( |
206 | 198 | sanitize_checkpoint_from_pre_quantization, |
207 | | - transform_linear_for_pre_quantization, |
208 | | - ) |
209 | | - |
210 | | - mapping = { |
211 | | - "fp32": torch.float32, |
212 | | - "fp16": torch.float16, |
213 | | - "bf16": torch.bfloat16, |
214 | | - } |
215 | | - |
216 | | - # Transform the output layer first if needed. |
217 | | - if self.args.preq_mode == "8da4w_output_8da8w": |
218 | | - from .source_transformation.pre_quantization import ( |
219 | | - transform_output_linear_for_pre_quantization, |
220 | | - ) |
221 | | - |
222 | | - self.model_ = transform_output_linear_for_pre_quantization( |
223 | | - module=self.model_, |
224 | | - checkpoint=checkpoint, |
225 | | - dtype=mapping[self.args.dtype_override], |
226 | | - ) |
227 | | - |
228 | | - self.model_ = transform_linear_for_pre_quantization( |
229 | | - self.model_, |
230 | | - checkpoint, |
231 | | - self.args.preq_group_size, |
232 | | - mapping[self.args.dtype_override], |
233 | 199 | ) |
234 | 200 |
|
235 | | - embedding_bit_width, embedding_group_size = None, None |
236 | | - if hasattr(self.args, "preq_embedding_quantize"): |
237 | | - embedding_bit_width, embedding_group_size = ( |
238 | | - self.args.preq_embedding_quantize.split(",") |
239 | | - ) |
240 | | - from .source_transformation.pre_quantization import ( |
241 | | - transform_embedding_for_pre_quantization, |
| 201 | + sanitize_checkpoint_from_pre_quantization(checkpoint) |
| 202 | + elif hasattr(self.args, "use_qat") and self.args.use_qat: |
| 203 | + print("Using QAT quantization.") |
| 204 | + self._transform_for_pre_quantization(checkpoint) |
| 205 | + if hasattr(self.args, "use_lora") and self.args.use_lora: |
| 206 | + from .source_transformation.lora import ( |
| 207 | + transform_linear_for_lora_after_quantization, |
242 | 208 | ) |
243 | 209 |
|
244 | | - if ( |
245 | | - embedding_group_size == "none" |
246 | | - or embedding_group_size == "None" |
247 | | - or embedding_group_size == "0" |
248 | | - ): |
249 | | - embedding_group_size = None |
250 | | - else: |
251 | | - embedding_group_size = int(embedding_group_size) |
252 | | - |
253 | | - self.model_ = transform_embedding_for_pre_quantization( |
| 210 | + self.model_ = transform_linear_for_lora_after_quantization( |
254 | 211 | self.model_, |
255 | 212 | checkpoint, |
256 | | - mapping[self.args.dtype_override], |
257 | | - int(embedding_bit_width), |
258 | | - embedding_group_size, |
| 213 | + self.args.use_lora, |
259 | 214 | ) |
260 | 215 |
|
| 216 | + from .source_transformation.pre_quantization import ( |
| 217 | + sanitize_checkpoint_from_pre_quantization, |
| 218 | + ) |
| 219 | + |
261 | 220 | sanitize_checkpoint_from_pre_quantization(checkpoint) |
262 | 221 |
|
263 | 222 | # assign=True: load params/buffers by assignment instead of performing an in-place copy. |
@@ -318,3 +277,62 @@ def get_example_inputs_kvcache_sdpa(self): |
318 | 277 | [0], dtype=torch.long |
319 | 278 | ), # start_pos, what token of output are we on. |
320 | 279 | ) |
| 280 | + |
| 281 | + def _transform_for_pre_quantization(self, checkpoint): |
| 282 | + assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" |
| 283 | + assert self.args.preq_mode in [ |
| 284 | + "8da4w", |
| 285 | + "8da4w_output_8da8w", |
| 286 | + ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." |
| 287 | + assert hasattr( |
| 288 | + self.args, "preq_group_size" |
| 289 | + ), "preq_group_size must be specified" |
| 290 | + assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" |
| 291 | + from .source_transformation.pre_quantization import ( |
| 292 | + transform_linear_for_pre_quantization, |
| 293 | + ) |
| 294 | + |
| 295 | + # Transform the output layer first if needed. |
| 296 | + if self.args.preq_mode == "8da4w_output_8da8w": |
| 297 | + from .source_transformation.pre_quantization import ( |
| 298 | + transform_output_linear_for_pre_quantization, |
| 299 | + ) |
| 300 | + |
| 301 | + self.model_ = transform_output_linear_for_pre_quantization( |
| 302 | + module=self.model_, |
| 303 | + checkpoint=checkpoint, |
| 304 | + dtype=DType[self.args.dtype_override].to_torch_dtype(), |
| 305 | + ) |
| 306 | + |
| 307 | + self.model_ = transform_linear_for_pre_quantization( |
| 308 | + self.model_, |
| 309 | + checkpoint, |
| 310 | + self.args.preq_group_size, |
| 311 | + DType[self.args.dtype_override].to_torch_dtype(), |
| 312 | + ) |
| 313 | + |
| 314 | + embedding_bit_width, embedding_group_size = None, None |
| 315 | + if hasattr(self.args, "preq_embedding_quantize"): |
| 316 | + embedding_bit_width, embedding_group_size = ( |
| 317 | + self.args.preq_embedding_quantize.split(",") |
| 318 | + ) |
| 319 | + from .source_transformation.pre_quantization import ( |
| 320 | + transform_embedding_for_pre_quantization, |
| 321 | + ) |
| 322 | + |
| 323 | + if ( |
| 324 | + embedding_group_size == "none" |
| 325 | + or embedding_group_size == "None" |
| 326 | + or embedding_group_size == "0" |
| 327 | + ): |
| 328 | + embedding_group_size = None |
| 329 | + else: |
| 330 | + embedding_group_size = int(embedding_group_size) |
| 331 | + |
| 332 | + self.model_ = transform_embedding_for_pre_quantization( |
| 333 | + self.model_, |
| 334 | + checkpoint, |
| 335 | + DType[self.args.dtype_override].to_torch_dtype(), |
| 336 | + int(embedding_bit_width), |
| 337 | + embedding_group_size, |
| 338 | + ) |
0 commit comments