Skip to content

Commit ecc3506

Browse files
committed
Revert "all 9 flags"
This reverts commit 82f327b.
1 parent 82f327b commit ecc3506

File tree

3 files changed

+55
-161
lines changed

3 files changed

+55
-161
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,9 @@ use_random_routing: False # whether to use random routing for debug/test purpose
178178
use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops
179179
use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism
180180
# Tunable tiling dimensions used for Megablox
181-
tile_fwd_batch_seq: 512
182-
tile_fwd_embed_dim: 1024
183-
tile_fwd_mlp_dim: 1024
184-
tile_dlhs_fwd_batch_seq: 512
185-
tile_dlhs_fwd_embed_dim: 1024
186-
tile_dlhs_fwd_mlp_dim: 1024
187-
tile_drhs_fwd_batch_seq: 512
188-
tile_drhs_fwd_embed_dim: 1024
189-
tile_drhs_fwd_mlp_dim: 1024
181+
tile_batch_seq: 512
182+
tile_embed_dim: 1024
183+
tile_mlp_dim: 1024
190184
norm_topk_prob: False # Boolean to enable the top-k probability normalization. Qwen3-specific normalization of router weights.
191185

192186
# How the expert axis is used to shard attention weights and activations

src/MaxText/kernels/megablox/ops.py

Lines changed: 46 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,27 @@
1717
# pylint: disable=too-many-positional-arguments
1818

1919
import functools
20-
import itertools
21-
import dataclasses
2220
from typing import Literal
2321
import jax
2422
import jax.numpy as jnp
25-
from jax.experimental.xla_metadata import set_xla_metadata
26-
from MaxText.kernels.megablox import backend as megablox_backend
27-
# from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
28-
import tokamax
23+
from MaxText.kernels.megablox import backend
2924
import qwix
3025
import qwix.pallas as qpl
3126

32-
_counter = itertools.count()
27+
3328
def gmm(
3429
lhs: jnp.ndarray,
3530
rhs: jnp.ndarray,
3631
group_sizes: jnp.ndarray,
3732
preferred_element_type: jnp.dtype = jnp.float32,
38-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
33+
tiling: tuple[int, int, int] = (128, 128, 128),
3934
group_offset: jnp.ndarray | None = None,
4035
existing_out: jnp.ndarray | None = None,
4136
transpose_rhs: bool = False,
4237
interpret: bool = False,
4338
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
4439
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
4540
use_qwix_quantization: bool = False,
46-
use_tokamax_backend: bool = False,
4741
):
4842
"""Grouped matrix multiplication operation."""
4943
quantization_rule = None
@@ -76,7 +70,6 @@ def gmm(
7670
transpose_rhs,
7771
interpret,
7872
quantization_rule,
79-
use_tokamax_backend,
8073
)
8174

8275

@@ -85,13 +78,12 @@ def _gmm_fwd(
8578
rhs: jnp.ndarray,
8679
group_sizes: jnp.ndarray,
8780
preferred_element_type: jnp.dtype = jnp.float32,
88-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
81+
tiling: tuple[int, int, int] = (128, 128, 128),
8982
group_offset: jnp.ndarray | None = None,
9083
existing_out: jnp.ndarray | None = None,
9184
transpose_rhs: bool = False,
9285
interpret: bool = False,
9386
quantization_rule: qwix.QtRule | None = None,
94-
use_tokamax_backend: bool = False,
9587
) -> tuple[
9688
jnp.ndarray,
9789
tuple[
@@ -102,17 +94,15 @@ def _gmm_fwd(
10294
],
10395
]:
10496
"""Forward function for GMM VJP."""
105-
fwd_counter = next(_counter)
10697
if quantization_rule:
10798
if quantization_rule.act_qtype:
108-
with set_xla_metadata(MUST_FUSE=fwd_counter):
109-
lhs = qpl.quantize(
110-
lhs,
111-
quantization_rule.act_qtype,
112-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
113-
calibration_method=quantization_rule.act_calibration_method,
114-
scale_dtype=jnp.float32,
115-
)
99+
lhs = qpl.quantize(
100+
lhs,
101+
quantization_rule.act_qtype,
102+
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
103+
calibration_method=quantization_rule.act_calibration_method,
104+
scale_dtype=jnp.float32,
105+
)
116106
if quantization_rule.weight_qtype:
117107
rhs = qpl.quantize(
118108
rhs,
@@ -124,50 +114,29 @@ def _gmm_fwd(
124114
calibration_method=quantization_rule.weight_calibration_method,
125115
scale_dtype=jnp.float32,
126116
)
127-
# QAG is only supported for following conditions
128-
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
129-
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, "fsdp", axis=0, tiled=True)
130-
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
131-
if use_tokamax_backend:
132-
with set_xla_metadata(MUST_FUSE=fwd_counter):
133-
out = tokamax.ragged_dot_general(
134-
lhs=lhs,
135-
rhs=rhs,
136-
group_sizes=group_sizes,
137-
ragged_dot_dimension_numbers=tokamax.RaggedDotDimensionNumbers(
138-
dot_dimension_numbers=(([1], [1]), ([], [])),
139-
lhs_ragged_dimensions=[0],
140-
rhs_group_dimensions=[0],
141-
),
142-
precision=jax.lax.Precision.DEFAULT,
143-
preferred_element_type=preferred_element_type,
144-
group_offset=group_offset,
145-
implementation="mosaic",
146-
)
147-
else:
148-
out = megablox_backend.gmm(
117+
118+
out = backend.gmm(
149119
lhs,
150120
rhs,
151121
group_sizes,
152122
preferred_element_type,
153-
tiling[:3],
123+
tiling,
154124
group_offset,
155125
existing_out,
156126
transpose_rhs=transpose_rhs,
157127
interpret=interpret,
158-
)
128+
)
159129
return out, (lhs, rhs, group_sizes, group_offset)
160130

161131

162132
def _gmm_bwd(
163133
lhs_dtype: jax.typing.DTypeLike,
164134
rhs_dtype: jax.typing.DTypeLike,
165135
preferred_element_type: jnp.dtype,
166-
tiling: tuple[int, int, int, int, int, int, int, int, int],
136+
tiling: tuple[int, int, int],
167137
transpose_rhs: bool,
168138
interpret: bool,
169139
quantization_rule: qwix.QtRule | None,
170-
use_tokamax_backend: bool,
171140
residual: tuple[
172141
jnp.ndarray | qpl.QArray,
173142
jnp.ndarray | qpl.QArray,
@@ -191,8 +160,6 @@ def _gmm_bwd(
191160
# - drhs_dout: the incoming gradient used to calculate drhs.
192161

193162
# dlhs_dout and drhs_dout can be different when quantization is enabled.
194-
dlhs_counter = next(_counter)
195-
drhs_counter = next(_counter)
196163
dlhs_dout = grad
197164
drhs_dout = grad
198165
if isinstance(rhs, qpl.QArray): # qvalue: [g, k, n] scale: [1, 1, n]
@@ -206,76 +173,41 @@ def _gmm_bwd(
206173
lhs = lhs.qvalue
207174
if quantization_rule and quantization_rule.bwd_qtype:
208175
# Enable backward pass quantization
209-
with set_xla_metadata(MUST_FUSE=dlhs_counter):
210-
dlhs_dout = qpl.quantize(
211-
dlhs_dout,
212-
quantization_rule.bwd_qtype,
213-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
214-
calibration_method=quantization_rule.bwd_calibration_method,
215-
scale_dtype=jnp.float32,
216-
)
217-
with set_xla_metadata(MUST_FUSE=drhs_counter):
218-
drhs_dout = qpl.quantize(
219-
drhs_dout,
220-
quantization_rule.bwd_qtype,
221-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
222-
calibration_method=quantization_rule.bwd_calibration_method,
223-
scale_dtype=jnp.float32,
224-
)
225-
if use_tokamax_backend:
226-
with set_xla_metadata(MUST_FUSE=dlhs_counter):
227-
dlhs = tokamax.ragged_dot_general(
228-
lhs=dlhs_dout,
229-
rhs=rhs,
230-
group_sizes=group_sizes,
231-
ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers(
232-
dot_dimension_numbers=(([1], [2]), ([], [])),
233-
lhs_ragged_dimensions=[0],
234-
rhs_group_dimensions=[0],
235-
),
236-
precision=jax.lax.Precision.DEFAULT,
237-
preferred_element_type=preferred_element_type,
238-
group_offset=group_offset,
239-
implementation="mosaic",
240-
)
241-
drhs = tokamax.tgmm(
242-
lhs.swapaxes(0, 1),
243-
drhs_dout,
244-
group_sizes=group_sizes,
245-
ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers(
246-
dot_dimension_numbers=(([0], [0]), ([], [])),
247-
lhs_ragged_dimensions=[0],
248-
rhs_group_dimensions=[],
249-
),
250-
precision=jax.lax.Precision.DEFAULT,
251-
preferred_element_type=preferred_element_type,
252-
group_offset=group_offset,
253-
implementation="mosaic",
254-
)
255-
else:
256-
dlhs = megablox_backend.gmm(
176+
dlhs_dout = qpl.quantize(
257177
dlhs_dout,
258-
rhs,
259-
group_sizes,
260-
lhs_dtype,
261-
tiling[3:6],
262-
group_offset,
263-
transpose_rhs=not transpose_rhs,
264-
interpret=interpret,
178+
quantization_rule.bwd_qtype,
179+
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
180+
calibration_method=quantization_rule.bwd_calibration_method,
181+
scale_dtype=jnp.float32,
265182
)
266-
drhs = megablox_backend.tgmm(
267-
lhs.swapaxes(0, 1),
183+
drhs_dout = qpl.quantize(
268184
drhs_dout,
269-
group_sizes,
270-
rhs_dtype,
271-
tiling[-3:],
272-
group_offset,
273-
num_actual_groups,
274-
interpret=interpret,
185+
quantization_rule.bwd_qtype,
186+
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
187+
calibration_method=quantization_rule.bwd_calibration_method,
188+
scale_dtype=jnp.float32,
275189
)
276190

277-
if quantization_rule and quantization_rule.bwd_qtype:
278-
drhs = jax.lax.psum_scatter(drhs, "fsdp", scatter_dimension=0, tiled=True)
191+
dlhs = backend.gmm(
192+
dlhs_dout,
193+
rhs,
194+
group_sizes,
195+
lhs_dtype,
196+
tiling,
197+
group_offset,
198+
transpose_rhs=not transpose_rhs,
199+
interpret=interpret,
200+
)
201+
drhs = backend.tgmm(
202+
lhs.swapaxes(0, 1),
203+
drhs_dout,
204+
group_sizes,
205+
rhs_dtype,
206+
tiling,
207+
group_offset,
208+
num_actual_groups,
209+
interpret=interpret,
210+
)
279211

280212
# NOTE: If the rhs transposition is fused into the forward pass we need to
281213
# return the transpose of the rhs gradient that we calculated above.

src/MaxText/layers/moe.py

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -810,12 +810,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
810810
min(tiling[0], m),
811811
min(tiling[1], k),
812812
min(tiling[2], n),
813-
min(tiling[3], m),
814-
min(tiling[4], k),
815-
min(tiling[5], n),
816-
min(tiling[6], m),
817-
min(tiling[7], k),
818-
min(tiling[8], n),
819813
)
820814
if self.config.use_tokamax_gmm:
821815
output = tokamax_api.ragged_dot(
@@ -826,19 +820,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
826820
preferred_element_type=self.dtype,
827821
implementation="mosaic",
828822
)
829-
elif self.config.use_tokamax_gmm and self.config.quantization:
830-
# call mblx gmm with tokamax quantization
831-
output = mblx.gmm(
832-
lhs=inputs,
833-
rhs=kernel,
834-
group_sizes=group_sizes,
835-
preferred_element_type=self.dtype,
836-
tiling=tiling,
837-
lhs_quantize_dtype=lhs_quantize_dtype,
838-
rhs_quantize_dtype=rhs_quantize_dtype,
839-
use_qwix_quantization=self.config.use_qwix_quantization,
840-
use_tokamax_backend=self.config.use_tokamax_gmm,
841-
)
842823
else:
843824
if self.config.megablox:
844825
output = mblx.gmm(
@@ -850,7 +831,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
850831
lhs_quantize_dtype=lhs_quantize_dtype,
851832
rhs_quantize_dtype=rhs_quantize_dtype,
852833
use_qwix_quantization=self.config.use_qwix_quantization,
853-
use_tokamax_backend=self.config.use_tokamax_gmm,
854834
)
855835
else:
856836
rhs_inputs = kernel
@@ -1061,26 +1041,14 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10611041
expert_assignments=selected_experts,
10621042
)
10631043
wi_tile_size = (
1064-
self.config.tile_fwd_batch_seq,
1065-
self.config.tile_fwd_embed_dim,
1066-
self.config.tile_fwd_mlp_dim,
1067-
self.config.tile_dlhs_batch_seq,
1068-
self.config.tile_dlhs_embed_dim,
1069-
self.config.tile_dlhs_mlp_dim,
1070-
self.config.tile_drhs_batch_seq,
1071-
self.config.tile_drhs_embed_dim,
1072-
self.config.tile_drhs_mlp_dim,
1044+
self.config.tile_batch_seq,
1045+
self.config.tile_embed_dim,
1046+
self.config.tile_mlp_dim,
10731047
)
10741048
wo_tile_size = (
1075-
self.config.tile_fwd_batch_seq,
1076-
self.config.tile_fwd_mlp_dim,
1077-
self.config.tile_fwd_embed_dim,
1078-
self.config.tile_dlhs_batch_seq,
1079-
self.config.tile_dlhs_mlp_dim,
1080-
self.config.tile_dlhs_embed_dim,
1081-
self.config.tile_drhs_batch_seq,
1082-
self.config.tile_drhs_mlp_dim,
1083-
self.config.tile_drhs_embed_dim,
1049+
self.config.tile_batch_seq,
1050+
self.config.tile_mlp_dim,
1051+
self.config.tile_embed_dim,
10841052
)
10851053
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
10861054
if self.get_tensor_transpose_parallelism_size() > 1:

0 commit comments

Comments
 (0)