Skip to content

Commit 138c806

Browse files
committed
add training_seq_len to eagleconfig; remove deprecated unittest
Signed-off-by: h-guo18 <[email protected]>
1 parent 3531467 commit 138c806

File tree

7 files changed

+18
-126
lines changed

7 files changed

+18
-126
lines changed

examples/speculative_decoding/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,12 @@ def train():
185185
}[training_args.mode]["config"]
186186

187187
# overwrite config with custom config
188-
if use_offline_training:
189-
config["eagle_offline"] = True
188+
config.update(
189+
{
190+
"eagle_offline": use_offline_training,
191+
"eagle_training_seq_len": training_args.training_seq_len,
192+
}
193+
)
190194

191195
if eagle_args.eagle_config:
192196
with open(eagle_args.eagle_config) as f:
@@ -203,8 +207,6 @@ def train():
203207
"draft_vocab_size": custom_config["draft_vocab_size"]
204208
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
205209
else model.config.vocab_size,
206-
# pass in the seq length for flex attention mask compilation
207-
"training_seq_len": training_args.training_seq_len,
208210
}
209211
)
210212

modelopt/torch/speculative/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class EagleConfig(ModeloptBaseConfig):
9595
default=0.9, description=("The decay factor for multiple eagle_loss.")
9696
)
9797

98+
eagle_training_seq_len: int = ModeloptField(
99+
default=1024, description=("The training sequence length.")
100+
)
101+
98102
eagle_architecture_config: dict = ModeloptField(
99103
default={}, description=("The config for eagle module architecture.")
100104
)

modelopt/torch/speculative/eagle/conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
4747
eagle_report_acc=config.eagle_report_acc,
4848
eagle_reuse_base_decoder=config.eagle_reuse_base_decoder,
4949
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
50+
eagle_training_seq_len=config.eagle_training_seq_len,
5051
eagle_architecture_config=config.eagle_architecture_config,
5152
)
5253

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def modify(
3535
eagle_report_acc,
3636
eagle_reuse_base_decoder,
3737
eagle_loss_decay_factor,
38+
eagle_training_seq_len,
3839
eagle_architecture_config,
3940
):
4041
"""Base Eagle Model modify function. Child class should implement the details."""
@@ -45,7 +46,7 @@ def modify(
4546
self.eagle_report_acc = eagle_report_acc
4647
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
4748
self.eagle_loss_decay_factor = eagle_loss_decay_factor
48-
49+
self.eagle_training_seq_len = eagle_training_seq_len
4950
if eagle_architecture_config.get("parallel_draft_step", 1) > 1:
5051
for i in range(eagle_architecture_config.get("parallel_draft_step") - 1):
5152
self.register_buffer(f"mask_token_{i}", torch.tensor(-1))

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ def modify(
721721
eagle_report_acc,
722722
eagle_reuse_base_decoder,
723723
eagle_loss_decay_factor,
724+
eagle_training_seq_len,
724725
eagle_architecture_config,
725726
):
726727
if self.config.pipeline_model_parallel_size > 1:
@@ -742,6 +743,7 @@ def modify(
742743
eagle_report_acc=eagle_report_acc,
743744
eagle_reuse_base_decoder=eagle_reuse_base_decoder,
744745
eagle_loss_decay_factor=eagle_loss_decay_factor,
746+
eagle_training_seq_len=eagle_training_seq_len,
745747
eagle_architecture_config=eagle_architecture_config,
746748
)
747749

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def modify(
382382
eagle_report_acc,
383383
eagle_reuse_base_decoder,
384384
eagle_loss_decay_factor,
385+
eagle_training_seq_len,
385386
eagle_architecture_config,
386387
):
387388
"""Constructor.
@@ -397,6 +398,7 @@ def modify(
397398
eagle_report_acc=eagle_report_acc,
398399
eagle_reuse_base_decoder=eagle_reuse_base_decoder,
399400
eagle_loss_decay_factor=eagle_loss_decay_factor,
401+
eagle_training_seq_len=eagle_training_seq_len,
400402
eagle_architecture_config=eagle_architecture_config,
401403
)
402404
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
@@ -454,7 +456,7 @@ def modify(
454456
self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
455457
# compile and cach flex attention masks
456458
self.cached_attn_blk_masks = [
457-
self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
459+
self._compile_ttt_block_mask(eagle_training_seq_len, i)
458460
for i in range(self.num_ttt_steps)
459461
]
460462

tests/unit/torch/speculative/plugins/test_hf_speculative.py

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from copy import deepcopy
1818

1919
import pytest
20-
import torch
2120
from _test_utils.torch_model.transformers_models import (
2221
create_tiny_llama_dir,
2322
get_tiny_llama,
@@ -69,122 +68,3 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config):
6968
model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model")
7069
assert isinstance(model_test, mtsp.plugins.HFEagleModel)
7170
tf_modelopt_state_and_output_tester(model_ref, model_test)
72-
73-
74-
# fmt: off
75-
@pytest.mark.parametrize("dtype", [torch.bfloat16])
76-
def test_eagle_model_prepare_eagle_inputs(dtype):
77-
dummy_model = get_tiny_llama(num_hidden_layers=4)
78-
79-
config = EAGLE3_DEFAULT_CFG["config"]
80-
config["eagle_architecture_config"].update({
81-
"draft_vocab_size": dummy_model.config.vocab_size,
82-
"hidden_size": dummy_model.config.hidden_size,
83-
})
84-
mtsp.convert(dummy_model, mode=[("eagle", config)])
85-
86-
eagle_input_ids_0 = torch.tensor([[10, 20, 30, 40]], dtype=torch.long)
87-
position_ids_0 = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
88-
89-
90-
#This is concatenated from 3 intermediate base model layers
91-
cat_aux_hidden_states = torch.randn(1, 4, 32, dtype=dtype)
92-
93-
#This is eagle output from previous eagle forward pass
94-
dummy_eagle_output_hidden_states = torch.randn(1, 4, 32, dtype=dtype)
95-
96-
#This is the causal mask for the 0th eagle step
97-
m = torch.finfo(dtype).min
98-
attention_mask_0 = torch.tensor([[0, m, m, m], # input tok 10-> predicting token 20
99-
[0, 0, m, m], # 20 -> 30
100-
[0, 0, 0, m], # 30 -> 40
101-
[0, 0, 0, 0]] # 40 -> tok after 40
102-
103-
, dtype=dtype).view(1, 1, 4, 4)
104-
105-
# 2nd eagle step
106-
eagle_input_h_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = dummy_model._concat_eagle_inputs(
107-
eagle_input_ids_0,
108-
cat_aux_hidden_states,
109-
attention_mask_0,
110-
position_ids_0,
111-
dummy_eagle_output_hidden_states,
112-
)
113-
114-
assert eagle_input_ids_1.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long))
115-
assert position_ids_1.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long))
116-
117-
assert attention_mask_1.equal(torch.tensor([[0, m, m, m, m, m, m, m], # (x) output discarded
118-
[0, 0, m, m, m, m, m, m], # (x)
119-
[0, 0, 0, m, m, m, m, m], # (x)
120-
[0, 0, 0, 0, m, m, m, m], # (x)
121-
122-
[m, m, m, m, m, m, m, m], # (x) input tok 10-> predicting token 20
123-
[0, m, m, m, m, 0, m, m], # 20 -> 30
124-
[0, 0, m, m, m, m, 0, m], # 30 -> 40
125-
[0, 0, 0, 0, m, m, m, m], # (x) 40 -> tok after 40
126-
], dtype=dtype).view(1, 1, 8, 8))
127-
128-
# 3rd eagle step
129-
eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = dummy_model._concat_eagle_inputs(
130-
eagle_input_ids_0,
131-
cat_aux_hidden_states,
132-
attention_mask_0,
133-
position_ids_0,
134-
torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states], dim=1),
135-
)
136-
assert eagle_input_ids_2.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long))
137-
assert position_ids_2.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long))
138-
139-
assert attention_mask_2.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m], # (x)
140-
[0, 0, m, m, m, m, m, m, m, m, m, m], # (x)
141-
[0, 0, 0, m, m, m, m, m, m, m, m, m], # (x)
142-
[0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x)
143-
144-
[m, m, m, m, m, m, m, m, m, m, m, m], # (x)
145-
[0, m, m, m, m, 0, m, m, m, m, m, m], # (x)
146-
[0, 0, m, m, m, m, 0, m, m, m, m, m], # (x)
147-
[0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x)
148-
149-
[m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20
150-
[m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30
151-
[0, m, m, m, m, 0, m, m, m, m, 0, m], # 30 -> 40
152-
[0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) 40 -> tok after 40
153-
154-
], dtype=dtype).view(1, 1, 12, 12))
155-
156-
# 4th eagle step
157-
eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = dummy_model._concat_eagle_inputs(
158-
eagle_input_ids_0,
159-
cat_aux_hidden_states,
160-
attention_mask_0,
161-
position_ids_0,
162-
torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states,
163-
dummy_eagle_output_hidden_states],dim=1),
164-
)
165-
166-
assert eagle_input_ids_3.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40,
167-
10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long))
168-
assert position_ids_3.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long))
169-
170-
assert attention_mask_3.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
171-
[0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
172-
[0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
173-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
174-
175-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
176-
[0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m, m], # (x)
177-
[0, 0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m], # (x)
178-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
179-
180-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
181-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
182-
[0, m, m, m, m, 0, m, m, m, m, 0, m, m, m, m, m], # (x)
183-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
184-
185-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20
186-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30
187-
[m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
188-
[0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x)
189-
190-
], dtype=dtype).view(1, 1, 16, 16))

0 commit comments

Comments
 (0)