Skip to content

Commit 192349a

Browse files
committed
[quantization] Ouput kv-tuples
This PR outputs kv-tuples in case `use_cache` was set. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent fcacf65 commit 192349a

File tree

3 files changed

+97
-5
lines changed

3 files changed

+97
-5
lines changed

test/quantization/wrapq/wrappers/llama/test_quant_model.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,68 @@ def test_forward_diff(self):
109109
self.assertGreater(diff, 0.0)
110110
self.assertLess(diff, 0.4)
111111
self.assertEqual(fp_out.shape, q_out.shape)
112+
113+
114+
@unittest.skipUnless(has_transformers_for("llama"), skip_msg)
115+
class TestQuantLlamaModelWithCache(unittest.TestCase):
116+
seq_len: int
117+
vocab_size: int
118+
hid_layers: int
119+
fp_model: torch.nn.Module
120+
121+
@classmethod
122+
def setUpClass(cls):
123+
torch.manual_seed(0)
124+
125+
from transformers.models.llama.configuration_llama import LlamaConfig
126+
from transformers.models.llama.modeling_llama import LlamaModel
127+
128+
cls.seq_len = 16
129+
cls.vocab_size = 10000
130+
cls.hid_layers = 3
131+
cfg = LlamaConfig(
132+
hidden_size=8,
133+
num_attention_heads=2,
134+
num_key_value_heads=1,
135+
head_dim=4,
136+
attention_bias=False,
137+
attention_dropout=0.0,
138+
attn_implementation="eager",
139+
num_hidden_layers=cls.hid_layers,
140+
max_position_embeddings=cls.seq_len,
141+
use_cache=True,
142+
return_dict=False,
143+
)
144+
cls.fp_model = LlamaModel(cfg)
145+
146+
def test_model_output(self):
147+
qmodel = QuantLlamaModel(
148+
self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill")
149+
)
150+
self.assertIs(qmodel._mode, Mode.NO_QUANT)
151+
152+
qmodel.enable_calibration()
153+
self.assertIs(qmodel._mode, Mode.CALIB)
154+
155+
x = torch.randint(
156+
0,
157+
self.vocab_size,
158+
(
159+
1,
160+
self.seq_len,
161+
),
162+
)
163+
output = qmodel(x)
164+
165+
assert len(output) == 2 # last_hidden_states + past_key_values
166+
past_key_values = output[1]
167+
assert len(past_key_values) == self.hid_layers
168+
for index in range(self.hid_layers):
169+
past_key_value = past_key_values[index]
170+
assert isinstance(past_key_value, tuple)
171+
172+
past_key = past_key_value[0]
173+
assert past_key.shape[-2] == self.seq_len
174+
175+
past_value = past_key_value[1]
176+
assert past_value.shape[-2] == self.seq_len

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,12 @@ def main():
334334
type=str,
335335
default=None,
336336
)
337+
parser.add_argument(
338+
"--use_cache",
339+
action="store_true",
340+
default=False,
341+
help="Wether to use cache",
342+
)
337343
args = parser.parse_args()
338344
print(args)
339345

@@ -370,7 +376,7 @@ def main():
370376
.eval()
371377
)
372378

373-
model.config.use_cache = False # TODO use args for it
379+
model.config.use_cache = args.use_cache
374380
if args.calibrate_seq_len is not None:
375381
model.config.max_position_embeddings = min(
376382
model.config.max_position_embeddings, args.calibrate_seq_len
@@ -420,6 +426,8 @@ def main():
420426
if not args.no_GPTQ:
421427
print("Applying GPTQ …")
422428

429+
old_use_cache = model.config.use_cache
430+
model.config.use_cache = False # to save memory
423431
sens = None
424432
if args.gptq_mse is not None and args.gptq_mse == "smse":
425433
if args.sensitivity_path is not None:
@@ -440,6 +448,7 @@ def main():
440448
q_m(inp.to(args.device))
441449

442450
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
451+
model.config.use_cache = old_use_cache
443452
else:
444453
q_m = model
445454

tico/quantization/wrapq/wrappers/llama/quant_model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,15 @@ def forward(
185185
inputs_embeds = self.embed_tokens(input_ids)
186186

187187
if use_cache and past_key_values is None:
188-
past_key_values = DynamicCache()
188+
past_key_values = []
189189

190190
if cache_position is None:
191191
past_seen_tokens = (
192-
past_key_values.get_seq_length() if past_key_values is not None else 0
192+
0
193+
if (past_key_values is None or len(past_key_values) == 0)
194+
else past_key_values[0][0].shape[-2]
193195
)
196+
194197
cache_position = torch.arange(
195198
past_seen_tokens,
196199
past_seen_tokens + inputs_embeds.shape[1],
@@ -217,15 +220,21 @@ def forward(
217220
all_hidden_states = () if output_hidden_states else None
218221
all_self_attns = () if output_attentions else None
219222

220-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
223+
for idx, decoder_layer in enumerate(
224+
self.layers[: self.config.num_hidden_layers]
225+
):
221226
if output_hidden_states:
222227
all_hidden_states += (hidden_states,) # type: ignore[operator]
223228

224229
layer_outputs = decoder_layer(
225230
hidden_states,
226231
attention_mask=causal_mask,
227232
position_ids=position_ids,
228-
past_key_value=past_key_values,
233+
past_key_value=(
234+
past_key_values[idx]
235+
if past_key_values is not None and len(past_key_values) > idx
236+
else None
237+
),
229238
output_attentions=output_attentions,
230239
use_cache=use_cache,
231240
cache_position=cache_position,
@@ -235,6 +244,15 @@ def forward(
235244

236245
if decoder_layer.wrapped.return_type == "tuple":
237246
hidden_states = layer_outputs[0]
247+
elif use_cache:
248+
hidden_states = layer_outputs[0]
249+
assert isinstance(layer_outputs[1], tuple)
250+
if len(past_key_values) >= idx: # type: ignore[arg-type]
251+
# prefill mode
252+
past_key_values += (layer_outputs[1],) # type: ignore[operator]
253+
else:
254+
# decode mode
255+
past_key_values[idx] = (layer_outputs[1],) # type: ignore[index]
238256
else:
239257
hidden_states = layer_outputs
240258

0 commit comments

Comments
 (0)