Skip to content

Commit 82f327b

Browse files
committed
all 9 flags
1 parent 4d2419e commit 82f327b

File tree

3 files changed

+161
-55
lines changed

3 files changed

+161
-55
lines changed

src/MaxText/configs/base.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,15 @@ use_random_routing: False # whether to use random routing for debug/test purpose
178178
use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops
179179
use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism
180180
# Tunable tiling dimensions used for Megablox
181-
tile_batch_seq: 512
182-
tile_embed_dim: 1024
183-
tile_mlp_dim: 1024
181+
tile_fwd_batch_seq: 512
182+
tile_fwd_embed_dim: 1024
183+
tile_fwd_mlp_dim: 1024
184+
tile_dlhs_fwd_batch_seq: 512
185+
tile_dlhs_fwd_embed_dim: 1024
186+
tile_dlhs_fwd_mlp_dim: 1024
187+
tile_drhs_fwd_batch_seq: 512
188+
tile_drhs_fwd_embed_dim: 1024
189+
tile_drhs_fwd_mlp_dim: 1024
184190
norm_topk_prob: False # Boolean to enable the top-k probability normalization. Qwen3-specific normalization of router weights.
185191

186192
# How the expert axis is used to shard attention weights and activations

src/MaxText/kernels/megablox/ops.py

Lines changed: 114 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,33 @@
1717
# pylint: disable=too-many-positional-arguments
1818

1919
import functools
20+
import itertools
21+
import dataclasses
2022
from typing import Literal
2123
import jax
2224
import jax.numpy as jnp
23-
from MaxText.kernels.megablox import backend
25+
from jax.experimental.xla_metadata import set_xla_metadata
26+
from MaxText.kernels.megablox import backend as megablox_backend
27+
# from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
28+
import tokamax
2429
import qwix
2530
import qwix.pallas as qpl
2631

27-
32+
_counter = itertools.count()
2833
def gmm(
2934
lhs: jnp.ndarray,
3035
rhs: jnp.ndarray,
3136
group_sizes: jnp.ndarray,
3237
preferred_element_type: jnp.dtype = jnp.float32,
33-
tiling: tuple[int, int, int] = (128, 128, 128),
38+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
3439
group_offset: jnp.ndarray | None = None,
3540
existing_out: jnp.ndarray | None = None,
3641
transpose_rhs: bool = False,
3742
interpret: bool = False,
3843
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
3944
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
4045
use_qwix_quantization: bool = False,
46+
use_tokamax_backend: bool = False,
4147
):
4248
"""Grouped matrix multiplication operation."""
4349
quantization_rule = None
@@ -70,6 +76,7 @@ def gmm(
7076
transpose_rhs,
7177
interpret,
7278
quantization_rule,
79+
use_tokamax_backend,
7380
)
7481

7582

@@ -78,12 +85,13 @@ def _gmm_fwd(
7885
rhs: jnp.ndarray,
7986
group_sizes: jnp.ndarray,
8087
preferred_element_type: jnp.dtype = jnp.float32,
81-
tiling: tuple[int, int, int] = (128, 128, 128),
88+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
8289
group_offset: jnp.ndarray | None = None,
8390
existing_out: jnp.ndarray | None = None,
8491
transpose_rhs: bool = False,
8592
interpret: bool = False,
8693
quantization_rule: qwix.QtRule | None = None,
94+
use_tokamax_backend: bool = False,
8795
) -> tuple[
8896
jnp.ndarray,
8997
tuple[
@@ -94,15 +102,17 @@ def _gmm_fwd(
94102
],
95103
]:
96104
"""Forward function for GMM VJP."""
105+
fwd_counter = next(_counter)
97106
if quantization_rule:
98107
if quantization_rule.act_qtype:
99-
lhs = qpl.quantize(
100-
lhs,
101-
quantization_rule.act_qtype,
102-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
103-
calibration_method=quantization_rule.act_calibration_method,
104-
scale_dtype=jnp.float32,
105-
)
108+
with set_xla_metadata(MUST_FUSE=fwd_counter):
109+
lhs = qpl.quantize(
110+
lhs,
111+
quantization_rule.act_qtype,
112+
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
113+
calibration_method=quantization_rule.act_calibration_method,
114+
scale_dtype=jnp.float32,
115+
)
106116
if quantization_rule.weight_qtype:
107117
rhs = qpl.quantize(
108118
rhs,
@@ -114,29 +124,50 @@ def _gmm_fwd(
114124
calibration_method=quantization_rule.weight_calibration_method,
115125
scale_dtype=jnp.float32,
116126
)
117-
118-
out = backend.gmm(
127+
# QAG is only supported for following conditions
128+
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
129+
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, "fsdp", axis=0, tiled=True)
130+
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
131+
if use_tokamax_backend:
132+
with set_xla_metadata(MUST_FUSE=fwd_counter):
133+
out = tokamax.ragged_dot_general(
134+
lhs=lhs,
135+
rhs=rhs,
136+
group_sizes=group_sizes,
137+
ragged_dot_dimension_numbers=tokamax.RaggedDotDimensionNumbers(
138+
dot_dimension_numbers=(([1], [1]), ([], [])),
139+
lhs_ragged_dimensions=[0],
140+
rhs_group_dimensions=[0],
141+
),
142+
precision=jax.lax.Precision.DEFAULT,
143+
preferred_element_type=preferred_element_type,
144+
group_offset=group_offset,
145+
implementation="mosaic",
146+
)
147+
else:
148+
out = megablox_backend.gmm(
119149
lhs,
120150
rhs,
121151
group_sizes,
122152
preferred_element_type,
123-
tiling,
153+
tiling[:3],
124154
group_offset,
125155
existing_out,
126156
transpose_rhs=transpose_rhs,
127157
interpret=interpret,
128-
)
158+
)
129159
return out, (lhs, rhs, group_sizes, group_offset)
130160

131161

132162
def _gmm_bwd(
133163
lhs_dtype: jax.typing.DTypeLike,
134164
rhs_dtype: jax.typing.DTypeLike,
135165
preferred_element_type: jnp.dtype,
136-
tiling: tuple[int, int, int],
166+
tiling: tuple[int, int, int, int, int, int, int, int, int],
137167
transpose_rhs: bool,
138168
interpret: bool,
139169
quantization_rule: qwix.QtRule | None,
170+
use_tokamax_backend: bool,
140171
residual: tuple[
141172
jnp.ndarray | qpl.QArray,
142173
jnp.ndarray | qpl.QArray,
@@ -160,6 +191,8 @@ def _gmm_bwd(
160191
# - drhs_dout: the incoming gradient used to calculate drhs.
161192

162193
# dlhs_dout and drhs_dout can be different when quantization is enabled.
194+
dlhs_counter = next(_counter)
195+
drhs_counter = next(_counter)
163196
dlhs_dout = grad
164197
drhs_dout = grad
165198
if isinstance(rhs, qpl.QArray): # qvalue: [g, k, n] scale: [1, 1, n]
@@ -173,41 +206,76 @@ def _gmm_bwd(
173206
lhs = lhs.qvalue
174207
if quantization_rule and quantization_rule.bwd_qtype:
175208
# Enable backward pass quantization
176-
dlhs_dout = qpl.quantize(
209+
with set_xla_metadata(MUST_FUSE=dlhs_counter):
210+
dlhs_dout = qpl.quantize(
211+
dlhs_dout,
212+
quantization_rule.bwd_qtype,
213+
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
214+
calibration_method=quantization_rule.bwd_calibration_method,
215+
scale_dtype=jnp.float32,
216+
)
217+
with set_xla_metadata(MUST_FUSE=drhs_counter):
218+
drhs_dout = qpl.quantize(
219+
drhs_dout,
220+
quantization_rule.bwd_qtype,
221+
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
222+
calibration_method=quantization_rule.bwd_calibration_method,
223+
scale_dtype=jnp.float32,
224+
)
225+
if use_tokamax_backend:
226+
with set_xla_metadata(MUST_FUSE=dlhs_counter):
227+
dlhs = tokamax.ragged_dot_general(
228+
lhs=dlhs_dout,
229+
rhs=rhs,
230+
group_sizes=group_sizes,
231+
ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers(
232+
dot_dimension_numbers=(([1], [2]), ([], [])),
233+
lhs_ragged_dimensions=[0],
234+
rhs_group_dimensions=[0],
235+
),
236+
precision=jax.lax.Precision.DEFAULT,
237+
preferred_element_type=preferred_element_type,
238+
group_offset=group_offset,
239+
implementation="mosaic",
240+
)
241+
drhs = tokamax.tgmm(
242+
lhs.swapaxes(0, 1),
243+
drhs_dout,
244+
group_sizes=group_sizes,
245+
ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers(
246+
dot_dimension_numbers=(([0], [0]), ([], [])),
247+
lhs_ragged_dimensions=[0],
248+
rhs_group_dimensions=[],
249+
),
250+
precision=jax.lax.Precision.DEFAULT,
251+
preferred_element_type=preferred_element_type,
252+
group_offset=group_offset,
253+
implementation="mosaic",
254+
)
255+
else:
256+
dlhs = megablox_backend.gmm(
177257
dlhs_dout,
178-
quantization_rule.bwd_qtype,
179-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
180-
calibration_method=quantization_rule.bwd_calibration_method,
181-
scale_dtype=jnp.float32,
258+
rhs,
259+
group_sizes,
260+
lhs_dtype,
261+
tiling[3:6],
262+
group_offset,
263+
transpose_rhs=not transpose_rhs,
264+
interpret=interpret,
182265
)
183-
drhs_dout = qpl.quantize(
266+
drhs = megablox_backend.tgmm(
267+
lhs.swapaxes(0, 1),
184268
drhs_dout,
185-
quantization_rule.bwd_qtype,
186-
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
187-
calibration_method=quantization_rule.bwd_calibration_method,
188-
scale_dtype=jnp.float32,
269+
group_sizes,
270+
rhs_dtype,
271+
tiling[-3:],
272+
group_offset,
273+
num_actual_groups,
274+
interpret=interpret,
189275
)
190276

191-
dlhs = backend.gmm(
192-
dlhs_dout,
193-
rhs,
194-
group_sizes,
195-
lhs_dtype,
196-
tiling,
197-
group_offset,
198-
transpose_rhs=not transpose_rhs,
199-
interpret=interpret,
200-
)
201-
drhs = backend.tgmm(
202-
lhs.swapaxes(0, 1),
203-
drhs_dout,
204-
group_sizes,
205-
rhs_dtype,
206-
tiling,
207-
group_offset,
208-
num_actual_groups,
209-
interpret=interpret,
210-
)
277+
if quantization_rule and quantization_rule.bwd_qtype:
278+
drhs = jax.lax.psum_scatter(drhs, "fsdp", scatter_dimension=0, tiled=True)
211279

212280
# NOTE: If the rhs transposition is fused into the forward pass we need to
213281
# return the transpose of the rhs gradient that we calculated above.

src/MaxText/layers/moe.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,12 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
810810
min(tiling[0], m),
811811
min(tiling[1], k),
812812
min(tiling[2], n),
813+
min(tiling[3], m),
814+
min(tiling[4], k),
815+
min(tiling[5], n),
816+
min(tiling[6], m),
817+
min(tiling[7], k),
818+
min(tiling[8], n),
813819
)
814820
if self.config.use_tokamax_gmm:
815821
output = tokamax_api.ragged_dot(
@@ -820,6 +826,19 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
820826
preferred_element_type=self.dtype,
821827
implementation="mosaic",
822828
)
829+
elif self.config.use_tokamax_gmm and self.config.quantization:
830+
# call mblx gmm with tokamax quantization
831+
output = mblx.gmm(
832+
lhs=inputs,
833+
rhs=kernel,
834+
group_sizes=group_sizes,
835+
preferred_element_type=self.dtype,
836+
tiling=tiling,
837+
lhs_quantize_dtype=lhs_quantize_dtype,
838+
rhs_quantize_dtype=rhs_quantize_dtype,
839+
use_qwix_quantization=self.config.use_qwix_quantization,
840+
use_tokamax_backend=self.config.use_tokamax_gmm,
841+
)
823842
else:
824843
if self.config.megablox:
825844
output = mblx.gmm(
@@ -831,6 +850,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
831850
lhs_quantize_dtype=lhs_quantize_dtype,
832851
rhs_quantize_dtype=rhs_quantize_dtype,
833852
use_qwix_quantization=self.config.use_qwix_quantization,
853+
use_tokamax_backend=self.config.use_tokamax_gmm,
834854
)
835855
else:
836856
rhs_inputs = kernel
@@ -1041,14 +1061,26 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10411061
expert_assignments=selected_experts,
10421062
)
10431063
wi_tile_size = (
1044-
self.config.tile_batch_seq,
1045-
self.config.tile_embed_dim,
1046-
self.config.tile_mlp_dim,
1064+
self.config.tile_fwd_batch_seq,
1065+
self.config.tile_fwd_embed_dim,
1066+
self.config.tile_fwd_mlp_dim,
1067+
self.config.tile_dlhs_batch_seq,
1068+
self.config.tile_dlhs_embed_dim,
1069+
self.config.tile_dlhs_mlp_dim,
1070+
self.config.tile_drhs_batch_seq,
1071+
self.config.tile_drhs_embed_dim,
1072+
self.config.tile_drhs_mlp_dim,
10471073
)
10481074
wo_tile_size = (
1049-
self.config.tile_batch_seq,
1050-
self.config.tile_mlp_dim,
1051-
self.config.tile_embed_dim,
1075+
self.config.tile_fwd_batch_seq,
1076+
self.config.tile_fwd_mlp_dim,
1077+
self.config.tile_fwd_embed_dim,
1078+
self.config.tile_dlhs_batch_seq,
1079+
self.config.tile_dlhs_mlp_dim,
1080+
self.config.tile_dlhs_embed_dim,
1081+
self.config.tile_drhs_batch_seq,
1082+
self.config.tile_drhs_mlp_dim,
1083+
self.config.tile_drhs_embed_dim,
10521084
)
10531085
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
10541086
if self.get_tensor_transpose_parallelism_size() > 1:

0 commit comments

Comments
 (0)