Skip to content

File tree

5 files changed

+18
-94
lines changed

5 files changed

+18
-94
lines changed

test/inductor/test_max_autotune.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch._dynamo import reset
2020
from torch._dynamo.exc import BackendCompilerFailed
2121
from torch._dynamo.testing import rand_strided, reset_rng_state
22-
from torch._dynamo.utils import counters, same
22+
from torch._dynamo.utils import same
2323
from torch._inductor import config
2424
from torch._inductor.autotune_process import (
2525
_TestBenchmarkRequest,
@@ -1682,26 +1682,6 @@ def mm(x, y):
16821682
out, code = run_and_get_code(compiled_f, a, b)
16831683
torch.testing.assert_close(out, mm(a, b), atol=1e-2, rtol=1e-2)
16841684

1685-
@config.patch(
1686-
max_autotune_gemm=True,
1687-
max_autotune_prune_choices_based_on_shared_mem=True,
1688-
)
1689-
def test_max_autotune_prune_choices(self):
1690-
def mm(x, y):
1691-
return x @ y
1692-
1693-
M, K, N = (3, 3, 3)
1694-
1695-
x = torch.rand([M, K], device=GPU_TYPE, dtype=torch.float32)
1696-
y = torch.rand([K, N], device=GPU_TYPE, dtype=torch.float32)
1697-
1698-
compiled_f = torch.compile(mm)
1699-
compiled_f(x, y)
1700-
1701-
self.assertEqual(
1702-
counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0
1703-
)
1704-
17051685

17061686
class TestMaxAutotunePrecompile(TestCase):
17071687
def test_precompilation_threads(self):

test/inductor/test_triton_heuristics.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,7 @@
99
from torch._dynamo.testing import rand_strided
1010
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
1111
from torch._inductor.utils import clone_preserve_strides
12-
from torch.testing._internal.common_utils import (
13-
instantiate_parametrized_tests,
14-
IS_LINUX,
15-
parametrize,
16-
runOnRocm,
17-
skipIfXpu,
18-
)
12+
from torch.testing._internal.common_utils import IS_LINUX, runOnRocm, skipIfXpu
1913
from torch.testing._internal.inductor_utils import (
2014
GPU_TYPE,
2115
HAS_GPU,
@@ -73,7 +67,6 @@ def get_autotuned_amd_sqr_kernel():
7367
)(amd_sqr_kernel)
7468

7569

76-
@instantiate_parametrized_tests
7770
class TestTritonHeuristics(TestCase):
7871
device_type = GPU_TYPE
7972

@@ -269,28 +262,6 @@ def fn(x):
269262
res = torch.compile(fn)(x)
270263
self.assertEqual(ref, res)
271264

272-
@parametrize("do_pruning", [False, True])
273-
def test_prune_configs_over_shared_memory_limit(self, do_pruning):
274-
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
275-
276-
expected_count = 1 if do_pruning else 2
277-
mm_configs = [
278-
GemmConfig(32, 32, 32, 1, 8, 8),
279-
GemmConfig(
280-
128, 128, 128, 100, 8, 4
281-
), # intentionally large to exceed shared memory limit
282-
]
283-
with config.patch(
284-
{"max_autotune_prune_choices_based_on_shared_mem": do_pruning}
285-
):
286-
config_heuristic = CUDAConfigHeuristic()
287-
config_heuristic.should_scale_configs = False
288-
config_heuristic.mm_configs = mm_configs
289-
configs = list(
290-
config_heuristic.get_mm_configs()(3, 3, 3, dtype_size=4, op_name="mm")
291-
)
292-
self.assertEqual(len(configs), expected_count)
293-
294265

295266
class TestArgumentCloneAndRestore(TestCase):
296267
# Our tensor is large enough. If a unexpected copy happens, the

torch/_inductor/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,12 +448,6 @@ def prologue_fusion_enabled() -> bool:
448448
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1"
449449
)
450450

451-
# Prune configs that require more shared memory than the hardware limit
452-
max_autotune_prune_choices_based_on_shared_mem = (
453-
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM", "1")
454-
== "1"
455-
)
456-
457451
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
458452
graph_partition: bool = (
459453
os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")

torch/_inductor/select_algorithm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,9 +2760,6 @@ def wait_on_futures():
27602760
timeout=precompilation_timeout_seconds,
27612761
):
27622762
if e := future.exception():
2763-
counters["inductor"][
2764-
"select_algorithm_num_precompilation_exceptions"
2765-
] += 1
27662763
exceptions.append((futures[future], e))
27672764
from torch._inductor.codegen.cuda.cuda_kernel import (
27682765
CUDATemplateCaller,

torch/_inductor/template_heuristics.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -540,43 +540,34 @@ def _scale_mm_configs(
540540

541541
return scaled_configs
542542

543-
def _exceed_available_shared_memeory(
544-
self, gemm_config: BaseConfig, dtype_size: int
545-
) -> bool:
546-
try:
547-
if dtype_size <= 0:
548-
return False
549-
550-
device = torch.cuda.current_device()
551-
props = torch.cuda.get_device_properties(device)
552-
if not hasattr(props, "shared_memory_per_block_optin"):
553-
return False
554-
sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined]
555-
shared_mem_accum = dtype_size * (
556-
gemm_config.block_m * gemm_config.block_k
557-
+ gemm_config.block_n * gemm_config.block_k
558-
)
559-
return shared_mem_accum * gemm_config.num_stages > sm_available
560-
except Exception:
561-
return False
562-
563543
def _prune_exhaustive_configs(
564544
self,
565545
configs: list[BaseConfig],
566546
dtype_size: int,
567547
) -> list[BaseConfig]:
548+
import torch
549+
568550
pruned_configs = []
569551
for gemm_config in configs:
570-
# Will use more shared memory than available
571-
if self._exceed_available_shared_memeory(gemm_config, dtype_size):
572-
continue
573-
552+
device = torch.cuda.current_device()
553+
props = torch.cuda.get_device_properties(device)
554+
sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined]
574555
NUM_REG = 255
556+
575557
acc_regs = math.ceil(
576558
gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32)
577559
)
560+
561+
shared_mem_accum = dtype_size * (
562+
gemm_config.block_m * gemm_config.block_k
563+
+ gemm_config.block_n * gemm_config.block_k
564+
)
565+
566+
# Will use more shared memory than available
567+
if shared_mem_accum * gemm_config.num_stages > sm_available:
568+
continue
578569
# Lower bound for register spillage, if exceeds the kernel will certainly spill
579-
if acc_regs > NUM_REG:
570+
elif acc_regs > NUM_REG:
580571
continue
581572

582573
pruned_configs.append(gemm_config)
@@ -608,15 +599,6 @@ def preprocess_mm_configs(
608599
scaled_configs = self._scale_mm_configs(
609600
m, n, k, configs, scale, has_int8_tensor, exclude
610601
)
611-
612-
# Filter out configs that require more shared memory than is available.
613-
if dtype_size > 0 and config.max_autotune_prune_choices_based_on_shared_mem:
614-
scaled_configs = [
615-
c
616-
for c in scaled_configs
617-
if not self._exceed_available_shared_memeory(c, dtype_size)
618-
]
619-
620602
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
621603
assert dtype_size > 0, "dtype_size must be provided for exhaustive search"
622604
scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size)

0 commit comments

Comments
 (0)