Skip to content

Commit 5ef75e2

Browse files
support for 2d-2d emulated mxfp8 grouped gemm (#2632)
1 parent 9ab291f commit 5ef75e2

File tree

3 files changed

+280
-9
lines changed

3 files changed

+280
-9
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
pytest.importorskip("triton", reason="Triton required to run this test")
1111

12-
from torchao.prototype.moe_training.utils import generate_jagged_offs
12+
from torchao.prototype.moe_training.utils import (
13+
_to_mxfp8_per_group_colwise,
14+
_to_mxfp8_per_group_rowwise,
15+
generate_jagged_offs,
16+
)
1317
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1418

1519
# We need to skip before doing any imports which would use triton, since
@@ -30,8 +34,9 @@
3034
from torchao.float8.float8_training_tensor import LinearMMConfig
3135
from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated
3236
from torchao.prototype.moe_training.scaled_grouped_mm import (
37+
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
38+
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
3339
_scaled_grouped_mm,
34-
emulated_mxfp8_scaled_grouped_mm,
3540
)
3641
from torchao.prototype.mx_formats.mx_tensor import to_mx
3742
from torchao.testing.utils import skip_if_rocm
@@ -223,7 +228,7 @@ def compute_reference_forward(
223228
@skip_if_rocm("ROCm not supported")
224229
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
225230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
226-
def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
231+
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
227232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
228233
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
229234
offs = generate_jagged_offs(num_experts, M)
@@ -242,7 +247,7 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
242247
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
243248

244249
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
245-
out = emulated_mxfp8_scaled_grouped_mm(
250+
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
246251
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
247252
)
248253

@@ -252,6 +257,53 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
252257

253258

254259
@skip_if_rocm("ROCm not supported")
260+
@pytest.mark.parametrize("M", (1024, 4096))
261+
@pytest.mark.parametrize("N", (1024, 4096))
262+
@pytest.mark.parametrize("num_experts", (8, 16))
263+
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
264+
# Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x
265+
block_size = 32
266+
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
267+
grad_out_t = grad_out.t().contiguous()
268+
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
269+
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
270+
x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone()
271+
272+
# bf16 reference grouped gemm
273+
ref_out = torch._grouped_mm(
274+
grad_out_t_ref,
275+
x_ref,
276+
offs=offs_ref,
277+
out_dtype=torch.bfloat16,
278+
)
279+
280+
# mxpf8 grouped gemm
281+
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
282+
grad_out_t_mx, grad_out_t_scale = _to_mxfp8_per_group_rowwise(
283+
grad_out_t,
284+
offs=offs,
285+
block_size=block_size,
286+
)
287+
x_mx, x_scale = _to_mxfp8_per_group_colwise(
288+
x,
289+
offs=offs,
290+
block_size=block_size,
291+
)
292+
out = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
293+
grad_out_t_mx,
294+
grad_out_t_scale,
295+
x_mx,
296+
x_scale,
297+
offs=offs,
298+
out_dtype=torch.bfloat16,
299+
block_size=block_size,
300+
)
301+
302+
sqnr = compute_error(ref_out, out)
303+
min_sqnr = 27.0
304+
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
305+
306+
255307
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
256308
@pytest.mark.parametrize("num_experts", (1, 8, 16))
257309
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def forward(
300300

301301
# Store what we need for backward.
302302
ctx.save_for_backward(A, B_t, offs)
303+
ctx.block_size = block_size
303304
ctx.out_dtype = out_dtype
304305

305306
# Perform scaled grouped GEMM and return result.
@@ -317,7 +318,7 @@ def forward(
317318
return out
318319

319320
@staticmethod
320-
def backward(ctx, grad_output: torch.Tensor):
321+
def backward(ctx, grad_out: torch.Tensor):
321322
raise NotImplementedError
322323

323324

@@ -352,6 +353,27 @@ def emulated_mxfp8_scaled_grouped_mm(
352353
offs: Optional[torch.Tensor] = None,
353354
out_dtype: Optional[torch.dtype] = torch.bfloat16,
354355
block_size: int = 32,
356+
) -> torch.Tensor:
357+
if A_mx.ndim == 2 and B_t_mx.ndim == 3:
358+
return _emulated_mxfp8_scaled_grouped_mm_2d_3d(
359+
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
360+
)
361+
elif A_mx.ndim == 2 and B_t_mx.ndim == 2:
362+
return _emulated_mxfp8_scaled_grouped_mm_2d_2d(
363+
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
364+
)
365+
else:
366+
raise NotImplementedError
367+
368+
369+
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
370+
A_mx: torch.Tensor,
371+
A_scale: torch.Tensor,
372+
B_t_mx: torch.Tensor,
373+
B_t_scale: torch.Tensor,
374+
offs: Optional[torch.Tensor] = None,
375+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
376+
block_size: int = 32,
355377
) -> torch.Tensor:
356378
# Dequantize input
357379
# A_mx shape: (M, K)
@@ -397,3 +419,100 @@ def emulated_mxfp8_scaled_grouped_mm(
397419
# Perform bf16 grouped GEMM.
398420
out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype)
399421
return out
422+
423+
424+
def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
425+
A_mx: torch.Tensor, # (M, K)
426+
A_scale: torch.Tensor, # (M, K//block_size)
427+
B_mx: torch.Tensor, # (K, N)
428+
B_scale: torch.Tensor, # (K//block_size, N)
429+
offs: torch.Tensor,
430+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
431+
block_size: int = 32,
432+
) -> torch.Tensor:
433+
assert A_mx.ndim == 2, "A must be 2D"
434+
assert B_mx.ndim == 2, "B must be 2D"
435+
A = torch.zeros(
436+
A_mx.shape,
437+
dtype=torch.bfloat16,
438+
device=A_mx.device,
439+
requires_grad=A_mx.requires_grad,
440+
)
441+
B = torch.zeros(
442+
B_mx.shape,
443+
dtype=torch.bfloat16,
444+
device=B_mx.device,
445+
requires_grad=B_mx.requires_grad,
446+
)
447+
448+
# Dequantize input per each scaling group
449+
scales_start_idx = 0
450+
group_start_idx = 0
451+
for group_end_idx in offs.tolist():
452+
group_size = group_end_idx - group_start_idx
453+
scale_group_size = group_size // block_size
454+
if group_size == 0:
455+
group_start_idx = group_end_idx
456+
continue
457+
458+
# -- Dequantize A tensor
459+
# A_group shape: (M, group_size)
460+
# A_scale shape: (M, group_size//block_size)
461+
A_group = A_mx[:, group_start_idx:group_end_idx]
462+
A_group_shape = A_group.shape
463+
464+
# Get scales for this group.
465+
# scales shape: (M, group_size//block_size)
466+
scales = A_scale[:, scales_start_idx : scales_start_idx + scale_group_size]
467+
468+
# Reshape to be able to do per-scaling group multiplication
469+
# A_group shape: (M, group_size//block_size, block_size)
470+
# A_scale shape: (M, group_size//block_size, 1)
471+
A_group = A_group.reshape(
472+
*A_group.shape[:-1], A_group.shape[-1] // block_size, block_size
473+
)
474+
scales = scales.unsqueeze(-1)
475+
476+
# Rescale and cast to bfloat16
477+
A_group = A_group.to(torch.bfloat16) * scales.to(torch.bfloat16)
478+
479+
# Reshape back to original shape and store in dequantized A buffer
480+
# A shape: (M, group_size)
481+
A_group = A_group.reshape(A_group_shape)
482+
A[:, group_start_idx:group_end_idx] = A_group
483+
484+
# -- Dequantize B tensor
485+
# B_group shape is (group_size, N)
486+
B_group = B_mx[group_start_idx:group_end_idx, :]
487+
B_group_shape = B_group.shape
488+
489+
# Scales shape is (group_size//block_size, N)
490+
scales = B_scale[scales_start_idx : scales_start_idx + scale_group_size, :]
491+
492+
# Transpose B to get scaling group on rightmost dim, to make things easier
493+
# B_group_shape = (N, group_size)
494+
# scales shape = N, group_size//block_size)
495+
B_group, scales = B_group.transpose(-2, -1), scales.transpose(-2, -1)
496+
497+
# Reshape B to be able to do per-scaling group multiplication
498+
# B_group shape: (N, group_size//block_size, block_size)
499+
# scales shape: (N, group_size//block_size, 1)
500+
B_group = B_group.reshape(
501+
*B_group.shape[:-1], B_group.shape[-1] // block_size, block_size
502+
)
503+
scales = scales.unsqueeze(-1)
504+
505+
# Cast to bf16 and perform scaling
506+
B_group = B_group.to(torch.bfloat16) * scales.to(torch.bfloat16)
507+
508+
# Reshape B_group back to original shape and store in dequantized B buffer
509+
B_group = B_group.reshape(B_group_shape[1], B_group_shape[0]).transpose(-2, -1)
510+
B[group_start_idx:group_end_idx, :] = B_group
511+
512+
# Increment group start and scale start indices
513+
group_start_idx = group_end_idx
514+
scales_start_idx += scale_group_size
515+
516+
# Perform bf16 grouped GEMM using dequantized A and B.
517+
out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype)
518+
return out

torchao/prototype/moe_training/utils.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
from torchao.float8.config import ScalingGranularity
77
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
8+
from torchao.prototype.mx_formats.mx_tensor import to_mx
89

910

11+
# --- float8 rowwise scaling ---
1012
def _to_2d_jagged_float8_tensor_colwise(
1113
A_col_major: torch.Tensor,
1214
offs: torch.Tensor,
@@ -143,6 +145,104 @@ def _to_2d_jagged_float8_tensor_rowwise(
143145
return x_fp8, x_scales
144146

145147

148+
# --- mxfp8 scaling ---
149+
def _to_mxfp8_per_group_rowwise(
150+
x: torch.Tensor,
151+
offs: torch.Tensor,
152+
block_size: int = 32,
153+
) -> Tuple[torch.Tensor, torch.Tensor]:
154+
"""
155+
This is a reference implementation used for testing correctness, it is not performant.
156+
157+
This function converts the 2D input tensor a mxpf8 tensor along dim 0 with per-token-group scaling,
158+
where groups are determined based on the offsets.
159+
160+
Args:
161+
A (torch.Tensor): The input tensor to be converted to a jagged mxfp8 tensor.
162+
163+
Returns:
164+
A tuple containing the jagged mxpf8 tensor and the scales used for the conversion.
165+
"""
166+
assert x.ndim == 2, "input tensor must be 2D"
167+
assert offs.numel() > 0, "offs must be non-empty"
168+
169+
x_mx = torch.empty_like(x, dtype=torch.float8_e4m3fn)
170+
x_scales = None
171+
172+
start_idx = 0
173+
for end_idx in offs.tolist():
174+
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
175+
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)
176+
177+
# Perform mxfp8 conversion on logically distinct subtensor.
178+
scales, mx_subtensor = to_mx(
179+
subtensor.contiguous(),
180+
elem_dtype=torch.float8_e4m3fn,
181+
block_size=block_size,
182+
)
183+
184+
# Store this portion of the resulting mxfp8 tensor and scales.
185+
x_mx[:, start_idx:end_idx] = mx_subtensor
186+
if x_scales is None:
187+
x_scales = scales.view(torch.uint8) # Needed to support cat op below
188+
else:
189+
x_scales = torch.cat((x_scales, scales.view(torch.uint8)), dim=1)
190+
191+
# Update start index for next group.
192+
start_idx = end_idx
193+
194+
return x_mx, x_scales.view(torch.float8_e8m0fnu)
195+
196+
197+
def _to_mxfp8_per_group_colwise(
198+
A_col_major: torch.Tensor, # (K, N)
199+
offs: torch.Tensor,
200+
block_size: int = 32,
201+
) -> Tuple[torch.Tensor, torch.Tensor]:
202+
"""
203+
This is a reference implementation used for testing correctness, it is not performant.
204+
205+
This function converts the 2D input tensor a mxpf8 tensor along dim 1 with per-token-group scaling,
206+
where groups are determined based on the offsets.
207+
208+
Args:
209+
A (torch.Tensor): The input tensor to be converted to a mxfp8 tensor.
210+
211+
Returns:
212+
A tuple containing the mxpf8 tensor and the scales used for the conversion.
213+
"""
214+
assert A_col_major.ndim == 2, "A must be 2D"
215+
assert offs.numel() > 0, "offs must be non-empty"
216+
217+
A_mx = torch.empty_like(A_col_major, dtype=torch.float8_e4m3fn)
218+
A_scales = None
219+
220+
start_idx = 0
221+
for end_idx in offs.tolist():
222+
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
223+
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, N)
224+
225+
# Convert to mxfp8 along dim1, by transposing, converting, and transposing back.
226+
scales, mx_subtensor = to_mx(
227+
subtensor.transpose(-2, -1).contiguous(),
228+
elem_dtype=torch.float8_e4m3fn,
229+
block_size=block_size,
230+
)
231+
scales, mx_subtensor = scales.transpose(-2, -1), mx_subtensor.transpose(-2, -1)
232+
233+
# Store this portion of the resulting mxfp8 tensor and scales.
234+
A_mx[start_idx:end_idx, :] = mx_subtensor
235+
if A_scales is None:
236+
A_scales = scales.view(torch.uint8) # Needed to support cat op below
237+
else:
238+
A_scales = torch.cat((A_scales, scales.view(torch.uint8)), dim=0)
239+
240+
# Update start index for next group.
241+
start_idx = end_idx
242+
243+
return A_mx, A_scales.view(torch.float8_e8m0fnu)
244+
245+
146246
def _is_column_major(x: torch.Tensor) -> bool:
147247
"""
148248
This function checks if the input tensor is column-major.
@@ -157,7 +257,7 @@ def _is_column_major(x: torch.Tensor) -> bool:
157257
return x.stride(-2) == 1 and x.stride(-1) > 1
158258

159259

160-
def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
260+
def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):
161261
"""
162262
Utility function for tests and benchmarks.
163263
@@ -170,11 +270,11 @@ def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
170270
torch.Tensor: A tensor of length E with the specified properties.
171271
"""
172272
# Ensure M is divisible by 16
173-
if M % 16 != 0:
174-
raise ValueError("M must be divisible by 16")
273+
if M % multiple_of != 0:
274+
raise ValueError(f"M must be divisible by {multiple_of}")
175275

176276
# Generate a list of possible values
177-
possible_values = [i for i in range(0, M + 1, 16)]
277+
possible_values = [i for i in range(multiple_of, M + 1, multiple_of)]
178278

179279
# If E is larger than the number of possible values, raise an error
180280
if E > len(possible_values):

0 commit comments

Comments
 (0)