Skip to content

Commit f971e9f

Browse files
authored
[Runtime] Fix out of threads error to raise on every call (#7857)
Currently the attention tutorial is failing on gb200. This happens because we only check for too many threads on the first `_init_modules` call. This changes it so that when we raise `OutOfResources` we also set `self._run` to a function that will raise the error. This means we will still raise the error, even if the module has already been loaded.
1 parent ff3832d commit f971e9f

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

python/triton/compiler/compiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ def __missing__(self, key):
403403
return value
404404

405405

406+
def _raise_error(err, *args, **kwargs):
407+
raise err
408+
409+
406410
class CompiledKernel:
407411

408412
def __init__(self, src, metadata_group, hash):
@@ -439,32 +443,38 @@ def __init__(self, src, metadata_group, hash):
439443
def _init_handles(self):
440444
if self.module is not None:
441445
return
446+
447+
def raise_(err):
448+
self._run = functools.partial(_raise_error, err)
449+
raise err
450+
442451
device = driver.active.get_current_device()
443452
# create launcher
444453
self._run = driver.active.launcher_cls(self.src, self.metadata)
445454
# not enough shared memory to run the kernel
446455
max_shared = max_shared_mem(device)
447456
if self.metadata.shared > max_shared:
448-
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
457+
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
449458
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
450459
# Use blackwell max tmem size for now, this should be moved in device properties
451460
max_tmem_size = 512 # tmem size in number of columns
452461
if self.metadata.tmem_size > max_tmem_size:
453-
raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
462+
raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
454463
if knobs.runtime.kernel_load_start_hook is not None:
455464
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
456465
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
457466
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
458467
self.name, self.kernel, self.metadata.shared, device)
459468
warp_size = driver.active.get_current_target().warp_size
460469
if self.metadata.num_warps * warp_size > self.n_max_threads:
461-
raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")
470+
raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
462471
if knobs.runtime.kernel_load_end_hook is not None:
463472
knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
464473

465474
@property
466475
def run(self):
467-
self._init_handles()
476+
if self._run is None:
477+
self._init_handles()
468478
return self._run
469479

470480
def launch_metadata(self, grid, stream, *args):

0 commit comments

Comments
 (0)