|
17 | 17 | from copy import deepcopy |
18 | 18 |
|
19 | 19 | import pytest |
20 | | -import torch |
21 | 20 | from _test_utils.torch_model.transformers_models import ( |
22 | 21 | create_tiny_llama_dir, |
23 | 22 | get_tiny_llama, |
@@ -69,122 +68,3 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config): |
69 | 68 | model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") |
70 | 69 | assert isinstance(model_test, mtsp.plugins.HFEagleModel) |
71 | 70 | tf_modelopt_state_and_output_tester(model_ref, model_test) |
72 | | - |
73 | | - |
74 | | -# fmt: off |
75 | | -@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
76 | | -def test_eagle_model_prepare_eagle_inputs(dtype): |
77 | | - dummy_model = get_tiny_llama(num_hidden_layers=4) |
78 | | - |
79 | | - config = EAGLE3_DEFAULT_CFG["config"] |
80 | | - config["eagle_architecture_config"].update({ |
81 | | - "draft_vocab_size": dummy_model.config.vocab_size, |
82 | | - "hidden_size": dummy_model.config.hidden_size, |
83 | | - }) |
84 | | - mtsp.convert(dummy_model, mode=[("eagle", config)]) |
85 | | - |
86 | | - eagle_input_ids_0 = torch.tensor([[10, 20, 30, 40]], dtype=torch.long) |
87 | | - position_ids_0 = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) |
88 | | - |
89 | | - |
90 | | - #This is concatenated from 3 intermediate base model layers |
91 | | - cat_aux_hidden_states = torch.randn(1, 4, 32, dtype=dtype) |
92 | | - |
93 | | - #This is eagle output from previous eagle forward pass |
94 | | - dummy_eagle_output_hidden_states = torch.randn(1, 4, 32, dtype=dtype) |
95 | | - |
96 | | - #This is the causal mask for the 0th eagle step |
97 | | - m = torch.finfo(dtype).min |
98 | | - attention_mask_0 = torch.tensor([[0, m, m, m], # input tok 10-> predicting token 20 |
99 | | - [0, 0, m, m], # 20 -> 30 |
100 | | - [0, 0, 0, m], # 30 -> 40 |
101 | | - [0, 0, 0, 0]] # 40 -> tok after 40 |
102 | | - |
103 | | - , dtype=dtype).view(1, 1, 4, 4) |
104 | | - |
105 | | - # 2nd eagle step |
106 | | - eagle_input_h_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = dummy_model._concat_eagle_inputs( |
107 | | - eagle_input_ids_0, |
108 | | - cat_aux_hidden_states, |
109 | | - attention_mask_0, |
110 | | - position_ids_0, |
111 | | - dummy_eagle_output_hidden_states, |
112 | | - ) |
113 | | - |
114 | | - assert eagle_input_ids_1.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) |
115 | | - assert position_ids_1.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) |
116 | | - |
117 | | - assert attention_mask_1.equal(torch.tensor([[0, m, m, m, m, m, m, m], # (x) output discarded |
118 | | - [0, 0, m, m, m, m, m, m], # (x) |
119 | | - [0, 0, 0, m, m, m, m, m], # (x) |
120 | | - [0, 0, 0, 0, m, m, m, m], # (x) |
121 | | - |
122 | | - [m, m, m, m, m, m, m, m], # (x) input tok 10-> predicting token 20 |
123 | | - [0, m, m, m, m, 0, m, m], # 20 -> 30 |
124 | | - [0, 0, m, m, m, m, 0, m], # 30 -> 40 |
125 | | - [0, 0, 0, 0, m, m, m, m], # (x) 40 -> tok after 40 |
126 | | - ], dtype=dtype).view(1, 1, 8, 8)) |
127 | | - |
128 | | - # 3rd eagle step |
129 | | - eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = dummy_model._concat_eagle_inputs( |
130 | | - eagle_input_ids_0, |
131 | | - cat_aux_hidden_states, |
132 | | - attention_mask_0, |
133 | | - position_ids_0, |
134 | | - torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states], dim=1), |
135 | | - ) |
136 | | - assert eagle_input_ids_2.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) |
137 | | - assert position_ids_2.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) |
138 | | - |
139 | | - assert attention_mask_2.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m], # (x) |
140 | | - [0, 0, m, m, m, m, m, m, m, m, m, m], # (x) |
141 | | - [0, 0, 0, m, m, m, m, m, m, m, m, m], # (x) |
142 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) |
143 | | - |
144 | | - [m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
145 | | - [0, m, m, m, m, 0, m, m, m, m, m, m], # (x) |
146 | | - [0, 0, m, m, m, m, 0, m, m, m, m, m], # (x) |
147 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) |
148 | | - |
149 | | - [m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20 |
150 | | - [m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30 |
151 | | - [0, m, m, m, m, 0, m, m, m, m, 0, m], # 30 -> 40 |
152 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) 40 -> tok after 40 |
153 | | - |
154 | | - ], dtype=dtype).view(1, 1, 12, 12)) |
155 | | - |
156 | | - # 4th eagle step |
157 | | - eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = dummy_model._concat_eagle_inputs( |
158 | | - eagle_input_ids_0, |
159 | | - cat_aux_hidden_states, |
160 | | - attention_mask_0, |
161 | | - position_ids_0, |
162 | | - torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states, |
163 | | - dummy_eagle_output_hidden_states],dim=1), |
164 | | - ) |
165 | | - |
166 | | - assert eagle_input_ids_3.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, |
167 | | - 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) |
168 | | - assert position_ids_3.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) |
169 | | - |
170 | | - assert attention_mask_3.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
171 | | - [0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
172 | | - [0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
173 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
174 | | - |
175 | | - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
176 | | - [0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m, m], # (x) |
177 | | - [0, 0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m], # (x) |
178 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
179 | | - |
180 | | - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
181 | | - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
182 | | - [0, m, m, m, m, 0, m, m, m, m, 0, m, m, m, m, m], # (x) |
183 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
184 | | - |
185 | | - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20 |
186 | | - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30 |
187 | | - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
188 | | - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) |
189 | | - |
190 | | - ], dtype=dtype).view(1, 1, 16, 16)) |
0 commit comments