Skip to content

Commit 8dd7ccb

Browse files
authored
[KERNELS] Fix warnings from tl.where (#7630)
``` tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got int64 ```
1 parent 7a9c004 commit 8dd7ccb

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

python/triton_kernels/triton_kernels/compaction_details/_masked_compaction.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sent
1111
rem = yi % 32
1212
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
1313
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
14-
rev_arange = tl.where(active_bits, 0, K - 1 - tl.arange(0, K))
14+
active_flags = active_bits.to(tl.int1)
15+
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
1516
write_indx = exc_cumsum + rev_arange
16-
yv = tl.where(active_bits, yv, sentinel)
17-
yi = tl.where(active_bits, yi, sentinel)
17+
yv = tl.where(active_flags, yv, sentinel)
18+
yi = tl.where(active_flags, yi, sentinel)
1819
tl.store(RetYv + pid_m * K + write_indx, yv)
1920
tl.store(RetYi + pid_m * K + write_indx, yi)

0 commit comments

Comments
 (0)