@@ -312,6 +312,12 @@ def __init__(
312312 self .wi_kernel_axes = ("exp" , "embed_no_exp" , "mlp" )
313313 self .wo_kernel_axes = ("exp" , "mlp" , "embed_no_exp" )
314314
315+ if self .config .attention == "vllm_rpa" :
316+ # vLLM uses 'model' as the tensor parallelism axis name
317+ self ._tensor_parallelism_name = "model"
318+ else :
319+ self ._tensor_parallelism_name = "tensor"
320+
315321 self .gate = GateLogit (
316322 in_features_shape = self .config .emb_dim ,
317323 out_features_shape = self .num_experts ,
@@ -398,7 +404,7 @@ def get_expert_parallelism_size(self):
398404 return self .mesh .shape .get ("expert" , 1 )
399405
400406 def get_tensor_parallelism_size (self ):
401- return self .mesh .shape .get ("tensor" , 1 )
407+ return self .mesh .shape .get (self . _tensor_parallelism_name , 1 )
402408
403409 def get_tensor_transpose_parallelism_size (self ):
404410 return self .mesh .shape .get ("tensor_transpose" , 1 )
@@ -1073,14 +1079,17 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10731079
10741080 if self .config .mlp_bias :
10751081 w0_bias , w1_bias , wo_bias = self .transform_bias (selected_experts , w0_bias , w1_bias , wo_bias )
1082+
10761083 def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
1077- if pspec_dim_axes is None : return []
1084+ if pspec_dim_axes is None :
1085+ return []
10781086 axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
10791087 active = []
10801088 for ax in axes :
10811089 if ax and self .mesh .shape .get (ax , 1 ) > 1 :
10821090 active .append ((ax , tensor_dim_index ))
10831091 return active
1092+
10841093 wi_gather_axes = []
10851094 wo_gather_axes = []
10861095
@@ -1136,7 +1145,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
11361145
11371146 intermediate_output = gmm_fn (intermediate_layer , wo , tiling = wo_tile_size , weight_gather_axes = wo_gather_axes )
11381147 if self .get_tensor_parallelism_size () > 1 :
1139- intermediate_output = jax .lax .psum_scatter (intermediate_output , "tensor" , scatter_dimension = 1 , tiled = True )
1148+ intermediate_output = jax .lax .psum_scatter (
1149+ intermediate_output , self ._tensor_parallelism_name , scatter_dimension = 1 , tiled = True
1150+ )
11401151 if self .config .mlp_bias :
11411152 intermediate_output = intermediate_output + wo_bias
11421153 intermediate_output = adc .checkpoint_name (intermediate_output , "mlpwo" )
0 commit comments