|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import copy |
16 | | -from typing import Any, Dict, Optional, Tuple |
| 16 | +from typing import Any, Dict, Optional, Tuple, Type |
17 | 17 |
|
18 | 18 | from tico.quantization.config.ptq import PTQConfig, WrapperVariant |
19 | 19 | from tico.quantization.wrapq.dtypes import DType |
| 20 | +from tico.quantization.wrapq.observers.base import ObserverBase |
| 21 | +from tico.quantization.wrapq.observers.minmax import MinMaxObserver |
20 | 22 | from tico.quantization.wrapq.qscheme import QScheme |
21 | 23 |
|
22 | 24 |
|
@@ -338,6 +340,7 @@ def build_llm_ptq_config( |
338 | 340 | wrapper_variant: WrapperVariant = "prefill", |
339 | 341 | activation_dtype: DType = DType.int(16), |
340 | 342 | default_qscheme: QScheme = QScheme.PER_TENSOR_SYMM, |
| 343 | + default_observer: Type[ObserverBase] = MinMaxObserver, |
341 | 344 | linear_weight_bits: Optional[int] = None, |
342 | 345 | linear_weight_dtype: Optional[DType] = None, |
343 | 346 | embedding_weight_bits: Optional[int] = None, |
@@ -375,6 +378,11 @@ def build_llm_ptq_config( |
375 | 378 | default_qscheme : QScheme, default=QScheme.PER_TENSOR_SYMM |
376 | 379 | Default quantization scheme for observers that do not receive an |
377 | 380 | explicit override. |
| 381 | + default_observer : Type[ObserverBase], default=MinMaxObserver |
| 382 | + Observer class to instantiate when no explicit observer is provided |
| 383 | + via overrides. |
| 384 | + This should be a subclass of `ObserverBase` (e.g., MinMaxObserver, |
| 385 | + EMAObserver). The class itself (not an instance) must be passed. |
378 | 386 | linear_weight_bits : Optional[int], default=None |
379 | 387 | Convenience bit-width for decoder-layer linear projection weights. |
380 | 388 | Used only when `linear_weight_dtype` is not provided. |
@@ -445,6 +453,7 @@ def build_llm_ptq_config( |
445 | 453 | return PTQConfig( |
446 | 454 | default_dtype=activation_dtype, |
447 | 455 | default_qscheme=default_qscheme, |
| 456 | + default_observer=default_observer, |
448 | 457 | wrapper_variant=wrapper_variant, |
449 | 458 | overrides=overrides, |
450 | 459 | strict_wrap=strict_wrap, |
|
0 commit comments