Skip to content

Commit 37df2dd

Browse files
authored
Clean up observer defaulting logic, better error message (#200)
1 parent 2b79056 commit 37df2dd

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,6 @@ def get_observer(self):
114114
"""
115115
:return: torch quantization FakeQuantize built based on these QuantizationArgs
116116
"""
117-
118-
# No observer required for the dynamic case
119-
if self.dynamic:
120-
self.observer = None
121-
return self.observer
122-
123117
return self.observer
124118

125119
@field_validator("type", mode="before")
@@ -203,6 +197,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
203197
"activation ordering"
204198
)
205199

200+
# infer observer w.r.t. dynamic
206201
if dynamic:
207202
if strategy not in (
208203
QuantizationStrategy.TOKEN,
@@ -214,18 +209,19 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
214209
"quantization",
215210
)
216211
if observer is not None:
217-
warnings.warn(
218-
"No observer is used for dynamic quantization, setting to None"
219-
)
220-
model.observer = None
212+
if observer != "memoryless": # avoid annoying users with old configs
213+
warnings.warn(
214+
"No observer is used for dynamic quantization, setting to None"
215+
)
216+
observer = None
221217

222-
# if we have not set an observer and we
223-
# are running static quantization, use minmax
224-
if not observer and not dynamic:
225-
model.observer = "minmax"
218+
elif observer is None:
219+
# default to minmax for non-dynamic cases
220+
observer = "minmax"
226221

227222
# write back modified values
228223
model.strategy = strategy
224+
model.observer = observer
229225
return model
230226

231227
def pytorch_dtype(self) -> torch.dtype:

src/compressed_tensors/registry/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def get_from_registry(
258258
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
259259
else:
260260
# look up name in alias registry
261-
name = _ALIAS_REGISTRY[parent_class].get(name)
261+
name = _ALIAS_REGISTRY[parent_class].get(name, name)
262262
# look up name in registry
263263
retrieved_value = _REGISTRY[parent_class].get(name)
264264
if retrieved_value is None:

0 commit comments

Comments
 (0)