Skip to content

Commit f638fc8

Browse files
committed
[quantization] Receive default observer in builders
This commit adds a default observer option to builder API. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 1e01956 commit f638fc8

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

test/quantization/config/test_builders.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from tico.quantization.config.ptq import PTQConfig
2828
from tico.quantization.wrapq.dtypes import DType
29+
from tico.quantization.wrapq.observers.ema import EMAObserver
2930
from tico.quantization.wrapq.qscheme import QScheme
3031

3132

@@ -224,3 +225,12 @@ def test_build_llm_ptq_config_unsupported_model_type_raises(self):
224225
model_type="mistral",
225226
num_hidden_layers=1,
226227
)
228+
229+
def test_build_llm_ptq_config_accepts_default_observer(self):
230+
cfg = build_llm_ptq_config(
231+
model_type="llama",
232+
num_hidden_layers=1,
233+
default_observer=EMAObserver,
234+
)
235+
236+
self.assertIs(cfg.default_observer, EMAObserver)

tico/quantization/config/builders.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# limitations under the License.
1414

1515
import copy
16-
from typing import Any, Dict, Optional, Tuple
16+
from typing import Any, Dict, Optional, Tuple, Type
1717

1818
from tico.quantization.config.ptq import PTQConfig, WrapperVariant
1919
from tico.quantization.wrapq.dtypes import DType
20+
from tico.quantization.wrapq.observers.base import ObserverBase
21+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
2022
from tico.quantization.wrapq.qscheme import QScheme
2123

2224

@@ -338,6 +340,7 @@ def build_llm_ptq_config(
338340
wrapper_variant: WrapperVariant = "prefill",
339341
activation_dtype: DType = DType.int(16),
340342
default_qscheme: QScheme = QScheme.PER_TENSOR_SYMM,
343+
default_observer: Type[ObserverBase] = MinMaxObserver,
341344
linear_weight_bits: Optional[int] = None,
342345
linear_weight_dtype: Optional[DType] = None,
343346
embedding_weight_bits: Optional[int] = None,
@@ -375,6 +378,11 @@ def build_llm_ptq_config(
375378
default_qscheme : QScheme, default=QScheme.PER_TENSOR_SYMM
376379
Default quantization scheme for observers that do not receive an
377380
explicit override.
381+
default_observer : Type[ObserverBase], default=MinMaxObserver
382+
Observer class to instantiate when no explicit observer is provided
383+
via overrides.
384+
This should be a subclass of `ObserverBase` (e.g., MinMaxObserver,
385+
EMAObserver). The class itself (not an instance) must be passed.
378386
linear_weight_bits : Optional[int], default=None
379387
Convenience bit-width for decoder-layer linear projection weights.
380388
Used only when `linear_weight_dtype` is not provided.
@@ -445,6 +453,7 @@ def build_llm_ptq_config(
445453
return PTQConfig(
446454
default_dtype=activation_dtype,
447455
default_qscheme=default_qscheme,
456+
default_observer=default_observer,
448457
wrapper_variant=wrapper_variant,
449458
overrides=overrides,
450459
strict_wrap=strict_wrap,

0 commit comments

Comments
 (0)