Skip to content

Commit 1c03c46

Browse files
authored
Throw if all autotune configs are pruned (#8059)
1 parent a25e06d commit 1c03c46

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

python/test/unit/runtime/test_autotuner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,32 @@ def grid(meta):
448448
warp_size = triton.runtime.driver.active.get_current_target().warp_size
449449
assert exception_out_of_resource is not None and f"out of resource: threads, Required: {128 * warp_size}" in str(
450450
exception_out_of_resource)
451+
452+
453+
def test_prune_all_configs(device):
454+
N = 1024
455+
src = torch.randn(N, device=device)
456+
dst = torch.empty(N, device=device)
457+
458+
def early_config_prune(configs, named_args, **kwargs):
459+
return []
460+
461+
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
462+
463+
prune_configs_by = {'early_config_prune': early_config_prune}
464+
465+
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by)
466+
@triton.jit
467+
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
468+
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
469+
x = tl.load(src + offsets, mask=offsets < N)
470+
tl.store(dst + offsets, x, mask=offsets < N)
471+
472+
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
473+
try:
474+
_kernel[grid](dst, src, N=N)
475+
pytest.fail("Expected exception was not thrown.")
476+
except triton.TritonError as e:
477+
assert e is not None and str(
478+
e
479+
) == "Autotuner error: No valid autotuner configs after pruning. `early_config_prune` should return at least one config."

python/triton/runtime/autotuner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .. import knobs
1212
from .jit import KernelInterface, JITFunction
13-
from .errors import OutOfResources, PTXASError
13+
from .errors import OutOfResources, PTXASError, AutotunerError
1414
from .driver import driver
1515
from .cache import get_cache_manager, triton_key
1616
from triton._C.libtriton import get_cache_invalidating_env_vars
@@ -25,7 +25,9 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr
2525
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
2626
'perf_model': performance model used to predicate running time with different configs, returns running time
2727
'top_k': number of configs to bench
28-
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
28+
'early_config_prune': a function used to prune configs. It should have the signature
29+
`prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:`
30+
and return pruned configs. It should return at least one config.
2931
"""
3032
if not configs:
3133
self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
@@ -259,6 +261,9 @@ def prune_configs(self, kwargs: Dict) -> List[Config]:
259261
pruned_configs = self.configs
260262
if self.early_config_prune:
261263
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
264+
if not pruned_configs:
265+
raise AutotunerError(
266+
"No valid autotuner configs after pruning. `early_config_prune` should return at least one config.")
262267
if self.perf_model:
263268
top_k = self.configs_top_k
264269
if isinstance(top_k, float) and top_k <= 1.0:
@@ -406,7 +411,9 @@ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
406411
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
407412
'perf_model': performance model used to predicate running time with different configs, returns running time
408413
'top_k': number of configs to bench
409-
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
414+
'early_config_prune': a function used to prune configs. It should have the signature
415+
`prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:`
416+
and return pruned configs. It should return at least one config.
410417
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
411418
:type reset_to_zero: list[str]
412419
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.

python/triton/runtime/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,13 @@ def __init__(self, error_message: Optional[str] = None):
3434
def __str__(self) -> str:
3535
error_message = self.error_message or ""
3636
return f"PTXAS error: {error_message}"
37+
38+
39+
class AutotunerError(TritonError):
40+
41+
def __init__(self, error_message: Optional[str] = None):
42+
self.error_message = error_message
43+
44+
def __str__(self) -> str:
45+
error_message = self.error_message or ""
46+
return f"Autotuner error: {error_message}"

0 commit comments

Comments
 (0)