diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 3de431fb9574..f17596dd8e0d 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -2113,7 +2113,8 @@ class IterMapToExprNormalizer : public ExprMutator { } if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { return source * expr->scale; - } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { + } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent) || + analyzer_->CanProve(expr->source->extent == expr->extent * expr->lower_factor)) { // Simplify if `expr` is always 0. The 2nd condition guarantess that we do not aggressively // simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis // like tensorization. @@ -2122,8 +2123,8 @@ class IterMapToExprNormalizer : public ExprMutator { } return floordiv(source, expr->lower_factor) * expr->scale; } else { - return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) * - expr->scale; + PrimExpr full_extent = analyzer_->canonical_simplify(expr->extent * expr->lower_factor); + return floordiv(floormod(source, full_extent), expr->lower_factor) * expr->scale; } } diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index b2c64e65e568..6b6307e2b726 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -922,15 +922,16 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, bits = std::max(bits, loops[i]->loop_var.dtype().bits()); } suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits)); ffi::Array substitute_value; substitute_value.resize(loops.size()); PrimExpr lower = 1; for (int i = static_cast(loops.size()) - 1; i > 0; i--) { - substitute_value.Set(i, is_one(loops[i]->extent) - ? 0 - : floordiv(floormod(fused_var, lower * loops[i]->extent), lower)); - lower = lower * loops[i]->extent; + PrimExpr next_lower = analyzer.canonical_simplify(loops[i]->extent * lower); + substitute_value.Set( + i, is_one(loops[i]->extent) ? 0 : floordiv(floormod(fused_var, next_lower), lower)); + lower = next_lower; } substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; @@ -947,7 +948,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Generate a loop to replace the original loops PrimExpr fused_extent = 1; - for (int i = 0; i < n; i++) { + for (int i = 0; i < n; ++i) { fused_extent *= loops[i]->extent; } fused_extent = analyzer.Simplify(fused_extent); diff --git a/tests/python/dlight/test_gpu_fallback.py b/tests/python/dlight/test_gpu_fallback.py index a4eaa3ad748c..07be6067b7f2 100644 --- a/tests/python/dlight/test_gpu_fallback.py +++ b/tests/python/dlight/test_gpu_fallback.py @@ -161,8 +161,8 @@ def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_tabl for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen + 1023) // 1024, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.block("block"): - v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead)) - v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen) + v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // (nhead * seqlen)) + v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (nhead * seqlen) // seqlen) v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % seqlen) T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < nlayer * nhead * seqlen) T.reads(pages[page_table_values[page_table_indptr[seq_id] + v2 // page_size], v0, v1, v2 % page_size], page_table_values[page_table_indptr[seq_id] + v2 // page_size], page_table_indptr[seq_id]) diff --git a/tests/python/dlight/test_gpu_general_reduction.py b/tests/python/dlight/test_gpu_general_reduction.py index aafe76f900e4..941fc2ee37b4 100644 --- a/tests/python/dlight/test_gpu_general_reduction.py +++ b/tests/python/dlight/test_gpu_general_reduction.py @@ -615,7 +615,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): with T.block("max"): v0 = T.axis.spatial( batch_size, - ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0, + ax0_ax1_fused // num_chunks + ax0, ) v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) v2 = T.axis.reduce( @@ -645,7 +645,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): with T.block("sum_exp"): v0 = T.axis.spatial( batch_size, - ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0, + ax0_ax1_fused // num_chunks + ax0, ) v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) v2 = T.axis.reduce( @@ -678,9 +678,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): }, ): with T.block("log"): - v0 = T.axis.spatial( - batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks - ) + v0 = T.axis.spatial(batch_size, ax0_ax1_fused // num_chunks) v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks) v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1) T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1)) diff --git a/tests/python/meta_schedule/test_meta_schedule_space_cuda.py b/tests/python/meta_schedule/test_meta_schedule_space_cuda.py index d05ade960164..683721ab8e82 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_cuda.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_cuda.py @@ -722,7 +722,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 for ax0_ax1_ax2_ax3_fused in range(rh_0 % 2 * 96 + 96): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (rh_0 % 2 + 1)) // 96) + v1 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (rh_0 % 2 * 96 + 96) // 96) v2 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32) v3 = T.axis.spatial(512, rc_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(inputs[v0, v1 - 1, v2 - 1, v3]) diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index f09f7417baf6..b4e55355b271 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -822,5 +822,40 @@ def before(a: T.handle): assert warning_msg in captured +def test_fused_symbolic_2D_tiling(): + @T.prim_func + def before(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None: + A = T.match_buffer(a, (M, N)) + B = T.match_buffer(b, (M, N)) + for i, j in T.grid(M, N): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle, M: T.int32, N: T.int32): + A = T.match_buffer(a, (M, N)) + B = T.match_buffer(b, (M, N)) + for i_0_j_0_fused, i_1, j_1 in T.grid(((M + 63) // 64) * ((N + 15) // 16), 64, 16): + with T.block("B"): + vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1) + vj = T.axis.spatial(N, i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1) + T.where( + i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1 < M + and i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1 < N + ) + B[vi, vj] = A[vi, vj] * T.float32(2.0) + + sch = tir.Schedule(before, debug_mask="all") + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + i0, i1 = sch.split(i, factors=[None, 64]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.fuse(i0, j0) + assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py b/tests/python/tir-transform/test_transform_default_gpu_schedule.py index 33f3933cb4e5..936282b94966 100644 --- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py +++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py @@ -54,7 +54,7 @@ def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), v for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) // T.int64(262144)): with T.block("T_broadcast_to"): - v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * x_0) // x_1) + v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) // x_1) v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1) T.where((ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1) T_broadcast_to[v_ax0, v_ax1] = rxplaceholder[v_ax0, T.int64(0)]