Skip to content

Commit 546a2ca

Browse files
committed
Proposal custom quantizer
1 parent 493445c commit 546a2ca

File tree

5 files changed

+105
-12
lines changed

5 files changed

+105
-12
lines changed

src/brevitas/utils/python_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from contextlib import contextmanager
55
from enum import Enum
66
import functools
7+
from typing import Callable
8+
from typing import Dict
9+
from typing import Generic
10+
from typing import Iterable
11+
from typing import List
12+
from typing import Optional
13+
from typing import TypeVar
14+
from typing import Union
715

816

917
class AutoName(str, Enum):
@@ -64,3 +72,43 @@ def run(*args, **kwargs):
6472
return function(*args, **kwargs)
6573

6674
return run
75+
76+
77+
T = TypeVar("T")
78+
79+
80+
class Registry(Generic[T]):
81+
82+
def __init__(self, registry_name: Optional[str] = None) -> None:
83+
self._registry_name = registry_name
84+
self._registry: Dict[str, T] = {}
85+
86+
@property
87+
def registry_name(self) -> str:
88+
return "registry" if self._registry_name is None else self._registry_name
89+
90+
def register(self, names: Union[str, List[str]]) -> Callable[[T], T]:
91+
if isinstance(names, str):
92+
names = [names]
93+
94+
def decorator(value: T) -> T:
95+
# Allow registering the same value to multiple keys
96+
for name in names:
97+
if name in self._registry:
98+
raise ValueError(f"'{name}' is already registered in {self.registry_name}.")
99+
self._registry[name] = value
100+
return value
101+
102+
return decorator
103+
104+
def get_registered_keys(self) -> Iterable[str]:
105+
return self._registry.keys()
106+
107+
def get(self, name: str) -> T:
108+
try:
109+
return self._registry[name]
110+
except KeyError:
111+
available = ", ".join(sorted(self._registry)) or "<empty>"
112+
raise ValueError(
113+
f"'{name}' not found in {self.registry_name}. The available values are: {available}"
114+
)

src/brevitas_examples/common/generative/quantize.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,14 @@ def quant_format_from_string(quant_format):
468468
linear_input_quant = linear_input_quant.let(
469469
**{
470470
'group_dim': -1, 'group_size': input_group_size})
471-
return linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant
471+
return {
472+
'linear_input_quant': linear_input_quant,
473+
'weight_quant': weight_quant,
474+
'input_quant': input_quant,
475+
'q_scaled_quant': q_scaled_quant,
476+
'k_transposed_quant': k_transposed_quant,
477+
'v_quant': v_quant,
478+
'attn_output_weights_quant': attn_output_weights_quant}
472479

473480

474481
def generate_quant_maps(

src/brevitas_examples/common/generative/quantizers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
"""
55

6+
from typing import ClassVar
7+
from typing import Dict
8+
from typing import Optional
9+
from typing import Type
10+
from typing import TypeVar
11+
612
from torch import nn
713

814
from brevitas.core.function_wrapper.ops_ste import FloorSte
@@ -39,8 +45,10 @@
3945
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
4046
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
4147
from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO
48+
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
4249
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
4350
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
51+
from brevitas.utils.python_utils import Registry
4452

4553
from .quant_blocks import *
4654

@@ -218,3 +226,31 @@ class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat):
218226

219227
class Fp8e4m3WeightPerChannelFloatMSE(MSESymmetricScale, Fp8e4m3WeightPerChannelFloat):
220228
pass
229+
230+
231+
# TODO: Subject to change
232+
class BaseQuantizer:
233+
weight_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
234+
linear_input_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
235+
input_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
236+
q_scaled_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
237+
k_transposed_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
238+
v_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
239+
attn_output_weights_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
240+
241+
@classmethod
242+
def override_quantizers_dict(
243+
cls: "BaseQuantizer",
244+
quantizers_dict: Dict[str, Optional[ExtendedInjector]]): # type: ignore
245+
for key in quantizers_dict:
246+
if hasattr(cls, key) and (value := getattr(cls, key)) is not None:
247+
quantizers_dict[key] = value
248+
return quantizers_dict
249+
250+
251+
CUSTOM_QUANTIZERS_REGISTRY = Registry[Type[BaseQuantizer]](registry_name="CustomQuantizersRegistry")
252+
253+
254+
@CUSTOM_QUANTIZERS_REGISTRY.register("custom_quant")
255+
class CustomQuantizerExample(BaseQuantizer):
256+
weight_quant: ClassVar[Optional[ExtendedInjector]] = Int8WeightPerTensorFloat # type: ignore

src/brevitas_examples/llm/llm_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ def create_args_parser() -> ArgumentParser:
2020
type=str,
2121
default="facebook/opt-125m",
2222
help='HF model name. Default: facebook/opt-125m.')
23+
parser.add_argument(
24+
'--custom-quantizer',
25+
type=str,
26+
default=None,
27+
help=
28+
'Override the quantization list with custom user defined quantizers. This must be a .py file with a list of seven quantizers. Default: None.'
29+
)
2330
parser.add_argument(
2431
'--dtype',
2532
type=str,

src/brevitas_examples/llm/main.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict
3535
from brevitas_examples.common.generative.quantize import generate_quant_maps
3636
from brevitas_examples.common.generative.quantize import generate_quantizers
37+
from brevitas_examples.common.generative.quantizers import CUSTOM_QUANTIZERS_REGISTRY
3738
from brevitas_examples.common.parse_utils import override_defaults
3839
from brevitas_examples.common.parse_utils import parse_args
3940
from brevitas_examples.llm.gguf_export.export import save_quantized_as_gguf
@@ -411,7 +412,7 @@ def quantize_llm(args, extra_args=None):
411412
'zero_point_affine_rescaling_init': args.weight_quant_rescaling_init}}
412413
if args.weight_narrow_range:
413414
weight_kwargs = {**weight_kwargs, **{'narrow_range': args.weight_narrow_range}}
414-
linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers(
415+
quantizers_dict = generate_quantizers(
415416
weight_bit_width=args.weight_bit_width,
416417
weight_param_method=args.weight_param_method,
417418
weight_scale_precision=args.weight_scale_precision,
@@ -444,17 +445,11 @@ def quantize_llm(args, extra_args=None):
444445
quant_attn_mode='sdpa',
445446
scaling_min_val=args.scaling_min_val,
446447
weight_kwargs=weight_kwargs)
448+
if args.custom_quantizer is not None:
449+
custom_quantizer = CUSTOM_QUANTIZERS_REGISTRY.get(args.custom_quantizer)
450+
quantizers_dict = custom_quantizer.override_quantizers_dict(quantizers_dict)
447451
layer_map = generate_quant_maps(
448-
linear_input_quant=linear_input_quant,
449-
weight_quant=weight_quant,
450-
input_quant=input_quant,
451-
q_scaled_quant=q_scaled_quant,
452-
k_transposed_quant=k_transposed_quant,
453-
v_quant=v_quant,
454-
attn_output_weights_quant=attn_output_weights_quant,
455-
dtype=dtype,
456-
device=device,
457-
quantize_embedding=False)
452+
**quantizers_dict, dtype=dtype, device=device, quantize_embedding=False)
458453
if not args.quantize_last_layer:
459454
# Dynamo tracing changes the name of the modules, thus we need this workaround to pick
460455
# up the last module.

0 commit comments

Comments
 (0)