3838from MaxText .layers import pipeline
3939from MaxText import maxtext_utils
4040from MaxText import multimodal_utils
41+ from MaxText import sharding
4142from MaxText .layers .attentions import attention_as_linen
4243from MaxText .layers .normalizations import rms_norm
4344from MaxText .layers .embeddings import attend_on_embedding , embed_as_linen , positional_embedding_as_linen
@@ -90,7 +91,7 @@ def __call__(
9091 cfg = self .config
9192 mesh = self .mesh
9293 _maybe_shard_with_logical = functools .partial (
93- maxtext_utils .maybe_shard_with_logical ,
94+ sharding .maybe_shard_with_logical ,
9495 mesh = mesh ,
9596 shard_mode = cfg .shard_mode ,
9697 )
@@ -722,7 +723,7 @@ def __call__(
722723 moe_layer = RemattedBlockLayers [1 ]
723724 num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
724725 num_moe_layers_outside_pp = num_moe_layers - self .config .pipeline_parallel_layers
725- logical_axis_rules_pp_as_dp = maxtext_utils .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
726+ logical_axis_rules_pp_as_dp = sharding .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
726727 # We chose not to pipeline the dense layers, only sparse for SPMD.
727728 with self .mesh , nn .partitioning .axis_rules (logical_axis_rules_pp_as_dp ):
728729 y , _ = self .scan_decoder_layers (
@@ -749,7 +750,7 @@ def __call__(
749750 y = self .pipeline_module (y , * broadcast_args , partition_spec = partition_spec )
750751 remaining_layers = self .config .num_decoder_layers - self .config .pipeline_parallel_layers
751752 if remaining_layers > 0 :
752- logical_axis_rules_pp_as_dp = maxtext_utils .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
753+ logical_axis_rules_pp_as_dp = sharding .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
753754 with self .mesh , nn .partitioning .axis_rules (logical_axis_rules_pp_as_dp ):
754755 y , _ = self .scan_decoder_layers (
755756 cfg ,
0 commit comments