Skip to content

Commit e14466c

Browse files
authored
Add option to ban/disallow autotuning (#184)
1 parent e6adcac commit e14466c

File tree

4 files changed

+46
-2
lines changed

4 files changed

+46
-2
lines changed

helion/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,7 @@ class TensorOperationsInHostCall(TensorOperationInWrapper):
341341

342342
class WrongDevice(BaseWarning):
343343
message = "Operation {0} returned a tensor on {1} device, but the kernel is on {2} device. "
344+
345+
346+
class AutotuningDisallowedInEnvironment(BaseWarning):
347+
message = "Autotuning is disabled {0}, please provide a config to @helion.kernel via the config= argument."

helion/runtime/kernel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ def autotune(
410410

411411
config = FiniteSearch(self, args, self.configs).autotune()
412412
else:
413+
self.settings.check_autotuning_disabled()
414+
413415
from ..autotuner import DifferentialEvolutionSearch
414416

415417
config = DifferentialEvolutionSearch(

helion/runtime/settings.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from typing import cast
1212

1313
import torch
14+
from torch._environment import is_fbcode
15+
16+
from helion import exc
1417

1518
if TYPE_CHECKING:
1619
from contextlib import AbstractContextManager
1720

18-
from helion import exc
19-
2021
class _TLS(Protocol):
2122
default_settings: Settings | None
2223

@@ -110,6 +111,20 @@ def shallow_copy(x: object) -> object:
110111

111112
return {k: shallow_copy(v) for k, v in dataclasses.asdict(self).items()}
112113

114+
def check_autotuning_disabled(self) -> None:
115+
msg = None
116+
if os.environ.get("HELION_DISALLOW_AUTOTUNING", "0") == "1":
117+
msg = "by HELION_DISALLOW_AUTOTUNING=1"
118+
if is_fbcode():
119+
from aiplatform.runtime_environment.runtime_environment_pybind import ( # pyre-fixme[21]
120+
RuntimeEnvironment,
121+
)
122+
123+
if RuntimeEnvironment().get_mast_job_name() is not None:
124+
msg = "because autotuning is not allowed in MAST environment"
125+
if msg:
126+
raise exc.AutotuningDisallowedInEnvironment(msg)
127+
113128
@staticmethod
114129
def default() -> Settings:
115130
"""

test/test_autotuner.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

3+
import os
34
from pathlib import Path
45
import random
56
import tempfile
67
import unittest
78
from unittest.mock import patch
89

910
from expecttest import TestCase
11+
import pytest
1012
import torch
1113

1214
import helion
@@ -200,6 +202,27 @@ def add(a, b):
200202
result = add(*args)
201203
torch.testing.assert_close(result, sum(args))
202204

205+
def test_autotuner_disabled(self):
206+
@helion.kernel
207+
def add(a, b):
208+
out = torch.empty_like(a)
209+
for tile in hl.tile(out.size()):
210+
out[tile] = a[tile] + b[tile]
211+
return out
212+
213+
args = (
214+
torch.randn([8, 512, 512], device=DEVICE),
215+
torch.randn([8, 512, 512], device=DEVICE),
216+
)
217+
with (
218+
patch.dict(os.environ, {"HELION_DISALLOW_AUTOTUNING": "1"}),
219+
pytest.raises(
220+
expected_exception=helion.exc.AutotuningDisallowedInEnvironment,
221+
match="Autotuning is disabled by HELION_DISALLOW_AUTOTUNING=1, please provide a config to @helion.kernel via the config= argument.",
222+
),
223+
):
224+
add(*args)
225+
203226

204227
if __name__ == "__main__":
205228
unittest.main()

0 commit comments

Comments
 (0)