1111from triton .runtime .build import _build
1212from triton .runtime .cache import get_cache_manager
1313from triton .backends .compiler import GPUTarget
14- from triton .backends .driver import DriverBase
14+ from triton .backends .driver import DriverBase , platform_key
1515
1616# A hard-coded cache version that can be updated when we know that the cached file is invalid and
1717# there are no other ways to detect that the runtime environment has changed. For example, a shared
@@ -251,11 +251,11 @@ def __del__(self):
251251
252252def compile_module_from_src (src , name ):
253253 hasher = hashlib .sha256 (__CACHE_VERSION .encode ("utf-8" ))
254- hasher .update (src .encode ("utf-8" ))
254+ hasher .update (( src + platform_key ()) .encode ("utf-8" ))
255255 key = hasher .hexdigest ()
256256 cache = get_cache_manager (key )
257- file_name = f" { name } . { sysconfig .get_config_var (' EXT_SUFFIX' ). split ( '.' )[ - 1 ] } "
258- cache_path = cache .get_file (file_name )
257+ suffix = sysconfig .get_config_var (" EXT_SUFFIX" )
258+ cache_path = cache .get_file (f" { name } { suffix } " )
259259 if cache_path is None :
260260 with tempfile .TemporaryDirectory () as tmpdir :
261261 src_path = os .path .join (tmpdir , "main.cpp" )
@@ -271,7 +271,7 @@ def compile_module_from_src(src, name):
271271 so = _build (name , src_path , tmpdir , COMPILATION_HELPER .library_dir , COMPILATION_HELPER .include_dir ,
272272 COMPILATION_HELPER .libraries , extra_compile_args = extra_compiler_args )
273273 with open (so , "rb" ) as f :
274- cache_path = cache .put (f .read (), file_name , binary = True )
274+ cache_path = cache .put (f .read (), f" { name } { suffix } " , binary = True )
275275
276276 if name == 'arch_utils' :
277277 return ArchParser (cache_path )
0 commit comments