Skip to content

Commit a24092e

Browse files
committed
Update configs for tokamax
1 parent 9d5cbba commit a24092e

File tree

2 files changed

+45
-50
lines changed

2 files changed

+45
-50
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import math
2121

2222
import numpy as np
23-
from packaging import version
2423

2524
import jax
2625
from jax import lax
@@ -1286,21 +1285,18 @@ def wrap_flash_attention(
12861285
decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv)
12871286
else:
12881287
decoder_segment_ids_tuple = None
1289-
# TODO(ranran): remove if/else branch once b/441336842 is fixed
1290-
if version.parse(jax.__version__) < version.parse("0.7.2.dev20250824"):
1291-
attention_output = jax.vmap(splash_kernel)(query, key, value, decoder_segment_ids_tuple)
1292-
else:
1293-
if self.config.use_tokamax_splash:
1294-
if max_logit_value is not None:
1295-
attention_output = jax.vmap(partial(splash_kernel, max_logit_value=max_logit_value))(
1296-
query, key, value, decoder_segment_ids_tuple
1297-
)
1298-
else:
1299-
attention_output = jax.vmap(splash_kernel)(query, key, value, decoder_segment_ids_tuple)
1300-
else:
1301-
attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))(
1302-
query, key, value, decoder_segment_ids_tuple, sinks
1288+
1289+
if self.config.use_tokamax_splash:
1290+
if max_logit_value is not None:
1291+
attention_output = jax.vmap(partial(splash_kernel, max_logit_value=max_logit_value))(
1292+
query, key, value, decoder_segment_ids_tuple
13031293
)
1294+
else:
1295+
attention_output = jax.vmap(splash_kernel)(query, key, value, decoder_segment_ids_tuple)
1296+
else:
1297+
attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))(
1298+
query, key, value, decoder_segment_ids_tuple, sinks
1299+
)
13041300
return attention_output
13051301

13061302
def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):

src/MaxText/layers/moe.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from MaxText.layers import attentions, linears, quantizations, nnx_wrappers
3737
from MaxText.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned
3838

39-
if jax.__version__ >= "0.8.0":
40-
from tokamax._src.ops.ragged_dot import api as tokamax_api
39+
from tokamax._src.ops.ragged_dot import api as tokamax_api
4140

4241
set_xla_metadata = xla_metadata.set_xla_metadata
4342

@@ -809,17 +808,17 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
809808
min(tiling[1], k),
810809
min(tiling[2], n),
811810
)
812-
if self.config.megablox:
813-
if self.config.use_tokamax_gmm:
814-
output = tokamax_api.ragged_dot( # pylint: disable=possibly-used-before-assignment
815-
lhs=inputs,
816-
rhs=kernel,
817-
group_sizes=group_sizes,
818-
precision=jax.lax.Precision.DEFAULT,
819-
preferred_element_type=self.dtype,
820-
implementation="mosaic",
821-
)
822-
else:
811+
if self.config.use_tokamax_gmm:
812+
output = tokamax_api.ragged_dot(
813+
lhs=inputs,
814+
rhs=kernel,
815+
group_sizes=group_sizes,
816+
precision=jax.lax.Precision.DEFAULT,
817+
preferred_element_type=self.dtype,
818+
implementation="mosaic",
819+
)
820+
else:
821+
if self.config.megablox:
823822
output = mblx.gmm(
824823
lhs=inputs,
825824
rhs=kernel,
@@ -830,29 +829,29 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
830829
rhs_quantize_dtype=rhs_quantize_dtype,
831830
use_qwix_quantization=self.config.use_qwix_quantization,
832831
)
833-
else:
834-
rhs_inputs = kernel
835-
if isinstance(kernel, aqt.QTensor):
836-
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
837-
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
838-
rhs_inputs = kernel.qvalue
839-
with set_xla_metadata(ragged_dot_tiling=",".join([str(t) for t in tiling])):
840-
output = jax.lax.ragged_dot(
841-
lhs=inputs,
842-
rhs=rhs_inputs,
843-
group_sizes=group_sizes,
844-
preferred_element_type=self.dtype,
845-
)
846-
if isinstance(kernel, aqt.QTensor):
847-
# Multiply outputs by the kernely scale
848-
scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0)
849-
if padding_amount > 0:
850-
scales = jax.lax.pad(
851-
scales,
852-
jnp.array(0.0, dtype=scales.dtype),
853-
[(0, padding_amount, 0), (0, 0, 0)],
832+
else:
833+
rhs_inputs = kernel
834+
if isinstance(kernel, aqt.QTensor):
835+
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
836+
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
837+
rhs_inputs = kernel.qvalue
838+
with set_xla_metadata(ragged_dot_tiling=",".join([str(t) for t in tiling])):
839+
output = jax.lax.ragged_dot(
840+
lhs=inputs,
841+
rhs=rhs_inputs,
842+
group_sizes=group_sizes,
843+
preferred_element_type=self.dtype,
854844
)
855-
output *= scales
845+
if isinstance(kernel, aqt.QTensor):
846+
# Multiply outputs by the kernely scale
847+
scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0)
848+
if padding_amount > 0:
849+
scales = jax.lax.pad(
850+
scales,
851+
jnp.array(0.0, dtype=scales.dtype),
852+
[(0, padding_amount, 0), (0, 0, 0)],
853+
)
854+
output *= scales
856855
if padding_amount > 0:
857856
output = output[: hs_shape[0]]
858857
return output

0 commit comments

Comments
 (0)