Skip to content

Commit 6c7e78d

Browse files
jananisriramfacebook-github-bot
authored andcommitted
[Inductor][FP8] Validate exhaustive autotuning for FP8 Inductor templates (pytorch#161442)
Summary: X-link: meta-pytorch/tritonbench#355 Validate exhaustive autotuning for FP8 Inductor templates: scaled MM templates require `block_k >= 32`. Before, exhaustive autotuning defaulted to a limited set of autotuning configs, as limitations for exhaustively autotuning on FP8 shapes had not been tested. Test Plan: ``` CUDA_VISIBLE_DEVICES=0 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=DEFAULT buck2 run mode/{opt,inplace} pytorch/t ritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --input-loader=/home/jananisriram/personal/exhaustive_autotune_rowwise_persistent_tma/json_fi les/rowwise_ptma_0.json --output="/home/jananisriram/personal/exhaustive_autotune_rowwise_persistent_tma/autotune/gpu0_bench.csv" --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/exhaustive_ autotune_rowwise_persistent_tma/autotune/gpu0.log ``` autotunes on the maximum configs available, rather than the defaults, and skips configs not compatible with TMA. Rollback Plan: Reviewed By: coconutruben Differential Revision: D80958642
1 parent 624bc36 commit 6c7e78d

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

torch/_inductor/template_heuristics.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,9 +1300,9 @@ def get_template_configs(
13001300
Convert config lists to template kwargs.
13011301
This replaces the logic from choices.get_mm_configs and inlines mm_options.
13021302
"""
1303-
assert isinstance(kernel_inputs, MMKernelInputs), (
1304-
f"{self.__class__.__name__} requires MMKernelInputs"
1305-
)
1303+
assert isinstance(
1304+
kernel_inputs, MMKernelInputs
1305+
), f"{self.__class__.__name__} requires MMKernelInputs"
13061306
input_nodes = kernel_inputs.nodes()
13071307
if len(input_nodes) < 2:
13081308
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
@@ -1474,9 +1474,9 @@ def get_template_configs(
14741474
input_nodes = kernel_inputs.nodes()
14751475

14761476
# Initial assertion from mm_common.scaled_mm_options
1477-
assert len(input_nodes) >= 4, (
1478-
f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}"
1479-
)
1477+
assert (
1478+
len(input_nodes) >= 4
1479+
), f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}"
14801480

14811481
# Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3])
14821482
scale_a = input_nodes[2]
@@ -1529,9 +1529,11 @@ class ScaledTMAConfigMixin(ScaledMMConfigMixin):
15291529

15301530
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
15311531
"""
1532-
TMA specific filtering, as num_warps=2 not safe for TMA
1532+
TMA specific filtering:
1533+
- num_warps=2 not safe for TMA
1534+
- block_k >= 32 required for TMA (requires inner-most dimension >= 32)
15331535
"""
1534-
configs = [c for c in configs if c.num_warps != 2]
1536+
configs = [c for c in configs if c.num_warps != 2 and c.block_k >= 32]
15351537
return super()._filter_configs(configs)
15361538

15371539
def get_template_configs(
@@ -1603,11 +1605,10 @@ def __init__(self) -> None:
16031605
super().__init__()
16041606
# Override mm_configs to use scaled_mm_configs
16051607
self.mm_configs = self.scaled_mm_configs
1606-
# NOTE: overriding exhaustive configs here to be the same as mm_configs
1607-
# as we haven't validated exhaustive support here yet
1608-
# TODO(coconutruben): remove this once we have validated exhaustive support
1609-
# for scaled_mm
1610-
self.exhaustive_configs = self.scaled_mm_configs
1608+
1609+
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
1610+
configs = [c for c in configs if c.block_k >= 32]
1611+
return super()._filter_configs(configs)
16111612

16121613

16131614
# TODO(coconutruben): replace with template.name once templates are importable
@@ -1621,11 +1622,6 @@ def __init__(self) -> None:
16211622
super().__init__()
16221623
# Override mm_configs to use scaled_persistent_mm_configs for TMA
16231624
self.mm_configs = self.scaled_persistent_mm_configs
1624-
# NOTE: overriding exhaustive configs here to be the same as mm_configs
1625-
# as we haven't validated exhaustive support here yet
1626-
# TODO(coconutruben): remove this once we have validated exhaustive support
1627-
# for scaled_mm
1628-
self.exhaustive_configs = self.scaled_persistent_mm_configs
16291625

16301626

16311627
# TODO(coconutruben): replace with template.name once templates are importable

0 commit comments

Comments
 (0)