Skip to content

Commit eb1c705

Browse files
kylesayrsKyle Sayers
andauthored
remove function (#156)
Co-authored-by: Kyle Sayers <[email protected]>
1 parent 0067091 commit eb1c705

File tree

1 file changed

+0
-47
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+0
-47
lines changed

src/compressed_tensors/quantization/lifecycle/helpers.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,15 @@
1616
Miscelaneous helpers for the quantization lifecycle
1717
"""
1818

19-
from typing import Optional
20-
21-
import torch
2219
from torch.nn import Module
2320

2421

2522
__all__ = [
26-
"update_layer_weight_quant_params",
2723
"enable_quantization",
2824
"disable_quantization",
2925
]
3026

3127

32-
def update_layer_weight_quant_params(
33-
layer: Module,
34-
weight: Optional[torch.Tensor] = None,
35-
g_idx: Optional[torch.Tensor] = None,
36-
reset_obs: bool = False,
37-
):
38-
"""
39-
Update quantization parameters on layer
40-
41-
:param layer: input layer
42-
:param weight: weight to update quant params with, defaults to layer weight
43-
:param g_idx: optional mapping from column index to group index
44-
:param reset_obs: reset the observer before calculating quant params,
45-
defaults to False
46-
"""
47-
attached_weight = getattr(layer, "weight", None)
48-
49-
if weight is None:
50-
weight = attached_weight
51-
scale = getattr(layer, "weight_scale", None)
52-
zero_point = getattr(layer, "weight_zero_point", None)
53-
if g_idx is None:
54-
g_idx = getattr(layer, "weight_g_idx", None)
55-
observer = getattr(layer, "weight_observer", None)
56-
57-
if weight is None or observer is None or scale is None or zero_point is None:
58-
# scale, zp, or observer not calibratable or weight not available
59-
return
60-
61-
if reset_obs:
62-
observer.reset()
63-
64-
if attached_weight is not None:
65-
weight = weight.to(attached_weight.dtype)
66-
67-
updated_scale, updated_zero_point = observer(weight)
68-
69-
# update scale and zero point
70-
device = next(layer.parameters()).device
71-
scale.data = updated_scale.to(device)
72-
zero_point.data = updated_zero_point.to(device)
73-
74-
7528
def enable_quantization(module: Module):
7629
module.quantization_enabled = True
7730

0 commit comments

Comments
 (0)