Skip to content

Commit b47928e

Browse files
committed
debug: reduce kv cache size from ttt*parallel to ttt+parallel-1; in each ttt step, only the non_parallel tokens from previous ttt are used as context
Signed-off-by: Ye Yu <[email protected]>
1 parent cb9282e commit b47928e

File tree

1 file changed

+77
-46
lines changed

1 file changed

+77
-46
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 77 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -195,48 +195,72 @@ def set_multi_step_attention_mask(attn_mask, step):
195195
h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states)
196196
l0 l1 l2 l3 l4 l5 l6 l7 (base labels)
197197
198-
ttt_step=2
199-
parallel_draft_step=2
200-
->step=3
201-
202-
| i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- |
203-
(out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 |
204-
=======================================================================================================================
205-
F1 l1 | i1 h0 | x | | | |
206-
F2 l2 | i2 h1 | x x | | | |
207-
F3 l3 | i3 h2 | x x x | | | |
208-
F4 l4 | i4 h3 | x x x x | | | |
209-
F5 l5 | i5 h4 | x x x x x | | | |
210-
F6 l6 | i6 h5 | x x x x x x | | | |
211-
F7 l7 | i7 h6 | x x x x x x x | | | |
212-
-- -- | -- h7 | o o o o o o o o | | | |
198+
ttt_steps=2
199+
parallel_draft_step=3
200+
201+
202+
ttt_step=0
203+
| i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
204+
| h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 |
205+
=============================================================================================
206+
F1 l1 | i1 h0 | x | | |
207+
F2 l2 | i2 h1 | x x | | |
208+
F3 l3 | i3 h2 | x x x | | |
209+
F4 l4 | i4 h3 | x x x x | | |
210+
F5 l5 | i5 h4 | x x x x x | | |
211+
F6 l6 | i6 h5 | x x x x x x | | |
212+
F7 l7 | i7 h6 | x x x x x x x | | |
213+
-- -- | -- h7 | o o o o o o o o | | |
214+
=============================================================================================
215+
-- -- | m0 -- | | | |
216+
G2 l2 | m0 h1 | x o | x | |
217+
G3 l3 | m0 h2 | x x o | x | |
218+
G4 l4 | m0 h3 | x x x o | x | |
219+
G5 l5 | m0 h4 | x x x x o | x | |
220+
G6 l6 | m0 h5 | x x x x x o | x | |
221+
G7 l7 | m0 h6 | x x x x x x o | x | |
222+
-- -- | -- h7 | | | |
223+
=============================================================================================
224+
-- -- | m1 -- | | | |
225+
-- -- | m1 h1 | | | |
226+
H3 l3 | m1 h2 | x o o | x o | x |
227+
H4 l4 | m1 h3 | x x o o | x o | x |
228+
H5 l5 | m1 h4 | x x x o o | x o | x |
229+
H6 l6 | m1 h5 | x x x x o o | x o | x |
230+
H7 l7 | m1 h6 | x x x x x o o | x o | x |
231+
-- -- | -- h7 | | | |
232+
233+
234+
ttt_step=1
235+
| i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- |
236+
| h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 | -- F1 F2 F3 F4 F5 F6 F7 |
213237
=======================================================================================================================
214238
-- -- | i1 -- | | | | |
215-
G2 l2 | i2 h1 | x o | x | | |
216-
G3 l3 | i3 h2 | x x o | x | | |
217-
G4 l4 | i4 h3 | x x x o | x | | |
218-
G5 l5 | i5 h4 | x x x x o | x | | |
219-
G6 l6 | i6 h5 | x x x x x o | x | | |
220-
G7 l7 | i7 h6 | x x x x x x o | x | | |
221-
-- -- | -- h7 | | | | |
239+
J2 l2 | i2 F1 | x o | x | | |
240+
J3 l3 | i3 F2 | x x o | x | | |
241+
J4 l4 | i4 F3 | x x x o | x | | |
242+
J5 l5 | i5 F4 | x x x x o | x | | |
243+
J6 l6 | i6 F5 | x x x x x o | x | | |
244+
J7 l7 | i7 F6 | x x x x x x o | x | | |
245+
-- -- | -- F7 | | | | |
222246
=======================================================================================================================
223-
-- -- | i1 -- | | | | |
224-
-- -- | i2 -- | | | | |
225-
H3 l3 | i3 G2 | x o o | x o | x | |
226-
H4 l4 | i4 G3 | x x o o | x o | x | |
227-
H5 l5 | i5 G4 | x x x o o | x o | x | |
228-
H6 l6 | i6 G5 | x x x x o o | x o | x | |
229-
H7 l7 | i7 G6 | x x x x x o o | x o | x | |
230-
-- -- | -- G7 | | | | |
231-
=======================================================================================================================
232-
-- -- | m0 -- | | | | |
233247
-- -- | m0 -- | | | | |
234248
-- -- | m0 -- | | | | |
235-
K4 l4 | m0 G3 | x | x | x | x |
236-
K5 l5 | m0 G4 | x x | x | x | x |
237-
K6 l6 | m0 G5 | x x x | x | x | x |
238-
K7 l7 | m0 G6 | x x x x | x | x | x |
239-
-- -- | -- G7 | | | | |
249+
K3 l3 | m0 F2 | x o o | x o | x | | |
250+
K4 l4 | m0 F3 | x x o o | x o | x | |
251+
K5 l5 | m0 F4 | x x x o o | x o | x | |
252+
K6 l6 | m0 F5 | x x x x o o | x o | x | |
253+
K7 l7 | m0 F6 | x x x x x o o | x o | x | |
254+
-- -- | -- F7 | | | | |
255+
=======================================================================================================================
256+
-- -- | m1 -- | | | | |
257+
-- -- | m1 -- | | | | |
258+
-- -- | m1 -- | | | | |
259+
N4 l4 | m1 F3 | x | x | x | x |
260+
N5 l5 | m1 F4 | x x | x | x | x |
261+
N6 l6 | m1 F5 | x x x | x | x | x |
262+
N7 l7 | m1 F6 | x x x x | x | x | x |
263+
-- -- | -- F7 | | | | |
240264
=======================================================================================================================
241265
""" # noqa: E501
242266
s = attn_mask.shape[-1]
@@ -765,7 +789,6 @@ def _get_eagle_module_inputs(
765789
"""Getting EAGLE module inputs."""
766790
b = hidden_states.shape[1]
767791
h = hidden_states.shape[2]
768-
s = input_ids.shape[1]
769792

770793
# [b, 1]
771794
id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device)
@@ -801,8 +824,7 @@ def _get_eagle_module_inputs(
801824
else:
802825
gathered_hidden_states = hidden_states
803826
gathered_features = features
804-
if gathered_features is not None:
805-
feature = gathered_features[-s:]
827+
806828
eagle_inputs["hidden_states"] = (
807829
gathered_hidden_states
808830
if ttt_step == 0
@@ -813,7 +835,7 @@ def _get_eagle_module_inputs(
813835
dtype=hidden_states.dtype,
814836
device=hidden_states.device,
815837
),
816-
feature[:-1, :, :],
838+
gathered_features[:-1, :, :],
817839
)
818840
)
819841
)
@@ -824,12 +846,11 @@ def _get_eagle_module_inputs(
824846
)
825847

826848
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
827-
attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index
849+
attn_mask, ttt_step + parallel_draft_index
828850
)
829851

830852
eagle_inputs["rotary_pos_emb"] = torch.cat(
831-
[rotary_pos_emb]
832-
* (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index + 1),
853+
[rotary_pos_emb] * (ttt_step + parallel_draft_index + 1),
833854
dim=0,
834855
)
835856

@@ -1015,7 +1036,7 @@ def forward(
10151036
# EAGLE kv cache
10161037
eagle_inference_context = StaticInferenceContext(
10171038
input_ids.shape[0],
1018-
input_ids.shape[1] * self.eagle_config.parallel_draft_step * ttt_steps,
1039+
input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1),
10191040
)
10201041

10211042
if self.eagle_offline:
@@ -1087,9 +1108,19 @@ def forward(
10871108
**(extra_block_kwargs or {}),
10881109
)
10891110

1111+
if i == 0:
1112+
next_eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1113+
10901114
eagle_logits.append(eagle_logits_)
10911115
eagle_logits = torch.cat(eagle_logits, dim=0)
1092-
eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1116+
eagle_hidden_states_pre_norm = next_eagle_hidden_states_pre_norm
1117+
1118+
# Discard kv cache for the last parallel_draft_step - 1 tokens
1119+
# as the next ttt_step will only base on the first token in the
1120+
# current ttt_step
1121+
eagle_inference_context.sequence_len_offset -= input_ids.shape[1] * (
1122+
self.eagle_config.parallel_draft_step - 1
1123+
)
10931124

10941125
# If labels are not provided, return the original logits. We only return after
10951126
# all eagle weights have been exercised for quantization calibration purpose.

0 commit comments

Comments
 (0)