Skip to content

Commit f8d0c10

Browse files
committed
blocksize check
1 parent b19458a commit f8d0c10

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,8 +2574,19 @@ def pointwise(
25742574
"""
25752575
def addConfig__(xblock:int, yblock:int, num_warps:int):
25762576
# only add a tiling config if size is bigger than the tile
2577-
if size_hints["x"] >= xblock and size_hints["y"] >= yblock:
2578-
configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps))
2577+
# check also for grid overflow
2578+
xgrid = (size_hints["x"] + xblock - 1) // xblock
2579+
ygrid = (size_hints["y"] + yblock - 1) // yblock
2580+
if xgrid > 2147483647:
2581+
return
2582+
if ygrid > 65535:
2583+
return
2584+
if size_hints["x"] < xblock:
2585+
return
2586+
if size_hints["y"] < yblock:
2587+
return
2588+
# all good, add the config
2589+
configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps))
25792590
addConfig__(512, 8, 8) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
25802591
addConfig__(32, 128, 4) # wrt2: 570us : triton_poi_fused_add_transpose_view_52
25812592
addConfig__(64, 32, 8) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103

0 commit comments

Comments
 (0)