Skip to content

Commit de560d2

Browse files
authored
Refactor to match TVM Upstream PR#18689 (#3413)
Refactoring to match the changes introduced in apache/tvm#18689. 3rdparty/tvm is also updated.
1 parent af1020e commit de560d2

21 files changed

+126
-126
lines changed

3rdparty/tvm

Submodule tvm updated 507 files

python/mlc_llm/compiler_pass/attach_logit_processor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _apply_logit_bias_inplace(
101101

102102
for p0 in T.thread_binding(0, (num_token + tx - 1) // tx, "blockIdx.x"):
103103
for p1 in T.thread_binding(0, tx, "threadIdx.x"):
104-
with T.block("block"):
104+
with T.sblock("block"):
105105
vp = T.axis.spatial(num_token, p0 * tx + p1)
106106
T.where(p0 * tx + p1 < num_token)
107107
logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp]
@@ -139,7 +139,7 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
139139
penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32")
140140

141141
for token in T.serial(num_token):
142-
with T.block("block"):
142+
with T.sblock("block"):
143143
vp = T.axis.spatial(num_token, token)
144144
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (
145145
penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]
@@ -189,7 +189,7 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
189189

190190
for p0 in T.thread_binding(0, (num_token + tx - 1) // tx, "blockIdx.x"):
191191
for p1 in T.thread_binding(0, tx, "threadIdx.x"):
192-
with T.block("block"):
192+
with T.sblock("block"):
193193
vp = T.axis.spatial(num_token, p0 * tx + p1)
194194
T.where(p0 * tx + p1 < num_token)
195195
# Penalties: (presence_penalty, frequency_penalty, repetition_penalty)
@@ -230,7 +230,7 @@ def _apply_bitmask_inplace(
230230
bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
231231

232232
for token in T.serial(num_seq * vocab_size):
233-
with T.block("block"):
233+
with T.sblock("block"):
234234
vs = T.axis.spatial(num_seq, (token) // vocab_size)
235235
vv = T.axis.spatial(vocab_size, (token) % vocab_size)
236236

@@ -272,7 +272,7 @@ def _apply_bitmask_inplace(
272272

273273
for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + tx - 1) // tx, "blockIdx.x"):
274274
for fused_s_v_1 in T.thread_binding(0, tx, "threadIdx.x"):
275-
with T.block("block"):
275+
with T.sblock("block"):
276276
vs = T.axis.spatial(num_seq, (fused_s_v_0 * tx + fused_s_v_1) // vocab_size)
277277
vv = T.axis.spatial(vocab_size, (fused_s_v_0 * tx + fused_s_v_1) % vocab_size)
278278
T.where(fused_s_v_0 * tx + fused_s_v_1 < num_seq * vocab_size)

python/mlc_llm/compiler_pass/attach_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def full(var_result: T.handle, value: T.int32):
144144
batch_size = T.int32(is_size_var=True)
145145
result = T.match_buffer(var_result, (batch_size, 1), "int32")
146146
for i in T.serial(batch_size):
147-
with T.block("block"):
147+
with T.sblock("block"):
148148
vi = T.axis.spatial(batch_size, i)
149149
result[vi, 0] = value
150150

@@ -305,7 +305,7 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument
305305
top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,), "float32")
306306
top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
307307
for i in T.serial(num_positions + num_samples):
308-
with T.block("block"):
308+
with T.sblock("block"):
309309
vi = T.axis.spatial(num_positions + num_samples, i)
310310
if vi < num_positions:
311311
row = T.floordiv(top_prob_offsets[vi], vocab_size)

python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def chunk_lse( # pylint: disable=too-many-locals
131131
temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32")
132132

133133
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):
134-
with T.block("pad"):
134+
with T.sblock("pad"):
135135
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
136136
A_pad[v0, v1, v2] = T.Select(
137137
v1 * T.int64(chunk_size) + v2
@@ -144,13 +144,13 @@ def chunk_lse( # pylint: disable=too-many-locals
144144
T.min_value("float32"),
145145
)
146146
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):
147-
with T.block("max"):
147+
with T.sblock("max"):
148148
v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
149149
with T.init():
150150
temp_max[v0, v1] = T.min_value("float32")
151151
temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])
152152
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):
153-
with T.block("sum_exp"):
153+
with T.sblock("sum_exp"):
154154
v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
155155
with T.init():
156156
temp_sum[v0, v1] = T.float32(0)
@@ -165,7 +165,7 @@ def chunk_lse( # pylint: disable=too-many-locals
165165
T.float32(0),
166166
)
167167
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
168-
with T.block("log"):
168+
with T.sblock("log"):
169169
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
170170
chunked_sum[v0, v1] = T.Select(
171171
temperature[v0] > T.float32(1e-5),
@@ -194,13 +194,13 @@ def softmax_with_chunked_sum(
194194
temp_max = T.alloc_buffer((batch_size,), dtype="float32")
195195
temp_sum = T.alloc_buffer((batch_size,), dtype="float32")
196196
for l0, l1 in T.grid(batch_size, num_chunks):
197-
with T.block("max"):
197+
with T.sblock("max"):
198198
v0, v1 = T.axis.remap("SR", [l0, l1])
199199
with T.init():
200200
temp_max[v0] = T.min_value("float32")
201201
temp_max[v0] = T.max(temp_max[v0], chunked_max[v0, v1])
202202
for l0, l1 in T.grid(batch_size, num_chunks):
203-
with T.block("sum_exp"):
203+
with T.sblock("sum_exp"):
204204
v0, v1 = T.axis.remap("SR", [l0, l1])
205205
with T.init():
206206
temp_sum[v0] = T.float32(0)
@@ -210,7 +210,7 @@ def softmax_with_chunked_sum(
210210
T.cast(chunked_max[v0, v1] == temp_max[v0], "float32") * chunked_sum[v0, v1],
211211
)
212212
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):
213-
with T.block("log_pad"):
213+
with T.sblock("log_pad"):
214214
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
215215
if v1 * T.int64(chunk_size) + v2 < vocab_size:
216216
softmax[v0, v1 * T.int64(chunk_size) + v2] = T.Select(
@@ -248,7 +248,7 @@ def apply_gpu_schedule(target, sch):
248248
sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)
249249

250250
for block_name in ["sum_exp", "max"]:
251-
block = sch.get_block(block_name)
251+
block = sch.get_sblock(block_name)
252252
sch.set_scope(block, buffer_index=0, storage_scope="shared")
253253
sch.compute_at(block, bx)
254254
r_loop = sch.get_loops(block)[-1]

python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
4646
indices = T.match_buffer(var_indices, (batch_size,), "int32")
4747
dst = T.match_buffer(var_dst, (m, n), dtype)
4848
for b, j in T.grid(batch_size, n):
49-
with T.block("scatter_2d"):
49+
with T.sblock("scatter_2d"):
5050
vb, vj = T.axis.remap("SS", [b, j])
5151
dst[indices[vb], vj] = src[vb, vj]
5252

@@ -64,7 +64,7 @@ def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
6464
indices = T.match_buffer(var_indices, (batch_size,), "int32")
6565
dst = T.match_buffer(var_dst, (batch_size, n), dtype)
6666
for b, j in T.grid(batch_size, n):
67-
with T.block("gather_2d"):
67+
with T.sblock("gather_2d"):
6868
vb, vj = T.axis.remap("SS", [b, j])
6969
dst[vb, vj] = src[indices[vb], vj]
7070

python/mlc_llm/compiler_pass/fuse_add_norm.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,25 @@ def decode_add_rms( # pylint: disable=too-many-locals
4141
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
4242
):
4343
for i in range(add_local_size):
44-
with T.block("T_add"):
44+
with T.sblock("T_add"):
4545
bx = T.axis.spatial(batch_size, v_bx)
4646
h = T.axis.spatial(hidden_size, i * TX + v_tx)
4747
add_local[h // TX] = A[bx, 0, h] + B[bx, 0, h]
48-
with T.block("T_write_back"):
48+
with T.sblock("T_write_back"):
4949
bx = T.axis.spatial(batch_size, v_bx)
5050
v_ax1 = T.axis.spatial(1, 0)
5151
h = T.axis.spatial(hidden_size, i * TX + v_tx)
5252
add[bx, v_ax1, h] = add_local[h // TX]
53-
with T.block("T_multiply_red_rf_init"):
53+
with T.sblock("T_multiply_red_rf_init"):
5454
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
5555
sum_local[tx, bx, 0] = T.float32(0)
5656
for v_i, _j in T.grid(add_local_size, 1):
57-
with T.block("T_multiply_red_rf_update"):
57+
with T.sblock("T_multiply_red_rf_update"):
5858
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
5959
sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i])
6060
for _j in range(1):
6161
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
62-
with T.block("T_multiply_red"):
62+
with T.sblock("T_multiply_red"):
6363
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
6464
T.reads(sum_local[tx, bx, 0])
6565
T.writes(sum_shared[bx, 0])
@@ -68,7 +68,7 @@ def decode_add_rms( # pylint: disable=too-many-locals
6868
sum_shared[bx, 0] += sum_local[tx, bx, 0]
6969
for i in range(add_local_size):
7070
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
71-
with T.block("T_cast_2"):
71+
with T.sblock("T_cast_2"):
7272
bx = T.axis.spatial(batch_size, v_bx)
7373
h = T.axis.spatial(hidden_size, i * TX + v_tx_2)
7474
O[bx, 0, h] = T.cast(
@@ -109,31 +109,31 @@ def prefill_add_rms( # pylint: disable=too-many-locals
109109
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1},
110110
):
111111
for v_i in range(add_local_size):
112-
with T.block("T_add"):
112+
with T.sblock("T_add"):
113113
bx = T.axis.spatial(seq_len, v_bx)
114114
h = T.axis.spatial(hidden_size, v_i * TX + v_tx)
115115
add_local[h // TX] = A[0, bx, h] + B[0, bx, h]
116-
with T.block("T_write_back"):
116+
with T.sblock("T_write_back"):
117117
bx = T.axis.spatial(seq_len, v_bx)
118118
h = T.axis.spatial(hidden_size, v_i * TX + v_tx)
119119
add[0, bx, h] = add_local[h // TX]
120-
with T.block("T_multiply_red_rf_init"):
120+
with T.sblock("T_multiply_red_rf_init"):
121121
tx, bx = T.axis.remap("SS", [v_tx, v_bx])
122122
sum_local[tx, 0, bx] = T.float32(0)
123123
for v_i, _j in T.grid(add_local_size, 1):
124-
with T.block("T_multiply_red_rf_update"):
124+
with T.sblock("T_multiply_red_rf_update"):
125125
tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
126126
sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i])
127127
for _j in range(1):
128128
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
129-
with T.block("T_multiply_red"):
129+
with T.sblock("T_multiply_red"):
130130
tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
131131
with T.init():
132132
sum_shared[0, bx] = T.float32(0)
133133
sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
134134
for v_i in range(add_local_size):
135135
for v_tx_2 in T.thread_binding(TX, thread="threadIdx.x"):
136-
with T.block("T_cast_2"):
136+
with T.sblock("T_cast_2"):
137137
bx = T.axis.spatial(seq_len, v_bx)
138138
v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2)
139139
O[0, bx, v1] = T.cast(

python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def visit_call_( # pylint: disable=arguments-renamed
7474
or not isinstance(dequantize_tir_func.body.block.body, tir.SeqStmt)
7575
or len(dequantize_tir_func.body.block.body) != 2
7676
or not isinstance(dequantize_tir_func.body.block.body[1], tir.For)
77-
or not isinstance(dequantize_tir_func.body.block.body[1].body.body, tir.BlockRealize)
77+
or not isinstance(dequantize_tir_func.body.block.body[1].body.body, tir.SBlockRealize)
7878
or dequantize_tir_func.body.block.body[1].body.body.block.name_hint != "T_transpose"
7979
):
8080
return call
@@ -85,10 +85,10 @@ def visit_call_( # pylint: disable=arguments-renamed
8585
new_func_buffers[-1] = dequantize_tir_func.body.block.alloc_buffers[0]
8686
new_func = tir.PrimFunc(
8787
params=new_func_buffers,
88-
body=tir.BlockRealize(
88+
body=tir.SBlockRealize(
8989
iter_values=[],
9090
predicate=True,
91-
block=tir.Block(
91+
block=tir.SBlock(
9292
iter_vars=[],
9393
reads=[],
9494
writes=[],

python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def remove_global_buf_alloc(
9393
func: tir.PrimFunc,
9494
) -> Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]:
9595
"""Remove the global buffer allocation for a given TIR PrimFunc."""
96-
assert isinstance(func.body, tir.BlockRealize)
96+
assert isinstance(func.body, tir.SBlockRealize)
9797
params = list(func.params)
9898
buffer_map = dict(func.buffer_map)
9999
tensor_sinfo = []
@@ -124,7 +124,7 @@ def remove_global_buf_alloc(
124124
assert len(prev_root_block.match_buffers) == 0
125125
assert prev_root_block.name_hint == "root"
126126
assert prev_root_block.init is None
127-
root_block = tir.Block(
127+
root_block = tir.SBlock(
128128
iter_vars=[],
129129
reads=[],
130130
writes=[],
@@ -136,7 +136,7 @@ def remove_global_buf_alloc(
136136

137137
updated_func = tir.PrimFunc(
138138
params=params,
139-
body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block),
139+
body=tir.SBlockRealize(iter_values=[], predicate=True, block=root_block),
140140
ret_type=func.ret_type,
141141
buffer_map=buffer_map,
142142
attrs=func.attrs,

python/mlc_llm/compiler_pass/low_batch_specialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def transform_module(
5656
low_batch_funcs[i].body,
5757
body,
5858
)
59-
body = tir.Block([], [], [], "root", body)
60-
body = tir.BlockRealize([], True, body)
59+
body = tir.SBlock([], [], [], "root", body)
60+
body = tir.SBlockRealize([], True, body)
6161
new_func = func.with_body(body)
6262
new_func = new_func.with_attr("tir.is_scheduled", 1)
6363
new_func = new_func.with_attr("tir.HoistIfThenElseExprWithBlock", 1)

python/mlc_llm/model/phi3v/phi3v_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def dyn_repeat_4d_tensor_func( # pylint disable=too-many-locals
9898
for n_idx in T.thread_binding(n * ch0, thread="blockIdx.x"):
9999
for c_idx in T.thread_binding(c * ch1, thread="blockIdx.y"):
100100
for h_idx, w_idx in T.grid(h * ch2, w * ch3):
101-
with T.block("dyn_repeat_4d_tensor"):
101+
with T.sblock("dyn_repeat_4d_tensor"):
102102
T.reads(input_tensor_buf[n_idx, c_idx, h_idx, w_idx])
103103
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
104104
out_buf[n_idx, c_idx, h_idx, w_idx] = input_tensor_buf[
@@ -129,7 +129,7 @@ def dyn_concate_dim_2_func(input_1: T.handle, input_2: T.handle, output: T.handl
129129
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
130130
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
131131
for h_idx, w_idx in T.grid(h1 + h2, w):
132-
with T.block("dyn_concate_dim_2"):
132+
with T.sblock("dyn_concate_dim_2"):
133133
T.reads(input_1_buf[n_idx, c_idx, h_idx, w_idx])
134134
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
135135
if h_idx < h1:
@@ -167,7 +167,7 @@ def dyn_concate_dim_1_func(input_1: T.handle, input_2: T.handle, output: T.handl
167167

168168
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
169169
for h_idx, w_idx in T.grid(h1 + h2, w):
170-
with T.block("dyn_concate_dim_1"):
170+
with T.sblock("dyn_concate_dim_1"):
171171
T.reads(input_1_buf[c_idx, h_idx, w_idx])
172172
T.writes(out_buf[c_idx, h_idx, w_idx])
173173
if h_idx < h1:

0 commit comments

Comments
 (0)