12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from abc import abstractmethod
16
15
from typing import Tuple
17
16
from weakref import ref
18
17
23
22
generate_gparam ,
24
23
strategy_cdiv ,
25
24
)
26
- from compressed_tensors .utils import getattr_chain
27
25
28
26
29
- base_name_to_scheme_field = {
30
- "q" : "input_activations" ,
31
- "k" : "input_activations" ,
32
- "v" : "input_activations" ,
33
- "input" : "input_activations" ,
34
- "weight" : "weights" ,
35
- "output" : "output_activations" ,
36
- }
37
-
38
-
39
- class ObserverBase (torch .nn .Module ):
40
- def __init__ (self , module : torch .nn .Module , base_name : str ):
27
+ class MockMinMaxObserver (torch .nn .Module ):
28
+ def __init__ (self , base_name : str , args : QuantizationArgs , module : torch .nn .Module ):
41
29
super ().__init__ ()
42
30
self .parent = ref (module )
43
31
self .base_name = base_name
32
+ self .args = args
44
33
45
- self .scheme_field = base_name_to_scheme_field [base_name ]
46
- self .args : QuantizationArgs = getattr_chain (
47
- module , f"quantization_scheme.{ self .scheme_field } "
48
- )
49
-
50
- # used for moving averages and testing
34
+ # used for testing
51
35
self .min_vals = None
52
36
self .max_vals = None
53
37
54
- @abstractmethod
55
38
def get_min_max (self , observed : torch .Tensor ):
56
- ...
39
+ min_vals = torch .amin (observed , dim = (0 , - 1 ))
40
+ max_vals = torch .amax (observed , dim = (0 , - 1 ))
41
+
42
+ return min_vals , max_vals
57
43
58
44
def forward (self , observed : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
59
45
observed = flatten_for_quantization (observed , self .base_name , self .args )
@@ -71,46 +57,12 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
71
57
72
58
def get_global_scale (self , observed : torch .Tensor ):
73
59
observed = observed .reshape ((1 , 1 , - 1 )) # per tensor reshape
74
-
75
60
min_vals , max_vals = self .get_min_max (observed )
76
-
77
61
global_scale = generate_gparam (min_vals , max_vals )
78
62
79
63
return global_scale
80
64
81
65
82
- class MockMinMaxObserver (ObserverBase ):
83
- def __init__ (self , module : torch .nn .Module , base_name : str ):
84
- super ().__init__ (module , base_name )
85
-
86
- def get_min_max (self , observed : torch .Tensor ):
87
- min_vals = torch .amin (observed , dim = (0 , - 1 ))
88
- max_vals = torch .amax (observed , dim = (0 , - 1 ))
89
-
90
- return min_vals , max_vals
91
-
92
-
93
- class MockMovingMinMaxObserver (ObserverBase ):
94
- def __init__ (self , module : torch .nn .Module , base_name : str ):
95
- super ().__init__ (module , base_name )
96
-
97
- self .averaging_constant = self .args .observer_kwargs .get (
98
- "averaging_constant" , 0.01
99
- )
100
-
101
- def get_min_max (self , observed : torch .Tensor ):
102
- min_vals = torch .amin (observed , dim = (0 , - 1 ))
103
- max_vals = torch .amax (observed , dim = (0 , - 1 ))
104
-
105
- if self .min_vals is not None :
106
- # FUTURE: consider scaling by num observations (first dim)
107
- # rather than reducing by first dim
108
- min_vals = torch .lerp (self .min_vals , min_vals , self .averaging_constant )
109
- max_vals = torch .lerp (self .max_vals , max_vals , self .averaging_constant )
110
-
111
- return min_vals , max_vals
112
-
113
-
114
66
def flatten_for_quantization (
115
67
value : torch .Tensor , base_name : str , args : QuantizationArgs
116
68
) -> torch .Tensor :
0 commit comments