Skip to content

Commit f192dc8

Browse files
NikhilAPatelfacebook-github-bot
authored andcommitted
[Inductor][Triton] Fix SCALING_ROWWISE misclassification for scalar scales (pytorch#160450)
Summary: Pull Request resolved: pytorch#160450 In `tuned_scaled_mm()`, we unsqeeuze any scalar scale from [] -> [1, 1]. Later, when we are determining how to set the `SCALING_ROWWISE` kernel attribute, we check whether the scale has 2 dimensions. However, since we previously unsqueezed any scalar scales, this will always evaluate to True. Test Plan: ` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- test_tensorwise_scaling ` ` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- test_rowwise_scaling ` ` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- test_tensorwise_scaling_acceptable_input_dims ` ` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- test_tensorwise_scaling_acceptable_input_dims ` Rollback Plan: Reviewed By: eellison, PaulZhang12 Differential Revision: D80108117
1 parent adecb0c commit f192dc8

File tree

2 files changed

+165
-2
lines changed

2 files changed

+165
-2
lines changed

test/inductor/test_fp8.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import Tensor
99
from torch._inductor import config, utils
1010
from torch._inductor.test_case import run_tests, TestCase
11+
from torch._inductor.utils import run_and_get_code
1112
from torch.testing._internal.common_cuda import (
1213
PLATFORM_SUPPORTS_FP8,
1314
PLATFORM_SUPPORTS_MX_GEMM,
@@ -24,6 +25,7 @@
2425
HAS_CPU,
2526
HAS_CUDA_AND_TRITON,
2627
)
28+
from torch.testing._internal.jit_utils import FileCheck
2729
from torch.utils._triton import has_triton_tma_device
2830

2931

@@ -465,6 +467,86 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
465467
# setting a small absolute tolerance in these tests
466468
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
467469

470+
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
471+
@unittest.skipIf(
472+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
473+
)
474+
@parametrize("dtype", (torch.bfloat16, torch.float32))
475+
@parametrize("shape", ("16,32,32", "1024,1024,512"))
476+
@parametrize("use_fast_accum", (False, True))
477+
def test_tensorwise_scaling_tma_template(
478+
self,
479+
dtype: torch.dtype,
480+
shape: str,
481+
use_fast_accum: bool,
482+
):
483+
device = "cuda"
484+
dtype_float8 = torch.float8_e4m3fn
485+
dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device)
486+
487+
shape = [int(dim) for dim in shape.split(",")]
488+
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
489+
# input and output dtypes of _scaled_mm do not need to be the same, but
490+
# typically in a model they are
491+
x = torch.randn(M, K, dtype=dtype, device=device)
492+
w = torch.randn(N, K, dtype=dtype, device=device)
493+
bias = None
494+
495+
# quantize weight (prior to inference)
496+
w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
497+
w_t_fp8 = w_fp8.t()
498+
499+
# quantize input x
500+
x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
501+
502+
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
503+
y = torch._scaled_mm(
504+
x_fp8,
505+
w_t_fp8,
506+
x_inverse_scale,
507+
w_inverse_scale,
508+
bias,
509+
out_dtype=dtype,
510+
use_fast_accum=use_fast_accum,
511+
)
512+
return y
513+
514+
y_eager = linear(
515+
x_fp8,
516+
x_inverse_scale,
517+
w_t_fp8,
518+
w_inverse_scale,
519+
bias,
520+
)
521+
with config.patch(
522+
{
523+
"triton.enable_persistent_tma_matmul": True,
524+
"test_configs.autotune_choice_name_regex": "triton_scaled_mm_device_tma",
525+
"max_autotune_gemm_backends": "TRITON",
526+
"max_autotune": True,
527+
}
528+
):
529+
linear_compiled = torch.compile(
530+
linear, backend="inductor", mode="max-autotune"
531+
)
532+
y_compiled, code = run_and_get_code(
533+
linear_compiled,
534+
x_fp8,
535+
x_inverse_scale,
536+
w_t_fp8,
537+
w_inverse_scale,
538+
bias,
539+
)
540+
541+
FileCheck().check("SCALING_ROWWISE : tl.constexpr = False").run(code[0])
542+
self.assertEqual(y_eager.dtype, dtype)
543+
self.assertEqual(y_compiled.dtype, dtype)
544+
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
545+
# autotuning for the compiled case, the results can be different because of
546+
# the way blocks of results are accumulated (float addition not associative), so
547+
# setting a small absolute tolerance in these tests
548+
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
549+
468550
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
469551
@parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512"))
470552
@parametrize("has_bias", (False, True))
@@ -531,6 +613,81 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
531613
self.assertEqual(y_compiled.dtype, dtype)
532614
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
533615

616+
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
617+
@unittest.skipIf(
618+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
619+
)
620+
@parametrize("shape", ("16,32,32", "1024,1024,512"))
621+
@parametrize("use_fast_accum", (False, True))
622+
def test_rowwise_scaling_tma_template(
623+
self,
624+
shape: str,
625+
use_fast_accum: bool,
626+
):
627+
# Only bf16 output type is supported for row-wise scaling, not fp32
628+
dtype: torch.dtype = torch.bfloat16
629+
device = "cuda"
630+
dtype_float8 = torch.float8_e4m3fn
631+
dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device)
632+
633+
shape = [int(dim) for dim in shape.split(",")]
634+
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
635+
x = torch.randn(M, K, dtype=dtype, device=device)
636+
w = torch.randn(N, K, dtype=dtype, device=device)
637+
bias = None
638+
639+
# quantize weight (prior to inference)
640+
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
641+
w_t_fp8 = w_fp8.t()
642+
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
643+
644+
# quantize input x
645+
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
646+
647+
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
648+
y = torch._scaled_mm(
649+
x_fp8,
650+
w_t_fp8,
651+
x_inverse_scale,
652+
w_inverse_scale,
653+
bias,
654+
out_dtype=dtype,
655+
use_fast_accum=use_fast_accum,
656+
)
657+
return y
658+
659+
y_eager = linear(
660+
x_fp8,
661+
x_inverse_scale,
662+
w_t_fp8,
663+
w_inverse_scale,
664+
bias,
665+
)
666+
with config.patch(
667+
{
668+
"triton.enable_persistent_tma_matmul": True,
669+
"test_configs.autotune_choice_name_regex": "triton_scaled_mm_device_tma",
670+
"max_autotune_gemm_backends": "TRITON",
671+
"max_autotune": True,
672+
}
673+
):
674+
linear_compiled = torch.compile(
675+
linear, backend="inductor", mode="max-autotune"
676+
)
677+
y_compiled, code = run_and_get_code(
678+
linear_compiled,
679+
x_fp8,
680+
x_inverse_scale,
681+
w_t_fp8,
682+
w_inverse_scale,
683+
bias,
684+
)
685+
686+
FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0])
687+
self.assertEqual(y_eager.dtype, dtype)
688+
self.assertEqual(y_compiled.dtype, dtype)
689+
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
690+
534691
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
535692
@parametrize("M", (1, 3, 33, 257, 1024))
536693
@parametrize("K", (16, 32, 1024))

torch/_inductor/template_heuristics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,11 @@ def are_compatible_scales(size_a: Any, size_b: Any) -> bool:
14871487

14881488
return False
14891489

1490+
def is_scalar_like(sz: Any) -> bool:
1491+
return (len(sz) == 0) or all(
1492+
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
1493+
)
1494+
14901495
size_a, size_b = scale_a.get_size(), scale_b.get_size()
14911496
assert are_compatible_scales(size_a, size_b), (
14921497
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
@@ -1500,8 +1505,9 @@ def are_compatible_scales(size_a: Any, size_b: Any) -> bool:
15001505
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
15011506
# Override accumulator type for scaled MM
15021507
template_kwargs["ACC_TYPE"] = "tl.float32"
1503-
# Add SCALING_ROWWISE attribute based on scale_a tensor shape
1504-
template_kwargs["SCALING_ROWWISE"] = len(size_a) == 2
1508+
# Add SCALING_ROWWISE attribute based on scale tensor shapes
1509+
both_scalar_like = is_scalar_like(size_a) and is_scalar_like(size_b)
1510+
template_kwargs["SCALING_ROWWISE"] = not both_scalar_like
15051511

15061512
yield template_kwargs
15071513

0 commit comments

Comments
 (0)