Skip to content

Commit ac0d9b0

Browse files
authored
[TMEM] Fix missing log2 in 32x32b_split and fix the error message (triton-lang#8696)
1 parent 1a9734c commit ac0d9b0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant
6565
layout_obj.lane_bases[-1] = [0, 0]
6666
elif layout_obj.reg_bases[-1] != [0, N // 2]:
6767
bitwidth = element_ty.primitive_bitwidth
68+
num_reg = 2**len(layout_obj.reg_bases)
6869
_check(
69-
len(layout_obj.reg_bases) * bitwidth > 32,
70-
lambda: "splitn requires register bases of more than 2 32 bit registers")
70+
num_reg > 32 // bitwidth, lambda: "To be able to `tmem.load` into `tl.split` you need to have more "
71+
f"than {32 // bitwidth} {bitwidth}-bit registers, as you need to use "
72+
"the instruction 32x32b.x1 twice. You can always load into "
73+
"instr_variant=\"32x32b\" and then convert_layout to this layout otherwise.")
7174

7275
reg_bases = layout_obj.reg_bases
7376
for bases_str in ("lane_bases", "warp_bases"):

0 commit comments

Comments
 (0)