Skip to content

Commit b610f6e

Browse files
authored
[utils] Improve cache utils to support layer-based caches (#545)
* [utils] Improve cache utils to support layer-based caches Let's improve its coverage. TICO-DCO-1.0-Signed-off-by: Dayoung Lee <dayoung.lee@samsung.com>
1 parent 7a5381f commit b610f6e

File tree

5 files changed

+1004
-110
lines changed

5 files changed

+1004
-110
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DO NOT REMOVE THIS FILE
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""E2E test: TinyLlama in decode mode with a DynamicCache.
16+
17+
Scenario
18+
--------
19+
Simulates the token-generation (decode) step where a previously-computed
20+
key/value cache is fed back into the model alongside a single new token.
21+
22+
register_dynamic_cache() selects the correct pytree flatten strategy
23+
automatically based on the installed transformers version:
24+
25+
* transformers with DynamicLayer (newer): Layer-based layout (cache.layers)
26+
* transformers without DynamicLayer (e.g. 4.52.x): legacy layout
27+
(cache.key_cache / cache.value_cache)
28+
29+
register_dynamic_layer() is also called so that if the Layer-based layout is
30+
in use, DynamicLayer objects inside the cache are also pytree-traversable.
31+
It is a safe no-op when DynamicLayer does not exist in the installed
32+
transformers version.
33+
"""
34+
35+
import torch
36+
37+
from tico.utils.pytree_utils import register_dynamic_cache, register_dynamic_layer
38+
from transformers import AutoModelForCausalLM
39+
from transformers.cache_utils import DynamicCache
40+
41+
from test.modules.base import TestModuleBase
42+
43+
# Number of previously-processed tokens to pre-fill into the cache.
44+
_PAST_SEQ_LEN = 5
45+
46+
# To suppress warning:
47+
# _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
48+
# (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
49+
# (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
50+
# WeightsUnpickler error: Unsupported global: GLOBAL transformers.cache_utils.DynamicCache was not an allowed global by default. Please use `torch.serialization.add_safe_globals([transformers.cache_utils.DynamicCache])` or the `torch.serialization.safe_globals([transformers.cache_utils.DynamicCache])` context manager to allowlist this global if you trust this class/function.
51+
torch.serialization.add_safe_globals([DynamicCache])
52+
53+
54+
class TinyLlamaWithDynamicCache(TestModuleBase):
55+
"""TinyLlama decode step with a pre-populated DynamicCache."""
56+
57+
def __init__(self):
58+
super().__init__()
59+
self.model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0").to(
60+
"cpu"
61+
)
62+
self.cfg = self.model.config
63+
self.rtol = 1e-4
64+
self.atol = 1e-4
65+
66+
# register_dynamic_cache picks the right flatten strategy for the
67+
# installed transformers version automatically.
68+
# register_dynamic_layer is a no-op when DynamicLayer doesn't exist.
69+
register_dynamic_cache()
70+
register_dynamic_layer()
71+
72+
def forward(self, *args, **kwargs):
73+
return self.model(*args, **kwargs)
74+
75+
def get_example_inputs(self):
76+
cfg = self.cfg
77+
num_layers = cfg.num_hidden_layers
78+
num_kv_heads = getattr(cfg, "num_key_value_heads", cfg.num_attention_heads)
79+
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
80+
81+
# Single new token (decode step).
82+
input_ids = torch.tensor([[869]], dtype=torch.long) # token id for '▁.'
83+
attention_mask = torch.ones(1, _PAST_SEQ_LEN + 1, dtype=torch.long)
84+
position_ids = torch.tensor([[_PAST_SEQ_LEN]], dtype=torch.long)
85+
86+
# Build a DynamicCache pre-filled with random past KV pairs.
87+
past_key_values = DynamicCache()
88+
for layer_idx in range(num_layers):
89+
past_key_values.update(
90+
torch.randn(1, num_kv_heads, _PAST_SEQ_LEN, head_dim),
91+
torch.randn(1, num_kv_heads, _PAST_SEQ_LEN, head_dim),
92+
layer_idx,
93+
)
94+
95+
return (
96+
input_ids,
97+
attention_mask,
98+
position_ids,
99+
past_key_values,
100+
), {}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
transformers==4.52.4

0 commit comments

Comments
 (0)