|
| 1 | +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""E2E test: TinyLlama in decode mode with a DynamicCache. |
| 16 | +
|
| 17 | +Scenario |
| 18 | +-------- |
| 19 | +Simulates the token-generation (decode) step where a previously-computed |
| 20 | +key/value cache is fed back into the model alongside a single new token. |
| 21 | +
|
| 22 | +register_dynamic_cache() selects the correct pytree flatten strategy |
| 23 | +automatically based on the installed transformers version: |
| 24 | +
|
| 25 | +* transformers with DynamicLayer (newer): Layer-based layout (cache.layers) |
| 26 | +* transformers without DynamicLayer (e.g. 4.52.x): legacy layout |
| 27 | + (cache.key_cache / cache.value_cache) |
| 28 | +
|
| 29 | +register_dynamic_layer() is also called so that if the Layer-based layout is |
| 30 | +in use, DynamicLayer objects inside the cache are also pytree-traversable. |
| 31 | +It is a safe no-op when DynamicLayer does not exist in the installed |
| 32 | +transformers version. |
| 33 | +""" |
| 34 | + |
| 35 | +import torch |
| 36 | + |
| 37 | +from tico.utils.pytree_utils import register_dynamic_cache, register_dynamic_layer |
| 38 | +from transformers import AutoModelForCausalLM |
| 39 | +from transformers.cache_utils import DynamicCache |
| 40 | + |
| 41 | +from test.modules.base import TestModuleBase |
| 42 | + |
| 43 | +# Number of previously-processed tokens to pre-fill into the cache. |
| 44 | +_PAST_SEQ_LEN = 5 |
| 45 | + |
| 46 | +# To suppress warning: |
| 47 | +# _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. |
| 48 | +# (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. |
| 49 | +# (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. |
| 50 | +# WeightsUnpickler error: Unsupported global: GLOBAL transformers.cache_utils.DynamicCache was not an allowed global by default. Please use `torch.serialization.add_safe_globals([transformers.cache_utils.DynamicCache])` or the `torch.serialization.safe_globals([transformers.cache_utils.DynamicCache])` context manager to allowlist this global if you trust this class/function. |
| 51 | +torch.serialization.add_safe_globals([DynamicCache]) |
| 52 | + |
| 53 | + |
| 54 | +class TinyLlamaWithDynamicCache(TestModuleBase): |
| 55 | + """TinyLlama decode step with a pre-populated DynamicCache.""" |
| 56 | + |
| 57 | + def __init__(self): |
| 58 | + super().__init__() |
| 59 | + self.model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0").to( |
| 60 | + "cpu" |
| 61 | + ) |
| 62 | + self.cfg = self.model.config |
| 63 | + self.rtol = 1e-4 |
| 64 | + self.atol = 1e-4 |
| 65 | + |
| 66 | + # register_dynamic_cache picks the right flatten strategy for the |
| 67 | + # installed transformers version automatically. |
| 68 | + # register_dynamic_layer is a no-op when DynamicLayer doesn't exist. |
| 69 | + register_dynamic_cache() |
| 70 | + register_dynamic_layer() |
| 71 | + |
| 72 | + def forward(self, *args, **kwargs): |
| 73 | + return self.model(*args, **kwargs) |
| 74 | + |
| 75 | + def get_example_inputs(self): |
| 76 | + cfg = self.cfg |
| 77 | + num_layers = cfg.num_hidden_layers |
| 78 | + num_kv_heads = getattr(cfg, "num_key_value_heads", cfg.num_attention_heads) |
| 79 | + head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) |
| 80 | + |
| 81 | + # Single new token (decode step). |
| 82 | + input_ids = torch.tensor([[869]], dtype=torch.long) # token id for '▁.' |
| 83 | + attention_mask = torch.ones(1, _PAST_SEQ_LEN + 1, dtype=torch.long) |
| 84 | + position_ids = torch.tensor([[_PAST_SEQ_LEN]], dtype=torch.long) |
| 85 | + |
| 86 | + # Build a DynamicCache pre-filled with random past KV pairs. |
| 87 | + past_key_values = DynamicCache() |
| 88 | + for layer_idx in range(num_layers): |
| 89 | + past_key_values.update( |
| 90 | + torch.randn(1, num_kv_heads, _PAST_SEQ_LEN, head_dim), |
| 91 | + torch.randn(1, num_kv_heads, _PAST_SEQ_LEN, head_dim), |
| 92 | + layer_idx, |
| 93 | + ) |
| 94 | + |
| 95 | + return ( |
| 96 | + input_ids, |
| 97 | + attention_mask, |
| 98 | + position_ids, |
| 99 | + past_key_values, |
| 100 | + ), {} |
0 commit comments