Skip to content

Commit f8fd186

Browse files
authored
Add page_pool argument (#1303)
* Add page_pool argument * Fix for rattention * Add missing
1 parent 2cb1b7b commit f8fd186

File tree

9 files changed

+94
-14
lines changed

9 files changed

+94
-14
lines changed

axlearn/common/attention.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@
7676
* When the accompanying argument is `query`, the `positions` argument is named as
7777
`query_position`. Similarly, when the argument `target`, it is named as `target_positions`.
7878
79+
On `page_pool`:
80+
* If not None, stores the external paged KV pool possibly shared by multiple
81+
layers. Additionally, `cached_states` will not contain KV state. Instead, it will
82+
contain indices used to index into `page_pool`.
83+
7984
TODO(changlan): Merge the use of `positions` and `time_step` to reduce cognitive complexity.
8085
8186
"""
@@ -317,6 +322,7 @@ def extend_step(
317322
self_attention_logit_biases: Optional[Tensor] = None,
318323
cross_attention_data: Optional[Tensor] = None,
319324
cross_attention_logit_biases: Optional[Tensor] = None,
325+
page_pool: Optional[Nested[Tensor]] = None,
320326
) -> tuple[NestedTensor, Output]:
321327
"""Computes incremental outputs.
322328
@@ -335,6 +341,7 @@ def extend_step(
335341
cross_attention_logit_biases: An optional Tensor of shape
336342
[..., target_step_length, source_length], where `target_step_length` must match
337343
the shape of `data`.
344+
page_pool: See file-level comments on `page_pool`.
338345
339346
Returns:
340347
(updated_cached_states, output), where:
@@ -1653,6 +1660,7 @@ def _forward_for_mode(
16531660
query_positions: Optional[Tensor] = None,
16541661
cached_states: Optional[NestedTensor] = None,
16551662
return_aux: Optional[set[str]] = None,
1663+
page_pool: Optional[Nested[Tensor]] = None,
16561664
) -> tuple[Nested[Tensor], Optional[Output]]:
16571665
"""Computes attention for the given query, key, value, and attention logit biases.
16581666
@@ -1672,6 +1680,7 @@ def _forward_for_mode(
16721680
query_positions: See ``On positions`` in the file comments.
16731681
cached_states: Optional NestedTensor as produced by `init_states`.
16741682
return_aux: See comments on `Output`.
1683+
page_pool: See file-level comments on `page_pool`.
16751684
16761685
Returns:
16771686
A tuple (cached_states, output):
@@ -1749,6 +1758,7 @@ def _forward_for_mode(
17491758
v_proj=v_proj,
17501759
key_positions=query_positions,
17511760
live_step_len=live_step_len,
1761+
page_pool=page_pool,
17521762
)
17531763
if mode == ForwardMode.EXTEND_STEP:
17541764
kv_state = KVState(*kv_cache_output)
@@ -2042,6 +2052,7 @@ def extend_step(
20422052
kv_state: Optional[KVState] = None,
20432053
attention_logit_biases: Optional[Tensor] = None,
20442054
return_aux: Optional[set[str]] = None,
2055+
page_pool: Optional[Nested[Tensor]] = None,
20452056
) -> tuple[NestedTensor, Output]:
20462057
"""Computes the value vector given the query of the current step.
20472058
This function is used by autoregressive decoding.
@@ -2068,6 +2079,7 @@ def extend_step(
20682079
The biases should already include causal masking for decoding, plus other biases
20692080
if necessary.
20702081
return_aux: See comments on `Output`.
2082+
page_pool: See file-level comments on `page_pool`.
20712083
20722084
Returns:
20732085
A `NestedTensor` state of key and value pair along with index updated at `time_step`.
@@ -2083,6 +2095,7 @@ def extend_step(
20832095
kv_state=kv_state,
20842096
attention_logit_biases=attention_logit_biases,
20852097
return_aux=return_aux,
2098+
page_pool=page_pool,
20862099
)
20872100

20882101
@staticmethod
@@ -2640,6 +2653,7 @@ def _forward_for_mode(
26402653
target_positions: Optional[Tensor] = None,
26412654
cached_states: Optional[NestedTensor] = None,
26422655
return_aux: Optional[set[str]] = None,
2656+
page_pool: Optional[Nested[Tensor]] = None,
26432657
) -> tuple[Optional[Nested[Tensor]], Optional[Output]]:
26442658
"""Computes either self-attention or cross-attention for the given target and source.
26452659
@@ -2654,6 +2668,7 @@ def _forward_for_mode(
26542668
target_positions: See ``On positions`` in the file comments.
26552669
cached_states: Optional NestedTensor as produced by `init_states`.
26562670
return_aux: See comments on `Output`.
2671+
page_pool: See file-level comments on `page_pool`.
26572672
26582673
Returns:
26592674
A tuple (cached_states, output):
@@ -2709,6 +2724,7 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
27092724
target,
27102725
**kv_kwargs,
27112726
attention_logit_biases=attention_logit_biases,
2727+
page_pool=page_pool,
27122728
)
27132729
else:
27142730
raise ValueError(f"Unrecognized mode {mode}.")
@@ -2841,6 +2857,7 @@ def extend_step(
28412857
source: Optional[Union[Tensor, KVState]] = None,
28422858
attention_logit_biases: Optional[Tensor] = None,
28432859
return_aux: Optional[set[str]] = None,
2860+
page_pool: Optional[Nested[Tensor]] = None,
28442861
) -> tuple[Nested[Tensor], Output]:
28452862
"""Computes the value vector given the query of the current step.
28462863
This function is used by autoregressive decoding.
@@ -2858,6 +2875,7 @@ def extend_step(
28582875
attention_logit_biases should have already taken care of causal masking for
28592876
decoding, plus other maskings necessary.
28602877
return_aux: See comments on `Output`.
2878+
page_pool: See file-level comments on `page_pool`.
28612879
28622880
Returns:
28632881
A `NestedTensor` state of key and value pair along with index updated at `time_step`.
@@ -2874,6 +2892,7 @@ def extend_step(
28742892
cached_states=cached_states,
28752893
attention_logit_biases=attention_logit_biases,
28762894
return_aux=return_aux,
2895+
page_pool=page_pool,
28772896
)
28782897

28792898

@@ -3196,6 +3215,7 @@ def _forward_for_mode(
31963215
target_positions: Optional[Tensor] = None,
31973216
cached_states: Optional[NestedTensor] = None,
31983217
return_aux: Optional[set[str]] = None,
3218+
page_pool: Optional[Nested[Tensor]] = None,
31993219
) -> tuple[Optional[NestedTensor], Optional[BaseTransformerLayer.Output]]:
32003220
"""Computes transformer layer outputs and self/cross-attention probabilities.
32013221
@@ -3212,6 +3232,7 @@ def _forward_for_mode(
32123232
target_positions: See ``positions`` in the file comments.
32133233
cached_states: Optional NestedTensor as produced by `init_states`.
32143234
return_aux: See comments on BaseTransformerLayer.forward.
3235+
page_pool: See file-level comments on `page_pool`.
32153236
32163237
Returns:
32173238
A tuple (cached_states, output):
@@ -3273,6 +3294,7 @@ def _forward_for_mode(
32733294
source=self_attention_kv_state,
32743295
attention_logit_biases=self_attention_logit_biases,
32753296
return_aux=self_attention_return_aux,
3297+
page_pool=page_pool,
32763298
)
32773299
else:
32783300
raise ValueError(f"Unrecognized mode {mode}.")
@@ -3979,6 +4001,7 @@ def _forward_for_mode(
39794001
assert value.shape[0] == cfg.num_layers, f"{path}={shapes(value)}"
39804002

39814003
def layer_fn(carry, x_i):
4004+
x_i, page_pool = x_i
39824005
if mode == ForwardMode.FORWARD:
39834006
layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs)
39844007
elif mode == ForwardMode.INIT_STATES:
@@ -3990,7 +4013,10 @@ def layer_fn(carry, x_i):
39904013
elif mode == ForwardMode.EXTEND_STEP:
39914014
assert x_i is not None
39924015
layer_states, layer_outputs = self.layer.extend_step(
3993-
cached_states=x_i, **carry, **layer_kwargs
4016+
cached_states=x_i,
4017+
**carry,
4018+
**layer_kwargs,
4019+
page_pool=page_pool,
39944020
)
39954021
else:
39964022
raise ValueError(f"Unrecognized mode {mode}.")
@@ -4005,6 +4031,7 @@ def layer_fn(carry, x_i):
40054031
return carry, ys
40064032

40074033
ys.update({k: v for k, v in layer_outputs._asdict().items() if k not in carry})
4034+
ys["page_pool"] = page_pool
40084035
return {k: getattr(layer_outputs, k) for k in carry}, ys
40094036

40104037
if cfg.carry is None:
@@ -4013,10 +4040,16 @@ def layer_fn(carry, x_i):
40134040
layer_kwargs["data"] = data
40144041
carry = {k: layer_kwargs.pop(k) for k in cfg.carry}
40154042

4016-
repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states)
4043+
page_pool = layer_kwargs.pop("page_pool", None)
4044+
repeat_outputs: Repeat.Output = self._run(
4045+
layer_fn, carry=carry, xs=(cached_states, page_pool)
4046+
)
40174047
carry = repeat_outputs.carry
40184048
ys = repeat_outputs.ys
40194049
updated_states = ys.pop("cached_states", None)
4050+
out_page_pool = ys.pop("page_pool", None)
4051+
if page_pool is not None and out_page_pool is not None:
4052+
page_pool[:] = out_page_pool # type: ignore
40204053

40214054
if cache_init:
40224055
assert ys == {}

axlearn/common/decoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def _forward_for_mode(
496496
cross_attention_data: Optional[Tensor] = None,
497497
cross_attention_logit_biases: Optional[Tensor] = None,
498498
cached_states: Optional[NestedTensor] = None,
499+
page_pool: Optional[Nested[Tensor]] = None,
499500
) -> tuple[Optional[NestedTensor], Tensor]:
500501
validate_contains_paths(input_batch, paths=["input_ids"])
501502
input_segment_ids = input_batch.get("input_segment_ids", None)
@@ -538,6 +539,7 @@ def _forward_for_mode(
538539
self_attention_logit_biases=self_attention_logit_biases,
539540
cross_attention_data=cross_attention_data,
540541
cross_attention_logit_biases=cross_attention_logit_biases,
542+
page_pool=page_pool,
541543
)
542544
else:
543545
raise ValueError(f"Unrecognized mode {mode}.")

axlearn/common/kv_cache/base_kv_cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def extend_step(
8484
v_proj: Tensor,
8585
key_positions: Tensor,
8686
live_step_len: Optional[Tensor] = None,
87+
page_pool: Optional[Nested[Tensor]] = None,
8788
) -> tuple[Nested[Tensor], Output]:
8889
"""Updates the KV cache per extend step.
8990
@@ -96,8 +97,9 @@ def extend_step(
9697
k_proj: A Tensor of shape [batch, step_length, num_kv_heads, per_head_dim].
9798
v_proj: A Tensor of shape [batch, step_length, num_kv_heads, per_head_dim].
9899
key_positions: An optional Tensor of shape [1|batch, step_length].
99-
live_step_len: An optional Tensor of shape [batch]. Please refer to ``On live_step_len``
100-
in the file docstring for details.
100+
live_step_len: An optional Tensor of shape [batch]. See file-level docstring of
101+
`attention.py`
102+
page_pool: See file-level docstring of `attention.py`.
101103
102104
Returns:
103105
A tuple (updated_state, output):

axlearn/common/kv_cache/kv_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ def extend_step(
3838
v_proj: Tensor,
3939
key_positions: Tensor,
4040
live_step_len: Optional[Tensor] = None,
41+
page_pool: Optional[Nested[Tensor]] = None,
4142
) -> tuple[Nested[Tensor], BaseKVCache.Output]:
4243
# TODO(dhwang2): By returning only the valid portions of the KV (by live_step_len),
4344
# the attention complexity can be reduced from O(max_len²) to O(live_step_len²), especially
4445
# in prefill.
4546
# The remaining part after `live_step_len` is considered padding.
47+
assert page_pool is None
4648
del live_step_len
4749
if k_proj.shape != v_proj.shape:
4850
raise ValueError(f"{k_proj.shape=} != {v_proj.shape=}")

axlearn/common/kv_cache/paged_kv_cache.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def extend_step(
133133
v_proj: Tensor,
134134
key_positions: Tensor,
135135
live_step_len: Optional[Tensor] = None,
136+
page_pool: Optional[Nested[Tensor]] = None,
136137
) -> tuple[Nested[Tensor], KVCache.Output]:
137138
"""Extend the cache with the new key and value.
138139
@@ -166,6 +167,7 @@ def extend_step(
166167
raise ValueError(f"{k_proj.shape[1]=} != {key_positions.shape[1]=}")
167168

168169
if "page_indices" not in cached_states:
170+
assert page_pool is None
169171
# Prefill, return kv cache directly
170172
cached_states["key"] = k_proj
171173
cached_states["value"] = v_proj
@@ -177,8 +179,21 @@ def extend_step(
177179

178180
# kv_pages shape: [num_heads, max_pages_global, page_size, head_dim]. Also refer to
179181
# https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L388
180-
k_pages: Tensor = cached_states["key"]
181-
v_pages: Tensor = cached_states["value"]
182+
if page_pool is not None:
183+
# We use `group_info` to index into `page_pool` to get the paged KV pool for this
184+
# layer.
185+
group_info = cached_states["group_info"]
186+
# HACK(hanzhi-zhou): we store the indices as dict keys to workaround them being
187+
# converted to tracers.
188+
group_idx = list(group_info["group_idx"].keys())[0]
189+
repeat_idx = list(group_info["repeat_idx"].keys())[0]
190+
pool = page_pool[group_idx][repeat_idx]
191+
k_pages: Tensor = pool.k_pages # type: ignore
192+
v_pages: Tensor = pool.v_pages # type: ignore
193+
else:
194+
k_pages: Tensor = cached_states["key"]
195+
v_pages: Tensor = cached_states["value"]
196+
182197
assert k_pages.shape == v_pages.shape
183198

184199
batch = page_indices.shape[0]
@@ -220,11 +235,18 @@ def update_kv_pages(kv_pages, page_indices, key_positions, kv_proj):
220235
v_pages, page_indices, key_positions, v_proj.astype(v_pages.dtype)
221236
)
222237

223-
updated_state = dict(
224-
key=updated_k_pages,
225-
value=updated_v_pages,
226-
page_indices=page_indices,
227-
)
238+
if page_pool is not None:
239+
page_pool[group_idx][repeat_idx] = type(pool)(updated_k_pages, updated_v_pages)
240+
241+
# Updates are already performed through mutable arrays above. We don't perform state
242+
# updates through `updated_state`.
243+
updated_state = dict(key=None, value=None, page_indices=None)
244+
else:
245+
updated_state = dict(
246+
key=updated_k_pages,
247+
value=updated_v_pages,
248+
page_indices=page_indices,
249+
)
228250

229251
assert updated_k_pages.shape == k_pages.shape
230252
assert updated_v_pages.shape == v_pages.shape

axlearn/common/kv_cache/sliding_window_kv_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def extend_step(
5353
v_proj: Tensor,
5454
key_positions: Tensor,
5555
live_step_len: Optional[Tensor] = None,
56+
page_pool: Optional[Nested[Tensor]] = None,
5657
) -> tuple[Nested[Tensor], BaseKVCache.Output]:
5758
"""Updates the sliding window KV cache per extend step.
5859
@@ -70,6 +71,7 @@ def extend_step(
7071
* output: The output k_proj, v_proj, and key_positions, which are merged with
7172
KV cache, resulting in a length of `cached_kv_length + step_size`.
7273
"""
74+
assert page_pool is None
7375
cfg = self.config
7476
cached_key: Tensor = cached_states["key"]
7577
cached_value: Tensor = cached_states["value"]

axlearn/common/module.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,8 @@ class _Functional:
10021002
context: InvocationContext = struct.field(pytree_node=True)
10031003
# Whether to require that context.parent is current_context().
10041004
require_parent: bool = struct.field(pytree_node=False)
1005+
# Whether to copy the argument pytrees to prevent method_fn from mutating the original.
1006+
copy_args_tree: bool = struct.field(pytree_node=False, default=True)
10051007

10061008
def __call__(self, *args, **kwargs) -> tuple[Any, OutputCollection]:
10071009
"""Invokes method_fn in a pure functional fashion.
@@ -1024,13 +1026,14 @@ def __call__(self, *args, **kwargs) -> tuple[Any, OutputCollection]:
10241026
call = getattr(self.method_fn, "__qualname__", None) or getattr(self.method_fn, "__name__")
10251027
logging.vlog(1, "functional: %s.%s (*%s, **%s)", call, self.method_fn, args, kwargs)
10261028

1027-
# Copy to prevent method_fn from mutating the original.
10281029
# Some badly behaved tests call F() with an InvocationContext.state that contains
10291030
# circular references.
10301031
# This results in a cryptic error that doesn't make the root cause obvious.
10311032
# So we raise a clearer error explicitly.
10321033
raise_for_cycles(dict(context=self.context, args=args, kwargs=kwargs))
1033-
context, args, kwargs = jax.tree.map(lambda x: x, (self.context, args, kwargs))
1034+
context = self.context
1035+
if self.copy_args_tree:
1036+
context, args, kwargs = jax.tree.map(lambda x: x, (self.context, args, kwargs))
10341037

10351038
with set_current_context(context, require_parent=self.require_parent):
10361039
# pylint: disable-next=not-an-iterable,not-a-mapping,not-callable
@@ -1047,6 +1050,7 @@ def functional(
10471050
method: str = "forward",
10481051
is_training: bool,
10491052
drop_output_collections: Sequence[str] = ("module_outputs",),
1053+
copy_args_tree: bool = True,
10501054
) -> tuple[Any, OutputCollection]:
10511055
"""Invokes <module>.<method> in a pure functional fashion.
10521056
@@ -1065,6 +1069,8 @@ def functional(
10651069
is_training: Whether the invocation should run in the training mode.
10661070
drop_output_collections: The output collection types to drop.
10671071
Defaults to dropping all module outputs.
1072+
copy_args_tree: Whether to copy the `inputs` pytree to prevent method_fn from mutating the
1073+
original. Defaults to True.
10681074
10691075
Returns:
10701076
(method_outputs, output_collection), where
@@ -1092,7 +1098,9 @@ def functional(
10921098
args = inputs
10931099
method_fn = getattr(module, method)
10941100

1095-
fn = _Functional(context=context, method_fn=method_fn, require_parent=True)
1101+
fn = _Functional(
1102+
context=context, method_fn=method_fn, require_parent=True, copy_args_tree=copy_args_tree
1103+
)
10961104
method_outputs, output_collection = fn(*args, **kwargs)
10971105

10981106
for output_collection_type in drop_output_collections:

0 commit comments

Comments
 (0)