1717# pylint: disable=too-many-positional-arguments
1818
1919import functools
20+ import itertools
21+ import dataclasses
2022from typing import Literal
2123import jax
2224import jax .numpy as jnp
23- from MaxText .kernels .megablox import backend
25+ from jax .experimental .xla_metadata import set_xla_metadata
26+ from MaxText .kernels .megablox import backend as megablox_backend
27+ # from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
28+ import tokamax
2429import qwix
2530import qwix .pallas as qpl
2631
27-
32+ _counter = itertools . count ()
2833def gmm (
2934 lhs : jnp .ndarray ,
3035 rhs : jnp .ndarray ,
3136 group_sizes : jnp .ndarray ,
3237 preferred_element_type : jnp .dtype = jnp .float32 ,
33- tiling : tuple [int , int , int ] = (128 , 128 , 128 ),
38+ tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
3439 group_offset : jnp .ndarray | None = None ,
3540 existing_out : jnp .ndarray | None = None ,
3641 transpose_rhs : bool = False ,
3742 interpret : bool = False ,
3843 lhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None = None ,
3944 rhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None = None ,
4045 use_qwix_quantization : bool = False ,
46+ use_tokamax_backend : bool = False ,
4147):
4248 """Grouped matrix multiplication operation."""
4349 quantization_rule = None
@@ -70,6 +76,7 @@ def gmm(
7076 transpose_rhs ,
7177 interpret ,
7278 quantization_rule ,
79+ use_tokamax_backend ,
7380 )
7481
7582
@@ -78,12 +85,13 @@ def _gmm_fwd(
7885 rhs : jnp .ndarray ,
7986 group_sizes : jnp .ndarray ,
8087 preferred_element_type : jnp .dtype = jnp .float32 ,
81- tiling : tuple [int , int , int ] = (128 , 128 , 128 ),
88+ tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
8289 group_offset : jnp .ndarray | None = None ,
8390 existing_out : jnp .ndarray | None = None ,
8491 transpose_rhs : bool = False ,
8592 interpret : bool = False ,
8693 quantization_rule : qwix .QtRule | None = None ,
94+ use_tokamax_backend : bool = False ,
8795) -> tuple [
8896 jnp .ndarray ,
8997 tuple [
@@ -94,15 +102,17 @@ def _gmm_fwd(
94102 ],
95103]:
96104 """Forward function for GMM VJP."""
105+ fwd_counter = next (_counter )
97106 if quantization_rule :
98107 if quantization_rule .act_qtype :
99- lhs = qpl .quantize (
100- lhs ,
101- quantization_rule .act_qtype ,
102- channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
103- calibration_method = quantization_rule .act_calibration_method ,
104- scale_dtype = jnp .float32 ,
105- )
108+ with set_xla_metadata (MUST_FUSE = fwd_counter ):
109+ lhs = qpl .quantize (
110+ lhs ,
111+ quantization_rule .act_qtype ,
112+ channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
113+ calibration_method = quantization_rule .act_calibration_method ,
114+ scale_dtype = jnp .float32 ,
115+ )
106116 if quantization_rule .weight_qtype :
107117 rhs = qpl .quantize (
108118 rhs ,
@@ -114,29 +124,50 @@ def _gmm_fwd(
114124 calibration_method = quantization_rule .weight_calibration_method ,
115125 scale_dtype = jnp .float32 ,
116126 )
117-
118- out = backend .gmm (
127+ # QAG is only supported for following conditions
128+ if quantization_rule .weight_calibration_method .startswith ("fixed" ) and isinstance (rhs , qpl .QArray ):
129+ rhs_qvalue = jax .lax .all_gather (rhs .qvalue , "fsdp" , axis = 0 , tiled = True )
130+ rhs = dataclasses .replace (rhs , qvalue = rhs_qvalue )
131+ if use_tokamax_backend :
132+ with set_xla_metadata (MUST_FUSE = fwd_counter ):
133+ out = tokamax .ragged_dot_general (
134+ lhs = lhs ,
135+ rhs = rhs ,
136+ group_sizes = group_sizes ,
137+ ragged_dot_dimension_numbers = tokamax .RaggedDotDimensionNumbers (
138+ dot_dimension_numbers = (([1 ], [1 ]), ([], [])),
139+ lhs_ragged_dimensions = [0 ],
140+ rhs_group_dimensions = [0 ],
141+ ),
142+ precision = jax .lax .Precision .DEFAULT ,
143+ preferred_element_type = preferred_element_type ,
144+ group_offset = group_offset ,
145+ implementation = "mosaic" ,
146+ )
147+ else :
148+ out = megablox_backend .gmm (
119149 lhs ,
120150 rhs ,
121151 group_sizes ,
122152 preferred_element_type ,
123- tiling ,
153+ tiling [: 3 ] ,
124154 group_offset ,
125155 existing_out ,
126156 transpose_rhs = transpose_rhs ,
127157 interpret = interpret ,
128- )
158+ )
129159 return out , (lhs , rhs , group_sizes , group_offset )
130160
131161
132162def _gmm_bwd (
133163 lhs_dtype : jax .typing .DTypeLike ,
134164 rhs_dtype : jax .typing .DTypeLike ,
135165 preferred_element_type : jnp .dtype ,
136- tiling : tuple [int , int , int ],
166+ tiling : tuple [int , int , int , int , int , int , int , int , int ],
137167 transpose_rhs : bool ,
138168 interpret : bool ,
139169 quantization_rule : qwix .QtRule | None ,
170+ use_tokamax_backend : bool ,
140171 residual : tuple [
141172 jnp .ndarray | qpl .QArray ,
142173 jnp .ndarray | qpl .QArray ,
@@ -160,6 +191,8 @@ def _gmm_bwd(
160191 # - drhs_dout: the incoming gradient used to calculate drhs.
161192
162193 # dlhs_dout and drhs_dout can be different when quantization is enabled.
194+ dlhs_counter = next (_counter )
195+ drhs_counter = next (_counter )
163196 dlhs_dout = grad
164197 drhs_dout = grad
165198 if isinstance (rhs , qpl .QArray ): # qvalue: [g, k, n] scale: [1, 1, n]
@@ -173,41 +206,76 @@ def _gmm_bwd(
173206 lhs = lhs .qvalue
174207 if quantization_rule and quantization_rule .bwd_qtype :
175208 # Enable backward pass quantization
176- dlhs_dout = qpl .quantize (
209+ with set_xla_metadata (MUST_FUSE = dlhs_counter ):
210+ dlhs_dout = qpl .quantize (
211+ dlhs_dout ,
212+ quantization_rule .bwd_qtype ,
213+ channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
214+ calibration_method = quantization_rule .bwd_calibration_method ,
215+ scale_dtype = jnp .float32 ,
216+ )
217+ with set_xla_metadata (MUST_FUSE = drhs_counter ):
218+ drhs_dout = qpl .quantize (
219+ drhs_dout ,
220+ quantization_rule .bwd_qtype ,
221+ channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [1 ],
222+ calibration_method = quantization_rule .bwd_calibration_method ,
223+ scale_dtype = jnp .float32 ,
224+ )
225+ if use_tokamax_backend :
226+ with set_xla_metadata (MUST_FUSE = dlhs_counter ):
227+ dlhs = tokamax .ragged_dot_general (
228+ lhs = dlhs_dout ,
229+ rhs = rhs ,
230+ group_sizes = group_sizes ,
231+ ragged_dot_dimension_numbers = jax .lax .RaggedDotDimensionNumbers (
232+ dot_dimension_numbers = (([1 ], [2 ]), ([], [])),
233+ lhs_ragged_dimensions = [0 ],
234+ rhs_group_dimensions = [0 ],
235+ ),
236+ precision = jax .lax .Precision .DEFAULT ,
237+ preferred_element_type = preferred_element_type ,
238+ group_offset = group_offset ,
239+ implementation = "mosaic" ,
240+ )
241+ drhs = tokamax .tgmm (
242+ lhs .swapaxes (0 , 1 ),
243+ drhs_dout ,
244+ group_sizes = group_sizes ,
245+ ragged_dot_dimension_numbers = jax .lax .RaggedDotDimensionNumbers (
246+ dot_dimension_numbers = (([0 ], [0 ]), ([], [])),
247+ lhs_ragged_dimensions = [0 ],
248+ rhs_group_dimensions = [],
249+ ),
250+ precision = jax .lax .Precision .DEFAULT ,
251+ preferred_element_type = preferred_element_type ,
252+ group_offset = group_offset ,
253+ implementation = "mosaic" ,
254+ )
255+ else :
256+ dlhs = megablox_backend .gmm (
177257 dlhs_dout ,
178- quantization_rule .bwd_qtype ,
179- channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
180- calibration_method = quantization_rule .bwd_calibration_method ,
181- scale_dtype = jnp .float32 ,
258+ rhs ,
259+ group_sizes ,
260+ lhs_dtype ,
261+ tiling [3 :6 ],
262+ group_offset ,
263+ transpose_rhs = not transpose_rhs ,
264+ interpret = interpret ,
182265 )
183- drhs_dout = qpl .quantize (
266+ drhs = megablox_backend .tgmm (
267+ lhs .swapaxes (0 , 1 ),
184268 drhs_dout ,
185- quantization_rule .bwd_qtype ,
186- channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [1 ],
187- calibration_method = quantization_rule .bwd_calibration_method ,
188- scale_dtype = jnp .float32 ,
269+ group_sizes ,
270+ rhs_dtype ,
271+ tiling [- 3 :],
272+ group_offset ,
273+ num_actual_groups ,
274+ interpret = interpret ,
189275 )
190276
191- dlhs = backend .gmm (
192- dlhs_dout ,
193- rhs ,
194- group_sizes ,
195- lhs_dtype ,
196- tiling ,
197- group_offset ,
198- transpose_rhs = not transpose_rhs ,
199- interpret = interpret ,
200- )
201- drhs = backend .tgmm (
202- lhs .swapaxes (0 , 1 ),
203- drhs_dout ,
204- group_sizes ,
205- rhs_dtype ,
206- tiling ,
207- group_offset ,
208- num_actual_groups ,
209- interpret = interpret ,
210- )
277+ if quantization_rule and quantization_rule .bwd_qtype :
278+ drhs = jax .lax .psum_scatter (drhs , "fsdp" , scatter_dimension = 0 , tiled = True )
211279
212280 # NOTE: If the rhs transposition is fused into the forward pass we need to
213281 # return the transpose of the rhs gradient that we calculated above.
0 commit comments