Skip to content

Commit 21de3c1

Browse files
authored
Minor patch on the annotation and indents of TIR func dispatch (#34)
This PR adds the function annotation to those TIR functions after scheduling, so that the default GPU schedule pass will skip them. These annotations are missed to add before. This PR reindent some of the PrimFunc TVMScript to remove the redundant indents.
1 parent 25d235f commit 21de3c1

File tree

1 file changed

+71
-71
lines changed

1 file changed

+71
-71
lines changed

web_llm/transform/dispatch_tir_operator.py

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,28 +1973,28 @@ def decode_sch_func(orig_func):
19731973

19741974
@T.prim_func
19751975
def fused_decode3_matmul1_before(lv2931: T.Buffer((T.int64(512), T.int64(32001)), "uint32"), lv2932: T.Buffer((T.int64(128), T.int64(32001)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32001)), "float32")):
1976-
T.func_attr({"tir.noalias": T.bool(True)})
1977-
# with T.block("root"):
1978-
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32001)))
1979-
for i, j in T.grid(T.int64(4096), T.int64(32001)):
1980-
with T.block("decode"):
1981-
v_i, v_j = T.axis.remap("SS", [i, j])
1982-
T.reads(lv2931[v_i // T.int64(8), v_j], lv2932[v_i // T.int64(32), v_j])
1983-
T.writes(var_decode_intermediate[v_i, v_j])
1984-
var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv2932[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv2932[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
1985-
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32001), T.int64(4096)):
1986-
with T.block("matmul"):
1987-
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
1988-
T.reads(lv1511[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
1989-
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
1990-
with T.init():
1991-
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
1992-
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1511[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
1976+
T.func_attr({"tir.noalias": T.bool(True)})
1977+
# with T.block("root"):
1978+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32001)))
1979+
for i, j in T.grid(T.int64(4096), T.int64(32001)):
1980+
with T.block("decode"):
1981+
v_i, v_j = T.axis.remap("SS", [i, j])
1982+
T.reads(lv2931[v_i // T.int64(8), v_j], lv2932[v_i // T.int64(32), v_j])
1983+
T.writes(var_decode_intermediate[v_i, v_j])
1984+
var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv2932[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv2932[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
1985+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32001), T.int64(4096)):
1986+
with T.block("matmul"):
1987+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
1988+
T.reads(lv1511[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
1989+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
1990+
with T.init():
1991+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
1992+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1511[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
19931993

19941994

19951995
@T.prim_func
19961996
def fused_decode3_matmul1_after(lv1123: T.Buffer((T.int64(512), T.int64(32001)), "uint32"), lv1124: T.Buffer((T.int64(128), T.int64(32001)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32001)), "float32")):
1997-
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
1997+
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1})
19981998
# with T.block("root"):
19991999
var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(35072)), scope="local")
20002000
var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(35072)), scope="local")
@@ -2415,59 +2415,59 @@ def fused_decode6_fused_matmul9_add3_before(lv1623: T.Buffer((T.int64(1376), T.i
24152415

24162416
@T.prim_func
24172417
def fused_decode6_fused_matmul9_add3_after(lv1158: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv1159: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")):
2418-
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
2419-
# with T.block("root"):
2420-
var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local")
2421-
var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local")
2422-
lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="shared")
2423-
for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}):
2424-
for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
2425-
for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
2426-
with T.block("matmul_init"):
2427-
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
2428-
v_i1 = T.axis.spatial(T.int64(1), T.int64(0))
2429-
v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)
2430-
T.reads()
2431-
T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])
2432-
var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)
2433-
for k_0_0 in range(T.int64(2)):
2434-
for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)):
2435-
for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
2436-
with T.block("lv6_shared"):
2437-
v0 = T.axis.spatial(T.int64(1), ax0)
2438-
v1 = T.axis.spatial(T.int64(1), T.int64(0))
2439-
v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1))
2440-
T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504))
2441-
T.reads(lv6[v0, v1, v2])
2442-
T.writes(lv6_shared[v0, v1, v2])
2443-
T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]})
2444-
lv6_shared[v0, v1, v2] = lv6[v0, v1, v2]
2445-
for k_0_1 in range(T.int64(86)):
2446-
for ax0_0 in range(T.int64(8)):
2447-
for ax0_1 in T.unroll(T.int64(8)):
2448-
for ax1 in range(T.int64(1)):
2449-
with T.block("decode"):
2450-
v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)
2451-
v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)
2452-
T.reads(lv1158[v_j // T.int64(8), v_i], lv1159[v_j // T.int64(32), v_i])
2453-
T.writes(var_decode_intermediate_local[v_j, v_i])
2454-
var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1159[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1159[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
2455-
for k_0_2_k_1_fused in range(T.int64(64)):
2456-
with T.block("matmul_update"):
2457-
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
2458-
v_i1 = T.axis.spatial(T.int64(1), T.int64(0))
2459-
v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)
2460-
v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused)
2461-
T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])
2462-
T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])
2463-
var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2]
2464-
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):
2465-
with T.block("var_matmul_intermediate_local"):
2466-
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2467-
v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)
2468-
T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])
2469-
T.writes(p_output0_intermediate[v0, v1, v2])
2470-
p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]
2418+
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1})
2419+
# with T.block("root"):
2420+
var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local")
2421+
var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local")
2422+
lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="shared")
2423+
for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}):
2424+
for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
2425+
for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
2426+
with T.block("matmul_init"):
2427+
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
2428+
v_i1 = T.axis.spatial(T.int64(1), T.int64(0))
2429+
v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)
2430+
T.reads()
2431+
T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])
2432+
var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)
2433+
for k_0_0 in range(T.int64(2)):
2434+
for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)):
2435+
for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
2436+
with T.block("lv6_shared"):
2437+
v0 = T.axis.spatial(T.int64(1), ax0)
2438+
v1 = T.axis.spatial(T.int64(1), T.int64(0))
2439+
v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1))
2440+
T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504))
2441+
T.reads(lv6[v0, v1, v2])
2442+
T.writes(lv6_shared[v0, v1, v2])
2443+
T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]})
2444+
lv6_shared[v0, v1, v2] = lv6[v0, v1, v2]
2445+
for k_0_1 in range(T.int64(86)):
2446+
for ax0_0 in range(T.int64(8)):
2447+
for ax0_1 in T.unroll(T.int64(8)):
2448+
for ax1 in range(T.int64(1)):
2449+
with T.block("decode"):
2450+
v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)
2451+
v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)
2452+
T.reads(lv1158[v_j // T.int64(8), v_i], lv1159[v_j // T.int64(32), v_i])
2453+
T.writes(var_decode_intermediate_local[v_j, v_i])
2454+
var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1159[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1159[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
2455+
for k_0_2_k_1_fused in range(T.int64(64)):
2456+
with T.block("matmul_update"):
2457+
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
2458+
v_i1 = T.axis.spatial(T.int64(1), T.int64(0))
2459+
v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)
2460+
v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused)
2461+
T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])
2462+
T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])
2463+
var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2]
2464+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):
2465+
with T.block("var_matmul_intermediate_local"):
2466+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2467+
v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)
2468+
T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])
2469+
T.writes(p_output0_intermediate[v0, v1, v2])
2470+
p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]
24712471
# fmt: on
24722472

24732473
################################################

0 commit comments

Comments
 (0)