Skip to content

Commit 8b2ed6d

Browse files
authored
[IMP][Launch Latency] remove __getattr__ overhead from DriverConfig and CompiledKernel (#7770)
1 parent 58ae6f0 commit 8b2ed6d

File tree

3 files changed

+29
-49
lines changed

3 files changed

+29
-49
lines changed

python/test/unit/runtime/test_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ def test_is_lazy():
1010
from importlib import reload
1111
reload(sys.modules["triton.runtime.driver"])
1212
reload(sys.modules["triton.runtime"])
13-
mod = sys.modules[triton.runtime.driver.__module__]
14-
assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy"))
15-
assert triton.runtime.driver.active._obj is None
13+
assert triton.runtime.driver._active is None
14+
assert triton.runtime.driver._default is None
15+
assert isinstance(triton.runtime.driver.active, getattr(triton.backends.driver, "DriverBase"))
16+
assert isinstance(triton.runtime.driver.default, getattr(triton.backends.driver, "DriverBase"))
1617
utils = triton.runtime.driver.active.utils # noqa: F841
17-
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))
1818

1919

2020
def test_kernel_in_thread(device):

python/triton/compiler/compiler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,14 @@ def __init__(self, src, metadata_group, hash):
437437
# (e.g., checking amount of shared memory on current device)
438438
self.module = None
439439
self.function = None
440+
self._run = None
440441

441442
def _init_handles(self):
442443
if self.module is not None:
443444
return
444445
device = driver.active.get_current_device()
445446
# create launcher
446-
self.run = driver.active.launcher_cls(self.src, self.metadata)
447+
self._run = driver.active.launcher_cls(self.src, self.metadata)
447448
# not enough shared memory to run the kernel
448449
max_shared = max_shared_mem(device)
449450
if self.metadata.shared > max_shared:
@@ -462,10 +463,14 @@ def _init_handles(self):
462463
if knobs.runtime.init_handle_hook is not None:
463464
knobs.runtime.init_handle_hook(self.module, self.function, self.name, self.metadata_group)
464465

465-
def __getattribute__(self, name):
466-
if name == 'run':
466+
@property
467+
def run(self):
468+
# it should be safe to do this as launch_metadata will
469+
# call _init_handles before running the kernel or it
470+
# was called manually or it was already initialized
471+
if self._run is None:
467472
self._init_handles()
468-
return super().__getattribute__(name)
473+
return self._run
469474

470475
def launch_metadata(self, grid, stream, *args):
471476
if knobs.runtime.launch_enter_hook is None:

python/triton/runtime/driver.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from ..backends import backends, DriverBase
44

5-
from typing import Any, Callable, Generic, TypeVar, Union
6-
75

86
def _create_driver() -> DriverBase:
97
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
@@ -12,52 +10,29 @@ def _create_driver() -> DriverBase:
1210
return active_drivers[0]()
1311

1412

15-
T = TypeVar("T")
16-
17-
18-
class LazyProxy(Generic[T]):
19-
20-
def __init__(self, init_fn: Callable[[], T]) -> None:
21-
self._init_fn = init_fn
22-
self._obj: Union[T, None] = None
23-
24-
def _initialize_obj(self) -> T:
25-
if self._obj is None:
26-
self._obj = self._init_fn()
27-
return self._obj
28-
29-
def __getattr__(self, name) -> Any:
30-
return getattr(self._initialize_obj(), name)
31-
32-
def __setattr__(self, name: str, value: Any) -> None:
33-
if name in ["_init_fn", "_obj"]:
34-
super().__setattr__(name, value)
35-
else:
36-
setattr(self._initialize_obj(), name, value)
37-
38-
def __delattr__(self, name: str) -> None:
39-
delattr(self._initialize_obj(), name)
40-
41-
def __repr__(self) -> str:
42-
if self._obj is None:
43-
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
44-
return repr(self._obj)
45-
46-
def __str__(self) -> str:
47-
return str(self._initialize_obj())
48-
49-
5013
class DriverConfig:
5114

5215
def __init__(self) -> None:
53-
self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver)
54-
self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default
16+
self._default: DriverBase | None = None
17+
self._active: DriverBase | None = None
18+
19+
@property
20+
def default(self) -> DriverBase:
21+
if self._default is None:
22+
self._default = _create_driver()
23+
return self._default
24+
25+
@property
26+
def active(self) -> DriverBase:
27+
if self._active is None:
28+
self._active = self.default
29+
return self._active
5530

5631
def set_active(self, driver: DriverBase) -> None:
57-
self.active = driver
32+
self._active = driver
5833

5934
def reset_active(self) -> None:
60-
self.active = self.default
35+
self._active = self.default
6136

6237

6338
driver = DriverConfig()

0 commit comments

Comments
 (0)