Skip to content

Commit 38472b1

Browse files
committed
add explicit sharding support
1 parent 64d6d9b commit 38472b1

35 files changed

+573
-203
lines changed

src/MaxText/common_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,8 @@ class AttentionType(enum.Enum):
100100
CHUNK = "chunk"
101101
MLA = "mla"
102102
FULL = "full"
103+
104+
105+
class ShardMode(enum.Enum):
106+
AUTO = "auto" # default
107+
EXPLICIT = "explicit"

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ jax_cache_dir: "~/jax_cache"
353353
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'
354354

355355
# Parallelism
356+
shard_mode: "auto" # can be either auto or explicit
356357
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
357358
logical_axis_rules: [
358359
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],

src/MaxText/data_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def load_next_batch(self):
4848
else:
4949
example_batch = next(self.data_iterator)
5050
# Reshard data from loaded sharding to performant activation sharding
51-
self.last_batch = jax.lax.with_sharding_constraint(example_batch, self.input_data_shardings)
51+
self.last_batch = maxtext_utils.maybe_shard_with_name(
52+
example_batch,
53+
self.input_data_shardings,
54+
self.config.shard_mode,
55+
)
5256
self.check_example_batch()
5357
except Exception as e: # pylint: disable=broad-except
5458
if isinstance(e, StopIteration):

src/MaxText/gradient_accumulation.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
import jax
1818
import jax.numpy as jnp
19+
from jax.sharding import NamedSharding
20+
21+
from MaxText.common_types import ShardMode
22+
from MaxText.maxtext_utils import maybe_shard_with_name
1923

2024

2125
def gradient_accumulation_loss_and_grad(
@@ -58,6 +62,17 @@ def gradient_accumulation_loss_and_grad(
5862
- final_aux (PyTree): Auxiliary outputs, summed across microbatches.
5963
- raw_grads (PyTree): The accumulated and averaged gradients.
6064
"""
65+
66+
def _maybe_shard_with_name(inputs, sharding_names):
67+
"""Wrapper of maybe_shard_with_name with fixed shard_mode"""
68+
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode)
69+
70+
# For more efficient DP/ZeRO-1 + GA
71+
if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
72+
ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
73+
grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
74+
else:
75+
ga_params_shardings = grad_shardings = params_shardings
6176
# When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints
6277
# so that all-gather is done once in the lower precision before the gradient accumulation loop
6378
if config.shard_optimizer_over_data:
@@ -68,15 +83,14 @@ def convert_to_bf16(param):
6883
return param
6984

7085
ga_params = jax.tree_util.tree_map(convert_to_bf16, params)
71-
ga_params = jax.tree.map(jax.lax.with_sharding_constraint, ga_params, params_shardings)
7286
else:
7387
ga_params = params
7488

89+
ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings)
7590
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)
7691

7792
def accumulate_gradient(acc_grad_and_loss, data):
7893
ga_params = acc_grad_and_loss["ga_params"]
79-
8094
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True)
8195
acc_grad_and_loss["loss"] += aux["total_loss"]
8296
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
@@ -94,7 +108,7 @@ def reshape_to_microbatch_accumulations(batch_arr):
94108

95109
data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data)
96110
init_grad = jax.tree_util.tree_map(jnp.zeros_like, ga_params)
97-
init_grad = jax.tree.map(jax.lax.with_sharding_constraint, init_grad, params_shardings)
111+
init_grad = jax.tree.map(_maybe_shard_with_name, init_grad, grad_shardings)
98112
init_grad_and_loss = {
99113
"loss": 0.0,
100114
"grad": init_grad,
@@ -113,9 +127,23 @@ def reshape_to_microbatch_accumulations(batch_arr):
113127
+ grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps
114128
)
115129
raw_grads = grad_and_loss["grad"]
116-
if config.shard_optimizer_over_data:
117-
raw_grads = jax.tree.map(jax.lax.with_sharding_constraint, raw_grads, params_shardings)
130+
raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings)
118131
raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads)
119132
aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr
120133

121134
return loss, aux, raw_grads
135+
136+
137+
# GA helper functions
138+
def update_sharding_for_reduced(sharding: NamedSharding) -> NamedSharding:
139+
"""
140+
Add reduced on data axis of given NamedSharding
141+
"""
142+
return sharding.update(spec=sharding.spec.update(reduced={"data"}))
143+
144+
145+
def update_sharding_for_unreduced(sharding: NamedSharding) -> NamedSharding:
146+
"""
147+
Add unreduced on data axis of given NamedSharding
148+
"""
149+
return sharding.update(spec=sharding.spec.update(unreduced={"data"}))

src/MaxText/layers/attention_mla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Optional, Tuple
1919

2020
from jax.ad_checkpoint import checkpoint_name
21-
from jax.sharding import Mesh
21+
from jax.sharding import Mesh, NamedSharding
2222
import jax.numpy as jnp
2323

2424
from flax import linen as nn
@@ -663,6 +663,7 @@ def __call__(
663663
inputs_kv: Array,
664664
inputs_positions: Array | None = None,
665665
decoder_segment_ids: Array | None = None,
666+
out_sharding: NamedSharding | None = None,
666667
*,
667668
model_mode: str = MODEL_MODE_TRAIN,
668669
deterministic: bool = False,

src/MaxText/layers/attention_op.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from jax.experimental.pallas.ops.gpu import attention as gpu_pallas_attention
2929
from jax.experimental.pallas.ops.gpu import decode_attention as gpu_pallas_decode_attention
3030
from jax.experimental import pallas as pl
31-
from jax.sharding import Mesh
31+
from jax.sharding import Mesh, NamedSharding
3232
import jax.numpy as jnp
3333

3434
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
@@ -44,6 +44,7 @@
4444

4545

4646
from MaxText import max_utils
47+
from MaxText.maxtext_utils import maybe_shard_with_name
4748
from MaxText.common_types import (
4849
DEFAULT_MASK_VALUE,
4950
BATCH,
@@ -1302,12 +1303,26 @@ def wrap_flash_attention(
13021303
)
13031304
return attention_output
13041305

1306+
def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
1307+
# decoder_segment_ids can be None
1308+
if pspec is None:
1309+
return None
1310+
sharding = NamedSharding(self.mesh, pspec)
1311+
return maybe_shard_with_name(inputs, sharding, shard_mode=self.config.shard_mode)
1312+
1313+
query = _maybe_shard_with_pspec(query, axis_names_q)
1314+
key = _maybe_shard_with_pspec(key, axis_names_kv)
1315+
value = _maybe_shard_with_pspec(value, axis_names_kv)
1316+
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
1317+
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
1318+
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
1319+
13051320
x = wrap_flash_attention(
13061321
query,
13071322
key,
13081323
value,
1309-
decoder_segment_ids,
1310-
decoder_segment_ids,
1324+
decoder_segment_ids_q,
1325+
decoder_segment_ids_kv,
13111326
splash_kernel,
13121327
cp_size,
13131328
load_balanced_context_parallel,

0 commit comments

Comments
 (0)