@@ -122,7 +122,7 @@ def forward(
122122 round_scales_to_power_of_2 = True ,
123123 )
124124 A_scaled = A .to (torch .float32 ) * A_scales
125- A_fp8_row_major = to_fp8_saturated (A_scaled , torch .float8_e4m3fn )
125+ A_data_row_major = to_fp8_saturated (A_scaled , torch .float8_e4m3fn )
126126
127127 # Convert B to float8, column-major for right operand of grouped GEMM.
128128 # B_t shape: (E, K, N)
@@ -136,18 +136,18 @@ def forward(
136136 round_scales_to_power_of_2 = True ,
137137 )
138138 B_t_scaled = B_t .to (torch .float32 ) * B_t_scales
139- B_t_fp8_col_major = to_fp8_saturated (B_t_scaled , torch .float8_e4m3fn )
139+ B_t_data_col_major = to_fp8_saturated (B_t_scaled , torch .float8_e4m3fn )
140140
141141 # Store what we need for backward.
142142 ctx .save_for_backward (A , B_t , offs )
143143 ctx .out_dtype = out_dtype
144144
145145 # Perform scaled grouped GEMM and return result.
146146 # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
147- assert not _is_column_major (A_fp8_row_major ), (
147+ assert not _is_column_major (A_data_row_major ), (
148148 "A must be row-major for output = A @ B"
149149 )
150- assert _is_column_major (B_t_fp8_col_major ), (
150+ assert _is_column_major (B_t_data_col_major ), (
151151 "B must be column-major for output = A @ B"
152152 )
153153
@@ -157,8 +157,8 @@ def forward(
157157 A_scales = A_scales .squeeze (- 1 )
158158 B_t_scales = B_t_scales .squeeze (1 )
159159 return torch ._scaled_grouped_mm (
160- A_fp8_row_major ,
161- B_t_fp8_col_major ,
160+ A_data_row_major ,
161+ B_t_data_col_major ,
162162 A_scales .reciprocal (), # Reciprocals are needed for rescaling the output.
163163 B_t_scales .reciprocal (),
164164 offs ,
@@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor):
184184 round_scales_to_power_of_2 = True ,
185185 )
186186 grad_output_scaled = grad_output .to (torch .float32 ) * grad_output_scales
187- grad_output_fp8_row_major = to_fp8_saturated (
187+ grad_output_data_row_major = to_fp8_saturated (
188188 grad_output_scaled , torch .float8_e4m3fn
189189 )
190190
191191 # Compute B fp8 column-major for right operand of grouped GEMM:
192192 # grad_A = grad_output @ B.
193- B_fp8_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
193+ B_data_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
194194 B_t ._data if hasattr (B_t , "_data" ) else B_t ,
195195 output_dtype = torch .float8_e4m3fn ,
196196 round_scales_to_power_of_2 = True ,
@@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor):
199199 # Compute grad_A.
200200 # grad_A = grad_output @ B
201201 # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
202- assert not _is_column_major (grad_output_fp8_row_major ), (
202+ assert not _is_column_major (grad_output_data_row_major ), (
203203 "grad_output must be row-major for grad_A = grad_output @ B"
204204 )
205- assert _is_column_major (B_fp8_col_major ), (
205+ assert _is_column_major (B_data_col_major ), (
206206 "B must be column-major for grad_A = grad_output @ B"
207207 )
208208
@@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor):
212212 grad_output_scales = grad_output_scales .squeeze (- 1 )
213213 B_scales = B_scales .squeeze (1 )
214214 grad_A = torch ._scaled_grouped_mm (
215- grad_output_fp8_row_major ,
216- B_fp8_col_major ,
215+ grad_output_data_row_major ,
216+ B_data_col_major ,
217217 grad_output_scales .reciprocal (),
218218 B_scales .reciprocal (),
219219 offs ,
@@ -227,18 +227,18 @@ def backward(ctx, grad_output: torch.Tensor):
227227 # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
228228 # needed for grad_B: grad_output_t @ A
229229 # Use transpose method to avoid uncoalesced memory accesses.
230- grad_out_fp8_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
230+ grad_out_data_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
231231 grad_output .t ()
232232 .contiguous ()
233233 .t (), # Quantization is over 2x faster when input is col major, even with this transformation
234234 offs ,
235235 torch .float8_e4m3fn ,
236236 round_scales_to_power_of_2 = True ,
237237 )
238- grad_output_t_fp8_row_major = grad_out_fp8_colwise .t ()
238+ grad_output_t_data_row_major = grad_out_data_colwise .t ()
239239 grad_output_t_scales = grad_out_scales .t ()
240240
241- A_fp8_col_major , A_scales = triton_fp8_per_group_colwise_scales (
241+ A_data_col_major , A_scales = triton_fp8_per_group_colwise_scales (
242242 A .t ()
243243 .contiguous ()
244244 .t (), # Quantization is over 2x faster when input is col major, even with this transformation
@@ -249,19 +249,19 @@ def backward(ctx, grad_output: torch.Tensor):
249249
250250 # Compute grad_B = grad_output_t @ A.
251251 # grad_B = grad_output_t @ A
252- assert not _is_column_major (grad_output_t_fp8_row_major ), (
252+ assert not _is_column_major (grad_output_t_data_row_major ), (
253253 "grad_output_t must be row-major for grad_B = grad_output_t @ A"
254254 )
255- assert _is_column_major (A_fp8_col_major ), (
255+ assert _is_column_major (A_data_col_major ), (
256256 "A must be column-major for grad_B = grad_output_t @ A"
257257 )
258258
259259 # Per-token group scales computed via triton kernels above do not have
260260 # the empty dim like the scales computed via tensor_to_scale, so we need
261261 # don't need to squeeze here.
262262 grad_B = torch ._scaled_grouped_mm (
263- grad_output_t_fp8_row_major ,
264- A_fp8_col_major ,
263+ grad_output_t_data_row_major ,
264+ A_data_col_major ,
265265 grad_output_t_scales .reciprocal (),
266266 A_scales .reciprocal (),
267267 offs ,
@@ -295,13 +295,15 @@ def forward(
295295 ctx .out_dtype = out_dtype
296296 ctx .emulated = emulated
297297
298- # A_mx shape: (M, K)
298+ # A_data shape: (M, K)
299299 # A_scale shape: (M, K//block_size)
300- A_scale , A_mx = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
300+ A_scale , A_data = to_mx (
301+ A , elem_dtype = torch .float8_e4m3fn , block_size = block_size
302+ )
301303
302- # B_mx shape: (E, N, K)
304+ # B_data shape: (E, N, K)
303305 # B_scale shape: (E, N, K//block_size)
304- B_scales , B_mx = to_mx (
306+ B_scales , B_data = to_mx (
305307 B_t .transpose (- 2 , - 1 ),
306308 elem_dtype = torch .float8_e4m3fn ,
307309 block_size = block_size ,
@@ -315,9 +317,9 @@ def forward(
315317 else fbgemm_mxfp8_grouped_mm_2d_3d
316318 )
317319 out = mxfp8_2d_3d_grouped_mm (
318- A_mx ,
320+ A_data ,
319321 A_scale ,
320- B_mx ,
322+ B_data ,
321323 B_scales ,
322324 offs = offs ,
323325 block_size = block_size ,
@@ -332,15 +334,15 @@ def backward(ctx, grad_out: torch.Tensor):
332334 out_dtype = ctx .out_dtype
333335 emulated = ctx .emulated
334336
335- # grad_out_mx shape: (M, N)
337+ # grad_out_data shape: (M, N)
336338 # grad_out_scale shape: (M, N//block_size)
337- grad_out_scale , grad_out_mx = to_mx (
339+ grad_out_scale , grad_out_data = to_mx (
338340 grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
339341 )
340342
341- # B_mx shape: (E, K, N)
343+ # B_data shape: (E, K, N)
342344 # B_scale shape: (E, K, N//block_size)
343- B_scales , B_mx = to_mx (
345+ B_scales , B_data = to_mx (
344346 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
345347 B_t .contiguous (),
346348 elem_dtype = torch .float8_e4m3fn ,
@@ -354,43 +356,43 @@ def backward(ctx, grad_out: torch.Tensor):
354356 else fbgemm_mxfp8_grouped_mm_2d_3d
355357 )
356358 grad_A = mxfp8_2d_3d_grouped_mm (
357- grad_out_mx ,
359+ grad_out_data ,
358360 grad_out_scale ,
359- B_mx ,
361+ B_data ,
360362 B_scales ,
361363 offs = offs ,
362364 out_dtype = out_dtype ,
363365 )
364366
365- # grad_out_t_mx shape: (N, M)
367+ # grad_out_t_data shape: (N, M)
366368 # grad_out_t_scales shape: (N, M//block_size)
367- grad_out_t_scales , grad_out_t_mx = to_mx (
369+ grad_out_t_scales , grad_out_t_data = to_mx (
368370 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
369371 grad_out .transpose (- 2 , - 1 ).contiguous (),
370372 elem_dtype = torch .float8_e4m3fn ,
371373 block_size = block_size ,
372374 )
373375
374376 # Transpose A so we can scale along the M dimension, then un-transpose.
375- # A_t_mx shape: (K, M)
377+ # A_t_data shape: (K, M)
376378 # A_t_scales shape: (K, M//block_size)
377- A_t_scales , A_t_mx = to_mx (
379+ A_t_scales , A_t_data = to_mx (
378380 A .transpose (- 2 , - 1 ).contiguous (),
379381 elem_dtype = torch .float8_e4m3fn ,
380382 block_size = block_size ,
381383 )
382384
383- # A_mx shape = (M, K)
384- A_mx = A_t_mx .transpose (- 2 , - 1 )
385+ # A_data shape = (M, K)
386+ A_data = A_t_data .transpose (- 2 , - 1 )
385387
386388 # A_scales shape = (M//block_size, K)
387389 A_scales = A_t_scales .transpose (- 2 , - 1 )
388390
389391 # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
390392 grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d (
391- grad_out_t_mx ,
393+ grad_out_t_data ,
392394 grad_out_t_scales ,
393- A_mx ,
395+ A_data ,
394396 A_scales ,
395397 offs = offs ,
396398 )
@@ -402,64 +404,68 @@ def backward(ctx, grad_out: torch.Tensor):
402404
403405
404406def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
405- A_mx : torch .Tensor ,
407+ A_data : torch .Tensor ,
406408 A_scale : torch .Tensor ,
407- B_mx : torch .Tensor ,
409+ B_data : torch .Tensor ,
408410 B_scale : torch .Tensor ,
409411 offs : Optional [torch .Tensor ] = None ,
410412 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
411413 block_size : int = 32 ,
412414) -> torch .Tensor :
413- assert A_mx .ndim == 2 , f"A must be 2D, got { A_mx .ndim } "
414- assert B_mx .ndim == 3 , f"B must be 3D, got { B_mx .ndim } "
415- assert A_scale .shape [0 ] == A_mx .shape [0 ], (
416- f"A_scale must have same M dim as A_mx , got A={ A_mx .shape } and A_scale={ A_scale .shape } "
415+ assert A_data .ndim == 2 , f"A must be 2D, got { A_data .ndim } "
416+ assert B_data .ndim == 3 , f"B must be 3D, got { B_data .ndim } "
417+ assert A_scale .shape [0 ] == A_data .shape [0 ], (
418+ f"A_scale must have same M dim as A_data , got A={ A_data .shape } and A_scale={ A_scale .shape } "
417419 )
418- assert A_scale .shape [1 ] == A_mx .shape [1 ] // block_size , (
419- f"A_scale dim1 should be size K//block_size, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
420+ assert A_scale .shape [1 ] == A_data .shape [1 ] // block_size , (
421+ f"A_scale dim1 should be size K//block_size, got A={ A_data .shape } and A_scale={ A_scale .shape } "
420422 )
421- assert B_scale .shape [0 ] == B_mx .shape [0 ], (
422- f"B_scale must have same E dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
423+ assert B_scale .shape [0 ] == B_data .shape [0 ], (
424+ f"B_scale must have same E dim as B_data , got B={ B_data .shape } and B_scale={ B_scale .shape } "
423425 )
424- assert B_scale .shape [1 ] == B_mx .shape [1 ], (
425- f"B_scale must have same N dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
426+ assert B_scale .shape [1 ] == B_data .shape [1 ], (
427+ f"B_scale must have same N dim as B_data , got B={ B_data .shape } and B_scale={ B_scale .shape } "
426428 )
427- assert B_scale .shape [2 ] == B_mx .shape [2 ] // block_size , (
428- f"B_scale dim2 should be size K//block_size, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
429+ assert B_scale .shape [2 ] == B_data .shape [2 ] // block_size , (
430+ f"B_scale dim2 should be size K//block_size, got B={ B_data .shape } and B_scale={ B_scale .shape } "
429431 )
430432
431433 # Dequantize input
432- # A_mx shape: (M, K)
434+ # A_data shape: (M, K)
433435 # A_scale shape: (M, K//block_size)
434- A_orig_shape = A_mx .shape
436+ A_orig_shape = A_data .shape
435437
436438 # Reshape to be able to do per-scaling group multiplication
437- # A_mx shape: (M, K//block_size, block_size)
439+ # A_data shape: (M, K//block_size, block_size)
438440 # A_scale shape: (M, K//block_size, 1)
439- A_mx = A_mx .reshape (* A_mx .shape [:- 1 ], A_mx .shape [- 1 ] // block_size , block_size )
441+ A_data = A_data .reshape (
442+ * A_data .shape [:- 1 ], A_data .shape [- 1 ] // block_size , block_size
443+ )
440444 A_scale = A_scale .unsqueeze (- 1 )
441445
442446 # Rescale and cast to bfloat16
443- A = A_mx .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
447+ A = A_data .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
444448
445449 # Reshape back to original shape
446450 # A shape: (M, K)
447451 A = A .reshape (A_orig_shape )
448452
449453 # Dequantize weights
450454 # Tranpose to get block_size on rightmost dim
451- # B_mx shape: (E, N, K)
455+ # B_data shape: (E, N, K)
452456 # B_scale shape: (E, N, K//block_size)
453- E , N , K = B_mx .shape
457+ E , N , K = B_data .shape
454458
455459 # Reshape to be able to do per-scaling group multiplication
456- # B_mx shape: (E, N, K//block_size, block_size)
460+ # B_data shape: (E, N, K//block_size, block_size)
457461 # B_scale shape: (E, N, K//block_size, 1)
458- B_mx = B_mx .reshape (* B_mx .shape [:- 1 ], B_mx .shape [- 1 ] // block_size , block_size )
462+ B_data = B_data .reshape (
463+ * B_data .shape [:- 1 ], B_data .shape [- 1 ] // block_size , block_size
464+ )
459465 B_scale = B_scale .unsqueeze (- 1 )
460466
461467 # Rescale and cast to bfloat16
462- B = B_mx .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
468+ B = B_data .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
463469
464470 # Reshape back to original shape
465471 # B shape: (E, K, N)
@@ -471,27 +477,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
471477
472478
473479def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
474- A_mx : torch .Tensor , # (M, K)
480+ A_data : torch .Tensor , # (M, K)
475481 A_scale : torch .Tensor , # (M, K//block_size)
476- B_mx : torch .Tensor , # (K, N)
482+ B_data : torch .Tensor , # (K, N)
477483 B_scale : torch .Tensor , # (K//block_size, N)
478484 offs : torch .Tensor ,
479485 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
480486 block_size : int = 32 ,
481487) -> torch .Tensor :
482- assert A_mx .ndim == 2 , "A must be 2D"
483- assert B_mx .ndim == 2 , "B must be 2D"
488+ assert A_data .ndim == 2 , "A must be 2D"
489+ assert B_data .ndim == 2 , "B must be 2D"
484490 A = torch .zeros (
485- A_mx .shape ,
491+ A_data .shape ,
486492 dtype = torch .bfloat16 ,
487- device = A_mx .device ,
488- requires_grad = A_mx .requires_grad ,
493+ device = A_data .device ,
494+ requires_grad = A_data .requires_grad ,
489495 )
490496 B = torch .zeros (
491- B_mx .shape ,
497+ B_data .shape ,
492498 dtype = torch .bfloat16 ,
493- device = B_mx .device ,
494- requires_grad = B_mx .requires_grad ,
499+ device = B_data .device ,
500+ requires_grad = B_data .requires_grad ,
495501 )
496502
497503 # Dequantize input per each scaling group
@@ -507,7 +513,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
507513 # -- Dequantize A tensor
508514 # A_group shape: (M, group_size)
509515 # A_scale shape: (M, group_size//block_size)
510- A_group = A_mx [:, group_start_idx :group_end_idx ]
516+ A_group = A_data [:, group_start_idx :group_end_idx ]
511517 A_group_shape = A_group .shape
512518
513519 # Get scales for this group.
@@ -532,7 +538,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
532538
533539 # -- Dequantize B tensor
534540 # B_group shape is (group_size, N)
535- B_group = B_mx [group_start_idx :group_end_idx , :]
541+ B_group = B_data [group_start_idx :group_end_idx , :]
536542 B_group_shape = B_group .shape
537543
538544 # Scales shape is (group_size//block_size, N)
0 commit comments