|
7 | 7 | from executorch.examples.models.llama.rope import Rope |
8 | 8 | from executorch.examples.models.llama.static_attention import ( |
9 | 9 | StaticAttention, |
| 10 | + StaticAttentionMask, |
10 | 11 | StaticKVCache, |
11 | 12 | ) |
12 | 13 |
|
@@ -92,48 +93,54 @@ def test_with_cache(self): |
92 | 93 | n_chunks = 3 |
93 | 94 | chunk_len = config.max_seq_len // n_chunks |
94 | 95 | cache_len = config.max_seq_len - chunk_len |
95 | | - mask = torch.zeros(1, chunk_len, cache_len + chunk_len) |
96 | | - mask[:, :, :cache_len] = float("-inf") |
97 | | - mask[:, :, cache_len:] = torch.triu( |
98 | | - torch.full((1, chunk_len, chunk_len), float("-inf")), |
99 | | - diagonal=1, |
100 | | - ) |
101 | | - k_caches = { |
102 | | - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
103 | | - 1, cache_len, config.head_dim |
104 | | - ) |
105 | | - for i in range(config.n_kv_heads) |
106 | | - } |
107 | | - v_caches = { |
108 | | - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
109 | | - 1, cache_len, config.head_dim |
110 | | - ) |
111 | | - for i in range(config.n_kv_heads) |
112 | | - } |
113 | | - ys = [] |
114 | | - for i in range(n_chunks): |
115 | | - y_i, attn_update = static_attn( |
116 | | - x[:, i * chunk_len : (i + 1) * chunk_len, :], |
117 | | - freqs_cos[i * chunk_len : (i + 1) * chunk_len], |
118 | | - freqs_sin[i * chunk_len : (i + 1) * chunk_len], |
119 | | - mask=mask, |
120 | | - in_cache_state=(k_caches, v_caches), |
121 | | - out_cache_state=({}, {}), |
| 96 | + |
| 97 | + def test_with_style(style): |
| 98 | + mask = StaticAttentionMask(chunk_len, cache_len, style=style) |
| 99 | + mask.tensor[:, :, cache_len:] = torch.triu( |
| 100 | + torch.full((1, chunk_len, chunk_len), float("-inf")), |
| 101 | + diagonal=1, |
122 | 102 | ) |
123 | | - ys.append(y_i) |
124 | | - mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0 |
125 | | - k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
126 | | - for cache_id, update in k_cache_updates.items(): |
127 | | - k_caches[cache_id] = StaticKVCache.apply_update( |
128 | | - k_caches[cache_id], update |
| 103 | + k_caches = { |
| 104 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 105 | + 1, cache_len, config.head_dim |
129 | 106 | ) |
130 | | - for cache_id, update in v_cache_updates.items(): |
131 | | - v_caches[cache_id] = StaticKVCache.apply_update( |
132 | | - v_caches[cache_id], update |
| 107 | + for i in range(config.n_kv_heads) |
| 108 | + } |
| 109 | + v_caches = { |
| 110 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 111 | + 1, cache_len, config.head_dim |
133 | 112 | ) |
134 | | - |
135 | | - y = torch.cat(ys, dim=1) |
136 | | - self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) |
| 113 | + for i in range(config.n_kv_heads) |
| 114 | + } |
| 115 | + ys = [] |
| 116 | + for i in range(n_chunks): |
| 117 | + y_i, attn_update = static_attn( |
| 118 | + x[:, i * chunk_len : (i + 1) * chunk_len, :], |
| 119 | + freqs_cos[i * chunk_len : (i + 1) * chunk_len], |
| 120 | + freqs_sin[i * chunk_len : (i + 1) * chunk_len], |
| 121 | + mask=mask.tensor, |
| 122 | + in_cache_state=(k_caches, v_caches), |
| 123 | + out_cache_state=({}, {}), |
| 124 | + ) |
| 125 | + ys.append(y_i) |
| 126 | + mask.unmask(chunk_len) |
| 127 | + k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
| 128 | + |
| 129 | + if i < n_chunks - 1: |
| 130 | + for cache_id, update in k_cache_updates.items(): |
| 131 | + k_caches[cache_id] = StaticKVCache.apply_update( |
| 132 | + k_caches[cache_id], update, pos=chunk_len * i, style=style |
| 133 | + ) |
| 134 | + for cache_id, update in v_cache_updates.items(): |
| 135 | + v_caches[cache_id] = StaticKVCache.apply_update( |
| 136 | + v_caches[cache_id], update, pos=chunk_len * i, style=style |
| 137 | + ) |
| 138 | + |
| 139 | + y = torch.cat(ys, dim=1) |
| 140 | + self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) |
| 141 | + |
| 142 | + test_with_style("shift_pointer") |
| 143 | + test_with_style("smart_mask") |
137 | 144 |
|
138 | 145 | def test_within_transformer(self): |
139 | 146 | config = ModelArgs( |
@@ -162,48 +169,57 @@ def test_within_transformer(self): |
162 | 169 | n_chunks = 3 |
163 | 170 | chunk_len = config.max_seq_len // n_chunks |
164 | 171 | cache_len = config.max_seq_len - chunk_len |
165 | | - mask = torch.zeros(1, chunk_len, cache_len + chunk_len) |
166 | | - mask[:, :, :cache_len] = float("-inf") |
167 | | - mask[:, :, cache_len:] = torch.triu( |
168 | | - torch.full((1, chunk_len, chunk_len), float("-inf")), |
169 | | - diagonal=1, |
170 | | - ) |
171 | | - k_caches = { |
172 | | - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
173 | | - 1, cache_len, config.head_dim |
174 | | - ) |
175 | | - for layer_id in range(config.n_layers) |
176 | | - for i in range(config.n_kv_heads) |
177 | | - } |
178 | | - v_caches = { |
179 | | - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
180 | | - 1, cache_len, config.head_dim |
181 | | - ) |
182 | | - for layer_id in range(config.n_layers) |
183 | | - for i in range(config.n_kv_heads) |
184 | | - } |
185 | | - ys = [] |
186 | | - for i in range(n_chunks): |
187 | | - y_i, attn_update = static_transformer( |
188 | | - x[:, i * chunk_len : (i + 1) * chunk_len], |
189 | | - attn_options=ForwardOptions( |
190 | | - mask=mask, |
191 | | - freqs_cos_override=freqs_cos[i * chunk_len : (i + 1) * chunk_len], |
192 | | - freqs_sin_override=freqs_sin[i * chunk_len : (i + 1) * chunk_len], |
193 | | - in_cache_state=(k_caches, v_caches), |
194 | | - out_cache_state=({}, {}), |
195 | | - ), |
| 172 | + |
| 173 | + def test_with_style(style): |
| 174 | + mask = StaticAttentionMask(chunk_len, cache_len, style=style) |
| 175 | + mask.tensor[:, :, cache_len:] = torch.triu( |
| 176 | + torch.full((1, chunk_len, chunk_len), float("-inf")), |
| 177 | + diagonal=1, |
196 | 178 | ) |
197 | | - ys.append(y_i) |
198 | | - mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0 |
199 | | - k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
200 | | - for cache_id, update in k_cache_updates.items(): |
201 | | - k_caches[cache_id] = StaticKVCache.apply_update( |
202 | | - k_caches[cache_id], update |
| 179 | + k_caches = { |
| 180 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 181 | + 1, cache_len, config.head_dim |
203 | 182 | ) |
204 | | - for cache_id, update in v_cache_updates.items(): |
205 | | - v_caches[cache_id] = StaticKVCache.apply_update( |
206 | | - v_caches[cache_id], update |
| 183 | + for layer_id in range(config.n_layers) |
| 184 | + for i in range(config.n_kv_heads) |
| 185 | + } |
| 186 | + v_caches = { |
| 187 | + StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( |
| 188 | + 1, cache_len, config.head_dim |
207 | 189 | ) |
208 | | - |
209 | | - self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all()) |
| 190 | + for layer_id in range(config.n_layers) |
| 191 | + for i in range(config.n_kv_heads) |
| 192 | + } |
| 193 | + ys = [] |
| 194 | + for i in range(n_chunks): |
| 195 | + y_i, attn_update = static_transformer( |
| 196 | + x[:, i * chunk_len : (i + 1) * chunk_len], |
| 197 | + attn_options=ForwardOptions( |
| 198 | + mask=mask.tensor, |
| 199 | + freqs_cos_override=freqs_cos[ |
| 200 | + i * chunk_len : (i + 1) * chunk_len |
| 201 | + ], |
| 202 | + freqs_sin_override=freqs_sin[ |
| 203 | + i * chunk_len : (i + 1) * chunk_len |
| 204 | + ], |
| 205 | + in_cache_state=(k_caches, v_caches), |
| 206 | + out_cache_state=({}, {}), |
| 207 | + ), |
| 208 | + ) |
| 209 | + ys.append(y_i) |
| 210 | + mask.unmask(chunk_len) |
| 211 | + k_cache_updates, v_cache_updates = attn_update["out_cache_state"] |
| 212 | + if i < n_chunks - 1: |
| 213 | + for cache_id, update in k_cache_updates.items(): |
| 214 | + k_caches[cache_id] = StaticKVCache.apply_update( |
| 215 | + k_caches[cache_id], update, pos=chunk_len * i, style=style |
| 216 | + ) |
| 217 | + for cache_id, update in v_cache_updates.items(): |
| 218 | + v_caches[cache_id] = StaticKVCache.apply_update( |
| 219 | + v_caches[cache_id], update, pos=chunk_len * i, style=style |
| 220 | + ) |
| 221 | + |
| 222 | + self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all()) |
| 223 | + |
| 224 | + test_with_style("shift_pointer") |
| 225 | + test_with_style("smart_mask") |
0 commit comments