Skip to content

Commit 9af74e1

Browse files
jananisriramfacebook-github-bot
authored andcommitted
[Inductor][FP8] Validate exhaustive autotuning for FP8 Inductor templates (pytorch#161442)
Summary: X-link: meta-pytorch/tritonbench#355 Validate exhaustive autotuning for FP8 Inductor templates: scaled MM templates require `block_k >= 32`. Before, exhaustive autotuning defaulted to a limited set of autotuning configs, as limitations for exhaustively autotuning on FP8 shapes had not been tested. Test Plan: ``` CUDA_VISIBLE_DEVICES=0 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=DEFAULT buck2 run mode/{opt,inplace} pytorch/t ritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --input-loader=/home/jananisriram/personal/exhaustive_autotune_rowwise_persistent_tma/json_fi les/rowwise_ptma_0.json --output="/home/jananisriram/personal/exhaustive_autotune_rowwise_persistent_tma/autotune/gpu0_bench.csv" --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/exhaustive_ autotune_rowwise_persistent_tma/autotune/gpu0.log ``` autotunes on the maximum configs available, rather than the defaults, and skips configs not compatible with TMA. Rollback Plan: Reviewed By: coconutruben Differential Revision: D80958642
1 parent 68d395d commit 9af74e1

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

torch/_inductor/template_heuristics.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,9 +1293,9 @@ def get_template_configs(
12931293
Convert config lists to template kwargs.
12941294
This replaces the logic from choices.get_mm_configs and inlines mm_options.
12951295
"""
1296-
assert isinstance(kernel_inputs, MMKernelInputs), (
1297-
f"{self.__class__.__name__} requires MMKernelInputs"
1298-
)
1296+
assert isinstance(
1297+
kernel_inputs, MMKernelInputs
1298+
), f"{self.__class__.__name__} requires MMKernelInputs"
12991299
input_nodes = kernel_inputs.nodes()
13001300
if len(input_nodes) < 2:
13011301
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
@@ -1467,9 +1467,9 @@ def get_template_configs(
14671467
input_nodes = kernel_inputs.nodes()
14681468

14691469
# Initial assertion from mm_common.scaled_mm_options
1470-
assert len(input_nodes) >= 4, (
1471-
f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}"
1472-
)
1470+
assert (
1471+
len(input_nodes) >= 4
1472+
), f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}"
14731473

14741474
# Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3])
14751475
scale_a = input_nodes[2]
@@ -1522,9 +1522,11 @@ class ScaledTMAConfigMixin(ScaledMMConfigMixin):
15221522

15231523
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
15241524
"""
1525-
TMA specific filtering, as num_warps=2 not safe for TMA
1525+
TMA specific filtering:
1526+
- num_warps=2 not safe for TMA
1527+
- block_k >= 32 required for TMA (requires inner-most dimension >= 32)
15261528
"""
1527-
configs = [c for c in configs if c.num_warps != 2]
1529+
configs = [c for c in configs if c.num_warps != 2 and c.block_k >= 32]
15281530
return super()._filter_configs(configs)
15291531

15301532
def get_template_configs(
@@ -1596,11 +1598,10 @@ def __init__(self) -> None:
15961598
super().__init__()
15971599
# Override mm_configs to use scaled_mm_configs
15981600
self.mm_configs = self.scaled_mm_configs
1599-
# NOTE: overriding exhaustive configs here to be the same as mm_configs
1600-
# as we haven't validated exhaustive support here yet
1601-
# TODO(coconutruben): remove this once we have validated exhaustive support
1602-
# for scaled_mm
1603-
self.exhaustive_configs = self.scaled_mm_configs
1601+
1602+
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
1603+
configs = [c for c in configs if c.block_k >= 32]
1604+
return super()._filter_configs(configs)
16041605

16051606

16061607
# TODO(coconutruben): replace with template.name once templates are importable
@@ -1614,11 +1615,6 @@ def __init__(self) -> None:
16141615
super().__init__()
16151616
# Override mm_configs to use scaled_persistent_mm_configs for TMA
16161617
self.mm_configs = self.scaled_persistent_mm_configs
1617-
# NOTE: overriding exhaustive configs here to be the same as mm_configs
1618-
# as we haven't validated exhaustive support here yet
1619-
# TODO(coconutruben): remove this once we have validated exhaustive support
1620-
# for scaled_mm
1621-
self.exhaustive_configs = self.scaled_persistent_mm_configs
16221618

16231619

16241620
# TODO(coconutruben): replace with template.name once templates are importable

0 commit comments

Comments
 (0)