Skip to content

Commit 72764da

Browse files
authored
[FRONTEND] Remove hardcoded warp size (#7253)
1 parent 5b7bc04 commit 72764da

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

python/triton/experimental/gluon/_runtime.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,9 @@ def make_ir(self, options, codegen_fns, module_map, context):
3131
module.set_attr("ttg.target", builder.get_string_attr(target))
3232
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
3333
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
34+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
3435

3536
is_cuda = options.backend_name == "cuda"
36-
37-
if is_cuda:
38-
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
39-
else:
40-
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(64))
41-
4237
if is_cuda and options.maxnreg is not None:
4338
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
4439

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class CUDAOptions:
101101
num_warps: int = 4
102102
num_ctas: int = 1
103103
num_stages: int = 3
104+
warp_size: int = 32
104105
# maxnreg corresponds to the ptx parameter .maxnreg, which controls the
105106
# maximum number of 32-bit registers used by one thread.
106107
maxnreg: Optional[int] = None

0 commit comments

Comments
 (0)