Skip to content

Commit f2ffe95

Browse files
Fix index term order in loop fusion
1 parent f532b89 commit f2ffe95

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/tir/schedule/primitive/loop_transformation.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,8 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array<StmtSRef>& loop_srefs,
929929
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
930930
substitute_value.Set(i, is_one(loops[i]->extent)
931931
? 0
932-
: floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
933-
lower = lower * loops[i]->extent;
932+
: floordiv(floormod(fused_var, loops[i]->extent * lower), lower));
933+
lower = loops[i]->extent * lower;
934934
}
935935
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
936936
Stmt new_stmt = loops.back()->body;
@@ -947,7 +947,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array<StmtSRef>& loop_srefs,
947947
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt));
948948
// Step 3. Generate a loop to replace the original loops
949949
PrimExpr fused_extent = 1;
950-
for (int i = 0; i < n; i++) {
950+
for (int i = n - 1; i >= 0; --i) {
951951
fused_extent *= loops[i]->extent;
952952
}
953953
fused_extent = analyzer.Simplify(fused_extent);

tests/python/tir-schedule/test_tir_schedule_split_fuse.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,5 +822,38 @@ def before(a: T.handle):
822822
assert warning_msg in captured
823823

824824

825+
826+
def test_fused_symbolic_2D_tiling():
827+
@T.prim_func
828+
def before(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
829+
A = T.match_buffer(a, (M, N))
830+
B = T.match_buffer(b, (M, N))
831+
for i, j in T.grid(M, N):
832+
with T.block("B"):
833+
vi, vj = T.axis.remap("SS", [i, j])
834+
B[vi, vj] = A[vi, vj] * 2.0
835+
836+
@T.prim_func
837+
def expected(a: T.handle, b: T.handle, M: T.int32, N: T.int32):
838+
A = T.match_buffer(a, (M, N))
839+
B = T.match_buffer(b, (M, N))
840+
for i_0_j_0_fused, i_1, j_1 in T.grid((N + 15) // 16 * ((M + 63) // 64), 64, 16):
841+
with T.block("B"):
842+
vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1)
843+
vj = T.axis.spatial(N, i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1)
844+
T.where(i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1 < M and i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1 < N)
845+
B[vi, vj] = A[vi, vj] * T.float32(2.0)
846+
847+
sch = tir.Schedule(before, debug_mask="all")
848+
block_b = sch.get_block("B")
849+
i, j = sch.get_loops(block_b)
850+
i0, i1 = sch.split(i, factors=[None, 64])
851+
j0, j1 = sch.split(j, factors=[None, 16])
852+
sch.reorder(i0, j0, i1, j1)
853+
sch.fuse(i0, j0)
854+
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
855+
verify_trace_roundtrip(sch=sch, mod=before)
856+
857+
825858
if __name__ == "__main__":
826859
tvm.testing.main()

0 commit comments

Comments
 (0)