3636from MaxText .layers import attentions , linears , quantizations , nnx_wrappers
3737from MaxText .layers .initializers import NdInitializer , nd_dense_init , default_bias_init , variable_to_logically_partitioned
3838
39- if jax .__version__ >= "0.8.0" :
40- from tokamax ._src .ops .ragged_dot import api as tokamax_api
39+ from tokamax ._src .ops .ragged_dot import api as tokamax_api
4140
4241set_xla_metadata = xla_metadata .set_xla_metadata
4342
@@ -809,17 +808,17 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
809808 min (tiling [1 ], k ),
810809 min (tiling [2 ], n ),
811810 )
812- if self .config .megablox :
813- if self . config . use_tokamax_gmm :
814- output = tokamax_api . ragged_dot ( # pylint: disable=possibly-used-before-assignment
815- lhs = inputs ,
816- rhs = kernel ,
817- group_sizes = group_sizes ,
818- precision = jax . lax . Precision . DEFAULT ,
819- preferred_element_type = self . dtype ,
820- implementation = "mosaic" ,
821- )
822- else :
811+ if self .config .use_tokamax_gmm :
812+ output = tokamax_api . ragged_dot (
813+ lhs = inputs ,
814+ rhs = kernel ,
815+ group_sizes = group_sizes ,
816+ precision = jax . lax . Precision . DEFAULT ,
817+ preferred_element_type = self . dtype ,
818+ implementation = "mosaic" ,
819+ )
820+ else :
821+ if self . config . megablox :
823822 output = mblx .gmm (
824823 lhs = inputs ,
825824 rhs = kernel ,
@@ -830,29 +829,29 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
830829 rhs_quantize_dtype = rhs_quantize_dtype ,
831830 use_qwix_quantization = self .config .use_qwix_quantization ,
832831 )
833- else :
834- rhs_inputs = kernel
835- if isinstance (kernel , aqt .QTensor ):
836- if kernel .bias or kernel .sparsity_mask or len (kernel .scale ) > 1 :
837- raise ValueError ("Unsupported usecase for ragged_dot with quantized kernel." )
838- rhs_inputs = kernel .qvalue
839- with set_xla_metadata (ragged_dot_tiling = "," .join ([str (t ) for t in tiling ])):
840- output = jax .lax .ragged_dot (
841- lhs = inputs ,
842- rhs = rhs_inputs ,
843- group_sizes = group_sizes ,
844- preferred_element_type = self .dtype ,
845- )
846- if isinstance (kernel , aqt .QTensor ):
847- # Multiply outputs by the kernely scale
848- scales = jnp .take (kernel .scale [0 ].squeeze (), indices = expert_assignments , axis = 0 )
849- if padding_amount > 0 :
850- scales = jax .lax .pad (
851- scales ,
852- jnp .array (0.0 , dtype = scales .dtype ),
853- [(0 , padding_amount , 0 ), (0 , 0 , 0 )],
832+ else :
833+ rhs_inputs = kernel
834+ if isinstance (kernel , aqt .QTensor ):
835+ if kernel .bias or kernel .sparsity_mask or len (kernel .scale ) > 1 :
836+ raise ValueError ("Unsupported usecase for ragged_dot with quantized kernel." )
837+ rhs_inputs = kernel .qvalue
838+ with set_xla_metadata (ragged_dot_tiling = "," .join ([str (t ) for t in tiling ])):
839+ output = jax .lax .ragged_dot (
840+ lhs = inputs ,
841+ rhs = rhs_inputs ,
842+ group_sizes = group_sizes ,
843+ preferred_element_type = self .dtype ,
854844 )
855- output *= scales
845+ if isinstance (kernel , aqt .QTensor ):
846+ # Multiply outputs by the kernely scale
847+ scales = jnp .take (kernel .scale [0 ].squeeze (), indices = expert_assignments , axis = 0 )
848+ if padding_amount > 0 :
849+ scales = jax .lax .pad (
850+ scales ,
851+ jnp .array (0.0 , dtype = scales .dtype ),
852+ [(0 , padding_amount , 0 ), (0 , 0 , 0 )],
853+ )
854+ output *= scales
856855 if padding_amount > 0 :
857856 output = output [: hs_shape [0 ]]
858857 return output
0 commit comments