|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import torch |
15 | 16 |
|
16 | 17 | from megatron.bridge.models.ministral3.ministral3_provider import ( |
17 | 18 | Ministral3ModelProvider, |
@@ -172,3 +173,175 @@ def test_ministral3_14b_initialization(self): |
172 | 173 | assert provider.ffn_hidden_size == 16384 |
173 | 174 | assert provider.num_layers == 40 |
174 | 175 | assert provider.rotary_base == 1000000000.0 |
| 176 | + |
| 177 | + |
| 178 | +class TestGetLlama4AttnScale: |
| 179 | + """Test cases for _get_llama_4_attn_scale function used in MinistralTEDotProductAttention. |
| 180 | +
|
| 181 | + This function computes attention scaling based on Llama 4 attention parameters. |
| 182 | + The key change in PR 1997 is that it now handles different query shapes for |
| 183 | + packed (3D) vs unpacked (4D) tensors. |
| 184 | + """ |
| 185 | + |
| 186 | + def _get_llama_4_attn_scale( |
| 187 | + self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple |
| 188 | + ) -> torch.Tensor: |
| 189 | + """Reimplementation of the function for testing.""" |
| 190 | + scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) |
| 191 | + num_dims_to_add = len(query_shape) - 1 |
| 192 | + for _ in range(num_dims_to_add): |
| 193 | + scaling = scaling.unsqueeze(-1) |
| 194 | + return scaling |
| 195 | + |
| 196 | + def test_unpacked_4d_query_shape(self): |
| 197 | + """Test attention scaling with unpacked 4D query shape [seq_len, batch, num_heads, head_dim].""" |
| 198 | + seq_len = 8 |
| 199 | + batch_size = 2 |
| 200 | + num_heads = 4 |
| 201 | + head_dim = 64 |
| 202 | + |
| 203 | + positions_ids = torch.arange(seq_len) |
| 204 | + beta = 0.1 |
| 205 | + max_position_embeddings = 16384 |
| 206 | + query_shape = (seq_len, batch_size, num_heads, head_dim) |
| 207 | + |
| 208 | + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) |
| 209 | + |
| 210 | + # Output should have shape [seq_len, 1, 1, 1] for broadcasting |
| 211 | + assert scaling.shape == (seq_len, 1, 1, 1) |
| 212 | + |
| 213 | + # First position should have scaling = 1 (since log(1 + 0) = 0) |
| 214 | + expected_first = 1 + beta * torch.log(torch.tensor(1.0)) |
| 215 | + assert torch.isclose(scaling[0, 0, 0, 0], expected_first, atol=1e-6) |
| 216 | + |
| 217 | + def test_packed_3d_query_shape(self): |
| 218 | + """Test attention scaling with packed 3D query shape [seq_len, num_heads, head_dim].""" |
| 219 | + seq_len = 16 |
| 220 | + num_heads = 8 |
| 221 | + head_dim = 32 |
| 222 | + |
| 223 | + positions_ids = torch.arange(seq_len) |
| 224 | + beta = 0.2 |
| 225 | + max_position_embeddings = 8192 |
| 226 | + query_shape = (seq_len, num_heads, head_dim) |
| 227 | + |
| 228 | + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) |
| 229 | + |
| 230 | + # Output should have shape [seq_len, 1, 1] for broadcasting (3D - 1 = 2 dims added) |
| 231 | + assert scaling.shape == (seq_len, 1, 1) |
| 232 | + |
| 233 | + # Verify scaling values are computed correctly |
| 234 | + expected = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) |
| 235 | + assert torch.allclose(scaling.squeeze(), expected, atol=1e-6) |
| 236 | + |
| 237 | + def test_scaling_formula_correctness(self): |
| 238 | + """Test that the scaling formula matches expected Llama 4 attention scaling.""" |
| 239 | + positions_ids = torch.tensor([0, 1, 100, 1000, 16384, 32768]) |
| 240 | + beta = 0.15 |
| 241 | + max_position_embeddings = 16384 |
| 242 | + query_shape = (6, 1, 1, 1) |
| 243 | + |
| 244 | + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) |
| 245 | + |
| 246 | + # Manual computation of expected values |
| 247 | + # For position 0: 1 + 0.15 * log(1 + 0) = 1 |
| 248 | + # For position 16384: 1 + 0.15 * log(1 + 1) = 1 + 0.15 * log(2) |
| 249 | + # For position 32768: 1 + 0.15 * log(1 + 2) = 1 + 0.15 * log(3) |
| 250 | + |
| 251 | + expected_0 = 1.0 |
| 252 | + expected_16384 = 1 + beta * torch.log(torch.tensor(2.0)) |
| 253 | + expected_32768 = 1 + beta * torch.log(torch.tensor(3.0)) |
| 254 | + |
| 255 | + assert torch.isclose(scaling[0].squeeze(), torch.tensor(expected_0), atol=1e-6) |
| 256 | + assert torch.isclose(scaling[4].squeeze(), expected_16384, atol=1e-6) |
| 257 | + assert torch.isclose(scaling[5].squeeze(), expected_32768, atol=1e-6) |
| 258 | + |
| 259 | + def test_beta_zero_returns_ones(self): |
| 260 | + """Test that beta=0 returns all ones (no scaling).""" |
| 261 | + positions_ids = torch.arange(10) |
| 262 | + beta = 0.0 |
| 263 | + max_position_embeddings = 4096 |
| 264 | + query_shape = (10, 4, 64) |
| 265 | + |
| 266 | + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) |
| 267 | + |
| 268 | + assert torch.allclose(scaling.squeeze(), torch.ones(10), atol=1e-6) |
| 269 | + |
| 270 | + def test_different_query_shapes_get_correct_dims(self): |
| 271 | + """Test that different query shapes result in correct number of dimensions added.""" |
| 272 | + positions_ids = torch.arange(4) |
| 273 | + beta = 0.1 |
| 274 | + max_position_embeddings = 1000 |
| 275 | + |
| 276 | + # 2D query shape |
| 277 | + query_shape_2d = (4, 32) |
| 278 | + scaling_2d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_2d) |
| 279 | + assert scaling_2d.shape == (4, 1) # 2-1 = 1 dim added |
| 280 | + |
| 281 | + # 3D query shape (packed THD) |
| 282 | + query_shape_3d = (4, 8, 32) |
| 283 | + scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_3d) |
| 284 | + assert scaling_3d.shape == (4, 1, 1) # 3-1 = 2 dims added |
| 285 | + |
| 286 | + # 4D query shape (unpacked BSHD) |
| 287 | + query_shape_4d = (4, 2, 8, 32) |
| 288 | + scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_4d) |
| 289 | + assert scaling_4d.shape == (4, 1, 1, 1) # 4-1 = 3 dims added |
| 290 | + |
| 291 | + def test_broadcasting_compatibility(self): |
| 292 | + """Test that scaling tensor is broadcastable to query tensor.""" |
| 293 | + seq_len = 8 |
| 294 | + num_heads = 4 |
| 295 | + head_dim = 64 |
| 296 | + |
| 297 | + positions_ids = torch.arange(seq_len) |
| 298 | + beta = 0.1 |
| 299 | + max_position_embeddings = 16384 |
| 300 | + |
| 301 | + # Test for 3D packed format |
| 302 | + query_3d = torch.randn(seq_len, num_heads, head_dim) |
| 303 | + scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_3d.shape) |
| 304 | + |
| 305 | + # Broadcasting should work |
| 306 | + result_3d = query_3d * scaling_3d.to(query_3d.dtype) |
| 307 | + assert result_3d.shape == query_3d.shape |
| 308 | + |
| 309 | + # Test for 4D unpacked format |
| 310 | + batch = 2 |
| 311 | + query_4d = torch.randn(seq_len, batch, num_heads, head_dim) |
| 312 | + scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_4d.shape) |
| 313 | + |
| 314 | + # Broadcasting should work |
| 315 | + result_4d = query_4d * scaling_4d.to(query_4d.dtype) |
| 316 | + assert result_4d.shape == query_4d.shape |
| 317 | + |
| 318 | + def test_gpu_tensor_support(self): |
| 319 | + """Test that the function works with GPU tensors if available.""" |
| 320 | + if not torch.cuda.is_available(): |
| 321 | + return # Skip test if no GPU |
| 322 | + |
| 323 | + positions_ids = torch.arange(8, device="cuda") |
| 324 | + beta = 0.1 |
| 325 | + max_position_embeddings = 1024 |
| 326 | + query_shape = (8, 4, 32) |
| 327 | + |
| 328 | + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) |
| 329 | + |
| 330 | + assert scaling.device.type == "cuda" |
| 331 | + assert scaling.shape == (8, 1, 1) |
| 332 | + |
| 333 | + def test_dtype_preservation(self): |
| 334 | + """Test that output dtype matches input positions_ids dtype.""" |
| 335 | + positions_ids_float32 = torch.arange(4, dtype=torch.float32) |
| 336 | + positions_ids_float64 = torch.arange(4, dtype=torch.float64) |
| 337 | + beta = 0.1 |
| 338 | + max_position_embeddings = 100 |
| 339 | + query_shape = (4, 2, 8) |
| 340 | + |
| 341 | + scaling_32 = self._get_llama_4_attn_scale(positions_ids_float32, beta, max_position_embeddings, query_shape) |
| 342 | + scaling_64 = self._get_llama_4_attn_scale(positions_ids_float64, beta, max_position_embeddings, query_shape) |
| 343 | + |
| 344 | + # Note: torch.arange with int creates int tensors, but the function uses float operations |
| 345 | + # The scaling result will be float due to log operation |
| 346 | + assert scaling_32.dtype == torch.float32 |
| 347 | + assert scaling_64.dtype == torch.float64 |
0 commit comments