Skip to content

Commit 558841e

Browse files
Reland upstream commit f7d2674: "[Knobs] Reduce overhead of reading knobs (#7841)" (#4966)
This reverts commit 298480d.
2 parents 231a3e8 + 49339d7 commit 558841e

File tree

2 files changed

+164
-86
lines changed

2 files changed

+164
-86
lines changed

python/src/ir.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,85 @@ void init_triton_ir(py::module &&m) {
19041904
py::call_guard<py::gil_scoped_release>());
19051905
}
19061906

1907+
bool str_eq_ignore_case(const char *s1, const char *s2, int n) {
1908+
for (int i = 0; i < n; ++i) {
1909+
if (tolower(s1[i]) != s2[i])
1910+
return false;
1911+
}
1912+
return true;
1913+
}
1914+
1915+
int strlen_max(const char *str, int max) {
1916+
for (int i = 0; i <= max; ++i) {
1917+
if (str[i] == '\0') {
1918+
return i;
1919+
}
1920+
}
1921+
return 0;
1922+
}
1923+
1924+
bool is_truthy(char *str) {
1925+
int len = strlen_max(str, 4);
1926+
switch (len) {
1927+
case 1:
1928+
return str[0] == '1' || tolower(str[0]) == 'y';
1929+
case 2:
1930+
return str_eq_ignore_case(str, "on", len);
1931+
case 3:
1932+
return str_eq_ignore_case(str, "yes", len);
1933+
case 4:
1934+
return str_eq_ignore_case(str, "true", len);
1935+
default:
1936+
return false;
1937+
}
1938+
}
1939+
1940+
PyObject *py_getenv(PyObject *self, PyObject *const *args, Py_ssize_t nargs) {
1941+
if (!(nargs == 1 || nargs == 2)) {
1942+
PyErr_SetString(PyExc_TypeError, "getenv expected 1 or 2 arguments");
1943+
return NULL;
1944+
}
1945+
PyObject *name = args[0];
1946+
PyObject *default_val = nargs == 2 ? args[1] : Py_None;
1947+
if (!PyUnicode_CheckExact(name)) {
1948+
PyErr_SetString(PyExc_TypeError, "name must be a string");
1949+
return NULL;
1950+
}
1951+
char *env_val = getenv(PyUnicode_AsUTF8(name));
1952+
if (!env_val) {
1953+
Py_INCREF(default_val);
1954+
return default_val;
1955+
}
1956+
return PyUnicode_FromString(env_val);
1957+
}
1958+
1959+
PyObject *py_getenv_bool(PyObject *self, PyObject *const *args,
1960+
Py_ssize_t nargs) {
1961+
if (nargs != 2) {
1962+
PyErr_SetString(PyExc_TypeError, "getenv_bool expected 2 arguments");
1963+
return NULL;
1964+
}
1965+
PyObject *name = args[0];
1966+
PyObject *default_val = args[1];
1967+
if (!PyUnicode_CheckExact(name)) {
1968+
PyErr_SetString(PyExc_TypeError, "name must be a string");
1969+
return NULL;
1970+
}
1971+
char *env_val = getenv(PyUnicode_AsUTF8(name));
1972+
PyObject *res = default_val;
1973+
if (env_val) {
1974+
res = is_truthy(env_val) ? Py_True : Py_False;
1975+
}
1976+
Py_INCREF(res);
1977+
return res;
1978+
}
1979+
1980+
static PyMethodDef ModuleMethods[] = {
1981+
{"getenv", (PyCFunction)py_getenv, METH_FASTCALL, NULL},
1982+
{"getenv_bool", (PyCFunction)py_getenv_bool, METH_FASTCALL, NULL},
1983+
{NULL, NULL, 0, NULL} // sentinel
1984+
};
1985+
19071986
void init_triton_env_vars(py::module &m) {
19081987
m.def("get_cache_invalidating_env_vars",
19091988
[]() -> std::map<std::string, std::string> {
@@ -1920,4 +1999,5 @@ void init_triton_env_vars(py::module &m) {
19201999
}
19212000
return ret;
19222001
});
2002+
PyModule_AddFunctions(m.ptr(), ModuleMethods);
19232003
}

python/triton/knobs.py

Lines changed: 84 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from contextlib import contextmanager
1313
from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
1414

15+
from triton._C.libtriton import getenv, getenv_bool
16+
1517
if TYPE_CHECKING:
1618
from .runtime.cache import CacheManager, RemoteCacheBackend
1719
from .runtime.jit import JitFunctionInfo, KernelParam
@@ -27,11 +29,6 @@ class Env:
2729
propagate_env: bool = True
2830

2931

30-
def getenv(key: str) -> Optional[str]:
31-
res = os.getenv(key)
32-
return res.strip() if res is not None else res
33-
34-
3532
def setenv(key: str, value: Optional[str]) -> None:
3633
if not propagate_env:
3734
return
@@ -64,32 +61,25 @@ def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
6461
SetType = TypeVar("SetType")
6562
GetType = TypeVar("GetType")
6663

64+
_NOTHING = object()
65+
6766

6867
class env_base(Generic[SetType, GetType]):
6968

70-
def __init__(self, key: str, default: Union[SetType, Callable[[], SetType]]) -> None:
69+
def __init__(self, key: str) -> None:
7170
self.key = key
72-
self.default: Callable[[], SetType] = default if callable(default) else lambda: default
7371

7472
def __set_name__(self, objclass: Type[object], name: str) -> None:
7573
self.name = name
7674

7775
def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
78-
if obj is None:
79-
raise AttributeError(f"Cannot access {type(self)} on non-instance")
80-
81-
if self.name in obj.__dict__:
82-
return self.transform(obj.__dict__[self.name])
83-
else:
76+
py_val = obj.__dict__.get(self.name, _NOTHING)
77+
if py_val is _NOTHING:
8478
return self.get()
85-
86-
@property
87-
def env_val(self) -> str | None:
88-
return getenv(self.key)
79+
return self.transform(py_val)
8980

9081
def get(self) -> GetType:
91-
env = self.env_val
92-
return self.transform(self.default() if env is None else self.from_env(env))
82+
raise NotImplementedError()
9383

9484
def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
9585
if isinstance(value, Env):
@@ -107,54 +97,70 @@ def transform(self, val: SetType) -> GetType:
10797
# if GetType != SetType.
10898
return cast(GetType, val)
10999

110-
def from_env(self, val: str) -> SetType:
111-
raise NotImplementedError()
112-
113100

114101
class env_str(env_base[str, str]):
115102

116-
def from_env(self, val: str) -> str:
117-
return val
103+
def __init__(self, key: str, default: str):
104+
super().__init__(key)
105+
self.default = default
106+
107+
def get(self) -> str:
108+
return getenv(self.key, self.default)
109+
110+
111+
class env_str_callable_default(env_base[str, str]):
112+
113+
def __init__(self, key: str, default_factory: Callable[[], str]):
114+
super().__init__(key)
115+
self.default_factory = default_factory
116+
117+
def get(self) -> str:
118+
env_val = getenv(self.key)
119+
if env_val is None:
120+
return self.default_factory()
121+
return env_val
118122

119123

120124
class env_bool(env_base[bool, bool]):
121125

122-
def __init__(self, key: str, default: Union[bool, Callable[[], bool]] = False) -> None:
123-
super().__init__(key, default)
126+
def __init__(self, key: str, default: bool = False) -> None:
127+
super().__init__(key)
128+
self.default = default
124129

125-
def from_env(self, val: str) -> bool:
126-
return val.lower() in ("1", "true", "yes", "on", "y")
130+
def get(self) -> bool:
131+
return getenv_bool(self.key, self.default)
127132

128133

129134
class env_int(env_base[int, int]):
130135

131-
def __init__(self, key: str, default: Union[int, Callable[[], int]] = 0) -> None:
132-
super().__init__(key, default)
136+
def __init__(self, key: str, default: int = 0) -> None:
137+
super().__init__(key)
138+
self.default = default
133139

134-
def from_env(self, val: str) -> int:
140+
def get(self) -> int:
141+
val = getenv(self.key)
142+
if val is None:
143+
return self.default
135144
try:
136145
return int(val)
137146
except ValueError as exc:
138147
raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
139148

140149

141-
class env_opt_base(Generic[GetType, SetType], env_base[Optional[GetType], Optional[SetType]]):
142-
143-
def __init__(self, key: str) -> None:
144-
super().__init__(key, None)
145-
146-
147150
ClassType = TypeVar("ClassType")
148151

149152

150-
class env_class(Generic[ClassType], env_opt_base[Type[ClassType], Type[ClassType]]):
153+
class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]):
151154

152155
def __init__(self, key: str, type: str) -> None:
153156
super().__init__(key)
154157
# We can't pass the type directly to avoid import cycles
155158
self.type = type
156159

157-
def from_env(self, val: str) -> Type[ClassType]:
160+
def get(self) -> Optional[Type[ClassType]]:
161+
val = getenv(self.key)
162+
if val is None:
163+
return None
158164
comps = val.split(":", 1)
159165
if len(comps) != 2:
160166
raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
@@ -201,7 +207,7 @@ def from_path(path: str) -> Optional[IntelTool]:
201207
if version is None:
202208
return None
203209
return IntelTool(path, version.group(1))
204-
except subprocess.CalledProcessError:
210+
except (subprocess.CalledProcessError, FileNotFoundError):
205211
return None
206212

207213

@@ -210,73 +216,65 @@ class env_nvidia_tool(env_base[str, NvidiaTool]):
210216
def __init__(self, binary: str) -> None:
211217
binary += sysconfig.get_config_var("EXE")
212218
self.binary = binary
213-
super().__init__(f"TRITON_{binary.upper()}_PATH", lambda: os.path.join(
214-
os.path.dirname(__file__),
215-
"backends",
216-
"nvidia",
217-
"bin",
218-
self.binary,
219-
))
219+
self.default_path = os.path.join(os.path.dirname(__file__), "backends", "nvidia", "bin", binary)
220+
super().__init__(f"TRITON_{binary.upper()}_PATH")
221+
222+
def get(self) -> NvidiaTool:
223+
return self.transform(getenv(self.key))
220224

221225
def transform(self, path: str) -> NvidiaTool:
222-
paths = [
223-
path,
224-
# We still add default as fallback in case the pointed binary isn't
225-
# accessible.
226-
self.default(),
227-
]
226+
# We still add default as fallback in case the pointed binary isn't
227+
# accessible.
228+
if path is not None:
229+
paths = [path, self.default_path]
230+
else:
231+
paths = [self.default_path]
232+
228233
for path in paths:
229-
if not path or not os.access(path, os.X_OK):
230-
continue
231234
if tool := NvidiaTool.from_path(path):
232235
return tool
233236

234237
raise RuntimeError(f"Cannot find {self.binary}")
235238

236-
def from_env(self, val: str) -> str:
237-
return val
238-
239239

240240
class env_intel_tool(env_base[str, IntelTool]):
241241

242242
def __init__(self, binary: str) -> None:
243243
binary += sysconfig.get_config_var("EXE")
244244
self.binary = binary
245-
super().__init__(f"TRITON_{binary.upper().replace('-', '_')}_PATH", lambda: os.path.join(
246-
os.path.dirname(__file__),
247-
"backends",
248-
"intel",
249-
"bin",
250-
self.binary,
251-
))
245+
self.default_path = os.path.join(os.path.dirname(__file__), "backends", "intel", "bin", binary)
246+
super().__init__(f"TRITON_{binary.upper()}_PATH")
247+
248+
def get(self) -> IntelTool:
249+
return self.transform(getenv(self.key))
252250

253251
def transform(self, path: str) -> IntelTool:
254-
paths = [
255-
path,
256-
# We still add default as fallback in case the pointed binary isn't
257-
# accessible.
258-
self.default(),
259-
shutil.which(self.binary) or "",
260-
]
252+
# We still add default as fallback in case the pointed binary isn't
253+
# accessible.
254+
if path is not None:
255+
paths = [path, self.default_path]
256+
else:
257+
paths = [self.default_path]
258+
if shutil_path := shutil.which(self.binary):
259+
paths += [shutil_path]
261260
for path in paths:
262-
if not path or not os.access(path, os.X_OK):
263-
continue
264261
if tool := IntelTool.from_path(path):
265262
return tool
266263

267264
raise RuntimeError(f"Cannot find {self.binary}")
268265

269-
def from_env(self, val: str) -> str:
270-
return val
271-
272266

273267
# Separate classes so that types are correct
274-
class env_opt_str(env_opt_base[str, str], env_str):
275-
pass
268+
class env_opt_str(env_base[Optional[str], Optional[str]]):
276269

270+
def get(self) -> Optional[str]:
271+
return getenv(self.key)
277272

278-
class env_opt_bool(env_opt_base[bool, bool], env_bool):
279-
pass
273+
274+
class env_opt_bool(env_base):
275+
276+
def get(self) -> Optional[str]:
277+
return getenv_bool(self.key, None)
280278

281279

282280
@dataclass(frozen=True)
@@ -344,7 +342,7 @@ def reset(self: knobs_type) -> knobs_type:
344342
@contextmanager
345343
def scope(self) -> Generator[None, None, None]:
346344
try:
347-
initial_env = {knob.key: knob.env_val for knob in self.knob_descriptors.values()}
345+
initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()}
348346
orig = dict(self.__dict__)
349347
yield
350348
finally:
@@ -389,11 +387,11 @@ class redis_knobs(base_knobs):
389387

390388

391389
class cache_knobs(base_knobs):
392-
home_dir: env_str = env_str("TRITON_HOME", lambda: os.path.expanduser("~/"))
390+
home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/"))
393391

394-
dump_dir: env_str = env_str("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
395-
override_dir: env_str = env_str("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
396-
dir: env_str = env_str("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
392+
dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
393+
override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
394+
dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
397395

398396
manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
399397
remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")

0 commit comments

Comments
 (0)