Skip to content

Commit efc5d35

Browse files
danzimmFindHao
authored andcommitted
[python][compiler] Memoize device max shared memory per device (triton-lang#6503)
Similar to triton-lang#6000 this patch is an upstreamed internal patch at Meta with the goal of reducing our internal patches, cc @jamesjwu the original author. When running various benchmarks with small kernels we see a non-trivial amount of time spent fetching this property, and memoizing helped. It might be worth looking into memoizing `get_device_properties`, but I think that'd need a more careful treatment in the driver package in order to properly handle arbitrary backends.
1 parent 3936805 commit efc5d35

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/triton/compiler/compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def triton_key():
164164
return f'{__version__}' + '-'.join(contents)
165165

166166

167+
@functools.lru_cache()
168+
def max_shared_mem(device):
169+
return driver.active.utils.get_device_properties(device)["max_shared_mem"]
170+
171+
167172
def parse(full_name, ext, context):
168173
if ext == "ttir" or ext == "ttgir":
169174
module = ir.parse_mlir_module(full_name, context)
@@ -400,7 +405,7 @@ def _init_handles(self):
400405
# create launcher
401406
self.run = driver.active.launcher_cls(self.src, self.metadata)
402407
# not enough shared memory to run the kernel
403-
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
408+
max_shared = max_shared_mem(device)
404409
if self.metadata.shared > max_shared:
405410
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
406411
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:

0 commit comments

Comments
 (0)