|
129 | 129 | TestCase,
|
130 | 130 | assert_allclose,
|
131 | 131 | dummy_segments_positions,
|
| 132 | + is_supported_mesh_shape, |
132 | 133 | set_threefry_partitionable,
|
133 | 134 | )
|
134 | 135 | from axlearn.common.torch_utils import parameters_from_torch_layer
|
@@ -2466,6 +2467,82 @@ def test_gqa_forward(
|
2466 | 2467 | )
|
2467 | 2468 | self.assertNestedAllClose(base_outputs, test_outputs)
|
2468 | 2469 |
|
| 2470 | + @parameterized.product(kv_part=[None, PartitionSpec("fsdp", None, "model", None)]) |
| 2471 | + @pytest.mark.d8 |
| 2472 | + def test_qkvo_partition_spec(self, kv_part): |
| 2473 | + """Tests that QKVO partition spec are applied correctly when specified.""" |
| 2474 | + mesh_shape = (2, 2, 2) |
| 2475 | + if not is_supported_mesh_shape(mesh_shape): |
| 2476 | + self.skipTest(f"Unsupported mesh shape {mesh_shape}") |
| 2477 | + model_dim = 16 |
| 2478 | + num_heads = 4 |
| 2479 | + mesh = jax.make_mesh(mesh_shape, axis_names=("fsdp", "seq", "model")) |
| 2480 | + q_part = PartitionSpec("fsdp", "seq", "model", None) |
| 2481 | + o_part = PartitionSpec("fsdp", "seq", None) |
| 2482 | + |
| 2483 | + layer_kwargs = dict( |
| 2484 | + query_dim=model_dim, |
| 2485 | + key_dim=model_dim, |
| 2486 | + value_dim=model_dim, |
| 2487 | + num_heads=num_heads, |
| 2488 | + dtype=jnp.float32, |
| 2489 | + q_partition_spec=q_part, |
| 2490 | + o_partition_spec=o_part, |
| 2491 | + k_partition_spec=kv_part, |
| 2492 | + v_partition_spec=kv_part, |
| 2493 | + ) |
| 2494 | + init_key = jax.random.PRNGKey(123) |
| 2495 | + base_cfg = attention.MultiheadAttention.default_config().set(**layer_kwargs) |
| 2496 | + base_layer = base_cfg.set(name="base").instantiate(parent=None) |
| 2497 | + base_state = base_layer.initialize_parameters_recursively(prng_key=init_key) |
| 2498 | + |
| 2499 | + # Dummy inputs. |
| 2500 | + batch_size, tgt_len = 2, 6 |
| 2501 | + base_inputs = dict( |
| 2502 | + query=jax.random.normal( |
| 2503 | + jax.random.PRNGKey(124), |
| 2504 | + [batch_size, tgt_len, model_dim], |
| 2505 | + dtype=jnp.float32, |
| 2506 | + ), |
| 2507 | + key=None, |
| 2508 | + value=None, |
| 2509 | + ) |
| 2510 | + forward_key = jax.random.PRNGKey(456) |
| 2511 | + |
| 2512 | + def patched_remat_name(_, tensor, name): |
| 2513 | + def callback(sharding): |
| 2514 | + # pylint: disable-next=protected-access |
| 2515 | + normalize_spec = sharding.spec._normalized_spec_for_aval(len(tensor.shape)) |
| 2516 | + if name == "q_proj": |
| 2517 | + self.assertEqual(normalize_spec, q_part) |
| 2518 | + elif name == "o_proj": |
| 2519 | + self.assertEqual(normalize_spec, o_part) |
| 2520 | + elif name in ["k_proj", "v_proj"]: |
| 2521 | + if kv_part is None: |
| 2522 | + self.assertEqual(normalize_spec, q_part) |
| 2523 | + else: |
| 2524 | + self.assertEqual(normalize_spec, kv_part) |
| 2525 | + |
| 2526 | + jax.debug.inspect_array_sharding(tensor, callback=callback) |
| 2527 | + return tensor |
| 2528 | + |
| 2529 | + with mesh, mock.patch.object( |
| 2530 | + attention.MultiheadAttention, "_remat_name", patched_remat_name |
| 2531 | + ): |
| 2532 | + |
| 2533 | + @jax.jit |
| 2534 | + def jit_fn(): |
| 2535 | + base_outputs, _ = F( |
| 2536 | + base_layer, |
| 2537 | + state=base_state, |
| 2538 | + is_training=False, |
| 2539 | + prng_key=forward_key, |
| 2540 | + inputs=base_inputs, |
| 2541 | + ) |
| 2542 | + return base_outputs |
| 2543 | + |
| 2544 | + jit_fn() |
| 2545 | + |
2469 | 2546 | def _test_extend_step(
|
2470 | 2547 | self,
|
2471 | 2548 | attention_cfg: attention.MultiheadAttention.Config,
|
|
0 commit comments