76
76
* When the accompanying argument is `query`, the `positions` argument is named as
77
77
`query_position`. Similarly, when the argument `target`, it is named as `target_positions`.
78
78
79
+ On `page_pool`:
80
+ * If not None, stores the external paged KV pool possibly shared by multiple
81
+ layers. Additionally, `cached_states` will not contain KV state. Instead, it will
82
+ contain indices used to index into `page_pool`.
83
+
79
84
TODO(changlan): Merge the use of `positions` and `time_step` to reduce cognitive complexity.
80
85
81
86
"""
@@ -317,6 +322,7 @@ def extend_step(
317
322
self_attention_logit_biases : Optional [Tensor ] = None ,
318
323
cross_attention_data : Optional [Tensor ] = None ,
319
324
cross_attention_logit_biases : Optional [Tensor ] = None ,
325
+ page_pool : Optional [Nested [Tensor ]] = None ,
320
326
) -> tuple [NestedTensor , Output ]:
321
327
"""Computes incremental outputs.
322
328
@@ -335,6 +341,7 @@ def extend_step(
335
341
cross_attention_logit_biases: An optional Tensor of shape
336
342
[..., target_step_length, source_length], where `target_step_length` must match
337
343
the shape of `data`.
344
+ page_pool: See file-level comments on `page_pool`.
338
345
339
346
Returns:
340
347
(updated_cached_states, output), where:
@@ -1653,6 +1660,7 @@ def _forward_for_mode(
1653
1660
query_positions : Optional [Tensor ] = None ,
1654
1661
cached_states : Optional [NestedTensor ] = None ,
1655
1662
return_aux : Optional [set [str ]] = None ,
1663
+ page_pool : Optional [Nested [Tensor ]] = None ,
1656
1664
) -> tuple [Nested [Tensor ], Optional [Output ]]:
1657
1665
"""Computes attention for the given query, key, value, and attention logit biases.
1658
1666
@@ -1672,6 +1680,7 @@ def _forward_for_mode(
1672
1680
query_positions: See ``On positions`` in the file comments.
1673
1681
cached_states: Optional NestedTensor as produced by `init_states`.
1674
1682
return_aux: See comments on `Output`.
1683
+ page_pool: See file-level comments on `page_pool`.
1675
1684
1676
1685
Returns:
1677
1686
A tuple (cached_states, output):
@@ -1749,6 +1758,7 @@ def _forward_for_mode(
1749
1758
v_proj = v_proj ,
1750
1759
key_positions = query_positions ,
1751
1760
live_step_len = live_step_len ,
1761
+ page_pool = page_pool ,
1752
1762
)
1753
1763
if mode == ForwardMode .EXTEND_STEP :
1754
1764
kv_state = KVState (* kv_cache_output )
@@ -2042,6 +2052,7 @@ def extend_step(
2042
2052
kv_state : Optional [KVState ] = None ,
2043
2053
attention_logit_biases : Optional [Tensor ] = None ,
2044
2054
return_aux : Optional [set [str ]] = None ,
2055
+ page_pool : Optional [Nested [Tensor ]] = None ,
2045
2056
) -> tuple [NestedTensor , Output ]:
2046
2057
"""Computes the value vector given the query of the current step.
2047
2058
This function is used by autoregressive decoding.
@@ -2068,6 +2079,7 @@ def extend_step(
2068
2079
The biases should already include causal masking for decoding, plus other biases
2069
2080
if necessary.
2070
2081
return_aux: See comments on `Output`.
2082
+ page_pool: See file-level comments on `page_pool`.
2071
2083
2072
2084
Returns:
2073
2085
A `NestedTensor` state of key and value pair along with index updated at `time_step`.
@@ -2083,6 +2095,7 @@ def extend_step(
2083
2095
kv_state = kv_state ,
2084
2096
attention_logit_biases = attention_logit_biases ,
2085
2097
return_aux = return_aux ,
2098
+ page_pool = page_pool ,
2086
2099
)
2087
2100
2088
2101
@staticmethod
@@ -2640,6 +2653,7 @@ def _forward_for_mode(
2640
2653
target_positions : Optional [Tensor ] = None ,
2641
2654
cached_states : Optional [NestedTensor ] = None ,
2642
2655
return_aux : Optional [set [str ]] = None ,
2656
+ page_pool : Optional [Nested [Tensor ]] = None ,
2643
2657
) -> tuple [Optional [Nested [Tensor ]], Optional [Output ]]:
2644
2658
"""Computes either self-attention or cross-attention for the given target and source.
2645
2659
@@ -2654,6 +2668,7 @@ def _forward_for_mode(
2654
2668
target_positions: See ``On positions`` in the file comments.
2655
2669
cached_states: Optional NestedTensor as produced by `init_states`.
2656
2670
return_aux: See comments on `Output`.
2671
+ page_pool: See file-level comments on `page_pool`.
2657
2672
2658
2673
Returns:
2659
2674
A tuple (cached_states, output):
@@ -2709,6 +2724,7 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
2709
2724
target ,
2710
2725
** kv_kwargs ,
2711
2726
attention_logit_biases = attention_logit_biases ,
2727
+ page_pool = page_pool ,
2712
2728
)
2713
2729
else :
2714
2730
raise ValueError (f"Unrecognized mode { mode } ." )
@@ -2841,6 +2857,7 @@ def extend_step(
2841
2857
source : Optional [Union [Tensor , KVState ]] = None ,
2842
2858
attention_logit_biases : Optional [Tensor ] = None ,
2843
2859
return_aux : Optional [set [str ]] = None ,
2860
+ page_pool : Optional [Nested [Tensor ]] = None ,
2844
2861
) -> tuple [Nested [Tensor ], Output ]:
2845
2862
"""Computes the value vector given the query of the current step.
2846
2863
This function is used by autoregressive decoding.
@@ -2858,6 +2875,7 @@ def extend_step(
2858
2875
attention_logit_biases should have already taken care of causal masking for
2859
2876
decoding, plus other maskings necessary.
2860
2877
return_aux: See comments on `Output`.
2878
+ page_pool: See file-level comments on `page_pool`.
2861
2879
2862
2880
Returns:
2863
2881
A `NestedTensor` state of key and value pair along with index updated at `time_step`.
@@ -2874,6 +2892,7 @@ def extend_step(
2874
2892
cached_states = cached_states ,
2875
2893
attention_logit_biases = attention_logit_biases ,
2876
2894
return_aux = return_aux ,
2895
+ page_pool = page_pool ,
2877
2896
)
2878
2897
2879
2898
@@ -3196,6 +3215,7 @@ def _forward_for_mode(
3196
3215
target_positions : Optional [Tensor ] = None ,
3197
3216
cached_states : Optional [NestedTensor ] = None ,
3198
3217
return_aux : Optional [set [str ]] = None ,
3218
+ page_pool : Optional [Nested [Tensor ]] = None ,
3199
3219
) -> tuple [Optional [NestedTensor ], Optional [BaseTransformerLayer .Output ]]:
3200
3220
"""Computes transformer layer outputs and self/cross-attention probabilities.
3201
3221
@@ -3212,6 +3232,7 @@ def _forward_for_mode(
3212
3232
target_positions: See ``positions`` in the file comments.
3213
3233
cached_states: Optional NestedTensor as produced by `init_states`.
3214
3234
return_aux: See comments on BaseTransformerLayer.forward.
3235
+ page_pool: See file-level comments on `page_pool`.
3215
3236
3216
3237
Returns:
3217
3238
A tuple (cached_states, output):
@@ -3273,6 +3294,7 @@ def _forward_for_mode(
3273
3294
source = self_attention_kv_state ,
3274
3295
attention_logit_biases = self_attention_logit_biases ,
3275
3296
return_aux = self_attention_return_aux ,
3297
+ page_pool = page_pool ,
3276
3298
)
3277
3299
else :
3278
3300
raise ValueError (f"Unrecognized mode { mode } ." )
@@ -3979,6 +4001,7 @@ def _forward_for_mode(
3979
4001
assert value .shape [0 ] == cfg .num_layers , f"{ path } ={ shapes (value )} "
3980
4002
3981
4003
def layer_fn (carry , x_i ):
4004
+ x_i , page_pool = x_i
3982
4005
if mode == ForwardMode .FORWARD :
3983
4006
layer_states , layer_outputs = None , self .layer (** carry , ** layer_kwargs )
3984
4007
elif mode == ForwardMode .INIT_STATES :
@@ -3990,7 +4013,10 @@ def layer_fn(carry, x_i):
3990
4013
elif mode == ForwardMode .EXTEND_STEP :
3991
4014
assert x_i is not None
3992
4015
layer_states , layer_outputs = self .layer .extend_step (
3993
- cached_states = x_i , ** carry , ** layer_kwargs
4016
+ cached_states = x_i ,
4017
+ ** carry ,
4018
+ ** layer_kwargs ,
4019
+ page_pool = page_pool ,
3994
4020
)
3995
4021
else :
3996
4022
raise ValueError (f"Unrecognized mode { mode } ." )
@@ -4005,6 +4031,7 @@ def layer_fn(carry, x_i):
4005
4031
return carry , ys
4006
4032
4007
4033
ys .update ({k : v for k , v in layer_outputs ._asdict ().items () if k not in carry })
4034
+ ys ["page_pool" ] = page_pool
4008
4035
return {k : getattr (layer_outputs , k ) for k in carry }, ys
4009
4036
4010
4037
if cfg .carry is None :
@@ -4013,10 +4040,16 @@ def layer_fn(carry, x_i):
4013
4040
layer_kwargs ["data" ] = data
4014
4041
carry = {k : layer_kwargs .pop (k ) for k in cfg .carry }
4015
4042
4016
- repeat_outputs : Repeat .Output = self ._run (layer_fn , carry = carry , xs = cached_states )
4043
+ page_pool = layer_kwargs .pop ("page_pool" , None )
4044
+ repeat_outputs : Repeat .Output = self ._run (
4045
+ layer_fn , carry = carry , xs = (cached_states , page_pool )
4046
+ )
4017
4047
carry = repeat_outputs .carry
4018
4048
ys = repeat_outputs .ys
4019
4049
updated_states = ys .pop ("cached_states" , None )
4050
+ out_page_pool = ys .pop ("page_pool" , None )
4051
+ if page_pool is not None and out_page_pool is not None :
4052
+ page_pool [:] = out_page_pool # type: ignore
4020
4053
4021
4054
if cache_init :
4022
4055
assert ys == {}
0 commit comments