|
17 | 17 | from copy import deepcopy
|
18 | 18 |
|
19 | 19 | import pytest
|
20 |
| -import torch |
21 | 20 | from _test_utils.torch_model.transformers_models import (
|
22 | 21 | create_tiny_llama_dir,
|
23 | 22 | get_tiny_llama,
|
@@ -69,122 +68,3 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config):
|
69 | 68 | model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model")
|
70 | 69 | assert isinstance(model_test, mtsp.plugins.HFEagleModel)
|
71 | 70 | tf_modelopt_state_and_output_tester(model_ref, model_test)
|
72 |
| - |
73 |
| - |
74 |
| -# fmt: off |
75 |
| -@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
76 |
| -def test_eagle_model_prepare_eagle_inputs(dtype): |
77 |
| - dummy_model = get_tiny_llama(num_hidden_layers=4) |
78 |
| - |
79 |
| - config = EAGLE3_DEFAULT_CFG["config"] |
80 |
| - config["eagle_architecture_config"].update({ |
81 |
| - "draft_vocab_size": dummy_model.config.vocab_size, |
82 |
| - "hidden_size": dummy_model.config.hidden_size, |
83 |
| - }) |
84 |
| - mtsp.convert(dummy_model, mode=[("eagle", config)]) |
85 |
| - |
86 |
| - eagle_input_ids_0 = torch.tensor([[10, 20, 30, 40]], dtype=torch.long) |
87 |
| - position_ids_0 = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) |
88 |
| - |
89 |
| - |
90 |
| - #This is concatenated from 3 intermediate base model layers |
91 |
| - cat_aux_hidden_states = torch.randn(1, 4, 32, dtype=dtype) |
92 |
| - |
93 |
| - #This is eagle output from previous eagle forward pass |
94 |
| - dummy_eagle_output_hidden_states = torch.randn(1, 4, 32, dtype=dtype) |
95 |
| - |
96 |
| - #This is the causal mask for the 0th eagle step |
97 |
| - m = torch.finfo(dtype).min |
98 |
| - attention_mask_0 = torch.tensor([[0, m, m, m], # input tok 10-> predicting token 20 |
99 |
| - [0, 0, m, m], # 20 -> 30 |
100 |
| - [0, 0, 0, m], # 30 -> 40 |
101 |
| - [0, 0, 0, 0]] # 40 -> tok after 40 |
102 |
| - |
103 |
| - , dtype=dtype).view(1, 1, 4, 4) |
104 |
| - |
105 |
| - # 2nd eagle step |
106 |
| - eagle_input_h_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = dummy_model._concat_eagle_inputs( |
107 |
| - eagle_input_ids_0, |
108 |
| - cat_aux_hidden_states, |
109 |
| - attention_mask_0, |
110 |
| - position_ids_0, |
111 |
| - dummy_eagle_output_hidden_states, |
112 |
| - ) |
113 |
| - |
114 |
| - assert eagle_input_ids_1.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) |
115 |
| - assert position_ids_1.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) |
116 |
| - |
117 |
| - assert attention_mask_1.equal(torch.tensor([[0, m, m, m, m, m, m, m], # (x) output discarded |
118 |
| - [0, 0, m, m, m, m, m, m], # (x) |
119 |
| - [0, 0, 0, m, m, m, m, m], # (x) |
120 |
| - [0, 0, 0, 0, m, m, m, m], # (x) |
121 |
| - |
122 |
| - [m, m, m, m, m, m, m, m], # (x) input tok 10-> predicting token 20 |
123 |
| - [0, m, m, m, m, 0, m, m], # 20 -> 30 |
124 |
| - [0, 0, m, m, m, m, 0, m], # 30 -> 40 |
125 |
| - [0, 0, 0, 0, m, m, m, m], # (x) 40 -> tok after 40 |
126 |
| - ], dtype=dtype).view(1, 1, 8, 8)) |
127 |
| - |
128 |
| - # 3rd eagle step |
129 |
| - eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = dummy_model._concat_eagle_inputs( |
130 |
| - eagle_input_ids_0, |
131 |
| - cat_aux_hidden_states, |
132 |
| - attention_mask_0, |
133 |
| - position_ids_0, |
134 |
| - torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states], dim=1), |
135 |
| - ) |
136 |
| - assert eagle_input_ids_2.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) |
137 |
| - assert position_ids_2.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) |
138 |
| - |
139 |
| - assert attention_mask_2.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m], # (x) |
140 |
| - [0, 0, m, m, m, m, m, m, m, m, m, m], # (x) |
141 |
| - [0, 0, 0, m, m, m, m, m, m, m, m, m], # (x) |
142 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) |
143 |
| - |
144 |
| - [m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
145 |
| - [0, m, m, m, m, 0, m, m, m, m, m, m], # (x) |
146 |
| - [0, 0, m, m, m, m, 0, m, m, m, m, m], # (x) |
147 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) |
148 |
| - |
149 |
| - [m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20 |
150 |
| - [m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30 |
151 |
| - [0, m, m, m, m, 0, m, m, m, m, 0, m], # 30 -> 40 |
152 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) 40 -> tok after 40 |
153 |
| - |
154 |
| - ], dtype=dtype).view(1, 1, 12, 12)) |
155 |
| - |
156 |
| - # 4th eagle step |
157 |
| - eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = dummy_model._concat_eagle_inputs( |
158 |
| - eagle_input_ids_0, |
159 |
| - cat_aux_hidden_states, |
160 |
| - attention_mask_0, |
161 |
| - position_ids_0, |
162 |
| - torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states, |
163 |
| - dummy_eagle_output_hidden_states],dim=1), |
164 |
| - ) |
165 |
| - |
166 |
| - assert eagle_input_ids_3.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, |
167 |
| - 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) |
168 |
| - assert position_ids_3.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) |
169 |
| - |
170 |
| - assert attention_mask_3.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
171 |
| - [0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
172 |
| - [0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
173 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
174 |
| - |
175 |
| - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
176 |
| - [0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m, m], # (x) |
177 |
| - [0, 0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m], # (x) |
178 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
179 |
| - |
180 |
| - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
181 |
| - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
182 |
| - [0, m, m, m, m, 0, m, m, m, m, 0, m, m, m, m, m], # (x) |
183 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
184 |
| - |
185 |
| - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20 |
186 |
| - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30 |
187 |
| - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
188 |
| - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
189 |
| - |
190 |
| - ], dtype=dtype).view(1, 1, 16, 16)) |
0 commit comments