Skip to content

Commit c5a238f

Browse files
Merge pull request #2790 from AI-Hypercomputer:nicogrande/add-vllm-axis-moe
PiperOrigin-RevId: 840778838
2 parents 6f47311 + e7cbe14 commit c5a238f

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/MaxText/layers/moe.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def __init__(
312312
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
313313
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
314314

315+
if self.config.attention == "vllm_rpa":
316+
# vLLM uses 'model' as the tensor parallelism axis name
317+
self._tensor_parallelism_name = "model"
318+
else:
319+
self._tensor_parallelism_name = "tensor"
320+
315321
self.gate = GateLogit(
316322
in_features_shape=self.config.emb_dim,
317323
out_features_shape=self.num_experts,
@@ -398,7 +404,7 @@ def get_expert_parallelism_size(self):
398404
return self.mesh.shape.get("expert", 1)
399405

400406
def get_tensor_parallelism_size(self):
401-
return self.mesh.shape.get("tensor", 1)
407+
return self.mesh.shape.get(self._tensor_parallelism_name, 1)
402408

403409
def get_tensor_transpose_parallelism_size(self):
404410
return self.mesh.shape.get("tensor_transpose", 1)
@@ -1073,14 +1079,17 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10731079

10741080
if self.config.mlp_bias:
10751081
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
1082+
10761083
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
1077-
if pspec_dim_axes is None: return []
1084+
if pspec_dim_axes is None:
1085+
return []
10781086
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
10791087
active = []
10801088
for ax in axes:
10811089
if ax and self.mesh.shape.get(ax, 1) > 1:
10821090
active.append((ax, tensor_dim_index))
10831091
return active
1092+
10841093
wi_gather_axes = []
10851094
wo_gather_axes = []
10861095

@@ -1136,7 +1145,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
11361145

11371146
intermediate_output = gmm_fn(intermediate_layer, wo, tiling=wo_tile_size, weight_gather_axes=wo_gather_axes)
11381147
if self.get_tensor_parallelism_size() > 1:
1139-
intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True)
1148+
intermediate_output = jax.lax.psum_scatter(
1149+
intermediate_output, self._tensor_parallelism_name, scatter_dimension=1, tiled=True
1150+
)
11401151
if self.config.mlp_bias:
11411152
intermediate_output = intermediate_output + wo_bias
11421153
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")

0 commit comments

Comments
 (0)