Skip to content

Commit 006a243

Browse files
authored
Add F.scaled_mm (#2720)
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent a03dd40 commit 006a243

File tree

3 files changed

+402
-0
lines changed

3 files changed

+402
-0
lines changed

thunder/executors/torchex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,8 @@ def max_pool_with_indices_backward_meta(
16531653
nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional)
16541654
pad = _register_torch_operation("pad", module=torch.nn.functional)
16551655
scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional)
1656+
if hasattr(torch.nn.functional, "scaled_mm"):
1657+
scaled_mm = _register_torch_operation("scaled_mm", module=torch.nn.functional)
16561658
softmax = _register_torch_operation("softmax", like=ltorch._softmax)
16571659

16581660

@@ -1979,6 +1981,8 @@ def adaptive_avg_pool2d_bwd_wrapper(
19791981
pad_prim_impl = ex.register_operator("torch_pad_prim_impl", meta=prims.pad.meta, fn=_pad_prim_impl)
19801982
_register_implementation(prims.pad, pad_prim_impl, checker=_always_executable)
19811983
_register_implementation(ltorch._softmax, checker=_always_executable, execution_transform=_softmax_transform)
1984+
if hasattr(torch.nn.functional, "scaled_mm"):
1985+
_register_implementation(ltorch.scaled_mm, scaled_mm, checker=_always_executable)
19821986
_register_implementation(ltorch.scaled_dot_product_attention, scaled_dot_product_attention, checker=_always_executable)
19831987

19841988

thunder/tests/test_ops.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from collections.abc import Callable
22
import os
3+
import functools
34

45
import numpy as np
56
import pytest
67
import torch
78
from torch.testing import assert_close
89

10+
if hasattr(torch.nn.functional, "scaled_mm"):
11+
from torch.nn.functional import ScalingType, SwizzleType
12+
913
import thunder
1014
import thunder.core.devices as devices
1115
import thunder.core.dtypes as dtypes
@@ -419,6 +423,302 @@ def fn(a):
419423
assert_close(b, b_ref)
420424

421425

426+
def _cuda_version_tuple() -> tuple[int, int] | None:
427+
if torch.version.cuda is None:
428+
return None
429+
parts = torch.version.cuda.split(".")
430+
try:
431+
major = int(parts[0])
432+
minor = int(parts[1]) if len(parts) > 1 else 0
433+
return major, minor
434+
except ValueError:
435+
return None
436+
437+
438+
def _require_scaled_mm(fn):
439+
@functools.wraps(fn)
440+
def wrapper(*args, **kwargs):
441+
if not hasattr(torch.nn.functional, "scaled_mm"):
442+
pytest.skip("torch.nn.functional.scaled_mm is not found in this PyTorch")
443+
return fn(*args, **kwargs)
444+
445+
return wrapper
446+
447+
448+
def _ensure_fp8_tensorwise(device: torch.device) -> None:
449+
if torch.cuda.get_device_capability(device) < (8, 9):
450+
pytest.skip("scaled_mm tensor-wise support requires SM89 or newer")
451+
452+
453+
def _require_fp8_tensorwise(fn):
454+
@functools.wraps(fn)
455+
def wrapper(*args, **kwargs):
456+
device = torch.device("cuda")
457+
_ensure_fp8_tensorwise(device)
458+
return fn(*args, **kwargs)
459+
460+
return wrapper
461+
462+
463+
def _require_fp8_rowwise(device: torch.device) -> None:
464+
_ensure_fp8_tensorwise(device)
465+
if torch.cuda.get_device_capability(device) < (9, 0):
466+
pytest.skip("row-wise scaled_mm requires SM90 or newer")
467+
cuda_version = _cuda_version_tuple()
468+
if cuda_version is not None and cuda_version < (12, 9):
469+
pytest.skip("row-wise scaled_mm requires CUDA 12.9 or newer")
470+
471+
472+
def _require_fp8_blockwise(device: torch.device) -> None:
473+
_require_fp8_rowwise(device)
474+
475+
476+
# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L645-L659
477+
@requiresCUDA
478+
@_require_fp8_tensorwise
479+
@_require_scaled_mm
480+
def test_scaled_mm_tensorwise_matches_torch():
481+
device = torch.device("cuda")
482+
483+
def reference_fn(mat_a, mat_b, scale_a, scale_b):
484+
return torch.nn.functional.scaled_mm(
485+
mat_a,
486+
mat_b,
487+
scale_a,
488+
ScalingType.TensorWise,
489+
scale_b,
490+
ScalingType.TensorWise,
491+
swizzle_a=SwizzleType.NO_SWIZZLE,
492+
swizzle_b=SwizzleType.NO_SWIZZLE,
493+
output_dtype=torch.bfloat16,
494+
)
495+
496+
M, K, N = 16, 32, 16
497+
mat_a = torch.randn(M, K, device=device, dtype=torch.float32)
498+
mat_b = torch.randn(K, N, device=device, dtype=torch.float32)
499+
mat_a_lp = mat_a.to(torch.float8_e4m3fn)
500+
mat_b_lp = mat_b.to(torch.float8_e4m3fn)
501+
scale_a = torch.tensor(1.0, device=device, dtype=torch.float32)
502+
scale_b = torch.tensor(1.0, device=device, dtype=torch.float32)
503+
504+
try:
505+
expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
506+
except (NotImplementedError, RuntimeError) as exc:
507+
pytest.skip(str(exc))
508+
509+
jf = thunder.jit(reference_fn)
510+
result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b)
511+
assert_close(result, expected)
512+
513+
514+
# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L862-L910
515+
@requiresCUDA
516+
@_require_fp8_tensorwise
517+
@_require_scaled_mm
518+
def test_scaled_mm_matches_scaled_data():
519+
device = torch.device("cuda")
520+
521+
def quantize_to_fp8(tensor):
522+
dtype = torch.float8_e4m3fn
523+
max_val = torch.finfo(dtype).max
524+
amax = tensor.abs().max()
525+
encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32)
526+
quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype)
527+
decode = encode.reciprocal()
528+
return quant, decode, encode
529+
530+
def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
531+
return torch.nn.functional.scaled_mm(
532+
mat_a,
533+
mat_b,
534+
scale_a,
535+
ScalingType.TensorWise,
536+
scale_b,
537+
ScalingType.TensorWise,
538+
swizzle_a=SwizzleType.NO_SWIZZLE,
539+
swizzle_b=SwizzleType.NO_SWIZZLE,
540+
output_dtype=out_dtype,
541+
)
542+
543+
M, K, N = 32, 64, 32
544+
mat_a = torch.randn(M, K, device=device, dtype=torch.float32)
545+
mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32)
546+
547+
mat_a_lp, decode_a, encode_a = quantize_to_fp8(mat_a)
548+
mat_b_lp_pre, decode_b, encode_b = quantize_to_fp8(mat_b_base)
549+
# To use cublaslt, the second matrix needs to be column-major.
550+
mat_b_lp = mat_b_lp_pre.t()
551+
552+
try:
553+
reference = scaled_mm_fp8(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.float32)
554+
except (NotImplementedError, RuntimeError) as exc:
555+
pytest.skip(str(exc))
556+
557+
jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_fp8(a, b, sa, sb, out_dtype=torch.float32))
558+
thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b)
559+
560+
assert_close(thunder_out, reference)
561+
562+
563+
@requiresCUDA
564+
@_require_scaled_mm
565+
def test_scaled_mm_rowwise_matches_torch():
566+
device = torch.device("cuda")
567+
_require_fp8_rowwise(device)
568+
569+
def reference_fn(mat_a, mat_b, scale_a, scale_b):
570+
return torch.nn.functional.scaled_mm(
571+
mat_a,
572+
mat_b,
573+
scale_a,
574+
ScalingType.RowWise,
575+
scale_b,
576+
ScalingType.RowWise,
577+
swizzle_a=SwizzleType.NO_SWIZZLE,
578+
swizzle_b=SwizzleType.NO_SWIZZLE,
579+
output_dtype=torch.bfloat16,
580+
)
581+
582+
M, K, N = 16, 32, 16
583+
mat_a = torch.randn(M, K, device=device, dtype=torch.float32)
584+
mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32)
585+
mat_a_lp = mat_a.to(torch.float8_e4m3fn)
586+
# To use cublaslt, the second matrix needs to be column-major.
587+
mat_b_lp = mat_b_base.to(torch.float8_e4m3fn).t()
588+
scale_a = torch.ones((M, 1), device=device, dtype=torch.float32)
589+
scale_b = torch.ones((1, N), device=device, dtype=torch.float32)
590+
591+
try:
592+
expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
593+
except (NotImplementedError, RuntimeError) as exc:
594+
pytest.skip(str(exc))
595+
596+
jf = thunder.jit(reference_fn)
597+
result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b)
598+
assert_close(result, expected)
599+
600+
601+
@requiresCUDA
602+
@_require_scaled_mm
603+
def test_scaled_mm_rowwise_matches_scaled_data():
604+
device = torch.device("cuda")
605+
_require_fp8_rowwise(device)
606+
607+
dtype_fp8 = torch.float8_e4m3fn
608+
max_val = torch.finfo(dtype_fp8).max
609+
610+
def rowwise_quantize(tensor, *, dim):
611+
amax = tensor.abs().amax(dim=dim, keepdim=True)
612+
encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32)
613+
quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype_fp8)
614+
decode = encode.reciprocal()
615+
return quant, decode, encode
616+
617+
def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
618+
return torch.nn.functional.scaled_mm(
619+
mat_a,
620+
mat_b,
621+
scale_a,
622+
ScalingType.RowWise,
623+
scale_b,
624+
ScalingType.RowWise,
625+
swizzle_a=SwizzleType.NO_SWIZZLE,
626+
swizzle_b=SwizzleType.NO_SWIZZLE,
627+
output_dtype=out_dtype,
628+
)
629+
630+
M, K, N = 32, 64, 32
631+
mat_a = torch.randn(M, K, device=device, dtype=torch.bfloat16)
632+
mat_b = torch.randn(K, N, device=device, dtype=torch.bfloat16)
633+
634+
mat_a_lp, decode_a, encode_a = rowwise_quantize(mat_a.to(torch.float32), dim=1)
635+
mat_b_lp, decode_b, encode_b = rowwise_quantize(mat_b.to(torch.float32), dim=0)
636+
637+
try:
638+
reference = scaled_mm_rowwise(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.bfloat16)
639+
except (NotImplementedError, RuntimeError) as exc:
640+
pytest.skip(str(exc))
641+
642+
jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_rowwise(a, b, sa, sb, out_dtype=torch.bfloat16))
643+
thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b)
644+
645+
reference_f32 = reference.to(torch.float32)
646+
thunder_out_f32 = thunder_out.to(torch.float32)
647+
648+
assert_close(thunder_out_f32, reference_f32, atol=3e-2, rtol=3e-2)
649+
650+
651+
def _blockwise_quantize(tensor: torch.Tensor, block_rows: int, block_cols: int) -> tuple[torch.Tensor, torch.Tensor]:
652+
dtype_fp8 = torch.float8_e4m3fn
653+
max_val = torch.finfo(dtype_fp8).max
654+
655+
M, K = tensor.shape
656+
assert M % block_rows == 0 and K % block_cols == 0
657+
658+
reshaped = tensor.reshape(M // block_rows, block_rows, K // block_cols, block_cols)
659+
amax = reshaped.abs().amax(dim=(1, 3), keepdim=True)
660+
encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32)
661+
quant = torch.clamp(reshaped * encode, min=-max_val, max=max_val).to(dtype_fp8)
662+
663+
return quant.reshape(M, K), encode.reshape(M // block_rows, K // block_cols).to(tensor.device)
664+
665+
666+
@requiresCUDA
667+
@_require_scaled_mm
668+
@pytest.mark.parametrize("output_dtype", [torch.bfloat16])
669+
@pytest.mark.parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
670+
def test_scaled_mm_blockwise_matches_torch(output_dtype, lhs_block, rhs_block):
671+
device = torch.device("cuda")
672+
_require_fp8_blockwise(device)
673+
674+
M, K, N = 256, 256, 256
675+
mat_a = torch.randn(M, K, device=device, dtype=output_dtype).pow(3)
676+
mat_b_rows = torch.randn(N, K, device=device, dtype=output_dtype).pow(3)
677+
678+
mat_a_lp, encode_a = _blockwise_quantize(mat_a.to(torch.float32), lhs_block, 128)
679+
mat_b_lp_rows, encode_b = _blockwise_quantize(mat_b_rows.to(torch.float32), rhs_block, 128)
680+
mat_b_lp = mat_b_lp_rows.t().contiguous()
681+
682+
scale_a = encode_a.reciprocal().contiguous()
683+
scale_b = encode_b.reciprocal().t().contiguous()
684+
685+
recipe_map = {
686+
1: ScalingType.BlockWise1x128,
687+
128: ScalingType.BlockWise128x128,
688+
}
689+
690+
try:
691+
expected = torch.nn.functional.scaled_mm(
692+
mat_a_lp,
693+
mat_b_lp,
694+
scale_a,
695+
recipe_map[lhs_block],
696+
scale_b,
697+
recipe_map[rhs_block],
698+
swizzle_a=SwizzleType.NO_SWIZZLE,
699+
swizzle_b=SwizzleType.NO_SWIZZLE,
700+
output_dtype=output_dtype,
701+
)
702+
except (RuntimeError, NotImplementedError, ValueError) as exc:
703+
pytest.skip(str(exc))
704+
705+
fn = thunder.jit(
706+
lambda a, b, sa, sb: torch.nn.functional.scaled_mm(
707+
a,
708+
b,
709+
sa,
710+
recipe_map[lhs_block],
711+
sb,
712+
recipe_map[rhs_block],
713+
swizzle_a=SwizzleType.NO_SWIZZLE,
714+
swizzle_b=SwizzleType.NO_SWIZZLE,
715+
output_dtype=output_dtype,
716+
)
717+
)
718+
thunder_out = fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
719+
assert_close(thunder_out, expected)
720+
721+
422722
# https://github.com/Lightning-AI/lightning-thunder/issues/1857
423723
def test_max_with_int():
424724
def f(x, ids):

0 commit comments

Comments
 (0)