1212from contextlib import contextmanager
1313from typing import cast , Any , Callable , Generator , Generic , Optional , Protocol , Type , TypeVar , TypedDict , TYPE_CHECKING , Union
1414
15+ from triton ._C .libtriton import getenv , getenv_bool
16+
1517if TYPE_CHECKING :
1618 from .runtime .cache import CacheManager , RemoteCacheBackend
1719 from .runtime .jit import JitFunctionInfo , KernelParam
@@ -27,11 +29,6 @@ class Env:
2729propagate_env : bool = True
2830
2931
30- def getenv (key : str ) -> Optional [str ]:
31- res = os .getenv (key )
32- return res .strip () if res is not None else res
33-
34-
3532def setenv (key : str , value : Optional [str ]) -> None :
3633 if not propagate_env :
3734 return
@@ -64,32 +61,25 @@ def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
6461SetType = TypeVar ("SetType" )
6562GetType = TypeVar ("GetType" )
6663
64+ _NOTHING = object ()
65+
6766
6867class env_base (Generic [SetType , GetType ]):
6968
70- def __init__ (self , key : str , default : Union [ SetType , Callable [[], SetType ]] ) -> None :
69+ def __init__ (self , key : str ) -> None :
7170 self .key = key
72- self .default : Callable [[], SetType ] = default if callable (default ) else lambda : default
7371
7472 def __set_name__ (self , objclass : Type [object ], name : str ) -> None :
7573 self .name = name
7674
7775 def __get__ (self , obj : Optional [object ], objclass : Optional [Type [object ]]) -> GetType :
78- if obj is None :
79- raise AttributeError (f"Cannot access { type (self )} on non-instance" )
80-
81- if self .name in obj .__dict__ :
82- return self .transform (obj .__dict__ [self .name ])
83- else :
76+ py_val = obj .__dict__ .get (self .name , _NOTHING )
77+ if py_val is _NOTHING :
8478 return self .get ()
85-
86- @property
87- def env_val (self ) -> str | None :
88- return getenv (self .key )
79+ return self .transform (py_val )
8980
9081 def get (self ) -> GetType :
91- env = self .env_val
92- return self .transform (self .default () if env is None else self .from_env (env ))
82+ raise NotImplementedError ()
9383
9484 def __set__ (self , obj : object , value : Union [SetType , Env ]) -> None :
9585 if isinstance (value , Env ):
@@ -107,54 +97,70 @@ def transform(self, val: SetType) -> GetType:
10797 # if GetType != SetType.
10898 return cast (GetType , val )
10999
110- def from_env (self , val : str ) -> SetType :
111- raise NotImplementedError ()
112-
113100
114101class env_str (env_base [str , str ]):
115102
116- def from_env (self , val : str ) -> str :
117- return val
103+ def __init__ (self , key : str , default : str ):
104+ super ().__init__ (key )
105+ self .default = default
106+
107+ def get (self ) -> str :
108+ return getenv (self .key , self .default )
109+
110+
111+ class env_str_callable_default (env_base [str , str ]):
112+
113+ def __init__ (self , key : str , default_factory : Callable [[], str ]):
114+ super ().__init__ (key )
115+ self .default_factory = default_factory
116+
117+ def get (self ) -> str :
118+ env_val = getenv (self .key )
119+ if env_val is None :
120+ return self .default_factory ()
121+ return env_val
118122
119123
120124class env_bool (env_base [bool , bool ]):
121125
122- def __init__ (self , key : str , default : Union [bool , Callable [[], bool ]] = False ) -> None :
123- super ().__init__ (key , default )
126+ def __init__ (self , key : str , default : bool = False ) -> None :
127+ super ().__init__ (key )
128+ self .default = default
124129
125- def from_env (self , val : str ) -> bool :
126- return val . lower () in ( "1" , "true" , "yes" , "on" , "y" )
130+ def get (self ) -> bool :
131+ return getenv_bool ( self . key , self . default )
127132
128133
129134class env_int (env_base [int , int ]):
130135
131- def __init__ (self , key : str , default : Union [int , Callable [[], int ]] = 0 ) -> None :
132- super ().__init__ (key , default )
136+ def __init__ (self , key : str , default : int = 0 ) -> None :
137+ super ().__init__ (key )
138+ self .default = default
133139
134- def from_env (self , val : str ) -> int :
140+ def get (self ) -> int :
141+ val = getenv (self .key )
142+ if val is None :
143+ return self .default
135144 try :
136145 return int (val )
137146 except ValueError as exc :
138147 raise RuntimeError (f"Unable to use { self .key } ={ val } : expected int" ) from exc
139148
140149
141- class env_opt_base (Generic [GetType , SetType ], env_base [Optional [GetType ], Optional [SetType ]]):
142-
143- def __init__ (self , key : str ) -> None :
144- super ().__init__ (key , None )
145-
146-
147150ClassType = TypeVar ("ClassType" )
148151
149152
150- class env_class (Generic [ClassType ], env_opt_base [ Type [ClassType ], Type [ClassType ]]):
153+ class env_class (Generic [ClassType ], env_base [ Optional [ Type [ClassType ]], Optional [ Type [ClassType ] ]]):
151154
152155 def __init__ (self , key : str , type : str ) -> None :
153156 super ().__init__ (key )
154157 # We can't pass the type directly to avoid import cycles
155158 self .type = type
156159
157- def from_env (self , val : str ) -> Type [ClassType ]:
160+ def get (self ) -> Optional [Type [ClassType ]]:
161+ val = getenv (self .key )
162+ if val is None :
163+ return None
158164 comps = val .split (":" , 1 )
159165 if len (comps ) != 2 :
160166 raise RuntimeError (f"Unable to read { self .key } : '{ val } ' isn't of the form MODULE:CLASS" )
@@ -201,7 +207,7 @@ def from_path(path: str) -> Optional[IntelTool]:
201207 if version is None :
202208 return None
203209 return IntelTool (path , version .group (1 ))
204- except subprocess .CalledProcessError :
210+ except ( subprocess .CalledProcessError , FileNotFoundError ) :
205211 return None
206212
207213
@@ -210,73 +216,65 @@ class env_nvidia_tool(env_base[str, NvidiaTool]):
210216 def __init__ (self , binary : str ) -> None :
211217 binary += sysconfig .get_config_var ("EXE" )
212218 self .binary = binary
213- super ().__init__ (f"TRITON_{ binary .upper ()} _PATH" , lambda : os .path .join (
214- os .path .dirname (__file__ ),
215- "backends" ,
216- "nvidia" ,
217- "bin" ,
218- self .binary ,
219- ))
219+ self .default_path = os .path .join (os .path .dirname (__file__ ), "backends" , "nvidia" , "bin" , binary )
220+ super ().__init__ (f"TRITON_{ binary .upper ()} _PATH" )
221+
222+ def get (self ) -> NvidiaTool :
223+ return self .transform (getenv (self .key ))
220224
221225 def transform (self , path : str ) -> NvidiaTool :
222- paths = [
223- path ,
224- # We still add default as fallback in case the pointed binary isn't
225- # accessible.
226- self .default (),
227- ]
226+ # We still add default as fallback in case the pointed binary isn't
227+ # accessible.
228+ if path is not None :
229+ paths = [path , self .default_path ]
230+ else :
231+ paths = [self .default_path ]
232+
228233 for path in paths :
229- if not path or not os .access (path , os .X_OK ):
230- continue
231234 if tool := NvidiaTool .from_path (path ):
232235 return tool
233236
234237 raise RuntimeError (f"Cannot find { self .binary } " )
235238
236- def from_env (self , val : str ) -> str :
237- return val
238-
239239
240240class env_intel_tool (env_base [str , IntelTool ]):
241241
242242 def __init__ (self , binary : str ) -> None :
243243 binary += sysconfig .get_config_var ("EXE" )
244244 self .binary = binary
245- super ().__init__ (f"TRITON_{ binary .upper ().replace ('-' , '_' )} _PATH" , lambda : os .path .join (
246- os .path .dirname (__file__ ),
247- "backends" ,
248- "intel" ,
249- "bin" ,
250- self .binary ,
251- ))
245+ self .default_path = os .path .join (os .path .dirname (__file__ ), "backends" , "intel" , "bin" , binary )
246+ super ().__init__ (f"TRITON_{ binary .upper ()} _PATH" )
247+
248+ def get (self ) -> IntelTool :
249+ return self .transform (getenv (self .key ))
252250
253251 def transform (self , path : str ) -> IntelTool :
254- paths = [
255- path ,
256- # We still add default as fallback in case the pointed binary isn't
257- # accessible.
258- self .default (),
259- shutil .which (self .binary ) or "" ,
260- ]
252+ # We still add default as fallback in case the pointed binary isn't
253+ # accessible.
254+ if path is not None :
255+ paths = [path , self .default_path ]
256+ else :
257+ paths = [self .default_path ]
258+ if shutil_path := shutil .which (self .binary ):
259+ paths += [shutil_path ]
261260 for path in paths :
262- if not path or not os .access (path , os .X_OK ):
263- continue
264261 if tool := IntelTool .from_path (path ):
265262 return tool
266263
267264 raise RuntimeError (f"Cannot find { self .binary } " )
268265
269- def from_env (self , val : str ) -> str :
270- return val
271-
272266
273267# Separate classes so that types are correct
274- class env_opt_str (env_opt_base [str , str ], env_str ):
275- pass
268+ class env_opt_str (env_base [Optional [str ], Optional [str ]]):
276269
270+ def get (self ) -> Optional [str ]:
271+ return getenv (self .key )
277272
278- class env_opt_bool (env_opt_base [bool , bool ], env_bool ):
279- pass
273+
274+ class env_opt_bool (env_base ):
275+
276+ def get (self ) -> Optional [str ]:
277+ return getenv_bool (self .key , None )
280278
281279
282280@dataclass (frozen = True )
@@ -344,7 +342,7 @@ def reset(self: knobs_type) -> knobs_type:
344342 @contextmanager
345343 def scope (self ) -> Generator [None , None , None ]:
346344 try :
347- initial_env = {knob .key : knob .env_val for knob in self .knob_descriptors .values ()}
345+ initial_env = {knob .key : getenv ( knob .key ) for knob in self .knob_descriptors .values ()}
348346 orig = dict (self .__dict__ )
349347 yield
350348 finally :
@@ -389,11 +387,11 @@ class redis_knobs(base_knobs):
389387
390388
391389class cache_knobs (base_knobs ):
392- home_dir : env_str = env_str ("TRITON_HOME" , lambda : os .path .expanduser ("~/" ))
390+ home_dir : env_str = env_str ("TRITON_HOME" , os .path .expanduser ("~/" ))
393391
394- dump_dir : env_str = env_str ("TRITON_DUMP_DIR" , lambda : cache .get_triton_dir ("dump" ))
395- override_dir : env_str = env_str ("TRITON_OVERRIDE_DIR" , lambda : cache .get_triton_dir ("override" ))
396- dir : env_str = env_str ("TRITON_CACHE_DIR" , lambda : cache .get_triton_dir ("cache" ))
392+ dump_dir = env_str_callable_default ("TRITON_DUMP_DIR" , lambda : cache .get_triton_dir ("dump" ))
393+ override_dir = env_str_callable_default ("TRITON_OVERRIDE_DIR" , lambda : cache .get_triton_dir ("override" ))
394+ dir = env_str_callable_default ("TRITON_CACHE_DIR" , lambda : cache .get_triton_dir ("cache" ))
397395
398396 manager_class : env_class [CacheManager ] = env_class ("TRITON_CACHE_MANAGER" , "CacheManager" )
399397 remote_manager_class : env_class [RemoteCacheBackend ] = env_class ("TRITON_REMOTE_CACHE_BACKEND" , "RemoteCacheBackend" )
0 commit comments