Skip to content

Commit d691aad

Browse files
authored
Create KV cache input tensor only if cache len > 0 for that layer
Differential Revision: D84471388 Pull Request resolved: #15042
1 parent 7421646 commit d691aad

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/models/llama/static_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def __init__(
297297
dtype=dtype,
298298
)
299299
for layer_id in range(config.n_layers)
300+
if cache_lens[layer_id] > 0
300301
}
301302
self.v_caches = {
302303
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
@@ -307,6 +308,7 @@ def __init__(
307308
dtype=dtype,
308309
)
309310
for layer_id in range(config.n_layers)
311+
if cache_lens[layer_id] > 0
310312
}
311313

312314
self.config = config

0 commit comments

Comments
 (0)