Skip to content

Commit aebd278

Browse files
Merge pull request #2562 from AI-Hypercomputer:chengnuojin-move-sharding
PiperOrigin-RevId: 826212582
2 parents 60b1f1b + 3aaad8a commit aebd278

18 files changed

+553
-539
lines changed

src/MaxText/data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from jax.experimental import checkify
2121

2222
from MaxText import exceptions
23-
from MaxText import maxtext_utils
23+
from MaxText import sharding
2424
from MaxText.utils.goodput_utils import (
2525
GoodputEvent,
2626
maybe_record_goodput,
@@ -37,7 +37,7 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
3737
self.goodput_recorder = goodput_recorder
3838
self.data_iterator = data_iterator
3939
self.last_batch = None
40-
self.input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
40+
self.input_data_shardings = sharding.get_input_data_sharding(config, mesh)
4141

4242
def load_next_batch(self):
4343
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
@@ -48,7 +48,7 @@ def load_next_batch(self):
4848
else:
4949
example_batch = next(self.data_iterator)
5050
# Reshard data from loaded sharding to performant activation sharding
51-
self.last_batch = maxtext_utils.maybe_shard_with_name(
51+
self.last_batch = sharding.maybe_shard_with_name(
5252
example_batch,
5353
self.input_data_shardings,
5454
self.config.shard_mode,

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from MaxText import max_logging
7373
from MaxText import max_utils
7474
from MaxText import maxtext_utils
75+
from MaxText import sharding
7576
from MaxText import train_utils
7677
from MaxText import profiler
7778
from MaxText import pyconfig
@@ -566,7 +567,7 @@ def setup_train_loop(
566567
)[2]
567568
if not config.using_pipeline_parallelism:
568569
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
569-
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
570+
sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
570571

571572
return (
572573
init_rng,
@@ -688,7 +689,7 @@ def train_loop(config, config_inference, recorder, state=None):
688689
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator
689690
)
690691

691-
data_sharding = maxtext_utils.get_input_data_sharding(config, mesh)
692+
data_sharding = sharding.get_input_data_sharding(config, mesh)
692693

693694
inference_engine = offline_engine.OfflineEngine(
694695
config=config_inference,

src/MaxText/gradient_accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from jax.sharding import NamedSharding
2020

2121
from MaxText.common_types import ShardMode
22-
from MaxText.maxtext_utils import maybe_shard_with_name
22+
from MaxText.sharding import maybe_shard_with_name
2323

2424

2525
def gradient_accumulation_loss_and_grad(

src/MaxText/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444

4545
from MaxText import max_utils
46-
from MaxText.maxtext_utils import maybe_shard_with_name
46+
from MaxText.sharding import maybe_shard_with_name
4747
from MaxText.common_types import (
4848
DEFAULT_MASK_VALUE,
4949
BATCH,

src/MaxText/layers/attentions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
EP_AS_CONTEXT,
5454
AttentionType,
5555
)
56-
from MaxText.maxtext_utils import maybe_shard_with_logical
56+
from MaxText.sharding import maybe_shard_with_logical
5757
from MaxText.inference import kvcache
5858
from MaxText.inference import page_manager
5959
from MaxText.inference import paged_attention

src/MaxText/layers/decoders.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from MaxText.layers import pipeline
3939
from MaxText import maxtext_utils
4040
from MaxText import multimodal_utils
41+
from MaxText import sharding
4142
from MaxText.layers.attentions import attention_as_linen
4243
from MaxText.layers.normalizations import rms_norm
4344
from MaxText.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
@@ -90,7 +91,7 @@ def __call__(
9091
cfg = self.config
9192
mesh = self.mesh
9293
_maybe_shard_with_logical = functools.partial(
93-
maxtext_utils.maybe_shard_with_logical,
94+
sharding.maybe_shard_with_logical,
9495
mesh=mesh,
9596
shard_mode=cfg.shard_mode,
9697
)
@@ -722,7 +723,7 @@ def __call__(
722723
moe_layer = RemattedBlockLayers[1]
723724
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
724725
num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers
725-
logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
726+
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
726727
# We chose not to pipeline the dense layers, only sparse for SPMD.
727728
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
728729
y, _ = self.scan_decoder_layers(
@@ -749,7 +750,7 @@ def __call__(
749750
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
750751
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
751752
if remaining_layers > 0:
752-
logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
753+
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
753754
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
754755
y, _ = self.scan_decoder_layers(
755756
cfg,

src/MaxText/layers/linears.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from MaxText import max_logging
3333
from MaxText import max_utils
34-
from MaxText.maxtext_utils import maybe_shard_with_logical
34+
from MaxText.sharding import maybe_shard_with_logical
3535
from MaxText.common_types import DecoderBlockType, ShardMode, DType, Array, Config
3636
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT
3737
from MaxText.layers import nnx_wrappers, quantizations

src/MaxText/layers/llama2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from MaxText.inference import page_manager
2828
from MaxText.common_types import Config
2929
from MaxText import max_utils
30-
from MaxText.maxtext_utils import maybe_shard_with_logical
30+
from MaxText.sharding import maybe_shard_with_logical
3131
from MaxText.layers.linears import Dropout, MlpBlock
3232
from MaxText.layers import initializers
3333
from MaxText.layers import nnx_wrappers

src/MaxText/layers/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from MaxText.layers.encoders import VisionEncoder
3535
from MaxText.layers.quantizations import AqtQuantization as Quant
3636
from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock
37-
from MaxText.maxtext_utils import all_gather_over_fsdp
37+
from MaxText.sharding import all_gather_over_fsdp
3838

3939
# ------------------------------------------------------------------------------
4040
# The network: Transformer Definitions

src/MaxText/layers/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax import linen as nn
2929

3030
from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT
31-
from MaxText.maxtext_utils import all_gather_over_fsdp
31+
from MaxText.sharding import all_gather_over_fsdp
3232

3333

3434
class Pipeline(nn.Module):

0 commit comments

Comments
 (0)