|
16 | 16 | Miscelaneous helpers for the quantization lifecycle
|
17 | 17 | """
|
18 | 18 |
|
19 |
| -from typing import Optional |
20 |
| - |
21 |
| -import torch |
22 | 19 | from torch.nn import Module
|
23 | 20 |
|
24 | 21 |
|
25 | 22 | __all__ = [
|
26 |
| - "update_layer_weight_quant_params", |
27 | 23 | "enable_quantization",
|
28 | 24 | "disable_quantization",
|
29 | 25 | ]
|
30 | 26 |
|
31 | 27 |
|
32 |
| -def update_layer_weight_quant_params( |
33 |
| - layer: Module, |
34 |
| - weight: Optional[torch.Tensor] = None, |
35 |
| - g_idx: Optional[torch.Tensor] = None, |
36 |
| - reset_obs: bool = False, |
37 |
| -): |
38 |
| - """ |
39 |
| - Update quantization parameters on layer |
40 |
| -
|
41 |
| - :param layer: input layer |
42 |
| - :param weight: weight to update quant params with, defaults to layer weight |
43 |
| - :param g_idx: optional mapping from column index to group index |
44 |
| - :param reset_obs: reset the observer before calculating quant params, |
45 |
| - defaults to False |
46 |
| - """ |
47 |
| - attached_weight = getattr(layer, "weight", None) |
48 |
| - |
49 |
| - if weight is None: |
50 |
| - weight = attached_weight |
51 |
| - scale = getattr(layer, "weight_scale", None) |
52 |
| - zero_point = getattr(layer, "weight_zero_point", None) |
53 |
| - if g_idx is None: |
54 |
| - g_idx = getattr(layer, "weight_g_idx", None) |
55 |
| - observer = getattr(layer, "weight_observer", None) |
56 |
| - |
57 |
| - if weight is None or observer is None or scale is None or zero_point is None: |
58 |
| - # scale, zp, or observer not calibratable or weight not available |
59 |
| - return |
60 |
| - |
61 |
| - if reset_obs: |
62 |
| - observer.reset() |
63 |
| - |
64 |
| - if attached_weight is not None: |
65 |
| - weight = weight.to(attached_weight.dtype) |
66 |
| - |
67 |
| - updated_scale, updated_zero_point = observer(weight) |
68 |
| - |
69 |
| - # update scale and zero point |
70 |
| - device = next(layer.parameters()).device |
71 |
| - scale.data = updated_scale.to(device) |
72 |
| - zero_point.data = updated_zero_point.to(device) |
73 |
| - |
74 |
| - |
75 | 28 | def enable_quantization(module: Module):
|
76 | 29 | module.quantization_enabled = True
|
77 | 30 |
|
|
0 commit comments