Skip to content

Commit 077e8eb

Browse files
authored
[Fix][KVCache] Fix incorrect tile size calculation (#17595)
This PR fixes the tile size calculation in the TIR attention kernels, where the computed tile sizes may not divide the total loop extent.
1 parent da2e89a commit 077e8eb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ def get_tile_size(x, y, t):
855855
cnt = (x * y) // t
856856
assert (x * y) % t == 0
857857
tile_y = (int)(math.ceil(math.sqrt(cnt)))
858-
while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
858+
while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt:
859859
tile_y += 1
860860
assert tile_y <= cnt
861861
tile_x = cnt // tile_y
@@ -1509,7 +1509,7 @@ def get_tile_size(x, y, t):
15091509
cnt = (x * y) // t
15101510
assert (x * y) % t == 0
15111511
tile_y = (int)(math.ceil(math.sqrt(cnt)))
1512-
while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
1512+
while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt:
15131513
tile_y += 1
15141514
assert tile_y <= cnt
15151515
tile_x = cnt // tile_y
@@ -1867,7 +1867,7 @@ def get_tile_size(x, y, t):
18671867
cnt = (x * y) // t
18681868
assert (x * y) % t == 0
18691869
tile_y = (int)(math.ceil(math.sqrt(cnt)))
1870-
while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
1870+
while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt:
18711871
tile_y += 1
18721872
assert tile_y <= cnt
18731873
tile_x = cnt // tile_y

0 commit comments

Comments
 (0)