diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 95bae1b766a..849718527ed 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -297,6 +297,7 @@ def __init__( dtype=dtype, ) for layer_id in range(config.n_layers) + if cache_lens[layer_id] > 0 } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros( @@ -307,6 +308,7 @@ def __init__( dtype=dtype, ) for layer_id in range(config.n_layers) + if cache_lens[layer_id] > 0 } self.config = config