Skip to content

Commit 4e69203

Browse files
hanzhi713changlan
authored andcommitted
Allow qkvo partition specs
* Allow qkvo partition specs * Add unit test * Add missing skip GitOrigin-RevId: 143aaef
1 parent e9d9177 commit 4e69203

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

axlearn/common/attention.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
save_and_offload_only_these_names_regex,
163163
shapes,
164164
split_prng_key,
165+
with_sharding_constraint,
165166
)
166167

167168

@@ -1556,6 +1557,20 @@ class Config(BaseLayer.Config):
15561557
# If true, use learnable logit sinks.
15571558
logit_sink: Optional[bool] = None
15581559

1560+
# Partition spec for query ([batch, seq, q_heads, head_dim]) after input projections.
1561+
q_partition_spec: Optional[PartitionSpec] = None
1562+
1563+
# Partition spec for key ([batch, seq, kv_heads, head_dim]) after input projections.
1564+
# Follows `q_partition_spec` if None.
1565+
k_partition_spec: Optional[PartitionSpec] = None
1566+
1567+
# Partition spec for value ([batch, seq, kv_heads, head_dim]) after input projections.
1568+
# Follows `q_partition_spec` if None.
1569+
v_partition_spec: Optional[PartitionSpec] = None
1570+
1571+
# Partition spec for output ([batch, seq, hidden_dim]) after output projections.
1572+
o_partition_spec: Optional[PartitionSpec] = None
1573+
15591574
def __init__(self, cfg: Config, *, parent: Module):
15601575
super().__init__(cfg, parent=parent)
15611576
cfg = self.config
@@ -1719,6 +1734,12 @@ def _forward_for_mode(
17191734
time_step = cached_states["time_step"]
17201735
query_positions = query_positions + time_step[:, None] # [batch, steps]
17211736
q_proj, k_proj, v_proj = self.i_proj(query, query_positions=query_positions, **kv_kwargs)
1737+
if cfg.q_partition_spec:
1738+
q_proj = with_sharding_constraint(q_proj, cfg.q_partition_spec)
1739+
if cfg.q_partition_spec or cfg.k_partition_spec:
1740+
k_proj = with_sharding_constraint(k_proj, cfg.k_partition_spec or cfg.q_partition_spec)
1741+
if cfg.q_partition_spec or cfg.v_partition_spec:
1742+
v_proj = with_sharding_constraint(v_proj, cfg.v_partition_spec or cfg.q_partition_spec)
17221743

17231744
if cfg.scale_kv_before_cache_update:
17241745
if has_external_kv_state:
@@ -1821,6 +1842,8 @@ def _forward_for_mode(
18211842

18221843
# [batch, target_length, output_dim].
18231844
o_proj = self.o_proj(context)
1845+
if cfg.o_partition_spec:
1846+
o_proj = with_sharding_constraint(o_proj, cfg.o_partition_spec)
18241847
outputs = self._remat_name(o_proj, "o_proj")
18251848
self._add_tensor_stats("o_proj_outputs", outputs)
18261849
return_aux = return_aux or set()

axlearn/common/attention_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
TestCase,
130130
assert_allclose,
131131
dummy_segments_positions,
132+
is_supported_mesh_shape,
132133
set_threefry_partitionable,
133134
)
134135
from axlearn.common.torch_utils import parameters_from_torch_layer
@@ -2466,6 +2467,82 @@ def test_gqa_forward(
24662467
)
24672468
self.assertNestedAllClose(base_outputs, test_outputs)
24682469

2470+
@parameterized.product(kv_part=[None, PartitionSpec("fsdp", None, "model", None)])
2471+
@pytest.mark.d8
2472+
def test_qkvo_partition_spec(self, kv_part):
2473+
"""Tests that QKVO partition spec are applied correctly when specified."""
2474+
mesh_shape = (2, 2, 2)
2475+
if not is_supported_mesh_shape(mesh_shape):
2476+
self.skipTest(f"Unsupported mesh shape {mesh_shape}")
2477+
model_dim = 16
2478+
num_heads = 4
2479+
mesh = jax.make_mesh(mesh_shape, axis_names=("fsdp", "seq", "model"))
2480+
q_part = PartitionSpec("fsdp", "seq", "model", None)
2481+
o_part = PartitionSpec("fsdp", "seq", None)
2482+
2483+
layer_kwargs = dict(
2484+
query_dim=model_dim,
2485+
key_dim=model_dim,
2486+
value_dim=model_dim,
2487+
num_heads=num_heads,
2488+
dtype=jnp.float32,
2489+
q_partition_spec=q_part,
2490+
o_partition_spec=o_part,
2491+
k_partition_spec=kv_part,
2492+
v_partition_spec=kv_part,
2493+
)
2494+
init_key = jax.random.PRNGKey(123)
2495+
base_cfg = attention.MultiheadAttention.default_config().set(**layer_kwargs)
2496+
base_layer = base_cfg.set(name="base").instantiate(parent=None)
2497+
base_state = base_layer.initialize_parameters_recursively(prng_key=init_key)
2498+
2499+
# Dummy inputs.
2500+
batch_size, tgt_len = 2, 6
2501+
base_inputs = dict(
2502+
query=jax.random.normal(
2503+
jax.random.PRNGKey(124),
2504+
[batch_size, tgt_len, model_dim],
2505+
dtype=jnp.float32,
2506+
),
2507+
key=None,
2508+
value=None,
2509+
)
2510+
forward_key = jax.random.PRNGKey(456)
2511+
2512+
def patched_remat_name(_, tensor, name):
2513+
def callback(sharding):
2514+
# pylint: disable-next=protected-access
2515+
normalize_spec = sharding.spec._normalized_spec_for_aval(len(tensor.shape))
2516+
if name == "q_proj":
2517+
self.assertEqual(normalize_spec, q_part)
2518+
elif name == "o_proj":
2519+
self.assertEqual(normalize_spec, o_part)
2520+
elif name in ["k_proj", "v_proj"]:
2521+
if kv_part is None:
2522+
self.assertEqual(normalize_spec, q_part)
2523+
else:
2524+
self.assertEqual(normalize_spec, kv_part)
2525+
2526+
jax.debug.inspect_array_sharding(tensor, callback=callback)
2527+
return tensor
2528+
2529+
with mesh, mock.patch.object(
2530+
attention.MultiheadAttention, "_remat_name", patched_remat_name
2531+
):
2532+
2533+
@jax.jit
2534+
def jit_fn():
2535+
base_outputs, _ = F(
2536+
base_layer,
2537+
state=base_state,
2538+
is_training=False,
2539+
prng_key=forward_key,
2540+
inputs=base_inputs,
2541+
)
2542+
return base_outputs
2543+
2544+
jit_fn()
2545+
24692546
def _test_extend_step(
24702547
self,
24712548
attention_cfg: attention.MultiheadAttention.Config,

0 commit comments

Comments
 (0)