Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2113,7 +2113,8 @@ class IterMapToExprNormalizer : public ExprMutator {
}
if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
return source * expr->scale;
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent) ||
analyzer_->CanProve(expr->source->extent == expr->extent * expr->lower_factor)) {
// Simplify if `expr` is always 0. The 2nd condition guarantess that we do not aggressively
// simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis
// like tensorization.
Expand All @@ -2122,8 +2123,8 @@ class IterMapToExprNormalizer : public ExprMutator {
}
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *
expr->scale;
PrimExpr full_extent = analyzer_->canonical_simplify(expr->extent * expr->lower_factor);
return floordiv(floormod(source, full_extent), expr->lower_factor) * expr->scale;
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,15 +922,16 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array<StmtSRef>& loop_srefs,
bits = std::max(bits, loops[i]->loop_var.dtype().bits());
}
suffix += "_fused";

Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits));
ffi::Array<PrimExpr> substitute_value;
substitute_value.resize(loops.size());
PrimExpr lower = 1;
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
substitute_value.Set(i, is_one(loops[i]->extent)
? 0
: floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
lower = lower * loops[i]->extent;
PrimExpr next_lower = analyzer.canonical_simplify(loops[i]->extent * lower);
substitute_value.Set(
i, is_one(loops[i]->extent) ? 0 : floordiv(floormod(fused_var, next_lower), lower));
lower = next_lower;
}
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
Stmt new_stmt = loops.back()->body;
Expand All @@ -947,7 +948,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array<StmtSRef>& loop_srefs,
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt));
// Step 3. Generate a loop to replace the original loops
PrimExpr fused_extent = 1;
for (int i = 0; i < n; i++) {
for (int i = 0; i < n; ++i) {
fused_extent *= loops[i]->extent;
}
fused_extent = analyzer.Simplify(fused_extent);
Expand Down
4 changes: 2 additions & 2 deletions tests/python/dlight/test_gpu_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_tabl
for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen + 1023) // 1024, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("block"):
v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead))
v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen)
v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // (nhead * seqlen))
v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (nhead * seqlen) // seqlen)
v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % seqlen)
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < nlayer * nhead * seqlen)
T.reads(pages[page_table_values[page_table_indptr[seq_id] + v2 // page_size], v0, v1, v2 % page_size], page_table_values[page_table_indptr[seq_id] + v2 // page_size], page_table_indptr[seq_id])
Expand Down
8 changes: 3 additions & 5 deletions tests/python/dlight/test_gpu_general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
with T.block("max"):
v0 = T.axis.spatial(
batch_size,
ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0,
ax0_ax1_fused // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
v2 = T.axis.reduce(
Expand Down Expand Up @@ -645,7 +645,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
with T.block("sum_exp"):
v0 = T.axis.spatial(
batch_size,
ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0,
ax0_ax1_fused // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
v2 = T.axis.reduce(
Expand Down Expand Up @@ -678,9 +678,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
},
):
with T.block("log"):
v0 = T.axis.spatial(
batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks
)
v0 = T.axis.spatial(batch_size, ax0_ax1_fused // num_chunks)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks)
v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1)
T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5
for ax0_ax1_ax2_ax3_fused in range(rh_0 % 2 * 96 + 96):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (rh_0 % 2 + 1)) // 96)
v1 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (rh_0 % 2 * 96 + 96) // 96)
v2 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
v3 = T.axis.spatial(512, rc_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
Expand Down
35 changes: 35 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,5 +822,40 @@ def before(a: T.handle):
assert warning_msg in captured


def test_fused_symbolic_2D_tiling():
@T.prim_func
def before(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
A = T.match_buffer(a, (M, N))
B = T.match_buffer(b, (M, N))
for i, j in T.grid(M, N):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

@T.prim_func
def expected(a: T.handle, b: T.handle, M: T.int32, N: T.int32):
A = T.match_buffer(a, (M, N))
B = T.match_buffer(b, (M, N))
for i_0_j_0_fused, i_1, j_1 in T.grid(((M + 63) // 64) * ((N + 15) // 16), 64, 16):
with T.block("B"):
vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1)
vj = T.axis.spatial(N, i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1)
T.where(
i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1 < M
and i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1 < N
)
B[vi, vj] = A[vi, vj] * T.float32(2.0)

sch = tir.Schedule(before, debug_mask="all")
block_b = sch.get_block("B")
i, j = sch.get_loops(block_b)
i0, i1 = sch.split(i, factors=[None, 64])
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(i0, j0, i1, j1)
sch.fuse(i0, j0)
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=before)


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), v
for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) // T.int64(262144)):
with T.block("T_broadcast_to"):
v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * x_0) // x_1)
v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) // x_1)
v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1)
T.where((ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1)
T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)]
Expand Down
Loading