Skip to content

Commit a014498

Browse files
committed
Rename live_step_len parameter to unpadded_len for clarity
- Rename live_step_len -> unpadded_len across attention and KV cache modules - Update documentation to clarify that unpadded_len specifies the number of non-padding tokens per sequence, with actual behavior depending on KV cache implementation - Fix pre-existing pylint error in rattention.py where rla_output was used before assignment - Update all test files to use the new parameter name The new name better reflects the parameter's purpose: indicating the number of non-padding tokens in each sequence, rather than the ambiguous "live step length". Implementation behavior varies by KV cache type: - Standard KVCache: ignores the parameter - SlidingWindowKVCache: uses it for sequence masking - PagedKVCache: ignores the parameter GitOrigin-RevId: 5b0d848cc3d0ce3cbf736fa30e7953b8de4d1997
1 parent 10a9ec0 commit a014498

File tree

9 files changed

+64
-56
lines changed

9 files changed

+64
-56
lines changed

axlearn/common/attention.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@
4646
4747
TODO(apghml) Convert everything to take an instance of BaseAttentionBias rather than a Tensor.
4848
49-
On `live_step_len`:
50-
* An int tensor of shape [batch], indicating the valid step length in the given inputs.
51-
* We assume that live steps must be contiguous at the beginning. So once
52-
`live_step_len < max_step_len` for a sequence, the remaining `max_step_len - live_step_len`
53-
part is considered padding.
54-
* During prefill, `time_step == live_step_len`.
49+
On `unpadded_len`:
50+
* An int tensor of shape [batch], indicating the number of non-padding tokens in each sequence.
51+
* Non-padding tokens are assumed to be contiguous at the beginning of each sequence.
52+
For a sequence with `unpadded_len[i] < sequence_length`, tokens at positions
53+
`unpadded_len[i]:` are considered padding and should be ignored.
54+
* During prefill, `time_step == unpadded_len` since we process exactly the non-padding tokens.
55+
* This parameter enables optimizations in some KV cache implementations by avoiding
56+
computation on padding tokens.
5557
56-
TODO (dhwang2): Replace `time_step` argument with `live_step_len` to reduce cognitive complexity.
58+
TODO (dhwang2): Replace `time_step` argument with `unpadded_len` to reduce cognitive complexity.
5759
5860
On `segment_ids`:
5961
* A tensor of shape [batch, target_length] with values in [0, num_segments].
@@ -1669,7 +1671,7 @@ def _forward_for_mode(
16691671
key: Optional[Tensor] = None,
16701672
value: Optional[Tensor] = None,
16711673
kv_state: Optional[KVState] = None,
1672-
live_step_len: Optional[Tensor] = None,
1674+
unpadded_len: Optional[Tensor] = None,
16731675
attention_logit_biases: Union[None, Tensor, BaseAttentionBias] = None,
16741676
segment_ids: Optional[Tensor] = None,
16751677
query_positions: Optional[Tensor] = None,
@@ -1688,7 +1690,7 @@ def _forward_for_mode(
16881690
key: An optional Tensor of shape [batch, source_length, source_dim].
16891691
value: An optional Tensor of shape [batch, source_length, source_dim].
16901692
kv_state: An optional KVState. If specified, both `key` and `value` should be None.
1691-
live_step_len: An optional Tensor of shape [batch]. Please refer to ``On live_step_len``
1693+
unpadded_len: An optional Tensor of shape [batch]. Please refer to ``On unpadded_len``
16921694
in the file docstring for details.
16931695
attention_logit_biases: See ``On attention logit biases`` in the file comments.
16941696
segment_ids: See ``On segment_ids`` in the file comments.
@@ -1768,7 +1770,7 @@ def _forward_for_mode(
17681770
)
17691771
elif mode in (ForwardMode.EXTEND_STEP, ForwardMode.INIT_STATES):
17701772
assert cached_states is not None
1771-
step_len = live_step_len if live_step_len is not None else q_proj.shape[1]
1773+
step_len = unpadded_len if unpadded_len is not None else q_proj.shape[1]
17721774
new_cached_states = dict(time_step=time_step + step_len)
17731775
if not has_external_kv_state:
17741776
# In prefill, init_states already called self.kv_cache.init_states.
@@ -1778,7 +1780,7 @@ def _forward_for_mode(
17781780
k_proj=k_proj,
17791781
v_proj=v_proj,
17801782
key_positions=query_positions,
1781-
live_step_len=live_step_len,
1783+
unpadded_len=unpadded_len,
17821784
page_pool=page_pool,
17831785
)
17841786
if mode == ForwardMode.EXTEND_STEP:
@@ -2057,7 +2059,7 @@ def init_states(
20572059
query=query,
20582060
key=key,
20592061
value=value,
2060-
live_step_len=time_step,
2062+
unpadded_len=time_step,
20612063
cached_states=init_states,
20622064
kv_state=kv_state,
20632065
attention_logit_biases=attention_logit_biases,

axlearn/common/kv_cache/base_kv_cache.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def extend_step(
8383
k_proj: Tensor,
8484
v_proj: Tensor,
8585
key_positions: Tensor,
86-
live_step_len: Optional[Tensor] = None,
86+
unpadded_len: Optional[Tensor] = None,
8787
page_pool: Optional[Nested[Tensor]] = None,
8888
) -> tuple[Nested[Tensor], Output]:
8989
"""Updates the KV cache per extend step.
@@ -97,8 +97,10 @@ def extend_step(
9797
k_proj: A Tensor of shape [batch, step_length, num_kv_heads, per_head_dim].
9898
v_proj: A Tensor of shape [batch, step_length, num_kv_heads, per_head_dim].
9999
key_positions: An optional Tensor of shape [1|batch, step_length].
100-
live_step_len: An optional Tensor of shape [batch]. See file-level docstring of
101-
`attention.py`
100+
unpadded_len: An optional Tensor of shape [batch]. Specifies the number of
101+
non-padding tokens per sequence. When provided, only the first `unpadded_len[i]`
102+
tokens of sequence `i` are considered valid for caching. The actual behavior
103+
depends on the specific KV cache implementation.
102104
page_pool: See file-level docstring of `attention.py`.
103105
104106
Returns:

axlearn/common/kv_cache/kv_cache.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def extend_step(
3737
k_proj: Tensor,
3838
v_proj: Tensor,
3939
key_positions: Tensor,
40-
live_step_len: Optional[Tensor] = None,
40+
unpadded_len: Optional[Tensor] = None,
4141
page_pool: Optional[Nested[Tensor]] = None,
4242
) -> tuple[Nested[Tensor], BaseKVCache.Output]:
43-
# TODO(dhwang2): By returning only the valid portions of the KV (by live_step_len),
44-
# the attention complexity can be reduced from O(max_len²) to O(live_step_len²), especially
43+
# TODO(dhwang2): By returning only the valid portions of the KV (by unpadded_len),
44+
# the attention complexity can be reduced from O(max_len²) to O(unpadded_len²), especially
4545
# in prefill.
46-
# The remaining part after `live_step_len` is considered padding.
46+
# The remaining part after `unpadded_len` is considered padding.
4747
assert page_pool is None
48-
del live_step_len
48+
del unpadded_len
4949
if k_proj.shape != v_proj.shape:
5050
raise ValueError(f"{k_proj.shape=} != {v_proj.shape=}")
5151
if k_proj.shape[1] != key_positions.shape[1]:
@@ -101,7 +101,7 @@ def update_single(cached_kv_slice, kv_proj_slice, time_idx):
101101
# [B, S, N, H]
102102
k_proj = jnp.einsum("bnhs->bsnh", cached_key)
103103
v_proj = jnp.einsum("bnhs->bsnh", cached_value)
104-
# Currently, the part larger than live_step_len is also being overwritten in the KV cache,
104+
# Currently, the part larger than unpadded_len is also being overwritten in the KV cache,
105105
# and this part is filtered out by the causal mask through key_positions.
106106
key_positions = jnp.arange(k_proj.shape[1])[None] # [1, source_length]
107107
return updated_state, self.Output(k_proj=k_proj, v_proj=v_proj, key_positions=key_positions)

axlearn/common/kv_cache/kv_cache_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class KVCacheTest(TestCase):
1717
cached_kv_length=[8],
1818
time_step_value=[2, 4],
1919
cache_dtype=[None, jnp.bfloat16],
20-
live_step_len=[-1, 2, 4],
20+
unpadded_len=[-1, 2, 4],
2121
)
22-
def test_kv_cache(self, cached_kv_length, time_step_value, cache_dtype, live_step_len):
22+
def test_kv_cache(self, cached_kv_length, time_step_value, cache_dtype, unpadded_len):
2323
test_layer = (
2424
KVCache.default_config()
2525
.set(name="ref", cache_dtype=cache_dtype)
@@ -33,12 +33,12 @@ def test_kv_cache(self, cached_kv_length, time_step_value, cache_dtype, live_ste
3333
k_proj = jax.random.normal(prng_key, shape=step_shape)
3434
v_proj = jax.random.normal(prng_key, shape=step_shape)
3535
key_positions = jnp.arange(step_len)[None] + time_step_value
36-
if live_step_len < 0:
36+
if unpadded_len < 0:
3737
valid_step_len = step_len
38-
live_step_len = None
38+
unpadded_len = None
3939
else:
40-
valid_step_len = live_step_len
41-
live_step_len = jnp.full([batch], fill_value=live_step_len, dtype=jnp.int32)
40+
valid_step_len = unpadded_len
41+
unpadded_len = jnp.full([batch], fill_value=unpadded_len, dtype=jnp.int32)
4242

4343
kv_shape = KVCache.Shape(batch, cached_kv_length, heads, dim)
4444
test_states = test_layer.init_states(kv_shape, dtype=k_proj.dtype)
@@ -49,7 +49,7 @@ def test_kv_cache(self, cached_kv_length, time_step_value, cache_dtype, live_ste
4949
k_proj=k_proj,
5050
v_proj=v_proj,
5151
key_positions=key_positions,
52-
live_step_len=live_step_len,
52+
unpadded_len=unpadded_len,
5353
)
5454

5555
def check(input_kv, output_kv):
@@ -65,7 +65,7 @@ def check(input_kv, output_kv):
6565
check(v_proj, test_output.v_proj)
6666
key_positions = jnp.arange(cached_kv_length)[None]
6767
assert_allclose(test_output.key_positions, key_positions)
68-
# Currently, the part larger than live_step_len is also being overwritten in the KV cache.
68+
# Currently, the part larger than unpadded_len is also being overwritten in the KV cache.
6969
# TODO(dhwang2): remove this check when KVCache updates only valid part.
7070
assert_allclose(
7171
test_output.k_proj[:, time_step_value : time_step_value + step_len],

axlearn/common/kv_cache/paged_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def extend_step(
132132
k_proj: Tensor,
133133
v_proj: Tensor,
134134
key_positions: Tensor,
135-
live_step_len: Optional[Tensor] = None,
135+
unpadded_len: Optional[Tensor] = None,
136136
page_pool: Optional[Nested[Tensor]] = None,
137137
) -> tuple[Nested[Tensor], KVCache.Output]:
138138
"""Extend the cache with the new key and value.
@@ -159,7 +159,7 @@ def extend_step(
159159
k_pages = k_pages.at[k, actual_page_idx, page_offset].set(k_proj[i, j, k, :])
160160
v_pages = v_pages.at[k, actual_page_idx, page_offset].set(v_proj[i, j, k, :])
161161
"""
162-
del live_step_len
162+
del unpadded_len
163163

164164
if k_proj.shape != v_proj.shape:
165165
raise ValueError(f"{k_proj.shape=} != {v_proj.shape=}")

axlearn/common/kv_cache/paged_kv_cache_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def test_paged_kv_cache(
103103
v_proj = jax.random.normal(prng_key, shape=step_shape, dtype=cache_dtype)
104104
key_positions = jnp.full((batch, 1), time_step_value, dtype=jnp.int32)
105105

106-
# TODO(xiyou): consider live_step_len when it's supported
107-
live_step_len = None
106+
# TODO(xiyou): consider unpadded_len when it's supported
107+
unpadded_len = None
108108

109109
kv_shape = KVCache.Shape(batch, max_pages_each_request * page_size, heads, dim)
110110
ref_states = ref_layer.init_states(kv_shape, dtype=k_proj.dtype)
@@ -131,22 +131,22 @@ def test_paged_kv_cache(
131131

132132
@partial(jax.jit, static_argnums=(0,))
133133
def jit_extend_step(
134-
layer: KVCache, test_states, k_proj, v_proj, key_positions, live_step_len
134+
layer: KVCache, test_states, k_proj, v_proj, key_positions, unpadded_len
135135
):
136136
_, test_output = layer.extend_step(
137137
test_states,
138138
k_proj=k_proj,
139139
v_proj=v_proj,
140140
key_positions=key_positions,
141-
live_step_len=live_step_len,
141+
unpadded_len=unpadded_len,
142142
)
143143
return test_output
144144

145145
ref_out: KVState = jit_extend_step(
146-
ref_layer, ref_states, k_proj, v_proj, key_positions, live_step_len
146+
ref_layer, ref_states, k_proj, v_proj, key_positions, unpadded_len
147147
)
148148
test_out: KVState = jit_extend_step(
149-
test_layer, test_states, k_proj, v_proj, key_positions, live_step_len
149+
test_layer, test_states, k_proj, v_proj, key_positions, unpadded_len
150150
)
151151

152152
test_k_proj = reconstruct_kv(page_indices, test_out.k_proj)

axlearn/common/kv_cache/sliding_window_kv_cache.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def extend_step(
5252
k_proj: Tensor,
5353
v_proj: Tensor,
5454
key_positions: Tensor,
55-
live_step_len: Optional[Tensor] = None,
55+
unpadded_len: Optional[Tensor] = None,
5656
page_pool: Optional[Nested[Tensor]] = None,
5757
) -> tuple[Nested[Tensor], BaseKVCache.Output]:
5858
"""Updates the sliding window KV cache per extend step.
@@ -62,8 +62,10 @@ def extend_step(
6262
k_proj: A Tensor of shape [batch, step_length, num_kv_heads, per_head_dim].
6363
v_proj: A Tensor of shape [batch, step_length, num_kv_heads, per_head_dim].
6464
key_positions: An optional Tensor of shape [1|batch, step_length].
65-
live_step_len: An optional Tensor of shape [batch]. Please refer to ``On live_step_len``
66-
in the file docstring for details.
65+
unpadded_len: An optional Tensor of shape [batch]. Specifies the number of
66+
non-padding tokens per sequence. When provided, only the first `unpadded_len[i]`
67+
tokens of sequence `i` are considered valid and will be cached. Padding tokens
68+
are masked out and marked as invalid positions.
6769
6870
Returns:
6971
A tuple (updated_state, output):
@@ -81,10 +83,10 @@ def extend_step(
8183

8284
# [1|batch, step_length] -> [batch, step_length]
8385
key_positions = jnp.broadcast_to(key_positions, (batch, step_len))
84-
if live_step_len is not None:
85-
if live_step_len.shape[0] != batch:
86-
raise ValueError(f"{live_step_len.shape=} must be [{batch}].")
87-
steps = live_step_len
86+
if unpadded_len is not None:
87+
if unpadded_len.shape[0] != batch:
88+
raise ValueError(f"{unpadded_len.shape=} must be [{batch}].")
89+
steps = unpadded_len
8890
seq_mask = sequence_mask(lengths=steps, max_len=step_len, dtype=key_positions.dtype)
8991
# update_single rolls key_positions, so mark invalid positions.
9092
key_positions = jnp.where(seq_mask, key_positions, self._invaild_position())

axlearn/common/kv_cache/sliding_window_kv_cache_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
class SlidingWindowKVCacheTest(TestCase):
1414
"""Tests SlidingWindowKVCache."""
1515

16-
@parameterized.product(cached_kv_length=[8], time_step_value=[2, 4, 6], live_step_len=[None, 2])
17-
def test_sliding_window_kv_cache(self, cached_kv_length, time_step_value, live_step_len):
16+
@parameterized.product(cached_kv_length=[8], time_step_value=[2, 4, 6], unpadded_len=[None, 2])
17+
def test_sliding_window_kv_cache(self, cached_kv_length, time_step_value, unpadded_len):
1818
test_layer = (
1919
SlidingWindowKVCache.default_config()
2020
.set(name="ref", cached_kv_length=cached_kv_length)
@@ -29,16 +29,16 @@ def test_sliding_window_kv_cache(self, cached_kv_length, time_step_value, live_s
2929
k_proj = jax.random.normal(prng_key, shape=step_shape)
3030
v_proj = jax.random.normal(prng_key, shape=step_shape)
3131
key_positions = jnp.arange(step_len)[None] + time_step_value
32-
valid_out_len = live_step_len or step_len
33-
live_step_len = (
34-
jnp.full([batch], fill_value=live_step_len) if live_step_len is not None else None
32+
valid_out_len = unpadded_len or step_len
33+
unpadded_len = (
34+
jnp.full([batch], fill_value=unpadded_len) if unpadded_len is not None else None
3535
)
3636
_, test_output = test_layer.extend_step(
3737
test_states,
3838
k_proj=k_proj,
3939
v_proj=v_proj,
4040
key_positions=key_positions,
41-
live_step_len=live_step_len,
41+
unpadded_len=unpadded_len,
4242
)
4343
kv_shape = (2, cached_kv_length + step_len, 2, 2)
4444
self.assertEqual(test_output.key_positions.shape, kv_shape[:2])

axlearn/common/rattention/rattention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def _forward_for_mode(
600600
key: Optional[Tensor] = None,
601601
value: Optional[Tensor] = None,
602602
kv_state: Optional[KVState] = None,
603-
live_step_len: Optional[int] = None,
603+
unpadded_len: Optional[int] = None,
604604
attention_logit_biases: Union[None, Tensor, BaseAttentionBias] = None,
605605
segment_ids: Optional[Tensor] = None,
606606
query_positions: Optional[Tensor] = None,
@@ -616,7 +616,7 @@ def _forward_for_mode(
616616
k_proj/v_proj as kv_state to the output.
617617
618618
Notes on intermediate variables:
619-
* live_step_len vs time_step: time_step denotes the starting point where live_step_len
619+
* unpadded_len vs time_step: time_step denotes the starting point where unpadded_len
620620
denotes the length of progression.
621621
* k_proj/v_proj vs full_k_proj/full_v_proj: the former could be single token during
622622
extend_step whereas the latter always means the kv for the whole sequence. Residual_la
@@ -651,6 +651,8 @@ def _forward_for_mode(
651651

652652
if cfg.residual_la is not None:
653653
rla_output = self.residual_la(query, i_proj_output)
654+
else:
655+
rla_output = None
654656
new_cached_states = {}
655657
else:
656658
if kv_state is None:
@@ -661,7 +663,7 @@ def _forward_for_mode(
661663
k_proj=k_proj,
662664
v_proj=v_proj,
663665
key_positions=query_positions,
664-
live_step_len=live_step_len,
666+
unpadded_len=unpadded_len,
665667
page_pool=page_pool,
666668
)
667669
if mode == ForwardMode.EXTEND_STEP:
@@ -690,14 +692,14 @@ def _forward_for_mode(
690692
else:
691693
if mode == ForwardMode.INIT_STATES:
692694
rla_state, rla_output = self.residual_la.init_states(
693-
query, (q_proj, full_k_proj, full_v_proj), live_step_len
695+
query, (q_proj, full_k_proj, full_v_proj), unpadded_len
694696
)
695697
else:
696698
rla_state, rla_output = self.residual_la.extend_step(
697699
cached_states["rla_state"], query, (q_proj, full_k_proj, full_v_proj)
698700
)
699701

700-
step_len = live_step_len if live_step_len is not None else query.shape[1]
702+
step_len = unpadded_len if unpadded_len is not None else query.shape[1]
701703
new_time_step = time_step + step_len
702704
new_cached_states = dict(
703705
swa_state=swa_state, rla_state=rla_state, time_step=new_time_step
@@ -800,7 +802,7 @@ def init_states(
800802
query=query,
801803
key=key,
802804
value=value,
803-
live_step_len=time_step,
805+
unpadded_len=time_step,
804806
cached_states=init_states,
805807
kv_state=kv_state,
806808
attention_logit_biases=attention_logit_biases,

0 commit comments

Comments
 (0)