@@ -52,6 +52,8 @@ def __init__(
52
52
self .use_custom_update_cache_op = use_custom_update_cache_op
53
53
self .quantized_cache_dtype = torch .int8
54
54
self .cache_fp_type = torch .float32
55
+ self .return_float_values = True
56
+ self .max_context_length = max_context_length
55
57
cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
56
58
scale_shape = (max_batch_size , max_context_length , n_heads , 1 )
57
59
self .register_buffer (
@@ -61,17 +63,17 @@ def __init__(
61
63
"v_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
62
64
)
63
65
self .register_buffer (
64
- "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
66
+ "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float32 )
65
67
)
66
68
self .register_buffer (
67
- "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
69
+ "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float32 )
68
70
)
69
71
if cache_type == QuantizedCacheType .AffineAsymmetric :
70
72
self .register_buffer (
71
- "k_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
73
+ "k_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int8 )
72
74
)
73
75
self .register_buffer (
74
- "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
76
+ "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int8 )
75
77
)
76
78
77
79
def _quantize (self , value ):
@@ -91,20 +93,15 @@ def _quantize(self, value):
91
93
)
92
94
return quantized_value , scales , zero_points
93
95
94
- def update (self , input_pos , k_val , v_val ):
95
- """
96
- k_val, v_val: [B, H, S, D]
97
- return: [B, H, S, D]
98
- However the storage is [B, S, H, D] so we incur transpose in, transpose out
99
- This shall be removed by subsequent post-export graph pass
100
- """
101
- k_val = k_val .transpose (1 , 2 )
102
- v_val = v_val .transpose (1 , 2 )
103
- # quantize current k_val and store it in the cache
96
+ def _quantize_and_update (self , input_pos , k_val , v_val ):
104
97
quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
105
-
106
98
quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
107
99
100
+ k_scales = k_scales .to (torch .float32 )
101
+ k_zero_points = k_zero_points .to (self .quantized_cache_dtype )
102
+ v_scales = v_scales .to (torch .float32 )
103
+ v_zero_points = v_zero_points .to (self .quantized_cache_dtype )
104
+
108
105
if self .use_custom_update_cache_op :
109
106
start_pos = input_pos [0 ].item ()
110
107
_ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
@@ -125,25 +122,30 @@ def update(self, input_pos, k_val, v_val):
125
122
self .v_cache_scales [:, input_pos ] = v_scales
126
123
self .v_cache_zero_points [:, input_pos ] = v_zero_points
127
124
125
+ def _update_and_return_float_values (self , input_pos , k_val , v_val ):
126
+ self ._quantize_and_update (input_pos , k_val , v_val )
127
+
128
128
k_out = torch .ops .quantized_decomposed .dequantize_per_token (
129
129
self .k_cache ,
130
- self .k_cache_scales ,
131
- self .k_cache_zero_points ,
130
+ self .k_cache_scales . to ( torch . float64 ) ,
131
+ self .k_cache_zero_points . to ( torch . int64 ) ,
132
132
torch .iinfo (self .quantized_cache_dtype ).min ,
133
133
torch .iinfo (self .quantized_cache_dtype ).max ,
134
134
self .quantized_cache_dtype ,
135
135
self .cache_fp_type ,
136
136
)
137
137
v_out = torch .ops .quantized_decomposed .dequantize_per_token (
138
138
self .v_cache ,
139
- self .v_cache_scales ,
140
- self .v_cache_zero_points ,
139
+ self .v_cache_scales . to ( torch . float64 ) ,
140
+ self .v_cache_zero_points . to ( torch . int64 ) ,
141
141
torch .iinfo (self .quantized_cache_dtype ).min ,
142
142
torch .iinfo (self .quantized_cache_dtype ).max ,
143
143
self .quantized_cache_dtype ,
144
144
self .cache_fp_type ,
145
145
)
146
146
147
+ # When returning float values we jsut use the last value
148
+ # instead of dequantized value.
147
149
start_pos = input_pos [0 ].item ()
148
150
if self .use_custom_update_cache_op :
149
151
_ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
@@ -152,6 +154,29 @@ def update(self, input_pos, k_val, v_val):
152
154
k_out [:, input_pos ] = k_val
153
155
v_out [:, input_pos ] = v_val
154
156
157
+ return k_out , v_out
158
+
159
+ def _update_and_return_quantized_values (self , input_pos , k_val , v_val ):
160
+ self ._quantize_and_update (input_pos , k_val , v_val )
161
+
162
+ return self .k_cache , self .v_cache
163
+
164
+ def update (self , input_pos , k_val , v_val ):
165
+ """
166
+ k_val, v_val: [B, H, S, D]
167
+ return: [B, H, S, D]
168
+ However the storage is [B, S, H, D] so we incur transpose in, transpose out
169
+ This shall be removed by subsequent post-export graph pass
170
+ """
171
+ k_val = k_val .transpose (1 , 2 )
172
+ v_val = v_val .transpose (1 , 2 )
173
+
174
+ if self .return_float_values :
175
+ k_out , v_out = self ._update_and_return_float_values (input_pos , k_val , v_val )
176
+ else :
177
+ k_out , v_out = self ._update_and_return_quantized_values (
178
+ input_pos , k_val , v_val
179
+ )
155
180
return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
156
181
157
182
@classmethod
0 commit comments