Skip to content

Commit 723ac95

Browse files
committed
Minor refactoring
1 parent 546a2ca commit 723ac95

File tree

5 files changed

+140
-33
lines changed

5 files changed

+140
-33
lines changed

src/brevitas/utils/python_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Optional
1313
from typing import TypeVar
1414
from typing import Union
15+
import warnings
1516

1617

1718
class AutoName(str, Enum):
@@ -83,9 +84,16 @@ def __init__(self, registry_name: Optional[str] = None) -> None:
8384
self._registry_name = registry_name
8485
self._registry: Dict[str, T] = {}
8586

87+
@staticmethod
88+
def register(
89+
registry: "Registry[T]",
90+
names: Union[str, List[str]],
91+
) -> Callable[[T], T]:
92+
return registry.register(names)
93+
8694
@property
8795
def registry_name(self) -> str:
88-
return "registry" if self._registry_name is None else self._registry_name
96+
return "Registry" if self._registry_name is None else self._registry_name
8997

9098
def register(self, names: Union[str, List[str]]) -> Callable[[T], T]:
9199
if isinstance(names, str):
@@ -95,7 +103,9 @@ def decorator(value: T) -> T:
95103
# Allow registering the same value to multiple keys
96104
for name in names:
97105
if name in self._registry:
98-
raise ValueError(f"'{name}' is already registered in {self.registry_name}.")
106+
warnings.warn(
107+
f"'{name}' is already registered in {self.registry_name}. Overwriting the existing value."
108+
)
99109
self._registry[name] = value
100110
return value
101111

src/brevitas_examples/common/generative/quantizers.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Dict
88
from typing import Optional
99
from typing import Type
10-
from typing import TypeVar
10+
from typing import TypeAlias
1111

1212
from torch import nn
1313

@@ -52,6 +52,34 @@
5252

5353
from .quant_blocks import *
5454

55+
# Prevents Pylance from raising "Variable not allowed in type expression" error in every type hint in BaseQuantizer
56+
QuantInjector: TypeAlias = ExtendedInjector # type: ignore
57+
58+
59+
class BaseQuantizer:
60+
weight_quant: ClassVar[Optional[QuantInjector]] = None
61+
linear_input_quant: ClassVar[Optional[QuantInjector]] = None
62+
input_quant: ClassVar[Optional[QuantInjector]] = None
63+
q_scaled_quant: ClassVar[Optional[QuantInjector]] = None
64+
k_transposed_quant: ClassVar[Optional[QuantInjector]] = None
65+
v_quant: ClassVar[Optional[QuantInjector]] = None
66+
attn_output_weights_quant: ClassVar[Optional[QuantInjector]] = None
67+
68+
@classmethod
69+
def override_quantizers_dict(
70+
cls: "BaseQuantizer",
71+
quantizers_dict: Dict[str,
72+
Optional[QuantInjector]]) -> Dict[str, Optional[QuantInjector]]:
73+
# Overrides the quantizers in the input dictionary
74+
for key in quantizers_dict:
75+
if (value := getattr(cls, key)) is not None:
76+
quantizers_dict[key] = value
77+
return quantizers_dict
78+
79+
80+
# Registry for custom quantizers
81+
CUSTOM_QUANTIZERS_REGISTRY = Registry[Type[BaseQuantizer]](registry_name="CustomQuantizersRegistry")
82+
5583

5684
class DynamicActProxyMixin(ExtendedInjector):
5785
proxy_class = DynamicActQuantProxyFromInjector
@@ -226,31 +254,3 @@ class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat):
226254

227255
class Fp8e4m3WeightPerChannelFloatMSE(MSESymmetricScale, Fp8e4m3WeightPerChannelFloat):
228256
pass
229-
230-
231-
# TODO: Subject to change
232-
class BaseQuantizer:
233-
weight_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
234-
linear_input_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
235-
input_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
236-
q_scaled_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
237-
k_transposed_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
238-
v_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
239-
attn_output_weights_quant: ClassVar[Optional[ExtendedInjector]] = None # type: ignore
240-
241-
@classmethod
242-
def override_quantizers_dict(
243-
cls: "BaseQuantizer",
244-
quantizers_dict: Dict[str, Optional[ExtendedInjector]]): # type: ignore
245-
for key in quantizers_dict:
246-
if hasattr(cls, key) and (value := getattr(cls, key)) is not None:
247-
quantizers_dict[key] = value
248-
return quantizers_dict
249-
250-
251-
CUSTOM_QUANTIZERS_REGISTRY = Registry[Type[BaseQuantizer]](registry_name="CustomQuantizersRegistry")
252-
253-
254-
@CUSTOM_QUANTIZERS_REGISTRY.register("custom_quant")
255-
class CustomQuantizerExample(BaseQuantizer):
256-
weight_quant: ClassVar[Optional[ExtendedInjector]] = Int8WeightPerTensorFloat # type: ignore

src/brevitas_examples/stable_diffusion/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def sdpa_zp_stats_type():
530530
input_kwargs=input_kwargs)
531531

532532
layer_map = generate_quant_maps(
533-
*quantizers, dtype=dtype, device=args.device, quantize_embedding=False)
533+
**quantizers, dtype=dtype, device=args.device, quantize_embedding=False)
534534

535535
linear_qkwargs = layer_map[torch.nn.Linear][1]
536536
linear_qkwargs[

tests/brevitas/utils/test_python_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
from enum import auto
55

6+
import pytest
7+
68
from brevitas.utils.python_utils import AutoName
9+
from brevitas.utils.python_utils import Registry
710

811

912
class TestEnum(AutoName):
@@ -38,3 +41,77 @@ def test_eq_enum():
3841

3942
def test_neq_enum():
4043
assert TestEnum.FIRST != TestEnum.SECOND
44+
45+
46+
class TestRegistry:
47+
48+
def test_register_single_name(self):
49+
r = Registry()
50+
51+
@r.register("k")
52+
class Dummy:
53+
pass
54+
55+
dummy = r.get("k")
56+
57+
assert dummy is Dummy
58+
assert len(r.get_registered_keys()) == 1
59+
assert next(iter(r.get_registered_keys())) == "k"
60+
61+
def test_register_multiple_names(self):
62+
r = Registry()
63+
64+
@r.register(["k1", "k2"])
65+
class Dummy:
66+
pass
67+
68+
dummy1 = r.get("k1")
69+
dummy2 = r.get("k2")
70+
71+
assert dummy1 is Dummy
72+
assert dummy2 is Dummy
73+
assert len(r.get_registered_keys()) == 2
74+
assert set(r.get_registered_keys()) == {"k1", "k2"}
75+
76+
def test_register_single_name_static(self):
77+
r = Registry()
78+
79+
@Registry.register(r, "k")
80+
class Dummy:
81+
pass
82+
83+
dummy = r.get("k")
84+
85+
assert dummy is Dummy
86+
assert len(r.get_registered_keys()) == 1
87+
assert next(iter(r.get_registered_keys())) == "k"
88+
89+
def test_register_duplicate_raises_warning(self):
90+
r = Registry("TestRegistry")
91+
92+
r.register("dup")("k")
93+
# Patch warnings.warn and check that it is called with the expected message
94+
with pytest.warns(UserWarning) as record:
95+
r.register("dup")("k")
96+
msg = str(record[0].message)
97+
assert "'dup' is already registered in TestRegistry. Overwriting the existing value." == msg
98+
99+
def test_get_missing_empty_raises_valueerror(self):
100+
r = Registry("TestRegistry")
101+
102+
with pytest.raises(ValueError) as excinfo:
103+
r.get("missing")
104+
105+
msg = str(excinfo.value)
106+
assert msg == "'missing' not found in TestRegistry. The available values are: <empty>"
107+
108+
def test_get_missing_raises_valueerror(self):
109+
r = Registry("TestRegistry")
110+
111+
r.register("k1")("v1")
112+
r.register("k2")("v2")
113+
with pytest.raises(ValueError) as excinfo:
114+
r.get("missing")
115+
116+
msg = str(excinfo.value)
117+
assert msg == "'missing' not found in TestRegistry. The available values are: k1, k2"

tests/brevitas_examples/test_llm_cases.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class LLMRunCases:
2727
"mistral", #"mixtral",
2828
"opt",],
2929
)
30-
def case_small_models_with_ppl(self, run_dict, default_run_args, request):
30+
def case_small_models_run(self, run_dict, default_run_args, request):
3131
yield process_args_and_metrics(default_run_args, run_dict)
3232

3333
# yapf: disable
@@ -93,6 +93,26 @@ def case_small_models_toggle_args(self, run_dict, default_run_args, request):
9393
pytest.skip(reason=f'MSE as weight_param_method requires JIT to be disabled')
9494
yield process_args_and_metrics(default_run_args, run_dict)
9595

96+
@pytest_cases.parametrize(
97+
"run_dict",
98+
[
99+
{
100+
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
101+
"custom_quantizer": "example_int8_weight_quant"},],
102+
ids=[
103+
"llama",]
104+
)
105+
def case_small_models_custom_quantizer(self, run_dict, default_run_args, request):
106+
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
107+
from brevitas.utils.python_utils import Registry
108+
from brevitas_examples.common.generative.quantizers import BaseQuantizer
109+
from brevitas_examples.common.generative.quantizers import CUSTOM_QUANTIZERS_REGISTRY
110+
@Registry.register(CUSTOM_QUANTIZERS_REGISTRY, "example_int8_weight_quant")
111+
class ExampleInt8WeightQuantizer(BaseQuantizer):
112+
weight_quant = Int8WeightPerTensorFloat
113+
yield process_args_and_metrics(default_run_args, run_dict)
114+
115+
96116
class LLMPerplexityCases:
97117

98118
METRICS = ["float_ppl", "quant_ppl"]

0 commit comments

Comments
 (0)