Skip to content

Commit 69c74b2

Browse files
authored
Reapply "[IMP][Launch Latency] misc improvements to compiler/kernel" and fix too aggressive optimization (#7851)
1 parent 37baa79 commit 69c74b2

File tree

3 files changed

+26
-50
lines changed

3 files changed

+26
-50
lines changed

python/test/unit/runtime/test_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ def test_is_lazy():
1010
from importlib import reload
1111
reload(sys.modules["triton.runtime.driver"])
1212
reload(sys.modules["triton.runtime"])
13-
mod = sys.modules[triton.runtime.driver.__module__]
14-
assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy"))
15-
assert triton.runtime.driver.active._obj is None
13+
assert triton.runtime.driver._active is None
14+
assert triton.runtime.driver._default is None
15+
assert isinstance(triton.runtime.driver.active, getattr(triton.backends.driver, "DriverBase"))
16+
assert isinstance(triton.runtime.driver.default, getattr(triton.backends.driver, "DriverBase"))
1617
utils = triton.runtime.driver.active.utils # noqa: F841
17-
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))
1818

1919

2020
def test_kernel_in_thread(device):

python/triton/compiler/compiler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,14 @@ def __init__(self, src, metadata_group, hash):
434434
# (e.g., checking amount of shared memory on current device)
435435
self.module = None
436436
self.function = None
437+
self._run = None
437438

438439
def _init_handles(self):
439440
if self.module is not None:
440441
return
441442
device = driver.active.get_current_device()
442443
# create launcher
443-
self.run = driver.active.launcher_cls(self.src, self.metadata)
444+
self._run = driver.active.launcher_cls(self.src, self.metadata)
444445
# not enough shared memory to run the kernel
445446
max_shared = max_shared_mem(device)
446447
if self.metadata.shared > max_shared:
@@ -461,10 +462,10 @@ def _init_handles(self):
461462
if knobs.runtime.kernel_load_end_hook is not None:
462463
knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
463464

464-
def __getattribute__(self, name):
465-
if name == 'run':
466-
self._init_handles()
467-
return super().__getattribute__(name)
465+
@property
466+
def run(self):
467+
self._init_handles()
468+
return self._run
468469

469470
def launch_metadata(self, grid, stream, *args):
470471
if knobs.runtime.launch_enter_hook is None:

python/triton/runtime/driver.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from ..backends import backends, DriverBase
44

5-
from typing import Any, Callable, Generic, TypeVar, Union
6-
75

86
def _create_driver() -> DriverBase:
97
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
@@ -12,52 +10,29 @@ def _create_driver() -> DriverBase:
1210
return active_drivers[0]()
1311

1412

15-
T = TypeVar("T")
16-
17-
18-
class LazyProxy(Generic[T]):
19-
20-
def __init__(self, init_fn: Callable[[], T]) -> None:
21-
self._init_fn = init_fn
22-
self._obj: Union[T, None] = None
23-
24-
def _initialize_obj(self) -> T:
25-
if self._obj is None:
26-
self._obj = self._init_fn()
27-
return self._obj
28-
29-
def __getattr__(self, name) -> Any:
30-
return getattr(self._initialize_obj(), name)
31-
32-
def __setattr__(self, name: str, value: Any) -> None:
33-
if name in ["_init_fn", "_obj"]:
34-
super().__setattr__(name, value)
35-
else:
36-
setattr(self._initialize_obj(), name, value)
37-
38-
def __delattr__(self, name: str) -> None:
39-
delattr(self._initialize_obj(), name)
40-
41-
def __repr__(self) -> str:
42-
if self._obj is None:
43-
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
44-
return repr(self._obj)
45-
46-
def __str__(self) -> str:
47-
return str(self._initialize_obj())
48-
49-
5013
class DriverConfig:
5114

5215
def __init__(self) -> None:
53-
self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver)
54-
self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default
16+
self._default: DriverBase | None = None
17+
self._active: DriverBase | None = None
18+
19+
@property
20+
def default(self) -> DriverBase:
21+
if self._default is None:
22+
self._default = _create_driver()
23+
return self._default
24+
25+
@property
26+
def active(self) -> DriverBase:
27+
if self._active is None:
28+
self._active = self.default
29+
return self._active
5530

5631
def set_active(self, driver: DriverBase) -> None:
57-
self.active = driver
32+
self._active = driver
5833

5934
def reset_active(self) -> None:
60-
self.active = self.default
35+
self._active = self.default
6136

6237

6338
driver = DriverConfig()

0 commit comments

Comments
 (0)