Skip to content

Commit 23e4c61

Browse files
committed
[quantization] Ouput kv-tuples
This PR outputs kv-tuples in case `use_cache` was set. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent fcacf65 commit 23e4c61

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,12 @@ def main():
334334
type=str,
335335
default=None,
336336
)
337+
parser.add_argument(
338+
"--use_cache",
339+
action="store_true",
340+
default=False,
341+
help="Wether to use cache",
342+
)
337343
args = parser.parse_args()
338344
print(args)
339345

@@ -370,7 +376,7 @@ def main():
370376
.eval()
371377
)
372378

373-
model.config.use_cache = False # TODO use args for it
379+
model.config.use_cache = args.use_cache
374380
if args.calibrate_seq_len is not None:
375381
model.config.max_position_embeddings = min(
376382
model.config.max_position_embeddings, args.calibrate_seq_len
@@ -420,6 +426,8 @@ def main():
420426
if not args.no_GPTQ:
421427
print("Applying GPTQ …")
422428

429+
old_use_cache = model.config.use_cache
430+
model.config.use_cache = False # to save memory
423431
sens = None
424432
if args.gptq_mse is not None and args.gptq_mse == "smse":
425433
if args.sensitivity_path is not None:
@@ -440,6 +448,7 @@ def main():
440448
q_m(inp.to(args.device))
441449

442450
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
451+
model.config.use_cache = old_use_cache
443452
else:
444453
q_m = model
445454

tico/quantization/wrapq/wrappers/llama/quant_model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,15 @@ def forward(
185185
inputs_embeds = self.embed_tokens(input_ids)
186186

187187
if use_cache and past_key_values is None:
188-
past_key_values = DynamicCache()
188+
past_key_values = []
189189

190190
if cache_position is None:
191191
past_seen_tokens = (
192-
past_key_values.get_seq_length() if past_key_values is not None else 0
192+
0
193+
if (past_key_values is None or len(past_key_values) == 0)
194+
else past_key_values[0][0].shape[-2]
193195
)
196+
194197
cache_position = torch.arange(
195198
past_seen_tokens,
196199
past_seen_tokens + inputs_embeds.shape[1],
@@ -217,15 +220,21 @@ def forward(
217220
all_hidden_states = () if output_hidden_states else None
218221
all_self_attns = () if output_attentions else None
219222

220-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
223+
for idx, decoder_layer in enumerate(
224+
self.layers[: self.config.num_hidden_layers]
225+
):
221226
if output_hidden_states:
222227
all_hidden_states += (hidden_states,) # type: ignore[operator]
223228

224229
layer_outputs = decoder_layer(
225230
hidden_states,
226231
attention_mask=causal_mask,
227232
position_ids=position_ids,
228-
past_key_value=past_key_values,
233+
past_key_value=(
234+
past_key_values[idx]
235+
if past_key_values is not None and len(past_key_values) > idx
236+
else None
237+
),
229238
output_attentions=output_attentions,
230239
use_cache=use_cache,
231240
cache_position=cache_position,
@@ -235,6 +244,15 @@ def forward(
235244

236245
if decoder_layer.wrapped.return_type == "tuple":
237246
hidden_states = layer_outputs[0]
247+
elif use_cache:
248+
hidden_states = layer_outputs[0]
249+
assert isinstance(layer_outputs[1], tuple)
250+
if len(past_key_values) >= idx: # type: ignore[arg-type]
251+
# prefill mode
252+
past_key_values += (layer_outputs[1],) # type: ignore[operator]
253+
else:
254+
# decode mode
255+
past_key_values[idx] = (layer_outputs[1],) # type: ignore[index]
238256
else:
239257
hidden_states = layer_outputs
240258

0 commit comments

Comments
 (0)