1717# pylint: disable=too-many-positional-arguments
1818
1919import functools
20- import itertools
21- import dataclasses
2220from typing import Literal
2321import jax
2422import jax .numpy as jnp
25- from jax .experimental .xla_metadata import set_xla_metadata
26- from MaxText .kernels .megablox import backend as megablox_backend
27- # from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
28- import tokamax
23+ from MaxText .kernels .megablox import backend
2924import qwix
3025import qwix .pallas as qpl
3126
32- _counter = itertools . count ()
27+
3328def gmm (
3429 lhs : jnp .ndarray ,
3530 rhs : jnp .ndarray ,
3631 group_sizes : jnp .ndarray ,
3732 preferred_element_type : jnp .dtype = jnp .float32 ,
38- tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
33+ tiling : tuple [int , int , int ] = (128 , 128 , 128 ),
3934 group_offset : jnp .ndarray | None = None ,
4035 existing_out : jnp .ndarray | None = None ,
4136 transpose_rhs : bool = False ,
4237 interpret : bool = False ,
4338 lhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None = None ,
4439 rhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None = None ,
4540 use_qwix_quantization : bool = False ,
46- use_tokamax_backend : bool = False ,
4741):
4842 """Grouped matrix multiplication operation."""
4943 quantization_rule = None
@@ -76,7 +70,6 @@ def gmm(
7670 transpose_rhs ,
7771 interpret ,
7872 quantization_rule ,
79- use_tokamax_backend ,
8073 )
8174
8275
@@ -85,13 +78,12 @@ def _gmm_fwd(
8578 rhs : jnp .ndarray ,
8679 group_sizes : jnp .ndarray ,
8780 preferred_element_type : jnp .dtype = jnp .float32 ,
88- tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
81+ tiling : tuple [int , int , int ] = (128 , 128 , 128 ),
8982 group_offset : jnp .ndarray | None = None ,
9083 existing_out : jnp .ndarray | None = None ,
9184 transpose_rhs : bool = False ,
9285 interpret : bool = False ,
9386 quantization_rule : qwix .QtRule | None = None ,
94- use_tokamax_backend : bool = False ,
9587) -> tuple [
9688 jnp .ndarray ,
9789 tuple [
@@ -102,17 +94,15 @@ def _gmm_fwd(
10294 ],
10395]:
10496 """Forward function for GMM VJP."""
105- fwd_counter = next (_counter )
10697 if quantization_rule :
10798 if quantization_rule .act_qtype :
108- with set_xla_metadata (MUST_FUSE = fwd_counter ):
109- lhs = qpl .quantize (
110- lhs ,
111- quantization_rule .act_qtype ,
112- channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
113- calibration_method = quantization_rule .act_calibration_method ,
114- scale_dtype = jnp .float32 ,
115- )
99+ lhs = qpl .quantize (
100+ lhs ,
101+ quantization_rule .act_qtype ,
102+ channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
103+ calibration_method = quantization_rule .act_calibration_method ,
104+ scale_dtype = jnp .float32 ,
105+ )
116106 if quantization_rule .weight_qtype :
117107 rhs = qpl .quantize (
118108 rhs ,
@@ -124,50 +114,29 @@ def _gmm_fwd(
124114 calibration_method = quantization_rule .weight_calibration_method ,
125115 scale_dtype = jnp .float32 ,
126116 )
127- # QAG is only supported for following conditions
128- if quantization_rule .weight_calibration_method .startswith ("fixed" ) and isinstance (rhs , qpl .QArray ):
129- rhs_qvalue = jax .lax .all_gather (rhs .qvalue , "fsdp" , axis = 0 , tiled = True )
130- rhs = dataclasses .replace (rhs , qvalue = rhs_qvalue )
131- if use_tokamax_backend :
132- with set_xla_metadata (MUST_FUSE = fwd_counter ):
133- out = tokamax .ragged_dot_general (
134- lhs = lhs ,
135- rhs = rhs ,
136- group_sizes = group_sizes ,
137- ragged_dot_dimension_numbers = tokamax .RaggedDotDimensionNumbers (
138- dot_dimension_numbers = (([1 ], [1 ]), ([], [])),
139- lhs_ragged_dimensions = [0 ],
140- rhs_group_dimensions = [0 ],
141- ),
142- precision = jax .lax .Precision .DEFAULT ,
143- preferred_element_type = preferred_element_type ,
144- group_offset = group_offset ,
145- implementation = "mosaic" ,
146- )
147- else :
148- out = megablox_backend .gmm (
117+
118+ out = backend .gmm (
149119 lhs ,
150120 rhs ,
151121 group_sizes ,
152122 preferred_element_type ,
153- tiling [: 3 ] ,
123+ tiling ,
154124 group_offset ,
155125 existing_out ,
156126 transpose_rhs = transpose_rhs ,
157127 interpret = interpret ,
158- )
128+ )
159129 return out , (lhs , rhs , group_sizes , group_offset )
160130
161131
162132def _gmm_bwd (
163133 lhs_dtype : jax .typing .DTypeLike ,
164134 rhs_dtype : jax .typing .DTypeLike ,
165135 preferred_element_type : jnp .dtype ,
166- tiling : tuple [int , int , int , int , int , int , int , int , int ],
136+ tiling : tuple [int , int , int ],
167137 transpose_rhs : bool ,
168138 interpret : bool ,
169139 quantization_rule : qwix .QtRule | None ,
170- use_tokamax_backend : bool ,
171140 residual : tuple [
172141 jnp .ndarray | qpl .QArray ,
173142 jnp .ndarray | qpl .QArray ,
@@ -191,8 +160,6 @@ def _gmm_bwd(
191160 # - drhs_dout: the incoming gradient used to calculate drhs.
192161
193162 # dlhs_dout and drhs_dout can be different when quantization is enabled.
194- dlhs_counter = next (_counter )
195- drhs_counter = next (_counter )
196163 dlhs_dout = grad
197164 drhs_dout = grad
198165 if isinstance (rhs , qpl .QArray ): # qvalue: [g, k, n] scale: [1, 1, n]
@@ -206,76 +173,41 @@ def _gmm_bwd(
206173 lhs = lhs .qvalue
207174 if quantization_rule and quantization_rule .bwd_qtype :
208175 # Enable backward pass quantization
209- with set_xla_metadata (MUST_FUSE = dlhs_counter ):
210- dlhs_dout = qpl .quantize (
211- dlhs_dout ,
212- quantization_rule .bwd_qtype ,
213- channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
214- calibration_method = quantization_rule .bwd_calibration_method ,
215- scale_dtype = jnp .float32 ,
216- )
217- with set_xla_metadata (MUST_FUSE = drhs_counter ):
218- drhs_dout = qpl .quantize (
219- drhs_dout ,
220- quantization_rule .bwd_qtype ,
221- channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [1 ],
222- calibration_method = quantization_rule .bwd_calibration_method ,
223- scale_dtype = jnp .float32 ,
224- )
225- if use_tokamax_backend :
226- with set_xla_metadata (MUST_FUSE = dlhs_counter ):
227- dlhs = tokamax .ragged_dot_general (
228- lhs = dlhs_dout ,
229- rhs = rhs ,
230- group_sizes = group_sizes ,
231- ragged_dot_dimension_numbers = jax .lax .RaggedDotDimensionNumbers (
232- dot_dimension_numbers = (([1 ], [2 ]), ([], [])),
233- lhs_ragged_dimensions = [0 ],
234- rhs_group_dimensions = [0 ],
235- ),
236- precision = jax .lax .Precision .DEFAULT ,
237- preferred_element_type = preferred_element_type ,
238- group_offset = group_offset ,
239- implementation = "mosaic" ,
240- )
241- drhs = tokamax .tgmm (
242- lhs .swapaxes (0 , 1 ),
243- drhs_dout ,
244- group_sizes = group_sizes ,
245- ragged_dot_dimension_numbers = jax .lax .RaggedDotDimensionNumbers (
246- dot_dimension_numbers = (([0 ], [0 ]), ([], [])),
247- lhs_ragged_dimensions = [0 ],
248- rhs_group_dimensions = [],
249- ),
250- precision = jax .lax .Precision .DEFAULT ,
251- preferred_element_type = preferred_element_type ,
252- group_offset = group_offset ,
253- implementation = "mosaic" ,
254- )
255- else :
256- dlhs = megablox_backend .gmm (
176+ dlhs_dout = qpl .quantize (
257177 dlhs_dout ,
258- rhs ,
259- group_sizes ,
260- lhs_dtype ,
261- tiling [3 :6 ],
262- group_offset ,
263- transpose_rhs = not transpose_rhs ,
264- interpret = interpret ,
178+ quantization_rule .bwd_qtype ,
179+ channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [0 ],
180+ calibration_method = quantization_rule .bwd_calibration_method ,
181+ scale_dtype = jnp .float32 ,
265182 )
266- drhs = megablox_backend .tgmm (
267- lhs .swapaxes (0 , 1 ),
183+ drhs_dout = qpl .quantize (
268184 drhs_dout ,
269- group_sizes ,
270- rhs_dtype ,
271- tiling [- 3 :],
272- group_offset ,
273- num_actual_groups ,
274- interpret = interpret ,
185+ quantization_rule .bwd_qtype ,
186+ channelwise_axes = [] if quantization_rule .disable_channelwise_axes else [1 ],
187+ calibration_method = quantization_rule .bwd_calibration_method ,
188+ scale_dtype = jnp .float32 ,
275189 )
276190
277- if quantization_rule and quantization_rule .bwd_qtype :
278- drhs = jax .lax .psum_scatter (drhs , "fsdp" , scatter_dimension = 0 , tiled = True )
191+ dlhs = backend .gmm (
192+ dlhs_dout ,
193+ rhs ,
194+ group_sizes ,
195+ lhs_dtype ,
196+ tiling ,
197+ group_offset ,
198+ transpose_rhs = not transpose_rhs ,
199+ interpret = interpret ,
200+ )
201+ drhs = backend .tgmm (
202+ lhs .swapaxes (0 , 1 ),
203+ drhs_dout ,
204+ group_sizes ,
205+ rhs_dtype ,
206+ tiling ,
207+ group_offset ,
208+ num_actual_groups ,
209+ interpret = interpret ,
210+ )
279211
280212 # NOTE: If the rhs transposition is fused into the forward pass we need to
281213 # return the transpose of the rhs gradient that we calculated above.
0 commit comments