Skip to content

Commit 67cc30b

Browse files
committed
compile attention mask in first step
Signed-off-by: h-guo18 <[email protected]>
1 parent 254e546 commit 67cc30b

File tree

4 files changed

+10
-129
lines changed

4 files changed

+10
-129
lines changed

examples/speculative_decoding/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,6 @@ def train():
203203
"draft_vocab_size": custom_config["draft_vocab_size"]
204204
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
205205
else model.config.vocab_size,
206-
# pass in the seq length for flex attention mask compilation
207-
"training_seq_len": training_args.training_seq_len,
208206
}
209207
)
210208

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def modify(
4545
self.eagle_report_acc = eagle_report_acc
4646
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
4747
self.eagle_loss_decay_factor = eagle_loss_decay_factor
48-
4948
if eagle_architecture_config.get("parallel_draft_step", 1) > 1:
5049
for i in range(eagle_architecture_config.get("parallel_draft_step") - 1):
5150
self.register_buffer(f"mask_token_{i}", torch.tensor(-1))

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,15 @@ def modify(
452452
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
453453

454454
self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
455-
# compile and cach flex attention masks
456-
self.cached_attn_blk_masks = [
457-
self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
458-
for i in range(self.num_ttt_steps)
459-
]
455+
self._cached_attn_blk_masks = []
456+
457+
def _get_ttt_attention_mask(self, seq_length, ttt_step):
458+
# compile and cached flex attention masks in first call
459+
if ttt_step >= len(self._cached_attn_blk_masks):
460+
self._cached_attn_blk_masks.append(self._compile_ttt_block_mask(seq_length, ttt_step))
461+
462+
# return cached flex attention mask
463+
return self._cached_attn_blk_masks[ttt_step]
460464

461465
def _prepare_decoder_attention_mask(
462466
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -773,7 +777,7 @@ def forward(
773777
),
774778
dim=1,
775779
)
776-
attention_mask = self.cached_attn_blk_masks[ttt_step]
780+
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step)
777781
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
778782
eagle_input_hidden_states,
779783
inputs_embeds,

tests/unit/torch/speculative/plugins/test_hf_speculative.py

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from copy import deepcopy
1818

1919
import pytest
20-
import torch
2120
from _test_utils.torch_model.transformers_models import (
2221
create_tiny_llama_dir,
2322
get_tiny_llama,
@@ -69,122 +68,3 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config):
6968
model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model")
7069
assert isinstance(model_test, mtsp.plugins.HFEagleModel)
7170
tf_modelopt_state_and_output_tester(model_ref, model_test)
72-
73-
74-
# fmt: off
75-
@pytest.mark.parametrize("dtype", [torch.bfloat16])
76-
def test_eagle_model_prepare_eagle_inputs(dtype):
77-
dummy_model = get_tiny_llama(num_hidden_layers=4)
78-
79-
config = EAGLE3_DEFAULT_CFG["config"]
80-
config["eagle_architecture_config"].update({
81-
"draft_vocab_size": dummy_model.config.vocab_size,
82-
"hidden_size": dummy_model.config.hidden_size,
83-
})
84-
mtsp.convert(dummy_model, mode=[("eagle", config)])
85-
86-
eagle_input_ids_0 = torch.tensor([[10, 20, 30, 40]], dtype=torch.long)
87-
position_ids_0 = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
88-
89-
90-
#This is concatenated from 3 intermediate base model layers
91-
cat_aux_hidden_states = torch.randn(1, 4, 32, dtype=dtype)
92-
93-
#This is eagle output from previous eagle forward pass
94-
dummy_eagle_output_hidden_states = torch.randn(1, 4, 32, dtype=dtype)
95-
96-
#This is the causal mask for the 0th eagle step
97-
m = torch.finfo(dtype).min
98-
attention_mask_0 = torch.tensor([[0, m, m, m], # input tok 10-> predicting token 20
99-
[0, 0, m, m], # 20 -> 30
100-
[0, 0, 0, m], # 30 -> 40
101-
[0, 0, 0, 0]] # 40 -> tok after 40
102-
103-
, dtype=dtype).view(1, 1, 4, 4)
104-
105-
# 2nd eagle step
106-
eagle_input_h_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = dummy_model._concat_eagle_inputs(
107-
eagle_input_ids_0,
108-
cat_aux_hidden_states,
109-
attention_mask_0,
110-
position_ids_0,
111-
dummy_eagle_output_hidden_states,
112-
)
113-
114-
assert eagle_input_ids_1.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long))
115-
assert position_ids_1.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long))
116-
117-
assert attention_mask_1.equal(torch.tensor([[0, m, m, m, m, m, m, m], # (x) output discarded
118-
[0, 0, m, m, m, m, m, m], # (x)
119-
[0, 0, 0, m, m, m, m, m], # (x)
120-
[0, 0, 0, 0, m, m, m, m], # (x)
121-
122-
[m, m, m, m, m, m, m, m], # (x) input tok 10-> predicting token 20
123-
[0, m, m, m, m, 0, m, m], # 20 -> 30
124-
[0, 0, m, m, m, m, 0, m], # 30 -> 40
125-
[0, 0, 0, 0, m, m, m, m], # (x) 40 -> tok after 40
126-
], dtype=dtype).view(1, 1, 8, 8))
127-
128-
# 3rd eagle step
129-
eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = dummy_model._concat_eagle_inputs(
130-
eagle_input_ids_0,
131-
cat_aux_hidden_states,
132-
attention_mask_0,
133-
position_ids_0,
134-
torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states], dim=1),
135-
)
136-
assert eagle_input_ids_2.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long))
137-
assert position_ids_2.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long))
138-
139-
assert attention_mask_2.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m], # (x)
140-
[0, 0, m, m, m, m, m, m, m, m, m, m], # (x)
141-
[0, 0, 0, m, m, m, m, m, m, m, m, m], # (x)
142-
[0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x)
143-
144-
[m, m, m, m, m, m, m, m, m, m, m, m], # (x)
145-
[0, m, m, m, m, 0, m, m, m, m, m, m], # (x)
146-
[0, 0, m, m, m, m, 0, m, m, m, m, m], # (x)
147-
[0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x)
148-
149-
[m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20
150-
[m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30
151-
[0, m, m, m, m, 0, m, m, m, m, 0, m], # 30 -> 40
152-
[0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) 40 -> tok after 40
153-
154-
], dtype=dtype).view(1, 1, 12, 12))
155-
156-
# 4th eagle step
157-
eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = dummy_model._concat_eagle_inputs(
158-
eagle_input_ids_0,
159-
cat_aux_hidden_states,
160-
attention_mask_0,
161-
position_ids_0,
162-
torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states,
163-
dummy_eagle_output_hidden_states],dim=1),
164-
)
165-
166-
assert eagle_input_ids_3.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40,
167-
10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long))
168-
assert position_ids_3.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long))
169-
170-
assert attention_mask_3.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
171-
[0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
172-
[0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
173-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
174-
175-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
176-
[0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m, m], # (x)
177-
[0, 0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m], # (x)
178-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
179-
180-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
181-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
182-
[0, m, m, m, m, 0, m, m, m, m, 0, m, m, m, m, m], # (x)
183-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
184-
185-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20
186-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30
187-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
188-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
189-
190-
], dtype=dtype).view(1, 1, 16, 16))

0 commit comments

Comments
 (0)