diff --git a/apps/android_rpc/tests/android_rpc_test.py b/apps/android_rpc/tests/android_rpc_test.py index b1548df3e177..9fbb66985f11 100644 --- a/apps/android_rpc/tests/android_rpc_test.py +++ b/apps/android_rpc/tests/android_rpc_test.py @@ -58,7 +58,7 @@ def test_rpc_module(): mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]).with_attr("global_symbol", "myadd")) sch = tvm.tir.Schedule(mod) - (x,) = sch.get_loops(block=sch.get_block("B")) + (x,) = sch.get_loops(block=sch.get_sblock("B")) xo, xi = sch.split(i, [None, 32]) sch.bind(xo, "blockIdx.x") sch.bind(xi, "threadIdx.x") diff --git a/apps/ios_rpc/tests/ios_rpc_test.py b/apps/ios_rpc/tests/ios_rpc_test.py index df850812e527..eb55d1b78962 100644 --- a/apps/ios_rpc/tests/ios_rpc_test.py +++ b/apps/ios_rpc/tests/ios_rpc_test.py @@ -52,7 +52,7 @@ def test_rpc_module(host, port, key, mode): temp = utils.tempdir() mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]).with_attr("global_symbol", "myadd")) sch = tvm.tir.Schedule(mod) - (i,) = sch.get_loops(block=sch.get_block("B")) + (i,) = sch.get_loops(block=sch.get_sblock("B")) i0, i1 = sch.split(i, [None, 32]) sch.bind(i0, "blockIdx.x") sch.bind(i1, "threadIdx.x") diff --git a/docs/deep_dive/relax/learning.rst b/docs/deep_dive/relax/learning.rst index 702b0e0a9f29..6c16ff944b10 100644 --- a/docs/deep_dive/relax/learning.rst +++ b/docs/deep_dive/relax/learning.rst @@ -135,13 +135,13 @@ for the end-to-end model execution. The code block below shows a TVMScript imple Z = T.match_buffer(z, (M, N), "float32") Y = T.alloc_buffer((M, N), "float32") for i, j, k in T.grid(M, N, K): - with T.block("Y"): + with T.sblock("Y"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[v_i, v_j] = T.float32(0.0) Y[v_i, v_j] = Y[v_i, v_j] + X[v_i, v_k] * W[v_k, v_j] for i, j in T.grid(M, N): - with T.block("Z"): + with T.sblock("Z"): v_i, v_j = T.axis.remap("SS", [i, j]) Z[v_i, v_j] = Y[v_i, v_j] + B[v_j] @@ -151,7 +151,7 @@ for the end-to-end model execution. The code block below shows a TVMScript imple X = T.match_buffer(x, (M, N), "float32") Y = T.match_buffer(y, (M, N), "float32") for i, j in T.grid(M, N): - with T.block("Y"): + with T.sblock("Y"): v_i, v_j = T.axis.remap("SS", [i, j]) Y[v_i, v_j] = T.max(X[v_i, v_j], T.float32(0.0)) diff --git a/docs/deep_dive/relax/tutorials/relax_creation.py b/docs/deep_dive/relax/tutorials/relax_creation.py index f6278e3b65b1..7b2d9a1ad2b0 100644 --- a/docs/deep_dive/relax/tutorials/relax_creation.py +++ b/docs/deep_dive/relax/tutorials/relax_creation.py @@ -76,7 +76,7 @@ def relu(x: T.handle, y: T.handle): X = T.match_buffer(x, (n, m), "float32") Y = T.match_buffer(y, (n, m), "float32") for i, j in T.grid(n, m): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) Y[vi, vj] = T.max(X[vi, vj], T.float32(0)) @@ -170,13 +170,13 @@ def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle): B = T.match_buffer(b, (N,), "float32") Z = T.match_buffer(z, (M, N), "float32") for i, j, k in T.grid(M, N, K): - with T.block("linear"): + with T.sblock("linear"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Z[vi, vj] = 0 Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk] for i, j in T.grid(M, N): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Z[vi, vj] + B[vj] diff --git a/docs/deep_dive/tensor_ir/abstraction.rst b/docs/deep_dive/tensor_ir/abstraction.rst index a832fef995f1..86536b1dea6f 100644 --- a/docs/deep_dive/tensor_ir/abstraction.rst +++ b/docs/deep_dive/tensor_ir/abstraction.rst @@ -38,7 +38,7 @@ the compute statements themselves. C: T.Buffer((128,), "float32"), ) -> None: for i in range(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) C[vi] = A[vi] + B[vi] @@ -60,7 +60,7 @@ computations rely on the loop's sequence. Fortunately, the majority of primitive functions we focus on possess favorable properties, such as independence among loop iterations. For instance, the aforementioned program includes block and iteration annotations: -- The **block annotation** ``with T.block("C")`` signifies that the block is the fundamental +- The **block annotation** ``with T.sblock("C")`` signifies that the block is the fundamental computation unit designated for scheduling. A block may encompass a single computation statement, multiple computation statements with loops, or opaque intrinsics such as Tensor Core instructions. diff --git a/docs/deep_dive/tensor_ir/learning.rst b/docs/deep_dive/tensor_ir/learning.rst index b76f87f58c38..ace7ebe9de9f 100644 --- a/docs/deep_dive/tensor_ir/learning.rst +++ b/docs/deep_dive/tensor_ir/learning.rst @@ -67,7 +67,7 @@ language called TVMScript, which is a domain-specific dialect embedded in python C: T.Buffer((128, 128), "float32")): Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): - with T.block("Y"): + with T.sblock("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k) @@ -75,7 +75,7 @@ language called TVMScript, which is a domain-specific dialect embedded in python Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) @@ -142,7 +142,7 @@ A significant distinction lies in computational statements: .. code:: python # TensorIR - with T.block("Y"): + with T.sblock("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k) @@ -206,7 +206,7 @@ error because the loop expects an iterator of size 128, but we only bound it to # wrong program due to loop and block iteration mismatch for i in range(127): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ error here due to iterator size mismatch @@ -242,12 +242,12 @@ So we can also write the programs as follows. C: T.Buffer((128, 128), "float32")): Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): - with T.block("Y"): + with T.sblock("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py index 74b4406061b9..a35b9515fd79 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py @@ -70,7 +70,7 @@ def mm_relu( for i in range(128): for j in range(128): for k in range(128): - with T.block("Y"): + with T.sblock("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k) @@ -81,7 +81,7 @@ def mm_relu( Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i in range(128): for j in range(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) T.reads(Y[vi, vj]) @@ -111,13 +111,13 @@ def mm_relu( ): Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): - with T.block("Y"): + with T.sblock("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) @@ -150,13 +150,13 @@ def mm_relu( ): Y = T.alloc_buffer((M, N), dtype) for i, j, k in T.grid(M, N, K): - with T.block("Y"): + with T.sblock("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.cast(T.float32(0), dtype) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(M, N): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype)) @@ -188,13 +188,13 @@ def mm_relu(a: T.handle, b: T.handle, c: T.handle): C = T.match_buffer(c, [M, N], dtype) Y = T.alloc_buffer((M, N), dtype) for i, j, k in T.grid(M, N, K): - with T.block("Y"): + with T.sblock("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.cast(T.float32(0), dtype) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(M, N): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype)) diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py index eb1b2eb02029..599290cf9671 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py @@ -51,13 +51,13 @@ def main( T.func_attr({"tir.noalias": True}) Y = T.alloc_buffer((128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("Y"): + with T.sblock("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) @@ -103,7 +103,7 @@ def evaluate(mod: tvm.IRModule): # Subsequently, we execute the requisite operations to acquire a reference to # block **Y** and its associated loops. -block_Y = sch.get_block("Y") +block_Y = sch.get_sblock("Y") i, j, k = sch.get_loops(block_Y) ###################################################################### @@ -136,7 +136,7 @@ def evaluate(mod: tvm.IRModule): # variant. First, we employ a primitive known as **reverse_compute_at** to relocate block # **C** to an inner loop of **Y**. -block_C = sch.get_block("C") +block_C = sch.get_sblock("C") sch.reverse_compute_at(block_C, j0) sch.mod.show() diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py b/docs/how_to/tutorials/cross_compilation_and_rpc.py index ef1ca629ce4c..53e092268b95 100644 --- a/docs/how_to/tutorials/cross_compilation_and_rpc.py +++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py @@ -233,7 +233,7 @@ def run_opencl(): # create schedule for the above "add one" compute declaration mod = tvm.IRModule.from_expr(te.create_prim_func([A, B])) sch = tvm.tir.Schedule(mod) - (x,) = sch.get_loops(block=sch.get_block("B")) + (x,) = sch.get_loops(block=sch.get_sblock("B")) xo, xi = sch.split(x, [None, 32]) sch.bind(xo, "blockIdx.x") sch.bind(xi, "threadIdx.x") diff --git a/docs/reference/api/python/tir/tir.rst b/docs/reference/api/python/tir/tir.rst index 3f82fe8261ac..14a64d5592d2 100644 --- a/docs/reference/api/python/tir/tir.rst +++ b/docs/reference/api/python/tir/tir.rst @@ -20,4 +20,4 @@ tvm.tir .. automodule:: tvm.tir :members: :imported-members: - :exclude-members: PrimExpr, const, StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError + :exclude-members: PrimExpr, const, StmtSRef, SBlockScope, ScheduleState, Schedule, ScheduleError diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index aa3df4e7d443..15ed73716873 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -59,7 +59,7 @@ ffi::Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // * \param max_threads_per_block The maximum number of threads allowed. * \param get_factor A function that returns the tiling factor. */ -void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block, // +void BindBlockThreadIdx(tir::Schedule sch, tir::SBlockRV block, // int64_t max_threadblocks, int64_t max_threads_per_block, std::function get_factor = nullptr); diff --git a/include/tvm/meta_schedule/schedule/generic/winograd.h b/include/tvm/meta_schedule/schedule/generic/winograd.h index dc9b32fd10de..4a891fbaf1fc 100644 --- a/include/tvm/meta_schedule/schedule/generic/winograd.h +++ b/include/tvm/meta_schedule/schedule/generic/winograd.h @@ -29,7 +29,7 @@ namespace meta_schedule { * If there is a constant winograd transform matrix, inline it. * \return The only producer block. */ -tir::BlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::BlockRV block); +tir::SBlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::SBlockRV block); } // namespace meta_schedule } // namespace tvm diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index be9074acbde7..259b6ac12483 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -60,7 +60,7 @@ class ScheduleRuleNode : public runtime::Object { * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; + virtual ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block) = 0; /*! * \brief Deep clone the schedule rule. @@ -90,7 +90,7 @@ class ScheduleRule : public runtime::ObjectRef { * \return The list of schedules generated by applying the schedule rule. */ using FApply = - ffi::TypedFunction(const tir::Schedule&, const tir::BlockRV&)>; + ffi::TypedFunction(const tir::Schedule&, const tir::SBlockRV&)>; /*! * \brief Get the schedule rule as string with name. * \return The string of the schedule rule. @@ -151,7 +151,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param reuse_read Data reuse configuration for reading. std::nullopt means no reuse. * \param reuse_write Data reuse configuration for writing. std::nullopt means no reuse. * \param filter_fn A function that can be passed to overwrite the default condition for applying - * MultiLevelTiling to a block. Its signature must be (Schedule, BlockRV) -> bool. + * MultiLevelTiling to a block. Its signature must be (Schedule, SBlockRV) -> bool. * This is useful if there is a need to apply MultiLevelTiling to an operation / block which is * ignored by default. This function should return True for a block that should be tiled. * \return The schedule rule created @@ -343,7 +343,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode { } void InitializeWithTuneContext(const TuneContext& context) final; - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block) final; ScheduleRule Clone() const final; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyScheduleRule", PyScheduleRuleNode, ScheduleRuleNode); diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index b2f946328917..092a7a53f103 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -594,7 +594,7 @@ TVM_DLL bool WellFormed(ffi::Variant obj, bool check_struct_ * from the object (block or buffer) to it's index map transformation. */ -TVM_DLL ffi::Map> SuggestLayoutTransforms( +TVM_DLL ffi::Map> SuggestLayoutTransforms( const Function& fn, ffi::Array write_buffer_transformations); /* \brief Collect variables whose value can be computed at compile-time diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index ddb618e06b1f..26a6ab228c52 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -149,7 +149,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { return true; } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (op->name_hint == "root") { StmtExprVisitor::VisitStmt_(op); return; diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 8c5209982b10..93e4d10317fe 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -56,7 +56,7 @@ namespace ir_builder { * using T = tvm::script::ir_builder::tir; * With _(...); * { - * With _2(...); + * With _2(...); * Buffer buffer = T::MatchBuffer(...); * } * diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 5d6bcc8a2c2f..898e318950cc 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -147,7 +147,7 @@ class FunctionFrame : public SeqExprFrame { }; /*! \brief The ir_builder frame for relax binding blocks. */ -class BlockFrameNode : public RelaxFrameNode { +class BindingBlockFrameNode : public RelaxFrameNode { public: /*! \brief The flag that indicates whether the block is a dataflow block. */ bool is_dataflow; @@ -167,26 +167,27 @@ class BlockFrameNode : public RelaxFrameNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("is_dataflow", &BlockFrameNode::is_dataflow) - .def_ro("emitted_vars", &BlockFrameNode::emitted_vars) - .def_ro("output_vars", &BlockFrameNode::output_vars); + refl::ObjectDef() + .def_ro("is_dataflow", &BindingBlockFrameNode::is_dataflow) + .def_ro("emitted_vars", &BindingBlockFrameNode::emitted_vars) + .def_ro("output_vars", &BindingBlockFrameNode::output_vars); // `block_ended` is not registered as it's not visited. } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.BlockFrame", BlockFrameNode, - RelaxFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.BindingBlockFrame", + BindingBlockFrameNode, RelaxFrameNode); public: void EnterWithScope() final; void ExitWithScope() final; }; -class BlockFrame : public RelaxFrame { +class BindingBlockFrame : public RelaxFrame { public: - explicit BlockFrame(ObjectPtr data) : RelaxFrame(data) { + explicit BindingBlockFrame(ObjectPtr data) : RelaxFrame(data) { TVM_FFI_ICHECK(data != nullptr); } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockFrame, RelaxFrame, BlockFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BindingBlockFrame, RelaxFrame, + BindingBlockFrameNode); }; /*! diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 80b70daffd0b..ac26ddc036a3 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -77,13 +77,13 @@ TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); * \brief Start a binding block frame. * \return The created ir_builder Block frame. */ -TVM_DLL BlockFrame BindingBlock(); +TVM_DLL BindingBlockFrame BindingBlock(); /*! * \brief Start a dataflow binding block frame. * \return The created ir_builder Block frame. */ -TVM_DLL BlockFrame Dataflow(); +TVM_DLL BindingBlockFrame Dataflow(); /*! * \brief Expose the dataflow block output variables as global ones diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index db5776890ab9..1255a67335fa 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -125,9 +125,9 @@ class PrimFuncFrame : public TIRFrame { /*! * \brief A frame that represents the block. * - * \sa BlockFrame + * \sa SBlockFrame */ -class BlockFrameNode : public TIRFrameNode { +class SBlockFrameNode : public TIRFrameNode { public: /*! \brief The name of the block. */ ffi::String name; @@ -157,20 +157,20 @@ class BlockFrameNode : public TIRFrameNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &BlockFrameNode::name) - .def_ro("iter_vars", &BlockFrameNode::iter_vars) - .def_ro("reads", &BlockFrameNode::reads) - .def_ro("writes", &BlockFrameNode::writes) - .def_ro("init", &BlockFrameNode::init) - .def_ro("alloc_buffers", &BlockFrameNode::alloc_buffers) - .def_ro("match_buffers", &BlockFrameNode::match_buffers) - .def_ro("annotations", &BlockFrameNode::annotations) - .def_ro("iter_values", &BlockFrameNode::iter_values) - .def_ro("predicate", &BlockFrameNode::predicate) - .def_ro("no_realize", &BlockFrameNode::no_realize); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BlockFrame", BlockFrameNode, + refl::ObjectDef() + .def_ro("name", &SBlockFrameNode::name) + .def_ro("iter_vars", &SBlockFrameNode::iter_vars) + .def_ro("reads", &SBlockFrameNode::reads) + .def_ro("writes", &SBlockFrameNode::writes) + .def_ro("init", &SBlockFrameNode::init) + .def_ro("alloc_buffers", &SBlockFrameNode::alloc_buffers) + .def_ro("match_buffers", &SBlockFrameNode::match_buffers) + .def_ro("annotations", &SBlockFrameNode::annotations) + .def_ro("iter_values", &SBlockFrameNode::iter_values) + .def_ro("predicate", &SBlockFrameNode::predicate) + .def_ro("no_realize", &SBlockFrameNode::no_realize); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.SSBlockFrame", SBlockFrameNode, TIRFrameNode); public: @@ -182,18 +182,18 @@ class BlockFrameNode : public TIRFrameNode { }; /*! - * \brief Managed reference to BlockFrameNode. + * \brief Managed reference to SBlockFrameNode. * - * \sa BlockFrameNode + * \sa SBlockFrameNode */ -class BlockFrame : public TIRFrame { +class SBlockFrame : public TIRFrame { public: - explicit BlockFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + explicit SBlockFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { TVM_FFI_ICHECK(data != nullptr); data_ = std::move(data); } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockFrame, TIRFrame, BlockFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockFrame, TIRFrame, SBlockFrameNode); }; /*! @@ -207,7 +207,7 @@ class BlockInitFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BlockInitFrame", BlockInitFrameNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.SBlockInitFrame", BlockInitFrameNode, TIRFrameNode); public: diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 07c7fe262bb3..788eb9615c82 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -119,10 +119,10 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, /*! * \brief The block declaration statement. * \param name The name of the block. - * \param no_realize The flag whether to construct BlockRealize or Block. - * \return The BlockFrame. + * \param no_realize The flag whether to construct SBlockRealize or SBlock. + * \return The SBlockFrame. */ -BlockFrame Block(ffi::String name, bool no_realize = false); +SBlockFrame Block(ffi::String name, bool no_realize = false); /*! * \brief The block initialization statement. diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 0f4b6afd62fb..8e71d5a1a5f4 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -233,8 +233,8 @@ TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit); * - second: write regions * - third: opaque regions */ -TVM_DLL ffi::Array> GetBlockAccessRegion( - const Block& block, const ffi::Map& buffer_var_map); +TVM_DLL ffi::Array> GetSBlockAccessRegion( + const SBlock& block, const ffi::Map& buffer_var_map); /*! * \brief Auto detect the block read/write region according to its body stmt. An opaque access will @@ -244,8 +244,8 @@ TVM_DLL ffi::Array> GetBlockAccessRegion( * It is a map from buffer var to the buffer * \return An array only consisting of the read regions and write regions of the input block */ -TVM_DLL ffi::Array> GetBlockReadWriteRegion( - const Block& block, const ffi::Map& buffer_var_map); +TVM_DLL ffi::Array> GetSBlockReadWriteRegion( + const SBlock& block, const ffi::Map& buffer_var_map); /*! \brief Helper struct for return value of IdentifyMemCpy * @@ -329,7 +329,7 @@ TVM_DLL ffi::Map> DetectBufferAccessLCA(const PrimFu * * - Each variable has a single point of definition. * - * - Expressions within a tir::Block may not reference variables + * - Expressions within a tir::SBlock may not reference variables * defined outside the block. For example, for a block with iter * vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement * `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop @@ -379,7 +379,7 @@ const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var); * \param mod The input TIR module. * \return The anchor block if found, nullptr otherwise. */ -const tir::BlockNode* FindAnchorBlock(const IRModule& mod); +const tir::SBlockNode* FindAnchorBlock(const IRModule& mod); // Pass variants of verification analysis // directly throws RuntimeError when verification fails. diff --git a/include/tvm/tir/block_dependence_info.h b/include/tvm/tir/block_dependence_info.h index b1fd8998645a..2e56058eff56 100644 --- a/include/tvm/tir/block_dependence_info.h +++ b/include/tvm/tir/block_dependence_info.h @@ -18,7 +18,7 @@ */ /*! * \file tvm/tir/block_dependence_info.h - * \brief Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects to + * \brief Define BlockDependenceInfoNode that uses the SBlockScope and StmtSRef objects to * store the block level dependences * \sa BlockDependenceInfoNode */ @@ -41,10 +41,10 @@ namespace tir { /** * @brief An object that helps build and query block level dependences using the 2 core objects - * BlockScope and StmtSRef + * SBlockScope and StmtSRef * * The data structures exposed are: - * 1) sref2scope: Mapping from the srefs to its corresponding BlockScope + * 1) sref2scope: Mapping from the srefs to its corresponding SBlockScope * 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs * * Note that this object does not store SRefs to loops as the purpose is only to expose block level @@ -54,10 +54,10 @@ namespace tir { class BlockDependenceInfoNode : public Object { public: /*! - * \brief Mapping from a block sref to its corresponding BlockScope, + * \brief Mapping from a block sref to its corresponding SBlockScope, * tracking the dependency inside the block scope, */ - std::unordered_map sref2scope; + std::unordered_map sref2scope; /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */ std::unordered_map stmt2ref; @@ -65,17 +65,17 @@ class BlockDependenceInfoNode : public Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockDependenceInfo", BlockDependenceInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockDependenceInfo", BlockDependenceInfoNode, Object); /*! - * \brief Get the BlockScope corresponding to the sref of scope root block + * \brief Get the SBlockScope corresponding to the sref of scope root block * \param scope_root The block sref to be retrieved - * \return The corresponding BlockScope + * \return The corresponding SBlockScope */ - BlockScope GetBlockScope(const StmtSRef& scope_root) const { + SBlockScope GetSBlockScope(const StmtSRef& scope_root) const { auto it = sref2scope.find(scope_root); CHECK(it != sref2scope.end()) - << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" + << "IndexError: Cannot find the corresponding SBlockScope to the block sref:\n" << ffi::GetRef(scope_root->stmt); return it->second; } diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h index f1120c7837ff..d356643cda11 100644 --- a/include/tvm/tir/block_scope.h +++ b/include/tvm/tir/block_scope.h @@ -18,9 +18,9 @@ */ /*! * \file tvm/tir/block_scope.h - * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope. + * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, SBlockScope. * \sa StmtSRefNode - * \sa BlockScopeNode + * \sa SBlockScopeNode */ #ifndef TVM_TIR_BLOCK_SCOPE_H_ #define TVM_TIR_BLOCK_SCOPE_H_ @@ -41,7 +41,7 @@ namespace tir { * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". * * Glossary - * - Block sref: A StmtSRef that points to a TensorIR block. + * - SBlock sref: A StmtSRef that points to a TensorIR SBlock. * - Loop sref: A StmtSRef that points to a TensorIR for loop. * - Parent sref: The parent reference of an sref is the block or loop reference to the closest schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in @@ -86,7 +86,7 @@ class StmtSRefNode : public Object { * \brief Get the referenced statement with proper type checking. * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt` * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably - * tvm::tir::BlockNode or tvm::tir::ForNode + * tvm::tir::SBlockNode or tvm::tir::ForNode * \return nullptr if type check fails, otherwise the casted result for `this->stmt` */ template @@ -177,7 +177,7 @@ class SRefTreeCreator : private StmtVisitor { void VisitStmt_(const ForNode* loop) final; - void VisitStmt_(const BlockRealizeNode* realize) final; + void VisitStmt_(const SBlockRealizeNode* realize) final; void VisitStmt_(const SeqStmtNode* seq_stmt) final; @@ -243,14 +243,14 @@ class Dependency : public ObjectRef { * For example even leaf nodes have a scope node, even though they have no dependencies. * * Glossary: - * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref, + * - SBlock scope: A contiguous subtree of the sref tree, rooted at each SBlock sref, * whose components are: * - scope root: a block sref * - internal srefs: loop srefs * - scope leaves: block srefs * - Child block: The scope leaf blocks under the scope root or a specific internal sref */ -class BlockScopeNode : public Object { +class SBlockScopeNode : public Object { public: /*! * \brief Lookup table for the `src` of dependencies @@ -265,9 +265,9 @@ class BlockScopeNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockScope", BlockScopeNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockScope", SBlockScopeNode, Object); public: /******** Dependency ********/ @@ -286,29 +286,29 @@ class BlockScopeNode : public Object { }; /*! - * \brief Managed reference to BlockScopeNode - * \sa BlockScopeNode + * \brief Managed reference to SBlockScopeNode + * \sa SBlockScopeNode */ -class BlockScope : public ObjectRef { +class SBlockScope : public ObjectRef { public: /*! - * \brief Constructor from ObjectPtr. + * \brief Constructor from ObjectPtr. * \param data The object pointer. */ - explicit BlockScope(ObjectPtr data) : ObjectRef(data) { + explicit SBlockScope(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } /*! \brief The constructor creating an empty block scope with on dependency information */ - TVM_DLL BlockScope(); + TVM_DLL SBlockScope(); /*! * \brief Create the object with the specific leaf blocks, and compute the dependency information * between the leaf blocks. * \param child_block_srefs The srefs to the leaf blocks * \note We assume the leaf blocks are given in pre-DFS order */ - TVM_DLL explicit BlockScope(const ffi::Array& child_block_srefs); + TVM_DLL explicit SBlockScope(const ffi::Array& child_block_srefs); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockScope, ObjectRef, BlockScopeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockScope, ObjectRef, SBlockScopeNode); }; } // namespace tir diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 88398cf06f06..e100eeb59029 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -51,8 +51,8 @@ class DataTypeLegalizer : public StmtExprMutator { protected: Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; - Stmt VisitStmt_(const BlockRealizeNode* op) override; - Stmt VisitStmt_(const BlockNode* op) override; + Stmt VisitStmt_(const SBlockRealizeNode* op) override; + Stmt VisitStmt_(const SBlockNode* op) override; Stmt VisitStmt_(const LetStmtNode* op) override; PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const SelectNode* op) override; @@ -101,8 +101,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { using Parent::VisitExpr_; using Parent::VisitStmt_; - Stmt VisitStmt_(const BlockRealizeNode* op) override; - Stmt VisitStmt_(const BlockNode* op) override; + Stmt VisitStmt_(const SBlockRealizeNode* op) override; + Stmt VisitStmt_(const SBlockNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 97701d16b097..956254bbebc0 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -223,7 +223,7 @@ class TensorIntrin : public ObjectRef { * A = T.match_buffer(a, (m, n), "float32") * B = T.match_buffer(b, (m, n), "float32") * for i, j in T.grid(m, n): - * with T.block(): + * with T.sblock(): * vi, vj = T.axis.remap("SS", [i, j]) * B[vi, vj] = A[vi, vj] * \endcode @@ -243,7 +243,7 @@ class TensorIntrin : public ObjectRef { * A = T.match_buffer(a, (16, 16), "float32") * B = T.match_buffer(b, (16, 16), "float32") * for i, j in T.grid(16, 16): - * with T.block(): + * with T.sblock(): * vi, vj = T.axis.remap("SS", [i, j]) * B[vi, vj] = A[vi, vj] * \endcode diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index b6e283f400fb..c4ee3ce03d15 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -92,7 +92,7 @@ class InstructionKindNode : public runtime::Object { ffi::String name; /*! * \brief Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule - * state. For example, the instruction `GetBlock` is pure because it changes + * state. For example, the instruction `GetSBlock` is pure because it changes * nothing, while `ComputeInline` is not because removing it leads to a different resulting * schedule. */ @@ -148,7 +148,7 @@ class InstructionNode : public runtime::Object { /*! * \brief The input random variables of the instruction, and the type of each element can be one * of the following: - * - BlockRV + * - SBlockRV * - LoopRV * - ExprRV * - double @@ -160,12 +160,12 @@ class InstructionNode : public runtime::Object { /*! * \brief The attributes of the instruction. Similar to attributes of an operator, * attributes of an instruction are arbitrary constant metadata required by the instructions. - * For example, the name of the block to be retrieved in `GetBlock`. + * For example, the name of the block to be retrieved in `GetSBlock`. */ ffi::Array attrs; /*! \brief The output random variables of the instruction, and the type of each element can be one * of the following: - * - BlockRV + * - SBlockRV * - LoopRV * - ExprRV, atomic variables only, won't be constants or composite PrimExpr */ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index a768a7dd4f31..e346eb458b9f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -45,27 +45,27 @@ enum class BufferIndexType : int32_t { kWrite = 1, }; -/**************** Random variable: BlockRV ****************/ +/**************** Random variable: SBlockRV ****************/ /*! \brief A random variable that evaluates to a TensorIR block */ -class BlockRVNode : public runtime::Object { +class SBlockRVNode : public runtime::Object { public: static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockRV", BlockRVNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockRV", SBlockRVNode, runtime::Object); }; /*! - * \brief Managed reference to BlockRVNode - * \sa BlockRVNode + * \brief Managed reference to SBlockRVNode + * \sa SBlockRVNode */ -class BlockRV : public runtime::ObjectRef { +class SBlockRV : public runtime::ObjectRef { public: /*! \brief Constructor */ - TVM_DLL BlockRV(); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockRV, runtime::ObjectRef, BlockRVNode); + TVM_DLL SBlockRV(); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockRV, runtime::ObjectRef, SBlockRVNode); }; /**************** Random variable: LoopRV ****************/ @@ -129,11 +129,11 @@ class ScheduleNode : public runtime::Object { * their names are "main", users will have to call this method to explicitly specify which * function to work on. * - * This sugar function will guide the `GetBlock` method if its `func_name` is not specified. + * This sugar function will guide the `GetSBlock` method if its `func_name` is not specified. * * \param func_name The name of the function to be working on * - * \sa GetBlock + * \sa GetSBlock */ virtual void WorkOn(const ffi::String& func_name) = 0; /*! @@ -156,11 +156,11 @@ class ScheduleNode : public runtime::Object { public: /******** Lookup/Remove random variables ********/ /*! - * \brief Get the block corresponding to the specific BlockRV - * \param block_rv The BlockRV to be looked up + * \brief Get the block corresponding to the specific SBlockRV + * \param block_rv The SBlockRV to be looked up * \return The corresponding block */ - virtual Block Get(const BlockRV& block_rv) const = 0; + virtual SBlock Get(const SBlockRV& block_rv) const = 0; /*! * \brief Get the for loop corresponding to the specific LoopRV * \param loop_rv The LoopRV to be looked up @@ -174,11 +174,11 @@ class ScheduleNode : public runtime::Object { */ virtual PrimExpr Get(const ExprRV& expr_rv) const = 0; /*! - * \brief Get the block sref corresponding to the specific BlockRV - * \param block_rv The BlockRV to be looked up + * \brief Get the block sref corresponding to the specific SBlockRV + * \param block_rv The SBlockRV to be looked up * \return The corresponding block sref */ - virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0; + virtual StmtSRef GetSRef(const SBlockRV& block_rv) const = 0; /*! * \brief Get the loop sref corresponding to the specific LoopRV * \param loop_rv The LoopRV to be looked up @@ -186,11 +186,11 @@ class ScheduleNode : public runtime::Object { */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; /*! - * \brief Check the existance of a specific BlockRV - * \param block_rv The BlockRV to be looked up + * \brief Check the existance of a specific SBlockRV + * \param block_rv The SBlockRV to be looked up * \return Whether the corresponding block exists */ - virtual bool HasBlock(const BlockRV& block_rv) const = 0; + virtual bool HasBlock(const SBlockRV& block_rv) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -207,7 +207,7 @@ class ScheduleNode : public runtime::Object { * \brief Remove a block random variable from the symbol table * \param block_rv The random variable to be removed */ - virtual void RemoveRV(const BlockRV& block_rv) = 0; + virtual void RemoveRV(const SBlockRV& block_rv) = 0; /*! * \brief Remove a loop random variable from the symbol table * \param loop_rv The random variable to be removed @@ -266,7 +266,7 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The sampled loop where the input block is to be computed at */ - virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, + virtual LoopRV SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision = std::nullopt) = 0; /******** Schedule: Get blocks & loops ********/ @@ -284,40 +284,40 @@ class ScheduleNode : public runtime::Object { * * \sa WorkOn */ - virtual BlockRV GetBlock(const ffi::String& name, - const ffi::Optional& func_name = std::nullopt) = 0; + virtual SBlockRV GetSBlock(const ffi::String& name, + const ffi::Optional& func_name = std::nullopt) = 0; /*! * \brief Get the parent loops of the block in its scope, from outer to inner * \param block_rv The query block * \return A list of loops above the given block in its scope, from outer to inner */ - virtual ffi::Array GetLoops(const BlockRV& block_rv) = 0; + virtual ffi::Array GetLoops(const SBlockRV& block_rv) = 0; /*! * \brief Get the leaf blocks of a specific scope * \param block_rv The block where the scope is rooted * \return A list of child blocks */ - virtual ffi::Array GetChildBlocks(const BlockRV& block_rv) = 0; + virtual ffi::Array GetChildBlocks(const SBlockRV& block_rv) = 0; /*! * \brief Get the leaf blocks of under a specific loop * \param loop_rv The loop under which collecting is conducted * \return A list of child blocks */ - virtual ffi::Array GetChildBlocks(const LoopRV& loop_rv) = 0; + virtual ffi::Array GetChildBlocks(const LoopRV& loop_rv) = 0; /*! * \brief Get the producer of a specific block, under the same block scope * \param block_rv The block in the query * \return A list of blocks, the producers of the given block under the same scope of the given * block */ - virtual ffi::Array GetProducers(const BlockRV& block_rv) = 0; + virtual ffi::Array GetProducers(const SBlockRV& block_rv) = 0; /*! * \brief Get the consumers of a specific block, under the same block scope * \param block_rv The block to be queried * \return A list of blocks, the consumers of the given block under the same scope of the given * block */ - virtual ffi::Array GetConsumers(const BlockRV& block_rv) = 0; + virtual ffi::Array GetConsumers(const SBlockRV& block_rv) = 0; /*! * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written @@ -326,7 +326,7 @@ class ScheduleNode : public runtime::Object { * \return A list of all blocks that write to some output buffer * block */ - virtual ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; + virtual ffi::Array GetOutputBlocks(const SBlockRV& scope_block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Merge a list of loops into one. The loops under their LCA requires: @@ -395,14 +395,14 @@ class ScheduleNode : public runtime::Object { * \param block_rv The block to be transformed. * \param new_order The new itervar order. */ - virtual void ReorderBlockIterVar(const BlockRV& block_rv, + virtual void ReorderBlockIterVar(const SBlockRV& block_rv, const ffi::Array new_order) = 0; /*! * \brief Create a new unit loop on top of the specific block. * \param block_rv The block above which the new loop is created * \return The new loop created */ - virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0; + virtual LoopRV AddUnitLoop(const SBlockRV& block_rv) = 0; /*! * \brief Create a new unit loop on top of the specific loop. * \param loop_rv The loop above which the new loop is created @@ -458,9 +458,9 @@ class ScheduleNode : public runtime::Object { * \param consumer_blocks An optional list of consumers of the cache to rewrite. * \return The cache stage block. */ - virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks = {}) = 0; + virtual SBlockRV CacheRead(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -471,9 +471,9 @@ class ScheduleNode : public runtime::Object { * \param consumer_blocks An optional list of consumers to read from cache directly. * \return The cache stage block. */ - virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks = {}) = 0; + virtual SBlockRV CacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) = 0; /*! * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. @@ -486,8 +486,9 @@ class ScheduleNode : public runtime::Object { * vars. * \return The cache stage block. */ - virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, const IndexMap& index_map) = 0; + virtual SBlockRV ReindexCacheRead(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const IndexMap& index_map) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -500,9 +501,9 @@ class ScheduleNode : public runtime::Object { * vars. * \return The cache stage block. */ - virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const ffi::String& storage_scope, - const IndexMap& index_map) = 0; + virtual SBlockRV ReindexCacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const IndexMap& index_map) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the target block both read & write the target buffer. @@ -511,8 +512,8 @@ class ScheduleNode : public runtime::Object { * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ - virtual ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) = 0; + virtual ffi::Array CacheInplace(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) = 0; /*! * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. @@ -521,8 +522,8 @@ class ScheduleNode : public runtime::Object { * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage blocks. */ - virtual ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, - int cse_thresh) = 0; + virtual ffi::Array CacheIndex(const SBlockRV& block_rv, + const ffi::String& storage_scope, int cse_thresh) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the @@ -534,13 +535,13 @@ class ScheduleNode : public runtime::Object { * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \return The reindex stage block. */ - virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) = 0; + virtual SBlockRV ReIndex(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) = 0; /******** Schedule: Data movement ********/ - virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) = 0; - virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + virtual SBlockRV ReadAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) = 0; + virtual SBlockRV WriteAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -562,7 +563,7 @@ class ScheduleNode : public runtime::Object { * - `index = -2` means inserted into the first possible insertion point; * - Otherwise, `index` is a nonnegative number that indicates the insertion point */ - virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + virtual void ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) = 0; /*! * \brief Move a consumer block under the specific loop, and regenerate the @@ -583,7 +584,7 @@ class ScheduleNode : public runtime::Object { * - `index = -2` means inserted into the first possible insertion point; * - Otherwise, `index` is a nonnegative number that indicates the insertion point */ - virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + virtual void ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) = 0; /*! * \brief Inline a block into its consumer(s). It requires: @@ -595,7 +596,7 @@ class ScheduleNode : public runtime::Object { * and no variables other than those indexing variables are allowed in the statement. * \param block The block to be inlined to its consumer(s) */ - virtual void ComputeInline(const BlockRV& block) = 0; + virtual void ComputeInline(const SBlockRV& block) = 0; /*! * \brief Inline a block into its only producer. It requires: * 1) The block is a complete non-root block, which only produces and consumers one buffer @@ -607,14 +608,14 @@ class ScheduleNode : public runtime::Object { * and no variables other than those indexing variables are allowed in the statement. * \param block The block to be inlined to its producer */ - virtual void ReverseComputeInline(const BlockRV& block) = 0; + virtual void ReverseComputeInline(const SBlockRV& block) = 0; /*! * \brief Fuse an epilogue block into a reduction block * \param reduction_block The reduction block (e.g., matmul) * \param epilogue_block The epilogue block to be fused (e.g., bias add) */ - virtual void FuseReductionEpilogue(const BlockRV& reduction_block, - const BlockRV& epilogue_block) = 0; + virtual void FuseReductionEpilogue(const SBlockRV& reduction_block, + const SBlockRV& epilogue_block) = 0; /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. @@ -631,7 +632,7 @@ class ScheduleNode : public runtime::Object { * \param loop_rv The loop above which the init block is inserted before. * \return The init block */ - virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0; + virtual SBlockRV DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) = 0; /*! * \brief Factorize an associative reduction block by the specified loop. * \details An associative reduction cannot be parallelized directly, @@ -649,8 +650,8 @@ class ScheduleNode : public runtime::Object { * ndim(B)]`, and the negative index will be normalized to a non-negative one * \return The rfactor block */ - virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; - /******** Schedule: Block annotation ********/ + virtual SBlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; + /******** Schedule: SBlock annotation ********/ /*! * \brief Set alignment requirement for specific dimension such that * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for @@ -663,7 +664,7 @@ class ScheduleNode : public runtime::Object { * \param factor The factor multiple of alignment * \param offset The required offset factor */ - virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + virtual void StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor, int offset) = 0; /*! * \brief Set the storage scope of a buffer, where the buffer is specified by a block and a @@ -672,7 +673,7 @@ class ScheduleNode : public runtime::Object { * \param buffer_index The index of the buffer in block's write region * \param storage_scope The storage scope to be set */ - virtual void SetScope(const BlockRV& block_rv, int buffer_index, + virtual void SetScope(const SBlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) = 0; /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a @@ -683,7 +684,7 @@ class ScheduleNode : public runtime::Object { * \param buffer_index the index of the buffer in block's write region * \param dtype The data type to be set */ - virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + virtual void UnsafeSetDType(const SBlockRV& block_rv, int buffer_index, const ffi::String& dtype) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! @@ -692,14 +693,15 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return the new block */ - virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0; + virtual SBlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0; /*! * \brief Convert specified blocks into a nested block. * \param blocks the specified block to construct the new block * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return the new block */ - virtual BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters = true) = 0; + virtual SBlockRV Blockize(const ffi::Array& blocks, + bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. * \param loop_rv The loop to be tensorized @@ -714,7 +716,7 @@ class ScheduleNode : public runtime::Object { * \param intrin Name of the tensor intrinsic * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + virtual void Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters = true) = 0; /******** Schedule: Annotation ********/ @@ -731,7 +733,7 @@ class ScheduleNode : public runtime::Object { * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ - virtual void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, + virtual void Annotate(const SBlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) = 0; /*! * \brief Unannotate a loop's annotation with key ann_key @@ -744,7 +746,7 @@ class ScheduleNode : public runtime::Object { * \param block_rv The block to be unannotated * \param ann_key The annotation key */ - virtual void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) = 0; + virtual void Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) = 0; /******** Schedule: Layout transformation ********/ /*! @@ -778,7 +780,7 @@ class ScheduleNode : public runtime::Object { * to ensure the index map is injective, otherwise, the correctness of the schedule is not * guaranteed. */ - virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, + virtual void TransformLayout(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, const ffi::Optional& pad_value = std::nullopt, bool assume_injective_transform = false) = 0; @@ -791,7 +793,7 @@ class ScheduleNode : public runtime::Object { * \param block_rv The block to be transformed * \param index_map The transformation to apply. */ - virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0; + virtual void TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) = 0; /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read @@ -801,7 +803,7 @@ class ScheduleNode : public runtime::Object { * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param axis_separators The axis separator of the buffer */ - virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, + virtual void SetAxisSeparator(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array& axis_separators) = 0; @@ -813,7 +815,7 @@ class ScheduleNode : public runtime::Object { * \param loop_rv The loop above which the const filling block is inserted before. * \return The const pad value filling block. */ - virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0; + virtual SBlockRV DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) = 0; /*! * \brief Pad the computation of Einsum. @@ -832,7 +834,7 @@ class ScheduleNode : public runtime::Object { * The size of the producer buffers are infered from the padding size of the Einsum computation. * The producer buffers are padded by the initial value of the corresponding reduction. */ - virtual void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) = 0; + virtual void PadEinsum(const SBlockRV& block_rv, const ffi::Array& padding) = 0; /******** Schedule: Buffer transformation ********/ /*! @@ -849,7 +851,7 @@ class ScheduleNode : public runtime::Object { * \param block_rv The producer block of the buffer. * \param write_buffer_index The index of the buffer in block's write region. */ - virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; + virtual void RollingBuffer(const SBlockRV& block_rv, int write_buffer_index) = 0; /*! * \brief Annotate the buffer access of a block @@ -858,7 +860,7 @@ class ScheduleNode : public runtime::Object { * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param index_map The index map that defines the new read or write region */ - virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + virtual void AnnotateBufferAccess(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) = 0; @@ -872,7 +874,7 @@ class ScheduleNode : public runtime::Object { * \param buf_type The buffer type: read/write * \param buf_index_array The array of buffer indices we hide access. */ - virtual void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + virtual void UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type, const ffi::Array& buf_index_array) = 0; }; diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 4467463912e8..fffa25e19fd6 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -41,9 +41,9 @@ namespace tir { * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it * reads are completely covered by their producers, etc. */ -struct BlockInfo { +struct SBlockInfo { /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */ - BlockScope scope{ffi::UnsafeInit()}; + SBlockScope scope{ffi::UnsafeInit()}; // The properties below are information about the current block realization under its parent scope /*! \brief Property of a block, indicating the block realization binding is quasi-affine */ bool affine_binding{false}; @@ -59,14 +59,14 @@ struct BlockInfo { * 1) The region cover property holds for every of its child blocks * 2) No write-after-read dependency or opaque dependency, only read-after-write and * write-after-write are allowed - * 3) All the statements in the scope are schedulable statements, i.e. Block and For + * 3) All the statements in the scope are schedulable statements, i.e. SBlock and For */ bool stage_pipeline{false}; - BlockInfo() = default; + SBlockInfo() = default; - explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false, - bool stage_pipeline = false) + explicit SBlockInfo(SBlockScope scope, bool affine_binding = false, bool region_cover = false, + bool stage_pipeline = false) : scope(std::move(scope)), // affine_binding(affine_binding), // region_cover(region_cover), @@ -101,11 +101,11 @@ class ScheduleStateNode : public Object { /*! \brief The AST of the module being scheduled */ IRModule mod; /*! - * \brief Mapping from a block sref to its correpsonding BlockInfo, + * \brief Mapping from a block sref to its correpsonding SBlockInfo, * tracking the dependency inside the block scope, * and storing necessary information flags for scheduling */ - std::unordered_map block_info; + std::unordered_map block_info; /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */ std::unordered_map stmt2ref; /*! @@ -134,9 +134,9 @@ class ScheduleStateNode : public Object { * the only copy to the IRModule and IR nodes. * * Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`. - * 1) Block -> Block + * 1) SBlock -> SBlock * 2) Loop -> Loop - * 3) Loop -> BlockRealize + * 3) Loop -> SBlockRealize * * \param src_sref The sref to the statement to be replaced * \param tgt_stmt The statement to be replaced in @@ -147,7 +147,7 @@ class ScheduleStateNode : public Object { * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars. */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, - const ffi::Map& block_sref_reuse); + const ffi::Map& block_sref_reuse); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. @@ -160,21 +160,21 @@ class ScheduleStateNode : public Object { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ScheduleState", ScheduleStateNode, Object); /******** Property of blocks ********/ - /*! \brief Returns the BlockInfo correpsonding to the block sref */ - TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const; + /*! \brief Returns the SBlockInfo correpsonding to the block sref */ + TVM_DLL SBlockInfo GetSBlockInfo(const StmtSRef& block_sref) const; /*! - * \brief Recalculate the BlockInfo recursively under stmt. - * If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't + * \brief Recalculate the SBlockInfo recursively under stmt. + * If stmt is a SBlock itself, we will not reset its affine binding flag unless it doesn't * have block vars, since the affine flag depends on the outer scope of stmt. */ - TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt); + TVM_DLL void UpdateScopeSBlockInfo(const Stmt& stmt); /*! - * \brief Get the BlockScope correpsonding to the sref of scope root block + * \brief Get the SBlockScope correpsonding to the sref of scope root block * \param scope_root The block sref to be retrieved - * \return The corresponding BlockScope + * \return The corresponding SBlockScope */ - BlockScope GetBlockScope(const StmtSRef& scope_root) const { - return GetBlockInfo(scope_root).scope; + SBlockScope GetSBlockScope(const StmtSRef& scope_root) const { + return GetSBlockInfo(scope_root).scope; } /*! * \brief Check a cached flag indicating if the specific block has quasi-affine bindings @@ -182,7 +182,7 @@ class ScheduleStateNode : public Object { * \return A boolean flag indicating if the block has quasi-affine bindings */ bool IsAffineBlockBinding(const StmtSRef& block_sref) const { - return GetBlockInfo(block_sref).affine_binding; + return GetSBlockInfo(block_sref).affine_binding; } /*! * \brief Check a cached flag indicating if each of the specific consumer block's read region @@ -191,15 +191,15 @@ class ScheduleStateNode : public Object { * \return A boolean flag indicating if the block has quasi-affine bindings */ bool IsRegionCoveredConsumer(const StmtSRef& consumer_block_sref) const { - return GetBlockInfo(consumer_block_sref).region_cover; + return GetSBlockInfo(consumer_block_sref).region_cover; } /*! * \brief Check a cached flag indicating if a block scope is an equivalence of a stage pipeline * \param scope_root The block sref to be retrieved - * \return The corresponding BlockScope + * \return The corresponding SBlockScope */ bool IsStagePipeline(const StmtSRef& scope_root) const { - return GetBlockInfo(scope_root).stage_pipeline; + return GetSBlockInfo(scope_root).stage_pipeline; } }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0831b84cf6fe..b64dc4beecff 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -916,10 +916,10 @@ class MatchBufferRegion : public ObjectRef { /*! * \brief A block is a basic schedule unit in TIR. - * \note Block's body is parameterized by iter vars. + * \note SBlock's body is parameterized by iter vars. * \code * - * with T.block(name): + * with T.sblock(name): * v0 = T.axis.S(domain, value0) * v1 = T.axis.R(domain, value1) * ... @@ -935,7 +935,7 @@ class MatchBufferRegion : public ObjectRef { * * \endcode */ -class BlockNode : public StmtNode { +class SBlockNode : public StmtNode { public: /*! \brief The variables of the block. */ ffi::Array iter_vars; @@ -964,27 +964,27 @@ class BlockNode : public StmtNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("iter_vars", &BlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("reads", &BlockNode::reads) - .def_ro("writes", &BlockNode::writes) - .def_ro("name_hint", &BlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("alloc_buffers", &BlockNode::alloc_buffers) - .def_ro("match_buffers", &BlockNode::match_buffers) - .def_ro("annotations", &BlockNode::annotations) - .def_ro("init", &BlockNode::init) - .def_ro("body", &BlockNode::body); + refl::ObjectDef() + .def_ro("iter_vars", &SBlockNode::iter_vars, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("reads", &SBlockNode::reads) + .def_ro("writes", &SBlockNode::writes) + .def_ro("name_hint", &SBlockNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("alloc_buffers", &SBlockNode::alloc_buffers) + .def_ro("match_buffers", &SBlockNode::match_buffers) + .def_ro("annotations", &SBlockNode::annotations) + .def_ro("init", &SBlockNode::init) + .def_ro("body", &SBlockNode::body); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Block", BlockNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlock", SBlockNode, StmtNode); }; /*! - * \brief Managed reference to BlockNode. - * \sa BlockNode + * \brief Managed reference to SBlockNode. + * \sa SBlockNode */ -class Block : public Stmt { +class SBlock : public Stmt { public: - TVM_DLL explicit Block( + TVM_DLL explicit SBlock( ffi::Array iter_vars, ffi::Array reads, ffi::Array writes, ffi::String name_hint, Stmt body, ffi::Optional init = std::nullopt, @@ -993,14 +993,14 @@ class Block : public Stmt { ffi::Map annotations = ffi::Map(), Span span = Span()); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Block, Stmt, BlockNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlock, Stmt, SBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockNode); }; /*! * \brief A block realization node represents execution of the block at the binding values. */ -class BlockRealizeNode : public StmtNode { +class SBlockRealizeNode : public StmtNode { public: /*! \brief The corresponding values of the iter vars. */ ffi::Array iter_values; @@ -1010,29 +1010,29 @@ class BlockRealizeNode : public StmtNode { */ PrimExpr predicate; /*! \brief The block to be realized. */ - Block block; + SBlock block; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("iter_values", &BlockRealizeNode::iter_values) - .def_ro("predicate", &BlockRealizeNode::predicate) - .def_ro("block", &BlockRealizeNode::block); + refl::ObjectDef() + .def_ro("iter_values", &SBlockRealizeNode::iter_values) + .def_ro("predicate", &SBlockRealizeNode::predicate) + .def_ro("block", &SBlockRealizeNode::block); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockRealize", BlockRealizeNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockRealize", SBlockRealizeNode, StmtNode); }; /*! * \brief Managed reference to BlockRealizeNode * \sa BlockRealizeNode */ -class BlockRealize : public Stmt { +class SBlockRealize : public Stmt { public: - TVM_DLL explicit BlockRealize(ffi::Array iter_values, PrimExpr predicate, Block block, - Span span = Span()); + TVM_DLL explicit SBlockRealize(ffi::Array iter_values, PrimExpr predicate, SBlock block, + Span span = Span()); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BlockRealize, Stmt, BlockRealizeNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SBlockRealize, Stmt, SBlockRealizeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockRealizeNode); }; /*! \brief namespace of possible attributes in AttrStmt.attr_key */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index b3c43bdc1459..d8d06ae1f353 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -95,8 +95,8 @@ class StmtFunctor { virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SBlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); TVM_FFI_UNREACHABLE(); @@ -119,8 +119,8 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); - IR_STMT_FUNCTOR_DISPATCH(BlockNode); - IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); + IR_STMT_FUNCTOR_DISPATCH(SBlockNode); + IR_STMT_FUNCTOR_DISPATCH(SBlockRealizeNode); vtable.Finalize(); return vtable; } @@ -160,8 +160,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; - void VisitStmt_(const BlockNode* op) override; - void VisitStmt_(const BlockRealizeNode* op) override; + void VisitStmt_(const SBlockNode* op) override; + void VisitStmt_(const SBlockRealizeNode* op) override; }; /*! @@ -258,8 +258,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; - Stmt VisitStmt_(const BlockNode* op) override; - Stmt VisitStmt_(const BlockRealizeNode* op) override; + Stmt VisitStmt_(const SBlockNode* op) override; + Stmt VisitStmt_(const SBlockRealizeNode* op) override; /*! * \brief Alternative advance method for SeqStmtNode. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index bf100dc49c4c..d2953f1fb48e 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -458,7 +458,7 @@ TVM_DLL Pass LiftThreadBinding(); * \code * * for i in range(0, 16): - * with T.block(): + * with T.sblock(): * B = T.alloc_buffer(16, 16) * for j in range(0, 16): * B[i, j] = A[i, j] + 1 @@ -474,7 +474,7 @@ TVM_DLL Pass LiftThreadBinding(); * \code * * for i in range(0, 16): - * with T.block(): + * with T.sblock(): * B = T.alloc_buffer(1, 16) * for j in range(0, 16): * B[0, j] = A[i, j] + 1 @@ -580,8 +580,8 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * are overlapped with the information provided in loop annotations, which enables optimization * techniques like prefetching and pipeline parallelism. * - * The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize, - * Block, SeqStmt), and the number of children is denoted by `n` in the documentation. + * The pipeline scope consists of the direct children of the annotated loop (ignoring SBlockRealize, + * SBlock, SeqStmt), and the number of children is denoted by `n` in the documentation. * * The following annotations are used to guide the loop transformation: * @@ -590,7 +590,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * where max_stage is the maximum (inclusive) stage. * 2) Loop annotation `software_pipeline_order` defines the pipeline order. * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1]; - * 3) Block annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of + * 3) SBlock annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of * read/write dependency. It's an integer index of the write regions of the block. * * Every annotated loop is transformed into a loop with three blocks as its direct children: @@ -620,15 +620,15 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * annotations={"software_pipeline_stage": [0, 1], * "software_pipeline_order": [0, 1]} * ): - * with T.block(): + * with T.sblock(): * T.reads(A[tx, i]) * T.writes(C[tx, i]) * B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - * with T.block("B"): + * with T.sblock("B"): * T.reads(A[tx, i]) * T.writes(B[tx, 0]) * B[tx, 0] = A[tx, i] * T.float32(2) - * with T.block("C"): + * with T.sblock("C"): * T.reads(B[tx, 0]) * T.writes(C[tx, i]) * C[tx, i] = B[tx, 0] + T.float32(1) @@ -641,27 +641,27 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * @T.prim_func * def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - * with T.block(): + * with T.sblock(): * T.reads([A[tx, 0:16]]) * T.writes([C[tx, 0:16]]) * B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - * with T.block("prologue"): + * with T.sblock("prologue"): * T.reads([A[tx, 0]]) * T.writes([B[0, tx, 0]]) * B[0, tx, 0] = A[tx, 0] * T.float32(2) - * with T.block("body"): + * with T.sblock("body"): * T.reads([A[tx, 1:16], B[0:2, tx, 0]]) * T.writes([B[0:2, tx, 0], C[tx, 0:15]]) * for i in T.serial(0, 15): - * with T.block("B"): + * with T.sblock("B"): * T.reads([A[tx, i + 1]]) * T.writes([B[(i + 1) % 2, tx, 0]]) * B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) - * with T.block("C"): + * with T.sblock("C"): * T.reads([B[i % 2, tx, 0]]) * T.writes([C[tx, i]]) * C[tx, i] = B[i % 2, tx, 0] + T.float32(1) - * with T.block("epilogue"): + * with T.sblock("epilogue"): * T.reads([B[1, tx, 0]]) * T.writes([C[tx, 15]]) * C[tx, 15] = B[1, tx, 0] + T.float32(1) diff --git a/include/tvm/tir/utils.h b/include/tvm/tir/utils.h index e85a973eb2a8..a62b13621990 100644 --- a/include/tvm/tir/utils.h +++ b/include/tvm/tir/utils.h @@ -32,7 +32,7 @@ namespace tir { * then check if the downcasting succeeded. * \param Result The result variable, used for checking * \param SRef The SRef to be cast - * \param Type The type to be cast to, can be Block or For + * \param Type The type to be cast to, can be SBlock or For */ #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ SRef->StmtAs(); \ @@ -46,9 +46,9 @@ namespace tir { * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_BLOCK(SRef) \ +#define TVM_SREF_TO_SBLOCK(SRef) \ [&]() { \ - auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::SBlockNode) \ << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ return result; \ @@ -101,15 +101,15 @@ namespace tir { * \param stmt The statement, or the realize node of the statement whose sref to be set * \param seq_index The seq_index to be set * \param include_loops Ignore ForNodes if this value is false - * \note The method is NOP for statements that are not schedulable, i.e. not For or Block + * \note The method is NOP for statements that are not schedulable, i.e. not For or SBlock */ inline void SetSeqIndex(std::unordered_map& stmt2ref, // NOLINT(*) const Stmt& stmt, int seq_index, bool include_loops = true) { - if (const auto* realize = stmt.as()) { - const BlockNode* block = realize->block.get(); + if (const auto* realize = stmt.as()) { + const SBlockNode* block = realize->block.get(); ICHECK(stmt2ref.count(block)); stmt2ref.at(block)->seq_index = seq_index; - } else if (const auto* block = stmt.as()) { + } else if (const auto* block = stmt.as()) { ICHECK(stmt2ref.count(block)); stmt2ref.at(block)->seq_index = seq_index; } else if (const auto* loop = stmt.as()) { diff --git a/jvm/core/src/test/scripts/prepare_test_libs.py b/jvm/core/src/test/scripts/prepare_test_libs.py index 550082adb816..b3d82cf07886 100644 --- a/jvm/core/src/test/scripts/prepare_test_libs.py +++ b/jvm/core/src/test/scripts/prepare_test_libs.py @@ -67,7 +67,7 @@ def prepare_gpu_lib(base_path): mod = tvm.IRModule.from_expr(te.create_prim_func([A, B, C]).with_attr("global_symbol", "myadd")) sch = tvm.tir.Schedule(mod) sch.work_on("myadd") - (i,) = sch.get_loops(block=sch.get_block("C")) + (i,) = sch.get_loops(block=sch.get_sblock("C")) i0, i1 = sch.split(i, [None, 32]) sch.bind(i0, "blockIdx.x") sch.bind(i1, "threadIdx.x") diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index 3d42d1972dcc..6ef108e64d8c 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -19,7 +19,7 @@ from . import adreno from . import cpu from .analysis import ( - BlockInfo, + SBlockInfo, IterInfo, normalize_prim_func, ) diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py index 462fb2550bd4..87b1f3641fb9 100644 --- a/python/tvm/dlight/adreno/convolution.py +++ b/python/tvm/dlight/adreno/convolution.py @@ -30,7 +30,7 @@ class Conv2d(AdrenoScheduleRule): """The schedule rule for convolution computation""" @staticmethod - def schedule_conv2d(sch: tir.Schedule, blk: tir.schedule.BlockRV): + def schedule_conv2d(sch: tir.Schedule, blk: tir.schedule.SBlockRV): n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk) bz, vz, tz = sch.split(oc, [None, 8, 1], preserve_unit_iters=True) @@ -81,12 +81,12 @@ def apply( # pylint: disable=too-many-locals,missing-docstring root_block = analysis.get_root_block(sch, sch.func_working_on) blocks = sch.get_child_blocks(root_block) reduction_blocks = list( - filter(lambda block: analysis.get_block_info(sch, block).is_reduction(), blocks) + filter(lambda block: analysis.get_sblock_info(sch, block).is_reduction(), blocks) ) remaining_blocks = [blk for blk in blocks if blk not in reduction_blocks] def is_convolution(blk): - block_info = analysis.get_block_info(sch, blk) + block_info = analysis.get_sblock_info(sch, blk) return "conv2d_NCHWc" in block_info.name if len(reduction_blocks) != 1 or not is_convolution(reduction_blocks[0]): diff --git a/python/tvm/dlight/adreno/fallback.py b/python/tvm/dlight/adreno/fallback.py index 050781a21f32..e60ccf12873f 100644 --- a/python/tvm/dlight/adreno/fallback.py +++ b/python/tvm/dlight/adreno/fallback.py @@ -47,8 +47,8 @@ class Fallback(AdrenoScheduleRule): @staticmethod def schedule_inline_blocks( - sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] - ) -> List[tir.schedule.BlockRV]: + sch: tir.Schedule, blocks: List[tir.schedule.SBlockRV] + ) -> List[tir.schedule.SBlockRV]: """ Auto Inlines Injective and Element-wise Operations while trying to omit data pad blocks... """ @@ -59,7 +59,7 @@ def schedule_inline_blocks( remaining_blocks = [] for blk in blocks: - block_info = analysis.get_block_info(sch, blk) + block_info = analysis.get_sblock_info(sch, blk) if block_info.is_injective() and not block_info.is_data_pad(sch): if len(sch.get_consumers(blk)) == 1: try: @@ -87,8 +87,8 @@ def schedule_inline_blocks( return remaining_blocks @staticmethod - def schedule_default(sch: tir.Schedule, blk: tir.schedule.BlockRV): - block_info = analysis.get_block_info(sch, blk) + def schedule_default(sch: tir.Schedule, blk: tir.schedule.SBlockRV): + block_info = analysis.get_sblock_info(sch, blk) s_loops, r_loops, o_loops = [], [], [] v_loop = block_info.write_bufs(sch)[0].assoc_lps[-1] @@ -139,8 +139,8 @@ def schedule_fallback(sch): schedule_blocks = [ blk for blk in blocks - if analysis.get_block_info(sch, blk).is_reduction() - or analysis.get_block_info(sch, blk).is_data_pad(sch) + if analysis.get_sblock_info(sch, blk).is_reduction() + or analysis.get_sblock_info(sch, blk).is_data_pad(sch) ] remaining_blocks = [blk for blk in blocks if blk not in schedule_blocks] @@ -169,7 +169,7 @@ def apply( # pylint: disable=too-many-locals if any(len(sch.get_child_blocks(block)) != 0 for block in blocks): return None - block_infos = [analysis.get_block_info(sch, block) for block in blocks] + block_infos = [analysis.get_sblock_info(sch, block) for block in blocks] if not any("texture" in block.write_bufs(sch)[0].get_scope() for block in block_infos): return None diff --git a/python/tvm/dlight/adreno/layout_transform.py b/python/tvm/dlight/adreno/layout_transform.py index bf3c446d59ff..6c2dfdff728e 100644 --- a/python/tvm/dlight/adreno/layout_transform.py +++ b/python/tvm/dlight/adreno/layout_transform.py @@ -58,7 +58,7 @@ def apply( # pylint: disable=too-many-locals return None blk = sch.get_child_blocks(root_block)[0] - block_info = analysis.get_block_info(sch, blk) + block_info = analysis.get_sblock_info(sch, blk) if not ( (self.use_op_name and block_info.name == "te_layout_transform") or (not self.use_op_name and block_info.is_layout_transform(sch)) diff --git a/python/tvm/dlight/adreno/pool.py b/python/tvm/dlight/adreno/pool.py index e3709caad96a..7026ab297491 100644 --- a/python/tvm/dlight/adreno/pool.py +++ b/python/tvm/dlight/adreno/pool.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -""" Pool schedule rule for Adreno operators.""" +"""Pool schedule rule for Adreno operators.""" from tvm import tir from tvm.target import Target @@ -33,7 +33,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring _: bool, ) -> tir.Schedule: sch = tir.Schedule(func) - root = sch.get_block(name="root", func_name="main") + root = sch.get_sblock(name="root", func_name="main") blocks = sch.get_child_blocks(root) blocks_names = [sch.get(blk).name_hint for blk in blocks] @@ -41,7 +41,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not "adaptive_pool_sum" in blocks_names and not "pool_max" in blocks_names: return None - def schedule_pad(blk: tir.schedule.BlockRV): + def schedule_pad(blk: tir.schedule.SBlockRV): lps, veclp = sch.get_loops(blk)[:-1], sch.get_loops(blk)[-1] sch.vectorize(veclp) b = sch.fuse(*lps) @@ -50,8 +50,8 @@ def schedule_pad(blk: tir.schedule.BlockRV): sch.bind(bx, "blockIdx.x") sch.bind(tx, "threadIdx.x") - def schedule_max_pool(blk: tir.schedule.BlockRV): - block_info = analysis.get_block_info(sch, blk) + def schedule_max_pool(blk: tir.schedule.SBlockRV): + block_info = analysis.get_sblock_info(sch, blk) iters_kind = "".join([_iter.kind for _iter in block_info.iters]) if iters_kind != "SSSSSRR": return None diff --git a/python/tvm/dlight/adreno/utils.py b/python/tvm/dlight/adreno/utils.py index 73597fe8578c..615fd82dcb03 100644 --- a/python/tvm/dlight/adreno/utils.py +++ b/python/tvm/dlight/adreno/utils.py @@ -22,10 +22,10 @@ from tvm.target import Target from tvm import tir -from ..analysis import BlockInfo +from ..analysis import SBlockInfo -def get_texture_storage(block_info: BlockInfo): +def get_texture_storage(block_info: SBlockInfo): """ Returns the texture layout acceptable for the shape @@ -67,13 +67,13 @@ def get_texture_storage(block_info: BlockInfo): return "global" -def schedule_inline_blocks(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] = None): +def schedule_inline_blocks(sch: tir.Schedule, blocks: List[tir.schedule.SBlockRV] = None): from .fallback import Fallback return Fallback.schedule_inline_blocks(sch, blocks) -def schedule_default(sch, blocks: List[tir.schedule.BlockRV] = None): +def schedule_default(sch, blocks: List[tir.schedule.SBlockRV] = None): from .fallback import Fallback ret = [] diff --git a/python/tvm/dlight/analysis/__init__.py b/python/tvm/dlight/analysis/__init__.py index 0df8abb2bf5c..a06de8d89339 100644 --- a/python/tvm/dlight/analysis/__init__.py +++ b/python/tvm/dlight/analysis/__init__.py @@ -16,7 +16,7 @@ # under the License. """Base infra""" from .common_analysis import ( - BlockInfo, + SBlockInfo, IterInfo, collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, @@ -24,7 +24,7 @@ is_broadcast_epilogue, normalize_prim_func, get_root_block, - get_block_info, + get_sblock_info, ) from .gemv import ( is_gemv, diff --git a/python/tvm/dlight/analysis/common_analysis.py b/python/tvm/dlight/analysis/common_analysis.py index 161deaf53772..79c466bcef42 100644 --- a/python/tvm/dlight/analysis/common_analysis.py +++ b/python/tvm/dlight/analysis/common_analysis.py @@ -26,7 +26,7 @@ from tvm import ir, tir from tvm.target.target import Target from tvm.tir import Schedule -from tvm.tir.schedule import BlockRV +from tvm.tir.schedule import SBlockRV from tvm.runtime import DataType @@ -63,7 +63,7 @@ def __repr__(self) -> str: return str(self) -get_blockrealize = get_global_func("tir.schedule.GetBlockRealize") +get_sblockrealize = get_global_func("tir.schedule.GetSBlockRealize") # BufferIndex Types Index = namedtuple("Index", ["sub"]) # c RemIndex = namedtuple("RemIndex", ["sub", "div"]) # c%len @@ -74,6 +74,7 @@ def __repr__(self) -> str: class BufferInfo: "Information about Buffer. Provides useful analysis" + buf_region: tir.BufferRegion shape: Tuple[int] assoc_lps: List[Union[tir.schedule.LoopRV, None]] @@ -82,7 +83,7 @@ class BufferInfo: def __init__( self, sch: tir.Schedule, - block_rv: tir.schedule.BlockRV, + block_rv: tir.schedule.SBlockRV, buf_region: tir.BufferRegion, lps: Union[List[tir.schedule.LoopRV], None], ): @@ -91,7 +92,7 @@ def __init__( lps = sch.get_loops(block_rv) loops = [sch.get(lp) for lp in lps] iter_vars = [Var.var for Var in block.iter_vars] - iter_values = get_blockrealize(sch, block_rv).iter_values + iter_values = get_sblockrealize(sch, block_rv).iter_values lpvar_lp = dict([loop.loop_var, lp] for loop, lp in zip(loops, lps)) var_lp = dict(zip(iter_vars, [lpvar_lp.get(val, None) for val in iter_values])) @@ -163,22 +164,22 @@ def __repr__(self) -> str: return str(self) -class BlockInfo: +class SBlockInfo: """Information about a TIR block.""" name: str iters: List[IterInfo] - block_rv: tir.schedule.BlockRV + block_rv: tir.schedule.SBlockRV _reduction_block: bool def __init__( self, name: str, iters: List[IterInfo], - block_rv: tir.schedule.BlockRV, + block_rv: tir.schedule.SBlockRV, reduction_block: bool = False, ): - """Construct a BlockInfo object.""" + """Construct a SBlockInfo object.""" self.name = name self.block_rv = block_rv self.iters = iters @@ -203,11 +204,11 @@ def dom_kind(self) -> str: return "".join(i.kind for i in self.iters) def is_injective(self) -> bool: - """Whether the block is injective, i.e. all its iteration domains are injective.""" + """Whether the SBlock is injective, i.e. all its iteration domains are injective.""" return all(k == "S" for k in self.dom_kind()) def is_elementwise(self, sch: tir.Schedule) -> bool: - """Whether the block is elementwise, i.e. trivial mapping between read/write region""" + """Whether the SBlock is elementwise, i.e. trivial mapping between read/write region""" def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: return dom.min.same_as(var) and dom.extent == 1 @@ -230,12 +231,12 @@ def get_loops(self) -> List[tir.schedule.LoopRV]: return [iter_info.loop_rv for iter_info in self.iters] def is_reduction(self) -> bool: - """Whether the block is a reduction workload.""" + """Whether the SBlock is a reduction workload.""" # TODO(@junrushao): distinguish GEMV and reduction return self._reduction_block def is_layout_transform(self, sch: tir.Schedule) -> bool: - """Whether the Block can be considered having a Layout Transform Pattern""" + """Whether the SBlock can be considered having a Layout Transform Pattern""" return ( all(k == "S" for k in self.dom_kind()) and len(self.write_bufs(sch)) == 1 @@ -245,7 +246,7 @@ def is_layout_transform(self, sch: tir.Schedule) -> bool: ) def is_data_pad(self, sch: tir.Schedule) -> bool: - """Whether the Block can be considered having a data pad pattern""" + """Whether the SBlock can be considered having a data pad pattern""" return ( all(k == "S" for k in self.dom_kind()) and len(self.write_bufs(sch)) == 1 @@ -257,23 +258,23 @@ def is_data_pad(self, sch: tir.Schedule) -> bool: ) def is_convolution(self) -> bool: - """Whether a Block can be considered having Convolution Pattern""" + """Whether a SBlock can be considered having Convolution Pattern""" raise NotImplementedError def is_pool(self) -> bool: - """Whether a Block can be considered having Pooling Pattern""" + """Whether a SBlock can be considered having Pooling Pattern""" raise NotImplementedError def is_gemv(self) -> bool: - """Whether the block is a GEMV workload.""" + """Whether the SBlock is a GEMV workload.""" raise NotImplementedError def is_gemm(self) -> bool: - """Whether the block is a GEMM workload.""" + """Whether the SBlock is a GEMM workload.""" raise NotImplementedError def __str__(self) -> str: - return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})' + return f'SBlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})' def __repr__(self) -> str: return str(self) @@ -282,7 +283,7 @@ def __repr__(self) -> str: _normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") -def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: +def normalize_prim_func(sch: tir.Schedule) -> Optional[List[SBlockInfo]]: """Normalize the primfunc to normal form""" try: result = _normalize_prim_func(sch) @@ -297,10 +298,10 @@ def _iter_kind(i: tir.IterVar) -> str: tir.IterVar.CommReduce: "R", }.get(i.iter_type, "O") - blocks: List[BlockInfo] = [] + blocks: List[SBlockInfo] = [] for block, loops, iters, is_reduction in zip(*result): blocks.append( - BlockInfo( + SBlockInfo( name=sch.get(block).name_hint, iters=[ IterInfo( @@ -318,17 +319,17 @@ def _iter_kind(i: tir.IterVar) -> str: return blocks -def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: +def get_sblock_info(sch: tir.Schedule, block: tir.schedule.SBlockRV) -> SBlockInfo: def _iter_kind(loop: tir.IterVar) -> str: return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") - def _is_reduction_block(block: tir.schedule.BlockRV): + def _is_reduction_block(block: tir.schedule.SBlockRV): for iter_var in sch.get(block).iter_vars: if _iter_kind(iter_var) == "R": return True return False - return BlockInfo( + return SBlockInfo( name=sch.get(block).name_hint, iters=[ IterInfo( @@ -370,7 +371,7 @@ def get_max_shared_memory_per_block(target: Target) -> int: return int(max_shared_memory_per_block) -def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: +def get_root_block(sch: Schedule, func_name: str = "main") -> SBlockRV: try: block = sch.mod[func_name].body.block except: @@ -378,11 +379,11 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: f"The function body is expected to be the root block, but got:\n" f"{sch.mod[func_name].body}" ) - return sch.get_block(block.name_hint) + return sch.get_sblock(block.name_hint) def collect_block_iter_vars_used_in_access_region( - block: tir.Block, region: List[ir.Range] + block: tir.SBlock, region: List[ir.Range] ) -> Set[tir.Var]: """Collect the block iter variables used in the access region of a buffer region.""" tir_vars = set() @@ -405,7 +406,7 @@ def _collect_tir_var(expr): return tir_vars -def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: +def detect_dominant_read(block: tir.SBlock) -> tir.PrimExpr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 @@ -421,8 +422,8 @@ def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: def is_broadcast_epilogue( sch: tir.Schedule, - block: tir.schedule.BlockRV, - epilogue: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, + epilogue: tir.schedule.SBlockRV, ) -> bool: """Check if the epilogue block is a broadcast pattern""" write_buffers = {r.buffer for r in sch.get(block).writes} diff --git a/python/tvm/dlight/analysis/gemv.py b/python/tvm/dlight/analysis/gemv.py index c502081ba320..74910091e99e 100644 --- a/python/tvm/dlight/analysis/gemv.py +++ b/python/tvm/dlight/analysis/gemv.py @@ -20,14 +20,14 @@ from tvm import arith, ir, tir from .common_analysis import ( - BlockInfo, + SBlockInfo, collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, detect_dominant_read, ) -def get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: +def get_reduction_expr(block: tir.SBlock) -> Optional[tir.PrimExpr]: """Extracts the reduction expression from a TIR block. This function checks whether the given TIR block follows a reduction pattern @@ -35,7 +35,7 @@ def get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: Parameters: ---------- - block : tir.Block + block : tir.SBlock The TIR block to analyze. Returns: @@ -58,7 +58,7 @@ def get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: return buffer_store.value.b -def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: +def is_gemv(sch: tir.Schedule, block_info: SBlockInfo) -> Optional[List[tir.Buffer]]: """Check if the block is a GEMV. Parameters @@ -67,7 +67,7 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe sch : tir.Schedule The schedule - block_info : BlockInfo + block_info : SBlockInfo The block info to be checked @@ -102,10 +102,10 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe def normalize( sch: tir.Schedule, - block_info: BlockInfo, + block_info: SBlockInfo, ) -> Optional[bool]: """Normalize the main block.""" - block_stmt: tir.Block = sch.get(block_info.block_rv) + block_stmt: tir.SBlock = sch.get(block_info.block_rv) access = arith.normalize_to_iter_sum( detect_dominant_read(block_stmt), input_iters={i.var: i.dom for i in block_stmt.iter_vars}, diff --git a/python/tvm/dlight/base/common_schedules.py b/python/tvm/dlight/base/common_schedules.py index c205b78390bc..552474f71256 100644 --- a/python/tvm/dlight/base/common_schedules.py +++ b/python/tvm/dlight/base/common_schedules.py @@ -19,25 +19,25 @@ from tvm import tir -from ..analysis import BlockInfo +from ..analysis import SBlockInfo def try_inline( sch: tir.Schedule, - blocks: List[BlockInfo], -) -> List[BlockInfo]: + blocks: List[SBlockInfo], +) -> List[SBlockInfo]: """Try to inline as many blocks as possible, and return the remaining blocks. Parameters ---------- sch : tir.Schedule The TIR schedule used to inline blocks. - blocks : List[BlockInfo] + blocks : List[SBlockInfo] The blocks to be inlined. Returns ------- - remaining : List[BlockInfo] + remaining : List[SBlockInfo] The remaining blocks that cannot be inlined. """ @@ -62,20 +62,20 @@ def _trial(func: Callable): def try_inline_contiguous_spatial( sch: tir.Schedule, - block_infos: List[BlockInfo], -) -> List[BlockInfo]: + block_infos: List[SBlockInfo], +) -> List[SBlockInfo]: """Try to inline contiguous spatial blocks in a schedule Parameters ---------- sch : tir.Schedule The TIR schedule used to inline blocks. - block_infos : List[BlockInfo] + block_infos : List[SBlockInfo] The blocks to be try. Returns ------- - remaining : List[BlockInfo] + remaining : List[SBlockInfo] The remaining blocks that cannot be inlined. """ @@ -83,7 +83,7 @@ def try_inline_contiguous_spatial( return None results = [] spatial_blocks = [] - block: BlockInfo + block: SBlockInfo for block in block_infos: if block.is_injective(): spatial_blocks.append(block) diff --git a/python/tvm/dlight/cpu/gemv.py b/python/tvm/dlight/cpu/gemv.py index 15b47de919a7..bfeb6384dd3a 100644 --- a/python/tvm/dlight/cpu/gemv.py +++ b/python/tvm/dlight/cpu/gemv.py @@ -20,7 +20,7 @@ from tvm import tir from tvm.target import Target -from ..analysis import BlockInfo, normalize_prim_func +from ..analysis import SBlockInfo, normalize_prim_func from ..analysis.gemv import is_gemv, normalize from ..base import get_extent, try_inline_contiguous_spatial from .base import CPUScheduleRule @@ -77,9 +77,9 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, too-many-positio self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], ): """Schedule the inner reduction block.""" diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index bcbfda791fb3..177322c9749b 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -50,7 +50,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring return None block_infos = try_inline(sch, block_infos) - reduction_blocks: List[Tuple[tir.schedule.BlockRV, tir.schedule.LoopRV]] = [] + reduction_blocks: List[Tuple[tir.schedule.SBlockRV, tir.schedule.LoopRV]] = [] for block in block_infos: s_loops: List[tir.schedule.LoopRV] = [] r_loops: List[tir.schedule.LoopRV] = [] diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ebb19ad72c3a..cf76f645372e 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -22,7 +22,7 @@ from tvm.target import Target from ..analysis import ( - BlockInfo, + SBlockInfo, is_broadcast_epilogue, is_gemv, normalize, @@ -87,9 +87,9 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, un self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], ): """Schedule the inner reduction block.""" @@ -427,9 +427,9 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], ): """Schedule the outer reduction block.""" @@ -632,9 +632,9 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], ): """Schedule the outer reduction block.""" # NOTE: Only Android is supported so far diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index f5e3669ad0f3..e503b05eecbc 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -22,7 +22,7 @@ from tvm.target import Target from ..analysis import ( - BlockInfo, + SBlockInfo, collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, is_broadcast_epilogue, @@ -32,7 +32,7 @@ from .base import GPUScheduleRule -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: +def _get_reduction_expr(block: tir.SBlock) -> Optional[tir.PrimExpr]: # Detect and return `Y` in `X[...] = X[...] + Y` buffer_store = block.body if not isinstance(buffer_store, tir.BufferStore): @@ -48,7 +48,7 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: return buffer_store.value.b -def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: +def is_gemv(sch: tir.Schedule, block_info: SBlockInfo) -> Optional[List[tir.Buffer]]: """Check if the block is a low batch GEMM. Parameters @@ -57,7 +57,7 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe sch : tir.Schedule The schedule - block_info : BlockInfo + block_info : SBlockInfo The block info to be checked @@ -109,7 +109,7 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe return ret if 0 < len(ret) < len(block_stmt.reads) else None -def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) -> tir.PrimExpr: +def detect_dominant_read(block: tir.SBlock, const_iter_vars: Set[tir.Var]) -> tir.PrimExpr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 @@ -128,10 +128,10 @@ def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) -> tir def normalize( sch: tir.Schedule, - block_info: BlockInfo, + block_info: SBlockInfo, ) -> Optional[bool]: """Normalize the main block.""" - block_stmt: tir.Block = sch.get(block_info.block_rv) + block_stmt: tir.SBlock = sch.get(block_info.block_rv) const_iter_vars = set( iter_var.var for iter_var in block_stmt.iter_vars @@ -288,11 +288,11 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, un self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, - dequantize_block: Optional[tir.schedule.BlockRV], - pad_input_block: Optional[tir.schedule.BlockRV], + block: tir.schedule.SBlockRV, + dequantize_block: Optional[tir.schedule.SBlockRV], + pad_input_block: Optional[tir.schedule.SBlockRV], vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], batch_pad: int, ): """Schedule the inner reduction block.""" @@ -600,11 +600,11 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, - dequantize_block: Optional[tir.schedule.BlockRV], - pad_input_block: Optional[tir.schedule.BlockRV], + block: tir.schedule.SBlockRV, + dequantize_block: Optional[tir.schedule.SBlockRV], + pad_input_block: Optional[tir.schedule.SBlockRV], vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], batch_pad: int, ): """Schedule the outer reduction block.""" @@ -615,7 +615,7 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un def apply( sch: tir.Schedule, - main_block: tir.schedule.BlockRV, + main_block: tir.schedule.SBlockRV, TAG_S: Literal["threadIdx.x", "threadIdx.y"], TAG_R: Literal["threadIdx.x", "threadIdx.y"], TS: int, diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 368552c88d43..d5f3f758a0ac 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -26,13 +26,13 @@ from tvm.target import Target from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars -from tvm.tir.schedule.schedule import BlockRV +from tvm.tir.schedule.schedule import SBlockRV -from ..analysis import BlockInfo, IterInfo, get_root_block +from ..analysis import SBlockInfo, IterInfo, get_root_block from .base import GPUScheduleRule -def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): +def _collect_producers(sch: tir.Schedule, block: tir.schedule.SBlockRV): result = [] for producer in sch.get_producers(block): result.append(producer) @@ -40,7 +40,7 @@ def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): return result -def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.SBlockRV): result = [] for consumer in sch.get_consumers(block): result.append(consumer) @@ -50,7 +50,7 @@ def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): def auto_inline_producers( sch: tir.Schedule, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, ): while True: inlined_cnt = 0 @@ -67,7 +67,7 @@ def auto_inline_producers( def auto_inline_consumers( sch: tir.Schedule, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, ): while True: inlined_cnt = 0 @@ -90,7 +90,7 @@ def auto_inline_consumers( def auto_inline_consumer_chain( sch: tir.Schedule, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, ): auto_inline_consumers(sch, block) remaining_consumers = sch.get_consumers(block) @@ -165,12 +165,12 @@ def make_iter_fusion_index_map( return tir.IndexMap(input_iters, final_indices, None) -def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: +def detect_iter_traits(block: tir.SBlock) -> Optional[Tuple[List[IterTrait]]]: """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] Parameters ---------- - block : tir.Block + block : tir.SBlock The block to be analyzed Returns @@ -235,12 +235,12 @@ def get_access_axes(region: List[Range]) -> Set[Var]: return A_traits, B_traits, C_traits, block_traits -def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]: +def get_index_map(block: tir.SBlock) -> Optional[Tuple[tir.IndexMap, ...]]: """Get index maps for the block Parameters ---------- - block : tir.Block + block : tir.SBlock The block to be analyzed Returns @@ -274,17 +274,17 @@ def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]: ) -def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: +def get_sblock_info(sch: tir.Schedule, block: tir.schedule.SBlockRV) -> SBlockInfo: def _iter_kind(loop: tir.IterVar) -> str: return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") - def _is_reduction_block(block: tir.schedule.BlockRV): + def _is_reduction_block(block: tir.schedule.SBlockRV): for iter_var in sch.get(block).iter_vars: if _iter_kind(iter_var) == "R": return True return False - return BlockInfo( + return SBlockInfo( name=sch.get(block).name_hint, iters=[ IterInfo( @@ -302,12 +302,12 @@ def _is_reduction_block(block: tir.schedule.BlockRV): def get_reduction_blocks(sch, blocks) -> bool: # Get the main computation block - def is_reduction(block: BlockRV) -> bool: + def is_reduction(block: SBlockRV) -> bool: block_stmt = sch.get(block) iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} return iter_types == {IterVar.CommReduce, IterVar.DataPar} - def is_spatial(block: BlockRV) -> bool: + def is_spatial(block: SBlockRV) -> bool: block_stmt = sch.get(block) iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} return iter_types == {IterVar.DataPar} @@ -325,7 +325,7 @@ def is_spatial(block: BlockRV) -> bool: return reduction_blocks -def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: +def get_in_out_dtypes(block: tir.SBlock) -> Tuple[str]: """ Detect In/Out data types for the given block based on the analysis if read/write buffers. """ @@ -453,7 +453,7 @@ def fetch_to_shared(block, idx): ) sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j, i)) - def tensorize_block(block: tir.schedule.BlockRV, intrin: str): + def tensorize_block(block: tir.schedule.SBlockRV, intrin: str): *_, i, j = sch.get_loops(block) io, ii = sch.split(i, [None, micro_size]) jo, ji = sch.split(j, [None, micro_size]) @@ -981,7 +981,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring main_block = reduction_blocks[0] block_stmt = sch.get(main_block) - main_block_info = get_block_info(sch, main_block) + main_block_info = get_sblock_info(sch, main_block) iter_infos = main_block_info.iters if not get_index_map(block_stmt): return None @@ -1127,8 +1127,8 @@ def sch_outer_reduction( self, sch: tir.Schedule, config: Config, - reduction_block: tir.schedule.BlockRV, - blocks: List[tir.schedule.BlockRV], + reduction_block: tir.schedule.SBlockRV, + blocks: List[tir.schedule.SBlockRV], ) -> Optional[tir.Schedule]: """Get vectorization factor""" diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 4faaa1cab94a..d07f43f738e1 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -22,7 +22,7 @@ from tvm.target import Target from ..analysis import ( - BlockInfo, + SBlockInfo, detect_dominant_read, is_broadcast_epilogue, normalize_prim_func, @@ -31,7 +31,7 @@ from .base import GPUScheduleRule -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: +def _get_reduction_expr(block: tir.SBlock) -> Optional[tir.PrimExpr]: # Detect and return `Y` in `X[...] = X[...] + Y` buffer_store = block.body if not isinstance(buffer_store, tir.BufferStore): @@ -113,7 +113,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- def _normalize( # pylint: disable=too-many-branches self, sch: tir.Schedule, - block_info: BlockInfo, + block_info: SBlockInfo, access: arith.IterSumExpr, ) -> Tuple[Optional[bool], Optional[int], Optional[Mapping[int, int]], Optional[int]]: if access.base != 0: @@ -177,9 +177,9 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments self, sch: tir.Schedule, target: Target, - block: tir.schedule.BlockRV, + block: tir.schedule.SBlockRV, unroll_spatial_factor: Optional[int], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], loop_order, s_split_index, ): @@ -235,10 +235,10 @@ def _sch_inner_spatial( self, sch: tir.Schedule, _: Target, - block: tir.schedule.BlockRV, - block_info: BlockInfo, + block: tir.schedule.SBlockRV, + block_info: SBlockInfo, unroll_spatial_factor: Optional[int], - epilogue_info: Optional[BlockInfo], + epilogue_info: Optional[SBlockInfo], loop_order, s_split_index, ): diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 5dc6887c782c..7f58f3aac3ca 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -20,13 +20,13 @@ import tvm from tvm import tir from tvm.target import Target -from tvm.tir import Block, BufferStore +from tvm.tir import SBlock, BufferStore from tvm.tir.expr import BufferLoad, Call, Cast from ..base import ScheduleRule -def identify_cast_or_load_block(block: Block) -> bool: +def identify_cast_or_load_block(block: SBlock) -> bool: if len(block.reads) != 1 or len(block.writes) != 1: return False @@ -55,7 +55,7 @@ def identify_cast_or_load_block(block: Block) -> bool: return True -def identify_rsqrt_block(block: Block) -> bool: +def identify_rsqrt_block(block: SBlock) -> bool: if len(block.reads) != 1 or len(block.writes) != 1: return False @@ -88,7 +88,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring num_tx = 64 sch = tir.Schedule(func) - root = sch.get_block(name="root", func_name="main") + root = sch.get_sblock(name="root", func_name="main") blocks = sch.get_child_blocks(root) diff --git a/python/tvm/dlight/gpu/transpose.py b/python/tvm/dlight/gpu/transpose.py index 125af538cdb8..83ae0b64371a 100644 --- a/python/tvm/dlight/gpu/transpose.py +++ b/python/tvm/dlight/gpu/transpose.py @@ -20,7 +20,7 @@ from tvm import arith, tir from tvm.target import Target from tvm.tir import Schedule -from tvm.tir.schedule import BlockRV +from tvm.tir.schedule import SBlockRV from ..analysis import detect_dominant_read, normalize_prim_func from ..base import try_inline_contiguous_spatial @@ -30,7 +30,7 @@ class Transpose(GPUScheduleRule): """Schedule rule for transpose""" - def is_transpose(self, sch: Schedule, block_rv: BlockRV): + def is_transpose(self, sch: Schedule, block_rv: SBlockRV): block = sch.get(block_rv) if isinstance(block.body, tir.BufferStore): rhs = block.body.value diff --git a/python/tvm/exec/gpu_memory_bandwidth.py b/python/tvm/exec/gpu_memory_bandwidth.py index e7e38638d715..836c041d5345 100644 --- a/python/tvm/exec/gpu_memory_bandwidth.py +++ b/python/tvm/exec/gpu_memory_bandwidth.py @@ -158,7 +158,7 @@ def _schedule( len_vec: int, ): # pylint: disable=invalid-name - block = sch.get_block("B") + block = sch.get_sblock("B") xo, xi, k = sch.get_loops(block) bx, xo = sch.split(xo, factors=[len_bx, None]) xi, tx, vec = sch.split(xi, factors=[None, len_tx, len_vec]) diff --git a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py index 58540839397d..d5dca7d8e895 100644 --- a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py +++ b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py @@ -22,7 +22,7 @@ import tvm from tvm import meta_schedule -from tvm.tir.schedule import BlockRV, ExprRV, LoopRV +from tvm.tir.schedule import SBlockRV, ExprRV, LoopRV ## Tiling layout transforms: # Assume we have an input shape of [A, B, C, D] and want to layout transform @@ -85,13 +85,13 @@ def tile_layout_transform( sch: tvm.tir.Schedule, - block_read: BlockRV, - block_write: BlockRV, + block_read: SBlockRV, + block_write: SBlockRV, src_layout: str, dst_layout: str, input_shape: List[int], tile_size: ExprRV, -) -> Tuple[BlockRV, BlockRV]: +) -> Tuple[SBlockRV, SBlockRV]: """ High level tiling for layout transform block. Mutates sch in place. @@ -237,7 +237,7 @@ def factor_dim_in_order( return loops, cur_loop_extants def get_high_level_loop_structure( - block_read: BlockRV, input_shape: List[int], src_layout: str, dst_layout: str + block_read: SBlockRV, input_shape: List[int], src_layout: str, dst_layout: str ): """Runs the factorization described above.""" # index 0 ... rank - 1 will always correspond to original loops @@ -329,11 +329,11 @@ def get_high_level_loop_structure( def create_cached_read( sch: tvm.tir.Schedule, - block_write: BlockRV, + block_write: SBlockRV, orig_input_shape: List[int], orig_src_layout: str, orig_dst_layout: str, -) -> Tuple[BlockRV, List[int], str, str]: +) -> Tuple[SBlockRV, List[int], str, str]: """ Creates the cached read block with expected structure. @@ -446,7 +446,7 @@ def unpack_list(target_list) -> List: return block_read, unpack_list(input_shape), new_src_layout_str, new_dst_layout_str -def auto_inline_into(sch: tvm.tir.Schedule, start_block: BlockRV) -> BlockRV: +def auto_inline_into(sch: tvm.tir.Schedule, start_block: SBlockRV) -> SBlockRV: """ Inlines given start_block's consumers and future dependencies into start_block. @@ -503,7 +503,7 @@ def get_max_tile_size() -> int: @tvm.register_global_func("meta_schedule.cuda.layout_transform") def cuda_layout_transform_schedule_rule( - sch: tvm.tir.Schedule, block: BlockRV, testing_tile_sizes: Optional[List[int]] = None + sch: tvm.tir.Schedule, block: SBlockRV, testing_tile_sizes: Optional[List[int]] = None ) -> List[tvm.tir.Schedule]: """ Applies tiling scheme to layout transform task (potentially fused with other injective funcs). diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 41c97a7862b4..d570af14895f 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -17,7 +17,7 @@ """Multi-level tiling with reuse.""" from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Callable -from tvm.tir.schedule import Schedule, BlockRV +from tvm.tir.schedule import Schedule, SBlockRV from tvm_ffi import register_object from .. import _ffi_api @@ -63,7 +63,7 @@ class MultiLevelTiling(ScheduleRule): Data reuse configuration for reading. None means no reuse. reuse_write : Optional[ReuseType] Data reuse configuration for writing. None means no reuse. - filter_fn: Optional[Callable[[Schedule, BlockRV], bool]] + filter_fn: Optional[Callable[[Schedule, SBlockRV], bool]] A function that can be passed to overwrite the default condition for applying MultiLevelTiling to a block. This is useful if there is a need to apply MultiLevelTiling to an operation / block which is ignored by default. This function should return True @@ -78,7 +78,7 @@ def __init__( vector_load_lens: Optional[List[int]] = None, reuse_read: Optional[ReuseType] = None, reuse_write: Optional[ReuseType] = None, - filter_fn: Optional[Callable[[Schedule, BlockRV], bool]] = None, + filter_fn: Optional[Callable[[Schedule, SBlockRV], bool]] = None, ) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 98c81e5b8f30..8c2ab1cf1dec 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -27,7 +27,7 @@ from tvm_ffi import register_object from tvm.runtime import Object -from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule import SBlockRV, Schedule from .. import _ffi_api from ..utils import _get_default_str @@ -52,14 +52,14 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: self, context ) - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: """Apply a schedule rule to the specific block in the given schedule. Parameters ---------- sch : tvm.tir.Schedule The schedule to be modified. - block : BlockRV + block : SBlockRV The specific block to apply the schedule rule. Returns @@ -162,14 +162,14 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: """ raise NotImplementedError - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: """Apply a schedule rule to the specific block in the given schedule. Parameters ---------- sch : Schedule The schedule to be modified. - block : BlockRV + block : SBlockRV The specific block to apply the schedule rule. Returns diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 8d40d3d42780..9be223a6b063 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -30,7 +30,7 @@ from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import DataflowBlock, Var, GlobalVar, Expr, Function, Call, Binding -from tvm.tir import IndexMap, PrimFunc, Block, Buffer +from tvm.tir import IndexMap, PrimFunc, SBlock, Buffer from . import _ffi_api @@ -517,7 +517,7 @@ def _get_prim_func_default_dtype(func: PrimFunc): def suggest_layout_transforms( func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]] -) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]: +) -> Dict[SBlock, Dict[Union[SBlock, Buffer], IndexMap]]: """Suggest Layout transformations of blocks and buffers in a PrimFunc. Parameters @@ -531,7 +531,7 @@ def suggest_layout_transforms( Returns ------- - ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]] + ret: Dict[SBlock, Dict[Union[SBlock, Buffer], IndexMap]] Suggested transforms per block in `func`. For each block the returned value is a map from the object (block or buffer) to it's index map transformation. """ diff --git a/python/tvm/relax/backend/gpu_generic/cumsum.py b/python/tvm/relax/backend/gpu_generic/cumsum.py index 914dfb6bd231..324a256bc5e8 100644 --- a/python/tvm/relax/backend/gpu_generic/cumsum.py +++ b/python/tvm/relax/backend/gpu_generic/cumsum.py @@ -91,7 +91,7 @@ def block_inclusive_inside_block( ): for by in T.thread_binding(batch, thread="blockIdx.y"): for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), thread="blockIdx.x"): - with T.block(): + with T.sblock(): local_buf = T.alloc_buffer((thread_elem,), out_dtype, scope="local") shared_buf = T.alloc_buffer((block_elem,), out_dtype, scope="shared") for ty in T.thread_binding(TY, thread="threadIdx.y"): diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py b/python/tvm/relax/backend/gpu_generic/sampling.py index 9a0d01ef2331..ba48f13d9629 100644 --- a/python/tvm/relax/backend/gpu_generic/sampling.py +++ b/python/tvm/relax/backend/gpu_generic/sampling.py @@ -137,7 +137,7 @@ def block_adjacent_difference_left( source_local: T.Buffer, output_local: T.Buffer, ): - with T.block(): + with T.sblock(): shared_buf = T.alloc_buffer((TX * TY,), "bool", scope="shared") tx_idx = ty * TX + tx shared_buf[tx_idx] = source_local[thread_elem - 1] @@ -166,7 +166,7 @@ def block_reduce_with_mask( reduce_op: Callable, # T.macro mask_local: Optional[T.Buffer] = None, ): - with T.block(): + with T.sblock(): local_sum = T.alloc_buffer((), dtype, scope="local") shared_buf = T.alloc_buffer((TX * TY,), dtype, scope="shared") idx = ty * TX + tx @@ -198,7 +198,7 @@ def single_batch_sampling( uniform_sample, sample_id_local, ): - with T.block(): + with T.sblock(): prob_gt_threshold = T.alloc_buffer((thread_elem,), prob_dtype, scope="local") cumsum = T.alloc_buffer((block_elem,), prob_dtype, scope="shared") greater_than_u = T.alloc_buffer((thread_elem,), "bool", scope="local") @@ -326,7 +326,7 @@ def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): output_index = T.match_buffer(D, (out_batch, 1), dtype) for ax0, ax1 in T.grid(out_batch, vocab_size): - with T.block("T_get_sample_index"): + with T.sblock("T_get_sample_index"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.writes(output_index[v_ax0, 0]) if ( diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 8c777eb53756..7b0bbff01652 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -482,9 +482,9 @@ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") compute = T.match_buffer(var_compute, [128, 128], dtype="float32") # body - # with T.block("root") + # with T.sblock("root") for i0, i1 in T.grid(128, 128): - with T.block("compute"): + with T.sblock("compute"): i, j = T.axis.remap("SS", [i0, i1]) T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]]) T.writes([compute[i, j]]) @@ -526,9 +526,9 @@ def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> N dtype="float32") compute = T.match_buffer(var_compute, [n + T.int64(1)], dtype="float32") # body - # with T.block("root") + # with T.sblock("root") for i0 in T.serial(0, n + T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): i = T.axis.spatial(n + T.int64(1), i0) T.reads([rxplaceholder[i]]) T.writes([compute[i]]) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index e94d5c42957b..6b6029630d85 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -702,13 +702,13 @@ def tir_kv_cache_transpose_append( ) for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): if position_map[global_pos] != T.int32(-1): - with T.block("k_transpose_append"): + with T.sblock("k_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]) position: T.int32 = position_map[vgpos] # type: ignore pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): + with T.sblock("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]) @@ -743,7 +743,7 @@ def tir_kv_cache_transpose_append_mla( ) for global_pos, f in T.grid(ntoken, d_qk): if position_map[global_pos] != T.int32(-1): - with T.block("k_transpose_append"): + with T.sblock("k_transpose_append"): vgpos, vf = T.axis.remap("SS", [global_pos, f]) T.reads(position_map[vgpos], kv_data[vgpos, vf]) T.writes(pages[position_map[vgpos] // page_size, position_map[vgpos] % page_size, vf]) @@ -781,7 +781,7 @@ def tir_kv_cache_debug_get_kv( k_data = T.match_buffer(var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) v_data = T.match_buffer(var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): - with T.block("copy0"): + with T.sblock("copy0"): vp, vh, vd = T.axis.remap("SSS", [p, h, d]) T.reads(position_map[vp], pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd]) T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) @@ -818,7 +818,7 @@ def tir_kv_cache_debug_get_kv_mla( ) compressed_kv_with_k_pe_data = T.match_buffer(var_compressed_kv_with_k_pe_data, (num_hidden_layers, seqlen, d_qk), dtype) for p, d in T.grid(seqlen, d_qk): - with T.block("copy0"): + with T.sblock("copy0"): vp, vd = T.axis.remap("SS", [p, d]) T.reads(position_map[vp], pages[position_map[vp] // page_size, position_map[vp] % page_size, vd]) T.writes(compressed_kv_with_k_pe_data[layer_id, vp, vd]) @@ -961,7 +961,7 @@ def batch_prefill_paged_kv_cpu( for h_qo in T.serial(h_q): for b_idx in T.serial(batch_size): - with T.block("attn"): + with T.sblock("attn"): O_local = T.alloc_buffer((d, ), "float32") Q_local = T.alloc_buffer((d, ), "float32") K_local = T.alloc_buffer((d, ), "float32") @@ -1183,18 +1183,18 @@ def apply_to_md(sch, block): sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_gemm(sch, sch.get_sblock("S_gemm"), tile_s, k_major=True) + apply_to_gemm(sch, sch.get_sblock("O_gemm"), tile_o, k_major=False) + apply_to_so_ewise(sch, sch.get_sblock("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_sblock("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_sblock("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_sblock("Q_load")) if not merged_qk_load: - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) + apply_to_qkv_load(sch, sch.get_sblock("K_load")) + apply_to_qkv_load(sch, sch.get_sblock("V_load")) else: - apply_to_qkv_load(sch, sch.get_block("KV_load")) - apply_to_md(sch, sch.get_block("lse_store")) + apply_to_qkv_load(sch, sch.get_sblock("KV_load")) + apply_to_md(sch, sch.get_sblock("lse_store")) return sch @@ -1280,7 +1280,7 @@ def batch_prefill_paged_kv( for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): + with T.sblock("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() @@ -1344,14 +1344,14 @@ def batch_prefill_paged_kv( d_smem[row] = 1.0 for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): + with T.sblock("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 T.tvm_storage_sync("shared") # Load Q from gmem to smem for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): + with T.sblock("Q_load"): i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() @@ -1370,7 +1370,7 @@ def batch_prefill_paged_kv( for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): L_kv_start: T.int32 = iterator * tile_z for lz, ly in T.grid(tile_z, tile_y): - with T.block("K_load"): + with T.sblock("K_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -1388,7 +1388,7 @@ def batch_prefill_paged_kv( K_smem[i, j] = 0.0 T.tvm_storage_sync("shared") for lz, ly in T.grid(tile_z, tile_y): - with T.block("V_load"): + with T.sblock("V_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -1403,16 +1403,16 @@ def batch_prefill_paged_kv( T.tvm_storage_sync("shared") # Compute S - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): + with T.sblock("S_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): + with T.sblock("S_store"): i, j = T.axis.remap("SS", [li, lj]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") @@ -1421,7 +1421,7 @@ def batch_prefill_paged_kv( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update1"): + with T.sblock("update1"): m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S @@ -1437,7 +1437,7 @@ def batch_prefill_paged_kv( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: @@ -1454,7 +1454,7 @@ def batch_prefill_paged_kv( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): d_new[i] += S_smem[row, j] m_smem[row] = m_new[i] @@ -1463,9 +1463,9 @@ def batch_prefill_paged_kv( T.tvm_storage_sync("shared") # Update O - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): + with T.sblock("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) @@ -1473,7 +1473,7 @@ def batch_prefill_paged_kv( # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): + with T.sblock("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size @@ -1482,7 +1482,7 @@ def batch_prefill_paged_kv( # Store LSE to gmem for li in T.grid(tile_x): - with T.block("lse_store"): + with T.sblock("lse_store"): i = T.axis.remap("S", [li]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size @@ -1575,7 +1575,7 @@ def batch_decode_paged_kv( ) for b in T.serial(B): - with T.block("attn"): + with T.sblock("attn"): O_local = T.alloc_buffer((D,), "float32") Q_local = T.alloc_buffer((D,), "float32") K_local = T.alloc_buffer((D,), "float32") @@ -1753,7 +1753,7 @@ def batch_decode_paged_kv( for ty in T.thread_binding(bdy, thread="threadIdx.y"): for tx in T.thread_binding(bdx, thread="threadIdx.x"): for tz in T.thread_binding(bdz, thread="threadIdx.z"): - with T.block("attn"): + with T.sblock("attn"): Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") @@ -1807,7 +1807,7 @@ def batch_decode_paged_kv( tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore # load KV from global memory to shared memory for j in T.serial(tile_size_per_bdx): - with T.block("KV_load"): + with T.sblock("KV_load"): T.reads() T.writes() row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore @@ -1837,7 +1837,7 @@ def batch_decode_paged_kv( for vec in T.unroll(VEC_SIZE): S_reduce_local[0] += QK_local[vec] - with T.block("block_cross_thread"): + with T.sblock("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr( @@ -1932,7 +1932,7 @@ def merge_state_inplace_cpu( for n in T.serial(N): for h in T.serial(H): - with T.block("merge"): + with T.sblock("merge"): s_val = _var_cpu("float32") s_other_val = _var_cpu("float32") s_max = _var_cpu("float32") @@ -1987,7 +1987,7 @@ def merge_state_inplace( for by in T.thread_binding(gdy, thread="blockIdx.y"): for ty in T.thread_binding(bdy, thread="threadIdx.y"): for tx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("merge"): + with T.sblock("merge"): s_val = _var("float32") s_other_val = _var("float32") s_max = _var("float32") @@ -2070,7 +2070,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): + with T.sblock("attn"): vbx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() @@ -2110,14 +2110,14 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches d_smem[row] = 1.0 for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): + with T.sblock("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 T.tvm_storage_sync("shared") # Load Q from gmem to smem for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): + with T.sblock("Q_load"): i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() @@ -2133,7 +2133,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches L_kv_start: T.int32 = iterator * tile_z L_kv_base: T.int32 = 0 for lz, ly in T.grid(tile_z, tile_y): - with T.block("K_load"): + with T.sblock("K_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -2146,7 +2146,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches K_smem[i, j] = 0.0 T.tvm_storage_sync("shared") for lz, ly in T.grid(tile_z, tile_y): - with T.block("V_load"): + with T.sblock("V_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -2160,9 +2160,9 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches T.tvm_storage_sync("shared") # Compute S - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): + with T.sblock("S_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 @@ -2174,7 +2174,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches ) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): + with T.sblock("S_store"): i, j = T.axis.remap("SS", [li, lj]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") @@ -2183,7 +2183,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update1"): + with T.sblock("update1"): m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S @@ -2205,7 +2205,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: @@ -2228,7 +2228,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): d_new[i] += S_smem[row, j] m_smem[row] = m_new[i] @@ -2237,9 +2237,9 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches T.tvm_storage_sync("shared") # Update O - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): + with T.sblock("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2( @@ -2251,7 +2251,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): + with T.sblock("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = 0 + (LH_start + i) // group_size cur_H_qo: T.int32 = ( @@ -2264,7 +2264,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches # Store LSE to gmem for li in T.grid(tile_x): - with T.block("lse_store"): + with T.sblock("lse_store"): i = T.axis.remap("S", [li]) cur_L: T.int32 = 0 + (LH_start + i) // group_size cur_H_qo: T.int32 = ( @@ -2333,7 +2333,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable for b in T.serial(batch_size): - with T.block("attn"): + with T.sblock("attn"): softmax_sum = T.alloc_buffer([h_q], "float32") m_prev = T.alloc_buffer([h_q], "float32") m_new = T.alloc_buffer([h_q], "float32") @@ -2471,7 +2471,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): + with T.sblock("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() @@ -2529,14 +2529,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches d_smem[row] = 1.0 for li, lj in T.grid(tile_x, d_v): - with T.block("O_init"): + with T.sblock("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 T.tvm_storage_sync("shared") # Load Q from gmem to smem for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): + with T.sblock("Q_load"): i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() @@ -2556,7 +2556,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches L_kv_start: T.int32 = iterator * tile_z L_kv_base: T.int32 = kv_indptr[b_idx] for lz, ly in T.grid(tile_z, tile_y): - with T.block("K_load"): + with T.sblock("K_load"): i, j = T.axis.remap("SS", [lz, ly]) cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: @@ -2569,7 +2569,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches K_smem[i, j] = 0.0 T.tvm_storage_sync("shared") for lz, ly in T.grid(tile_z, d_v): - with T.block("V_load"): + with T.sblock("V_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -2581,16 +2581,16 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches T.tvm_storage_sync("shared") # Compute S - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): + with T.sblock("S_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): + with T.sblock("S_store"): i, j = T.axis.remap("SS", [li, lj]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") @@ -2599,7 +2599,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update1"): + with T.sblock("update1"): m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S @@ -2615,7 +2615,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: @@ -2632,7 +2632,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): d_new[i] += S_smem[row, j] m_smem[row] = m_new[i] @@ -2641,9 +2641,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches T.tvm_storage_sync("shared") # Update O - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, d_v, tile_z): - with T.block("O_gemm"): + with T.sblock("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) @@ -2651,7 +2651,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # Store O from smem to gmem for li, lj in T.grid(tile_x, d_v): - with T.block("O_store"): + with T.sblock("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size @@ -2660,7 +2660,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # Store LSE to gmem for li in T.grid(tile_x): - with T.block("lse_store"): + with T.sblock("lse_store"): i = T.axis.remap("S", [li]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size @@ -2748,7 +2748,7 @@ def batch_prefill_paged_kv_mla( for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): + with T.sblock("attn"): bx, ty, tx = T.axis.remap("SSS", [lbx, lty, ltx]) T.reads() T.writes() @@ -2811,14 +2811,14 @@ def batch_prefill_paged_kv_mla( d_smem[row] = 1.0 for li, lj in T.grid(tile_x, d_latent): - with T.block("O_init"): + with T.sblock("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 T.tvm_storage_sync("shared") # Load Q from gmem to smem for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): + with T.sblock("Q_load"): i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() @@ -2833,7 +2833,7 @@ def batch_prefill_paged_kv_mla( for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): L_kv_start: T.int32 = iterator * tile_z for lz, ly in T.grid(tile_z, tile_y): - with T.block("KV_load"): + with T.sblock("KV_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -2848,16 +2848,16 @@ def batch_prefill_paged_kv_mla( T.tvm_storage_sync("shared") # Compute S - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): + with T.sblock("S_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(KV_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): + with T.sblock("S_store"): i, j = T.axis.remap("SS", [li, lj]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") @@ -2866,7 +2866,7 @@ def batch_prefill_paged_kv_mla( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update1"): + with T.sblock("update1"): m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S @@ -2882,7 +2882,7 @@ def batch_prefill_paged_kv_mla( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: @@ -2899,7 +2899,7 @@ def batch_prefill_paged_kv_mla( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): d_new[i] += S_smem[row, j] m_smem[row] = m_new[i] @@ -2908,9 +2908,9 @@ def batch_prefill_paged_kv_mla( T.tvm_storage_sync("shared") # Update O - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, d_latent, tile_z): - with T.block("O_gemm"): + with T.sblock("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) @@ -2918,7 +2918,7 @@ def batch_prefill_paged_kv_mla( # Store O from smem to gmem for li, lj in T.grid(tile_x, d_latent): - with T.block("O_store"): + with T.sblock("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = (LH_start + i) % group_size @@ -2927,7 +2927,7 @@ def batch_prefill_paged_kv_mla( # Store LSE to gmem for li in T.grid(tile_x): - with T.block("lse_store"): + with T.sblock("lse_store"): i = T.axis.remap("S", [li]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = (LH_start + i) % group_size @@ -2969,7 +2969,7 @@ def copy_single_page( (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" ): for t in T.thread_binding(tx, thread="threadIdx.x"): - with T.block("copy"): + with T.sblock("copy"): T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( num_heads, @@ -3011,7 +3011,7 @@ def copy_single_page_mla( for b in T.thread_binding((copy_length * head_dim + tx - 1) // tx, thread="blockIdx.x"): for t in T.thread_binding(tx, thread="threadIdx.x"): - with T.block("copy"): + with T.sblock("copy"): T.where(b * tx + t < copy_length * head_dim) vp = T.axis.spatial(copy_length, (b * tx + t) // head_dim) vd = T.axis.spatial(head_dim, T.Cast("int32", (b * tx + t) % head_dim)) @@ -3036,7 +3036,7 @@ def copy_single_page_cpu( for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx): for t in T.serial(tx): - with T.block("copy"): + with T.sblock("copy"): T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( num_heads, @@ -3094,7 +3094,7 @@ def compact_kv_copy( elem_offset=copy_src_dst_pos_elem_offset, ) - with T.block("root"): + with T.sblock("root"): for bhd_o in T.thread_binding( (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" ): @@ -3145,7 +3145,7 @@ def compact_kv_copy_cpu( elem_offset=copy_src_dst_pos_elem_offset, ) - with T.block("root"): + with T.sblock("root"): for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1) // tx): for bhd_i in T.serial(tx): b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 60808a6b35fd..b90b4bfecd12 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -388,7 +388,7 @@ def fused_rope( # pylint: disable=too-many-locals k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, head_dim), dtype) v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, head_dim), dtype) for iters in T.grid(batch_size, seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): + with T.sblock("llama_fused_rope"): b, s, h, d = T.axis.remap("SSSS", iters) if h < num_q_heads: q[b, s, h, d] = T.if_then_else( @@ -524,7 +524,7 @@ def fused_rope( # pylint: disable=too-many-locals var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset ) for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): + with T.sblock("llama_fused_rope"): s, h, d = T.axis.remap("SSS", iters) if h < num_q_heads: q[s, h, d] = T.if_then_else( @@ -573,7 +573,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals if seq_len > original_max_position_embeddings: for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): + with T.sblock("llama_fused_rope"): s, h, d = T.axis.remap("SSS", iters) if h < num_q_heads: q[s, h, d] = T.if_then_else( @@ -605,7 +605,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] else: for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): + with T.sblock("llama_fused_rope"): s, h, d = T.axis.remap("SSS", iters) if h < num_q_heads: q[s, h, d] = T.if_then_else( @@ -744,7 +744,7 @@ def fused_rope( # pylint: disable=too-many-locals var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset ) for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): + with T.sblock("llama_fused_rope"): s, h, d = T.axis.remap("SSS", iters) if h < num_q_heads: q[s, h, d] = T.if_then_else( @@ -786,7 +786,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset ) for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): + with T.sblock("llama_fused_rope"): s, h, d = T.axis.remap("SSS", iters) if h < num_q_heads: q[s, h, d] = T.if_then_else( diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 635b7a5d505a..385c0a51abf3 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -161,7 +161,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable for b in T.serial(batch_size_plus_1 - 1): - with T.block("attn"): + with T.sblock("attn"): softmax_sum = T.alloc_buffer([h_q], "float32") m_prev = T.alloc_buffer([h_q], "float32") @@ -382,7 +382,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): + with T.sblock("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() @@ -440,14 +440,14 @@ def batch_tree_attn( # pylint: disable=too-many-branches d_smem[row] = 1.0 for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): + with T.sblock("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 T.tvm_storage_sync("shared") # Load Q from gmem to smem for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): + with T.sblock("Q_load"): i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() @@ -467,7 +467,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches L_kv_start: T.int32 = iterator * tile_z L_kv_base: T.int32 = kv_indptr[b_idx] for lz, ly in T.grid(tile_z, tile_y): - with T.block("KV_load"): + with T.sblock("KV_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -485,16 +485,16 @@ def batch_tree_attn( # pylint: disable=too-many-branches T.tvm_storage_sync("shared") # Compute S - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): + with T.sblock("S_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): + with T.sblock("S_store"): i, j = T.axis.remap("SS", [li, lj]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") @@ -503,7 +503,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update1"): + with T.sblock("update1"): m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S @@ -522,7 +522,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: @@ -542,7 +542,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): d_new[i] += S_smem[row, j] m_smem[row] = m_new[i] @@ -551,9 +551,9 @@ def batch_tree_attn( # pylint: disable=too-many-branches T.tvm_storage_sync("shared") # Update O - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): + with T.sblock("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) @@ -561,7 +561,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): + with T.sblock("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size @@ -570,7 +570,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches # Store LSE to gmem for li in T.grid(tile_x): - with T.block("lse_store"): + with T.sblock("lse_store"): i = T.axis.remap("S", [li]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size @@ -640,15 +640,15 @@ def apply_to_md(sch, block): tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("KV_load")) - - apply_to_md(sch, sch.get_block("lse_store")) + apply_to_gemm(sch, sch.get_sblock("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_sblock("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_sblock("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_sblock("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_sblock("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_sblock("Q_load")) + apply_to_qkv_load(sch, sch.get_sblock("KV_load")) + + apply_to_md(sch, sch.get_sblock("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", True) @@ -753,7 +753,7 @@ def tree_attn_paged_kv_cpu( for h_qo in T.serial(h_q): for b_idx in T.serial(batch_size): - with T.block("attn"): + with T.sblock("attn"): T.reads() T.writes() O_local = T.alloc_buffer((d, ), "float32") @@ -998,7 +998,7 @@ def tree_attn_paged_kv( for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): + with T.sblock("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() @@ -1076,14 +1076,14 @@ def tree_attn_paged_kv( d_smem[row] = 1.0 for li, lj in T.grid(tile_x, tile_y): - with T.block("O_init"): + with T.sblock("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 T.tvm_storage_sync("shared") # Load Q from gmem to smem for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): + with T.sblock("Q_load"): i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() @@ -1111,7 +1111,7 @@ def tree_attn_paged_kv( for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): L_kv_start: T.int32 = iterator * tile_z for lz, ly in T.grid(tile_z, tile_y): - with T.block("K_load"): + with T.sblock("K_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -1128,7 +1128,7 @@ def tree_attn_paged_kv( T.tvm_storage_sync("shared") for lz, ly in T.grid(tile_z, tile_y): - with T.block("V_load"): + with T.sblock("V_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() T.writes() @@ -1145,9 +1145,9 @@ def tree_attn_paged_kv( T.tvm_storage_sync("shared") # Compute S - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): + with T.sblock("S_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 @@ -1159,7 +1159,7 @@ def tree_attn_paged_kv( ) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): + with T.sblock("S_store"): i, j = T.axis.remap("SS", [li, lj]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") @@ -1168,7 +1168,7 @@ def tree_attn_paged_kv( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update1"): + with T.sblock("update1"): m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S @@ -1193,7 +1193,7 @@ def tree_attn_paged_kv( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: @@ -1219,7 +1219,7 @@ def tree_attn_paged_kv( for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: - with T.block("update"): + with T.sblock("update"): for j in T.serial(tile_z): d_new[i] += S_smem[row, j] m_smem[row] = m_new[i] @@ -1228,9 +1228,9 @@ def tree_attn_paged_kv( T.tvm_storage_sync("shared") # Update O - with T.block(): + with T.sblock(): for li, lj, lk in T.grid(tile_x, tile_y, tile_z): - with T.block("O_gemm"): + with T.sblock("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2( @@ -1242,7 +1242,7 @@ def tree_attn_paged_kv( # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): - with T.block("O_store"): + with T.sblock("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = ( q_indptr[b_idx] + (LH_start + i) // group_size @@ -1257,7 +1257,7 @@ def tree_attn_paged_kv( # Store LSE to gmem for li in T.grid(tile_x): - with T.block("lse_store"): + with T.sblock("lse_store"): i = T.axis.remap("S", [li]) cur_L: T.int32 = ( q_indptr[b_idx] + (LH_start + i) // group_size @@ -1332,13 +1332,13 @@ def apply_to_md(sch, block): tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) - apply_to_md(sch, sch.get_block("lse_store")) + apply_to_gemm(sch, sch.get_sblock("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_sblock("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_sblock("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_sblock("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_sblock("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_sblock("Q_load")) + apply_to_qkv_load(sch, sch.get_sblock("K_load")) + apply_to_qkv_load(sch, sch.get_sblock("V_load")) + apply_to_md(sch, sch.get_sblock("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", True) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 50d4772d8ca1..7cff79660fa1 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2802,7 +2802,7 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): top_k = T.match_buffer(C, (batch, 1), index_dtype) renorm_prob = T.match_buffer(D, (batch, 1), prob_dtype) for ax0, ax1 in T.grid(batch, vocab_size): - with T.block("T_get_renorm_prob"): + with T.sblock("T_get_renorm_prob"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) if not _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0): renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0] @@ -2826,7 +2826,7 @@ def _get_index_from_sorted( output_index = T.match_buffer(F, (out_batch, 1), index_dtype) for ax0, ax1 in T.grid(out_batch, vocab_size): - with T.block("T_get_index_from_sorted"): + with T.sblock("T_get_index_from_sorted"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.writes(output_index[v_ax0, 0]) if ( @@ -2909,7 +2909,7 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T. top_k = T.match_buffer(D, (batch, 1), top_k_dtype) cutoff = T.match_buffer(E, (batch, 1), prob_dtype) for ax0, ax1 in T.grid(batch, vocab_size): - with T.block("T_get_renorm_cutoff"): + with T.sblock("T_get_renorm_cutoff"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0: cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0] diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index bfd7dbf87d70..a352c4555c82 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1151,7 +1151,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(2, 3): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) @@ -1165,7 +1165,7 @@ def multiply( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(2, 3): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index a6bb68e2507c..c76acf71fca7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -55,7 +55,7 @@ class IRBuilderFrame(_Object): with IRBuilder() as builder: with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame) # to `builder`'s stack of frames - with T.block(...): # pushes a BlockFrame (subclass of IRBuilderFrame) + with T.sblock(...): # pushes a BlockFrame (subclass of IRBuilderFrame) # to `builder`'s stack of frames buffer = T.match_buffer(...) """ diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py index ed4d948ff972..db16683a8a3b 100644 --- a/python/tvm/script/ir_builder/relax/frame.py +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -35,8 +35,8 @@ class FunctionFrame(SeqExprFrame): """The ir_builder frame for the relax function.""" -@_register_object("script.ir_builder.relax.BlockFrame") -class BlockFrame(RelaxFrame): +@_register_object("script.ir_builder.relax.BindingBlockFrame") +class BindingBlockFrame(RelaxFrame): """The ir_builder frame for relax binding blocks.""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 5410c3c03a43..9a961ec3862f 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -377,11 +377,11 @@ def replacement(A: R.Tensor): ############################# BindingBlock ############################## -def dataflow() -> frame.BlockFrame: +def dataflow() -> frame.BindingBlockFrame: """Start a dataflow binding block frame. Returns ------- - frame: frame.BlockFrame + frame: frame.BindingBlockFrame The created ir_builder Block frame. """ return _ffi_api.Dataflow() # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index f43b4cf6ed67..ddecd005c85d 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -33,12 +33,12 @@ class PrimFuncFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.BlockFrame") -class BlockFrame(TIRFrame): +@_register_object("script.ir_builder.tir.SSBlockFrame") +class SBlockFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.BlockInitFrame") +@_register_object("script.ir_builder.tir.SBlockInitFrame") class BlockInitFrame(TIRFrame): ... diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a08e66789fa3..bf8a08180137 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -92,7 +92,7 @@ _block_name_suffix = threading.local() -def _get_block_name_suffix() -> str: +def _get_sblock_name_suffix() -> str: """Get the current block name suffix for macro expansion.""" return getattr(_block_name_suffix, "value", "") @@ -367,23 +367,23 @@ def match_buffer( ) -def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: - """The block declaration statement. +def sblock(name: str = "", no_realize: bool = False) -> frame.SBlockFrame: + """The sblock declaration statement. Parameters ---------- name : str - The name of the block. + The name of the sblock. no_realize : bool - The flag whether to construct BlockRealize or Block. + The flag whether to construct SBlockRealize or SBlock. Returns ------- - res : frame.BlockFrame - The BlockFrame. + res : frame.SBlockFrame + The SBlockFrame. """ - block_suffix = _get_block_name_suffix() + block_suffix = _get_sblock_name_suffix() if block_suffix and name: name = name + block_suffix return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member @@ -458,7 +458,7 @@ def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: _ffi_api.Writes(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member -def block_attr(attrs: Dict[str, Any]) -> None: +def sblock_attr(attrs: Dict[str, Any]) -> None: """The block annotation statement. Parameters @@ -2168,13 +2168,13 @@ def wrapped(*args, **kwargs): "func_attr", "func_ret", "match_buffer", - "block", + "sblock", "block_name_suffix_context", "init", "where", "reads", "writes", - "block_attr", + "sblock_attr", "alloc_buffer", "axis", "serial", diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 011136d5d377..426a9d9346fb 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -24,7 +24,7 @@ from tvm.ir import GlobalVar, structural_equal from tvm.relax import Expr, StructInfo from tvm.relax.utils import convert_to_expr -from tvm.script.ir_builder.relax.frame import BlockFrame +from tvm.script.ir_builder.relax.frame import BindingBlockFrame from ...ir_builder import ir as I from ...ir_builder import relax as R @@ -360,7 +360,7 @@ def visit_with(self: Parser, node: doc.With) -> None: with self.var_table.with_frame(): with frame: self.visit(node.body) - if isinstance(frame, BlockFrame) and frame.is_dataflow: + if isinstance(frame, BindingBlockFrame) and frame.is_dataflow: output_vars = frame.output_vars for var in output_vars: self.var_table.add(var.name_hint, var, allow_shadowing=True) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 91d3e2b81cc9..7df3dfb3be84 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -372,7 +372,7 @@ def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -588,7 +588,7 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block(): + with T.sblock(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index 6842f1e519f3..b87fd2a58b6e 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -42,7 +42,7 @@ def mma_schedule( ir_module = tvm.IRModule({"main": workload}) sch = tvm.tir.Schedule(ir_module) - block = sch.get_block("C") + block = sch.get_sblock("C") i, j, k = sch.get_loops(block) i, i_tc = sch.split(i, factors=[None, 16]) j, j_tc = sch.split(j, factors=[None, 16]) @@ -101,7 +101,7 @@ def fetch_to_shared(block, idx, ndim): sch.reorder(io, jo, ii, ji) sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) - block_init_c = sch.get_block("C_init") + block_init_c = sch.get_sblock("C_init") def tile_wmma_fragment(block_read, height, width): i, j = sch.get_loops(block_read)[-2:] @@ -158,7 +158,7 @@ def mfma_schedule( wmma_n = 16 wmma_k = k_inner warp_size = 64 - block = sch.get_block("C") + block = sch.get_sblock("C") i, j, k = sch.get_loops(block) i, i_tc = sch.split(i, factors=[None, wmma_m]) j, j_tc = sch.split(j, factors=[None, wmma_n]) @@ -212,7 +212,7 @@ def fetch_to_shared(block, idx, ndim): sch.reorder(io, jo, ii, ji) sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) - block_init_c = sch.get_block("C_init") + block_init_c = sch.get_sblock("C_init") def tile_wmma_fragment(block_read, height, width): i, j = sch.get_loops(block_read)[-2:] diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 74f2f6b2f757..7f01624d3e60 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -40,7 +40,7 @@ from .stmt import SeqStmt from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list -from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize +from .stmt import BufferRegion, MatchBufferRegion, SBlock, SBlockRealize from .function import PrimFunc, TensorIntrin, IndexMap @@ -115,8 +115,8 @@ from .op import ignore_loop_partition from .generic import add, subtract, multiply -from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError -from .block_dependence_info import BlockDependenceInfo +from .schedule import StmtSRef, SBlockScope, ScheduleState, Schedule, ScheduleError +from .block_dependence_info import SBlockDependenceInfo from . import schedule from . import ir_builder diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 8a84d3ee51fa..78b8d6b804cb 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -21,7 +21,7 @@ import tvm from tvm.ir import IRModule from tvm.tir.expr import Var -from tvm.tir.stmt import Block, BufferRegion, PrimExpr +from tvm.tir.stmt import SBlock, BufferRegion, PrimExpr from .. import Buffer, Stmt from ..function import PrimFunc @@ -116,15 +116,15 @@ def verify_gpu_code(func: PrimFunc, constraints: Dict[str, int]) -> None: return _ffi_api.verify_gpu_code(func, constraints) # type: ignore -def get_block_access_region( - block: Block, buffer_var_map: Dict[Var, Buffer] +def get_sblock_access_region( + block: SBlock, buffer_var_map: Dict[Var, Buffer] ) -> List[List[BufferRegion]]: """Detect which regions of tensors in this block are read or written to. Regions are sorted by order of appearance in the AST. Parameters ---------- - block: tvm.tir.Block + block: tvm.tir.SBlock The block in which we are detecting read/write regions. buffer_var_map : Dict[Var, Buffer] @@ -138,18 +138,18 @@ def get_block_access_region( - second: write regions - third: opaque regions """ - return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore + return _ffi_api.GetSBlockAccessRegion(block, buffer_var_map) # type: ignore -def get_block_read_write_region( - block: Block, buffer_var_map: Dict[Var, Buffer] +def get_sblock_read_write_region( + block: SBlock, buffer_var_map: Dict[Var, Buffer] ) -> List[List[BufferRegion]]: """Auto detect the block read/write region according to its body stmt. An opaque access will be counted as both a read and a write access Parameters ---------- - block: tvm.tir.Block + block: tvm.tir.SBlock The block in which we are detecting read/write regions. buffer_var_map : Dict[Var, Buffer] @@ -160,7 +160,7 @@ def get_block_read_write_region( result : List[List[BufferRegion]] An array only consisting of the read regions and write regions of the input block """ - return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore + return _ffi_api.GetSBlockReadWriteRegion(block, buffer_var_map) # type: ignore def calculate_allocated_bytes( @@ -274,7 +274,7 @@ def OOBChecker(): return _ffi_api.OOBChecker() # type: ignore -def find_anchor_block(mod: IRModule) -> Block: +def find_anchor_sblock(mod: IRModule) -> SBlock: """Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an init statement and (2) having @@ -295,10 +295,10 @@ def find_anchor_block(mod: IRModule) -> Block: The input TIR module. Returns ------- - anchor_block: Block + anchor_block: SBlock The anchor block if found, None otherwise. """ - return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member + return _ffi_api.find_anchor_sblock(mod) # type: ignore # pylint: disable=no-member def has_if_then_else(stmt: Stmt) -> bool: diff --git a/python/tvm/tir/block_dependence_info.py b/python/tvm/tir/block_dependence_info.py index 7bd6b418fc72..8deba7e3a79f 100644 --- a/python/tvm/tir/block_dependence_info.py +++ b/python/tvm/tir/block_dependence_info.py @@ -14,28 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects +"""Define BlockDependenceInfoNode that uses the SBlockScope and StmtSRef objects to store the block level dependences""" from typing import Union, Optional from tvm_ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object -from tvm.tir import Block, PrimFunc +from tvm.tir import SBlock, PrimFunc -from .block_scope import BlockScope, StmtSRef +from .block_scope import SBlockScope, StmtSRef from . import _ffi_api -@register_object("tir.BlockDependenceInfo") -class BlockDependenceInfo(Object): +@register_object("tir.SBlockDependenceInfo") +class SBlockDependenceInfo(Object): """ - BlockDependenceInfo + SBlockDependenceInfo An object that helps build and query block level dependences using the 2 core objects - BlockScope and StmtSRef + SBlockScope and StmtSRef The data structures exposed are: - 1) sref2scope: Mapping from the srefs to its corresponding BlockScope + 1) sref2scope: Mapping from the srefs to its corresponding SBlockScope 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs Note that this object does not store SRefs to loops as the purpose is only to expose block level @@ -51,11 +51,11 @@ def __init__(self, mod: Union[IRModule, PrimFunc]): if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") self.__init_handle_by_constructor__( - _ffi_api.BlockDependenceInfo, # type: ignore # pylint: disable=no-member + _ffi_api.SBlockDependenceInfo, # type: ignore # pylint: disable=no-member mod, ) - def get_sref(self, block: Block) -> Optional[StmtSRef]: + def get_sref(self, block: SBlock) -> Optional[StmtSRef]: """Return the corresponding sref that points to the block Parameters @@ -68,10 +68,10 @@ def get_sref(self, block: Block) -> Optional[StmtSRef]: sref : StmtSRef The corresponding sref """ - return _ffi_api.BlockDependenceInfoGetSRef(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.SBlockDependenceInfoGetSRef(self, block) # type: ignore # pylint: disable=no-member - def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: - """Get the BlockScope correpsonding to the block sref + def get_sblock_scope(self, block_sref: StmtSRef) -> SBlockScope: + """Get the SBlockScope correpsonding to the block sref Parameters ---------- @@ -81,8 +81,8 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: Returns ------- scope : StmtSRef - The corresponding BlockScope + The corresponding SBlockScope """ - return _ffi_api.BlockDependenceInfoGetBlockScope( # type: ignore # pylint: disable=no-member + return _ffi_api.SBlockDependenceInfoGetSBlockScope( # type: ignore # pylint: disable=no-member self, block_sref ) diff --git a/python/tvm/tir/block_scope.py b/python/tvm/tir/block_scope.py index d63771fae93e..d8bc9b16e9c8 100644 --- a/python/tvm/tir/block_scope.py +++ b/python/tvm/tir/block_scope.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.""" +"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, SBlockScope.""" from enum import IntEnum from typing import List, Optional, Union from tvm_ffi import register_object from tvm.runtime import Object -from tvm.tir import Block, For +from tvm.tir import SBlock, For from . import _ffi_api @@ -30,7 +30,7 @@ class StmtSRef(Object): """An object that refers to schedulable elements in the TensorIR, aka "sref". Glossary - - Block sref: An StmtSref that points to a TensorIR block. + - SBlock sref: An StmtSref that points to a TensorIR block. - Loop sref: An StmtSRef that points to a TensorIR for loop. - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest schedulable statement of its ancestors on the TensorIR AST. @@ -43,7 +43,7 @@ class StmtSRef(Object): seq_index: int @property - def stmt(self) -> Optional[Union[Block, For]]: + def stmt(self) -> Optional[Union[SBlock, For]]: """The block/for stmt the object refers to""" return _ffi_api.StmtSRefStmt(self) # type: ignore # pylint: disable=no-member @@ -107,21 +107,21 @@ class Dependency(Object): kind: DepKind -@register_object("tir.BlockScope") -class BlockScope(Object): +@register_object("tir.SBlockScope") +class SBlockScope(Object): """An object corresponds to each block sref in the sref tree, which tracks the producer-consumer dependency between blocks. Glossary: - - Block scope: A contiguous subtree of the sref tree, rooted at - each block sref, whose components are: + - SBlock scope: A contiguous subtree of the sref tree, rooted at + each SBlock sref, whose components are: - - scope root: a block sref + - scope root: a SBlock sref - internal srefs: loop srefs - - scope leaves: block srefs + - scope leaves: SBlock srefs - - Child block: The scope leaf blocks under the scope root or a specific internal sref + - Child SBlock: The scope leaf SBlocks under the scope root or a specific internal sref """ def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: @@ -137,7 +137,7 @@ def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: blocks: List[Dependency] The dependencies """ - return _ffi_api.BlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.SBlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: """Get all dependencies whose `dst` is the target `block`. @@ -152,4 +152,4 @@ def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: blocks: List[Dependency] The dependencies """ - return _ffi_api.BlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.SBlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 5b365e124cfc..779b9c374ea6 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -140,7 +140,7 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: B = T.match_buffer(b, (m, n), "float32") for i, j in T.grid(m, n): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -163,7 +163,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (16, 16), "float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index c2594835fedf..1c403ce92759 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -63,8 +63,8 @@ AllocateConst, AssertStmt, AttrStmt, - Block, - BlockRealize, + SBlock, + SBlockRealize, BufferRealize, BufferStore, DeclBuffer, @@ -173,7 +173,7 @@ def __init__( f_visit_seq_stmt: Callable = None, f_visit_evaluate: Callable = None, f_visit_block: Callable = None, - f_visit_block_realize: Callable = None, + f_visit_sblock_realize: Callable = None, # PrimExpr f_visit_var: Callable = None, f_visit_size_var: Callable = None, @@ -229,7 +229,7 @@ def __init__( f_visit_seq_stmt, f_visit_evaluate, f_visit_block, - f_visit_block_realize, + f_visit_sblock_realize, # PrimExpr f_visit_var, f_visit_size_var, @@ -293,8 +293,8 @@ class PyStmtExprVisitor: "visit_assert_stmt_", "visit_seq_stmt_", "visit_evaluate_", - "visit_block_", - "visit_block_realize_", + "visit_sblock_", + "visit_sblock_realize_", # PrimExpr "visit_var_", "visit_size_var_", @@ -521,30 +521,30 @@ def visit_evaluate_(self, op: Evaluate) -> None: print("visit_evaluate_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_block_(self, op: Block) -> None: - """Visit Block. - Users can customize this function to overwrite VisitStmt_(const BlockNode* op) + def visit_sblock_(self, op: SBlock) -> None: + """Visit SBlock. + Users can customize this function to overwrite VisitStmt_(const SBlockNode* op) on the C++ side. Parameters ---------- - op : Block - The Block to be visited. + op : SBlock + The SBlock to be visited. """ - print("visit_block_", op) + print("visit_sblock_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_block_realize_(self, op: BlockRealize) -> None: + def visit_sblock_realize_(self, op: SBlockRealize) -> None: """Visit BlockRealize. - Users can customize this function to overwrite VisitStmt_(const BlockRealizeNode* op) + Users can customize this function to overwrite VisitStmt_(const SBlockRealizeNode* op) on the C++ side. Parameters ---------- - op : BlockRealize + op : SBlockRealize The BlockRealize to be visited. """ - print("visit_block_realize_", op) + print("visit_sblock_realize_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_var_(self, op: Var) -> None: @@ -1007,7 +1007,7 @@ def __init__( f_visit_seq_stmt: Callable = None, f_visit_evaluate: Callable = None, f_visit_block: Callable = None, - f_visit_block_realize: Callable = None, + f_visit_sblock_realize: Callable = None, # PrimExpr f_visit_var: Callable = None, f_visit_size_var: Callable = None, @@ -1063,7 +1063,7 @@ def __init__( f_visit_seq_stmt, f_visit_evaluate, f_visit_block, - f_visit_block_realize, + f_visit_sblock_realize, # PrimExpr f_visit_var, f_visit_size_var, @@ -1127,8 +1127,8 @@ class PyStmtExprMutator: "visit_assert_stmt_", "visit_seq_stmt_", "visit_evaluate_", - "visit_block_", - "visit_block_realize_", + "visit_sblock_", + "visit_sblock_realize_", # PrimExpr "visit_var_", "visit_size_var_", @@ -1421,15 +1421,15 @@ def visit_evaluate_(self, op: Evaluate) -> Stmt: """ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_block_(self, op: Block) -> Stmt: - """Visit Block. - Users can customize this function to overwrite VisitStmt_(const BlockNode* op) + def visit_sblock_(self, op: SBlock) -> Stmt: + """Visit SBlock. + Users can customize this function to overwrite VisitStmt_(const SBlockNode* op) on the C++ side. Parameters ---------- - op : Block - The Block to be visited. + op : SBlock + The SBlock to be visited. Returns ------- @@ -1438,15 +1438,15 @@ def visit_block_(self, op: Block) -> Stmt: """ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_block_realize_(self, op: BlockRealize) -> Stmt: + def visit_sblock_realize_(self, op: SBlockRealize) -> Stmt: """Visit BlockRealize. - Users can customize this function to overwrite VisitStmt_(const BlockRealizeNode* op) + Users can customize this function to overwrite VisitStmt_(const SBlockRealizeNode* op) on the C++ side. Parameters ---------- - op : BlockRealize - The BlockRealize to be visited. + op : SBlockRealize + The SBlockRealize to be visited. Returns ------- diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 1f68c487c063..170d6dd9abc2 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -17,9 +17,9 @@ # pylint: disable=unused-import """Namespace for the TensorIR schedule API.""" -from ..block_scope import BlockScope, Dependency, DepKind, StmtSRef +from ..block_scope import SBlockScope, Dependency, DepKind, StmtSRef from .instruction import Instruction, InstructionKind -from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError +from .schedule import SBlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 66eab497eb5a..d8fcc7213e6d 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -26,7 +26,7 @@ from ..function import IndexMap, PrimFunc from . import _ffi_api -from .schedule import Schedule, BlockRV +from .schedule import Schedule, SBlockRV def suggest_index_map( @@ -68,7 +68,7 @@ class TensorizeInfo(Object): def get_tensorize_loop_mapping( - sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool = False + sch: Schedule, block: SBlockRV, desc_func: PrimFunc, allow_padding: bool = False ) -> Optional[TensorizeInfo]: """Establish a mapping between loops in a target block and an intrinsic description @@ -76,7 +76,7 @@ def get_tensorize_loop_mapping( ---------- sch : Schedule The schedule to be tensorized - block : BlockRV + block : SBlockRV The target block to match against desc_func : PrimFunc The prim func describing the computation to be tensorized @@ -96,7 +96,7 @@ class AutoTensorizeMappingInfo(Object): def get_auto_tensorize_mapping_info( - sch: Schedule, block: BlockRV, desc_func: PrimFunc + sch: Schedule, block: SBlockRV, desc_func: PrimFunc ) -> Optional[AutoTensorizeMappingInfo]: """Get mapping info between a target block and an intrinsic description including layout transformations to apply. @@ -105,7 +105,7 @@ def get_auto_tensorize_mapping_info( ---------- sch : Schedule The schedule to be tensorized - block : BlockRV + block : SBlockRV The compute block for auto tensorization desc_func : PrimFunc The prim func describing the computation to be tensorized @@ -142,14 +142,14 @@ def has_block(sch: Schedule, block_name: str) -> bool: return _ffi_api.HasBlock(sch, block_name) # type: ignore -def is_output_block(sch: Schedule, block: BlockRV) -> bool: +def is_output_block(sch: Schedule, block: SBlockRV) -> bool: """Check whether the given block is an output block Parameters ---------- sch : Schedule The schedule object of the block - block : BlockRV + block : SBlockRV The blockRV to be checked Returns diff --git a/python/tvm/tir/schedule/instruction.py b/python/tvm/tir/schedule/instruction.py index 918292a7bbaa..26428d320dde 100644 --- a/python/tvm/tir/schedule/instruction.py +++ b/python/tvm/tir/schedule/instruction.py @@ -60,7 +60,7 @@ class InstructionKind(Object): @property def is_pure(self) -> bool: """Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule - state. For example, the instruction `GetBlock` is pure because it changes + state. For example, the instruction `GetSBlock` is pure because it changes nothing, while `ComputeInline` is not because removing it leads to a different resulting schedule. @@ -99,7 +99,7 @@ class Instruction(Object): inputs : List[INPUT_RV_TYPE] The input random variables of the instruction, and the type of each element can be one of the following: - - BlockRV + - SBlockRV - LoopRV - ExprRV - float @@ -109,11 +109,11 @@ class Instruction(Object): attrs : List[ATTR_TYPE] The attributes of the instruction. Similar to attributes of an operator, attributes of an instruction are arbitrary constant metadata required by the instructions. - For example, the name of the block to be retrieved in `GetBlock`. + For example, the name of the block to be retrieved in `GetSBlock`. outputs : List[OUTPUT_RV_TYPE] The output random variables of the instruction, and the type of each element can be one of the following: - - BlockRV + - SBlockRV - LoopRV - ExprRV, atomic variables only, won't be constants or composite PrimExpr """ @@ -139,7 +139,7 @@ def __init__( inputs : List[INPUT_RV_TYPE] The input random variables of the instruction, and the type of each element can be one of the following: - - BlockRV + - SBlockRV - LoopRV - ExprRV - float @@ -149,11 +149,11 @@ def __init__( attrs : List[ATTR_TYPE] The attributes of the instruction. Similar to attributes of an operator, attributes of an instruction are arbitrary constant metadata required by the - instructions. For example, the name of the block to be retrieved in `GetBlock`. + instructions. For example, the name of the block to be retrieved in `GetSBlock`. outputs : List[OUTPUT_RV_TYPE] The output random variables of the instruction, and the type of each element can be one of the following: - - BlockRV + - SBlockRV - LoopRV - ExprRV, atomic variables only, won't be constants or composite PrimExpr """ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b1e1a3f5d532..9226bbe30b6a 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -22,7 +22,7 @@ from tvm.error import TVMError, register_error from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object -from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc +from tvm.tir import SBlock, Buffer, FloatImm, For, IntImm, PrimFunc from ..function import IndexMap from . import _ffi_api @@ -47,14 +47,14 @@ def __init__(self) -> None: ) -@_register_object("tir.BlockRV") -class BlockRV(Object): +@_register_object("tir.SBlockRV") +class SBlockRV(Object): """A random variable that refers to a block""" def __init__(self) -> None: - """Construct a new BlockRV.""" + """Construct a new SBlockRV.""" self.__init_handle_by_constructor__( - _ffi_api.BlockRV # type: ignore # pylint: disable=no-member + _ffi_api.SBlockRV # type: ignore # pylint: disable=no-member ) @@ -64,7 +64,7 @@ def __init__(self) -> None: # A random variable that evaluates to an integer ExprRV = Union[PrimExpr] # pylint: disable=invalid-name -RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name +RAND_VAR_TYPE = Union[ExprRV, SBlockRV, LoopRV] # pylint: disable=invalid-name _ERROR_RENDER_LEVEL: Dict[Literal["detail", "fast", "none"], int] = { "detail": 0, @@ -98,7 +98,7 @@ def _parse_seed(seed: Optional[int]) -> int: return seed -def _get_block_default_dtype(block: Block) -> str: +def _get_sblock_default_dtype(block: SBlock) -> str: for i in block.iter_vars: return i.var.dtype for buffer_region in list(block.reads) + list(block.writes): @@ -228,7 +228,7 @@ def work_on(self, func_name: str) -> None: of their names are "main", users will have to call this method to explicitly specify which function to work on. - This sugar function will guide the `GetBlock` method if its `func_name` is not specified. + This sugar function will guide the `GetSBlock` method if its `func_name` is not specified. Parameters ---------- @@ -297,22 +297,22 @@ def show(self, *args, **kwargs) -> None: @type_checked def get( self, rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef] - ) -> Optional[Union[int, Block, For]]: + ) -> Optional[Union[int, SBlock, For]]: """Returns: - - the corresponding Block that a BlockRV evaluates to; + - the corresponding SBlock that a SBlockRV evaluates to; - the corresponding For that a LoopRV evaluates to; - the corresponding integer that a ExprRV evaluates to; - - the corresponding Block that a block sref points to; + - the corresponding SBlock that a SBlock sref points to; - the corresponding For that a loop sref points to; Parameters ---------- - rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef] + rand_var_or_sref : Union[ExprRV, SBlockRV, LoopRV, StmtSRef] The random variable / sref to be evaluated Returns ------- - result : Optional[Union[int, Block, For]] + result : Optional[Union[int, SBlock, For]] The corresponding result """ if isinstance(rand_var_or_sref, StmtSRef): @@ -324,16 +324,18 @@ def get( return result @type_checked - def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]: + def get_sref( + self, rand_var_or_stmt: Union[SBlockRV, LoopRV, SBlock, For] + ) -> Optional[StmtSRef]: """Returns the corresponding sref to the given 1) LoopRV - 2) BlockRV + 2) SBlockRV 3) Block 4) For Parameters ---------- - rand_var_or_stmt : Union[BlockRV, LoopRV, Block, For] + rand_var_or_stmt : Union[SBlockRV, LoopRV, SBlock, For] The random variable / sref to be evaluated Returns @@ -351,7 +353,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: Parameters ---------- - rand_var : Union[BlockRV, LoopRV, ExprRV] + rand_var : Union[SBlockRV, LoopRV, ExprRV] The random variable to be removed """ return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member @@ -456,13 +458,13 @@ def sample_partitioned_tile( @type_checked def sample_compute_location( - self, block: Union[BlockRV, str], decision: Optional[int] = None + self, block: Union[SBlockRV, str], decision: Optional[int] = None ) -> LoopRV: """Sample a compute-at location of the given block Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block whose compute-at location is to be sampled decision : Optional[int] The sampling decision @@ -480,7 +482,7 @@ def sample_compute_location( ########## Schedule: Get blocks & loops ########## @type_checked - def get_block(self, name: str, func_name: Optional[str] = None) -> BlockRV: + def get_sblock(self, name: str, func_name: Optional[str] = None) -> SBlockRV: """Retrieve a block in a specific function with its name By default, if `func_name` is not specified, the schedule will search for the block in the @@ -496,21 +498,21 @@ def get_block(self, name: str, func_name: Optional[str] = None) -> BlockRV: Returns ------- - block : BlockRV + block : SBlockRV The block retrieved IndexError is raised if 0 or multiple blocks exist with the specific name. """ - return _ffi_api.ScheduleGetBlock( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetSBlock( # type: ignore # pylint: disable=no-member self, name, func_name ) @type_checked - def get_loops(self, block: Union[BlockRV, str]) -> List[LoopRV]: + def get_loops(self, block: Union[SBlockRV, str]) -> List[LoopRV]: """Get the parent loops of the block in its scope, from outer to inner Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The query block Returns @@ -523,12 +525,12 @@ def get_loops(self, block: Union[BlockRV, str]) -> List[LoopRV]: return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore @type_checked - def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]: + def get_child_blocks(self, block_or_loop: Union[SBlockRV, LoopRV]) -> List[SBlockRV]: """Get the leaf blocks of a specific block/loop Parameters ---------- - block_or_loop : Union[BlockRV, LoopRV] + block_or_loop : Union[SBlockRV, LoopRV] The query block/loop Returns @@ -540,17 +542,17 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore @type_checked - def get_producers(self, block: Union[BlockRV, str]) -> List[BlockRV]: + def get_producers(self, block: Union[SBlockRV, str]) -> List[SBlockRV]: """Get the producers of a specific block Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block in the query Returns ------- - producers : List[BlockRV] + producers : List[SBlockRV] A list of producers of the given block """ block = self._normalize_block_arg(block) @@ -558,17 +560,17 @@ def get_producers(self, block: Union[BlockRV, str]) -> List[BlockRV]: return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore @type_checked - def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: + def get_consumers(self, block: Union[SBlockRV, str]) -> List[SBlockRV]: """Get the consumers of a specific block Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block in the query Returns ------- - consumers : List[BlockRV] + consumers : List[SBlockRV] A list of consumers of the given block """ block = self._normalize_block_arg(block) @@ -576,19 +578,19 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore @type_checked - def get_output_blocks(self, scope_block: Union[BlockRV, str]) -> List[BlockRV]: + def get_output_blocks(self, scope_block: Union[SBlockRV, str]) -> List[SBlockRV]: """Get the list of output blocks within the given scope An output block is a block which has atleast one buffer being written to, but is not allocated within the PrimFunc Parameters ---------- - scope_block : Union[BlockRV, str], + scope_block : Union[SBlockRV, str], The scope block from which output blocks are collected Returns ------- - output_blocks : List[BlockRV] + output_blocks : List[SBlockRV] A list of all blocks that write to some output buffer """ @@ -628,11 +630,11 @@ def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 @@ -641,8 +643,8 @@ def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_fuse) - i1, _ = sch.get_loops(sch.get_block("B")) - i2, _ = sch.get_loops(sch.get_block("C")) + i1, _ = sch.get_loops(sch.get_sblock("B")) + i2, _ = sch.get_loops(sch.get_sblock("C")) sch.merge(i1, i2) print(sch.mod["main"].script()) @@ -658,13 +660,13 @@ def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None: # the 2 loops are merged into 1 for i_m in range(128): for j in range(128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i_m, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for j in range(128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i_m, j]) T.reads(A[vi, vj]) T.writes(C[vi, vj]) @@ -702,7 +704,7 @@ def before_fuse(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -711,7 +713,7 @@ def before_fuse(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_fuse) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.fuse(i, j) print(sch.mod["main"].script()) @@ -725,7 +727,7 @@ def after_fuse(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the 2 loops are fused into 1 for i_j_fused in T.serial(0, 16384): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, T.floordiv(i_j_fused, 128)) vj = T.axis.S(128, T.floormod(i_j_fused, 128)) B[vi, vj] = A[vi, vj] * 2.0 @@ -789,7 +791,7 @@ def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -798,7 +800,7 @@ def before_split(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_split) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.split(i, factors=[2, 64]) print(sch.mod["main"].script()) @@ -812,7 +814,7 @@ def after_split(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the original loop is split into 2 loops for i0, i1, j in T.grid(2, 64, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i0 * 64 + i1) vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 @@ -875,7 +877,7 @@ def before_partition(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -884,7 +886,7 @@ def before_partition(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_partition) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.partition(i, factors=[2, 64]) print(sch.mod["main"].script()) @@ -896,37 +898,37 @@ def after_partition(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) # the original loop is partition into 3 loops - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block("B_i_common"): + with T.sblock("B_i_common"): T.reads() T.writes() - with T.block("B_i0_partition"): + with T.sblock("B_i0_partition"): T.reads() T.writes() for i0, j in T.grid(2, 128): - with T.block("B_i0"): + with T.sblock("B_i0"): vi, vj = T.axis.remap("SS", [i0, j]) T.reads(A[0:2, 0:128]) T.writes(B[0:2, 0:128]) B[vi, vj] = A[vi, vj] * T.float32(2) - with T.block("B_i1_partition"): + with T.sblock("B_i1_partition"): T.reads() T.writes() for i1 in range(2, 66): for j in range(128): - with T.block("B_i1"): + with T.sblock("B_i1"): vi, vj = T.axis.remap("SS", [i1, j]) T.reads(A[2:66, 0:128]) T.writes(B[2:66, 0:128]) B[vi, vj] = A[vi, vj] * T.float32(2) - with T.block("B_partition_2"): + with T.sblock("B_partition_2"): T.reads() T.writes() for i2 in range(66, 128): for j in range(128): - with T.block("B_i2"): + with T.sblock("B_i2"): vi, vj = T.axis.remap("SS", [i2, j]) T.reads(A[66:128, 0:128]) T.writes(B[66:128, 0:128]) @@ -968,7 +970,7 @@ def before_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -977,7 +979,7 @@ def before_reorder(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_reorder) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.reorder(j, i) print(sch.mod["main"].script()) @@ -991,7 +993,7 @@ def after_reorder(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # Here j and i are reordered for j, i in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -999,12 +1001,12 @@ def after_reorder(a: T.handle, b: T.handle) -> None: _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member @type_checked - def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None: + def reorder_block_iter_var(self, block: SBlockRV, new_order: List[int]) -> None: """Reorder the itervars inside a given block. Parameters ---------- - block : BlockRV + block : SBlockRV The block to be transformed. new_order : List[int] The new block itervar order. @@ -1023,7 +1025,7 @@ def matmul( C: T.Buffer((128, 128), "float32"), ) -> None: for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -1034,7 +1036,7 @@ def matmul( .. code-block:: python sch = tir.Schedule(matmul) - C = sch.get_block("C") + C = sch.get_sblock("C") sch.reorder_block_iter_var(C, [2, 1, 0]) After applying reorder_block_iter_var, the IR becomes: @@ -1048,7 +1050,7 @@ def matmul_after_reorder_block_iter_var( C: T.Buffer((128, 128), "float32"), ): for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vk, vj, vi = T.axis.remap("RSS", [k, j, i]) T.reads(A[vi, vk], B[vj, vk]) T.writes(C[vi, vj]) @@ -1064,12 +1066,12 @@ def matmul_after_reorder_block_iter_var( _ffi_api.ScheduleReorderBlockIterVar(self, block, new_order) # type: ignore @type_checked - def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV: + def add_unit_loop(self, block_or_loop: Union[LoopRV, SBlockRV]) -> LoopRV: """Create a new unit loop on top of the specific block or loop. Parameters ---------- - block_or_loop : Union[LoopRV, BlockRV] + block_or_loop : Union[LoopRV, SBlockRV] The block above which the new loop is created Returns @@ -1090,7 +1092,7 @@ def before_add_unit_loop( B: T.Buffer((), "int32"), C: T.Buffer((), "int32"), ) -> None: - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] @@ -1099,7 +1101,7 @@ def before_add_unit_loop( .. code-block:: python sch = tir.Schedule(before_add_unit_loop) - sch.add_unit_loop(sch.get_block("C")) + sch.add_unit_loop(sch.get_sblock("C")) print(sch.mod["main"].script()) After applying add-unit-loop, the IR becomes: @@ -1113,7 +1115,7 @@ def after_add_unit_loop( C: T.Buffer((), "int32"), ) -> None: for u in T.serial(1): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] """ @@ -1148,7 +1150,7 @@ def before_parallel(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1157,7 +1159,7 @@ def before_parallel(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_parallel) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.parallel(i) After applying parallel, the IR becomes: @@ -1170,7 +1172,7 @@ def after_parallel(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.parallel(0, 128): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1203,7 +1205,7 @@ def before_vectorize(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1212,7 +1214,7 @@ def before_vectorize(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_vectorize) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.vectorize(j) After applying vectorize, the IR becomes: @@ -1225,7 +1227,7 @@ def after_vectorize(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.serial(0, 128): for j in T.vectorized(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1266,7 +1268,7 @@ def before_bind(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1275,7 +1277,7 @@ def before_bind(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_bind) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.bind(i, "blockIdx.x") sch.bind(j, "threadIdx.x") @@ -1289,7 +1291,7 @@ def after_bind(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.thread_binding(0, 128, thread = "blockIdx.x"): for j in T.thread_binding(0, 128, thread = "threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1317,7 +1319,7 @@ def before_unroll(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1326,7 +1328,7 @@ def before_unroll(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_unroll) - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.unroll(i) After applying unroll, the IR becomes: @@ -1339,7 +1341,7 @@ def after_unroll(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.unroll(0, 128): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1351,11 +1353,11 @@ def after_unroll(a: T.handle, b: T.handle) -> None: @type_checked def cache_read( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], read_buffer_index: Union[int, str, Buffer], storage_scope: str, - consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, - ) -> BlockRV: + consumer_blocks: Optional[List[Union[SBlockRV, str]]] = None, + ) -> SBlockRV: """Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one block who write the buffer in the scope. @@ -1364,7 +1366,7 @@ def cache_read( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The consumer block of the target buffer. buffer: Union[int, str, Buffer] @@ -1375,13 +1377,13 @@ def cache_read( storage_scope: str The target storage scope. - consumer_blocks: Optional[List[Union[BlockRV, str]]] + consumer_blocks: Optional[List[Union[SBlockRV, str]]] An optional list of consumers that should read from the cache. If not specified, all consumers will use the cache. Returns ------- - cached_block : BlockRV + cached_block : SBlockRV The block of the cache stage Examples @@ -1395,7 +1397,7 @@ def before_cache_read(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1404,7 +1406,7 @@ def before_cache_read(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_cache_read) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") sch.cache_read(block_b, 0, "local") print(sch.mod["main"].script()) @@ -1418,11 +1420,11 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) A_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("A_local"): + with T.sblock("A_local"): vi, vj = T.axis.remap("SS", [i, j]) A_local[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_local[vi, vj] * 2.0 @@ -1430,7 +1432,7 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: if consumer_blocks is None: consumer_blocks = [] - # Convert any string block names into Block RVs. + # Convert any string SBlock names into SBlock RVs. consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] block = self._normalize_block_arg(block) @@ -1445,11 +1447,11 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: @type_checked def cache_write( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], write_buffer_index: Union[int, str, Buffer], storage_scope: str, - consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, - ) -> BlockRV: + consumer_blocks: Optional[List[Union[SBlockRV, str]]] = None, + ) -> SBlockRV: """Create a block that reads a buffer region into a write cache. It requires: 1) There is only one block who write the buffer in the scope. @@ -1458,7 +1460,7 @@ def cache_write( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The producer block of the target buffer. write_buffer_index: int @@ -1469,13 +1471,13 @@ def cache_write( storage_scope: str The target storage scope. - consumer_blocks: Optional[List[Union[BlockRV, str]]] + consumer_blocks: Optional[List[Union[SBlockRV, str]]] An optional list of consumers that should read directly from the cache. If not specified, all consumers will read from the original buffer. Returns ------- - cached_block : BlockRV + cached_block : SBlockRV The block of the cache stage Examples @@ -1489,7 +1491,7 @@ def before_cache_write(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1498,7 +1500,7 @@ def before_cache_write(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_cache_write) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") sch.cache_write(block_b, 0, "local") print(sch.mod["main"].script()) @@ -1512,11 +1514,11 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("A_local"): + with T.sblock("A_local"): vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_local[vi, vj] @@ -1524,7 +1526,7 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: if consumer_blocks is None: consumer_blocks = [] - # Convert any string block names into Block RVs. + # Convert any string SBlock names into SBlock RVs. consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] block = self._normalize_block_arg(block) @@ -1539,11 +1541,11 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: @type_checked def reindex_cache_read( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], read_buffer_index: int, storage_scope: str, index_map: Union[IndexMap, Callable], - ) -> BlockRV: + ) -> SBlockRV: """Create a block that reads a buffer region into a read cache using customized indices specified by index map. The read region of the buffer must be a single point. @@ -1557,7 +1559,7 @@ def reindex_cache_read( Parameters ---------- - block : BlockRV + block : SBlockRV The consumer block of the target buffer. read_buffer_index: int The index of the buffer in block's read region. @@ -1568,7 +1570,7 @@ def reindex_cache_read( Returns ------- - cached_block : BlockRV + cached_block : SBlockRV The block of the cache stage Examples @@ -1582,7 +1584,7 @@ def before_reindex_cache_read(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1591,7 +1593,7 @@ def before_reindex_cache_read(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_cache_read) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi)) print(sch.mod["main"].script()) @@ -1605,11 +1607,11 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) A_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("A_local"): + with T.sblock("A_local"): vi, vj = T.axis.remap("SS", [i, j]) A_local[vj, vi] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_local[vj, vi] * 2.0 @@ -1621,13 +1623,13 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: cache_read reindex """ - # Convert any string block names into Block RVs. + # Convert any string SBlock names into SBlock RVs. block = self._normalize_block_arg(block) if callable(index_map): index_map = IndexMap.from_func( index_map, - index_dtype=_get_block_default_dtype(self.get(block)), + index_dtype=_get_sblock_default_dtype(self.get(block)), ) return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member self, block, read_buffer_index, storage_scope, index_map @@ -1636,11 +1638,11 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: @type_checked def reindex_cache_write( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], write_buffer_index: int, storage_scope: str, index_map: Union[Callable, IndexMap], - ) -> BlockRV: + ) -> SBlockRV: r"""Create a block that reads a buffer region into a write cache using customized indices specified by index map. The write region of the buffer must be a single point. @@ -1654,7 +1656,7 @@ def reindex_cache_write( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The consumer block of the target buffer. write_buffer_index: int The index of the buffer in block's write region. @@ -1662,13 +1664,13 @@ def reindex_cache_write( The target storage scope. index_map: Union[Callable, IndexMap] User defined indices to access allocated cache buffer, maps from block iter vars. - consumer_blocks: Optional[List[Union[BlockRV, str]]] + consumer_blocks: Optional[List[Union[SBlockRV, str]]] An optional list of consumers that should read directly from the cache. If not specified, all consumers will read from the original buffer. Returns ------- - cached_block : BlockRV + cached_block : SBlockRV The block of the cache stage Examples @@ -1682,7 +1684,7 @@ def before_reindex_cache_write(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -1691,7 +1693,7 @@ def before_reindex_cache_write(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_cache_write) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj)) print(sch.mod["main"].script()) @@ -1705,11 +1707,11 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (64, 2, 128)) B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("A_local"): + with T.sblock("A_local"): vi, vj = T.axis.remap("SS", [i, j]) B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_local[vi % 2, vi // 2, vj] @@ -1721,13 +1723,13 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: cache_write reindex """ - # Convert any string block names into Block RVs. + # Convert any string SBlock names into SBlock RVs. block = self._normalize_block_arg(block) if callable(index_map): index_map = IndexMap.from_func( index_map, - index_dtype=_get_block_default_dtype(self.get(block)), + index_dtype=_get_sblock_default_dtype(self.get(block)), ) return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member self, block, write_buffer_index, storage_scope, index_map @@ -1736,17 +1738,17 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: @type_checked def cache_inplace( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], read_buffer_index: Union[int, str, Buffer], storage_scope: str, - ) -> List[BlockRV]: + ) -> List[SBlockRV]: """Create blocks that reads & write a buffer region into a cache block. It requires the target block both read & write the target buffer. Mainly for inplace operation. Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The target block operates on the target buffer. read_buffer_index: int @@ -1760,7 +1762,7 @@ def cache_inplace( Returns ------- - cached_blocks : List[BlockRV] + cached_blocks : List[SBlockRV] The blocks of the cache stage, read cache first, write cache second Examples @@ -1772,7 +1774,7 @@ def cache_inplace( @T.prim_func def before_cache_inplace(data_io: T.Buffer((64), "int32")): for i0 in T.serial(1): - with T.block("A"): + with T.sblock("A"): T.reads(data_io[:64]) T.writes(data_io[:64]) T.evaluate(T.call_extern("call_impl", data_io.data, dtype="")) @@ -1782,7 +1784,7 @@ def before_cache_inplace(data_io: T.Buffer((64), "int32")): .. code-block:: python sch = tir.Schedule(before_cache_inplace) - block_a = sch.get_block("A") + block_a = sch.get_sblock("A") sch.cache_inplace(block_a, 0, "local") print(sch.mod["main"].script()) @@ -1795,17 +1797,17 @@ def cache_inplace(data_io: T.Buffer(64, "int32")) -> None: data_io_local = T.alloc_buffer([64], dtype="int32", scope="local") for i0 in T.serial(1): for ax0 in T.serial(64): - with T.block("data_io_local"): + with T.sblock("data_io_local"): v0 = T.axis.spatial(64, ax0) T.reads(data_io[v0]) T.writes(data_io_local[v0]) data_io_local[v0] = data_io[v0] - with T.block("A"): + with T.sblock("A"): T.reads(data_io_local[0 : 64]) T.writes(data_io_local[0 : 64]) T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype="")) for ax0 in T.serial(64): - with T.block("data_io_local"): + with T.sblock("data_io_local"): v0 = T.axis.spatial(64, ax0) T.reads(data_io_local[v0]) T.writes(data_io[v0]) @@ -1824,14 +1826,14 @@ def cache_inplace(data_io: T.Buffer(64, "int32")) -> None: @type_checked def cache_index( - self, block: Union[BlockRV, str], storage_scope: str, cse_thresh: int = 0 - ) -> List[BlockRV]: + self, block: Union[SBlockRV, str], storage_scope: str, cse_thresh: int = 0 + ) -> List[SBlockRV]: """Create a block to cache precomputed index for later use. if there is no index computation, keep unchanged. Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The target block operates on the target buffer. storage_scope: str @@ -1844,7 +1846,7 @@ def cache_index( Returns ------- - cached_blocks : List[BlockRV] + cached_blocks : List[SBlockRV] The blocks of the stage writing the cache buffers Examples @@ -1858,7 +1860,7 @@ def resize(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (1, 3, 40, 40)) B = T.match_buffer(b, (1, 3, 80, 80)) for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): - with T.block("A"): + with T.sblock("A"): n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) B[n, c, vi, vj] = A[n, c, vi//4 + vj//4, vj//2] @@ -1867,7 +1869,7 @@ def resize(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(resize) - block_a = sch.get_block("A") + block_a = sch.get_sblock("A") sch.cache_index(block_a, "global", 1) print(sch.mod["main"].script()) @@ -1882,20 +1884,20 @@ def resize_cache_index( index_var_0 = T.alloc_buffer([80, 80], dtype="int32", strides=[1]) index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) for ax0, ax1 in T.grid(80, 80): - with T.block("index_0"): + with T.sblock("index_0"): v0 = T.axis.spatial(80, ax0) v1 = T.axis.spatial(80, ax1) T.reads() T.writes(index_var_0[v0, v1]) index_var_0[v0, v1] = v0 // 4 + v1 // 4 for ax0 in T.serial(80): - with T.block("index_1"): + with T.sblock("index_1"): v0 = T.axis.spatial(80, ax0) T.reads() T.writes(index_var_1[v0]) index_var_1[v0] = v0 // 2 for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): - with T.block("A"): + with T.sblock("A"): n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[n, c, vi // 4 + vj // 4, vj // 2]) T.writes(B[n, c, vi, vj]) @@ -1910,8 +1912,8 @@ def resize_cache_index( @type_checked def reindex( - self, block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer] - ) -> BlockRV: + self, block: Union[SBlockRV, str], buffer: Union[Tuple[str, int], str, Buffer] + ) -> SBlockRV: """Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes the buffer. It requires: @@ -1920,7 +1922,7 @@ def reindex( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block that accesses the target buffer. If a string, this must uniquely identify a block. @@ -1944,7 +1946,7 @@ def reindex( Returns ------- - reindex_block : BlockRV + reindex_block : SBlockRV The block of the reindex stage Examples @@ -1960,7 +1962,7 @@ def before_reindex( B: T.Buffer((128, 128), "float32") ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] * 2.0 @@ -1969,7 +1971,7 @@ def before_reindex( .. code-block:: python sch = tir.Schedule(before_reindex) - block = sch.get_block("B") + block = sch.get_sblock("B") sch.reindex(block, ("read", 0)) After applying reindex, the IR becomes: @@ -1983,11 +1985,11 @@ def after_reindex( ) -> None: A_reindex = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("A_reindex"): + with T.sblock("A_reindex"): vi, vj = T.axis.remap("SS", [i, j]) A_reindex[vi, vj] = A[vj, vi] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_reindex[vi, vj] * 2.0 @@ -2003,15 +2005,15 @@ def after_reindex( ########## Schedule: Data movement ########## def read_at( - self, loop: LoopRV, block: BlockRV, read_buffer_index: int, storage_scope: str - ) -> BlockRV: + self, loop: LoopRV, block: SBlockRV, read_buffer_index: int, storage_scope: str + ) -> SBlockRV: return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member self, loop, block, read_buffer_index, storage_scope ) def write_at( - self, loop: LoopRV, block: BlockRV, write_buffer_index: int, storage_scope: str - ) -> BlockRV: + self, loop: LoopRV, block: SBlockRV, write_buffer_index: int, storage_scope: str + ) -> SBlockRV: return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member self, loop, block, write_buffer_index, storage_scope ) @@ -2021,7 +2023,7 @@ def write_at( @type_checked def compute_at( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, index: int = -1, @@ -2045,7 +2047,7 @@ def compute_at( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block to be moved loop: LoopRV @@ -2073,11 +2075,11 @@ def before_compute_at(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2086,8 +2088,8 @@ def before_compute_at(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_compute_at) - block = sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B") + loop, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=False) print(sch.mod["main"].script()) @@ -2102,11 +2104,11 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2119,7 +2121,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: @type_checked def reverse_compute_at( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, index: int = -1, @@ -2140,7 +2142,7 @@ def reverse_compute_at( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block to be moved loop: LoopRV @@ -2168,11 +2170,11 @@ def before_reverse_compute_at(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2181,8 +2183,8 @@ def before_reverse_compute_at(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_reverse_compute_at) - block = sch.get_block("C") - loop, _ = sch.get_loops(sch.get_block("B")) + block = sch.get_sblock("C") + loop, _ = sch.get_loops(sch.get_sblock("B")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) print(sch.mod["main"].script()) @@ -2197,11 +2199,11 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2212,7 +2214,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: ) @type_checked - def compute_inline(self, block: Union[BlockRV, str]) -> None: + def compute_inline(self, block: Union[SBlockRV, str]) -> None: """Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block, which only produces one buffer @@ -2227,7 +2229,7 @@ def compute_inline(self, block: Union[BlockRV, str]) -> None: Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block to be inlined to its consumer(s) Examples @@ -2243,11 +2245,11 @@ def before_inline(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2256,7 +2258,7 @@ def before_inline(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_inline) - sch.compute_inline(sch.get_block("B")) + sch.compute_inline(sch.get_sblock("B")) print(sch.mod["main"].script()) After applying compute-inline, the IR becomes: @@ -2268,7 +2270,7 @@ def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -2277,7 +2279,7 @@ def after_inline(a: T.handle, c: T.handle) -> None: _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member @type_checked - def reverse_compute_inline(self, block: Union[BlockRV, str]) -> None: + def reverse_compute_inline(self, block: Union[SBlockRV, str]) -> None: """Inline a block into its only producer. It requires: 1) The block is a complete non-root block, which only produces and consumes one buffer @@ -2295,7 +2297,7 @@ def reverse_compute_inline(self, block: Union[BlockRV, str]) -> None: Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block to be inlined to its producer Examples @@ -2311,11 +2313,11 @@ def before_inline(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2324,7 +2326,7 @@ def before_inline(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_inline) - sch.reverse_compute_inline(sch.get_block("C")) + sch.reverse_compute_inline(sch.get_sblock("C")) print(sch.mod["main"].script()) After applying reverse-compute-inline, the IR becomes: @@ -2336,7 +2338,7 @@ def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -2348,8 +2350,8 @@ def after_inline(a: T.handle, c: T.handle) -> None: @type_checked def fuse_reduction_epilogue( self, - reduction_block: Union[BlockRV, str], - epilogue_block: Union[BlockRV, str], + reduction_block: Union[SBlockRV, str], + epilogue_block: Union[SBlockRV, str], ) -> None: """Fuse an epilogue block into a reduction block. @@ -2383,9 +2385,9 @@ def fuse_reduction_epilogue( Parameters ---------- - reduction_block : Union[BlockRV, str] + reduction_block : Union[SBlockRV, str] The reduction block (e.g., matmul) - epilogue_block : Union[BlockRV, str] + epilogue_block : Union[SBlockRV, str] The epilogue block to be fused (e.g., bias add, ReLU, clipping) Examples @@ -2402,7 +2404,7 @@ def fuse_reduction_epilogue( ########## Schedule: Reduction ########## @type_checked - def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV: + def decompose_reduction(self, block: Union[SBlockRV, str], loop: LoopRV) -> SBlockRV: """Decompose a reduction block into two separate blocks. a) The init block, which is translated from the init statement of the reduction block; @@ -2421,14 +2423,14 @@ def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> Block Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The reduction block to be decomposed loop : LoopRV The loop above which the init block is inserted before. Returns ------- - init_block : BlockRV + init_block : SBlockRV The init block Examples @@ -2453,7 +2455,7 @@ def before_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: .. code-block:: python sch = tir.Schedule(before_decompose) - C = sch.get_block("C") + C = sch.get_sblock("C") i, j, k = sch.get_loops(C) sch.decompose_reduction(C, i) print(sch.mod["main"].script()) @@ -2481,7 +2483,7 @@ def after_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore @type_checked - def rfactor(self, loop: LoopRV, factor_axis: int) -> BlockRV: + def rfactor(self, loop: LoopRV, factor_axis: int) -> SBlockRV: """Factorize an associative reduction block by the specified loop. An associative reduction cannot be parallelized directly, @@ -2550,7 +2552,7 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> BlockRV: Returns ------- - rf_block : BlockRV + rf_block : SBlockRV The block which computes partial results over each slices (i.e., the first block as described in the above illustration) @@ -2566,7 +2568,7 @@ def before_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128,)) for ii, i, j in T.grid(128, 128, 128): - with T.block("B"): + with T.sblock("B"): vii, vi, vj = T.axis.remap("SRR", [ii, i, j]) with T.init(): B[vii] = 0.0 @@ -2577,7 +2579,7 @@ def before_rfactor(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_rfactor) - _, _, k = sch.get_loops(sch.get_block("B")) + _, _, k = sch.get_loops(sch.get_sblock("B")) sch.rfactor(k, 0) print(sch.mod["main"].script()) @@ -2591,13 +2593,13 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128]) B_rf = T.alloc_buffer([128, 128]) for i2, ii, i in T.grid(128, 128, 128): - with T.block("B_rf"): + with T.sblock("B_rf"): vi2, vii, vi = T.axis.remap("SSR", [i2, ii, i]) with T.init(): B_rf[vi2, vii] = 0.0 B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) for ii, i2 in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vii, vi2 = T.axis.remap("SR", [ii, i2]) with T.init(): B[vii] = 0.0 @@ -2630,11 +2632,11 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: # pylint: disable-next=no-member return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore - ######## Schedule: Block annotation ######## + ######## Schedule: SBlock annotation ######## @type_checked def storage_align( # pylint: disable=too-many-arguments - self, block: Union[BlockRV, str], buffer_index: int, axis: int, factor: int, offset: int + self, block: Union[SBlockRV, str], buffer_index: int, axis: int, factor: int, offset: int ) -> None: """Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more @@ -2643,7 +2645,7 @@ def storage_align( # pylint: disable=too-many-arguments Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The producer block of the buffer. buffer_index : int The index of the buffer in block's write region. @@ -2667,11 +2669,11 @@ def before_storage_align(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2680,7 +2682,7 @@ def before_storage_align(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_storage_align) - sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1) + sch.storage_align(sch.get_sblock("B"), buffer_index=0, axis=0, factor=128, offset=1) print(sch.mod["main"].script()) After applying storage_align, the IR becomes: @@ -2693,12 +2695,12 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): - T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + with T.sblock("B"): + T.sblock_attr({"buffer_dim_align": [[[0, 128, 1]]]}) vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2715,14 +2717,14 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: @type_checked def set_scope( - self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer], storage_scope: str + self, block: Union[SBlockRV, str], buffer_index: Union[int, str, Buffer], storage_scope: str ) -> None: """Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index. Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The producer block of the buffer buffer_index : int The index of the buffer in block's write region @@ -2743,11 +2745,11 @@ def before_set_scope( B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -2756,7 +2758,7 @@ def before_set_scope( .. code-block:: python sch = tir.Schedule(before_set_scope) - sch.set_scope(sch.get_block("B"), buffer_index=0, storage_scope="shared") + sch.set_scope(sch.get_sblock("B"), buffer_index=0, storage_scope="shared") print(sch.mod["main"].script()) After applying set_scope, the IR becomes: @@ -2770,11 +2772,11 @@ def after_set_scope( B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_shared[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B_shared[vi, vj] + T.float32(1) @@ -2792,7 +2794,7 @@ def after_set_scope( ) @type_checked - def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None: + def unsafe_set_dtype(self, block: Union[SBlockRV, str], buffer_index: int, dtype: str) -> None: """Set the data type of a buffer, where the buffer is specified by the a block and write-index. @@ -2801,7 +2803,7 @@ def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The producer block of the buffer buffer_index : int The index of the buffer in block's write region @@ -2822,11 +2824,11 @@ def before_set_dtype( B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j] C[vi, vj] = B[vi, vj] + 1.0 @@ -2849,11 +2851,11 @@ def after_set_dtype( B = T.alloc_buffer((128, 128), dtype="float16") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j] C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 @@ -2871,20 +2873,20 @@ def after_set_dtype( @type_checked def blockize( - self, target: Union[LoopRV, List[BlockRV]], preserve_unit_iters: bool = True - ) -> BlockRV: + self, target: Union[LoopRV, List[SBlockRV]], preserve_unit_iters: bool = True + ) -> SBlockRV: """Convert multiple blocks or the subtree rooted at a specific loop into a block. Parameters ---------- - target : LoopRV or List[BlockRV] + target : LoopRV or List[SBlockRV] The root of the subtree or the specified blocks. preserve_unit_iters : bool Whether or not to preserve unit iterators in block bindings Returns ------- - result : BlockRV + result : SBlockRV The new block. Examples @@ -2900,7 +2902,7 @@ def before_blockize( B: T.Buffer((128, 128), "float32") ) -> None: for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + j_1) T.reads(A[vi, vj]) @@ -2912,7 +2914,7 @@ def before_blockize( .. code-block:: python sch = tir.Schedule(before_blockize) - B = sch.get_block("B") + B = sch.get_sblock("B") _, _, i1, _ = sch.get_loops(B) sch.blockize(i1) print(sch.mod["main"].script()) @@ -2927,12 +2929,12 @@ def after_blockize( B: T.Buffer((128, 128), "float32") )-> None: for i_0, j_0 in T.grid(8, 8): - with T.block("B_o"): + with T.sblock("B_o"): vio, vjo = T.axis.remap("SS", [i_0, j_0]) T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) for i_1, j_1 in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i_1, j_1]) T.reads(A[vio * 16 + vi, vjo * 16 + vj]) T.writes(B[vio * 16 + vi, vjo * 16 + vj]) @@ -2951,7 +2953,7 @@ def after_blockize( @type_checked def tensorize( self, - block_or_loop: Union[BlockRV, LoopRV], + block_or_loop: Union[SBlockRV, LoopRV], tensor_intrin: str, preserve_unit_iters: bool = True, ) -> None: @@ -2959,7 +2961,7 @@ def tensorize( Parameters ---------- - block_or_loop : Union[BlockRV, LoopRV] + block_or_loop : Union[SBlockRV, LoopRV] The loop to be tensorized. tensor_intrin : str The tensor intrin or the name of the tensor intrin. @@ -2980,9 +2982,9 @@ def before_tensorize( C: T.Buffer((128, 128), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16): - with T.block("update"): + with T.sblock("update"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + j_1) vk = T.axis.reduce(128, k_0 * 16 + k_1) @@ -3000,11 +3002,11 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) T.writes(C[0 : 16, 0 : 16]) for i, j, k in T.grid(16, 16, 16): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -3015,7 +3017,7 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) T.writes(C[0 : 16, 0 : 16]) T.evaluate( @@ -3039,7 +3041,7 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_tensorize) - update = sch.get_block("update") + update = sch.get_sblock("update") _, _, _, i1, _, _ = sch.get_loops(update) sch.tensorize(i1, "test_mma_intrin") print(sch.mod["main"].script()) @@ -3055,9 +3057,9 @@ def after_tensorize( C: T.Buffer((128, 128), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i_0, j_0, k_0 in T.grid(8, 8, 8): - with T.block("update_o"): + with T.sblock("update_o"): vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0]) T.reads( C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], @@ -3112,13 +3114,13 @@ def after_tensorize( @type_checked def annotate( - self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str, ann_val: AnnotationValueT + self, block_or_loop: Union[SBlockRV, LoopRV], ann_key: str, ann_val: AnnotationValueT ) -> None: """Annotate a block/loop with a key value pair Parameters ---------- - block_or_loop: Union[BlockRV, LoopRV] + block_or_loop: Union[SBlockRV, LoopRV] The block/loop to be annotated ann_key : str The annotation key @@ -3137,7 +3139,7 @@ def before_annotate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -3146,7 +3148,7 @@ def before_annotate(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_annotate) - sch.annotate(sch.get_block("B"), "ann_key", "ann_value") + sch.annotate(sch.get_sblock("B"), "ann_key", "ann_value") print(sch.mod["main"].script()) After applying annotate, the IR becomes: @@ -3158,9 +3160,9 @@ def after_annotate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"ann_key", "ann_value"}) + T.sblock_attr({"ann_key", "ann_value"}) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -3169,12 +3171,12 @@ def after_annotate(a: T.handle, b: T.handle) -> None: ) @type_checked - def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: + def unannotate(self, block_or_loop: Union[SBlockRV, LoopRV], ann_key: str) -> None: """Unannotate a block/loop's annotation with key ann_key Parameters ---------- - block_or_loop: Union[BlockRV, LoopRV] + block_or_loop: Union[SBlockRV, LoopRV] The block/loop to be unannotated ann_key : str The annotation key @@ -3191,9 +3193,9 @@ def before_unannotate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"ann_key", "ann_value"}) + T.sblock_attr({"ann_key", "ann_value"}) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do annotate: @@ -3201,7 +3203,7 @@ def before_unannotate(a: T.handle, b: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_unannotate) - sch.unannotate(sch.get_block("B"), "ann_key") + sch.unannotate(sch.get_sblock("B"), "ann_key") print(sch.mod["main"].script()) After applying unannotate, the IR becomes: @@ -3213,7 +3215,7 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -3224,19 +3226,19 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: ########## Schedule: Layout transformation ########## - def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV: + def _normalize_block_arg(self, block: Union[SBlockRV, str]) -> SBlockRV: if isinstance(block, str): - return self.get_block(block) + return self.get_sblock(block) return block def _normalize_buffer_arg( self, - block: BlockRV, + block: SBlockRV, buffer: Union[Tuple[str, int], int, str, Buffer], required_buffer_type=None, ) -> Tuple[str, int, Buffer]: - block_obj: Block = self.get(block) + block_obj: SBlock = self.get(block) block_name = block_obj.name_hint def iter_buffers(): @@ -3301,7 +3303,7 @@ def iter_buffers(): @type_checked def transform_layout( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], buffer: Union[Tuple[str, int], str, Buffer], index_map: Union[IndexMap, Callable], pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] = None, @@ -3312,7 +3314,7 @@ def transform_layout( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block that accesses the target buffer. If a string, this must uniquely identify a block. @@ -3392,11 +3394,11 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -3405,7 +3407,7 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None: .. code-block:: python sch = tir.Schedule(before_storage_align) - sch.transform_layout(sch.get_block("B"), buffer=("write",0), + sch.transform_layout(sch.get_sblock("B"), buffer=("write",0), index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) @@ -3419,11 +3421,11 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> B = T.alloc_buffer((8, 8, 16, 16), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 @@ -3436,7 +3438,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> index_map, axis_separators = IndexMap.from_func_with_separators( index_map, ndim=ndim, - index_dtype=_get_block_default_dtype(self.get(block)), + index_dtype=_get_sblock_default_dtype(self.get(block)), ) else: axis_separators = [] @@ -3447,7 +3449,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> pad_value = IndexMap.from_func( pad_value, ndim=len(index_map.final_indices), - index_dtype=_get_block_default_dtype(self.get(block)), + index_dtype=_get_sblock_default_dtype(self.get(block)), ) elif not isinstance(pad_value, IndexMap): # Explicitly convert python int/float arguments to the @@ -3461,7 +3463,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> pad_value = IndexMap.from_func( lambda *indices: pad_value, ndim=len(index_map.final_indices), - index_dtype=_get_block_default_dtype(self.get(block)), + index_dtype=_get_sblock_default_dtype(self.get(block)), ) buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 @@ -3481,13 +3483,13 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> @type_checked def transform_block_layout( - self, block: Union[BlockRV, str], index_map: Union[IndexMap, Callable] + self, block: Union[SBlockRV, str], index_map: Union[IndexMap, Callable] ) -> None: """Apply a transformation represented by IndexMap to block Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block to be transformed index_map : Union[IndexMap, Callable] @@ -3506,7 +3508,7 @@ def before_transform_block_layout( B: T.Buffer((16, 16), "float32") ) -> None: for i, j in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -3515,7 +3517,7 @@ def before_transform_block_layout( .. code-block:: python sch = tir.Schedule(before_transform_block_layout) - sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,)) + sch.transform_block_layout(sch.get_sblock("B"), lambda i, j: (i * 16 + j,)) print(sch.mod["main"].script()) After applying transform_block_layout, the IR becomes: @@ -3528,7 +3530,7 @@ def after_transform_block_layout( B: T.Buffer((16, 16), "float32") ) -> None: for i in range(256): - with T.block("B"): + with T.sblock("B"): vi, = T.axis.remap("S", [i]) B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0 """ @@ -3536,7 +3538,7 @@ def after_transform_block_layout( if callable(index_map): index_map = IndexMap.from_func( index_map, - index_dtype=_get_block_default_dtype(self.get(block)), + index_dtype=_get_sblock_default_dtype(self.get(block)), ) _ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member self, block, index_map @@ -3544,7 +3546,7 @@ def after_transform_block_layout( def set_axis_separator( self, - block: Union[BlockRV, str], + block: Union[SBlockRV, str], buffer: Union[Tuple[str, int], str, Buffer], axis_separators: Optional[List[int]], ) -> None: @@ -3553,7 +3555,7 @@ def set_axis_separator( Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block that accesses the target buffer. If a string, this must uniquely identify a block. @@ -3593,11 +3595,11 @@ def before_set_axis_separator( B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -3606,7 +3608,7 @@ def before_set_axis_separator( .. code-block:: python sch = tir.Schedule(before_set_axis_separator) - sch.set_axis_separators(sch.get_block("B"), buffer=("write", 0), + sch.set_axis_separators(sch.get_sblock("B"), buffer=("write", 0), axis_separators=[1]) print(sch.mod["main"].script()) @@ -3621,11 +3623,11 @@ def after_set_axis_separators( B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1]) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) """ @@ -3641,7 +3643,7 @@ def after_set_axis_separators( ########## Schedule: Padding decomposition ######### @type_checked - def decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV: + def decompose_padding(self, block: Union[SBlockRV, str], loop: LoopRV) -> SBlockRV: """Decompose a block of padding computation pattern into two separate blocks. a) The block which fill const pad values into full write region; @@ -3660,14 +3662,14 @@ def decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The padding block to be decomposed. loop : LoopRV The loop above which the pad value filling block is inserted before. Returns ------- - pad_value_block : BlockRV + pad_value_block : SBlockRV The block filling const pad values. Examples @@ -3679,7 +3681,7 @@ def decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV @T.prim_func def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): for i in range(140): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32") @@ -3688,7 +3690,7 @@ def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): .. code-block:: python sch = tir.Schedule(before_decompose, debug_mask="all") - block = sch.get_block("block") + block = sch.get_sblock("block") sch.decompose_padding(block, sch.get_loops(block)[0]) print(sch.mod["main].script()) @@ -3699,11 +3701,11 @@ def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): @T.prim_func def after_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): for i in T.serial(140): - with T.block("block_pad_const"): + with T.sblock("block_pad_const"): vi = T.axis.spatial(140, i) y[vi] = 0 for i in T.serial(128): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial(128, i) y[vi + 6] = x[vi] """ @@ -3713,13 +3715,13 @@ def after_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): ) @type_checked - def can_decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> bool: + def can_decompose_padding(self, block: Union[SBlockRV, str], loop: LoopRV) -> bool: """Check whether the block match padding pattern and can be decomposed.""" # pylint: disable-next=no-member return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore @type_checked - def pad_einsum(self, block: Union[BlockRV, str], padding: List[int]) -> None: + def pad_einsum(self, block: Union[SBlockRV, str], padding: List[int]) -> None: """Pad the computation of Einsum. On a block with trivial binding, this primitive pads the iteration domain of the block by @@ -3732,7 +3734,7 @@ def pad_einsum(self, block: Union[BlockRV, str], padding: List[int]) -> None: Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The block that matches the Einsum pattern. padding : List[int] @@ -3752,7 +3754,7 @@ def before_pad_einsum( C: T.Buffer((127, 127), "float32"), ) -> None: for i0, i1, i2 in T.grid(127, 127, 127): - with T.block("C_shared"): + with T.sblock("C_shared"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) with T.init(): C[i, j] = T.float32(0) @@ -3763,7 +3765,7 @@ def before_pad_einsum( .. code-block:: python sch = tir.Schedule(before_pad_einsum, debug_mask="all") - block = sch.get_block("C_shared") + block = sch.get_sblock("C_shared") sch.pad_einsum(block, [32, 32, 32]) print(sch.mod["main"].script()) @@ -3777,12 +3779,12 @@ def main( B: T.Buffer((127, 127), "float32"), C: T.Buffer((127, 127), "float32"), ): - # with T.block("root"): + # with T.sblock("root"): A_pad = T.alloc_buffer((128, 128)) B_pad = T.alloc_buffer((128, 128)) C_pad = T.alloc_buffer((128, 128)) for i0, i1 in T.grid(128, 128): - with T.block("A_pad"): + with T.sblock("A_pad"): v0, v1 = T.axis.remap("SS", [i0, i1]) A_pad[v0, v1] = T.if_then_else( v0 < 127 and v1 < 127, @@ -3790,7 +3792,7 @@ def main( T.float32(0), ) for i0, i1 in T.grid(128, 128): - with T.block("B_pad"): + with T.sblock("B_pad"): v0, v1 = T.axis.remap("SS", [i0, i1]) B_pad[v0, v1] = T.if_then_else( v0 < 127 and v1 < 127, @@ -3798,13 +3800,13 @@ def main( T.float32(0), ) for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("C_shared"): + with T.sblock("C_shared"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) with T.init(): C_pad[i, j] = T.float32(0) C_pad[i, j] = C_pad[i, j] + A_pad[i, k] * B_pad[k, j] for i0, i1 in T.grid(127, 127): - with T.block("C_pad"): + with T.sblock("C_pad"): v0, v1 = T.axis.remap("SS", [i0, i1]) C[v0, v1] = C_pad[v0, v1] @@ -3817,7 +3819,7 @@ def main( ######## Schedule: Buffer transformation ######## @type_checked - def rolling_buffer(self, block: Union[BlockRV, str], write_buffer_index: int) -> None: + def rolling_buffer(self, block: Union[SBlockRV, str], write_buffer_index: int) -> None: """Compute the target buffer via rolling buffering, select the outermost rollable axis with a positive bound overlap that appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along the rolling dimension, @@ -3835,7 +3837,7 @@ def rolling_buffer(self, block: Union[BlockRV, str], write_buffer_index: int) -> Parameters ---------- - block : Union[BlockRV, str] + block : Union[SBlockRV, str] The producer block of the buffer. write_buffer_index : int The index of the buffer in block's write region. @@ -3852,11 +3854,11 @@ def before_rolling_buffer( A: T.Buffer((12, 12), "int8"), C: T.Buffer((8, 8), "int8") ) -> None: # body - # with T.block("root") + # with T.sblock("root") B = T.alloc_buffer([10, 10], dtype="int8") for i0, i1 in T.grid(2, 2): for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): - with T.block("B"): + with T.sblock("B"): ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) @@ -3864,7 +3866,7 @@ def before_rolling_buffer( B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1] ) for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) @@ -3877,7 +3879,7 @@ def before_rolling_buffer( .. code-block:: python sch = tir.Schedule(before_rolling_buffer) - sch.rolling_buffer(sch.get_block("B"), write_buffer_index=0) + sch.rolling_buffer(sch.get_sblock("B"), write_buffer_index=0) print(sch.mod["main"].script()) After applying rolling_buffer, the IR becomes: @@ -3890,11 +3892,11 @@ def after_rolling_buffer( C: T.Buffer((8, 8), "int8") ) -> None: # body - # with T.block("root") + # with T.sblock("root") B = T.alloc_buffer([6, 10], dtype="int8") for i0, i1 in T.grid(2, 2): for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1)) ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) @@ -3903,7 +3905,7 @@ def after_rolling_buffer( B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1] ) for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) @@ -3928,13 +3930,13 @@ def enter_postproc(self) -> None: @type_checked def unsafe_hide_buffer_access( - self, block: BlockRV, buf_type: str, buf_index_array: List[int] + self, block: SBlockRV, buf_type: str, buf_index_array: List[int] ) -> None: """Hide some buffer access in a given block. This is an unsafe schedule primitive. Parameters ---------- - block : BlockRV + block : SBlockRV The block where we hide read access. buf_type : str The buffer type: "read"/"write". @@ -3959,13 +3961,13 @@ def unsafe_hide_buffer_access( @type_checked def annotate_buffer_access( - self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable + self, block: SBlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable ) -> None: """Annotate the read or write region of a block Parameters ---------- - block : BlockRV + block : SBlockRV The block to be annotated buffer_index : int The index of the buffer in block's read or write region @@ -3993,11 +3995,11 @@ def before_annotate_buffer_access( ) -> None: B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -4006,7 +4008,7 @@ def before_annotate_buffer_access( .. code-block:: python sch = tir.Schedule(before_annotate_buffer_access) - block = sch.get_block("B") + block = sch.get_sblock("B") sch.annotate_buffer_access(block, 0, "read", lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1))) print(sch.mod["main"].script()) @@ -4022,14 +4024,14 @@ def after_annotate_buffer_access( ) -> None: B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1]) T.writes(B[vi, vj]) - T.block_attr({"explicit_read_region": 0}) + T.sblock_attr({"explicit_read_region": 0}) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index 36436fe95783..fe980b0848f1 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -23,10 +23,10 @@ from tvm_ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object -from tvm.tir import Block, BlockRealize, For, PrimFunc +from tvm.tir import SBlock, SBlockRealize, For, PrimFunc from . import _ffi_api -from ..block_scope import BlockScope, StmtSRef +from ..block_scope import SBlockScope, StmtSRef CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"]) @@ -133,7 +133,7 @@ def __init__( _parse_enable_checks(enable_check), ) - def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: + def get_sref(self, stmt: Union[SBlock, For]) -> Optional[StmtSRef]: """Return the corresponding sref that points to the stmt Parameters @@ -148,8 +148,8 @@ def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: """ return _ffi_api.ScheduleStateGetSRef(self, stmt) # type: ignore # pylint: disable=no-member - def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: - """Get the BlockScope correpsonding to the block sref + def get_sblock_scope(self, block_sref: StmtSRef) -> SBlockScope: + """Get the SBlockScope correpsonding to the block sref Parameters ---------- @@ -161,7 +161,7 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: sref : StmtSRef The corresponding sref """ - return _ffi_api.ScheduleStateGetBlockScope( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleStateGetSBlockScope( # type: ignore # pylint: disable=no-member self, block_sref ) @@ -198,8 +198,8 @@ def _get_cached_flags(self, block_sref: StmtSRef) -> CachedFlags: def replace( self, src_sref: StmtSRef, - tgt_stmt: Union[Block, For, BlockRealize], - block_sref_reuse: Optional[Dict[Block, Block]] = None, + tgt_stmt: Union[SBlock, For, SBlockRealize], + block_sref_reuse: Optional[Dict[SBlock, SBlock]] = None, ) -> None: """ Replace the part of the AST, as being pointed to by `src_sref`, @@ -208,7 +208,7 @@ def replace( the only copy to the IRModule and IR nodes. Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`. - 1) Block -> Block + 1) SBlock -> SBlock 2) Loop -> Loop 3) Loop -> BlockRealize diff --git a/python/tvm/tir/schedule/transform.py b/python/tvm/tir/schedule/transform.py index fbaca81197e5..96a030f41f9a 100644 --- a/python/tvm/tir/schedule/transform.py +++ b/python/tvm/tir/schedule/transform.py @@ -17,12 +17,12 @@ """Transformation on TIR schedule.""" from typing import Optional -from tvm.tir.schedule import Schedule, BlockRV, LoopRV +from tvm.tir.schedule import Schedule, SBlockRV, LoopRV from . import _ffi_api def tile_with_tensor_intrin( - sch: Schedule, block: BlockRV, intrin_name: str, allow_padding: bool = False + sch: Schedule, block: SBlockRV, intrin_name: str, allow_padding: bool = False ) -> Optional[LoopRV]: """Tile a subset of loops in the block according to the given tensor intrinsic. @@ -30,7 +30,7 @@ def tile_with_tensor_intrin( ---------- sch : Schedule The schedule to which tiling is applied - block : BlockRV + block : SBlockRV The block whose subset of loops will be tiled intrin_name : str The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 448ace3ade63..cd8f6e92a706 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -605,9 +605,9 @@ def __init__(self, buffer: Buffer, source: BufferRegion) -> None: ) -@tvm_ffi.register_object("tir.Block") -class Block(Stmt): - """Block node. +@tvm_ffi.register_object("tir.SBlock") +class SBlock(Stmt): + """SBlock node. Parameters ---------- @@ -673,7 +673,7 @@ def __init__( if annotations is None: annotations = {} self.__init_handle_by_constructor__( - _ffi_api.Block, # type: ignore + _ffi_api.SBlock, # type: ignore iter_vars, reads, writes, @@ -687,9 +687,9 @@ def __init__( ) # type: ignore -@tvm_ffi.register_object("tir.BlockRealize") -class BlockRealize(Stmt): - """BlockRealize node. +@tvm_ffi.register_object("tir.SBlockRealize") +class SBlockRealize(Stmt): + """SBlockRealize node. Parameters ---------- @@ -699,7 +699,7 @@ class BlockRealize(Stmt): predicate : Union[PrimExpr, bool] The predicate of the block. - block : Block + block : SBlock The block to realize span : Optional[Span] @@ -708,20 +708,20 @@ class BlockRealize(Stmt): iter_values: List[PrimExpr] predicate: PrimExpr - block: Block + block: SBlock span: Optional[Span] def __init__( self, iter_values: List[PrimExpr], predicate: Union[PrimExpr, bool], - block: Block, + block: SBlock, span: Optional[Span] = None, ) -> None: if isinstance(predicate, bool): predicate = const(predicate, "bool") self.__init_handle_by_constructor__( - _ffi_api.BlockRealize, # type: ignore + _ffi_api.SBlockRealize, # type: ignore iter_values, predicate, block, diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 0a5c0ea3a51a..aeb0883fed3d 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -42,12 +42,12 @@ def neon_4x4_i8i8i32_desc( B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) for i in T.serial(0, 4): for k in T.serial(0, 4): - with T.block("update"): + with T.sblock("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") @@ -58,7 +58,7 @@ def neon_4x4_i8i8i32_impl( B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) @@ -123,12 +123,12 @@ def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1) B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1) C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) for i in T.serial(0, 4): for k in T.serial(0, 4): - with T.block("update"): + with T.sblock("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], dtype=out_dtype) * T.cast( B[vi, vk], dtype=out_dtype @@ -139,7 +139,7 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1) B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1) C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) @@ -258,11 +258,11 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows): def desc(a: T.handle, a_t: T.handle) -> None: A = T.match_buffer(a, (SVF2, SVF2), dtype="float32", offset_factor=1) A_t = T.match_buffer(a_t, (SVF2, SVF2), dtype="float32", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:SVF2, 0:SVF2]) T.writes(A_t[0:SVF2, 0:SVF2]) for k, m in T.grid(SVF2, SVF2): - with T.block("transpose"): + with T.sblock("transpose"): v_m, v_k = T.axis.remap("SS", [m, k]) A_t[v_k, v_m] = A[v_m, v_k] @@ -285,7 +285,7 @@ def impl(): strides=[T.int32(), 1], ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:SVF2, 0:SVF2]) T.writes(A_t[0:SVF2, 0:SVF2]) @@ -393,11 +393,11 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin(): def desc(a: T.handle, a_t: T.handle) -> None: A = T.match_buffer(a, (SVF2, SVF), dtype="float16", offset_factor=1) A_t = T.match_buffer(a_t, (SVF, SVF2), dtype="float16", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:SVF2, 0:SVF]) T.writes(A_t[0:SVF, 0:SVF2]) for k, m in T.grid(SVF, SVF2): - with T.block("transpose"): + with T.sblock("transpose"): v_m, v_k = T.axis.remap("SS", [m, k]) A_t[v_k, v_m] = A[v_m, v_k] @@ -417,7 +417,7 @@ def impl(): ptrue_fp16 = _create_ptrue_mask("float16") ptrue_fp32 = _create_ptrue_mask("float32") - with T.block("root"): + with T.sblock("root"): T.reads(A[0:SVF2, 0:SVF]) T.writes(A_t[0:SVF, 0:SVF2]) @@ -596,11 +596,11 @@ def desc(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (K, SVF2), dtype=in_dtype, offset_factor=1) C = T.match_buffer(c, (SVF2, SVF2), dtype="float32", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2]) T.writes(C[0:SVF2, 0:SVF2]) for m, n, k in T.grid(SVF2, SVF2, K): - with T.block("gemm"): + with T.sblock("gemm"): v_m, v_n, v_k = T.axis.remap("SSR", [m, n, k]) C[v_m, v_n] += T.Cast("float32", A[v_k, v_m]) * T.Cast("float32", B[v_k, v_n]) @@ -621,7 +621,7 @@ def impl(): ptrue = _create_ptrue_mask(in_dtype) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:SVF2, 0:SVF2], A[0:K, 0:SVF2], B[0:K, 0:SVF2]) T.writes(C[0:SVF2, 0:SVF2]) @@ -723,18 +723,18 @@ def get_sme_init_intrin(): @T.prim_func def desc(c: T.handle) -> None: C = T.match_buffer(c, (SVF2, SVF2), "float32", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C[0:SVF2, 0:SVF2]) for m, n in T.grid(SVF2, SVF2): - with T.block("init"): + with T.sblock("init"): v_m, v_n = T.axis.remap("SS", [m, n]) C[v_m, v_n] = T.float32(0) @T.prim_func def impl(c: T.handle) -> None: C = T.match_buffer(c, (SVF2, SVF2), "float32", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C[0:SVF2, 0:SVF2]) clear_all_tiles = T.int32(255) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 7b0c71583b1a..d063e9381eb3 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -164,12 +164,12 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: scope="warp", ) - with T.block("root"): + with T.sblock("root"): T.reads(shared[0:smem_tile_row, 0:smem_tile_col]) T.writes(warp[0:WARP_SIZE, 0:local_size]) for ax0, ax1 in T.grid(smem_tile_row, smem_tile_col): - with T.block("shared_warp"): + with T.sblock("shared_warp"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(shared[v0, v1]) @@ -199,7 +199,7 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: scope="warp", ) - with T.block("root"): + with T.sblock("root"): T.reads(shared[0:smem_tile_row, 0:smem_tile_col]) T.writes(warp[0:WARP_SIZE, 0:local_size]) for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): @@ -361,7 +361,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: scope="warp", ) - with T.block("root"): + with T.sblock("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], @@ -370,7 +370,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: T.writes(C[0:WARP_SIZE, 0:local_size_out]) for i, j, k in T.grid(M_DIM, N_DIM, k_dim): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) a_row_ind, a_col_ind = T.meta_var(swap_if_flag(i, k, a_transposed)) b_row_ind, b_col_ind = T.meta_var(swap_if_flag(k, j, b_transposed)) @@ -417,7 +417,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: scope="warp", ) - with T.block("root"): + with T.sblock("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], @@ -554,11 +554,11 @@ def get_mma_fill_intrin(dtype, local_size): def mma_fill_desc(a: T.handle) -> None: C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C_warp[0:WARP_SIZE, 0:local_size]) for i0, i1 in T.grid(M_DIM, N_DIM): - with T.block("C_warp"): + with T.sblock("C_warp"): i, j = T.axis.remap("SS", [i0, i1]) thread_id, local_id = T.meta_var(index_map(i, j)) T.reads() @@ -571,7 +571,7 @@ def mma_fill_impl(a: T.handle) -> None: a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C_warp[0:WARP_SIZE, 0:local_size]) @@ -601,11 +601,11 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) - with T.block("root"): + with T.sblock("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) for i0, i1 in T.grid(M_DIM, N_DIM): - with T.block("C_warp"): + with T.sblock("C_warp"): v0, v1 = T.axis.remap("SS", [i0, i1]) thread_id, local_id = T.meta_var(index_map(v0, v1)) T.reads(C_warp[thread_id, local_id]) @@ -626,7 +626,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] ) - with T.block("root"): + with T.sblock("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) @@ -657,7 +657,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] ) - with T.block("root"): + with T.sblock("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) @@ -842,11 +842,11 @@ def wmma_load_desc(a: T.handle, c: T.handle) -> None: offset_factor=offset_factor, scope=wmma_fragment_scope, ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:frag_m, 0:frag_n]) T.writes(C[0:frag_m, 0:frag_n]) for i, j in T.grid(frag_m, frag_n): - with T.block("load"): + with T.sblock("load"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = A[vii, vjj] @@ -874,7 +874,7 @@ def wmma_load_impl(a: T.handle, c: T.handle) -> None: scope=wmma_fragment_scope, strides=[d1, d0], ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:frag_m, 0:frag_n]) T.writes(C[0:frag_m, 0:frag_n]) T.evaluate( @@ -911,11 +911,11 @@ def wmma_fill_desc(c: T.handle) -> None: offset_factor=offset_factor, scope="wmma.accumulator", ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C[0:m_dim, 0:n_dim]) for i, j in T.grid(m_dim, n_dim): - with T.block("init"): + with T.sblock("init"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = zero @@ -932,7 +932,7 @@ def wmma_fill_impl(c: T.handle) -> None: scope="wmma.accumulator", strides=[d1, d0], ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( @@ -969,11 +969,11 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None: C = T.match_buffer( c, (m_dim, n_dim), dtype, align=64, offset_factor=offset_factor, scope=scope ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) for i, j in T.grid(m_dim, n_dim): - with T.block("store"): + with T.sblock("store"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = A[vii, vjj] @@ -1001,7 +1001,7 @@ def wmma_store_impl(a: T.handle, c: T.handle) -> None: scope=scope, strides=[s1, s0], ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( @@ -1069,11 +1069,11 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: scope="wmma.accumulator", ) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 0:b_shape_1]) T.writes(C[0:m_dim, 0:n_dim]) for i, j, k in T.grid(m_dim, n_dim, k_dim): - with T.block(""): + with T.sblock(""): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) B_index_0, B_index_1 = T.meta_var(maybe_swap(vkk, vjj)) C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast( @@ -1117,7 +1117,7 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: strides=[c1, c0], ) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 0:b_shape_1]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( @@ -1483,11 +1483,11 @@ def mma_init_desc(c: T.handle) -> None: dst = T.match_buffer( c, (m_dim, n_dim), dtype, align=64, offset_factor=1, scope="m16n8k8.matrixC" ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(dst[0:m_dim, 0:n_dim]) for i, j in T.grid(m_dim, n_dim): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) dst[vi, vj] = zero @@ -1497,7 +1497,7 @@ def mma_init_impl(c: T.handle) -> None: c, (m_dim, n_dim), dtype, align=64, offset_factor=1, scope="m16n8k8.matrixC" ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(dst[0:m_dim, 0:n_dim]) @@ -1538,11 +1538,11 @@ def mma_load_desc(a: T.handle, c: T.handle) -> None: c, (frag_m, frag_n), dtype, align=64, offset_factor=1, scope=mma_fragment_scope ) - with T.block("root"): + with T.sblock("root"): T.reads(src[0:frag_m, 0:frag_n]) T.writes(dst[0:frag_m, 0:frag_n]) for i, j in T.grid(frag_m, frag_n): - with T.block("root"): + with T.sblock("root"): vi, vj = T.axis.remap("SS", [i, j]) dst[vi, vj] = src[vi, vj] @@ -1571,7 +1571,7 @@ def mma_load_impl(a: T.handle, c: T.handle) -> None: strides=[d0, d1], ) - with T.block("root"): + with T.sblock("root"): T.reads(src[0:frag_m, 0:frag_n]) T.writes(dst[0:frag_m, 0:frag_n]) @@ -1621,11 +1621,11 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: c, (m_dim, n_dim), out_dtype, align=64, offset_factor=1, scope="m16n8k8.matrixC" ) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:B_shape_0, 0:B_shape_1]) T.writes(C[0:m_dim, 0:n_dim]) for i, j, k in T.grid(m_dim, n_dim, k_dim): - with T.block("m16n8k8_sync"): + with T.sblock("m16n8k8_sync"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) B_index_0, B_index_1 = T.meta_var(maybe_swap(vk, vj)) C[vi, vj] = C[vi, vj] + maybe_cast(A[vi, vk]) * maybe_cast( @@ -1668,7 +1668,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: strides=[c0, c1], ) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:B_shape_0, 0:B_shape_1]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( @@ -1708,11 +1708,11 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: c, (m_dim, n_dim), dtype, align=64, offset_factor=1, scope="shared.dyn" ) - with T.block("root"): + with T.sblock("root"): T.reads(src[0:m_dim, 0:n_dim]) T.writes(dst[0:m_dim, 0:n_dim]) for i, j in T.grid(m_dim, n_dim): - with T.block("m16n8k8_store"): + with T.sblock("m16n8k8_store"): vi, vj = T.axis.remap("SS", [i, j]) dst[vi, vj] = src[vi, vj] diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py b/python/tvm/tir/tensor_intrin/dot_product_common.py index db10422c8ef4..19e7646f73da 100644 --- a/python/tvm/tir/tensor_intrin/dot_product_common.py +++ b/python/tvm/tir/tensor_intrin/dot_product_common.py @@ -32,11 +32,11 @@ def dp4a_desc( B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) for i in range(0, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.remap("R", [i]) C[0] = C[0] + T.cast(A[vi], dtype_c) * T.cast(B[vi], dtype_c) @@ -46,7 +46,7 @@ def dp4a_impl( B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 631d6b353240..7488ae7f3ae6 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -30,11 +30,11 @@ def generate_dma_load_intrin( def sync_dma_load_desc(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (size), dtype, offset_factor=1, scope="global") C = T.match_buffer(c, (size), dtype, offset_factor=1, scope="global.vtcm") - with T.block("root"): + with T.sblock("root"): T.reads(A[0:size]) T.writes(C[0:size]) for i in T.serial(size): - with T.block("load"): + with T.sblock("load"): vii = T.axis.remap("S", [i]) C[vii] = A[vii] @@ -42,7 +42,7 @@ def sync_dma_load_desc(a: T.handle, c: T.handle) -> None: def sync_dma_load_impl(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (size), dtype, offset_factor=1, scope="global") C = T.match_buffer(c, (size), dtype, offset_factor=1, scope="global.vtcm") - with T.block("root"): + with T.sblock("root"): T.reads(A[0:size]) T.writes(C[0:size]) T.evaluate( @@ -81,12 +81,12 @@ def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) for i in T.serial(0, 32): for k in T.serial(0, 4): - with T.block("update"): + with T.sblock("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") @@ -95,7 +95,7 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) @@ -122,12 +122,12 @@ def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) for i in T.serial(0, 32): for k in T.serial(0, 4): - with T.block("update"): + with T.sblock("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") @@ -136,7 +136,7 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) @@ -163,12 +163,12 @@ def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> No A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32], A[0:2], B[0:32, 0:2]) T.writes(C[0:32]) for i in T.serial(0, 32): for k in T.serial(0, 2): - with T.block("update"): + with T.sblock("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") @@ -177,7 +177,7 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32], A[0:2], B[0:32, 0:2]) T.writes(C[0:32]) diff --git a/python/tvm/tir/tensor_intrin/metal.py b/python/tvm/tir/tensor_intrin/metal.py index be34a9e266c8..feb640da6409 100644 --- a/python/tvm/tir/tensor_intrin/metal.py +++ b/python/tvm/tir/tensor_intrin/metal.py @@ -42,11 +42,11 @@ def get_make_filled_simdgroup_matrix_intrin( @T.prim_func def desc(a: T.handle) -> None: A = T.match_buffer(a, (col, row), dtype, scope="metal.simdgroup", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(A[0:col, 0:row]) for i, j in T.grid(col, row): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = T.float32(0) @@ -56,7 +56,7 @@ def impl(a: T.handle) -> None: A = T.match_buffer( a, (col, row), dtype, scope="metal.simdgroup", strides=[d1, d0], offset_factor=1 ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(A[0:col, 0:row]) T.make_filled_simdgroup_matrix( @@ -85,11 +85,11 @@ def desc(a: T.handle, c: T.handle) -> None: C = T.match_buffer( c, (col, row), dtype, align=align, scope="metal.simdgroup", offset_factor=1 ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:col, 0:row]) T.writes(C[0:col, 0:row]) for i, j in T.grid(col, row): - with T.block("load"): + with T.sblock("load"): vii, vjj = T.axis.remap("SS", [i, j]) if transpose_matrix: # C[vii, vjj] = A[vjj, vii] @@ -118,7 +118,7 @@ def impl(a: T.handle, c: T.handle) -> None: strides=[d1, d0], offset_factor=1, ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:col, 0:row]) T.writes(C[0:col, 0:row]) T.simdgroup_load( @@ -149,11 +149,11 @@ def desc(a: T.handle, c: T.handle) -> None: a, (col, row), dtype, align=align, scope="metal.simdgroup", offset_factor=1 ) C = T.match_buffer(c, (col, row), dtype, align=align, scope=scope, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:col, 0:row]) T.writes(C[0:col, 0:row]) for i, j in T.grid(col, row): - with T.block("store"): + with T.sblock("store"): vii, vjj = T.axis.remap("SS", [i, j]) if transpose_matrix: C[vjj, vii] = A[vii, vjj] @@ -175,7 +175,7 @@ def impl(a: T.handle, c: T.handle) -> None: C = T.match_buffer( c, (col, row), dtype, align=align, scope=scope, strides=[d1, d0], offset_factor=1 ) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:col, 0:row]) T.writes(C[0:col, 0:row]) T.simdgroup_store( @@ -199,11 +199,11 @@ def desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (m_dim, k_dim), dtype, scope="metal.simdgroup", offset_factor=1) B = T.match_buffer(b, (k_dim, n_dim), dtype, scope="metal.simdgroup", offset_factor=1) C = T.match_buffer(c, (m_dim, n_dim), dtype, scope="metal.simdgroup", offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) for i, j, k in T.grid(m_dim, n_dim, k_dim): - with T.block(""): + with T.sblock(""): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] += A[vii, vkk] * B[vkk, vjj] @@ -219,7 +219,7 @@ def impl(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer( c, (m_dim, n_dim), dtype, scope="metal.simdgroup", strides=[c1, c0], offset_factor=1 ) - with T.block("root"): + with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) T.simdgroup_multiply_accumulate( diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py index e0782ada4cc1..a44279f95b19 100644 --- a/python/tvm/tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -74,12 +74,12 @@ def rvv_vec_dot_prod_desc( B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1), C: T.Buffer((n_lanes,), out_dtype, offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems]) T.writes(C[0:n_lanes]) for j in T.serial(0, n_lanes): for k in T.serial(0, n_elems): - with T.block("update"): + with T.sblock("update"): vj, vk = T.axis.remap("SR", [j, k]) C[vj] = C[vj] + T.cast(A[vk], out_dtype) * T.cast(B[vj, vk], out_dtype) @@ -106,7 +106,7 @@ def rvv_vec_dot_prod_impl( B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1), C: T.Buffer((n_lanes,), out_dtype, offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems]) T.writes(C[0:n_lanes]) @@ -118,7 +118,7 @@ def rvv_vec_dot_prod_impl( T.int64(n_elems)) for i in range(n_lanes): - with T.block("update"): + with T.sblock("update"): T.reads(B[i, 0:n_elems]) T.writes(C[i]) diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py index bfac2ca1d25b..b6cce551f95f 100644 --- a/python/tvm/tir/tensor_intrin/rocm.py +++ b/python/tvm/tir/tensor_intrin/rocm.py @@ -33,7 +33,7 @@ def sdot4( B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) @@ -125,11 +125,11 @@ def get_mma_fill_intrin(dtype, local_size): def mma_fill_desc(a: T.handle) -> None: C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C_warp[0:WARP_SIZE, 0:local_size]) for i0, i1 in T.grid(M_DIM, N_DIM): - with T.block("C_warp"): + with T.sblock("C_warp"): i, j = T.axis.remap("SS", [i0, i1]) thread_id, local_id = T.meta_var(index_map(i, j)) T.reads() @@ -142,7 +142,7 @@ def mma_fill_impl(a: T.handle) -> None: a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 ) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(C_warp[0:WARP_SIZE, 0:local_size]) tx = T.env_thread("threadIdx.x") @@ -212,12 +212,12 @@ def mfma_load_desc(reg_handle: T.handle, memory_handle: T.handle) -> None: reg_handle, (WARP_SIZE, local_size), dtype, offset_factor=1, scope="warp" ) - with T.block("root"): + with T.sblock("root"): T.reads(memory[0:row_dim, 0:col_dim]) T.writes(reg[0:WARP_SIZE, 0:local_size]) for ax0, ax1 in T.grid(row_dim, col_dim): - with T.block("memory_reg"): + with T.sblock("memory_reg"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(memory[v0, v1]) @@ -243,7 +243,7 @@ def mfma_load_impl(reg_handle: T.handle, memory_handle: T.handle) -> None: reg_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=1, scope="warp" ) - with T.block("root"): + with T.sblock("root"): T.reads(memory[0:row_dim, 0:col_dim]) T.writes(reg[0:WARP_SIZE, 0:local_size]) tx = T.env_thread("threadIdx.x") @@ -291,7 +291,7 @@ def mfma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], @@ -300,7 +300,7 @@ def mfma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: T.writes(C[0:WARP_SIZE, 0:local_size_out]) for i, j, k in T.grid(M_DIM, N_DIM, k_dim): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j)) @@ -325,7 +325,7 @@ def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads( A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], @@ -351,7 +351,7 @@ def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads( A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], @@ -387,11 +387,11 @@ def mfma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) - with T.block("root"): + with T.sblock("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) for i0, i1 in T.grid(M_DIM, N_DIM): - with T.block("C_warp"): + with T.sblock("C_warp"): v0, v1 = T.axis.remap("SS", [i0, i1]) thread_id, local_id = T.meta_var(index_map(v0, v1)) T.reads(C_warp[thread_id, local_id]) @@ -410,7 +410,7 @@ def mfma_store_impl(a: T.handle, c: T.handle) -> None: c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] ) - with T.block("root"): + with T.sblock("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) T.writes(C[0:M_DIM, 0:N_DIM]) tx = T.env_thread("threadIdx.x") diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index 8f9518ce459f..42965fee1e1a 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -30,12 +30,12 @@ def dot_product_16x4_u8i8i32_desc( B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) for i in T.serial(0, 16): for k in T.serial(0, 4): - with T.block("update"): + with T.sblock("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") @@ -46,7 +46,7 @@ def dot_product_16x4_u8i8i32_vnni( B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) @@ -72,7 +72,7 @@ def dot_product_16x4_u8i8i32_avx512( B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 88cf4720d3a6..86d79dc6badd 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -834,7 +834,7 @@ def CompactBufferAllocation(is_strict: bool = True): .. code-block:: python for i in range(0, 16): - with T.block(): + with T.sblock(): B = T.alloc_buffer(16, 16) for j in range(0, 16): B[i, j] = A[i, j] + 1 @@ -849,7 +849,7 @@ def CompactBufferAllocation(is_strict: bool = True): .. code-block:: python for i in range(0, 16): - with T.block(): + with T.sblock(): B = T.alloc_buffer(1, 16) for j in range(0, 16): B[0, j] = A[i, j] + 1 diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 59b0b0546dab..8fb6dba8764a 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -67,7 +67,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { return StmtExprMutator::VisitStmt_(op); } -Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const SBlockNode* op) { for (const auto& iter_var : op->iter_vars) { analyzer_->Bind(iter_var->var, iter_var->dom); iter_vars_.Set(iter_var->var, iter_var->dom); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 28f8e600d38e..5b5ac7e6cd2a 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -51,7 +51,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { // override functions that need to populate the context information. tir::Stmt VisitStmt_(const tir::ForNode* op) override; - tir::Stmt VisitStmt_(const tir::BlockNode* op) override; + tir::Stmt VisitStmt_(const tir::SBlockNode* op) override; tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override; tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index dba4567f88ec..88eff9fc2c42 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -36,7 +36,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { StmtExprVisitor::VisitStmt_(op); } -void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) { +void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) { for (const auto& iter_var : op->iter_vars) { analyzer_.Bind(iter_var->var, iter_var->dom); } diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 416b2af196bd..cd2b9bfdec26 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -40,7 +40,7 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { using StmtExprVisitor::VisitStmt_; void VisitStmt_(const tir::ForNode* op); - void VisitStmt_(const tir::BlockNode* op); + void VisitStmt_(const tir::SBlockNode* op); void VisitStmt_(const tir::LetStmtNode* op); void VisitStmt_(const tir::IfThenElseNode* op); void VisitStmt_(const tir::AttrStmtNode* op); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 9072ccf62a94..aed8e21ffa42 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -1209,7 +1209,7 @@ class WorkloadEmbeddingExtractor : private StmtVisitor { } private: - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { StmtVisitor::VisitStmt_(block); std::string name = block->name_hint; std::for_each(name.begin(), name.end(), [](char& c) { c = ::tolower(c); }); @@ -1327,7 +1327,7 @@ class PerStoreFeatureCollector : private StmtVisitor { feature.group5 = std::make_unique(loop_nest_); } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { StmtVisitor::VisitStmt_(block); for (const Buffer& buffer : block->alloc_buffers) { HandleBufferAlloc(buffer); diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 8eb1f46b0b22..ce25417dcd58 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -56,7 +56,7 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = tir::FindAnchorBlock(mod); if (anchor_block) { - return ffi::StructuralHash::Hash(ffi::GetRef(anchor_block), + return ffi::StructuralHash::Hash(ffi::GetRef(anchor_block), /*map_free_vars=*/false, /*skip_tensor_content=*/true); } @@ -66,8 +66,8 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { auto anchor_block_lhs = tir::FindAnchorBlock(lhs); auto anchor_block_rhs = tir::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return tvm::ffi::StructuralEqual::Equal(ffi::GetRef(anchor_block_lhs), - ffi::GetRef(anchor_block_rhs), + return tvm::ffi::StructuralEqual::Equal(ffi::GetRef(anchor_block_lhs), + ffi::GetRef(anchor_block_rhs), /*map_free_vars=*/false, /*skip_tensor_content=*/true); } diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 4ad979648aca..02aa6d898e27 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -92,7 +92,7 @@ std::vector MutateComputeLocationNode::Fin if (inst->kind.same_as(inst_sample_compute_location)) { // Step 1. Extract the instruction input and the old decision. ICHECK_EQ(inputs.size(), 1); - tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); + tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); int old_decision = Downcast(decision)->value; // Step 2. Collect all the compute_at locations. diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 66266dd2a539..fa056b27444c 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -60,13 +60,13 @@ Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { * \param inst The instruction to be checked * \return The output of the instruction Get-Block */ -const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { - static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock"); - if (!inst->kind.same_as(inst_get_block)) { +const SBlockRVNode* GetInstGetSBlockOutput(const Instruction& inst) { + static const InstructionKind& inst_get_sblock = InstructionKind::Get("GetSBlock"); + if (!inst->kind.same_as(inst_get_sblock)) { return nullptr; } ICHECK_EQ(inst->outputs.size(), 1); - const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode); + const SBlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], SBlockRVNode); return block; } @@ -82,13 +82,13 @@ std::vector> AnalyzeParallel(const ScheduleState& self, const ffi::String& block_name, const ffi::String& func_name, int64_t limit) { ffi::Array block_srefs = - tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); + tir::GetSBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); - const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]); - ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_srefs[0]); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(ffi::GetRef(block)); std::vector> results; results.reserve(info.realizes.size()); - for (const BlockRealize& realize : info.realizes) { + for (const SBlockRealize& realize : info.realizes) { // Step 1. Extract static loop extents for spatial loops std::vector loop_extents; const ForNode* loop = nullptr; @@ -217,18 +217,18 @@ struct MutateParallelNode::Candidate { */ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, MutateParallelNode::Candidate* candidate) { - using tir::BlockRVNode; using tir::InstructionNode; - std::unordered_map get_block_insts; + using tir::SBlockRVNode; + std::unordered_map get_sblock_insts; std::vector ann_insts; - get_block_insts.reserve(trace->insts.size()); + get_sblock_insts.reserve(trace->insts.size()); ann_insts.reserve(trace->insts.size()); for (const Instruction& inst : trace->insts) { if (tir::IsAnnotateWithParallel(inst)) { ann_insts.push_back(inst.get()); } - if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { - get_block_insts[block_rv] = inst.get(); + if (const SBlockRVNode* block_rv = tir::GetInstGetSBlockOutput(inst)) { + get_sblock_insts[block_rv] = inst.get(); } } int n_ann_insts = ann_insts.size(); @@ -237,13 +237,13 @@ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, } const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; ICHECK_EQ(ann_inst->inputs.size(), 2); - const InstructionNode* get_block_inst = - get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); - ICHECK_EQ(get_block_inst->attrs.size(), 2); + const InstructionNode* get_sblock_inst = + get_sblock_insts.at(Downcast(ann_inst->inputs[0]).get()); + ICHECK_EQ(get_sblock_inst->attrs.size(), 2); candidate->inst = ffi::GetRef(ann_inst); candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; - candidate->block_name = Downcast(get_block_inst->attrs[0]); - candidate->func_name = Downcast(get_block_inst->attrs[1]); + candidate->block_name = Downcast(get_sblock_inst->attrs[0]); + candidate->func_name = Downcast(get_sblock_inst->attrs[1]); return true; } diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index ae7b693efd94..6b2fa17bd20b 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -52,8 +52,8 @@ ffi::Optional ParseThreadBinding(const Schedule& sch, const Instruction * \param vector_lane The number of vector lane in vectorized cooperative fetching * \return std::nullopt if parsing fails; Otherwise, the annotated block */ -ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, - int64_t* vector_lane) { +ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, + int64_t* vector_lane) { static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); if (!inst->kind.same_as(inst_kind_annotate)) { return std::nullopt; @@ -65,7 +65,7 @@ ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& ins return std::nullopt; } *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; - return Downcast(inst->inputs[0]); + return Downcast(inst->inputs[0]); } /*! @@ -85,7 +85,7 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { return ann_key == attr::warp_execution; } -size_t GetMaxUsedDtypeBytes(Block block) { +size_t GetMaxUsedDtypeBytes(SBlock block) { size_t max_bytes = 1; static auto q_multiply_shift_per_axis = Op::Get("tir.q_multiply_shift_per_axis"); static auto q_multiply_shift = Op::Get("tir.q_multiply_shift"); @@ -168,7 +168,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { thread_extent_x = thread_warp_size_; continue; } - ffi::Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); + ffi::Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); if (!opt_block_rv.defined()) { continue; } diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 17acdcc9bf2f..88fe0419b5ca 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -34,7 +34,7 @@ class BufferReadPosCollector : public StmtExprVisitor { public: explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {} - const std::pair& GetBufferLocation() const { return buffer_loc_; } + const std::pair& GetBufferLocation() const { return buffer_loc_; } const ffi::Optional GetBufferIndexMap() const { return buffer_index_map_; } @@ -45,8 +45,8 @@ class BufferReadPosCollector : public StmtExprVisitor { loop_stack_.pop_back(); } - void VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize outer_block_realize = ffi::GetRef(op); + void VisitStmt_(const SBlockRealizeNode* op) final { + SBlockRealize outer_block_realize = ffi::GetRef(op); std::swap(outer_block_realize, cur_realize_); StmtVisitor::VisitStmt_(op); std::swap(cur_realize_, outer_block_realize); @@ -78,7 +78,7 @@ class BufferReadPosCollector : public StmtExprVisitor { } } - static int GetReadBufferIndex(const Block& block, const Buffer& buffer) { + static int GetReadBufferIndex(const SBlock& block, const Buffer& buffer) { for (size_t i = 0; i < block->reads.size(); i++) { if (block->reads[i]->buffer.same_as(buffer)) { return i; @@ -91,7 +91,7 @@ class BufferReadPosCollector : public StmtExprVisitor { /*! \brief The buffer of interest. */ const BufferNode* buffer_; /*! \brief The block that consumes the buffer and the corresponding read index. */ - std::pair buffer_loc_; + std::pair buffer_loc_; /*! \brief The proposed IndexMap. */ ffi::Optional buffer_index_map_; @@ -100,12 +100,12 @@ class BufferReadPosCollector : public StmtExprVisitor { /*! \brief Arithmetic analyzer. */ arith::Analyzer analyzer_; /*! \brief Current BlockRealize scope, used in recursive visit */ - BlockRealize cur_realize_; + SBlockRealize cur_realize_; }; class LayoutFreeBufferCollector : public StmtVisitor { public: - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { StmtVisitor::VisitStmt_(block); if (auto ann = block->annotations.Get("layout_free_placeholders")) { for (Buffer buffer : Downcast>(ann.value())) { @@ -138,7 +138,7 @@ ffi::Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { return layout_free_buffers; } -std::optional> GetSuggestedIndexMap( +std::optional> GetSuggestedIndexMap( Buffer buffer, const PrimFuncNode* prim_func) { BufferReadPosCollector collector(buffer); collector(prim_func->body); @@ -160,7 +160,7 @@ std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode public: explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {} - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_. if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 && op->reads[0]->buffer.get() == cur_buffer_) { @@ -183,8 +183,8 @@ std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode bool RewriteLayout(const Schedule& sch) { std::vector> results; - auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { - BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); + auto add_layout_rewrite_block = [&sch](SBlockRV consumer_block_rv, int buffer_index) { + SBlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, true); }; @@ -205,7 +205,7 @@ bool RewriteLayout(const Schedule& sch) { if (tup_opt == std::nullopt) continue; auto [anchor_block, buffer_index, index_map] = *tup_opt; - auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name); + auto anchor_block_rv = sch->GetSBlock(anchor_block->name_hint, func_name); add_layout_rewrite_block(anchor_block_rv, buffer_index); sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map, std::nullopt); @@ -213,20 +213,20 @@ bool RewriteLayout(const Schedule& sch) { // When the layout-free buffer is consumed by cache_read, we need to find the index map // for a cache-read buffer that is directly consumed by an anchor op. The last buffer // in cache_read_chain corresponds to that buffer. - Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name)); + SBlock cache_read_block = sch->Get(sch->GetSBlock(cache_read_chain.back(), func_name)); ICHECK_EQ(cache_read_block->writes.size(), 1); auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func); if (tup_opt == std::nullopt) continue; auto [anchor_block, buffer_index, index_map] = *tup_opt; // Transform the layout of the last cache-read buffer. - sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index, + sch->TransformLayout(sch->GetSBlock(anchor_block->name_hint, func_name), buffer_index, BufferIndexType::kRead, index_map, std::nullopt); // Propagate the layout transformation over cache_read_chain, starting from // the next-to-last cache-read buffer. for (int i = static_cast(cache_read_chain.size()) - 1; i >= 0; --i) { - BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name); + SBlockRV cache_read_block_rv = sch->GetSBlock(cache_read_chain[i], func_name); if (i == 0) { // Before the first cache_read that consumes the layout-free buffer, insert // a layout-rewrite block. Another cache-read buffer is added, and its layout is diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 5950ef742d49..b2cbf2701043 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -102,7 +102,7 @@ struct ParsedAnnotation { int num_vectorize_loops; }; -bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { +bool ParseAnnotation(const SBlock& block, ParsedAnnotation* parsed) { bool found = false; *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; for (const auto& ann : block->annotations) { @@ -131,7 +131,8 @@ bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { return found; } -void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { +void RemoveParsedAnn(const Schedule& sch, const SBlockRV& block_rv, + const ParsedAnnotation& parsed) { if (parsed.max_parallel_extent != -1) { sch->Unannotate(block_rv, attr::meta_schedule_parallel); } @@ -173,7 +174,7 @@ int CalculateNumRewritableLoops(const ffi::Array& loop_srefs, return rw_loops_num; } -void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, +void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, const ffi::Array& loop_rvs, ParsedAnnotation* parsed) { StmtSRef block_sref = sch->GetSRef(block_rv); if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { @@ -197,7 +198,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } } // check the maximal number of axes that are vectorizable (contiguous memory access) - BlockRealize realize = GetBlockRealize(sch->state(), block_sref); + SBlockRealize realize = GetSBlockRealize(sch->state(), block_sref); ffi::Array buffer_access(realize->block->reads); buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), realize->block->writes.end()); @@ -337,17 +338,17 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } } -bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { +bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, SBlockRV* root_rv) { IRModule mod = sch->mod(); for (const auto& kv : mod->functions) { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { - const BlockRealizeNode* block_realize = prim_func->body.as(); + const SBlockRealizeNode* block_realize = prim_func->body.as(); if (block_realize != nullptr) { - Block block = block_realize->block; + SBlock block = block_realize->block; if (ParseAnnotation(block, parsed)) { - *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); + *root_rv = sch->GetSBlock(block->name_hint, g_var->name_hint); RemoveParsedAnn(sch, *root_rv, *parsed); return true; } @@ -392,7 +393,7 @@ void RewriteVectorize(const Schedule& sch, size_t n, ffi::Array* loop_rv } } -void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const BlockRV& block, +void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const SBlockRV& block, const LoopRV& loop) { // Do not unroll for pure spatial block. if (max_step <= 0 || IsSpatial(sch->GetSRef(block))) { @@ -415,9 +416,9 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { bool Apply(const Schedule& sch) final { tir::ParsedAnnotation parsed_root; - tir::BlockRV root_rv{ffi::UnsafeInit()}; + tir::SBlockRV root_rv{ffi::UnsafeInit()}; while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { - for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { + for (tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { ffi::Array loop_rvs = sch->GetLoops(block_rv); if (loop_rvs.empty()) { continue; diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index fffef8ba6856..f65ca90e7783 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -35,7 +35,7 @@ struct ReductionBlockFinder : private StmtVisitor { if (const auto* prim_func = base_func.as()) { ReductionBlockFinder finder; finder(prim_func->body); - for (const BlockNode* block : finder.results_) { + for (const SBlockNode* block : finder.results_) { results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); } } @@ -52,19 +52,19 @@ struct ReductionBlockFinder : private StmtVisitor { StmtVisitor::VisitStmt_(loop); } - void VisitStmt_(const BlockRealizeNode* realize) final { + void VisitStmt_(const SBlockRealizeNode* realize) final { if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { results_.push_back(realize->block.get()); } StmtVisitor::VisitStmt_(realize); } - bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { + bool AllReductionIterVarAreUnbound(const SBlockRealizeNode* realize) const { if (thread_bound_loop_vars_.empty()) { return true; } auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; - const BlockNode* block = realize->block.get(); + const SBlockNode* block = realize->block.get(); ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); int n = block->iter_vars.size(); for (int i = 0; i < n; ++i) { @@ -80,7 +80,7 @@ struct ReductionBlockFinder : private StmtVisitor { } /*! \brief The results of the collection */ - std::vector results_; + std::vector results_; /*! \brief Loop variables that are bound to threads */ std::unordered_set thread_bound_loop_vars_; }; @@ -142,9 +142,9 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { if (decompose_point == -1) { continue; } - tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + tir::SBlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); ffi::Array loop_rvs = sch->GetLoops(block_rv); - tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + tir::SBlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); // Rewrite auto tensorization related annotations if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize) diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 473731b5a7b5..e3490af29072 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -26,21 +26,21 @@ namespace tvm { namespace meta_schedule { -using tir::BlockRV; using tir::LoopRV; +using tir::SBlockRV; void CollectTensorizationJobs( const tir::Schedule& sch, const ffi::String& func_name, const tir::PrimFuncNode* func, bool vectorize_init_loop, - std::vector>>* jobs) { + std::vector>>* jobs) { tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { - if (const auto* block = obj.as()) { + if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); - std::string block_name = block_sref->StmtAs()->name_hint; + std::string block_name = block_sref->StmtAs()->name_hint; if (ffi::Optional intrin_name = tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { if (intrin_name.value() != "") { - jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { + jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::SBlockRV block) { try { sch->Tensorize(block, intrin_name.value()); } catch (const std::exception& e) { @@ -48,8 +48,8 @@ void CollectTensorizationJobs( } }); } else if (block_name.find("init") && vectorize_init_loop) { - jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { - ffi::Array child_blocks = sch->GetChildBlocks(block); + jobs->emplace_back(block_name, func_name, [sch](tir::SBlockRV block) { + ffi::Array child_blocks = sch->GetChildBlocks(block); ICHECK(child_blocks.size() == 1); ffi::Array init_loops = sch->GetLoops(child_blocks[0]); ICHECK(init_loops.size() == 1); @@ -85,7 +85,7 @@ class RewriteTensorizeNode : public PostprocNode { bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { // The rewriting jobs, 3-tuple (block_name, func_name, job_func) - std::vector>> jobs; + std::vector>> jobs; for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; @@ -97,7 +97,7 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { const ffi::String& block_name = std::get<0>(job); const ffi::String& func_name = std::get<1>(job); const auto& job_func = std::get<2>(job); - BlockRV block = sch->GetBlock(block_name, func_name); + SBlockRV block = sch->GetSBlock(block_name, func_name); sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); job_func(block); } diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 98e3db2522f1..08580830965b 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -34,7 +34,7 @@ class UnboundBlockFinder : private StmtVisitor { BaseFunc base_func = kv.second; if (const auto* prim_func = base_func.as()) { finder.global_var_name_ = g_var->name_hint; - finder(Downcast(prim_func->body)->block->body); + finder(Downcast(prim_func->body)->block->body); } } return std::move(finder.blocks_); @@ -58,7 +58,7 @@ class UnboundBlockFinder : private StmtVisitor { } } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_); } @@ -119,9 +119,9 @@ class RewriteUnboundBlockNode : public PostprocNode { }; bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { - using tir::BlockRV; using tir::ExprRV; using tir::LoopRV; + using tir::SBlockRV; using tir::Schedule; ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { @@ -132,7 +132,7 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { for (const auto& kv : unbound_blocks) { tir::StmtSRef block_sref = kv.first; ffi::String global_var_name = kv.second; - BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + SBlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); } return true; diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 04a9cf2ea79b..bdcb1af1fe41 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -68,7 +68,7 @@ class ThreadExtentChecker : private StmtVisitor { StmtVisitor::VisitStmt_(loop); } - void VisitStmt_(const BlockNode* block) { + void VisitStmt_(const SBlockNode* block) { int old_thread_idx_x = thread_idx_x; if (block->annotations.count(attr::warp_execution)) { thread_idx_x = thread_warp_size_; diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index c3fd12e282b3..dfa5a3969118 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -26,7 +26,7 @@ namespace meta_schedule { using namespace tvm::tir; -static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, +static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::SBlockRV block, std::vector tiled, std::vector unrolled) { using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); @@ -64,9 +64,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> ffi::Array { - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + [](Schedule sch, SBlockRV data_pack) -> ffi::Array { + SBlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + SBlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), /*preserve_unit_loops=*/true); @@ -75,15 +75,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { return {sch}; }) .def("meta_schedule.cpu.conv2d_nhwc_winograd_inverse", - [](Schedule sch, BlockRV block) -> ffi::Array { + [](Schedule sch, SBlockRV block) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); return {sch}; }) .def("meta_schedule.cpu.conv2d_nchw_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> ffi::Array { - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + [](Schedule sch, SBlockRV data_pack) -> ffi::Array { + SBlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + SBlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), /*preserve_unit_loops=*/true); @@ -92,7 +92,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { return {sch}; }) .def("meta_schedule.cpu.conv2d_nchw_winograd_inverse", - [](Schedule sch, BlockRV block) -> ffi::Array { + [](Schedule sch, SBlockRV block) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); return {sch}; diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index 2a042553d6b9..d80fefc6cc5d 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -84,7 +84,7 @@ ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_thread } } -void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // +void BindBlockThreadIdx(tir::Schedule sch, tir::SBlockRV block_rv, // int64_t max_threadblocks, int64_t max_threads_per_block, std::function get_factor) { using namespace tvm::tir; diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 74a70da58b36..62d8c767e293 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -29,7 +29,7 @@ namespace meta_schedule { using namespace tvm::tir; -static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, +static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::SBlockRV block, std::vector tiled, std::vector unrolled) { // This method is used for NHWC layout only. Will likely be refactored into a more schedule using namespace tvm::tir; @@ -68,12 +68,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> ffi::Array { - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + [](Schedule sch, SBlockRV data_pack) -> ffi::Array { + SBlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + SBlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ffi::Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); { - BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); + SBlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); } { @@ -92,7 +92,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { return {sch}; }) .def("meta_schedule.cuda.conv2d_nhwc_winograd_inverse", - [](Schedule sch, BlockRV inverse) -> ffi::Array { + [](Schedule sch, SBlockRV inverse) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, inverse); ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); int64_t max_threadblocks = 256; @@ -104,11 +104,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { return {sch}; }) .def("meta_schedule.cuda.conv2d_nchw_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> ffi::Array { + [](Schedule sch, SBlockRV data_pack) -> ffi::Array { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; - BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); - BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); + SBlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); + SBlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); LoopRV outer{ffi::UnsafeInit()}; { ffi::Array loops = sch->GetLoops(data_pack); @@ -123,7 +123,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .back(); } { - BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); + SBlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); sch->ReverseComputeAt(data_pack_local, outer, /*preserve_unit_loops=*/true); } { @@ -134,14 +134,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { return {sch}; }) .def("meta_schedule.cuda.conv2d_nchw_winograd_inverse", - [](Schedule sch, BlockRV inverse) -> ffi::Array { + [](Schedule sch, SBlockRV inverse) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, inverse); // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] int64_t tile_size = Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; LoopRV outer{ffi::UnsafeInit()}; { - BlockRV output = sch->GetConsumers(inverse)[0]; + SBlockRV output = sch->GetConsumers(inverse)[0]; ffi::Array nchw = sch->GetLoops(output); ICHECK_EQ(nchw.size(), 4); ffi::Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); diff --git a/src/meta_schedule/schedule/generic/winograd.cc b/src/meta_schedule/schedule/generic/winograd.cc index fe41e1e686f1..a3c75f33cb53 100644 --- a/src/meta_schedule/schedule/generic/winograd.cc +++ b/src/meta_schedule/schedule/generic/winograd.cc @@ -28,10 +28,10 @@ using namespace tvm::tir; * If there is a constant winograd transform matrix, inline it. * \return The only producer block. */ -BlockRV GetWinogradProducerAndInlineConst(Schedule sch, BlockRV block) { - ffi::Array producers = sch->GetProducers(block); - ffi::Array results; - for (const BlockRV& producer : producers) { +SBlockRV GetWinogradProducerAndInlineConst(Schedule sch, SBlockRV block) { + ffi::Array producers = sch->GetProducers(block); + ffi::Array results; + for (const SBlockRV& producer : producers) { if (sch->Get(producer)->reads.empty()) { sch->ComputeInline(producer); } else { diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index fad3279eb792..2b730b0138a2 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -36,7 +36,7 @@ class AddRFactorNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv); // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { @@ -78,7 +78,7 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, } ffi::Array AddRFactorNode::Apply(const tir::Schedule& sch, - const tir::BlockRV& block_rv) { + const tir::SBlockRV& block_rv) { tir::StmtSRef block_sref = sch->GetSRef(block_rv); if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, max_parallel_basic_)) { @@ -106,7 +106,7 @@ ffi::Array AddRFactorNode::Apply(const tir::Schedule& sch, tir::Schedule sch_tmp = sch->Copy(); sch_tmp->Seed(sch->ForkSeed()); try { - const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); + const tir::SBlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); ffi::Array axes = sch_tmp->GetLoops(block_rf); ICHECK_GT(axes.size(), num_spatial_loops); diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 927ce3656c2f..bdfd9b525690 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -36,7 +36,7 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { CHECK(this->target_.defined()) << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target."; ffi::Array keys = this->target_.value()->keys; diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 1ab276c5bec7..2fbf013e82da 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -40,7 +40,7 @@ class AutoBindNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { @@ -64,7 +64,7 @@ class AutoBindNode : public ScheduleRuleNode { }; ffi::Array AutoBindNode::Apply(const tir::Schedule& sch, - const tir::BlockRV& block_rv) { + const tir::SBlockRV& block_rv) { ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = MakeFactorSampler(sch, this->thread_extents_); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 3d5fc8798c13..5c065e6b4738 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -38,7 +38,7 @@ bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sr const StmtSRefNode* sref = block_sref.get(); for (; sref->parent != nullptr; sref = sref->parent) { } - ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance()); + ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance()); return IsSpatialPrimFunc(ffi::GetRef(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr))); } @@ -46,13 +46,13 @@ bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sr class AutoInlineNode : public ScheduleRuleNode { public: /*! \brief Checks if the specific block should be inlined */ - inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); + inline InlineType CheckInline(const tir::Schedule& sch, const tir::SBlockRV& block_rv); // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { InlineType inline_type = CheckInline(sch, block_rv); if (inline_type == InlineType::kInlineIntoConsumer) { sch->ComputeInline(block_rv); @@ -99,13 +99,13 @@ class AutoInlineNode : public ScheduleRuleNode { }; inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, - const tir::BlockRV& block_rv) { + const tir::SBlockRV& block_rv) { using namespace tvm::tir; StmtSRef block_sref = sch->GetSRef(block_rv); bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref); ScheduleState state = sch->state(); - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - BlockRealize realize = GetBlockRealize(state, block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + SBlockRealize realize = GetSBlockRealize(state, block_sref); // Cond 1. The block has only one write buffer if (block->writes.size() != 1) { return InlineType::kNoInline; @@ -205,7 +205,7 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { public: void InitializeWithTuneContext(const TuneContext& context) final {} - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { // Look for a block of the form // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) { // reads([]) diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 17e9552dcb60..1d70f21199a4 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -49,7 +49,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { // Step 0. Check the conditions of this rule. if (max_threads_per_block == -1 || warp_size == -1) { return {sch}; @@ -68,10 +68,10 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // block to its consumers. We want to fuse as much as possible because it results in // significantly faster schedule. // `target_loop` is the loop position where the input block will be computed at. - // `target_block` is the consumer block that we want to compute-at the input block to. + // `target_sblock` is the consumer block that we want to compute-at the input block to. // `tgt_block_innermost_loop` is the innermost loop outside the target block. - auto [fusible, target_loop, target_block, tgt_block_innermost_loop] = + auto [fusible, target_loop, target_sblock, tgt_block_innermost_loop] = GetComputeTargetLoopAndBlock(tmp_sch, block_rv); // Step 3. Try block fusion. @@ -79,15 +79,16 @@ class CrossThreadReductionNode : public ScheduleRuleNode { ffi::Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { - ICHECK(target_block.defined()); + ICHECK(target_sblock.defined()); ICHECK(target_loop.defined()); // Step 3.1. - // - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first - // bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split + // - If the outer loops of `target_sblock` haven't been bound to "threadIdx.x", we should + // first + // bound the innermost outer loop of `target_sblock` to threadIdx. Possibly we need to split // the loop before binding. // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. - if (!InThreadScope(tmp_sch, target_block)) { + if (!InThreadScope(tmp_sch, target_sblock)) { const ffi::Array& split_res = tmp_sch->Split(tgt_block_innermost_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); @@ -130,7 +131,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param block The block to be checked * \return A boolean indicating whether the block is in thread scope. */ - bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { + bool InThreadScope(const tir::Schedule& sch, const tir::SBlockRV& block) { const ffi::Array& axes = sch->GetLoops(block); for (const tir::LoopRV& loop_rv : axes) { const tir::For& loop = sch->Get(loop_rv); @@ -193,23 +194,23 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * 3. the first block under the target loop when fusible, or a null block random variable; * 4. the innermost loop outside the target block when fusible, or a null block random variable. */ - std::tuple GetComputeTargetLoopAndBlock( - const tir::Schedule& sch, const tir::BlockRV& block_rv) { + std::tuple GetComputeTargetLoopAndBlock( + const tir::Schedule& sch, const tir::SBlockRV& block_rv) { // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing // a tuple reduction, fusion is temporarily not supported. if (sch->Get(block_rv)->writes.size() != 1) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, - tir::LoopRV{ffi::UnsafeInit()}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, + tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); } // Step 1. Get all the consumers of the input block. - ffi::Array consumers = sch->GetConsumers(block_rv); + ffi::Array consumers = sch->GetConsumers(block_rv); // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is // not fusible. if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, - tir::LoopRV{ffi::UnsafeInit()}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, + tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); } // Step 3. Calculate the lowest common ancestor of all the consumers. @@ -219,10 +220,10 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // fusible; // - If the lowest common ancestor is a loop, the target block is also the first consumer. const tir::StmtSRef& lca_sref = - tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); - if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, - tir::LoopRV{ffi::UnsafeInit()}); + tir::GetSRefLowestCommonAncestor(tir::SBlockRVs2StmtSRefs(sch, consumers)); + if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, + tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); } // Step 4. Get the outer loops of the target block, and get the compute-at position index. @@ -231,8 +232,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 5. A negative position index means not fusible, and vice-versa. if (pos < 0) { - return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, - tir::LoopRV{ffi::UnsafeInit()}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, + tir::SBlockRV{ffi::UnsafeInit()}, tir::LoopRV{ffi::UnsafeInit()}); } else { return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index ea78c4f6e3d3..c1002b0ce2c0 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -31,7 +31,7 @@ namespace tvm { namespace tir { std::vector GetReadBufferNDims(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); const BufferNode* write_buffer = block->writes[0]->buffer.get(); int n = block->reads.size(); std::vector results(n, -1); @@ -50,14 +50,14 @@ std::vector GetReadBufferNDims(const StmtSRef& block_sref) { namespace tvm { namespace meta_schedule { -using tir::BlockRV; using tir::IterVarType; using tir::LoopRV; +using tir::SBlockRV; using tir::Schedule; TVM_FFI_STATIC_INIT_BLOCK() { MultiLevelTilingNode::RegisterReflection(); } -State::State(tir::Schedule sch, tir::BlockRV block_rv, ffi::Array> tiles) { +State::State(tir::Schedule sch, tir::SBlockRV block_rv, ffi::Array> tiles) { ObjectPtr node = ffi::make_object(); node->sch = std::move(sch); node->block_rv = std::move(block_rv); @@ -103,7 +103,7 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) } // Entry of the mega rule; Inherited from ScheduleRuleNode -ffi::Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { +ffi::Array MultiLevelTilingNode::Apply(const Schedule& sch, const SBlockRV& block_rv) { if ((filter_fn_ && filter_fn_.value()(sch, sch->GetSRef(block_rv)).cast()) || NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); @@ -149,7 +149,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { std::vector results; if (req == ReuseType::kMayReuse) { // Case 1. If the write cache is already there, we don't need to add another. - ffi::Array consumer_rvs = state->sch->GetConsumers(state->block_rv); + ffi::Array consumer_rvs = state->sch->GetConsumers(state->block_rv); if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) { for (int level : levels) { State new_state = state->Copy(); @@ -168,7 +168,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { } // Case 3. Add one write cache - BlockRV write_cache = + SBlockRV write_cache = state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0, /*storage_scope=*/config.scope); state->write_reuse.emplace(0, write_cache); @@ -182,7 +182,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { } std::pair, ffi::Array> MultiLevelTilingNode::SplitLoop( - const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { + const Schedule& sch, SBlockRV block, LoopRV loop, int n_tiles) const { ffi::Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, @@ -195,10 +195,10 @@ std::pair, ffi::Array> MultiLevelTilingNode std::vector MultiLevelTilingNode::TileLoopNest(State state, int tile_inner_most_space_loop_num) const { Schedule& sch = state->sch; - const BlockRV& block_rv = state->block_rv; + const SBlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types ffi::Array loops = sch->GetLoops(block_rv); - std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); + std::vector iter_types = GetSBlockVarTypes(sch->GetSRef(state->block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; @@ -290,7 +290,7 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { return {std::move(state)}; } ICHECK(config.req != ReuseType::kMayReuse); - const BlockRV& block_rv = state->block_rv; + const SBlockRV& block_rv = state->block_rv; std::vector results; results.reserve(config.levels.size()); for (int level : config.levels) { @@ -305,7 +305,7 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { continue; } // Do cache_read - BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope, {block_rv}); + SBlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope, {block_rv}); // Insert cache_read block to the proper place sch->ComputeAt(cache_read_block, loop_rv, true); // Fuse the iterators of the cache_read @@ -358,9 +358,9 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { } void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, - const tir::BlockRV& block) const { + const tir::SBlockRV& block) const { // Filter out invalid vector lanes according to the data type. - const tir::BlockNode* block_node = (*sch)->GetSRef(block)->StmtAs(); + const tir::SBlockNode* block_node = (*sch)->GetSRef(block)->StmtAs(); ICHECK_EQ(block_node->writes.size(), 1); const runtime::DataType dtype = block_node->writes[0]->buffer->dtype; std::function f_filter = nullptr; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 028d1aecbf45..19bfbd51c187 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -107,15 +107,15 @@ class StateNode : public Object { /*! \brief The schedule to date */ tir::Schedule sch; /*! \brief The block to be tiled */ - tir::BlockRV block_rv; + tir::SBlockRV block_rv; /*! \brief The loop tiles */ ffi::Array> tiles; /*! \brief The factors of the loop tiles. */ ffi::Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ - std::unordered_map read_reuse; + std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ - std::unordered_map write_reuse; + std::unordered_map write_reuse; /*! * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that @@ -131,7 +131,7 @@ class StateNode : public Object { class State : public ObjectRef { public: /*! \brief Default constructor */ - explicit State(tir::Schedule sch, tir::BlockRV block_rv, + explicit State(tir::Schedule sch, tir::SBlockRV block_rv, ffi::Array> tiles = {}); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(State, ObjectRef, StateNode); }; @@ -174,7 +174,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final; // Entry of the mega rule; Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) override; // Inherited from ScheduleRuleNode ScheduleRule Clone() const override; @@ -183,10 +183,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode { virtual std::vector ApplySubRules(std::vector states); virtual std::pair, ffi::Array> SplitLoop( - const tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop, int n_tiles) const; + const tir::Schedule& sch, tir::SBlockRV block, tir::LoopRV loop, int n_tiles) const; // Annotate a block to use cooperative fetching - void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; + void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::SBlockRV& block) const; public: /*! diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index c58e81dc3343..85705ea99876 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -30,9 +30,9 @@ namespace tvm { namespace meta_schedule { -using tir::BlockRV; using tir::IterVarType; using tir::LoopRV; +using tir::SBlockRV; using tir::Schedule; struct TensorCoreIntrinGroup { @@ -79,11 +79,11 @@ class TensorCoreStateNode : public StateNode { /*! \brief The auto tensorization maping info. */ tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()}; /*! \brief The Tensor Core reindex block A for Tensor Core computation */ - tir::BlockRV tensor_core_reindex_A; + tir::SBlockRV tensor_core_reindex_A; /*! \brief The Tensor Core reindex block B for Tensor Core computation */ - tir::BlockRV tensor_core_reindex_B; + tir::SBlockRV tensor_core_reindex_B; /*! \brief The Tensor Core reindex store block for Tensor Core computation */ - tir::BlockRV tensor_core_reindex_store; + tir::SBlockRV tensor_core_reindex_store; /*! \brief Flag to indicate its a WMMA or MMA intrin group */ bool is_mma; /*! \brief Flag to indicate whether to use async software pipeline */ @@ -104,7 +104,7 @@ class TensorCoreState : public State { public: explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, bool use_async, + SBlockRV block_rv, bool use_async, ffi::Array> tiles = {}); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorCoreState, State, TensorCoreStateNode); @@ -112,7 +112,7 @@ class TensorCoreState : public State { TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, bool use_async, + SBlockRV block_rv, bool use_async, ffi::Array> tiles) { ObjectPtr node = ffi::make_object(); node->intrin_group = intrin_group; @@ -154,7 +154,7 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { inline std::vector AddSoftwarePipeline(TensorCoreState state) const; // Subrule: split loop for mma using sample partitioned tile inline std::pair, ffi::Array> MMASplitLoop( - const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles, int partition_pos, + const Schedule& sch, SBlockRV block, LoopRV loop, int n_tiles, int partition_pos, int innerpart_factor) const; // Subrule: tile loop nest for mma // Basically same with MultiLevelTilingNode::TileLoopNest, but change SamplePerfectTile to @@ -165,7 +165,7 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { std::vector ApplySubRules(std::vector states) final; // Override Apply to apply tensorization-specific analysis before applying sub-rules - ffi::Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + ffi::Array Apply(const Schedule& sch, const SBlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { @@ -188,7 +188,7 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { * \param block_rv The block to be tensorized * \param intrin_name The name of the tensor intrin */ - void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, + void TileAndAnnotateTensorize(Schedule* sch, const SBlockRV& block_rv, const ffi::String& intrin_name, const ffi::String& permuted_layout_annotate_value) const; @@ -211,7 +211,7 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { // Entry of the mega rule; Inherited from ScheduleRuleNode ffi::Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, - const BlockRV& block_rv) { + const SBlockRV& block_rv) { if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { return {sch}; } @@ -286,11 +286,11 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); ICHECK(loop.defined()); - BlockRV blockized_outer = (*sch)->Blockize(loop.value()); + SBlockRV blockized_outer = (*sch)->Blockize(loop.value()); (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); if (!permuted_layout_annotate_value.empty()) { (*sch)->Annotate(blockized_outer, "permuted_layout", permuted_layout_annotate_value); @@ -303,7 +303,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta return {std::move(state)}; } ICHECK(config.req != ReuseType::kMayReuse); - const BlockRV& block_rv = state->block_rv; + const SBlockRV& block_rv = state->block_rv; std::vector results; results.reserve(config.levels.size()); for (int level : config.levels) { @@ -318,7 +318,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta continue; } // Do cache_read - BlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); + SBlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); new_state->read_reuse.emplace(i, cache_read_block); if (state->is_mma) { new_state->sch->Annotate( @@ -332,7 +332,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta } std::pair, ffi::Array> -MultiLevelTilingTensorCoreNode::MMASplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, +MultiLevelTilingTensorCoreNode::MMASplitLoop(const Schedule& sch, SBlockRV block, LoopRV loop, int n_tiles, int partition_pos, int innerpart_factor) const { ffi::Array factors = sch->SamplePartitionedTile( @@ -347,14 +347,14 @@ MultiLevelTilingTensorCoreNode::MMASplitLoop(const Schedule& sch, BlockRV block, std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreState state) const { Schedule& sch = state->sch; - const BlockRV& block_rv = state->block_rv; + const SBlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types ffi::Array loops = sch->GetLoops(block_rv); if (!(loops.size() == 3 || !state->is_mma)) { LOG(DEBUG) << "The MMA tensor core only supports SSR loops now"; return {}; } - std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); + std::vector iter_types = GetSBlockVarTypes(sch->GetSRef(state->block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; @@ -440,8 +440,8 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // Get the shape of the wmma accumulator auto [frag_shape_m, frag_shape_n] = [&]() { - tir::Block intrin_block = - Downcast( + tir::SBlock intrin_block = + Downcast( tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) ->block; tir::For loop_m = Downcast(intrin_block->body); @@ -561,7 +561,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n]. // Get the loops other than the innermost two loops (accum_m and accum_n). - auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { + auto f_get_loops = [&](const SBlockRV& block_rv) -> std::array { ffi::Array buffer_loops = sch->GetLoops(block_rv); ICHECK_GT(buffer_loops.size(), 6); return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], @@ -613,12 +613,12 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( state->intrin_group.load_b_intrin); for (int i = 0; i < 2; ++i) { - const tir::BlockRV cache_read = state->read_reuse.at(i); + const tir::SBlockRV cache_read = state->read_reuse.at(i); // Inline the reindex / padding block sch->ComputeInline(sch->GetProducers(cache_read)[0]); - const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); + const tir::SBlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( - sch->state(), ffi::GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); + sch->state(), ffi::GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { sch->StorageAlign(cache_read, 0, -2, 32, 8); @@ -658,7 +658,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( } for (int i = 0; i < 2; ++i) { - const tir::BlockRV cache_read = state->read_reuse.at(i); + const tir::SBlockRV cache_read = state->read_reuse.at(i); if (state->is_mma) { // Add vector bytes for memhammer sch->Annotate(cache_read, tir::attr::vector_bytes, Integer(16)); @@ -763,14 +763,14 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( TensorCoreStateNode* state, const ffi::String& intrin_name) const { - BlockRV block_rv = state->block_rv; + SBlockRV block_rv = state->block_rv; const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); // Add reindex stages - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const tir::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); // Hold the reference of the block before reindex - const tir::Block block_before_reindex = ffi::GetRef(block); + const tir::SBlock block_before_reindex = ffi::GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { // only matmul-like computation is allowed return std::nullopt; @@ -848,9 +848,9 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( } visited_buffers.insert(lhs_buffer); // Refresh block pointer (block sref is not invalidated) - block = TVM_SREF_TO_BLOCK(block_sref); + block = TVM_SREF_TO_SBLOCK(block_sref); const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( - state->sch->state(), ffi::GetRef(block), buffer_index, index_type); + state->sch->state(), ffi::GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, @@ -865,7 +865,7 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( } // Transform the layout of current block and reindex blocks - auto f_transform_reindex_block_layout = [&](const BlockRV& block_rv, + auto f_transform_reindex_block_layout = [&](const SBlockRV& block_rv, tir::BufferIndexType buffer_type) { tir::Buffer buffer = tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 080e1c9c0fbf..8a1ac2bae8d4 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -27,8 +27,8 @@ namespace tvm { namespace meta_schedule { -using tir::BlockRV; using tir::LoopRV; +using tir::SBlockRV; using tir::Schedule; /*! @@ -56,17 +56,17 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { } std::pair, ffi::Array> SplitLoop(const Schedule& sch, - BlockRV block, LoopRV loop, + SBlockRV block, LoopRV loop, int n_tiles) const; }; std::pair, ffi::Array> -MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, +MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv, LoopRV loop_rv, int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); - const tir::BlockNode* block_node = block_sref->StmtAs(); - const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref); + const tir::SBlockNode* block_node = block_sref->StmtAs(); + const tir::SBlockRealize block_realize = tir::GetSBlockRealize(sch->state(), block_sref); ICHECK(block_node && block_node->writes.size() == 1); const auto out_dtype = block_node->writes[0]->buffer->dtype; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 4a375689e493..8167a6f8974b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -31,14 +31,14 @@ namespace meta_schedule { * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate * the tiled block for tensorization by postproc rewrite. */ -ffi::Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, - const std::string& intrin_name) { +ffi::Optional TileForIntrin(tir::Schedule sch, tir::SBlockRV block, + const std::string& intrin_name) { ffi::Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); if (!tiled_loop_rv) { return std::nullopt; } ICHECK(tiled_loop_rv.defined()); - tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); + tir::SBlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, ffi::String(intrin_name)); return outer_block; } @@ -48,7 +48,7 @@ ffi::Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, */ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { protected: - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 9216c70e3328..d1e931e42434 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -23,12 +23,12 @@ namespace tvm { namespace tir { -bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) { +bool IsRootBlock(const Schedule& sch, const SBlockRV& block_rv) { StmtSRef block_sref = sch->GetSRef(block_rv); return block_sref->parent == nullptr; } -bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) { +bool CheckSpatialPrimFunc(const Schedule& sch, const SBlockRV& root_block_rv) { return IsSpatialPrimFunc( ffi::GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); } @@ -51,7 +51,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& root_rv) { // Currently only mark the root block with annotations. if (!tir::IsRootBlock(sch, root_rv)) { return {sch}; diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 2c9975fcf916..89a9f722a816 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -29,7 +29,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::SBlockRV& block_rv) final { if (!CheckConditions(sch, block_rv)) { return {sch}; } @@ -40,7 +40,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer // access the input block. Hence we collect its producer ahead of time. // - Note that only single producer is allowed in this case. - ffi::Array producers{nullptr}; + ffi::Array producers{nullptr}; if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, true)) { producers = sch->GetProducers(block_rv); @@ -66,9 +66,9 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } private: - bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { + bool CheckConditions(const tir::Schedule sch, const tir::SBlockRV& block_rv) const { tir::StmtSRef block_sref = sch->GetSRef(block_rv); - TVM_SREF_TO_BLOCK(block_sref); + TVM_SREF_TO_SBLOCK(block_sref); // Cond 1. The block is not the root block. if (block_sref->parent == nullptr) { @@ -106,7 +106,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { * \param block_rv The block whose compute-at location is to be sampled * \return The TIR schedule after transformation */ - tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::SBlockRV& block_rv) { tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); sch->ComputeAt(block_rv, compute_at_loc, true); return sch; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 9eac4ad57b20..f5ef5da48c2d 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -31,7 +31,7 @@ void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { } ffi::Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, - const tir::BlockRV& block) { + const tir::SBlockRV& block) { ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; return f_apply(sch, block); } diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index e3786a4d6188..44a365031894 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -47,7 +47,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { } ffi::Array GenerateDesignSpace(const IRModule& mod) final { - using ScheduleAndUnvisitedBlocks = std::pair>; + using ScheduleAndUnvisitedBlocks = std::pair>; CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply"; tir::Schedule sch = tir::Schedule::Traced( /*mod=*/mod, @@ -57,7 +57,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { std::vector stack; ffi::Array result{sch}; - ffi::Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); + ffi::Array all_blocks = SBlockCollector::Collect(sch, f_block_filter_); for (ScheduleRule sch_rule : sch_rules.value()) { for (const tir::Schedule& sch : result) { @@ -74,7 +74,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { continue; } // otherwise, get the last block that is not visited - tir::BlockRV block_rv = blocks.back(); + tir::SBlockRV block_rv = blocks.back(); blocks.pop_back(); if (!sch->HasBlock(block_rv)) { stack.emplace_back(sch, blocks); diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index d6300afcf9eb..aef04f7da19c 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -38,7 +38,7 @@ namespace meta_schedule { using namespace tir; // Returns true if b1 is an ancestor of b2 -bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { +bool IsAncestor(SBlockRV b1, SBlockRV b2, Schedule sch) { if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) { return true; } @@ -50,14 +50,14 @@ bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) { // Inline or reverse inline spatial blocks after the anchor block void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { - static auto kind_get_block = InstructionKind::Get("GetBlock"); + static auto kind_get_sblock = InstructionKind::Get("GetSBlock"); // We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace. // We record such block names to avoid inlining them here. - std::unordered_set get_block_names; + std::unordered_set get_sblock_names; for (const auto& inst : anchor_trace->insts) { - if (inst->kind.same_as(kind_get_block)) { + if (inst->kind.same_as(kind_get_sblock)) { auto block_name = Downcast(inst->attrs[0]); - get_block_names.insert(block_name); + get_sblock_names.insert(block_name); } } @@ -66,15 +66,15 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { std::vector inline_todos; std::optional last_block_idx{std::nullopt}; - for (auto name : GetBlockNames(sch->mod())) { - auto block = sch->GetBlock(name); + for (auto name : GetSBlockNames(sch->mod())) { + auto block = sch->GetSBlock(name); if (anchor_block) { - auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint); + auto anchor_block_rv = sch->GetSBlock(anchor_block->name_hint); if (IsAncestor(block, anchor_block_rv, sch)) continue; } // Spatial blocks which are not referenced in the anchor trace will be inlined here. auto block_sref = sch->GetSRef(block); - if (IsSpatial(block_sref) && !get_block_names.count(name)) { + if (IsSpatial(block_sref) && !get_sblock_names.count(name)) { StmtSRef scopeRoot = (name != "root") ? GetScopeRoot(sch->state(), block_sref, false) : block_sref; if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) { @@ -93,24 +93,24 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { auto inline_rule = GetDefaultAutoInline(target->kind->name); for (auto name : inline_todos) { - inline_rule->Apply(sch, sch->GetBlock(name)); + inline_rule->Apply(sch, sch->GetSBlock(name)); } } // Apply instructions from the anchor trace to the target schedule, and returns blocks // that remain unscheduled. -std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { +std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks"); - static auto kind_get_block = InstructionKind::Get("GetBlock"); + static auto kind_get_sblock = InstructionKind::Get("GetSBlock"); static auto kind_compute_inline = InstructionKind::Get("ComputeInline"); static auto kind_reverse_compute_inline = InstructionKind::Get("ReverseComputeInline"); - const auto block_names_orig = GetBlockNames(sch->mod()); + const auto block_names_orig = GetSBlockNames(sch->mod()); const auto sch_orig = sch->Copy(); std::unordered_map rv_map; // Blocks and loops that appear in the anchor trace but are not part of the target schedule. - std::unordered_set foreign_blocks; + std::unordered_set foreign_blocks; std::unordered_set foreign_loops; // Instructions in the anchor trace can be applied only if all inputs are part of the target @@ -118,7 +118,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) { for (auto input : inst->inputs) { if (input == nullptr) continue; - if ((input.as() && foreign_blocks.count(Downcast(input))) || + if ((input.as() && foreign_blocks.count(Downcast(input))) || (input.as() && foreign_loops.count(Downcast(input)))) { return false; } @@ -131,8 +131,8 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { // If we find an instruction that is not applicable, its outputs are recorded as "foreign" // to the target schedule. for (auto output : inst->outputs) { - if (output.as()) { - foreign_blocks.insert(Downcast(output)); + if (output.as()) { + foreign_blocks.insert(Downcast(output)); } else if (output.as()) { foreign_loops.insert(Downcast(output)); } @@ -142,17 +142,17 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { ffi::Array inputs = TranslateInputRVs(inst->inputs, rv_map); - if (inst->kind.same_as(kind_get_block) && + if (inst->kind.same_as(kind_get_sblock) && !HasBlock(sch, Downcast(inst->attrs[0]))) { - // The anchor trace does get_block on a block that is not part of the target schedule. - auto block = Downcast(inst->outputs[0]); + // The anchor trace does get_sblock on a block that is not part of the target schedule. + auto block = Downcast(inst->outputs[0]); foreign_blocks.insert(block); continue; } else if (inst->kind.same_as(kind_reverse_compute_inline)) { // The anchor trace does reverse_compute_inline on a block, but the block with the same name // in the target schedule cannot be reverse compute inline-ed. // In such cases, it should be possible to apply compute_inline instead. - auto block = Downcast(inputs[0]); + auto block = Downcast(inputs[0]); auto block_sref = sch->GetSRef(block); if (!CanReverseComputeInline(sch->state(), block_sref)) { ICHECK(CanComputeInline(sch->state(), block_sref)); @@ -161,7 +161,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { } } else if (inst->kind.same_as(kind_compute_inline)) { // Similar to the reverse_compute_inline case above. - auto block = Downcast(inputs[0]); + auto block = Downcast(inputs[0]); auto block_sref = sch->GetSRef(block); auto state = sch->state(); if (!CanComputeInline(state, block_sref)) { @@ -194,8 +194,8 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { } auto is_scheduled = [=](const std::string& block_name) { - auto loops = sch->GetLoops(sch->GetBlock(block_name)); - auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name)); + auto loops = sch->GetLoops(sch->GetSBlock(block_name)); + auto loops_orig = sch_orig->GetLoops(sch_orig->GetSBlock(block_name)); if (loops.size() != loops_orig.size()) { return true; } @@ -209,12 +209,12 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { return false; }; - const auto block_names_now = GetBlockNames(sch->mod()); - std::vector unscheduled_blocks; + const auto block_names_now = GetSBlockNames(sch->mod()); + std::vector unscheduled_blocks; for (auto name : block_names_orig) { if (block_names_now.count(name) && name != "root" && !is_scheduled(name)) { - unscheduled_blocks.push_back(sch->GetBlock(name)); + unscheduled_blocks.push_back(sch->GetSBlock(name)); } } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 96b18fb7e755..0ed65afa2fb2 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -285,16 +285,16 @@ inline std::string Concat(const ffi::Array& strs, const std::string } /*! - * \brief Get the BlockRV from a block StmtSRef + * \brief Get the SBlockRV from a block StmtSRef * \param sch The schedule * \param block_sref The block StmtSRef * \param global_var_name The global variable name - * \return The BlockRV + * \return The SBlockRV */ -inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, - const ffi::String& global_var_name) { - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - return sch->GetBlock(block->name_hint, global_var_name); +inline tir::SBlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, + const ffi::String& global_var_name) { + const tir::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + return sch->GetSBlock(block->name_hint, global_var_name); } /*! @@ -578,24 +578,24 @@ inline double Sum(const ffi::Array& arr) { } /*! \brief Collecting all the blocks */ -class BlockCollector : public tir::StmtVisitor { +class SBlockCollector : public tir::StmtVisitor { public: - static ffi::Array Collect(const tir::Schedule& sch, - const ffi::Function f_block_filter = nullptr) { // - return BlockCollector(sch, f_block_filter).Run(); + static ffi::Array Collect(const tir::Schedule& sch, + const ffi::Function f_block_filter = nullptr) { // + return SBlockCollector(sch, f_block_filter).Run(); } private: /*! \brief Entry point */ - ffi::Array Run() { - std::vector results; + ffi::Array Run() { + std::vector results; auto f_collect = [this, &results](tir::PrimFunc func, ffi::String func_name) { func_name_ = func_name; block_names_.clear(); blocks_to_collect_.clear(); VisitStmt(func->body); for (const ffi::String& name : blocks_to_collect_) { - results.push_back(sch_->GetBlock(name, func_name_)); + results.push_back(sch_->GetSBlock(name, func_name_)); } }; @@ -615,10 +615,10 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) + explicit SBlockCollector(const tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ - void VisitStmt_(const tir::BlockNode* block) override { + void VisitStmt_(const tir::SBlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); CHECK(block_names_.count(block->name_hint) == 0) << "Duplicated block name " << block->name_hint << " in function " << func_name_ @@ -629,7 +629,7 @@ class BlockCollector : public tir::StmtVisitor { // Otherwise collect all blocks. Bool collect_block = Bool(true); if (f_block_filter_ != nullptr) { - collect_block = f_block_filter_(ffi::GetRef(block)).cast(); + collect_block = f_block_filter_(ffi::GetRef(block)).cast(); } if (collect_block) { blocks_to_collect_.push_back(block->name_hint); diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index c87d0891ccfe..7eed4cd4aa9d 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -318,7 +318,8 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr */ class BlockAnalyzer : public StmtExprVisitor { public: - explicit BlockAnalyzer(const Block& block, const ffi::Map& transformation_cache, + explicit BlockAnalyzer(const SBlock& block, + const ffi::Map& transformation_cache, IndexMap write_transformation) : can_transform_block_(true), write_transformation_(write_transformation), @@ -465,7 +466,7 @@ class BlockAnalyzer : public StmtExprVisitor { } } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { // Blocks with nested blocks cannot be handled yet. LOG(WARNING) << "[LayoutInference] Nested blocks are not supported for layout inference yet"; can_transform_block_ = false; @@ -515,7 +516,7 @@ class BlockAnalyzer : public StmtExprVisitor { public: bool CanBeTransformed() { return can_transform_block_; } - IndexMap GetBlockTransformation() { return block_transformation_; } + IndexMap GetSBlockTransformation() { return block_transformation_; } ffi::Map GetReadBufferTransformations() { return read_buffer_transformations_; } private: @@ -524,7 +525,7 @@ class BlockAnalyzer : public StmtExprVisitor { ffi::Map spatial_dom_; arith::Analyzer arith_analyzer_; - Block block_; + SBlock block_; IndexMap block_transformation_; ffi::Map read_buffer_transformations_; @@ -557,8 +558,8 @@ class PrimFuncAnalyzer : public StmtExprVisitor { } VisitStmt(func->body); } - ffi::Map> GetSuggestedTransforms() { - ffi::Map> result; + ffi::Map> GetSuggestedTransforms() { + ffi::Map> result; for (const auto& [block, index_map] : block_transformations_) { ffi::Map block_transformations; block_transformations.Set(block, index_map); @@ -571,14 +572,14 @@ class PrimFuncAnalyzer : public StmtExprVisitor { } private: - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (op->name_hint == "root") { // Skip the root block StmtVisitor::VisitStmt_(op); return; } - Block block = ffi::GetRef(op); + SBlock block = ffi::GetRef(op); // Get block write buffer transformation. if (block->writes.size() != 1) return; auto write_buffer = block->writes[0]->buffer; @@ -588,7 +589,7 @@ class PrimFuncAnalyzer : public StmtExprVisitor { if (!block_analyzer.CanBeTransformed()) return; // Collect the suggested transformations - block_transformations_.Set(block, block_analyzer.GetBlockTransformation()); + block_transformations_.Set(block, block_analyzer.GetSBlockTransformation()); for (const auto& [buffer, index_map] : block_analyzer.GetReadBufferTransformations()) { // BlockAnalyzer makes sure that it does not propose transformation for a buffer for which a @@ -602,11 +603,11 @@ class PrimFuncAnalyzer : public StmtExprVisitor { private: ffi::Map buffer_transformation_cache_; - ffi::Map block_transformations_; - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; + ffi::Map block_transformations_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; }; -ffi::Map> SuggestLayoutTransforms( +ffi::Map> SuggestLayoutTransforms( const PrimFunc& prim_func, ffi::Array write_buffer_transformations) { // No changes to the PrimFunc are required if no transformations on output buffers. if (write_buffer_transformations.empty()) return {}; diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 58c47529a103..3a3e0e6697bc 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -43,7 +43,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { } private: - bool IsOutputBlock(const BlockNode* block) { + bool IsOutputBlock(const SBlockNode* block) { for (const BufferRegion& write_region : block->writes) { if (param_buffers_.count(write_region->buffer)) { return true; @@ -68,7 +68,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { ExprVisitor::VisitExpr_(op); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (op->name_hint == "root") { // Skip the root block StmtVisitor::VisitStmt(op->body); @@ -369,17 +369,17 @@ bool HasReshapePattern(const PrimFunc& func) { ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); // To detect the reshape pattern, we require each For to have // either another For or a BlockRealize as body. - if (!(loop->body->IsInstance() || loop->body->IsInstance())) { + if (!(loop->body->IsInstance() || loop->body->IsInstance())) { return; } this->VisitStmt(loop->body); } - void VisitStmt_(const BlockRealizeNode* block_realize) final { + void VisitStmt_(const SBlockRealizeNode* block_realize) final { // Constructing the mapping from block iterators to iterator // binding values. The mapping will be used in the substitution of // the flattened buffer access index. - const Block& block = block_realize->block; + const SBlock& block = block_realize->block; const ffi::Array& block_iter = block->iter_vars; const ffi::Array& iter_values = block_realize->iter_values; ICHECK_EQ(block_iter.size(), iter_values.size()); @@ -395,7 +395,7 @@ bool HasReshapePattern(const PrimFunc& func) { this->VisitStmt(block); } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { // Step 0. If the block body is a ForNode, recurse into it. if (block->body->IsInstance()) { this->VisitStmt(block->body); @@ -535,7 +535,7 @@ bool HasReshapePattern(const PrimFunc& func) { // To detect the reshape pattern, we require each For to have // either another For or a BlockRealize as body. - ICHECK(func->body->IsInstance()); + ICHECK(func->body->IsInstance()); return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index 71c024b9d7a0..d9da06ae6e30 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -51,14 +51,14 @@ using meta_schedule::ModuleHash; */ class BlockCounter : public tir::StmtVisitor { public: - static size_t GetBlockCount(const tir::PrimFunc& func) { + static size_t GetSBlockCount(const tir::PrimFunc& func) { BlockCounter counter; counter(func->body); return counter.count; } private: - void VisitStmt_(const tir::BlockNode* op) final { + void VisitStmt_(const tir::SBlockNode* op) final { ++count; StmtVisitor::VisitStmt_(op); } @@ -120,7 +120,7 @@ class TaskExtractor : public ExprVisitor { // count the PrinFunc number of blocks and leave only the function with the smallest number of // blocks. This way, "nn_conv2d_add_nn_relu" will have a smaller number of blocks than // "nn_conv2d_add_add_nn_relu" and will be selected to tune. - if (BlockCounter::GetBlockCount(func) < BlockCounter::GetBlockCount(alt_func)) { + if (BlockCounter::GetSBlockCount(func) < BlockCounter::GetSBlockCount(alt_func)) { weight += it->second->weight; func2task_.erase(it->first); } diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 837f2f0a5dcb..a21304b90152 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -64,19 +64,19 @@ class DistBufferReplacer : public StmtExprMutator { return load; } - Stmt VisitStmt_(const BlockNode* _block) final { - Block old_block = ffi::GetRef(_block); - Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); - ObjectPtr new_block = ffi::make_object(*block.get()); + Stmt VisitStmt_(const SBlockNode* _block) final { + SBlock old_block = ffi::GetRef(_block); + SBlock block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = ffi::make_object(*block.get()); new_block->reads = ReplaceBuffer(new_block->reads, buffer_map_); new_block->writes = ReplaceBuffer(new_block->writes, buffer_map_); - return Block(new_block); + return SBlock(new_block); } ffi::Map buffer_map_; }; -class DistBlockInfoCollector : public StmtExprVisitor { +class DistSBlockInfoCollector : public StmtExprVisitor { private: void VisitStmt_(const BufferStoreNode* op) final { buffer_access_indices[op->buffer].push_back(op->indices); @@ -88,7 +88,7 @@ class DistBlockInfoCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { for (const auto& iter_var : op->iter_vars) { if (iter_var->iter_type == kCommReduce) { ICHECK(op->writes.size() == 1); @@ -201,8 +201,9 @@ class DistributedBufferCompactor : StmtExprMutator { } ffi::Array ShardIterVar( - Block block, const std::unordered_map>, ObjectPtrHash, - ObjectPtrEqual>& buffer_access_indices) { + SBlock block, + const std::unordered_map>, ObjectPtrHash, + ObjectPtrEqual>& buffer_access_indices) { std::vector buffers; for (const auto& read : block->reads) { buffers.push_back(read->buffer); @@ -271,9 +272,9 @@ class DistributedBufferCompactor : StmtExprMutator { return Buffer(new_buffer); } - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - DistBlockInfoCollector collector; + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); + DistSBlockInfoCollector collector; collector(block); ffi::Array new_iter_vars = ShardIterVar(block, collector.buffer_access_indices); ffi::Array new_alloc_buffers; @@ -294,7 +295,7 @@ class DistributedBufferCompactor : StmtExprMutator { break; } } - ObjectPtr new_block = ffi::make_object(*block.operator->()); + ObjectPtr new_block = ffi::make_object(*block.operator->()); new_block->iter_vars = new_iter_vars; new_block->alloc_buffers = new_alloc_buffers; if (new_block->name_hint == "root") { @@ -303,13 +304,13 @@ class DistributedBufferCompactor : StmtExprMutator { allocated_buffer_under_root.end()); } new_block->body = DistBufferReplacer::BufferReplace(block->body, buffer_map); - return Block(new_block); + return SBlock(new_block); } void AddAllReduceBlock(std::string reduce_kind) { add_allreduce_kind_ = reduce_kind; } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockRealizeNode* op) final { + SBlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); for (int i = 0; i < static_cast(realize->iter_values.size()); i++) { PrimExpr iter_value = realize->iter_values[i]; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 09f404d29cbd..56146e80d063 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -166,9 +166,9 @@ class BlockBuilderImpl : public BlockBuilderNode { return it->second; } - void BeginDataflowBlock() final { block_stack_.emplace_back(BlockFrame{{}, true}); } + void BeginDataflowBlock() final { block_stack_.emplace_back(BindingBlockFrame{{}, true}); } - void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } + void BeginBindingBlock() final { block_stack_.emplace_back(BindingBlockFrame{{}, false}); } void BeginScope(ffi::Optional> params) final { // The current implementation handles the collection of shape var @@ -230,17 +230,17 @@ class BlockBuilderImpl : public BlockBuilderNode { void EndScope() final { scope_stack_.pop_back(); } BindingBlock EndBlock() final { - BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlockFrame* cur_frame = CurrentBindingBlockFrame(); BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) : BindingBlock(cur_frame->bindings); block_stack_.pop_back(); return ret; } - bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } + bool CurrentBlockIsDataFlow() final { return CurrentBindingBlockFrame()->is_dataflow; } Var Emit(Expr expr, ffi::String name_hint) final { - return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); + return this->Emit(expr, CurrentBindingBlockFrame()->is_dataflow, name_hint); } Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint) final { @@ -252,7 +252,7 @@ class BlockBuilderImpl : public BlockBuilderNode { << GetStructInfo(value) << ", given struct info: " << struct_info; // NOTE: do match cast checking later in a pass. - BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlockFrame* cur_frame = CurrentBindingBlockFrame(); Var var = CreateVar(cur_frame->is_dataflow, name_hint); UpdateStructInfo(var, struct_info); @@ -266,7 +266,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } Var EmitOutput(Expr output, ffi::String name_hint) final { - BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlockFrame* cur_frame = CurrentBindingBlockFrame(); ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; @@ -274,7 +274,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } void EmitNormalized(Binding binding) final { - BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlockFrame* cur_frame = CurrentBindingBlockFrame(); if (const auto* var_binding = binding.as()) { if (!cur_frame->is_dataflow) { @@ -313,7 +313,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * to build a binding block, and a boolean to indicate if the * block being built is a DataflowBlock or not. */ - struct BlockFrame { + struct BindingBlockFrame { /*! * \brief List of bindings */ @@ -345,7 +345,7 @@ class BlockBuilderImpl : public BlockBuilderNode { }; /*! \brief A stack to store block frames. */ - std::vector block_stack_; + std::vector block_stack_; /*! \brief A stack to store scope frames. */ std::vector scope_stack_; @@ -368,7 +368,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * or other scope calls this value can change if the block stack get updated, * then the block frame is no longer valid. */ - BlockFrame* CurrentBlockFrame() { + BindingBlockFrame* CurrentBindingBlockFrame() { ICHECK(!block_stack_.empty()) << "no block is being built"; return &block_stack_.back(); } @@ -399,7 +399,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // set the values UpdateStructInfo(var, Downcast(expr->struct_info_.value())); - CurrentBlockFrame()->bindings.push_back(VarBinding(var, expr)); + CurrentBindingBlockFrame()->bindings.push_back(VarBinding(var, expr)); // update the binding table binding_table_[var->vid] = expr; @@ -553,7 +553,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctornormalize_binding_map.find(arg); if (it != cur_frame->normalize_binding_map.end()) { return it->second; @@ -567,7 +567,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorEmit(post, ""); // NOTE: current frame addr can change due to underlying vector // re-allocation, redo lookup - CurrentBlockFrame()->normalize_binding_map[arg] = var; + CurrentBindingBlockFrame()->normalize_binding_map[arg] = var; return var; } else { return post; @@ -606,7 +606,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctornormalize_binding_map.find(expr); if (it != cur_frame->normalize_binding_map.end()) { return it->second; diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 3b56d6ca1d81..0b9eeb8341c5 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -730,8 +730,8 @@ tir::Stmt RemapBuffers(const tir::Stmt& stmt, return node; } - tir::Stmt VisitStmt_(const tir::BlockNode* op) final { - auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + tir::Stmt VisitStmt_(const tir::SBlockNode* op) final { + auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); auto* node_cow = node.CopyOnWrite(); // need the lambdas because class methods are not first-class (how ironic) node_cow->alloc_buffers = diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 549cd2197b4b..b33957cc41c4 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -226,8 +226,8 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { } } - Stmt VisitStmt_(const BlockNode* _op) final { - Block block = Downcast(StmtMutator::VisitStmt_(_op)); + Stmt VisitStmt_(const SBlockNode* _op) final { + SBlock block = Downcast(StmtMutator::VisitStmt_(_op)); // Define the mutation functions. @@ -283,7 +283,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { n->writes = std::move(writes); n->match_buffers = std::move(match_buffers); n->alloc_buffers = std::move(alloc_buffers); - return Block(n); + return SBlock(n); } } @@ -339,10 +339,10 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { }; /*! \brief A mutator which detect block name duplication and deduplicate the names. */ -class BlockNameDeduplicator : public tir::StmtMutator { +class SBlockNameDeduplicator : public tir::StmtMutator { private: - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(tir::StmtMutator::VisitStmt_(op)); ffi::String name = GetUniqueName(block->name_hint); @@ -350,7 +350,7 @@ class BlockNameDeduplicator : public tir::StmtMutator { return block; } else { - ObjectPtr n = CopyOnWrite(block.get()); + ObjectPtr n = CopyOnWrite(block.get()); n->name_hint = std::move(name); return Stmt(n); } @@ -683,10 +683,10 @@ class FusedTIRConstructor : public ExprVisitor { // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block // TODO(Siyuan): support un-schedulable functions. - ICHECK(prim_func->body->IsInstance()) + ICHECK(prim_func->body->IsInstance()) << "Only schedulable functions (whose body is the root block) can be fused"; - const tir::BlockRealize& root_realize = Downcast(prim_func->body); - const tir::Block& root_block = root_realize->block; + const tir::SBlockRealize& root_realize = Downcast(prim_func->body); + const tir::SBlock& root_block = root_realize->block; // Step 4. Add all the original alloc_buffers and body to the fused function. func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(), @@ -1003,11 +1003,11 @@ class FusedTIRConstructor : public ExprVisitor { alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf)); } } - tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); + tir::Stmt body = tir::SBlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); body = subst.Substitute(body); - body = tir::Block({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); - body = tir::BlockRealize({}, Bool(true), Downcast(body)); + body = tir::SBlock({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); + body = tir::SBlockRealize({}, Bool(true), Downcast(body)); tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, DictAttrs(attr_map)); // Renew function defs to prevent using the same symbolic vars in different functions diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 8ecfabd7c27a..9749599c2f85 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -341,7 +341,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } void VisitBindingBlock_(const BindingBlockNode* binding_block) final { - BindingBlockScope new_scope; + BindingSBlockScope new_scope; std::swap(new_scope, current_block_scope_); for (const auto& binding : binding_block->bindings) { VisitBinding(binding); @@ -597,7 +597,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { * lifting. They are initialized lazily when a binding that can be lifted is encountered. * They are reset to nullptr when an unsupported operation is encountered. */ - struct BindingBlockScope { + struct BindingSBlockScope { FuncBuilder* capture_builder = nullptr; // The builder for the capture function }; @@ -611,7 +611,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { // The IRModule IRModule mod_; // States of the current block scope - BindingBlockScope current_block_scope_; + BindingSBlockScope current_block_scope_; // States of the current function scope FunctionScope current_function_scope_; // Variables whose buffer address is fixed diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 00c6efb192a3..376759984816 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -59,7 +59,7 @@ class ForMatcher : public TensorizeComparator { } bool Match(const For& top) { - const ForNode* pattern_top = pattern_->body.as()->block->body.as(); + const ForNode* pattern_top = pattern_->body.as()->block->body.as(); ICHECK(pattern_top) << "Invalid pattern function"; if (!VisitStmt(top, ffi::GetRef(pattern_top))) { return false; @@ -248,10 +248,10 @@ class ForMatcher : public TensorizeComparator { loop_stack_lhs_.push_back(ffi::GetRef(op)); loop_stack_rhs_.push_back(ffi::GetRef(rhs)); // The body of loop must be loop or BlockRealize - if (!op->body->IsInstance() && !op->body->IsInstance()) { + if (!op->body->IsInstance() && !op->body->IsInstance()) { return false; } - if (!rhs->body->IsInstance() && !rhs->body->IsInstance()) { + if (!rhs->body->IsInstance() && !rhs->body->IsInstance()) { return false; } // Build mapping between the loop vars @@ -266,8 +266,8 @@ class ForMatcher : public TensorizeComparator { return VisitStmt(op->body, rhs->body); } - bool VisitStmt_(const tir::BlockNode* op, const Stmt& other) final { - const auto* rhs = other.as(); + bool VisitStmt_(const tir::SBlockNode* op, const Stmt& other) final { + const auto* rhs = other.as(); // Check block equality. // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. @@ -295,8 +295,8 @@ class ForMatcher : public TensorizeComparator { return VisitStmt(op->body, rhs->body); } - bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) final { - const auto* rhs = other.as(); + bool VisitStmt_(const SBlockRealizeNode* op, const Stmt& other) final { + const auto* rhs = other.as(); // Only allow trivial bindings for (size_t i = 0; i < op->iter_values.size(); ++i) { if (!op->iter_values[i].same_as(loop_stack_lhs_[i]->loop_var)) return false; @@ -448,7 +448,7 @@ class FunctionPartitioner : public StmtExprVisitor { /*! \brief alloc_buffers for the second function */ std::unordered_set allocs2; /*! \brief whether the current block is in the first function */ - ffi::Map block_partition; + ffi::Map block_partition; /*! \brief input buffers for the first function */ std::unordered_set input1; /*! \brief input buffers for the second function */ @@ -461,7 +461,7 @@ class FunctionPartitioner : public StmtExprVisitor { bool fail = false; private: - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { block_counter_++; bool is_matching_ = block_counter_ <= num_matched_ops_; if (block_counter_ == num_matched_ops_) { @@ -489,7 +489,7 @@ class FunctionPartitioner : public StmtExprVisitor { input2.insert(write->buffer); } } - block_partition.Set(ffi::GetRef(op), Bool(is_matching_)); + block_partition.Set(ffi::GetRef(op), Bool(is_matching_)); } // The number of matched ops in the function size_t num_matched_ops_; @@ -500,7 +500,7 @@ class FunctionPartitioner : public StmtExprVisitor { class BlockRemover : public StmtExprMutator { public: static Stmt RemoveBlockByPartition( - Stmt stmt, const ffi::Map& block_partition, + Stmt stmt, const ffi::Map& block_partition, const std::unordered_set& allocs, bool is_library_part) { BlockRemover remover(block_partition, allocs, is_library_part); @@ -508,17 +508,17 @@ class BlockRemover : public StmtExprMutator { } private: - BlockRemover(const ffi::Map& block_partition, + BlockRemover(const ffi::Map& block_partition, const std::unordered_set& allocs, bool is_library_part) : block_partition(block_partition), allocs_(allocs), is_library_part_(is_library_part) {} - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - ObjectPtr n = ffi::make_object(*block.operator->()); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); + ObjectPtr n = ffi::make_object(*block.operator->()); if (op->name_hint != "root") { - ICHECK(block_partition.count(ffi::GetRef(op))); - bool block_is_library = block_partition[ffi::GetRef(op)]->value; + ICHECK(block_partition.count(ffi::GetRef(op))); + bool block_is_library = block_partition[ffi::GetRef(op)]->value; if (!(is_library_part_ ^ block_is_library)) { n->body = block->body; } else { @@ -532,7 +532,7 @@ class BlockRemover : public StmtExprMutator { } } n->alloc_buffers = alloc_buffers; - return Block(n); + return SBlock(n); } Stmt VisitStmt_(const SeqStmtNode* op) final { @@ -549,7 +549,7 @@ class BlockRemover : public StmtExprMutator { } bool erased_ = false; - ffi::Map block_partition; + ffi::Map block_partition; std::unordered_set allocs_; bool is_library_part_ = false; }; @@ -568,9 +568,9 @@ std::pair> SplitFunctions( PrimFunc func, std::vector>* arg_partition, ffi::Array patterns, FCodegen f_codegen) { // Step 1. Find the library kernel and the rest. - Stmt body = func->body.as()->block->body; + Stmt body = func->body.as()->block->body; ffi::Array match_results = - TIRPatternMatcher::Match(patterns, func->body.as()->block->body); + TIRPatternMatcher::Match(patterns, func->body.as()->block->body); if (match_results.empty()) { return {func, std::nullopt}; } diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 1da49c1d7de3..c3e1e3aebe6e 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -36,8 +36,9 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { public: explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} std::tuple, PrimFunc> Transform(const PrimFunc& func) { - ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; - const auto& block = func->body.as()->block; + ICHECK(func->body.as()) + << "The body of the primfunc should be a root block."; + const auto& block = func->body.as()->block; visit_root_block(block.get()); if (layout_rewrite_preproc_stmts_.size() > 0) { return std::make_tuple(create_layout_rewrite_preproc_func(), create_compute_func()); @@ -75,12 +76,12 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { << "There should be at least one layout rewrite preproc stmt."; Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] : SeqStmt(layout_rewrite_preproc_stmts_); - body = BlockRealize( + body = SBlockRealize( /*iter_values=*/ffi::Array(), /*predicate=*/const_true(), /*block=*/ - Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/"root", body)); + SBlock(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body)); ffi::Map dict; for (const auto& [key, original_value] : original_func_->attrs->dict) { @@ -108,7 +109,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { // Step 2: Create the body for the new PrimFunc Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); - Block original_block = original_func_->body.as()->block; + SBlock original_block = original_func_->body.as()->block; ffi::Array alloc_buffers; for (const auto& buffer : original_block->alloc_buffers) { auto it = @@ -119,14 +120,14 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { } } - body = BlockRealize( + body = SBlockRealize( /*iter_values=*/ffi::Array(), /*predicate=*/const_true(), /*block=*/ - Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/"root", body, - /*init=*/std::nullopt, - /*alloc_buffers=*/alloc_buffers)); + SBlock(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body, + /*init=*/std::nullopt, + /*alloc_buffers=*/alloc_buffers)); ffi::Map dict; for (const auto& [key, original_value] : original_func_->attrs->dict) { @@ -142,7 +143,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { return RenewDefs(func); } - void visit_root_block(const BlockNode* op) { + void visit_root_block(const SBlockNode* op) { Stmt body = op->body; if (const auto* seq_stmt = body.as()) { for (const auto& stmt : seq_stmt->seq) { @@ -162,8 +163,8 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { << "There should be a compute block if there is only one subtree under the root."; } } - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtMutator::VisitStmt_(op)); auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc); bool is_layout_rewrite_preproc = it != op->annotations.end() && is_one(Downcast((*it).second)); @@ -199,9 +200,9 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { auto new_annotations = op->annotations; new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc); - auto n = ffi::make_object(*block.get()); + auto n = ffi::make_object(*block.get()); n->annotations = new_annotations; - return Block(n); + return SBlock(n); } return block; } diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index acd1784c88f0..c57ca041b328 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -34,19 +34,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { RelaxFrameNode::RegisterReflection(); SeqExprFrameNode::RegisterReflection(); FunctionFrameNode::RegisterReflection(); - BlockFrameNode::RegisterReflection(); + BindingBlockFrameNode::RegisterReflection(); IfFrameNode::RegisterReflection(); ThenFrameNode::RegisterReflection(); ElseFrameNode::RegisterReflection(); } void SeqExprFrameNode::ExitWithScope() { - // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call - // its `ExitBlockFrame` and check if there is any more unended BlockFrame. - if (ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame()) { + // At this moment, there should be at most one BindingBlockFrame which hasn't ended. In this case, + // call its `ExitBindingBlockFrame` and check if there is any more unended BindingBlockFrame. + if (ffi::Optional block_frame = + IRBuilder::Current()->GetLastFrame()) { block_frame.value()->ExitWithScope(); - ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) - << "ValueError: There is some remaining BlockFrame that is not properly popped out."; + ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) + << "ValueError: There is some remaining BindingBlockFrame that is not properly popped out."; } RelaxFrameNode::ExitWithScope(); } @@ -105,14 +106,15 @@ void FunctionFrameNode::ExitWithScope() { } } -void BlockFrameNode::EnterWithScope() { +void BindingBlockFrameNode::EnterWithScope() { // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the // last block frame. - ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = + IRBuilder::Current()->GetLastFrame(); if (block_frame.defined()) { block_frame.value()->ExitWithScope(); // Block frames cannot appear consecutively. - ICHECK(!IRBuilder::Current()->GetLastFrame()); + ICHECK(!IRBuilder::Current()->GetLastFrame()); } // Step 2. Deal with the new block frame. RelaxFrameNode::EnterWithScope(); @@ -147,7 +149,7 @@ class VarReplacer : public tvm::relax::ExprMutator { } }; -void BlockFrameNode::ExitWithScope() { +void BindingBlockFrameNode::ExitWithScope() { // Step 1. Pop the current frame out of the frame stack. RelaxFrameNode::ExitWithScope(); @@ -191,7 +193,7 @@ void BlockFrameNode::ExitWithScope() { // Step 4. Since we popped out any possible block frame when entering the "with" scope of the // current frame, the last frame cannot be a block frame. - ICHECK(!last_frame->IsInstance()); + ICHECK(!last_frame->IsInstance()); // Step 5. Push the block frame into the corresponding field of the last frame. if (const auto* seq_frame = last_frame.as()) { @@ -212,7 +214,7 @@ void BlockFrameNode::ExitWithScope() { void IfFrameNode::EnterWithScope() { const ffi::Array& frames = IRBuilder::Current()->frames; for (const IRBuilderFrame& frame : frames) { - const auto* block_frame = frame.as(); + const auto* block_frame = frame.as(); if (block_frame && block_frame->is_dataflow) { LOG(FATAL) << "ValueError: Cannot create an IfFrame inside a dataflow block."; } diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index db77d4db5b26..55f473a7ba0a 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -125,7 +125,7 @@ void FuncRetValue(const tvm::relax::Expr& value) { // a function body. Therefore if there is any unended block frame when dealing with function // return, we should end the block frame. - if (auto opt = ir_builder->GetLastFrame()) { + if (auto opt = ir_builder->GetLastFrame()) { auto block_frame = opt.value(); for (const auto& var : tvm::relax::FreeVars(normalized_value)) { if (var->IsInstance()) { @@ -159,23 +159,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { ///////////////////////////// BindingBlock ////////////////////////////// -BlockFrame Dataflow() { - ObjectPtr n = ffi::make_object(); +BindingBlockFrame Dataflow() { + ObjectPtr n = ffi::make_object(); n->is_dataflow = true; n->block_ended = false; - return BlockFrame(n); + return BindingBlockFrame(n); } -BlockFrame BindingBlock() { - ObjectPtr n = ffi::make_object(); +BindingBlockFrame BindingBlock() { + ObjectPtr n = ffi::make_object(); n->is_dataflow = false; n->block_ended = false; - return BlockFrame(n); + return BindingBlockFrame(n); } void DataflowBlockOutput(const ffi::Array& vars) { // Step 1. Check that we're in a Dataflow block that is not ended. - ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = + IRBuilder::Current()->GetLastFrame(); CHECK(block_frame.defined() && block_frame.value()->is_dataflow) << "ValueError: `R.output` should appear inside a dataflow block. However, the current " "innermost block is not a dataflow block."; @@ -210,7 +211,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { tvm::relax::Var Emit(const tvm::relax::Expr& expr, const ffi::Optional& annotate_struct_info) { using tvm::relax::GetStructInfo; - BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); if (annotate_struct_info.defined()) { const auto& sinfo = annotate_struct_info.value(); @@ -229,7 +230,7 @@ tvm::relax::Var Emit(const tvm::relax::Expr& expr, tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, const tvm::relax::StructInfo& struct_info) { - BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); @@ -238,7 +239,7 @@ tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, } tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { - BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); block_builder->EmitNormalized(binding); block_frame->emitted_vars.push_back(binding->var); diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index e24b4a27593d..30ca9753d497 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -57,11 +57,12 @@ inline tvm::relax::BlockBuilder GetBlockBuilder() { return frame.value()->block_builder; } -inline BlockFrame CheckBlockFrameExistAndUnended() { +inline BindingBlockFrame CheckBindingBlockFrameExistAndUnended() { // We check if the current block is "ended" - if a block is ended, it is not allowed to emit new // bindings into this block, and we should throw exceptions. - ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = + IRBuilder::Current()->GetLastFrame(); CHECK(block_frame.defined()) << "ValueError: Block frame not find"; CHECK(!block_frame.value()->block_ended) << "ValueError: New binding is not allowed after dataflow block output."; diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 7c10b6cdc8d1..2236e4f8b2bc 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -31,7 +31,7 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); PrimFuncFrameNode::RegisterReflection(); - BlockFrameNode::RegisterReflection(); + SBlockFrameNode::RegisterReflection(); BlockInitFrameNode::RegisterReflection(); ForFrameNode::RegisterReflection(); AssertFrameNode::RegisterReflection(); @@ -84,7 +84,7 @@ void PrimFuncFrameNode::ExitWithScope() { } } -void BlockFrameNode::ExitWithScope() { +void SBlockFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); ffi::Array tir_alloc_buffers; for (const tvm::tir::Buffer& buffer : alloc_buffers) { @@ -94,21 +94,21 @@ void BlockFrameNode::ExitWithScope() { if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); } - tvm::tir::Block block(iter_vars, reads.value_or(ffi::Array()), - writes.value_or(ffi::Array()), name, AsStmt(stmts), - init, tir_alloc_buffers, match_buffers, attrs); + tvm::tir::SBlock block(iter_vars, reads.value_or(ffi::Array()), + writes.value_or(ffi::Array()), name, AsStmt(stmts), + init, tir_alloc_buffers, match_buffers, attrs); if (no_realize) { CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`"; CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`"; AddToParent(block); } else { - AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block)); + AddToParent(tvm::tir::SBlockRealize(iter_values, predicate.value_or(Bool(true)), block)); } } void BlockInitFrameNode::EnterWithScope() { - BlockFrame frame = FindBlockFrame("T.init"); + SBlockFrame frame = FindSBlockFrame("T.init"); if (frame->init.defined()) { LOG(FATAL) << "ValueError: Duplicate block init declaration"; } @@ -117,7 +117,7 @@ void BlockInitFrameNode::EnterWithScope() { void BlockInitFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - BlockFrame frame = FindBlockFrame("T.init"); + SBlockFrame frame = FindSBlockFrame("T.init"); frame->init = AsStmt(stmts); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 00f9c28475b4..4b5d5d2fb121 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -142,11 +142,11 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, } LOG(FATAL) << "ValueError: Can not bind non-input param to buffer."; } else if (const auto* buffer_load = param.as()) { - BlockFrame frame = FindBlockFrame("T.match_buffer"); + SBlockFrame frame = FindSBlockFrame("T.match_buffer"); frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( buffer, BufferRegionFromLoad(ffi::GetRef(buffer_load)))); } else if (const auto* buffer_region = param.as()) { - BlockFrame frame = FindBlockFrame("T.match_buffer"); + SBlockFrame frame = FindSBlockFrame("T.match_buffer"); frame->match_buffers.push_back( tvm::tir::MatchBufferRegion(buffer, ffi::GetRef(buffer_region))); } else { @@ -155,8 +155,8 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, return buffer; } -BlockFrame Block(ffi::String name, bool no_realize) { - ObjectPtr n = ffi::make_object(); +SBlockFrame Block(ffi::String name, bool no_realize) { + ObjectPtr n = ffi::make_object(); n->name = name; n->iter_vars.clear(); n->reads = std::nullopt; @@ -168,13 +168,13 @@ BlockFrame Block(ffi::String name, bool no_realize) { n->iter_values.clear(); n->predicate = std::nullopt; n->no_realize = no_realize; - return BlockFrame(n); + return SBlockFrame(n); } BlockInitFrame Init() { return BlockInitFrame(ffi::make_object()); } void Where(PrimExpr predicate) { - BlockFrame frame = FindBlockFrame("T.where"); + SBlockFrame frame = FindSBlockFrame("T.where"); if (frame->predicate.defined()) { LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is " << frame->predicate; @@ -184,7 +184,7 @@ void Where(PrimExpr predicate) { void Reads(ffi::Array buffer_slices) { using namespace tvm::tir; - BlockFrame frame = FindBlockFrame("T.reads"); + SBlockFrame frame = FindSBlockFrame("T.reads"); if (frame->reads.defined()) { LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; } @@ -203,7 +203,7 @@ void Reads(ffi::Array buffer_slices) { void Writes(ffi::Array buffer_slices) { using namespace tvm::tir; - BlockFrame frame = FindBlockFrame("T.writes"); + SBlockFrame frame = FindSBlockFrame("T.writes"); if (frame->writes.defined()) { LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " << frame->writes; @@ -253,7 +253,7 @@ ffi::Map MergeAnnotations(const ffi::Map& ne } void BlockAttrs(ffi::Map attrs) { - BlockFrame frame = FindBlockFrame("T.block_attr"); + SBlockFrame frame = FindSBlockFrame("T.sblock_attr"); // Case 1: the block has no annotations, set the new annotations if (!frame->annotations.defined()) { frame->annotations = attrs; @@ -270,25 +270,25 @@ Buffer AllocBuffer(ffi::Array shape, DataType dtype, ffi::Optional frame = builder->FindFrame()) { + if (ffi::Optional frame = builder->FindFrame()) { frame.value()->alloc_buffers.push_back(buffer); } else if (ffi::Optional frame = builder->GetLastFrame()) { frame.value()->root_alloc_buffers.push_back(buffer); } else { LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " - "'T.alloc_buffer' is called under T.block() or T.prim_func()"; + "'T.alloc_buffer' is called under T.sblock() or T.prim_func()"; } return buffer; } namespace axis { IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { - if (ffi::Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { - BlockFrame frame = opt_frame.value(); + if (ffi::Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { + SBlockFrame frame = opt_frame.value(); frame->iter_vars.push_back(iter_var); frame->iter_values.push_back(binding); } else { - LOG(FATAL) << "TypeError: The last frame is not BlockFrame"; + LOG(FATAL) << "TypeError: The last frame is not SBlockFrame"; } return iter_var; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index d7c272ae5138..cabf418a10af 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -76,21 +76,21 @@ inline PrimFuncFrame FindPrimFuncFrame(const ffi::String& method) { } /*! - * \brief Check whether the top frame in IRBuilder frame stack is BlockFrame. + * \brief Check whether the top frame in IRBuilder frame stack is SBlockFrame. * \param method The method name to be printed when throwing exception. - * \return The top frame of BlockFrame. + * \return The top frame of SBlockFrame. */ -inline BlockFrame FindBlockFrame(const ffi::String& method) { - if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { +inline SBlockFrame FindSBlockFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); - } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " + } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { + LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.sblock(). " << "While " << method << " did occur within the block \"" << frame.value()->name - << "\", other frames (e.g. if/else/let) had been introduced since the T.block(\"" + << "\", other frames (e.g. if/else/let) had been introduced since the T.sblock(\"" << frame.value()->name << "\") frame"; } else { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(), " - << "but " << method << " occurred outside of any T.block() frame"; + LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.sblock(), " + << "but " << method << " occurred outside of any T.sblock() frame"; } throw; } diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 1a33d760a9d5..a5b6141dd040 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -22,12 +22,12 @@ namespace tvm { namespace script { namespace printer { -Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // - ffi::Optional opt_realize, +Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // + ffi::Optional opt_realize, ffi::Optional opt_realize_p) { With frame(d, block); ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); - const tir::BlockRealizeNode* realize = + const tir::SBlockRealizeNode* realize = opt_realize.defined() ? opt_realize.value().get() : nullptr; AccessPath realize_p = *opt_realize_p; // Step 1. Handle block var and block bindings @@ -174,7 +174,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // // Step 4. Handle block attributes if (!block->annotations.empty()) { (*frame)->stmts.push_back(ExprStmtDoc( - TIR(d, "block_attr") + TIR(d, "sblock_attr") ->Call({d->AsDoc(block->annotations, block_p->Attr("annotations"))}))); } // Step 5. Handle `alloc_buffer` @@ -210,15 +210,15 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // kwargs_values.push_back(LiteralDoc::Boolean(true, std::nullopt)); } return ScopeDoc(std::nullopt, - TIR(d, "block") // + TIR(d, "sblock") // ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))}, kwargs_keys, kwargs_values), (*frame)->stmts); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](tir::BlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](tir::SBlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p); // since we do not have d->AsDoc for realize->block, // we should add possible doc decoration manually. @@ -227,12 +227,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Block block, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::SBlock block, AccessPath p, IRDocsifier d) -> Doc { return PrintBlock(d, block, p, std::nullopt, std::nullopt); }); -TVM_SCRIPT_REPR(tir::BlockNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::BlockRealizeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SBlockNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tir::SBlockRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index c5083b57c2d0..bfa999a5c68d 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -143,19 +143,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // Step 4. Handle `func->body` - ffi::Optional implicit_root_block = [&]() -> ffi::Optional { - const tir::BlockRealizeNode* root_block_realize = func->body.as(); + ffi::Optional implicit_root_block = [&]() -> ffi::Optional { + const tir::SBlockRealizeNode* root_block_realize = func->body.as(); if (root_block_realize && !root_block_realize->iter_values.size() && tir::is_one(root_block_realize->predicate)) { - tir::Block root_block = root_block_realize->block; + tir::SBlock root_block = root_block_realize->block; if (!root_block->annotations.size() && !root_block->match_buffers.size() && !root_block->reads.size() && !root_block->writes.size() && !root_block->init.defined()) { - const tir::BlockRealizeNode* block_realize = - root_block->body.as(); + const tir::SBlockRealizeNode* block_realize = + root_block->body.as(); if (root_block->alloc_buffers.size() || (block_realize && block_realize->block->iter_vars.size()) || - (!block_realize && tir::ContainsNode(root_block->body))) { + (!block_realize && tir::ContainsNode(root_block->body))) { return root_block; } } @@ -163,9 +163,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return std::nullopt; }(); if (d->cfg->syntax_sugar && implicit_root_block) { - tir::Block root_block = implicit_root_block.value(); + tir::SBlock root_block = implicit_root_block.value(); AccessPath root_block_p = p->Attr("body")->Attr("block"); - (*f)->stmts.push_back(CommentDoc("with T.block(\"root\"):")); + (*f)->stmts.push_back(CommentDoc("with T.sblock(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { tir::Buffer buffer = root_block->alloc_buffers[i]; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index fa84ab3863fb..2b326c745a75 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -149,9 +149,9 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices); } - Stmt VisitStmt_(const BlockNode* _block) final { - Block block = Downcast(StmtMutator::VisitStmt_(_block)); - BlockNode* n = block.CopyOnWrite(); + Stmt VisitStmt_(const SBlockNode* _block) final { + SBlock block = Downcast(StmtMutator::VisitStmt_(_block)); + SBlockNode* n = block.CopyOnWrite(); if (auto opt_ann = n->annotations.Get(topi_attr)) { ffi::Array new_buffers; for (Buffer buffer : Downcast>(opt_ann.value())) { @@ -541,18 +541,18 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Stmt init = GenerateInitStmt(leaf.store_indices, buffers, reduce, leaf.axes_remap, info); Stmt body = GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap, expr_body, info, analyzer); - seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings, - /*predicate=*/Bool(true), - /*block=*/ - Block(/*iter_vars=*/leaf.block_iters, - /*reads=*/{}, - /*writes=*/{}, - /*name_hint=*/info->FreshName(compute_op->name), - /*body=*/body, - /*init=*/init, - /*alloc_buffers=*/{}, - /*match_buffers=*/{}, - /*annotations=*/annotations))); + seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, + /*predicate=*/Bool(true), + /*block=*/ + SBlock(/*iter_vars=*/leaf.block_iters, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/info->FreshName(compute_op->name), + /*body=*/body, + /*init=*/init, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations))); } else { for (int i = 0; i < compute_op->num_outputs(); ++i) { @@ -563,18 +563,18 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in PrimExpr expr_body = compute_op->body[i]; Stmt body = GenerateBodyStmt(leaf.store_indices, {buffers[i]}, leaf.axes_remap, expr_body, info, analyzer); - seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings, - /*predicate=*/Bool(true), - /*block=*/ - Block(/*iter_vars=*/leaf.block_iters, - /*reads=*/{}, - /*writes=*/{}, - /*name_hint=*/info->FreshName(buffers[i]->name), - /*body=*/body, - /*init=*/std::nullopt, - /*alloc_buffers=*/{}, - /*match_buffers=*/{}, - /*annotations=*/annotations))); + seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, + /*predicate=*/Bool(true), + /*block=*/ + SBlock(/*iter_vars=*/leaf.block_iters, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/info->FreshName(buffers[i]->name), + /*body=*/body, + /*init=*/std::nullopt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations))); } } Stmt body = SeqStmt::Flatten(seq_stmt); @@ -596,18 +596,18 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in } // wrap nested block - body = BlockRealize(/*iter_values=*/cur.bindings, - /*predicate=*/Bool(true), - /*block=*/ - Block(/*iter_vars=*/block_iters, - /*reads=*/{}, - /*writes=*/{}, - /*name_hint=*/block_name, - /*body=*/body, - /*init=*/init, - /*alloc_buffers=*/{}, - /*match_buffers=*/{}, - /*annotations=*/annotations)); + body = SBlockRealize(/*iter_values=*/cur.bindings, + /*predicate=*/Bool(true), + /*block=*/ + SBlock(/*iter_vars=*/block_iters, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/block_name, + /*body=*/body, + /*init=*/init, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations)); } for (size_t j = cur.loop_vars.size(); j > 0; --j) { const auto& [loop_var, dom] = cur.loop_vars[j - 1]; @@ -646,7 +646,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf // be generated with the later application of "script.Complete" in // GenerateAndCompletePrimFunc. Waiting until later also handles // the case where there is only a single BlockNode, which then - // becomes the root Block of the function, and should not have + // becomes the root SBlock of the function, and should not have // reads/writes filled in. BufferSubstituter substituter(var_map, input_buffer_map); @@ -656,18 +656,18 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf Stmt body = transformer(substituted_body); // Step 4. Generate opaque block as body. - return BlockRealize(/*iter_values=*/{}, - /*predicate=*/Bool(true), - /*block=*/ - Block(/*iter_vars=*/{}, - /*reads=*/{}, - /*writes=*/{}, - /*name_hint=*/info->FreshName(extern_op->name), - /*body=*/std::move(body), - /*init=*/std::nullopt, - /*alloc_buffers=*/{}, - /*match_buffers=*/{}, - /*annotations=*/extern_op->attrs)); + return SBlockRealize(/*iter_values=*/{}, + /*predicate=*/Bool(true), + /*block=*/ + SBlock(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/info->FreshName(extern_op->name), + /*body=*/std::move(body), + /*init=*/std::nullopt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/extern_op->attrs)); } ffi::Array CollectOrderedOps(const ffi::Array& arg_list) { diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index aca06ad595bc..cbd40a57a398 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -114,7 +114,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; - void VisitStmt_(const BlockRealizeNode* op) override; + void VisitStmt_(const SBlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const LetStmtNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; @@ -123,7 +123,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { - const auto* block = stmt.as(); + const auto* block = stmt.as(); ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { const Var& target_var = match_buffer->buffer->data; @@ -253,7 +253,7 @@ void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { StmtVisitor::VisitStmt_(op); } -void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { +void BlockReadWriteDetector::VisitStmt_(const SBlockRealizeNode* op) { /*! \note detector will not visit child block recursively, so it will stop here */ std::unordered_map vmap; for (size_t i = 0; i < op->block->iter_vars.size(); ++i) { @@ -371,8 +371,8 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { } } -ffi::Array> GetBlockAccessRegion( - const Block& block, const ffi::Map& buffer_var_map) { +ffi::Array> GetSBlockAccessRegion( + const SBlock& block, const ffi::Map& buffer_var_map) { BlockReadWriteDetector detector(buffer_var_map); detector(block); ffi::Array writes = detector.CollectWrites(); @@ -388,8 +388,8 @@ ffi::Array> GetBlockAccessRegion( return {reads, writes, opaques}; } -ffi::Array> GetBlockReadWriteRegion( - const Block& block, const ffi::Map& buffer_var_map) { +ffi::Array> GetSBlockReadWriteRegion( + const SBlock& block, const ffi::Map& buffer_var_map) { BlockReadWriteDetector detector(buffer_var_map); detector(block); ffi::Array opaques = detector.CollectOpaques(); @@ -414,8 +414,8 @@ ffi::Array> GetBlockReadWriteRegion( TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.analysis.GetBlockAccessRegion", GetBlockAccessRegion) - .def("tir.analysis.GetBlockReadWriteRegion", GetBlockReadWriteRegion); + .def("tir.analysis.GetSBlockAccessRegion", GetSBlockAccessRegion) + .def("tir.analysis.GetSBlockReadWriteRegion", GetSBlockReadWriteRegion); } } // namespace tir diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 67e8bda6f670..467f8123c4e1 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -107,8 +107,8 @@ class LCADetector : public StmtExprVisitor { loop_scope_map_.erase(op->loop_var.get()); } - void VisitStmt_(const BlockRealizeNode* op) final { - const BlockNode* block = op->block.get(); + void VisitStmt_(const SBlockRealizeNode* op) final { + const SBlockNode* block = op->block.get(); int n = ancestor_scopes_.size(); for (const Buffer& buf : block->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); @@ -137,7 +137,7 @@ class LCADetector : public StmtExprVisitor { ancestor_scopes_.pop_back(); } - void UpdateDominateScopeOfNonDataParIter(const BlockRealizeNode* block_realize) { + void UpdateDominateScopeOfNonDataParIter(const SBlockRealizeNode* block_realize) { // map iter var to the scope which dominate all loop carried dependencies. std::unordered_map opaque_var_scope; // maintain highest scope which dominate all reduce loop iters. null denotes non-reduce block. @@ -168,7 +168,7 @@ class LCADetector : public StmtExprVisitor { // collect non-data-parallel block iteration's dominate scope. // for reduction iter type, we maintain the highest dominate scope for all reduce iters. // for other iter type, we maintain the dict for each individual iter. - const Block& block = block_realize->block; + const SBlock& block = block_realize->block; bool is_reduce_block = false; for (size_t i = 0; i < block_realize->iter_values.size(); ++i) { const IterVar& iter_var = block->iter_vars[i]; @@ -324,7 +324,7 @@ class LCADetector : public StmtExprVisitor { return lhs; } - /*! \brief The ancestor scope stacks info (Block and For). The + /*! \brief The ancestor scope stacks info (SBlock and For). The * first element is initialized in LCADetector::Detect to represent * the root scope. */ diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 8d001dd1e459..f498c039ef9e 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -1422,7 +1422,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional flow_fr // Predecessors, if any, are unvisited. return {}; } else if (block.predecessors.size() == 1) { - // Block has only a single predecessor + // SBlock has only a single predecessor return states[0]; } @@ -1553,7 +1553,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ // Successors, if any, are unvisited. return {}; } else if (block.successors.size() == 1) { - // Block has only a single successor + // SBlock has only a single successor return states[0]; } diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 3fe33cdf2af2..6957ee578c90 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -145,10 +145,10 @@ class FlopEstimator : private ExprFunctor, return result; } TResult VisitStmt_(const BufferStoreNode* store) override { return VisitExpr(store->value); } - TResult VisitStmt_(const BlockRealizeNode* block) override { + TResult VisitStmt_(const SBlockRealizeNode* block) override { return VisitStmt(block->block->body); } - TResult VisitStmt_(const BlockNode* block) override { + TResult VisitStmt_(const SBlockNode* block) override { TResult result; if (block->init.defined()) { result += VisitStmt(block->init.value()); diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 9f6f4da7eaf3..58879277e9d9 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -64,22 +64,22 @@ const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) return nullptr; } -Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) { +Stmt GetEnclosingLoop(const SBlockNode* block, Stmt func_body) { struct GetRootSeqStmt : public StmtVisitor { void VisitStmt_(const SeqStmtNode* seq) override { result = seq; } const SeqStmtNode* result; }; struct BlockFinder : public StmtVisitor { - explicit BlockFinder(const BlockNode* tgt) : target(tgt) {} + explicit BlockFinder(const SBlockNode* tgt) : target(tgt) {} - void VisitStmt_(const BlockNode* block) override { + void VisitStmt_(const SBlockNode* block) override { if (block == target) { found = true; } } - const BlockNode* target; + const SBlockNode* target; bool found = false; }; @@ -98,23 +98,23 @@ Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) { } } - LOG(FATAL) << "Enclosing loop not found for a block " << ffi::GetRef(block); + LOG(FATAL) << "Enclosing loop not found for a block " << ffi::GetRef(block); TVM_FFI_UNREACHABLE(); } -const BlockNode* FindAnchorBlock(const IRModule& mod) { - struct ReductionBlockCollector : public StmtVisitor { - void VisitStmt_(const BlockNode* block) override { +const SBlockNode* FindAnchorBlock(const IRModule& mod) { + struct ReductionSBlockCollector : public StmtVisitor { + void VisitStmt_(const SBlockNode* block) override { if (block->init) { blocks.push_back(block); } StmtVisitor::VisitStmt(block->body); } - std::vector blocks; + std::vector blocks; }; if (auto prim_func = FindEntryFunc(mod, nullptr)) { - ReductionBlockCollector collector; + ReductionSBlockCollector collector; collector(prim_func->body); const auto& candidates = collector.blocks; @@ -142,12 +142,12 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.analysis.find_anchor_block", [](const IRModule& mod) { + refl::GlobalDef().def("tir.analysis.find_anchor_sblock", [](const IRModule& mod) { auto ret = FindAnchorBlock(mod); if (ret) { - return ffi::Optional(ffi::GetRef(ret)); + return ffi::Optional(ffi::GetRef(ret)); } - return ffi::Optional(std::nullopt); + return ffi::Optional(std::nullopt); }); } diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index c10931d1bd10..e50b60d55c7e 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -176,7 +176,7 @@ class BlockVarAccessVerifier : public StmtExprVisitor { loop_vars_.erase(op->loop_var.get()); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { // Do not check boundary if it's a opaque block. bool is_non_opaque = op->iter_vars.size(); if (is_non_opaque) { @@ -218,7 +218,7 @@ class BlockVarAccessVerifier : public StmtExprVisitor { /*! \brief Whether it's in assert mode. */ bool assert_mode_; /*! \brief Current nested block stack level. */ - std::vector block_stack_; + std::vector block_stack_; /*! \brief Whether there is error. */ bool has_error_{false}; }; diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index 3cda278d0a71..deff65e8cf19 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -27,7 +27,7 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { BlockDependenceInfoNode::RegisterReflection(); } /** - * @brief A helper class to collect and build Block Dependences using BlockScope class + * @brief A helper class to collect and build SBlock Dependences using SBlockScope class */ class BlockDependenceInfoCollector : private StmtVisitor { public: @@ -41,19 +41,19 @@ class BlockDependenceInfoCollector : private StmtVisitor { block_frames_.emplace_back(); } - void MakeBlockScope(StmtSRef scope) { + void MakeSBlockScope(StmtSRef scope) { ffi::Array child_block_srefs = std::move(block_frames_.back()); - self_->sref2scope[scope] = BlockScope(child_block_srefs); + self_->sref2scope[scope] = SBlockScope(child_block_srefs); } - void VisitStmt_(const BlockRealizeNode* realize) final { + void VisitStmt_(const SBlockRealizeNode* realize) final { block_frames_.emplace_back(); - const BlockNode* block = realize->block.get(); + const SBlockNode* block = realize->block.get(); // Recursive visit VisitStmt(block->body); // `block->init` is not visited - // Create BlockInfo for the block + // Create SBlockInfo for the block auto sref = self_->stmt2ref.at(block); - MakeBlockScope(sref); + MakeSBlockScope(sref); // Update parent scope block_frames_.pop_back(); block_frames_.back().push_back(sref); @@ -90,10 +90,11 @@ BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.BlockDependenceInfo", + .def("tir.SBlockDependenceInfo", [](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }) - .def_method("tir.BlockDependenceInfoGetBlockScope", &BlockDependenceInfoNode::GetBlockScope) - .def("tir.BlockDependenceInfoGetSRef", + .def_method("tir.SBlockDependenceInfoGetSBlockScope", + &BlockDependenceInfoNode::GetSBlockScope) + .def("tir.SBlockDependenceInfoGetSRef", [](BlockDependenceInfo self, Stmt stmt) -> ffi::Optional { auto it = self->stmt2ref.find(stmt.get()); return it != self->stmt2ref.end() ? it->second : ffi::Optional(std::nullopt); diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index 676f162076ce..8b2675936f92 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -26,7 +26,7 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { StmtSRefNode::RegisterReflection(); DependencyNode::RegisterReflection(); - BlockScopeNode::RegisterReflection(); + SBlockScopeNode::RegisterReflection(); } /******** Utility functions ********/ @@ -41,7 +41,7 @@ using SMap = std::unordered_map; * \param kind Type of the dependency * \note This method is effectively NOP on self-loops */ -void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) { +void AddDependency(SBlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) { if (!src.same_as(dst)) { Dependency dep(src, dst, kind); self->src2deps[src].push_back(dep); @@ -77,14 +77,14 @@ Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { data_ = std::move(node); } -BlockScope::BlockScope() { data_ = ffi::make_object(); } +SBlockScope::SBlockScope() { data_ = ffi::make_object(); } -BlockScope::BlockScope(const ffi::Array& child_block_srefs) { - ObjectPtr n = ffi::make_object(); +SBlockScope::SBlockScope(const ffi::Array& child_block_srefs) { + ObjectPtr n = ffi::make_object(); SMap> buffer_readers; SMap>& buffer_writers = n->buffer_writers; for (const StmtSRef& child_block_sref : child_block_srefs) { - const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); + const SBlockNode* child_block = TVM_SREF_TO_SBLOCK(child_block_sref); // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer for (const BufferRegion& region : child_block->reads) { buffer_readers[region->buffer].push_back(child_block_sref); @@ -125,7 +125,7 @@ BlockScope::BlockScope(const ffi::Array& child_block_srefs) { /******** Dependency ********/ -ffi::Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { +ffi::Array SBlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { auto iter = this->src2deps.find(block_sref); if (iter != this->src2deps.end()) { return iter->second; @@ -134,7 +134,7 @@ ffi::Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) } } -ffi::Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { +ffi::Array SBlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { auto iter = this->dst2deps.find(block_sref); if (iter != this->dst2deps.end()) { return iter->second; @@ -178,8 +178,8 @@ void SRefTreeCreator::VisitStmt_(const ForNode* loop) { } } -void SRefTreeCreator::VisitStmt_(const BlockRealizeNode* realize) { - const BlockNode* block = realize->block.get(); +void SRefTreeCreator::VisitStmt_(const SBlockRealizeNode* realize) { + const SBlockNode* block = realize->block.get(); PushSRef(block); VisitStmt(block->body); // `block->init` is not visited PopAndRecordSRef(); @@ -206,8 +206,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("tir.StmtSRefRootMark", StmtSRef::RootMark) .def("tir.StmtSRefInlineMark", StmtSRef::InlineMark) - .def_method("tir.BlockScopeGetDepsBySrc", &BlockScopeNode::GetDepsBySrc) - .def_method("tir.BlockScopeGetDepsByDst", &BlockScopeNode::GetDepsByDst); + .def_method("tir.SBlockScopeGetDepsBySrc", &SBlockScopeNode::GetDepsBySrc) + .def_method("tir.SBlockScopeGetDepsByDst", &SBlockScopeNode::GetDepsByDst); } } // namespace tir diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 393ac7ee57d0..d6d7c1e60998 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -50,8 +50,8 @@ Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { return For(n); } -Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { - BlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); +Stmt DataTypeLegalizer::VisitStmt_(const SBlockRealizeNode* op) { + SBlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); ffi::Array new_iter_values; bool changed = false; for (int i = 0; i < static_cast(op->iter_values.size()); ++i) { @@ -69,8 +69,8 @@ Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { return realize; } -Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { - Block new_block = Downcast(StmtExprMutator::VisitStmt_(op)); +Stmt DataTypeLegalizer::VisitStmt_(const SBlockNode* op) { + SBlock new_block = Downcast(StmtExprMutator::VisitStmt_(op)); ffi::Array new_iter_vars = MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) { auto dtype = iter->var.dtype(); @@ -302,7 +302,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { return decl_buffer; } -Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { +Stmt IndexDataTypeRewriter::VisitStmt_(const SBlockRealizeNode* op) { bool is_condition = is_condition_; is_condition_ = true; auto new_predicate = VisitExpr(op->predicate); @@ -313,10 +313,10 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { auto new_iter_values = op->iter_values.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); is_enabled_ = is_enabled; - Block new_body = Downcast(this->VisitStmt(op->block)); + SBlock new_body = Downcast(this->VisitStmt(op->block)); if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) || !new_body.same_as(op->block)) { - BlockRealize new_block_realize = ffi::GetRef(op); + SBlockRealize new_block_realize = ffi::GetRef(op); auto* n = new_block_realize.CopyOnWrite(); n->predicate = std::move(new_predicate); n->iter_values = std::move(new_iter_values); @@ -328,7 +328,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { } } -Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { +Stmt IndexDataTypeRewriter::VisitStmt_(const SBlockNode* op) { ffi::Array new_alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); }); ffi::Array new_match_buffers = @@ -360,8 +360,8 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) || !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars) || !new_annotations.same_as(op->annotations)) { - Block new_block = ffi::GetRef(op); - BlockNode* n = new_block.CopyOnWrite(); + SBlock new_block = ffi::GetRef(op); + SBlockNode* n = new_block.CopyOnWrite(); n->alloc_buffers = std::move(new_alloc_buffers); n->match_buffers = std::move(new_match_buffers); n->reads = std::move(new_reads); diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 61bdfb15e70e..990aa87717c1 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -196,10 +196,10 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { ffi::Function f_visit_seq_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` function. */ ffi::Function f_visit_evaluate{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` function. */ + /*! \brief The packed function to the `VisitStmt_(const SBlockNode* op)` function. */ ffi::Function f_visit_block{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* op)` function. */ - ffi::Function f_visit_block_realize{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const SBlockRealizeNode* op)` function. */ + ffi::Function f_visit_sblock_realize{nullptr}; using StmtExprVisitor::VisitExpr; using StmtExprVisitor::VisitStmt; @@ -237,8 +237,8 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { PY_STMT_VISITOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt); PY_STMT_VISITOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt); PY_STMT_VISITOR_DISPATCH(EvaluateNode, f_visit_evaluate); - PY_STMT_VISITOR_DISPATCH(BlockNode, f_visit_block); - PY_STMT_VISITOR_DISPATCH(BlockRealizeNode, f_visit_block_realize); + PY_STMT_VISITOR_DISPATCH(SBlockNode, f_visit_block); + PY_STMT_VISITOR_DISPATCH(SBlockRealizeNode, f_visit_sblock_realize); // Expression functions PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var); PY_EXPR_VISITOR_DISPATCH(SizeVarNode, f_visit_size_var); @@ -330,8 +330,8 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { PY_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(EvaluateNode); - PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockNode); - PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockRealizeNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(SBlockNode); + PY_STMT_VISITOR_DEFAULT_DISPATCH(SBlockRealizeNode); vtable.Finalize(); return vtable; } @@ -362,7 +362,7 @@ class PyStmtExprVisitor : public ObjectRef { ffi::Function f_visit_seq_stmt, // ffi::Function f_visit_evaluate, // ffi::Function f_visit_block, // - ffi::Function f_visit_block_realize, // + ffi::Function f_visit_sblock_realize, // ffi::Function f_visit_var, // ffi::Function f_visit_size_var, // ffi::Function f_visit_buffer_load, // @@ -414,7 +414,7 @@ class PyStmtExprVisitor : public ObjectRef { n->f_visit_seq_stmt = std::move(f_visit_seq_stmt); n->f_visit_evaluate = std::move(f_visit_evaluate); n->f_visit_block = std::move(f_visit_block); - n->f_visit_block_realize = std::move(f_visit_block_realize); + n->f_visit_sblock_realize = std::move(f_visit_sblock_realize); // Set expression functions n->f_visit_var = std::move(f_visit_var); n->f_visit_size_var = std::move(f_visit_size_var); @@ -563,10 +563,10 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { ffi::Function f_visit_seq_stmt{nullptr}; /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` function. */ ffi::Function f_visit_evaluate{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` function. */ + /*! \brief The packed function to the `VisitStmt_(const SBlockNode* op)` function. */ ffi::Function f_visit_block{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* op)` function. */ - ffi::Function f_visit_block_realize{nullptr}; + /*! \brief The packed function to the `VisitStmt_(const SBlockRealizeNode* op)` function. */ + ffi::Function f_visit_sblock_realize{nullptr}; using StmtExprMutator::VisitExpr; using StmtExprMutator::VisitStmt; @@ -604,8 +604,8 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { PY_STMT_MUTATOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt); PY_STMT_MUTATOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt); PY_STMT_MUTATOR_DISPATCH(EvaluateNode, f_visit_evaluate); - PY_STMT_MUTATOR_DISPATCH(BlockNode, f_visit_block); - PY_STMT_MUTATOR_DISPATCH(BlockRealizeNode, f_visit_block_realize); + PY_STMT_MUTATOR_DISPATCH(SBlockNode, f_visit_block); + PY_STMT_MUTATOR_DISPATCH(SBlockRealizeNode, f_visit_sblock_realize); // Expression functions PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var); PY_EXPR_MUTATOR_DISPATCH(SizeVarNode, f_visit_size_var); @@ -697,8 +697,8 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { PY_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(EvaluateNode); - PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockNode); - PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockRealizeNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(SBlockNode); + PY_STMT_MUTATOR_DEFAULT_DISPATCH(SBlockRealizeNode); vtable.Finalize(); return vtable; } @@ -730,7 +730,7 @@ class PyStmtExprMutator : public ObjectRef { ffi::Function f_visit_seq_stmt, // ffi::Function f_visit_evaluate, // ffi::Function f_visit_block, // - ffi::Function f_visit_block_realize, // + ffi::Function f_visit_sblock_realize, // ffi::Function f_visit_var, // ffi::Function f_visit_size_var, // ffi::Function f_visit_buffer_load, // @@ -782,7 +782,7 @@ class PyStmtExprMutator : public ObjectRef { n->f_visit_seq_stmt = std::move(f_visit_seq_stmt); n->f_visit_evaluate = std::move(f_visit_evaluate); n->f_visit_block = std::move(f_visit_block); - n->f_visit_block_realize = std::move(f_visit_block_realize); + n->f_visit_sblock_realize = std::move(f_visit_sblock_realize); // Expression functions n->f_visit_var = std::move(f_visit_var); n->f_visit_size_var = std::move(f_visit_size_var); diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index bf2b333f2501..4c2ccab58e10 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -41,7 +41,7 @@ class ScriptCompleter : public StmtMutator { private: ffi::Map* buffer_var_map_; - Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt VisitStmt_(const SBlockRealizeNode* op) final { for (const PrimExpr& value : op->iter_values) { CHECK(value.dtype().is_int()) << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); @@ -49,7 +49,7 @@ class ScriptCompleter : public StmtMutator { return StmtMutator::VisitStmt_(op); } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); @@ -61,7 +61,7 @@ class ScriptCompleter : public StmtMutator { bool is_root_block = this->is_root_block_; this->is_root_block_ = false; - Block block = Downcast(StmtMutator::VisitStmt_(op)); + SBlock block = Downcast(StmtMutator::VisitStmt_(op)); this->is_root_block_ = is_root_block; // Remove buffers allocated inside block to detect its access region @@ -81,7 +81,7 @@ class ScriptCompleter : public StmtMutator { } // ignore root block or blocks which already has reads/writes regions if (mask != 0) { - auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); + auto access_region = GetSBlockAccessRegion(block, *buffer_var_map_); const ffi::Array& reads = access_region[0]; const ffi::Array& writes = access_region[1]; const ffi::Array& opaque = access_region[2]; @@ -95,7 +95,7 @@ class ScriptCompleter : public StmtMutator { } n->annotations = op->annotations; n->annotations.erase(attr::script_parsing_detect_access); - return Block(n); + return SBlock(n); } else { return block; } @@ -134,19 +134,19 @@ PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates) if (root_allocates.size()) { return true; } - auto* block_realize = func->body.as(); + auto* block_realize = func->body.as(); if (block_realize && block_realize->block->iter_vars.size()) { return true; } - if (!block_realize && ContainsNode(func->body)) { + if (!block_realize && ContainsNode(func->body)) { return true; } return false; }(); if (should_insert_root) { - Block root_block({}, {}, {}, "root", std::move(res), std::nullopt, root_allocates); - res = BlockRealize({}, Bool(true), std::move(root_block)); + SBlock root_block({}, {}, {}, "root", std::move(res), std::nullopt, root_allocates); + res = SBlockRealize({}, Bool(true), std::move(root_block)); } // generate surrounding loops automatically diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 083dd8dedf31..de314257706c 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -113,14 +113,14 @@ class PrimFuncSpecializer : public StmtExprMutator { } private: - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); }); // Step.1. Recursively visit block body Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); ICHECK(op != nullptr); ffi::Array reads = @@ -130,9 +130,9 @@ class PrimFuncSpecializer : public StmtExprMutator { if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && writes.same_as(op->writes)) { - return ffi::GetRef(op); + return ffi::GetRef(op); } else { - ObjectPtr n = CopyOnWrite(op); + ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); n->reads = std::move(reads); n->writes = std::move(writes); @@ -296,7 +296,7 @@ class PrimFuncSpecializer : public StmtExprMutator { << "(see discussion on https://github.com/apache/tvm/pull/14565 for more details). " << "Please add a definition for this buffer, " << "either in the PrimFunc's buffer_map, " - << "in a tir::Block's alloc_buffer, " + << "in a tir::SBlock's alloc_buffer, " << "or in a DeclBuffer statement."; return old_buffer; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b7e28e84e748..d332741eeacf 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -49,8 +49,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { WhileNode::RegisterReflection(); BufferRegionNode::RegisterReflection(); MatchBufferRegionNode::RegisterReflection(); - BlockNode::RegisterReflection(); - BlockRealizeNode::RegisterReflection(); + SBlockNode::RegisterReflection(); + SBlockRealizeNode::RegisterReflection(); } // LetStmt @@ -660,12 +660,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // Block -Block::Block(ffi::Array iter_vars, ffi::Array reads, - ffi::Array writes, ffi::String name_hint, Stmt body, - ffi::Optional init, ffi::Array alloc_buffers, - ffi::Array match_buffers, ffi::Map annotations, - Span span) { - ObjectPtr node = ffi::make_object(); +SBlock::SBlock(ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init, ffi::Array alloc_buffers, + ffi::Array match_buffers, ffi::Map annotations, + Span span) { + ObjectPtr node = ffi::make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); node->writes = std::move(writes); @@ -681,25 +681,25 @@ Block::Block(ffi::Array iter_vars, ffi::Array reads, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Block", + refl::GlobalDef().def("tir.SBlock", [](ffi::Array iter_vars, ffi::Array reads, ffi::Array writes, ffi::String name_hint, Stmt body, ffi::Optional init, ffi::Array alloc_buffers, ffi::Array match_buffers, ffi::Map annotations, Span span) { - return Block(iter_vars, reads, writes, name_hint, body, init, - alloc_buffers, match_buffers, annotations, span); + return SBlock(iter_vars, reads, writes, name_hint, body, init, + alloc_buffers, match_buffers, annotations, span); }); } // BlockRealize -BlockRealize::BlockRealize(ffi::Array values, PrimExpr predicate, Block block, - Span span) { +SBlockRealize::SBlockRealize(ffi::Array values, PrimExpr predicate, SBlock block, + Span span) { CHECK_EQ(block->iter_vars.size(), values.size()) << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1)) << "TypeError: Expect Block.predicate to be a bool expression"; - ObjectPtr node = ffi::make_object(); + ObjectPtr node = ffi::make_object(); node->iter_values = std::move(values); node->predicate = std::move(predicate); node->block = std::move(block); @@ -709,9 +709,9 @@ BlockRealize::BlockRealize(ffi::Array values, PrimExpr predicate, Bloc TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BlockRealize", [](ffi::Array iter_values, PrimExpr predicate, - Block block, Span span) { - return BlockRealize(iter_values, predicate, block, span); + refl::GlobalDef().def("tir.SBlockRealize", [](ffi::Array iter_values, + PrimExpr predicate, SBlock block, Span span) { + return SBlockRealize(iter_values, predicate, block, span); }); } diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e6666cc63816..a2f3c46abaae 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -104,7 +104,7 @@ void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } -void StmtVisitor::VisitStmt_(const BlockNode* op) { +void StmtVisitor::VisitStmt_(const SBlockNode* op) { auto fvisit_buffer_region = [this](const BufferRegion& s) { for (const auto& range : s->region) { this->VisitExpr(range->min); @@ -127,7 +127,7 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { +void StmtVisitor::VisitStmt_(const SBlockRealizeNode* op) { VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->predicate); this->VisitStmt(op->block); @@ -466,7 +466,7 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } } -Stmt StmtMutator::VisitStmt_(const BlockNode* op) { +Stmt StmtMutator::VisitStmt_(const SBlockNode* op) { ffi::Array iter_vars = Internal::Mutate(this, op->iter_vars); ffi::Array reads = Internal::Mutate(this, op->reads); ffi::Array writes = Internal::Mutate(this, op->writes); @@ -479,7 +479,7 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && body.same_as(op->body) && init.same_as(op->init) && match_buffers.same_as(op->match_buffers)) { - return ffi::GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_vars = std::move(iter_vars); @@ -492,7 +492,7 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { } } -Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { +Stmt StmtMutator::VisitStmt_(const SBlockRealizeNode* op) { ffi::Array v = Internal::Mutate(this, op->iter_values); PrimExpr pred = this->VisitExpr(op->predicate); Stmt block = this->VisitStmt(op->block); @@ -502,7 +502,7 @@ Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { auto n = CopyOnWrite(op); n->iter_values = std::move(v); n->predicate = std::move(pred); - n->block = Downcast(block); + n->block = Downcast(block); return Stmt(n); } } diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 638340e0bd2f..712d4d88eb26 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -277,7 +277,7 @@ void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); } -void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const SBlockNode* op, AccessPath path) { std::vector, DefContext, DefContext>> context; { @@ -319,7 +319,7 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, AccessPath path) { while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::VisitStmt_(const BlockRealizeNode* op, AccessPath path) { +void TIRVisitorWithPath::VisitStmt_(const SBlockRealizeNode* op, AccessPath path) { Visit(op->iter_values, path->Attr("iter_values")); Visit(op->predicate, path->Attr("predicate")); Visit(op->block, path->Attr("block")); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 65673d1f2b34..1409fb39f52e 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -113,8 +113,8 @@ class TIRVisitorWithPath void VisitStmt_(const AssertStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const SeqStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const EvaluateNode* op, ffi::reflection::AccessPath path) override; - void VisitStmt_(const BlockNode* op, ffi::reflection::AccessPath path) override; - void VisitStmt_(const BlockRealizeNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const SBlockNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) override; using ExprFunctor::VisitExpr; void VisitExpr_(const VarNode* op, ffi::reflection::AccessPath path) override; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 1285c2c5f0ab..a68da6d5b883 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -101,7 +101,7 @@ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool requ */ struct ScopeBlockLoopInfo { /*! \brief A list of the leaf blocks, from left to right */ - std::vector realizes; + std::vector realizes; /*! \brief The loop vars bound to spatial block iters */ std::unordered_set spatial_vars; /*! \brief The loop vars bound to non-spatial block iters */ @@ -113,7 +113,7 @@ struct ScopeBlockLoopInfo { * \param scope_block The root block of the scope * \return The information of the scope */ -ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const SBlock& scope_block); /*! * \brief Checks whether the block is a complete block under the scope @@ -212,7 +212,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, * \param block_sref The block to be checked * \return A vector of types of the block vars */ -std::vector GetBlockVarTypes(const StmtSRef& block_sref); +std::vector GetSBlockVarTypes(const StmtSRef& block_sref); /*! * \brief Checks if a block could be considered as a "write cache" @@ -230,7 +230,7 @@ bool IsWriteCache(const StmtSRef& block_sref); * \param analyzer The analyzer * \return A boolean flag indicating if the binding is affine */ -bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& loop_var_ranges, +bool IsAffineBinding(const SBlockRealize& realize, const ffi::Map& loop_var_ranges, arith::Analyzer* analyzer); /*! @@ -240,7 +240,7 @@ bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& lo * \param block The block to be checked * \throw ScheduleError If the input block does not have an affine binding */ -void CheckAffineBinding(const ScheduleState& self, Block block); +void CheckAffineBinding(const ScheduleState& self, SBlock block); /*! * \brief Check whether a block has an affine binding under the high exclusive sref node, @@ -250,7 +250,7 @@ void CheckAffineBinding(const ScheduleState& self, Block block); * \param high_exclusive The highest sref node * \throw ScheduleError If the input block does not have an affine binding */ -void CheckPartialAffineBinding(const ScheduleState& self, Block block, +void CheckPartialAffineBinding(const ScheduleState& self, SBlock block, const ffi::Optional& high_exclusive); /*! @@ -273,7 +273,7 @@ ffi::Map LoopDomainOfSRefTreePath( * \param realize The BlockRealize to be analyzed * \return The block var binding */ -ffi::Map GetBindings(const BlockRealize& realize); +ffi::Map GetBindings(const SBlockRealize& realize); /*! * \brief Get the vars involved in the bindings of data parallel block vars and reduction block @@ -284,7 +284,7 @@ ffi::Map GetBindings(const BlockRealize& realize); * \return A boolean indicating whether the block has block iters that is neither a data parallel * block iter nor a reduction block iter */ -bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, +bool GetVarsTouchedByBlockIters(const SBlockRealize& block_realize, std::unordered_set* data_par_vars, std::unordered_set* reduce_vars); @@ -324,7 +324,7 @@ ffi::Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of leaf BlockRealize */ -ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); +ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); /*! * \brief Get the BlockRealize of the single child block of the block or loop specified by @@ -334,8 +334,8 @@ ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_s * \return The BlockRealize of the single child block * \throw ScheduleError If there is 0 or multiple child blocks */ -BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, - const StmtSRef& parent_sref); +SBlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref); /*! * \brief Get the BlockRealize of the input block @@ -343,7 +343,7 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self * \param block_sref The StmtSRef of the queried block * \return The BlockRealize of the input block */ -BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +SBlockRealize GetSBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); /*! * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to @@ -386,7 +386,7 @@ std::pair, std::vector> CollectComputeLocation( * \param scope The block scope where the given block is in * \return The producer blocks of the specified block */ -ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope); +ffi::Array GetProducers(const StmtSRef& block_sref, const SBlockScope& scope); /*! * \brief Get the consumer blocks to the given block under the given scope @@ -394,7 +394,7 @@ ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& * \param scope The block scope where the given block is in * \return The consumer blocks of the specified block */ -ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); +ffi::Array GetConsumers(const StmtSRef& block_sref, const SBlockScope& scope); /*! * \brief Get the list of output blocks within the given scope @@ -404,7 +404,7 @@ ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& * \return A list of all blocks that write to some output buffer * block */ -ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); +ffi::Array GetOutputBlocks(const ScheduleState& self, const SBlockNode* scope_block); /*! * \brief A solution to split a ordered list of subtrees into two parts, @@ -435,7 +435,7 @@ struct ProducerConsumerSplit { const ScheduleState& state, const ffi::Array& subtrees, const ffi::Array& producer_block_srefs, const ffi::Array& consumer_block_srefs, - std::unordered_map* block2realize); + std::unordered_map* block2realize); }; /******** Block-buffer relation ********/ @@ -449,7 +449,7 @@ struct ProducerConsumerSplit { * \return The buffer of the n-th read/write region of the block. * \throw ScheduleError If the buffer index is out of bound. */ -Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, +Buffer GetNthAccessBuffer(const ScheduleState& self, const SBlock& block, int n, BufferIndexType index_type); /*! @@ -461,7 +461,7 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, * \return The n-th read/write region of the block. * \throw ScheduleError If the buffer index is out of bound. */ -BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n, +BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const SBlock& block, int n, BufferIndexType index_type); /*! @@ -474,7 +474,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer); -/******** Reduction Block Related ********/ +/******** Reduction SBlock Related ********/ /*! * \brief Get the init values and the BufferStore updates from the input reduction block @@ -484,7 +484,7 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& b * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block */ std::pair, ffi::Array> GetInitValuesAndUpdatesFromReductionBlock( - const ffi::Optional& self, Block block); + const ffi::Optional& self, SBlock block); /*! * \brief Check whether the input array of IterVars only contains data-parallel and reduction block @@ -502,7 +502,7 @@ bool ContainsOnlyDataParAndReductionBlockIter(const ffi::Array& iters); * \return A boolean indicating whether the block's reduction block iters are not used to index the * block's output buffer */ -bool ReductionIterNotIndexOutputBuffer(const Block& block); +bool ReductionIterNotIndexOutputBuffer(const SBlock& block); /*! * \brief Given a list of reduction identities and a list of reduction combiners, detect the @@ -616,7 +616,7 @@ bool HasOp(const Stmt& stmt, const ffi::Array& ops); * 1) IfThenElse statement * 2) Select expression * 3) The operator `tir.if_then_else` - * 4) non-constant-true Block predicates + * 4) non-constant-true SBlock predicates * \param stmt The AST statement to be checked * \return A boolean indicating whether the statement contains the if-then-else pattern */ @@ -790,9 +790,9 @@ class AutoTensorizeMappingInfoNode : public Object { ffi::Map lhs_buffer_map; /*! \brief Buffer indices on RHS */ ffi::Map> rhs_buffer_indices; - /*! \brief Block iters on LHS */ + /*! \brief SBlock iters on LHS */ ffi::Array lhs_iters; - /*! \brief Block iters on RHS */ + /*! \brief SBlock iters on RHS */ ffi::Array rhs_iters; static void RegisterReflection() { @@ -841,7 +841,7 @@ ffi::Optional GetAutoTensorizeMappingInfo(const Schedu * \param desc_func The prim func describing the computation to be tensorized * \return true if basic conditions are met. */ -bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& block_rv, +bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::SBlockRV& block_rv, const tir::PrimFunc& desc_func); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 75cbd5f3e4c1..ee193b80b3cb 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -37,7 +37,7 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* func = base_func.as()) { - if (const auto* realize = func->body.as()) { + if (const auto* realize = func->body.as()) { if (realize->block.get() == root_block) { if (result_g_var != nullptr) { *result_g_var = g_var; @@ -73,7 +73,7 @@ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, class NotStagePipelineError : public ScheduleError { public: - explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit NotStagePipelineError(IRModule mod, SBlock block) : mod_(mod), block_(block) {} IRModule mod() const final { return mod_; } ffi::String FastErrorString() const final { return "ScheduleError: The scope root is not a stage pipeline"; @@ -84,12 +84,12 @@ Definition of a scope that is a stage pipeline: - The region cover property holds for every of its child blocks - No write-after-read dependency or opaque dependency, - only read-after-write and write-after-write are allowed -- All the statements in the scope are schedulable statements, i.e. Block and For +- All the statements in the scope are schedulable statements, i.e. SBlock and For )"; } ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; }; StmtSRef scope_root_sref{nullptr}; @@ -99,7 +99,7 @@ Definition of a scope that is a stage pipeline: const StmtSRefNode* p = sref->parent; const StmtSRefNode* subtree = sref.get(); for (; p != nullptr; subtree = p, p = p->parent) { - if (p->stmt->IsInstance()) { + if (p->stmt->IsInstance()) { scope_root_sref = ffi::GetRef(p); scope_root_subtree = ffi::GetRef(subtree); break; @@ -111,19 +111,19 @@ Definition of a scope that is a stage pipeline: } // Step 2. Handle `require_stage_pipeline` if (require_stage_pipeline && self->enable_check) { - bool stage_pipeline = self->GetBlockInfo(scope_root_sref).stage_pipeline; + bool stage_pipeline = self->GetSBlockInfo(scope_root_sref).stage_pipeline; if (stage_pipeline == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); - throw NotStagePipelineError(self->mod, ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(scope_root_sref); + throw NotStagePipelineError(self->mod, ffi::GetRef(block)); } } return scope_root_sref; } -ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const SBlock& scope_block) { struct Collector : public StmtVisitor { - void VisitStmt_(const BlockRealizeNode* realize) final { - result.realizes.push_back(ffi::GetRef(realize)); + void VisitStmt_(const SBlockRealizeNode* realize) final { + result.realizes.push_back(ffi::GetRef(realize)); const ffi::Array& iter_vars = realize->block->iter_vars; const ffi::Array& iter_values = realize->iter_values; ICHECK_EQ(iter_vars.size(), iter_values.size()); @@ -177,22 +177,22 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, const StmtSRef& block_sref) { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; CheckSRefHigherOrEqual(scope_root_sref, block_sref); - const BlockNode* maybe_root_block = scope_root_sref->StmtAs(); + const SBlockNode* maybe_root_block = scope_root_sref->StmtAs(); if (maybe_root_block) { - BlockScope scope = self->GetBlockScope(scope_root_sref); + SBlockScope scope = self->GetSBlockScope(scope_root_sref); buffer_writers = scope->buffer_writers; } else { // Collect all child blocks of root sub-tree, and merge their buffer writers. ffi::Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); for (const StmtSRef& child_block_sref : child_block_srefs) { - BlockScope child_scope = self->GetBlockScope(child_block_sref); + SBlockScope child_scope = self->GetSBlockScope(child_block_sref); for (const auto& it : child_scope->buffer_writers) { buffer_writers.insert(it); } } } // Check whether the input block is the only writer of its outputs - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); for (const BufferRegion& write_region : block->writes) { if (buffer_writers.count(write_region->buffer)) { if (buffer_writers.at(write_region->buffer).size() != 1) { @@ -215,7 +215,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { // Cond 1. All block vars are data parallel - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { return 1; @@ -273,7 +273,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { class IncompleteBlockError : public ScheduleError { public: - explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) + explicit IncompleteBlockError(IRModule mod, SBlock block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} ffi::String FastErrorString() const final { return "ScheduleError: Incomplete block"; } ffi::String DetailRenderTemplate() const final { @@ -285,14 +285,14 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; int violated_cond_; }; int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw IncompleteBlockError(self->mod, ffi::GetRef(block), error_code); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw IncompleteBlockError(self->mod, ffi::GetRef(block), error_code); } } @@ -307,7 +307,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, */ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { return 1; @@ -327,7 +327,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. - return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)) ? 0 : 5; + return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)) ? 0 : 5; } bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -337,17 +337,17 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tir.schedule.IsReductionBlock", [](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) { - return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); - }); + refl::GlobalDef().def("tir.schedule.IsReductionBlock", [](Schedule sch, SBlockRV block_rv, + SBlockRV scope_block_rv) { + return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); + }); } void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { class NotReductionBlockError : public ScheduleError { public: - explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) + explicit NotReductionBlockError(IRModule mod, SBlock block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} ffi::String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } ffi::String DetailRenderTemplate() const final { @@ -359,14 +359,14 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; int violated_cond_; }; int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotReductionBlockError(self->mod, ffi::GetRef(block), error_code); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw NotReductionBlockError(self->mod, ffi::GetRef(block), error_code); } } @@ -374,7 +374,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl const StmtSRef& scope_root_sref) { class NotCompleteOrReductionBlockError : public ScheduleError { public: - explicit NotCompleteOrReductionBlockError(IRModule mod, Block block, + explicit NotCompleteOrReductionBlockError(IRModule mod, SBlock block, int complete_block_error_code, int reduction_block_error_code) : mod_(mod), @@ -399,7 +399,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; int complete_block_error_code_; int reduction_block_error_code_; }; @@ -412,22 +412,22 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl if (reduction_block_error_code == 0) { return; } - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotCompleteOrReductionBlockError(self->mod, ffi::GetRef(block), + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw NotCompleteOrReductionBlockError(self->mod, ffi::GetRef(block), complete_block_error_code, reduction_block_error_code); } void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { class NotCompactDataFlowError : public ScheduleError { public: - explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block, + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, SBlock violate_block, int local_complete_block_code, int local_reduction_block_code) : mod_(std::move(mod)), subtree_root_(std::move(subtree_root)), violate_block_(std::move(violate_block)), local_complete_block_code_(local_complete_block_code), local_reduction_block_code_(local_reduction_block_code) { - ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); } ffi::String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " @@ -454,7 +454,7 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt IRModule mod_; Stmt subtree_root_; - Block violate_block_; + SBlock violate_block_; int local_complete_block_code_; int local_reduction_block_code_; }; @@ -464,9 +464,9 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); if (local_complete_block_code != 0 && local_reduction_block_code != 0) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); throw NotCompactDataFlowError(self->mod, ffi::GetRef(subtree_root->stmt), - ffi::GetRef(block), local_complete_block_code, + ffi::GetRef(block), local_complete_block_code, local_reduction_block_code); } } @@ -474,8 +474,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* scope_root = TVM_SREF_TO_SBLOCK(scope_root_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); std::unordered_set scope_allocated; scope_allocated.reserve(scope_root->alloc_buffers.size()); for (const Buffer& buffer : scope_root->alloc_buffers) { @@ -493,7 +493,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { class OutputBlockError : public ScheduleError { public: - explicit OutputBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit OutputBlockError(IRModule mod, SBlock block) : mod_(mod), block_(block) {} ffi::String FastErrorString() const final { return "ScheduleError: Cannot operate on an output block"; } @@ -502,15 +502,15 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; }; if (IsOutputBlock(self, block_sref, scope_root_sref)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw OutputBlockError(self->mod, ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw OutputBlockError(self->mod, ffi::GetRef(block)); } } -std::vector GetBlockVarTypes(const BlockNode* block) { +std::vector GetSBlockVarTypes(const SBlockNode* block) { std::vector results; results.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { @@ -519,13 +519,13 @@ std::vector GetBlockVarTypes(const BlockNode* block) { return results; } -std::vector GetBlockVarTypes(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - return GetBlockVarTypes(block); +std::vector GetSBlockVarTypes(const StmtSRef& block_sref) { + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + return GetSBlockVarTypes(block); } bool IsWriteCache(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); if (block->writes.size() != 1) { return false; } @@ -547,7 +547,7 @@ bool IsWriteCache(const StmtSRef& block_sref) { /******** Binding ********/ -bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& loop_var_ranges, +bool IsAffineBinding(const SBlockRealize& realize, const ffi::Map& loop_var_ranges, arith::Analyzer* analyzer) { if (loop_var_ranges.empty()) { return true; @@ -571,11 +571,11 @@ bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& lo return true; } -void CheckPartialAffineBinding(const ScheduleState& self, Block block, +void CheckPartialAffineBinding(const ScheduleState& self, SBlock block, const ffi::Optional& high_exclusive) { class NotAffineBindingError : public ScheduleError { public: - explicit NotAffineBindingError(IRModule mod, Block block, + explicit NotAffineBindingError(IRModule mod, SBlock block, ffi::Optional high_exclusive) : mod_(std::move(mod)), block_(std::move(block)) { if (high_exclusive.defined()) { @@ -605,7 +605,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; const ForNode* high_exclusive_loop_{nullptr}; }; @@ -619,21 +619,21 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, arith::Analyzer analyzer; ffi::Map dom_map = LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent), high_exclusive); - if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) { + if (IsAffineBinding(GetSBlockRealize(self, block_sref), dom_map, &analyzer)) { return; } } throw NotAffineBindingError(self->mod, std::move(block), high_exclusive); } -void CheckAffineBinding(const ScheduleState& self, Block block) { +void CheckAffineBinding(const ScheduleState& self, SBlock block) { CheckPartialAffineBinding(self, std::move(block), std::nullopt); } void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { class NotTrivialBindingError : public ScheduleError { public: - explicit NotTrivialBindingError(IRModule mod, Block block) + explicit NotTrivialBindingError(IRModule mod, SBlock block) : mod_(std::move(mod)), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -651,11 +651,11 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc private: IRModule mod_; - Block block_; + SBlock block_; }; if (!IsTrivialBinding(self, block_sref)) { - throw NotTrivialBindingError(self->mod, ffi::GetRef(block_sref->StmtAs())); + throw NotTrivialBindingError(self->mod, ffi::GetRef(block_sref->StmtAs())); } } @@ -688,8 +688,8 @@ ffi::Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, return result; } -ffi::Map GetBindings(const BlockRealize& realize) { - const BlockNode* block = realize->block.get(); +ffi::Map GetBindings(const SBlockRealize& realize) { + const SBlockNode* block = realize->block.get(); const ffi::Array& all_lhs = block->iter_vars; const ffi::Array& all_rhs = realize->iter_values; ICHECK_EQ(all_lhs.size(), all_rhs.size()); @@ -702,10 +702,10 @@ ffi::Map GetBindings(const BlockRealize& realize) { return result; } -bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, +bool GetVarsTouchedByBlockIters(const SBlockRealize& block_realize, std::unordered_set* data_par_vars, std::unordered_set* reduce_vars) { - Block block = block_realize->block; + SBlock block = block_realize->block; ICHECK(block_realize->block.same_as(block)) << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the " "input block"; @@ -769,49 +769,49 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre ffi::Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref) { - ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); ffi::Array child_block_srefs; child_block_srefs.reserve(child_block_realize.size()); - for (BlockRealize realize : child_block_realize) { + for (SBlockRealize realize : child_block_realize) { child_block_srefs.push_back(self->stmt2ref.at(realize->block.get())); } return child_block_srefs; } -ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { +ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { - static ffi::Array Collect(const Stmt& stmt) { + static ffi::Array Collect(const Stmt& stmt) { Collector collector; collector(stmt); return std::move(collector.result_); } - void VisitStmt_(const BlockRealizeNode* block_realize) final { - result_.push_back(ffi::GetRef(block_realize)); + void VisitStmt_(const SBlockRealizeNode* block_realize) final { + result_.push_back(ffi::GetRef(block_realize)); } - ffi::Array result_; + ffi::Array result_; }; if (parent_sref->stmt->IsInstance()) { const auto* loop = static_cast(parent_sref->stmt); return Collector::Collect(loop->body); - } else if (parent_sref->stmt->IsInstance()) { - const auto* block = static_cast(parent_sref->stmt); + } else if (parent_sref->stmt->IsInstance()) { + const auto* block = static_cast(parent_sref->stmt); return Collector::Collect(block->body); } ICHECK(false) << "Unreachable"; throw; } -BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, - const StmtSRef& parent_sref) { +SBlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { class NonSingleChildBlockError : public ScheduleError { public: explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) : mod_(std::move(mod)), stmt_(ffi::GetRef(sref->stmt)) { - sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; + sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; } ffi::String FastErrorString() const final { @@ -834,17 +834,17 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self ffi::String sref_type_; }; - ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); if (child_block_realize.size() != 1) { throw NonSingleChildBlockError(self->mod, parent_sref); } return child_block_realize[0]; } -BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) { +SBlockRealize GetSBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) { struct BlockRealizeFinder : public StmtVisitor { - explicit BlockRealizeFinder(const BlockNode* target_block) - : target_block(target_block), result(nullptr) {} + explicit BlockRealizeFinder(const SBlockNode* target_sblock) + : target_sblock(target_sblock), result(nullptr) {} void VisitStmt(const Stmt& stmt) final { if (result != nullptr) { @@ -853,34 +853,34 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr StmtVisitor::VisitStmt(stmt); } - void VisitStmt_(const BlockRealizeNode* block_realize) final { - if (block_realize->block.get() == target_block) { + void VisitStmt_(const SBlockRealizeNode* block_realize) final { + if (block_realize->block.get() == target_sblock) { result = block_realize; } // No need to visit recursively, since the deeper BlockRealizes must not be the result. } - const BlockNode* target_block; - const BlockRealizeNode* result; + const SBlockNode* target_sblock; + const SBlockRealizeNode* result; }; - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); if (block_sref->parent == nullptr) { const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); - return Downcast(func->body); + return Downcast(func->body); } else { BlockRealizeFinder finder(block); finder(ffi::GetRef(block_sref->parent->stmt)); ICHECK(finder.result != nullptr) - << "InternalError: Cannot find the BlockRealize of block " << ffi::GetRef(block); - return ffi::GetRef(finder.result); + << "InternalError: Cannot find the BlockRealize of block " << ffi::GetRef(block); + return ffi::GetRef(finder.result); } } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.GetBlockRealize", [](Schedule sch, BlockRV block_rv) { - return GetBlockRealize(sch->state(), sch->GetSRef(block_rv)); + refl::GlobalDef().def("tir.schedule.GetSBlockRealize", [](Schedule sch, SBlockRV block_rv) { + return GetSBlockRealize(sch->state(), sch->GetSRef(block_rv)); }); } @@ -891,8 +891,8 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref) { int n_reduce = 0; int n_other = 0; auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { - if (const auto* realize = obj.as()) { - const BlockNode* block = realize->block.get(); + if (const auto* realize = obj.as()) { + const SBlockNode* block = realize->block.get(); // Number of block vars and their bindings ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); size_t n = realize->iter_values.size(); @@ -1039,7 +1039,7 @@ std::pair, std::vector> CollectComputeLocation( /******** Producer-consumer relation ********/ -ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { +ffi::Array GetProducers(const StmtSRef& block_sref, const SBlockScope& scope) { ffi::Array edges = scope->GetDepsByDst(block_sref); ffi::Array results; std::unordered_set result_set; @@ -1054,7 +1054,7 @@ ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& return results; } -ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { +ffi::Array GetConsumers(const StmtSRef& block_sref, const SBlockScope& scope) { ffi::Array edges = scope->GetDepsBySrc(block_sref); ffi::Array results; std::unordered_set result_set; @@ -1069,11 +1069,11 @@ ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& return results; } -ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { - struct OutputBlockCollector : public StmtVisitor { - explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {} +ffi::Array GetOutputBlocks(const ScheduleState& self, const SBlockNode* scope_block) { + struct OutputSBlockCollector : public StmtVisitor { + explicit OutputSBlockCollector(const ScheduleState& self) : self_(self) {} - void VisitStmt_(const BlockNode* block) override { + void VisitStmt_(const SBlockNode* block) override { auto it = self_->stmt2ref.find(block); ICHECK(it != self_->stmt2ref.end()); auto block_sref = it->second; @@ -1090,7 +1090,7 @@ ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* const ScheduleState& self_; ffi::Array results_; }; - OutputBlockCollector collector(self); + OutputSBlockCollector collector(self); collector(scope_block->body); auto results = collector.results_; return results; @@ -1100,7 +1100,7 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( const ScheduleState& self, const ffi::Array& subtrees, const ffi::Array& producer_block_srefs, const ffi::Array& consumer_block_srefs, - std::unordered_map* block2realize) { + std::unordered_map* block2realize) { class InsertionPointNotFoundError : public ScheduleError { public: explicit InsertionPointNotFoundError(IRModule mod, int last_producer_position, @@ -1134,8 +1134,8 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( class Finder : public StmtVisitor { public: - void VisitStmt_(const BlockRealizeNode* realize) final { - const BlockNode* block = realize->block.get(); + void VisitStmt_(const SBlockRealizeNode* realize) final { + const SBlockNode* block = realize->block.get(); if (block2realize_) { block2realize_->emplace(block, realize); } @@ -1147,7 +1147,7 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( } } - std::unordered_map* block2realize_; + std::unordered_map* block2realize_; std::unordered_set producer_blocks_; std::unordered_set consumer_blocks_; int n_producers_visited_ = 0; @@ -1196,11 +1196,11 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( /******** Block-buffer relation ********/ -BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n, +BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const SBlock& block, int n, BufferIndexType index_type) { class BufferIndexOutOfRangeError : public ScheduleError { public: - explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, + explicit BufferIndexOutOfRangeError(IRModule mod, SBlock block, int buffer_index, BufferIndexType index_type) : mod_(std::move(mod)), block_(std::move(block)), @@ -1237,7 +1237,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl private: IRModule mod_; - Block block_; + SBlock block_; int buffer_index_; BufferIndexType index_type_; }; @@ -1251,7 +1251,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl return access_region[n]; } -Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, +Buffer GetNthAccessBuffer(const ScheduleState& self, const SBlock& block, int n, BufferIndexType index_type) { return GetNthAccessBufferRegion(self, block, n, index_type)->buffer; } @@ -1262,7 +1262,7 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& b // match_buffers. const StmtSRefNode* defining_site_sref = block_sref.get(); while (defining_site_sref != nullptr) { - const auto* block = defining_site_sref->StmtAs(); + const auto* block = defining_site_sref->StmtAs(); // If this sref is not a block sref, skip it. if (block == nullptr) { defining_site_sref = defining_site_sref->parent; @@ -1340,7 +1340,7 @@ bool HasIfThenElse(const Stmt& stmt) { // stop visiting return false; } - if (const auto* realize = obj.as()) { + if (const auto* realize = obj.as()) { // Case 1: BlockRealize if (!is_one(realize->predicate)) { has_branch = true; @@ -1475,7 +1475,7 @@ void CheckStorageScope(const ScheduleState& self, ffi::String storage_scope) { } bool IsSpatial(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != IterVarType::kDataPar) { return false; @@ -1485,9 +1485,9 @@ bool IsSpatial(const StmtSRef& block_sref) { } bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { - TVM_SREF_TO_BLOCK(block_sref); + TVM_SREF_TO_SBLOCK(block_sref); ffi::Array loops = GetLoops(block_sref); - ffi::Array binds = GetBlockRealize(self, block_sref)->iter_values; + ffi::Array binds = GetSBlockRealize(self, block_sref)->iter_values; if (loops.size() != binds.size()) { return false; } @@ -1502,7 +1502,7 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.IsTrivialBinding", [](Schedule sch, BlockRV block_rv) { + refl::GlobalDef().def("tir.schedule.IsTrivialBinding", [](Schedule sch, SBlockRV block_rv) { return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); }); } @@ -1511,7 +1511,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref if (HasBeenMultiLevelTiled(block_sref)) { return false; } - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || !IsTrivialBinding(self, block_sref)) { return false; @@ -1574,7 +1574,7 @@ bool IsSpatialPrimFunc(const PrimFunc& func) { if (result == false) { return false; } - if (const auto* block = obj.as()) { + if (const auto* block = obj.as()) { for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != IterVarType::kDataPar) { result = false; @@ -1623,7 +1623,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // const tir::StmtSRef& block_sref, // int64_t max_parallel_extent, // int64_t max_parallel_basic) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); ffi::Array loops = tir::GetLoops(block_sref); // Cond 1. The block must have at lease one write buffer @@ -1665,7 +1665,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // return false; } } else { - const auto* block_realize = loop_i->body.as(); + const auto* block_realize = loop_i->body.as(); if (!block_realize || block_realize->block.get() != block) { return false; } @@ -1705,7 +1705,7 @@ struct TensorIntrinDescInfo { /*! \brief The block of the description function, which is the (unique) direct child of the root * block. */ - const BlockRealizeNode* desc_block = nullptr; + const SBlockRealizeNode* desc_block = nullptr; /*! \brief The loops of the description function, in the order from outer loops to inner ones. */ std::vector desc_loops; /*! \brief The loop variables. */ @@ -1721,12 +1721,12 @@ struct TensorIntrinDescInfo { TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, const PrimFunc& desc_func) { TensorIntrinDescInfo info; - const auto* desc_scope_realize = desc_func->body.as(); + const auto* desc_scope_realize = desc_func->body.as(); ICHECK(desc_scope_realize); { auto f_visit = [&](const ObjectRef& obj) -> bool { // Extract the block - if (const auto* block = obj.as()) { + if (const auto* block = obj.as()) { info.desc_block = block; return false; } @@ -1752,12 +1752,12 @@ ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& s const tir::PrimFunc& desc_func, bool allow_padding) { arith::Analyzer analyzer; - const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + const tir::SBlockRealize& block = tir::GetSBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); - TVM_SREF_TO_BLOCK(scope_sref); + TVM_SREF_TO_SBLOCK(scope_sref); std::vector block_loops; std::unordered_set block_loop_vars; { @@ -1777,7 +1777,7 @@ ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& s // Step 3. Map from block loops to desc block loops const std::vector& desc_loops = desc_info.desc_loops; const std::unordered_set& desc_loop_vars = desc_info.desc_loop_vars; - const BlockRealizeNode* desc_block = desc_info.desc_block; + const SBlockRealizeNode* desc_block = desc_info.desc_block; ObjectPtr ret = ffi::make_object(); const int n_block_vars = block->iter_values.size(); const int n_desc_vars = desc_block->iter_values.size(); @@ -1789,8 +1789,8 @@ ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& s return std::nullopt; } - const std::vector iter_types_block = GetBlockVarTypes(block_sref); - const std::vector iter_types_desc = GetBlockVarTypes(desc_block->block.get()); + const std::vector iter_types_block = GetSBlockVarTypes(block_sref); + const std::vector iter_types_desc = GetSBlockVarTypes(desc_block->block.get()); ICHECK(desc_loops.size() == static_cast(n_desc_vars)); ICHECK(block_loops.size() == iter_types_block.size()); @@ -1912,7 +1912,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.IsSpatialPrimFunc", IsSpatialPrimFunc) - .def("tir.schedule.GetTensorizeLoopMapping", [](Schedule sch, BlockRV block, + .def("tir.schedule.GetTensorizeLoopMapping", [](Schedule sch, SBlockRV block, PrimFunc desc_func, bool allow_padding) { return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); @@ -2106,14 +2106,14 @@ bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRe // Step 1. Analyze desc_func, extract its block, loops and loop vars // Step 2. Check if `desc_block` matches `block` // Ignore the scope of buffers when comparing, since we can do cache_read/write - const BlockRealize& block = tir::GetBlockRealize(state, block_sref); + const SBlockRealize& block = tir::GetSBlockRealize(state, block_sref); arith::Analyzer analyzer; auto desc_info = tir::ExtractTensorIntrinDescInfo(&analyzer, desc_func); return extractor->VisitStmt(block->block, desc_info.desc_block->block); } -bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& block_rv, +bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::SBlockRV& block_rv, const tir::PrimFunc& desc_func) { AutoTensorizeComparator extractor(sch->state()->mod); return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor); @@ -2145,12 +2145,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.GetAutoTensorizeMappingInfo", - [](Schedule sch, BlockRV block, PrimFunc desc_func) { + [](Schedule sch, SBlockRV block, PrimFunc desc_func) { return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }) .def("tir.schedule.HasBlock", HasBlock) .def("tir.schedule.IsOutputBlock", - [](Schedule sch, BlockRV block) { + [](Schedule sch, SBlockRV block) { auto state = sch->state(); auto block_sref = sch->GetSRef(block); return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index 085a4a33de87..17e668d13ac8 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -286,7 +286,7 @@ class PatternMatcher : public ExprVisitor { std::unordered_map filled_map_; }; -/******** Reduction Block Related ********/ +/******** Reduction SBlock Related ********/ static const char* kRFactorCrossThreadReductionApplicableBlockDef = R"(Definition of a reduction block that is applicable by RFactor and Cross-Thread Reduction: @@ -304,10 +304,10 @@ static const char* kRFactorCrossThreadReductionApplicableBlockDef = 12) The indices of all BufferStores in the reduction block should be the same)"; void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional& self, - Block block, int violated_cond) { + SBlock block, int violated_cond) { class RFactorNotApplicableError : public ScheduleError { public: - explicit RFactorNotApplicableError(IRModule mod, Block block, int violated_cond) + explicit RFactorNotApplicableError(IRModule mod, SBlock block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} ffi::String FastErrorString() const final { @@ -327,7 +327,7 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; int violated_cond_; }; @@ -352,7 +352,7 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional& self, Block block, +void ExtractReductionUpdates(const ffi::Optional& self, SBlock block, const LetStmtNode* let, int n_buffers, ffi::Array* updates, std::unordered_map* buf2index) { @@ -429,7 +429,7 @@ void ExtractReductionUpdates(const ffi::Optional& self, Block blo } std::pair, ffi::Array> GetInitValuesAndUpdatesFromReductionBlock( - const ffi::Optional& self, Block block) { + const ffi::Optional& self, SBlock block) { ffi::Array inits; ffi::Array updates; @@ -522,7 +522,7 @@ bool ContainsOnlyDataParAndReductionBlockIter(const ffi::Array& iters) return true; } -bool ReductionIterNotIndexOutputBuffer(const Block& block) { +bool ReductionIterNotIndexOutputBuffer(const SBlock& block) { // Step 1. Collect the reduction block iters. std::unordered_set reduction_block_iters; reduction_block_iters.reserve(block->iter_vars.size()); @@ -559,7 +559,7 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) { if (affected) { return false; } - const auto* block_node = obj.as(); + const auto* block_node = obj.as(); if (block_node) { for (const MatchBufferRegion& region : block_node->match_buffers) { match_buffer_sources[region->buffer.get()] = region->source->buffer.get(); diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index f9a09552c21c..77c6bb605c8b 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -47,7 +47,7 @@ class SRefTreeVerifier : public StmtVisitor { ICHECK_EQ(n_block_sref_visited_, static_cast(self_->block_info.size())); } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { if (init_block_depth_) { ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its " "corresponding sref, which is not allowed"; @@ -115,7 +115,7 @@ class SRefTreeVerifier : public StmtVisitor { for (int i = 0; i < n; ++i) { const Stmt& child = seq_stmt->seq[i]; StmtSRef sref{nullptr}; - if (const auto* realize = child.as()) { + if (const auto* realize = child.as()) { const auto* block = realize->block.get(); ICHECK(self_->stmt2ref.count(block)); sref = self_->stmt2ref.at(block); @@ -157,13 +157,13 @@ void VerifyCachedFlags(const ScheduleState& self) { if (stmt->IsInstance() || !self->stmt2ref.count(stmt)) { continue; } - const BlockInfo& new_block_info = new_state->block_info.at(new_sref); + const SBlockInfo& new_block_info = new_state->block_info.at(new_sref); const StmtSRef& old_sref = self->stmt2ref.at(stmt); if (!self->block_info.count(old_sref)) { block_info_not_found.push_back(new_sref); continue; } - const BlockInfo& old_block_info = self->block_info.at(old_sref); + const SBlockInfo& old_block_info = self->block_info.at(old_sref); if (new_block_info.affine_binding != old_block_info.affine_binding) { block_info_wrong_affine_binding.emplace_back(new_sref, // new_block_info.affine_binding, @@ -191,9 +191,9 @@ void VerifyCachedFlags(const ScheduleState& self) { } std::ostringstream os; if (has_not_found) { - os << "- BlockInfo not found:"; + os << "- SBlockInfo not found:"; for (const StmtSRef& block_sref : block_info_not_found) { - const auto* block = block_sref->StmtAs(); + const auto* block = block_sref->StmtAs(); ICHECK(block); os << " " << block->name_hint; } @@ -205,7 +205,7 @@ void VerifyCachedFlags(const ScheduleState& self) { const StmtSRef& block_sref = std::get<0>(record); bool expected = std::get<1>(record); bool actual = std::get<2>(record); - const auto* block = block_sref->StmtAs(); + const auto* block = block_sref->StmtAs(); ICHECK(block); os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; } @@ -217,7 +217,7 @@ void VerifyCachedFlags(const ScheduleState& self) { const StmtSRef& block_sref = std::get<0>(record); bool expected = std::get<1>(record); bool actual = std::get<2>(record); - const auto* block = block_sref->StmtAs(); + const auto* block = block_sref->StmtAs(); ICHECK(block); os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; } @@ -229,7 +229,7 @@ void VerifyCachedFlags(const ScheduleState& self) { const StmtSRef& block_sref = std::get<0>(record); bool expected = std::get<1>(record); bool actual = std::get<2>(record); - const auto* block = block_sref->StmtAs(); + const auto* block = block_sref->StmtAs(); ICHECK(block); os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 00f421e733e2..e4a236e2ce0e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -139,17 +139,17 @@ class ScheduleCopier { } /*! \brief Copy SMap */ - SMap Copy(const SMap& scopes) { - SMap result; + SMap Copy(const SMap& scopes) { + SMap result; for (const auto& kv : scopes) { const StmtSRef& old_sref = kv.first; - const BlockInfo& old_info = kv.second; - BlockInfo new_info = old_info; - ObjectPtr scope = ffi::make_object(); + const SBlockInfo& old_info = kv.second; + SBlockInfo new_info = old_info; + ObjectPtr scope = ffi::make_object(); scope->src2deps = Copy(old_info.scope->src2deps); scope->dst2deps = Copy(old_info.scope->dst2deps); scope->buffer_writers = Copy(old_info.scope->buffer_writers); - new_info.scope = BlockScope(std::move(scope)); + new_info.scope = SBlockScope(std::move(scope)); result[Copy(old_sref)] = std::move(new_info); } return result; @@ -264,7 +264,7 @@ ffi::Array ConcreteScheduleNode::SamplePartitionedTile( throw; } -LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, +LoopRV ConcreteScheduleNode::SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV( @@ -275,16 +275,16 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV ConcreteScheduleNode::GetBlock(const ffi::String& name, - const ffi::Optional& func_name) { +SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String& name, + const ffi::Optional& func_name) { class NotSingleResult : public ScheduleError { public: explicit NotSingleResult(ffi::String name, IRModule mod, const ffi::Array& blocks) : name_(name), mod_(mod), blocks_{} { blocks_.reserve(blocks.size()); for (const StmtSRef& block_sref : blocks) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - blocks_.push_back(ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + blocks_.push_back(ffi::GetRef(block)); } } @@ -311,7 +311,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const ffi::String& name, ffi::String name_; IRModule mod_; - ffi::Array blocks_; + ffi::Array blocks_; }; GlobalVar gv = NullValue(); if (func_name.has_value()) { @@ -319,58 +319,58 @@ BlockRV ConcreteScheduleNode::GetBlock(const ffi::String& name, } else if (func_working_on_.has_value()) { gv = this->func_working_on_.value(); } else { - LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " + LOG(FATAL) << "ValueError: `get_sblock` does not know which function to be working on. Please " "specify the function name explicitly, or call `work_on` to specify the function " - "before using `get_block`."; + "before using `get_sblock`."; } - ffi::Array blocks = tir::GetBlocks(this->state_, name, gv); + ffi::Array blocks = tir::GetSBlocks(this->state_, name, gv); if (blocks.size() != 1) { TVM_TIR_SCHEDULE_BEGIN(); throw NotSingleResult(name, this->state_->mod, blocks); TVM_TIR_SCHEDULE_END("get-block", this->error_render_level_); } - return CreateRV(blocks[0]); + return CreateRV(blocks[0]); } -ffi::Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetLoops(const SBlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } -ffi::Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - ffi::Array result; +ffi::Array ConcreteScheduleNode::GetChildBlocks(const SBlockRV& block_rv) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); - result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); this->state_->DebugVerify(); return result; } -ffi::Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - ffi::Array result; +ffi::Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); - result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); this->state_->DebugVerify(); return result; } -ffi::Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetProducers(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); + return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_); throw; } -ffi::Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetConsumers(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); + return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_); throw; } -ffi::Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { +ffi::Array ConcreteScheduleNode::GetOutputBlocks(const SBlockRV& scope_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); + return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); throw; } @@ -595,7 +595,7 @@ void ConcreteScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { this->state_->DebugVerify(); } -void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, +void ConcreteScheduleNode::ReorderBlockIterVar(const SBlockRV& block_rv, const ffi::Array new_order) { TVM_TIR_SCHEDULE_BEGIN(); tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); @@ -603,7 +603,7 @@ void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, this->state_->DebugVerify(); } -LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { +LoopRV ConcreteScheduleNode::AddUnitLoop(const SBlockRV& block_rv) { LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); @@ -657,13 +657,13 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { /******** Schedule: Insert cache stages ********/ -BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks) { +SBlockRV ConcreteScheduleNode::CacheRead(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. ffi::Array consumer_block_refs = {}; - for (BlockRV block : consumer_blocks) { + for (SBlockRV block : consumer_blocks) { consumer_block_refs.push_back(this->GetSRef(block)); } TVM_TIR_SCHEDULE_BEGIN(); @@ -671,16 +671,16 @@ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer consumer_block_refs); TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks) { +SBlockRV ConcreteScheduleNode::CacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. ffi::Array consumer_block_refs = {}; - for (BlockRV block : consumer_blocks) { + for (SBlockRV block : consumer_blocks) { consumer_block_refs.push_back(this->GetSRef(block)); } TVM_TIR_SCHEDULE_BEGIN(); @@ -688,99 +688,99 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff consumer_block_refs); TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const IndexMap& index_map) { +SBlockRV ConcreteScheduleNode::ReindexCacheRead(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, index_map); TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const ffi::String& storage_scope, - const IndexMap& index_map) { +SBlockRV ConcreteScheduleNode::ReindexCacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope, index_map); TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -ffi::Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, - int write_buffer_index, - const ffi::String& storage_scope) { +ffi::Array ConcreteScheduleNode::CacheInplace(const SBlockRV& block_rv, + int write_buffer_index, + const ffi::String& storage_scope) { ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_); this->state_->DebugVerify(); - ffi::Array return_blocks; - return_blocks.push_back(CreateRV(results[0])); - return_blocks.push_back(CreateRV(results[1])); + ffi::Array return_blocks; + return_blocks.push_back(CreateRV(results[0])); + return_blocks.push_back(CreateRV(results[1])); return return_blocks; } -ffi::Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, - const ffi::String& storage_scope, - int cse_thresh) { +ffi::Array ConcreteScheduleNode::CacheIndex(const SBlockRV& block_rv, + const ffi::String& storage_scope, + int cse_thresh) { ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_); this->state_->DebugVerify(); - ffi::Array return_blocks; + ffi::Array return_blocks; for (const StmtSRef& blockrv : result) { - return_blocks.push_back(CreateRV(blockrv)); + return_blocks.push_back(CreateRV(blockrv)); } return return_blocks; } -BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) { +SBlockRV ConcreteScheduleNode::ReIndex(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } /******** Schedule: Data movement ********/ -BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int read_buffer_index, const ffi::String& storage_scope) { +SBlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const SBlockRV& block_rv, + int read_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int write_buffer_index, const ffi::String& storage_scope) { +SBlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const SBlockRV& block_rv, + int write_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } /******** Schedule: Compute location ********/ -void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, +void ConcreteScheduleNode::ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index) { static StmtSRef inline_mark = StmtSRef::InlineMark(); static StmtSRef root_mark = StmtSRef::RootMark(); @@ -799,7 +799,7 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop this->state_->DebugVerify(); } -void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, +void ConcreteScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index) { static StmtSRef inline_mark = StmtSRef::InlineMark(); static StmtSRef root_mark = StmtSRef::RootMark(); @@ -818,22 +818,22 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR this->state_->DebugVerify(); } -void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { +void ConcreteScheduleNode::ComputeInline(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); tir::ComputeInline(state_, this->GetSRef(block_rv)); TVM_TIR_SCHEDULE_END("compute-inline", this->error_render_level_); this->state_->DebugVerify(); } -void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { +void ConcreteScheduleNode::ReverseComputeInline(const SBlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); TVM_TIR_SCHEDULE_END("reverse-compute-inline", this->error_render_level_); this->state_->DebugVerify(); } -void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, - const BlockRV& epilogue_block_rv) { +void ConcreteScheduleNode::FuseReductionEpilogue(const SBlockRV& reduction_block_rv, + const SBlockRV& epilogue_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv), this->GetSRef(epilogue_block_rv)); @@ -841,9 +841,9 @@ void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_ this->state_->DebugVerify(); } -/******** Schedule: Block Annotation ********/ +/******** Schedule: SBlock Annotation ********/ -void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, +void ConcreteScheduleNode::StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor, int offset) { TVM_TIR_SCHEDULE_BEGIN(); tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); @@ -851,7 +851,7 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde this->state_->DebugVerify(); } -void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, +void ConcreteScheduleNode::SetScope(const SBlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) { TVM_TIR_SCHEDULE_BEGIN(); tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); @@ -859,7 +859,7 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, this->state_->DebugVerify(); } -void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, +void ConcreteScheduleNode::UnsafeSetDType(const SBlockRV& block_rv, int buffer_index, const ffi::String& dtype) { TVM_TIR_SCHEDULE_BEGIN(); tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); @@ -869,42 +869,42 @@ void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_in /******** Schedule: Reduction ********/ -BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { +SBlockRV ConcreteScheduleNode::DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { +SBlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); TVM_TIR_SCHEDULE_END("rfactor", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } /******** Schedule: Blockize & Tensorize ********/ -BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { +SBlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); - return CreateRV(result); + return CreateRV(result); } -BlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, - bool preserve_unit_iters) { +SBlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, + bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); - return CreateRV(result); + return CreateRV(result); } void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, @@ -916,7 +916,7 @@ void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& i TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } -void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const ffi::String& intrin, +void ConcreteScheduleNode::Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), @@ -999,7 +999,7 @@ void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } -void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& ann_key, +void ConcreteScheduleNode::Annotate(const SBlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); tir::Annotate(state_, this->GetSRef(block_rv), ann_key, @@ -1008,7 +1008,7 @@ void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } -void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) { +void ConcreteScheduleNode::Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); tir::Unannotate(state_, this->GetSRef(block_rv), ann_key); this->state_->DebugVerify(); @@ -1016,7 +1016,7 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String } /******** Schedule: Layout transformation ********/ -void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, +void ConcreteScheduleNode::TransformLayout(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, const ffi::Optional& pad_value, @@ -1036,7 +1036,7 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); } -void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv, +void ConcreteScheduleNode::TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) { TVM_TIR_SCHEDULE_BEGIN(); tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map); @@ -1044,7 +1044,7 @@ void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv, TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_); } -void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, +void ConcreteScheduleNode::SetAxisSeparator(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array& axis_separators) { TVM_TIR_SCHEDULE_BEGIN(); @@ -1056,16 +1056,16 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_ /******** Schedule: Padding ********/ -BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { +SBlockRV ConcreteScheduleNode::DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); TVM_TIR_SCHEDULE_END("decompose-padding", this->error_render_level_); this->state_->DebugVerify(); - return CreateRV(result); + return CreateRV(result); } -void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) { +void ConcreteScheduleNode::PadEinsum(const SBlockRV& block_rv, const ffi::Array& padding) { TVM_TIR_SCHEDULE_BEGIN(); tir::PadEinsum(state_, this->GetSRef(block_rv), padding); TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_); @@ -1074,7 +1074,7 @@ void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::ArrayGetSRef(block_rv), write_buffer_index); TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); @@ -1083,7 +1083,7 @@ void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buff /******** Schedule: Misc ********/ -void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, +void ConcreteScheduleNode::UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type, const ffi::Array& buf_index_array) { TVM_TIR_SCHEDULE_BEGIN(); @@ -1092,7 +1092,7 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, this->state_->DebugVerify(); } -void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, +void ConcreteScheduleNode::AnnotateBufferAccess(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { TVM_TIR_SCHEDULE_BEGIN(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 7ee54961415b..52591fad4cb2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,15 +68,15 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Lookup random variables ********/ - inline Block Get(const BlockRV& block_rv) const final; + inline SBlock Get(const SBlockRV& block_rv) const final; inline For Get(const LoopRV& loop_rv) const final; inline PrimExpr Get(const ExprRV& expr_rv) const final; - inline StmtSRef GetSRef(const BlockRV& block_rv) const final; + inline StmtSRef GetSRef(const SBlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; - inline bool HasBlock(const BlockRV& block_rv) const final; - inline ffi::Array GetSRefs(const ffi::Array& rvs) const; + inline bool HasBlock(const SBlockRV& block_rv) const final; + inline ffi::Array GetSRefs(const ffi::Array& rvs) const; inline ffi::Array GetSRefs(const ffi::Array& rvs) const; - void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } + void RemoveRV(const SBlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } using ScheduleNode::GetSRef; @@ -91,16 +91,16 @@ class ConcreteScheduleNode : public ScheduleNode { ffi::Array SamplePartitionedTile( const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, ffi::Optional> decision = std::nullopt) override; - LoopRV SampleComputeLocation(const BlockRV& block_rv, + LoopRV SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision = std::nullopt) override; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const ffi::String& name, const ffi::Optional& func_name) override; - ffi::Array GetLoops(const BlockRV& block_rv) override; - ffi::Array GetChildBlocks(const BlockRV& block_rv) override; - ffi::Array GetChildBlocks(const LoopRV& loop_rv) override; - ffi::Array GetProducers(const BlockRV& block_rv) override; - ffi::Array GetConsumers(const BlockRV& block_rv) override; - ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) override; + SBlockRV GetSBlock(const ffi::String& name, const ffi::Optional& func_name) override; + ffi::Array GetLoops(const SBlockRV& block_rv) override; + ffi::Array GetChildBlocks(const SBlockRV& block_rv) override; + ffi::Array GetChildBlocks(const LoopRV& loop_rv) override; + ffi::Array GetProducers(const SBlockRV& block_rv) override; + ffi::Array GetConsumers(const SBlockRV& block_rv) override; + ffi::Array GetOutputBlocks(const SBlockRV& scope_block_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) override; LoopRV Merge(const ffi::Array& loop_rvs) override; @@ -110,8 +110,8 @@ class ConcreteScheduleNode : public ScheduleNode { const ffi::Array>& factors, bool preserve_unit_iters) override; void Reorder(const ffi::Array& ordered_loop_rvs) override; - void ReorderBlockIterVar(const BlockRV& block_rv, const ffi::Array new_order) override; - LoopRV AddUnitLoop(const BlockRV& block_rv) override; + void ReorderBlockIterVar(const SBlockRV& block_rv, const ffi::Array new_order) override; + LoopRV AddUnitLoop(const SBlockRV& block_rv) override; LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; @@ -119,75 +119,77 @@ class ConcreteScheduleNode : public ScheduleNode { void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) override; void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ - BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks = {}) override; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + SBlockRV CacheRead(const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope, - const ffi::Array consumer_blocks = {}) override; - BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, const IndexMap& index_map) override; - BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const ffi::Array consumer_blocks = {}) override; + SBlockRV CacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) override; + SBlockRV ReindexCacheRead(const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope, const IndexMap& index_map) override; - ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) override; - ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, - int cse_thresh) override; - BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) override; + SBlockRV ReindexCacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, const IndexMap& index_map) override; + ffi::Array CacheInplace(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) override; + ffi::Array CacheIndex(const SBlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) override; + SBlockRV ReIndex(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) override; /******** Schedule: Data movement ********/ - BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) override; - BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + SBlockRV ReadAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) override; + SBlockRV WriteAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope) override; /******** Schedule: Compute location ********/ - void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + void ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) override; - void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + void ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) override; - void ComputeInline(const BlockRV& block) override; - void ReverseComputeInline(const BlockRV& block) override; - void FuseReductionEpilogue(const BlockRV& reduction_block, - const BlockRV& epilogue_block) override; + void ComputeInline(const SBlockRV& block) override; + void ReverseComputeInline(const SBlockRV& block) override; + void FuseReductionEpilogue(const SBlockRV& reduction_block, + const SBlockRV& epilogue_block) override; /******** Schedule: Reduction ********/ - BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; - BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; - void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) override; - /******** Schedule: Block annotation ********/ - void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + SBlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + SBlockRV DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) override; + void PadEinsum(const SBlockRV& block_rv, const ffi::Array& padding) override; + /******** Schedule: SBlock annotation ********/ + void StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; - void SetScope(const BlockRV& block_rv, int buffer_index, + void SetScope(const SBlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) override; - void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const ffi::String& dtype) override; + void UnsafeSetDType(const SBlockRV& block_rv, int buffer_index, + const ffi::String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ - BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; - BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) override; - void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + SBlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; + SBlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) override; + void Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) override; void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) override; void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) override; - void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; - void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) override; + void Annotate(const SBlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) override; /******** Schedule: Layout transformation ********/ - void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const ffi::Optional& pad_value, + void TransformLayout(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map, + const ffi::Optional& pad_value, bool assume_injective_transform = false) override; - void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; - void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, + void TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) override; + void SetAxisSeparator(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array& axis_separators) override; /******** Schedule: Padding decomposition ********/ - BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; + SBlockRV DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) override; /******** Schedule: Buffer transformation ********/ - void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override; + void RollingBuffer(const SBlockRV& block_rv, int write_buffer_index) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} - void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + void UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type, const ffi::Array& buf_index_array) override; - void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + void AnnotateBufferAccess(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) override; protected: @@ -244,10 +246,9 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Lookup random variables ********/ -inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { +inline SBlock ConcreteScheduleNode::Get(const SBlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const BlockNode* block = TVM_SREF_TO_BLOCK(sref); - return ffi::GetRef(block); + return ffi::GetRef(TVM_SREF_TO_SBLOCK(sref)); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { @@ -269,7 +270,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { return this->analyzer_->Simplify(transformed); } -inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { +inline bool ConcreteScheduleNode::HasBlock(const SBlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { return false; @@ -282,15 +283,15 @@ inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { return true; } -inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { +inline StmtSRef ConcreteScheduleNode::GetSRef(const SBlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; + LOG(FATAL) << "IndexError: Cannot find corresponding SBlockRV: " << block_rv; } const ObjectRef& obj = (*it).second; const auto* sref = obj.as(); if (sref == nullptr) { - LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: " + LOG(FATAL) << "ValueError: SBlockRV's corresponding type is invalid: " << (obj.defined() ? obj->GetTypeKey() : "None"); } if (sref->stmt == nullptr) { @@ -335,7 +336,7 @@ inline ffi::Array GetSRefsHelper(const ConcreteScheduleNode* sch, return result; } -inline ffi::Array ConcreteScheduleNode::GetSRefs(const ffi::Array& rvs) const { +inline ffi::Array ConcreteScheduleNode::GetSRefs(const ffi::Array& rvs) const { return GetSRefsHelper(this, rvs); } diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h index 39c9cc203fcf..daea23518e77 100644 --- a/src/tir/schedule/error.h +++ b/src/tir/schedule/error.h @@ -57,7 +57,7 @@ class ScheduleError : public tvm::runtime::Error { class LoopPositionError : public ScheduleError { public: - explicit LoopPositionError(IRModule mod, For loop, Block block, const std::string& primitive) + explicit LoopPositionError(IRModule mod, For loop, SBlock block, const std::string& primitive) : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)), @@ -79,7 +79,7 @@ class LoopPositionError : public ScheduleError { IRModule mod_; For loop_; - Block block_; + SBlock block_; std::string primitive_; }; diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 02c866e0b605..5cf128b25201 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -72,7 +72,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) inputs.push_back(ffi::String("None")); } else if (auto opt_str = obj.as()) { inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); - } else if (obj.as() || obj.as()) { + } else if (obj.as() || obj.as()) { inputs.push_back(ffi::String("_")); } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { inputs.push_back(obj); diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index bef35387cbaa..ae476ae0a2b1 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -34,7 +34,7 @@ class TensorIntrinMismatchError : public ScheduleError { lhs_stmt_(std::move(lhs_stmt)), rhs_stmt_(std::move(rhs_stmt)), error_messages_(std::move(error_messages)) { - ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); + ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); } ffi::String FastErrorString() const final { @@ -67,7 +67,7 @@ class TensorIntrinMismatchError : public ScheduleError { bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { bool equal = n.same_as(other) || ((n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other)); - if (!equal && assert_mode_ && (n->IsInstance() || n->IsInstance())) { + if (!equal && assert_mode_ && (n->IsInstance() || n->IsInstance())) { throw TensorIntrinMismatchError(lhs_mod_, n, other, std::move(error_messages_)); } return equal; @@ -183,8 +183,8 @@ bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& othe return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); } -bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { - const auto* rhs = other.as(); +bool TensorizeComparator::VisitStmt_(const SBlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); if (!is_scope_block) { if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) { if (assert_mode_) { @@ -199,8 +199,8 @@ bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& oth return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); } -bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { - const auto* rhs = other.as(); +bool TensorizeComparator::VisitStmt_(const SBlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); for (const IterVar& iter : op->iter_vars) { lhs_analyzer_.Bind(iter->var, iter->dom); } @@ -623,8 +623,8 @@ bool AutoTensorizeComparator::VisitStmtDefault_(const Object* op, const Stmt& ot return false; } -bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { - const auto* rhs = other.as(); +bool AutoTensorizeComparator::VisitStmt_(const SBlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); // Check block equality. // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. @@ -643,7 +643,7 @@ bool AutoTensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) inner_iter_dom_map_.Set(block_iter->var, arith::IntSet::FromRange(block_iter->dom)); } } else { - auto collect_iter = [&](const BlockNode* op, std::vector& iters) -> bool { + auto collect_iter = [&](const SBlockNode* op, std::vector& iters) -> bool { for (const auto& iter : op->iter_vars) { analyzer_.Bind(iter->var, iter->dom); if (iter->iter_type == IterVarType::kDataPar || diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 665d093b2fa4..dbf773922f48 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -50,8 +50,8 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool VisitStmt_(const ForNode* op, const Stmt& other) override; bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; - bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; - bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + bool VisitStmt_(const SBlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const SBlockNode* op, const Stmt& other) override; bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; @@ -135,7 +135,7 @@ class AutoTensorizeComparator : public TensorizeComparator { bool VisitExprDefault_(const Object* op, const PrimExpr& other) override; bool VisitStmtDefault_(const Object* op, const Stmt& other) override; - bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + bool VisitStmt_(const SBlockNode* op, const Stmt& other) override; bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; @@ -147,9 +147,9 @@ class AutoTensorizeComparator : public TensorizeComparator { public: // Additional information extracted from LHS (the workload) and RHS (the tensor intrin). - /*! \brief Block iters in the LHS stmt. */ + /*! \brief SBlock iters in the LHS stmt. */ std::vector lhs_iters_; - /*! \brief Block iters in the RHS stmt. */ + /*! \brief SBlock iters in the RHS stmt. */ std::vector rhs_iters_; /*! \brief The buffer and its access indices in the LHS stmt. */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 1af0033791f4..cc06f7f0d1b4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -158,8 +158,8 @@ TVM_DLL tir::StmtSRef SampleComputeLocation( * \param gvar The function to be retrieved * \return A list of blocks with the specific name */ -ffi::Array GetBlocks(const ScheduleState& self, const ffi::String& name, - const GlobalVar& gv); +ffi::Array GetSBlocks(const ScheduleState& self, const ffi::String& name, + const GlobalVar& gv); /*! * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state @@ -547,7 +547,7 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sr * \return The sref of the rfactor block */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); -/******** Schedule: Block annotation ********/ +/******** Schedule: SBlock annotation ********/ /*! * \brief Set alignment requirement for specific dimension such that diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index c398a46418a6..25c86431be4f 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -27,7 +27,7 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_k const ffi::Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; - } else if (const auto* block = sref->StmtAs()) { + } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); @@ -44,11 +44,11 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_k ObjectPtr n = ffi::make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); - } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = ffi::make_object(*block); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = ffi::make_object(*block); n->annotations = std::move(new_ann); - Block p(n); - self->Replace(sref, p, {{ffi::GetRef(block), p}}); + SBlock p(n); + self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; @@ -60,7 +60,7 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann const ffi::Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; - } else if (const auto* block = sref->StmtAs()) { + } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); @@ -75,11 +75,11 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann ObjectPtr n = ffi::make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); - } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = ffi::make_object(*block); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = ffi::make_object(*block); n->annotations = std::move(new_ann); - Block p(n); - self->Replace(sref, p, {{ffi::GetRef(block), p}}); + SBlock p(n); + self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; @@ -97,13 +97,14 @@ struct AnnotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, Any ann_val, ffi::String ann_key) { - if (auto block = block_or_loop_rv.as()) { + if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } if (auto loop = block_or_loop_rv.as()) { return sch->Annotate(loop.value(), ann_key, ann_val); } - LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); throw; } @@ -131,13 +132,14 @@ struct UnannotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ffi::String ann_key) { - if (auto block = block_or_loop_rv.as()) { + if (auto block = block_or_loop_rv.as()) { return sch->Unannotate(block.value(), ann_key); } if (auto loop = block_or_loop_rv.as()) { return sch->Unannotate(loop.value(), ann_key); } - LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); throw; } diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/tir/schedule/primitive/annotate_buffer_access.cc index 84672dede70d..c358fd84d6b2 100644 --- a/src/tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/tir/schedule/primitive/annotate_buffer_access.cc @@ -30,8 +30,8 @@ class AnnotateRegionRewriter : public StmtExprMutator { new_region_(new_region), buffer_index_type_(buffer_index_type) {} - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); ffi::Array regions = buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; @@ -39,7 +39,7 @@ class AnnotateRegionRewriter : public StmtExprMutator { ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; regions.Set(buffer_index_, new_region_); - ObjectPtr n = CopyOnWrite(block.get()); + ObjectPtr n = CopyOnWrite(block.get()); if (buffer_index_type_ == BufferIndexType::kWrite) { n->writes = std::move(regions); } else { @@ -70,7 +70,7 @@ class AnnotateRegionRewriter : public StmtExprMutator { } n->annotations = std::move(new_annotations); - return Block(n); + return SBlock(n); } private: @@ -82,9 +82,9 @@ class AnnotateRegionRewriter : public StmtExprMutator { void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, buffer_index_type); arith::Analyzer analyzer; ffi::Array block_iter_vars; @@ -105,7 +105,7 @@ void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int bu AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type); Stmt new_stmt = mutator(ffi::GetRef(block_sref->stmt)); - self->Replace(block_sref, new_stmt, {{ffi::GetRef(block), Downcast(new_stmt)}}); + self->Replace(block_sref, new_stmt, {{ffi::GetRef(block), Downcast(new_stmt)}}); } struct AnnotateBufferAccessTraits : public UnpackedInstTraits { @@ -117,7 +117,7 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraitsAnnotateBufferAccess(block, buffer_index->value, static_cast(buffer_index_type->value), diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 2bf62d409e2d..7810eb81b6dc 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -136,7 +136,7 @@ class StorageAlignInvalidFactorError : public ScheduleError { class StorageAlignInvalidAnnotationError : public ScheduleError { public: - explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) + explicit StorageAlignInvalidAnnotationError(IRModule mod, SBlock block) : mod_(std::move(mod)), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -153,7 +153,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { return os.str(); } - static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const Block& block) { + static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const SBlock& block) { // Get existing annotation value. auto it = block->annotations.find(attr::buffer_dim_align); if (it != block->annotations.end()) { @@ -172,12 +172,12 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { IRModule mod() const final { return mod_; } private: - static bool IsValidAnnotation(const Block& block, const Any& anno_value) { + static bool IsValidAnnotation(const SBlock& block, const Any& anno_value) { return anno_value.try_cast>>().has_value(); } IRModule mod_; - Block block_; + SBlock block_; }; /*! @@ -193,17 +193,18 @@ class StorageScopeMutator : private ReplaceBufferMutator { * \param block_sref_reuse The block sref reuse map to be updated * \return The new block after the mutation */ - static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, - const ffi::String& storage_scope, ffi::Map* block_sref_reuse) { + static SBlock Mutate(const SBlock& allocate_site, const Buffer& old_buffer, + const ffi::String& storage_scope, + ffi::Map* block_sref_reuse) { Buffer new_buffer = WithScope(old_buffer, storage_scope); StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse); Stmt new_block = mutator.VisitStmt(allocate_site); - return Downcast(new_block); + return Downcast(new_block); } private: StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, ffi::String storage_scope, - ffi::Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {} MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { @@ -221,8 +222,8 @@ class StorageScopeMutator : private ReplaceBufferMutator { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, + const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); + Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, BufferIndexType::kWrite); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); @@ -231,7 +232,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind // Step 1: Get existing or create new annotation value. StorageAlignAnnotation storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, - ffi::GetRef(block_ptr)); + ffi::GetRef(block_ptr)); // Step 2: Update the annotation value bool found = false; @@ -249,15 +250,15 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind } // Step 3: Replace the block with the new annotation - Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); - self->Replace(block_sref, new_block, {{ffi::GetRef(block_ptr), new_block}}); + SBlock new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); + self->Replace(block_sref, new_block, {{ffi::GetRef(block_ptr), new_block}}); } void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, const ffi::String& storage_scope) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return. if (buffer.scope() == storage_scope) { @@ -270,13 +271,13 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, // Step 3. Get the allocation site of the target buffer. StmtSRef alloc_site_sref = NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); - const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); + const SBlockNode* alloc_site = TVM_SREF_TO_SBLOCK(alloc_site_sref); // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given // storage scope. In the meanwhile, collect the block sref reuse information. - ffi::Map block_reuse_map; - Block new_block = StorageScopeMutator::Mutate(ffi::GetRef(alloc_site), buffer, - storage_scope, &block_reuse_map); + ffi::Map block_reuse_map; + SBlock new_block = StorageScopeMutator::Mutate(ffi::GetRef(alloc_site), buffer, + storage_scope, &block_reuse_map); self->Replace(alloc_site_sref, new_block, block_reuse_map); } @@ -293,17 +294,17 @@ class DTypeMutator : private ReplaceBufferMutator { * \param block_sref_reuse The block sref reuse map to be updated * \return The new block after the mutation */ - static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype, - ffi::Map* block_sref_reuse) { + static SBlock Mutate(const SBlock& allocate_site, const Buffer& old_buffer, const DataType& dtype, + ffi::Map* block_sref_reuse) { Buffer new_buffer = WithDType(old_buffer, dtype); DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse); Stmt new_block = mutator.VisitStmt(allocate_site); - return Downcast(new_block); + return Downcast(new_block); } private: DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype, - ffi::Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse), src_dtype_(old_buffer->dtype), tgt_dtype_(dtype) {} @@ -344,9 +345,9 @@ class DTypeMutator : private ReplaceBufferMutator { void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, const ffi::String& dtype) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); DataType target_dtype(ffi::StringToDLDataType(dtype)); // Step 1. If `dtype` equals the original data type, just return. @@ -357,13 +358,13 @@ void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_i // Step 2. Get the allocation site of the target buffer. StmtSRef alloc_site_sref = NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); - const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); + const SBlockNode* alloc_site = TVM_SREF_TO_SBLOCK(alloc_site_sref); // Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given // dtype, and insert data type conversions. - ffi::Map block_reuse_map; - Block new_block = - DTypeMutator::Mutate(ffi::GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); + ffi::Map block_reuse_map; + SBlock new_block = + DTypeMutator::Mutate(ffi::GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); self->Replace(alloc_site_sref, new_block, block_reuse_map); } @@ -378,7 +379,7 @@ struct StorageAlignTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 4; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, Integer buffer_index, Integer axis, Integer factor, Integer offset) { return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value, offset->value); @@ -409,7 +410,7 @@ struct SetScopeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, Integer buffer_index, ffi::String storage_scope) { return sch->SetScope(block_rv, buffer_index->value, storage_scope); } @@ -436,7 +437,7 @@ struct UnsafeSetDTypeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, Integer buffer_index, ffi::String dtype) { return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 2ae32ea66a6a..1ae2b8e7bfb4 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -47,7 +47,7 @@ T DeepCopy(const T& stmt) { */ class SubspaceNotDivisibleError : public ScheduleError { public: - explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block) + explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, SBlock inner_block) : mod_(std::move(mod)), scope_loop_(std::move(scope_loop)), inner_block_(std::move(inner_block)) {} @@ -68,7 +68,7 @@ class SubspaceNotDivisibleError : public ScheduleError { private: IRModule mod_; For scope_loop_; - Block inner_block_; + SBlock inner_block_; }; /*! @@ -154,7 +154,7 @@ ffi::Array> TrivialSubspaceDivision( * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \param loop_sref_as_outer Whether loop_sref is divided into outer or inner */ -ffi::Array> SubspaceDivide(const BlockRealize& realize, +ffi::Array> SubspaceDivide(const SBlockRealize& realize, const StmtSRef& block_sref, // const StmtSRef& loop_sref, // std::vector* loops, @@ -283,12 +283,12 @@ ffi::Map DeriveBlockBinding( * `iter_vars`, `init` and `reads` fields. * \return The inner block created. */ -BlockRealize GenerateInner(bool is_write_reduction, - const ffi::Array& iter_vars, // - const ffi::Array& iter_values, // - const PrimExpr& predicate, // - Block block) { - BlockNode* n = block.CopyOnWrite(); +SBlockRealize GenerateInner(bool is_write_reduction, + const ffi::Array& iter_vars, // + const ffi::Array& iter_values, // + const PrimExpr& predicate, // + SBlock block) { + SBlockNode* n = block.CopyOnWrite(); n->iter_vars = iter_vars; n->init = std::nullopt; if (is_write_reduction) { @@ -298,8 +298,8 @@ BlockRealize GenerateInner(bool is_write_reduction, reads.insert(reads.end(), block->reads.begin(), block->reads.end()); n->reads = std::move(reads); } - return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate, - /*block=*/block); + return SBlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate, + /*block=*/block); } /*! @@ -309,9 +309,9 @@ BlockRealize GenerateInner(bool is_write_reduction, * \param loops The inner loops after blockize. * \return The subtree of the init block and its outer loops. */ -Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize, +Stmt GenerateOuterInit(const Stmt& block_init, const SBlockRealize& inner_realize, const std::vector& loops, ffi::String block_name) { - const Block& inner_block = inner_realize->block; + const SBlock& inner_block = inner_realize->block; ffi::Map subst_map; // Step 1: Create new block vars for the block inside the init stmt of outer block // A iter is used in the block if @@ -336,16 +336,16 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize } } // Step 2: Generate the block inside init stmt of outer block - Stmt stmt = BlockRealize( + Stmt stmt = SBlockRealize( /*iter_values=*/iter_values, /*predicate=*/inner_realize->predicate, /*block=*/ - Block(/*iter_vars=*/iter_vars, - /*reads=*/{}, - /*writes=*/inner_block->writes, - /*name_hint=*/block_name, - /*body=*/block_init, - /*init=*/std::nullopt)); + SBlock(/*iter_vars=*/iter_vars, + /*reads=*/{}, + /*writes=*/inner_block->writes, + /*name_hint=*/block_name, + /*body=*/block_init, + /*init=*/std::nullopt)); // Step 3. Create the loop nest on top of the block for (const ForNode* loop : loops) { bool is_init_loop = false; @@ -376,10 +376,10 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize * \return The substituted stmt. */ Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, - ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) { + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) { struct Replacer : public StmtExprMutator { - explicit Replacer(const ffi::Map& sub, ffi::Map* block_sref_reuse, - arith::Analyzer* analyzer) + explicit Replacer(const ffi::Map& sub, + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {} PrimExpr VisitExpr(const PrimExpr& op) final { @@ -397,9 +397,9 @@ Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, return StmtExprMutator::VisitExpr_(op); } - Stmt VisitStmt_(const BlockNode* op) final { - Block src = ffi::GetRef(op); - Block tgt = Downcast(StmtExprMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock src = ffi::GetRef(op); + SBlock tgt = Downcast(StmtExprMutator::VisitStmt_(op)); if (!src.same_as(tgt)) { block_sref_reuse_->Set(src, tgt); } @@ -407,7 +407,7 @@ Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, } const ffi::Map& sub_; - ffi::Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; arith::Analyzer* analyzer_; }; return Replacer(sub, block_sref_reuse, analyzer)(stmt); @@ -485,13 +485,13 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { return stmt; } -BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, - ffi::Map* block_sref_reuse, arith::Analyzer* analyzer, - bool preserve_unit_iters) { +SBlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer, + bool preserve_unit_iters) { TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. - BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); - Block block = block_realize->block; + SBlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); + SBlock block = block_realize->block; StmtSRef block_sref = self->stmt2ref.at(block.get()); // Step 2: Derive subspace division std::vector loops; @@ -518,8 +518,8 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom)); analyzer->Bind(iter->var, iter->dom); } - Block block_subst = - Downcast(Substitute(block, block_var_subst, block_sref_reuse, analyzer)); + SBlock block_subst = + Downcast(Substitute(block, block_var_subst, block_sref_reuse, analyzer)); // Step 5: Generate the inner block. The write regions of the inner blocks will be reduction if // 1. The original block has init stmt. // 2. There are outer reduction iter vars. @@ -532,46 +532,46 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, } } } - BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction, - /*iter_vars=*/inner_iter_vars, - /*iter_values*/ inner_bindings, - /*predicate=*/inner_predicate, - /*block=*/block_subst); + SBlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction, + /*iter_vars=*/inner_iter_vars, + /*iter_values*/ inner_bindings, + /*predicate=*/inner_predicate, + /*block=*/block_subst); block_sref_reuse->Set(block, inner_realize->block); // Step 6: Generate the outer block. - return BlockRealize( + return SBlockRealize( /*iter_values=*/std::move(outer_bindings), /*predicate=*/std::move(outer_predicate), /*block=*/ - Block(/*iter_vars=*/std::move(outer_iter_vars), - /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom), - /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom), - /*name_hint=*/block_subst->name_hint + "_o", - /*body=*/MakeLoopNest(inner_realize, loops), - /*init=*/ - block_subst->init.defined() // - ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, - block_subst->name_hint + "_init") - : ffi::Optional(std::nullopt))); + SBlock(/*iter_vars=*/std::move(outer_iter_vars), + /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom), + /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom), + /*name_hint=*/block_subst->name_hint + "_o", + /*body=*/MakeLoopNest(inner_realize, loops), + /*init=*/ + block_subst->init.defined() // + ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, + block_subst->name_hint + "_init") + : ffi::Optional(std::nullopt))); } StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) { arith::Analyzer analyzer; - ffi::Map block_sref_reuse; - BlockRealize blockized = + ffi::Map block_sref_reuse; + SBlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); self->Replace(loop_sref, blockized, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); - self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, scope_root)); self->block_info[scope_root].affine_binding = scope_block_affine_binding; return result; } -BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array& block_srefs, - const StmtSRef& lca, ffi::Map* block_sref_reuse, - bool preserve_unit_iters) { +SBlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array& block_srefs, + const StmtSRef& lca, ffi::Map* block_sref_reuse, + bool preserve_unit_iters) { ffi::Array seq_body; PrimExpr outer_predicate{nullptr}; ffi::Array outer_iter_vars{nullptr}; @@ -582,7 +582,7 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array loop_var_subst; arith::Analyzer analyzer; for (const auto& block_sref : block_srefs) { - auto block_realize = GetBlockRealize(self, block_sref); + auto block_realize = GetSBlockRealize(self, block_sref); auto block = block_realize->block; // Step 1: Derive subspace division std::vector loops; @@ -613,8 +613,8 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Arrayvar, arith::IntSet::FromRange(dom)); analyzer.Bind(iter->var, dom); } - Block block_subst = - Downcast(Substitute(block, block_var_subst, block_sref_reuse, &analyzer)); + SBlock block_subst = + Downcast(Substitute(block, block_var_subst, block_sref_reuse, &analyzer)); auto reads = EvalSetRegions(block_subst->reads, inner_iter_dom); auto writes = EvalSetRegions(block_subst->writes, inner_iter_dom); read_regions.insert(read_regions.end(), reads.begin(), reads.end()); @@ -632,11 +632,11 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::ArraySet(block, inner_realize->block); Stmt stmt = inner_realize; for (const ForNode* loop : loops) { @@ -648,29 +648,29 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array(std::nullopt))); + SBlock(/*iter_vars=*/std::move(outer_iter_vars), + /*reads=*/UnionRegions(read_regions), + /*writes=*/UnionRegions(write_regions), + /*name_hint=*/outer_block_name, + /*body=*/SeqStmt(seq_body), + /*init=*/ffi::Optional(std::nullopt))); } class BlockizeRewriter : public StmtMutator { public: static Stmt Rewrite(const StmtSRef& lca, const ffi::Array& blocks, - const BlockRealize& blockized) { + const SBlockRealize& blockized) { BlockizeRewriter rewriter(lca, blocks, blockized); return rewriter(ffi::GetRef(lca->stmt)); } private: explicit BlockizeRewriter(const StmtSRef& lca, const ffi::Array& blocks, - const BlockRealize& blockized) + const SBlockRealize& blockized) : lca_(lca), blocks_(blocks), blockized_(blockized) {} Stmt RewriteSeq(const Stmt& stmt) { @@ -708,11 +708,11 @@ class BlockizeRewriter : public StmtMutator { return StmtMutator::VisitStmt_(loop); } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { if (block == lca_->stmt) { - return Block(block->iter_vars, block->reads, block->writes, block->name_hint, - RewriteSeq(block->body), block->init, block->alloc_buffers, block->match_buffers, - block->annotations, block->span); + return SBlock(block->iter_vars, block->reads, block->writes, block->name_hint, + RewriteSeq(block->body), block->init, block->alloc_buffers, + block->match_buffers, block->annotations, block->span); } for (const StmtSRef& block_sref : blocks_) { if (block_sref->stmt == block) { @@ -725,38 +725,38 @@ class BlockizeRewriter : public StmtMutator { StmtSRef lca_; ffi::Array blocks_; - BlockRealize blockized_; + SBlockRealize blockized_; bool target_in_ = false; }; StmtSRef Blockize(ScheduleState self, const ffi::Array& blocks, bool preserve_unit_iters) { - ffi::Map block_sref_reuse; + ffi::Map block_sref_reuse; auto lca = GetSRefLowestCommonAncestor(blocks); - BlockRealize blockized = + SBlockRealize blockized = BlockizeBlocks(self, blocks, lca, &block_sref_reuse, preserve_unit_iters); auto new_root = BlockizeRewriter::Rewrite(lca, blocks, blockized); self->Replace(lca, new_root, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); - self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, scope_root)); return result; } void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin, bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed - BlockRealize block_realize{nullptr}; - ffi::Optional old_block = std::nullopt; - if (sref->stmt->IsInstance()) { - block_realize = GetBlockRealize(self, sref); + SBlockRealize block_realize{nullptr}; + ffi::Optional old_block = std::nullopt; + if (sref->stmt->IsInstance()) { + block_realize = GetSBlockRealize(self, sref); old_block = block_realize->block; } else if (sref->stmt->IsInstance()) { arith::Analyzer analyzer; - ffi::Map block_sref_reuse; + ffi::Map block_sref_reuse; block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); } else { - LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " + LOG(FATAL) << "TypeError: Tensorize only support For or SBlock, but gets: " << ffi::GetRef(sref->stmt); throw; } @@ -799,7 +799,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int impl2cur[impl] = comparator.rhs_buffer_map_[desc]; } std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; - Block impl_block = Downcast(intrin_impl->body)->block; + SBlock impl_block = Downcast(intrin_impl->body)->block; for (const BufferRegion& read : impl_block->reads) { impl2region.emplace(read->buffer, read->region); } @@ -833,7 +833,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } // Step 5: Replace the subtree in the original IR with the tensor intrin impl. { - BlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite(); + SBlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite(); block->body = impl_block->body; block->match_buffers = std::move(match_buffer_regions); for (const auto& [key, val] : impl_block->annotations) { @@ -854,7 +854,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int // Step 6: Update the cached flags. StmtSRef result = self->stmt2ref.at(block_realize->block.get()); StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); - self->UpdateScopeBlockInfo(scope_root->StmtAs()->body); + self->UpdateScopeSBlockInfo(scope_root->StmtAs()->body); } /******** InstructionKind Registration ********/ @@ -868,13 +868,14 @@ struct BlockizeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, Bool preserve_unit_iters) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, + Bool preserve_unit_iters) { if (auto loop = target.as()) { return sch->Blockize(loop.value(), preserve_unit_iters.operator bool()); - } else if (auto blocks = target.as>()) { + } else if (auto blocks = target.as>()) { return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); } - LOG(FATAL) << "TypeError: expect Loop or list of Blocks, but gets:" << target->GetTypeKey(); + LOG(FATAL) << "TypeError: expect Loop or list of SBlocks, but gets:" << target->GetTypeKey(); } static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef target, @@ -901,12 +902,12 @@ struct TensorizeTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ffi::String intrin, Bool preserve_unit_iters) { - if (auto block = block_or_loop_rv.as()) { + if (auto block = block_or_loop_rv.as()) { sch->Tensorize(block.value(), intrin, preserve_unit_iters.operator bool()); } else if (auto loop = block_or_loop_rv.as()) { sch->Tensorize(loop.value(), intrin, preserve_unit_iters.operator bool()); } else { - LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " + LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); } } diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc index 156f2ae4c59c..788f34e883dd 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/tir/schedule/primitive/cache_index.cc @@ -30,7 +30,7 @@ namespace tir { /*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ struct IndexInfo { /*! \brief The target block to perform cache_index */ - StmtSRef target_block; + StmtSRef target_sblock; /*! \brief Record the common subexpr extract threshold */ size_t cse_thresh; /*! \brief The cache buffer to store the precomputed index */ @@ -48,7 +48,7 @@ struct IndexInfo { /*! \brief The cache stage to be inserted. */ Stmt cache_stage; /*! \brief The map used for ScheduleStateNode::Replace. */ - ffi::Map block_reuse; + ffi::Map block_reuse; }; /*! @@ -112,10 +112,10 @@ class IndexInfoCollector : public StmtExprVisitor { } } - void VisitStmt_(const BlockNode* block) final { - visiting_target_block = static_cast(block_sref_->stmt == block); + void VisitStmt_(const SBlockNode* block) final { + visiting_target_sblock = static_cast(block_sref_->stmt == block); StmtVisitor::VisitStmt_(block); - visiting_target_block = false; + visiting_target_sblock = false; if (block == scope_sref_->stmt) { // The block vistied is the current parent scope // Handling cases when no SeqStmt in the scope @@ -142,7 +142,7 @@ class IndexInfoCollector : public StmtExprVisitor { void VisitStmt_(const BufferStoreNode* store) final { // Only analyze the cache candidate for stores in target block - if (visiting_target_block) { + if (visiting_target_sblock) { auto IsEligibleComputation = [](const PrimExpr& expr) { return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && (expr.as() == nullptr) && (expr.as() == nullptr)); @@ -205,7 +205,7 @@ class IndexInfoCollector : public StmtExprVisitor { /*! \brief The flag whether we have visited the target block */ bool visited_block_{false}; /*! \brief The flag indicating currently visiting target block */ - bool visiting_target_block{false}; + bool visiting_target_sblock{false}; /*! \brief The index to insert the cache_index stage */ int loc_pos_{-1}; /*! \brief The flag indicating the right scope to update seq pos */ @@ -220,8 +220,8 @@ class IndexInfoCollector : public StmtExprVisitor { * \param storage_scope The storage scope of the cached buffer (only used in naming here) * \returns A block indicating the body of the loop nesting. */ -ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& storage_scope) { - ffi::Array blocks; +ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& storage_scope) { + ffi::Array blocks; ffi::Array bodies; bodies.reserve(info->index_exprs.size()); info->cache_buffer.reserve(info->index_exprs.size()); @@ -308,7 +308,7 @@ ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& storag // Create the index computing block PrimExpr new_expr = Substitute(index_expr, block_var_map); - Block block( + SBlock block( /*iter_vars=*/std::move(block_vars), /*reads=*/{}, /*writes=*/{BufferRegion(info->cache_buffer[expr_index], access_region)}, @@ -321,9 +321,9 @@ ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& storag /*annotations=*/{}); blocks.push_back(block); // Create the block realize node - Stmt body = BlockRealize(/*values=*/iter_values, - /*predicate=*/const_true(), - /*block=*/block); + Stmt body = SBlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); // Create surrounding loops for (size_t i = loop_vars.size(); i >= 1; --i) { body = For(/*loop_var=*/loop_vars[i - 1], @@ -385,22 +385,22 @@ class CacheIndexRewriter : public StmtExprMutator { } } - Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = ffi::GetRef(block); + Stmt VisitStmt_(const SBlockNode* block) final { + SBlock old_stmt = ffi::GetRef(block); // Mutate the body - visiting_target_block = static_cast(block == info_->target_block->stmt); - Block stmt = Downcast(StmtMutator::VisitStmt_(block)); - visiting_target_block = false; + visiting_target_sblock = static_cast(block == info_->target_sblock->stmt); + SBlock stmt = Downcast(StmtMutator::VisitStmt_(block)); + visiting_target_sblock = false; // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { // If so, put buffer allocation and insert cache stages on the parent scope - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertIndexStage(n->body, info_->loc_pos, info_->cache_stage); for (const Buffer& it : info_->cache_buffer) { n->alloc_buffers.push_back(it); } - stmt = Block(n); + stmt = SBlock(n); } info_->block_reuse.Set(old_stmt, stmt); return stmt; @@ -409,7 +409,7 @@ class CacheIndexRewriter : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* store) final { Stmt ret_stmt = StmtMutator::VisitStmt_(store); // Replace common sub expr for target block, with cached buffer load - if (visiting_target_block) { + if (visiting_target_sblock) { for (size_t i = 0; i < info_->index_exprs.size(); i++) { PrimExpr& computation = info_->index_exprs[i]; std::function predicate_selector = @@ -433,7 +433,7 @@ class CacheIndexRewriter : public StmtExprMutator { /*! \brief The indices for the cache buffer */ std::vector> cache_indices_; /*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/ - bool visiting_target_block{false}; + bool visiting_target_sblock{false}; }; ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, @@ -449,7 +449,7 @@ ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, // Step 0. Checking index, getting the target buffer and the parent scope IndexInfo info; - info.target_block = block_sref; + info.target_sblock = block_sref; CHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number"; info.cse_thresh = cse_thresh; StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); @@ -458,9 +458,9 @@ ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, IndexInfoCollector::Collect(self, block_sref, scope_sref, &info); // Step 2. Create cache stages and rewrite the stmt. - BlockRealize realize = GetBlockRealize(self, block_sref); + SBlockRealize realize = GetSBlockRealize(self, block_sref); info.var_binding = GetBindings(realize); - ffi::Array cache_stages = MakeIndexCacheStage(&info, storage_scope); + ffi::Array cache_stages = MakeIndexCacheStage(&info, storage_scope); Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); bool old_stage_pipeline = self->block_info[block_sref].stage_pipeline; @@ -468,10 +468,10 @@ ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, // Step 3. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); ffi::Array result_block_srefs; - for (const Block& it : cache_stages) { + for (const SBlock& it : cache_stages) { StmtSRef result_block_sref = self->stmt2ref.at(it.get()); result_block_srefs.push_back(result_block_sref); - BlockInfo& block_info = self->block_info[result_block_sref]; + SBlockInfo& block_info = self->block_info[result_block_sref]; bool affine_binding = false; if (result_block_sref->parent == nullptr) { @@ -479,7 +479,7 @@ ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, } else { arith::Analyzer analyzer; StmtSRef parent_sref = ffi::GetRef(result_block_sref->parent); - affine_binding = IsAffineBinding(/*realize=*/GetBlockRealize(self, result_block_sref), + affine_binding = IsAffineBinding(/*realize=*/GetSBlockRealize(self, result_block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), /*analyzer=*/&analyzer); } @@ -503,9 +503,9 @@ struct CacheIndexTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, - ffi::String storage_scope, - Integer cse_thresh) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block, + ffi::String storage_scope, + Integer cse_thresh) { return sch->CacheIndex(block, storage_scope, cse_thresh->value); } diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a2479a0d28ff..5cae3749c55f 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -35,8 +35,8 @@ class NotSingleWriteBlock : public ScheduleError { ICHECK_GT(write_blocks.size(), 1); write_blocks_.reserve(write_blocks.size()); for (const StmtSRef& block_sref : write_blocks) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - write_blocks_.push_back(ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + write_blocks_.push_back(ffi::GetRef(block)); } } @@ -58,7 +58,7 @@ class NotSingleWriteBlock : public ScheduleError { private: IRModule mod_; Buffer buffer_; - ffi::Array write_blocks_; + ffi::Array write_blocks_; }; /******** Helper Functions/Classes ********/ @@ -78,7 +78,7 @@ struct CacheStageInfo { /*! \brief The cache_read/cache_write stage to be inserted. */ Stmt cache_stage; /*! \brief The map used for ScheduleStateNode::Replace. */ - ffi::Map block_reuse; + ffi::Map block_reuse; /*! \brief A set of blocks that will consume the new cache. */ std::unordered_set consumer_blocks; /*! \brief cache region for the buffer to be cached */ @@ -113,7 +113,7 @@ struct ReindexCacheStageInfo : CacheStageInfo { * reindex_cache_read/write. */ class NotSinglePointAccess : public ScheduleError { public: - explicit NotSinglePointAccess(IRModule mod, Block block, BufferRegion cache_region, + explicit NotSinglePointAccess(IRModule mod, SBlock block, BufferRegion cache_region, bool is_cache_read) : mod_(std::move(mod)), block_(std::move(block)), cache_region_(cache_region) { primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; @@ -136,7 +136,7 @@ class NotSinglePointAccess : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; BufferRegion cache_region_; ffi::String primitive_name_; }; @@ -150,8 +150,8 @@ class NotSinglePointAccess : public ScheduleError { * \returns A block indicating the body of the loop nesting. */ template -Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, - const ffi::String& storage_scope) { +SBlock MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, + const ffi::String& storage_scope) { // loop variables std::vector loop_vars; // block variables @@ -196,7 +196,7 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI } // Create New Block - Block block( + SBlock block( /*iter_vars*/ std::move(block_vars), /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, @@ -208,10 +208,10 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*buf_doms=*/{}); - // Create Block Realize node - Stmt body = BlockRealize(/*values=*/iter_values, - /*predicate=*/const_true(), - /*block=*/block); + // Create SBlock Realize node + Stmt body = SBlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); // Create surrounding loops for (size_t i = loop_vars.size(); i >= 1; --i) { body = For(/*loop_var=*/loop_vars[i - 1], @@ -236,8 +236,8 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI * full region or compact region. * \returns A block indicating the body of the loop nesting. */ -Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, - const ffi::String& storage_scope, bool cache_full_region = true) { +SBlock MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, + const ffi::String& storage_scope, bool cache_full_region = true) { // loop variables std::vector loop_vars; // bindings in block realize @@ -296,7 +296,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, // reads = [read_buffer[access_region]] // writes = [write_buffer[access_region]] // write_buffer[access_indices] = read_buffer[access_indices] - Block block( + SBlock block( /*iter_vars=*/std::move(block_vars), /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, @@ -309,9 +309,9 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, /*match_buffers=*/{}, /*annotations=*/{}); // Create the block realize node - Stmt body = BlockRealize(/*values=*/iter_values, - /*predicate=*/const_true(), - /*block=*/block); + Stmt body = SBlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); // Create surrounding loops for (size_t i = loop_vars.size(); i >= 1; --i) { body = For(/*loop_var=*/loop_vars[i - 1], @@ -342,10 +342,10 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, * \param buffer_index_type The type of buffer index * \return The reindex block. */ -Block MakeReIndexStage(const Block& block, CacheStageInfo* info, - const std::unordered_set& covered, - const ffi::Array& original_indices, int buffer_index, - BufferIndexType buffer_index_type) { +SBlock MakeReIndexStage(const SBlock& block, CacheStageInfo* info, + const std::unordered_set& covered, + const ffi::Array& original_indices, int buffer_index, + BufferIndexType buffer_index_type) { // iters of the reindex block ffi::Array new_block_iters; // the substitution map from the original block iter to the iters of the reindex block @@ -395,7 +395,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, } // Create the body block - Block new_block( + SBlock new_block( /*iter_vars=*/new_block_iters, /*reads=*/{BufferRegion::FromPoint(info->read_buffer, src_indices)}, /*writes=*/{BufferRegion::FromPoint(info->write_buffer, dst_indices)}, @@ -418,9 +418,9 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, } // Create the block realize node - Stmt body = BlockRealize(/*values=*/iter_values, - /*predicate=*/const_true(), - /*block=*/new_block); + Stmt body = SBlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/new_block); // Create the chain of loops for (int i = static_cast(new_block_iters.size()) - 1; i >= 0; --i) { @@ -445,7 +445,7 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) } arith::Analyzer analyzer; StmtSRef parent_sref = ffi::GetRef(block_sref->parent); - return IsAffineBinding(/*realize=*/GetBlockRealize(self, block_sref), + return IsAffineBinding(/*realize=*/GetSBlockRealize(self, block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), /*analyzer=*/&analyzer); } @@ -508,7 +508,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { */ ffi::Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, const Buffer& buffer) { - BlockScope scope = self->GetBlockScope(scope_sref); + SBlockScope scope = self->GetSBlockScope(scope_sref); auto it = scope->buffer_writers.find(buffer); if (it == scope->buffer_writers.end()) { return std::nullopt; @@ -535,9 +535,9 @@ ffi::Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& sc bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sref, StmtSRef stmt_sref) { // Collect all children blocks of the target stmt. - std::unordered_set blocks_under_target; + std::unordered_set blocks_under_target; for (const StmtSRef& block_sref : GetChildBlocks(self, stmt_sref)) { - const auto* block = block_sref->StmtAs(); + const auto* block = block_sref->StmtAs(); ICHECK(block != nullptr); blocks_under_target.insert(block); } @@ -546,7 +546,7 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sre // input buffer, check if it is also a child block of the // target stmt. for (const StmtSRef& block_sref : GetChildBlocks(self, scope_sref)) { - const auto* block = block_sref->StmtAs(); + const auto* block = block_sref->StmtAs(); ICHECK(block != nullptr); if (GetBufferRegionFromBuffer(block->reads, buffer).defined()) { if (blocks_under_target.find(block) == blocks_under_target.end()) { @@ -569,7 +569,7 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sre BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region, const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, const StmtSRef& dom_high_exclusive) { - BlockRealize realize = GetBlockRealize(self, block_sref); + SBlockRealize realize = GetSBlockRealize(self, block_sref); ffi::Map binding = GetBindings(realize); const Buffer& buffer = buffer_region->buffer; arith::Analyzer analyzer; @@ -613,14 +613,14 @@ class CacheLocDetector : public StmtVisitor { related_blocks.emplace_back(consumer); } } else { - for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { + for (const Dependency& def : self->GetSBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { if (def->kind == DepKind::kRAW) { related_blocks.push_back(def->dst); } } } } else { - for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { + for (const Dependency& def : self->GetSBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { if (def->kind == DepKind::kRAW) { if (info->consumer_blocks.count(def->dst)) { continue; @@ -638,7 +638,7 @@ class CacheLocDetector : public StmtVisitor { } else { info->loc_sref = scope_sref; - auto block_body = scope_sref->StmtAs()->body; + auto block_body = scope_sref->StmtAs()->body; // Find the SeqStmtNode within (potentially nested) AllocateConstNodes while (true) { if (auto* ptr = block_body.as()) { @@ -692,7 +692,7 @@ class CacheLocDetector : public StmtVisitor { visited_block_ = visited_block_ || previous_visited_block; } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { // Only visit the current scope under buffer writer's parent block if (block == scope_sref_->stmt) { // The block visited is the current parent scope @@ -792,7 +792,7 @@ class CacheInplaceLocDetector : public StmtVisitor { } } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { // Only visit the current scope under buffer writer's parent block if (block == scope_sref_->stmt) { // The block visited is the current parent scope @@ -916,15 +916,15 @@ class CacheReadRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) override { - Block old_stmt = ffi::GetRef(block); + Stmt VisitStmt_(const SBlockNode* block) override { + SBlock old_stmt = ffi::GetRef(block); // Check if this block is one of the specified consumers. // If no consumer blocks are specified, all blocks should be considered consumers. bool is_consumer = info_->consumer_blocks.empty(); // Otherwise check if this is one of the specified blocks. for (StmtSRef consumer_sref : info_->consumer_blocks) { - const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = ffi::GetRef(consumer_node); + const SBlockNode* consumer_node = TVM_SREF_TO_SBLOCK(consumer_sref); + SBlock consumer_block = ffi::GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { is_consumer = true; } @@ -937,22 +937,22 @@ class CacheReadRewriter : public StmtExprMutator { return old_stmt; } // Mutate the body - Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + SBlock stmt = Downcast(StmtMutator::VisitStmt_(block)); // Check the insertion point if (block == info_->loc_sref->stmt) { // Insert cache stage into the block if it is the right place - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - stmt = Block(n); + stmt = SBlock(n); } // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { // If so, put buffer allocation on the parent scope - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); - stmt = Block(n); + stmt = SBlock(n); } } else { // Otherwise, update read regions and match_buffers @@ -962,10 +962,10 @@ class CacheReadRewriter : public StmtExprMutator { ffi::Array reads = update_access_regions(stmt->reads); ffi::Array match_buffers = update_match_buffers(stmt->match_buffers); if (!reads.same_as(stmt->reads) || !match_buffers.same_as(stmt->match_buffers)) { - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); - stmt = Block(n); + stmt = SBlock(n); } } } @@ -1173,14 +1173,14 @@ class CacheWriteRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) override { - Block old_stmt = ffi::GetRef(block); + Stmt VisitStmt_(const SBlockNode* block) override { + SBlock old_stmt = ffi::GetRef(block); // Check if this block is one of the specified cache consumers. // update the read buffer to the cache. for (StmtSRef consumer_sref : info_->consumer_blocks) { - const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = ffi::GetRef(consumer_node); + const SBlockNode* consumer_node = TVM_SREF_TO_SBLOCK(consumer_sref); + SBlock consumer_block = ffi::GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { ffi::Array writes = update_access_regions(block->writes); ffi::Array reads = update_access_regions(block->reads); @@ -1192,7 +1192,7 @@ class CacheWriteRewriter : public StmtExprMutator { n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); n->body = VisitStmt(block->body); - Block new_consumer = Block(n); + SBlock new_consumer = SBlock(n); info_->block_reuse.Set(old_stmt, new_consumer); return new_consumer; } @@ -1208,22 +1208,22 @@ class CacheWriteRewriter : public StmtExprMutator { // Mutate the body bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt; std::swap(under_scope, under_writer_block_); - Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + SBlock stmt = Downcast(StmtMutator::VisitStmt_(block)); std::swap(under_scope, under_writer_block_); // Find the insertion point if (block == info_->loc_sref->stmt) { - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - stmt = Block(n); + stmt = SBlock(n); } // Put buffer allocation on the parent scope if (block == scope_sref_->stmt) { - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); - stmt = Block(n); + stmt = SBlock(n); } } else { // Since cache_write changes the block, we need to update the buffer it writes @@ -1232,11 +1232,11 @@ class CacheWriteRewriter : public StmtExprMutator { auto match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); - stmt = Block(n); + stmt = SBlock(n); } } info_->block_reuse.Set(old_stmt, stmt); @@ -1420,7 +1420,7 @@ Buffer CreateReindexBuffer(const Buffer& buffer, const ffi::Array& bloc */ class NotLeafBlockError : public ScheduleError { public: - NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} + NotLeafBlockError(IRModule mod, SBlock block) : mod_(std::move(mod)), block_(std::move(block)) {} ffi::String FastErrorString() const final { return "ScheduleError: The target block is not a leaf block."; } @@ -1432,7 +1432,7 @@ class NotLeafBlockError : public ScheduleError { IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; }; /*! \brief The schedule error that the buffer access is invalid for reindex. */ @@ -1444,7 +1444,7 @@ class InvalidBufferAccessError : public ScheduleError { kOpaqueAccess, // opaque access to the buffer }; - InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind) + InvalidBufferAccessError(IRModule mod, Buffer buffer, SBlock block, ErrorKind kind) : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {} ffi::String FastErrorString() const final { return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The " @@ -1471,7 +1471,7 @@ class InvalidBufferAccessError : public ScheduleError { private: IRModule mod_; Buffer buffer_; - Block block_; + SBlock block_; ErrorKind kind_; }; @@ -1479,7 +1479,7 @@ class InvalidBufferAccessError : public ScheduleError { class ReIndexCollector : public StmtExprVisitor { public: static ffi::Array Collect(const IRModule& mod, const Buffer& buffer, - const Block& block) { + const SBlock& block) { ReIndexCollector collector(mod, buffer, block); collector(block->body); if (!collector.buffer_access_indices_.defined()) { @@ -1490,7 +1490,7 @@ class ReIndexCollector : public StmtExprVisitor { } private: - explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const Block& block) + explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const SBlock& block) : mod_(mod), buffer_(buffer), block_(block) {} void VisitExpr_(const BufferLoadNode* load) final { @@ -1500,7 +1500,7 @@ class ReIndexCollector : public StmtExprVisitor { } } - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { // no sub-blocks under this block throw NotLeafBlockError(mod_, block_); } @@ -1535,7 +1535,7 @@ class ReIndexCollector : public StmtExprVisitor { /*! \brief The buffer to rewrite */ Buffer buffer_; /*! \brief The block to visit */ - Block block_; + SBlock block_; /*! \brief The indices of buffer acess to rewrite */ ffi::Optional> buffer_access_indices_; }; @@ -1557,16 +1557,16 @@ class ReIndexRewriter : public StmtExprMutator { old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer; } - Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = ffi::GetRef(block); + Stmt VisitStmt_(const SBlockNode* block) final { + SBlock old_stmt = ffi::GetRef(block); if (is_scope_) { is_scope_ = false; - Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + SBlock stmt = Downcast(StmtExprMutator::VisitStmt_(block)); // Insert cache stage into the loop - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); n->alloc_buffers.push_back(info_->alloc.value()); - stmt = Block(n); + stmt = SBlock(n); info_->block_reuse.Set(old_stmt, stmt); return stmt; } @@ -1580,7 +1580,7 @@ class ReIndexRewriter : public StmtExprMutator { region_.push_back(Range::FromMinExtent(iter->var, IntImm(iter->var->dtype, 1))); } } - Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + SBlock stmt = Downcast(StmtExprMutator::VisitStmt_(block)); // Update block reads/writes to use the intermediate reindex buffer auto writes = ReplaceBufferRegion(block->writes, old_buffer_, BufferRegion{new_buffer_, region_}); @@ -1590,11 +1590,11 @@ class ReIndexRewriter : public StmtExprMutator { BufferRegion{new_buffer_, region_}); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - ObjectPtr n = ffi::make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); - stmt = Block(n); + stmt = SBlock(n); } info_->block_reuse.Set(old_stmt, stmt); return stmt; @@ -1643,7 +1643,7 @@ class ReIndexRewriter : public StmtExprMutator { void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) { class NotRegionCoverError : public ScheduleError { public: - explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit NotRegionCoverError(IRModule mod, SBlock block) : mod_(mod), block_(block) {} IRModule mod() const final { return mod_; } ffi::String FastErrorString() const final { return "ScheduleError: The scope root's region cover is not complete."; @@ -1655,16 +1655,16 @@ The region cover property require to hold for every of its child blocks } ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; }; for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) { - const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); + const SBlockNode* child_block = TVM_SREF_TO_SBLOCK(child_block_sref); for (const BufferRegion& region : child_block->reads) { if (region->buffer.same_as(read_buffer)) { if (!self->block_info.at(child_block_sref).region_cover) { - const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); - throw NotRegionCoverError(self->mod, ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(scope_root); + throw NotRegionCoverError(self->mod, ffi::GetRef(block)); } } } @@ -1690,13 +1690,13 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer read_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + Buffer read_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check required region cover for cache_read CheckRegionCover(self, scope_sref, read_buffer); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); // Step 2. Create CacheStageInfo CacheStageInfo info; @@ -1716,7 +1716,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); - const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); + const SBlockNode* write_block = TVM_SREF_TO_SBLOCK(write_block_sref); // Find the producing region BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); StmtSRef parent_sref = ffi::GetRef(write_block_sref->parent); @@ -1737,7 +1737,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } // Step 4. Making new cache stage block and rewrite readers. - bool cache_full_region = info.loc_sref->StmtAs() == nullptr || + bool cache_full_region = info.loc_sref->StmtAs() == nullptr || !AllConsumersUnderStmt(self, read_buffer, scope_sref, info.loc_sref); info.cache_region = cache_region; info.write_buffer = WithScope(read_buffer, storage_scope); @@ -1751,7 +1751,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } info.alloc = info.write_buffer; - Block cache_read_stage = + SBlock cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, /*storage_scope=*/storage_scope, /*cache_full_region=*/cache_full_region); Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info, @@ -1760,7 +1760,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 5. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); - BlockInfo& block_info = self->block_info[result_block_sref]; + SBlockInfo& block_info = self->block_info[result_block_sref]; block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.stage_pipeline = true; @@ -1784,8 +1784,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer write_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), write_buffer_index, + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + Buffer write_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), write_buffer_index, BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); @@ -1813,7 +1813,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu BufferRegion cache_region = RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); - bool cache_full_region = info.loc_sref->StmtAs() == nullptr || + bool cache_full_region = info.loc_sref->StmtAs() == nullptr || !AllConsumersUnderStmt(self, write_buffer, scope_sref, info.loc_sref); info.cache_region = cache_region; info.read_buffer = WithScope(write_buffer, storage_scope); @@ -1828,7 +1828,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu info.alloc = info.read_buffer; // Step 5. Making new cache stage block and rewrite readers. - Block cache_write_stage = + SBlock cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, /*storage_scope=*/storage_scope, /*cache_full_region=*/cache_full_region); Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, @@ -1838,7 +1838,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 6. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); - BlockInfo& block_info = self->block_info[result_block_sref]; + SBlockInfo& block_info = self->block_info[result_block_sref]; block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.stage_pipeline = true; @@ -1861,7 +1861,7 @@ ffi::Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSR */ class ReindexCacheReadWriteNotMatchError : public ScheduleError { public: - ReindexCacheReadWriteNotMatchError(IRModule mod, Block block, Var var, + ReindexCacheReadWriteNotMatchError(IRModule mod, SBlock block, Var var, ffi::Array old_indices, ffi::Array new_indices, bool is_cache_read, bool appears_in_old) @@ -1891,7 +1891,7 @@ class ReindexCacheReadWriteNotMatchError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; ffi::String primitive_name_; - Block block_; + SBlock block_; Var var_; ffi::Array appears_indices_; ffi::Array other_indices_; @@ -1913,8 +1913,8 @@ class ReindexCacheReadWriteNotMatchError : public ScheduleError { template void CollectReindexCacheStageInfoAndCreateBuffer( ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, - const ffi::String& storage_scope, const IndexMap& index_map, const Block& block, - const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { + const ffi::String& storage_scope, const IndexMap& index_map, const SBlock& block, + const SBlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { arith::Analyzer analyzer; ffi::Array block_iter_vars, block_shape; for (const IterVar& iter_var : block->iter_vars) { @@ -1983,7 +1983,7 @@ void CollectReindexCacheStageInfoAndCreateBuffer( /*! \brief Check whether given cache_region is a single point access. */ template -void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion& cache_region) { +void CheckSinglePoint(ScheduleState self, const SBlock& block, const BufferRegion& cache_region) { bool single_point = true; for (const Range& range : cache_region->region) { const auto* ext_int = range->extent.as(); @@ -2013,8 +2013,8 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - Block block = ffi::GetRef(TVM_SREF_TO_BLOCK(block_sref)); - BlockRealize realize = GetBlockRealize(self, block_sref); + SBlock block = ffi::GetRef(TVM_SREF_TO_SBLOCK(block_sref)); + SBlockRealize realize = GetSBlockRealize(self, block_sref); Buffer read_buffer = GetNthAccessBuffer(self, block, read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); @@ -2051,7 +2051,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re cache_region); // Step 6. Making new cache stage block and rewrite readers. - Block cache_read_stage = + SBlock cache_read_stage = MakeReindexCacheStage(/*cache_region=*/cache_region, /*info=*/&info, /*storage_scope=*/storage_scope); @@ -2060,7 +2060,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re // Step 7. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); - BlockInfo& block_info = self->block_info[result_block_sref]; + SBlockInfo& block_info = self->block_info[result_block_sref]; block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.stage_pipeline = true; @@ -2084,8 +2084,8 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - Block block = ffi::GetRef(TVM_SREF_TO_BLOCK(block_sref)); - BlockRealize realize = GetBlockRealize(self, block_sref); + SBlock block = ffi::GetRef(TVM_SREF_TO_SBLOCK(block_sref)); + SBlockRealize realize = GetSBlockRealize(self, block_sref); Buffer write_buffer = GetNthAccessBuffer(self, block, write_buffer_index, BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); @@ -2113,7 +2113,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CheckSinglePoint(self, block, cache_region); // Step 6. Making new cache stage block and rewrite readers. - Block cache_write_stage = + SBlock cache_write_stage = MakeReindexCacheStage(/*cache_region=*/cache_region, /*info=*/&info, /*storage_scope=*/storage_scope); @@ -2124,7 +2124,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w // Step 7. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); - BlockInfo& block_info = self->block_info[result_block_sref]; + SBlockInfo& block_info = self->block_info[result_block_sref]; block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.stage_pipeline = true; @@ -2134,7 +2134,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w /*! \brief The schedule error that the target block doesn't both read&write target buffer. */ class NotReadWriteError : public ScheduleError { public: - NotReadWriteError(IRModule mod, Block block, Buffer buffer) + NotReadWriteError(IRModule mod, SBlock block, Buffer buffer) : mod_(std::move(mod)), block_(std::move(block)), buffer_(std::move(buffer)) {} ffi::String FastErrorString() const final { return "ScheduleError: The target block does not both read & write target buffer."; @@ -2147,7 +2147,7 @@ class NotReadWriteError : public ScheduleError { IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_, buffer_}; } IRModule mod_; - Block block_; + SBlock block_; Buffer buffer_; }; @@ -2161,8 +2161,8 @@ ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref CheckStorageScope(self, storage_scope); // Check 1. Check index, get the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); @@ -2170,11 +2170,11 @@ ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref CheckRegionCover(self, scope_sref, buffer); // Check 4. Check if target block both read & write target buffer. - const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* rw_block = TVM_SREF_TO_SBLOCK(block_sref); ffi::Optional read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); ffi::Optional write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); if (!read_region.defined() || !write_region.defined()) { - throw NotReadWriteError(self->mod, ffi::GetRef(rw_block), buffer); + throw NotReadWriteError(self->mod, ffi::GetRef(rw_block), buffer); } ffi::Array results_block_sref; @@ -2195,14 +2195,14 @@ ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info); // Cache read step 2. Making new cache stage block and rewrite readers. - Block cache_read_stage = MakeCacheStage(/*cache_region=*/read_region.value(), /*info=*/&info, - /*storage_scope=*/storage_scope); + SBlock cache_read_stage = MakeCacheStage(/*cache_region=*/read_region.value(), /*info=*/&info, + /*storage_scope=*/storage_scope); Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); // Cache read step 3. Replacing and updating flags for cache read. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); - BlockInfo& block_info_read = self->block_info[result_block_sref]; + SBlockInfo& block_info_read = self->block_info[result_block_sref]; block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info_read.region_cover = true; block_info_read.stage_pipeline = false; @@ -2223,15 +2223,15 @@ ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref info.loc_pos += 1; // Cache write step 2. Making new cache stage block and rewrite readers. - Block cache_write_stage = MakeCacheStage(/*cache_region=*/write_region.value(), /*info=*/&info, - /*storage_scope=*/storage_scope); + SBlock cache_write_stage = MakeCacheStage(/*cache_region=*/write_region.value(), /*info=*/&info, + /*storage_scope=*/storage_scope); new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, /*writer_block_sref=*/block_sref, /*info=*/&info); // Cache write step 4. Replacing and updating flags for cache write. self->Replace(scope_sref, new_scope, info.block_reuse); result_block_sref = self->stmt2ref.at(cache_write_stage.get()); - BlockInfo& block_info_write = self->block_info[result_block_sref]; + SBlockInfo& block_info_write = self->block_info[result_block_sref]; block_info_write.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info_write.region_cover = true; block_info_write.stage_pipeline = false; @@ -2242,8 +2242,8 @@ ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - Block block = ffi::GetRef(block_ptr); + const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); + SBlock block = ffi::GetRef(block_ptr); Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); arith::Analyzer analyzer; @@ -2299,14 +2299,14 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde } // Step 4. Making new reindex stage block and rewrite - Block reindex_stage = + SBlock reindex_stage = MakeReIndexStage(block, &info, covered, original_indices, buffer_index, buffer_index_type); Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info, covered); // Step 5. Replacing and updating flags self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get()); - BlockInfo& block_info = self->block_info[result_block_sref]; + SBlockInfo& block_info = self->block_info[result_block_sref]; block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.stage_pipeline = true; @@ -2324,9 +2324,9 @@ struct CacheReadTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, - ffi::Array consumer_blocks, - Integer read_buffer_index, ffi::String storage_scope) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, + ffi::Array consumer_blocks, + Integer read_buffer_index, ffi::String storage_scope) { return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks); } @@ -2358,9 +2358,9 @@ struct CacheWriteTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, - ffi::Array consumer_blocks, - Integer write_buffer_index, ffi::String storage_scope) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, + ffi::Array consumer_blocks, + Integer write_buffer_index, ffi::String storage_scope) { return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks); } @@ -2392,9 +2392,9 @@ struct CacheInplaceTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Integer read_buffer_index, - ffi::String storage_scope) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block, + Integer read_buffer_index, + ffi::String storage_scope) { return sch->CacheInplace(block, read_buffer_index->value, storage_scope); } @@ -2421,8 +2421,8 @@ struct ReIndexTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, - Integer buffer_index_type) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, Integer buffer_index, + Integer buffer_index_type) { return sch->ReIndex(block, buffer_index.IntValue(), static_cast(buffer_index_type->value)); } @@ -2452,8 +2452,8 @@ struct ReindexCacheReadTraits : public UnpackedInstTraitsReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); } @@ -2482,8 +2482,8 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraitsReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); } diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index cd56ff8b9ddf..420876637de0 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -37,8 +37,8 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError { : mod_(mod), num_not_visited_(num_not_visited) { required_.reserve(required.size()); for (const StmtSRef& block_sref : required) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - required_.push_back(ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + required_.push_back(ffi::GetRef(block)); } } @@ -68,7 +68,7 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError { private: IRModule mod_; int num_not_visited_; - ffi::Array required_; + ffi::Array required_; }; /*! @@ -110,11 +110,11 @@ class NotInSameScopeError : public ScheduleError { private: explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) : mod_(mod), - block_(ffi::GetRef(block_sref->StmtAs())), + block_(ffi::GetRef(block_sref->StmtAs())), loop_(ffi::GetRef(loop_sref->StmtAs())) {} IRModule mod_; - Block block_; + SBlock block_; For loop_; }; @@ -138,11 +138,10 @@ class NotInSameScopeError : public ScheduleError { * \throws ScheduleError if there is no such insertion point found */ template -int FindInsertionPoint(const ScheduleState& self, const ffi::Array& subtrees, - const ffi::Array& producer_srefs, - const ffi::Array& consumer_srefs, - std::unordered_map* block2realize, - int index) { +int FindInsertionPoint( + const ScheduleState& self, const ffi::Array& subtrees, + const ffi::Array& producer_srefs, const ffi::Array& consumer_srefs, + std::unordered_map* block2realize, int index) { ProducerConsumerSplit split = ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); // Step 1. Check if all the producers are visited in the subtrees, if required to @@ -239,7 +238,7 @@ struct BlockVarDomainInfo { */ class ScopeReconstructor : private StmtMutator { public: - explicit ScopeReconstructor(Block scope_root, Block block, For loop) + explicit ScopeReconstructor(SBlock scope_root, SBlock block, For loop) : scope_root_(scope_root), block_(block), loop_(loop) {} using StmtMutator::operator(); @@ -292,7 +291,7 @@ class ScopeReconstructor : private StmtMutator { } } this->new_block_realize_ = - BlockRealize(std::move(iter_values), analyzer->Simplify(predicate), std::move(block_)); + SBlockRealize(std::move(iter_values), analyzer->Simplify(predicate), std::move(block_)); Stmt new_subtree = this->new_block_realize_; for (int i = static_cast(loop_vars.size()) - 1; i >= 0; --i) { const Var& loop_var = loop_vars[i]; @@ -311,12 +310,12 @@ class ScopeReconstructor : private StmtMutator { } private: - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { if (block != scope_root_.get()) { - return ffi::GetRef(block); + return ffi::GetRef(block); } if (block == rm_src_stmt_.get()) { - block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode); + block = TVM_TYPE_AS(rm_tgt_stmt_, SBlockNode); } return StmtMutator::VisitStmt_(block); } @@ -333,15 +332,15 @@ class ScopeReconstructor : private StmtMutator { public: /*! \brief The root block of the block scope */ - Block scope_root_; + SBlock scope_root_; /*! \brief The given block to be moved */ - Block block_; + SBlock block_; /*! \brief The given loop the block and its loop nest to be put under */ For loop_; /*! \brief The new loop to replace the original loop */ For new_loop_{nullptr}; /*! \brief The new block realize to the moved block */ - BlockRealize new_block_realize_{nullptr}; + SBlockRealize new_block_realize_{nullptr}; /*! \brief The plan to remove the given block by replacing this loop/block in the AST */ Stmt rm_src_stmt_{nullptr}; /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ @@ -659,8 +658,8 @@ std::vector CalculateBlockVarDomain( */ template void CalculateProvidedRequiredRegions( - const BlockNode* block, const StmtSRef& loop_sref, - std::unordered_map block2realize, + const SBlockNode* block, const StmtSRef& loop_sref, + std::unordered_map block2realize, ffi::Array producer_srefs, ffi::Array consumer_srefs, std::unordered_map>* provided_regions, std::unordered_map>* required_regions) { @@ -676,10 +675,10 @@ void CalculateProvidedRequiredRegions( } // Step 2. Calculate the region required by dependent blocks under `loop` for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { - const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref); + const SBlockNode* required_block = TVM_SREF_TO_SBLOCK(required_block_sref); ICHECK(block2realize.count(required_block)); RelaxBufferRegions( - /*binding=*/GetBindings(ffi::GetRef(block2realize.at(required_block))), + /*binding=*/GetBindings(ffi::GetRef(block2realize.at(required_block))), /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, /*relax_path_low_inclusive=*/ffi::GetRef(required_block_sref->parent), /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions); @@ -693,15 +692,15 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s const StmtSRef& loop_sref, bool preserve_unit_loops, arith::Analyzer* analyzer, bool check_only = false, int index = -1) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Step 1. Bunch of checks // Check condition 1) : scope stage pipeline StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - Block scope_root = ffi::GetRef(scope_root_sref->StmtAs()); + SBlock scope_root = ffi::GetRef(scope_root_sref->StmtAs()); AddShapeVarBounds(self, scope_root_sref.get(), analyzer); - BlockScope scope = self->GetBlockScope(scope_root_sref); + SBlockScope scope = self->GetSBlockScope(scope_root_sref); ffi::Array producer_srefs = GetProducers(block_sref, scope); ffi::Array consumer_srefs = GetConsumers(block_sref, scope); // Check condition 2) : `block` is a complete or reduction block @@ -715,11 +714,11 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s CheckNotOutputBlock(self, block_sref, scope_root_sref); } // Step 2. Plan for the removal of `block` - ScopeReconstructor reconstructor(scope_root, ffi::GetRef(block), ffi::GetRef(loop)); + ScopeReconstructor reconstructor(scope_root, ffi::GetRef(block), ffi::GetRef(loop)); LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_); // Step 3. Find the insertion point under `loop` // Check condition 5): all the required block are under the given loop - std::unordered_map block2realize; + std::unordered_map block2realize; block2realize.reserve(self->block_info.size()); int insert_position = FindInsertionPoint( /*self=*/self, @@ -748,7 +747,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s // Step 6. Create the new scope according to the iteration domain reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), /*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops); - Block new_scope_root = Downcast(reconstructor(scope_root)); + SBlock new_scope_root = Downcast(reconstructor(scope_root)); // Step 7. Do the actual replacement if (check_only) { @@ -756,7 +755,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s } self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); // Step 8. Update the cached flags - BlockInfo& block_info = self->block_info[block_sref]; + SBlockInfo& block_info = self->block_info[block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/reconstructor.new_block_realize_, /*loop_var_ranges=*/LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent)), @@ -812,7 +811,7 @@ struct ComputeAtTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, LoopRV loop_rv, Bool preserve_unit_loops, IntImm index) { return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); } @@ -840,7 +839,7 @@ struct ReverseComputeAtTraits : public UnpackedInstTraitsReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 0ab6d7e2b699..0e3fc5e2a227 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -34,7 +34,7 @@ block should be covered by the producer block.)"; class HasInitBlock : public ScheduleError { public: - explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit HasInitBlock(IRModule mod, SBlock block) : mod_(mod), block_(block) {} ffi::String FastErrorString() const final { return "ScheduleError: The block has init statement"; @@ -47,7 +47,7 @@ class HasInitBlock : public ScheduleError { IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_}; } - static void Check(const IRModule& mod, const Block& block) { + static void Check(const IRModule& mod, const SBlock& block) { if (block->init.defined()) { throw HasInitBlock(mod, block); } @@ -55,12 +55,12 @@ class HasInitBlock : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; }; class NotSingleReadWriteBuffer : public ScheduleError { public: - explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) + explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, SBlock block) : mod_(mod), is_read_(is_read), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -85,9 +85,9 @@ class NotSingleReadWriteBuffer : public ScheduleError { IRModule mod_; bool is_read_; - Block block_; + SBlock block_; - static Buffer GetSingleRead(const ScheduleState& self, const Block& block, + static Buffer GetSingleRead(const ScheduleState& self, const SBlock& block, const StmtSRef& scope_root_sref) { const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; @@ -110,7 +110,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { return ffi::GetRef(read_buffer); } - static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { + static Buffer GetSingleWrite(const ScheduleState& self, const SBlock& block) { if (block->writes.size() != 1) { throw NotSingleReadWriteBuffer(self->mod, false, block); } @@ -120,7 +120,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { class BodyAnalysisError : public ScheduleError { public: - explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block) + explicit BodyAnalysisError(bool is_reverse, IRModule mod, SBlock block) : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -137,12 +137,12 @@ class BodyAnalysisError : public ScheduleError { bool is_reverse_; IRModule mod_; - Block block_; + SBlock block_; }; class NonSingleProducerError : public ScheduleError { public: - explicit NonSingleProducerError(IRModule mod, Block block) + explicit NonSingleProducerError(IRModule mod, SBlock block) : mod_(mod), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -161,7 +161,7 @@ class NonSingleProducerError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; /*! * \brief Check if the block has a single producer. @@ -173,15 +173,15 @@ class NonSingleProducerError : public ScheduleError { */ static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, const StmtSRef& scope_root_sref) { - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); - const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); + const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_root_sref); + const SBlockNode* consumer_block = TVM_SREF_TO_SBLOCK(consumer_block_sref); Buffer consumer_buffer = NotSingleReadWriteBuffer::GetSingleRead( - self, ffi::GetRef(consumer_block), scope_root_sref); + self, ffi::GetRef(consumer_block), scope_root_sref); class ProducerFinder : public StmtVisitor { public: - static std::vector GetProducer(const ScheduleState& self, - const StmtSRef& scope_root_sref, const Buffer& buffer, - const Block& scope_block) { + static std::vector GetProducer(const ScheduleState& self, + const StmtSRef& scope_root_sref, const Buffer& buffer, + const SBlock& scope_block) { ProducerFinder finder(self, scope_root_sref, buffer); finder(scope_block); return finder.producer_across_scope_.back(); @@ -194,7 +194,7 @@ class NonSingleProducerError : public ScheduleError { producer_across_scope_.push_back({}); } - void VisitStmt_(const BlockNode* node) final { + void VisitStmt_(const SBlockNode* node) final { producer_across_scope_.push_back({}); StmtVisitor::VisitStmt_(node); // not a leaf block @@ -213,9 +213,9 @@ class NonSingleProducerError : public ScheduleError { // Check if the producer block is a complete block StmtSRef producer_block_sref = self_->stmt2ref.at(node); if (!IsCompleteBlock(self_, producer_block_sref, scope_root_sref_)) { - throw NonSingleProducerError(self_->mod, ffi::GetRef(node)); + throw NonSingleProducerError(self_->mod, ffi::GetRef(node)); } - producer_across_scope_.back().push_back(ffi::GetRef(node)); + producer_across_scope_.back().push_back(ffi::GetRef(node)); break; } } @@ -223,12 +223,12 @@ class NonSingleProducerError : public ScheduleError { ScheduleState self_; StmtSRef scope_root_sref_; Buffer buffer_; - std::vector> producer_across_scope_; + std::vector> producer_across_scope_; }; - std::vector producer_across_scope = ProducerFinder::GetProducer( - self, scope_root_sref, consumer_buffer, ffi::GetRef(scope_block)); + std::vector producer_across_scope = ProducerFinder::GetProducer( + self, scope_root_sref, consumer_buffer, ffi::GetRef(scope_block)); if (producer_across_scope.size() != 1) { - throw NonSingleProducerError(self->mod, ffi::GetRef(consumer_block)); + throw NonSingleProducerError(self->mod, ffi::GetRef(consumer_block)); } return self->stmt2ref.at(producer_across_scope[0].get()); } @@ -238,8 +238,8 @@ class OpaqueAccessError : public ScheduleError { public: explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) : mod_(mod), scope_root_(nullptr) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); - this->scope_root_ = ffi::GetRef(scope_root); + const SBlockNode* scope_root = TVM_SREF_TO_SBLOCK(scope_root_sref); + this->scope_root_ = ffi::GetRef(scope_root); } ffi::String FastErrorString() const final { @@ -256,12 +256,12 @@ class OpaqueAccessError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {scope_root_}; } IRModule mod_; - Block scope_root_; + SBlock scope_root_; }; class ProducerHasNonTrivialPredicateError : public ScheduleError { public: - explicit ProducerHasNonTrivialPredicateError(IRModule mod, BlockRealize producer, + explicit ProducerHasNonTrivialPredicateError(IRModule mod, SBlockRealize producer, PrimExpr new_predicate) : mod_(mod), producer_(producer), new_predicate_(new_predicate) {} @@ -281,7 +281,7 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {producer_}; } IRModule mod_; - BlockRealize producer_; + SBlockRealize producer_; PrimExpr new_predicate_; }; @@ -293,7 +293,7 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { */ class BaseInliner : public StmtExprMutator { protected: - explicit BaseInliner(const Buffer& inlined_buffer, const Block& inlined_block, + explicit BaseInliner(const Buffer& inlined_buffer, const SBlock& inlined_block, const StmtSRef& scope_root_sref) : inlined_buffer_(inlined_buffer), inlined_store_(inlined_block->body.as()), @@ -314,15 +314,15 @@ class BaseInliner : public StmtExprMutator { return StmtExprMutator::VisitStmt_(loop); } - Stmt VisitStmt_(const BlockNode* block) { + Stmt VisitStmt_(const SBlockNode* block) { CheckMatchBufferRegion(block); AddBuffersInBlockSignature(block); - Block src_block = ffi::GetRef(block); + SBlock src_block = ffi::GetRef(block); if (src_block.same_as(src_stmt)) { - block = tgt_stmt.as(); + block = tgt_stmt.as(); ICHECK(block != nullptr); } - Block tgt_block = Downcast(StmtExprMutator::VisitStmt_(block)); + SBlock tgt_block = Downcast(StmtExprMutator::VisitStmt_(block)); bool is_scope_root = src_block.get() == scope_root_sref_->stmt; tgt_block = UpdateBuffersInBlockSignature(std::move(tgt_block), is_scope_root); block_reuse.Set(src_block, tgt_block); @@ -335,7 +335,7 @@ class BaseInliner : public StmtExprMutator { * which is used for auto-completion of a block's read/write region * \param block The block whose signature to be added */ - void AddBuffersInBlockSignature(const BlockNode* block) { + void AddBuffersInBlockSignature(const SBlockNode* block) { for (const BufferRegion& buffer_region : block->reads) { const Buffer& buffer = buffer_region->buffer; buffer_var_map_.Set(buffer->data, buffer); @@ -358,7 +358,7 @@ class BaseInliner : public StmtExprMutator { * \param is_scope_root A flag indicating if a block is the scope root of the block to be inlined * \return The updated block */ - Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) { + SBlock UpdateBuffersInBlockSignature(SBlock block, bool is_scope_root) { // Step 1. Update `BlockNode::alloc_buffers` ffi::Array alloc_buffers; if (is_scope_root) { @@ -380,12 +380,12 @@ class BaseInliner : public StmtExprMutator { if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), f_access_inline_buffer) || std::any_of(writes.begin(), writes.end(), f_access_inline_buffer))) { ffi::Array> inspected = - GetBlockReadWriteRegion(block, buffer_var_map_); + GetSBlockReadWriteRegion(block, buffer_var_map_); reads = inspected[0]; writes = inspected[1]; } // Step 3. Assemble the result - BlockNode* n = block.CopyOnWrite(); + SBlockNode* n = block.CopyOnWrite(); n->reads = std::move(reads); n->writes = std::move(writes); n->alloc_buffers = std::move(alloc_buffers); @@ -408,7 +408,7 @@ class BaseInliner : public StmtExprMutator { * This method checks if a block has the disallowed behavior of buffer region match. * \param block The block to be checked */ - void CheckMatchBufferRegion(const BlockNode* block) { + void CheckMatchBufferRegion(const SBlockNode* block) { for (const MatchBufferRegion& match_buffer_region : block->match_buffers) { const Buffer& matched = match_buffer_region->source->buffer; if (matched.same_as(inlined_buffer_)) { @@ -441,7 +441,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief The Stmt to be replaced to when removing the leaf block */ Stmt tgt_stmt{nullptr}; /*! \brief The reuse mapping of block srefs */ - ffi::Map block_reuse; + ffi::Map block_reuse; /*! \brief Indicates if there is any opaque access of the inlined buffer */ bool has_opaque_access{false}; }; @@ -455,11 +455,11 @@ class BaseInliner : public StmtExprMutator { */ class ComputeInliner : public BaseInliner { public: - explicit ComputeInliner(const Buffer& inlined_buffer, const Block& producer_block, + explicit ComputeInliner(const Buffer& inlined_buffer, const SBlock& producer_block, const StmtSRef& scope_root_sref) : BaseInliner(inlined_buffer, producer_block, scope_root_sref) {} - bool BodyPatternAllowInline(const Block& producer_block) { + bool BodyPatternAllowInline(const SBlock& producer_block) { if (inlined_store_ == nullptr) { return false; } @@ -614,8 +614,8 @@ class ReverseComputeInliner : public BaseInliner { }; public: - explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block, - const BlockRealize& consumer_block_realize, + explicit ReverseComputeInliner(const Buffer& inlined_buffer, const SBlockNode* producer_block, + const SBlockRealize& consumer_block_realize, const StmtSRef& scope_root_sref, const IRModule& mod) : BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref), producer_block_(producer_block), @@ -629,8 +629,8 @@ class ReverseComputeInliner : public BaseInliner { } } - bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) { - const Block& consumer_block = consumer_block_realize->block; + bool BodyPatternAllowInline(const SBlockRealize& consumer_block_realize) { + const SBlock& consumer_block = consumer_block_realize->block; if (!is_one(consumer_block_realize->predicate)) { // Failure: Predicate is the consumer block is not supported @@ -709,10 +709,10 @@ class ReverseComputeInliner : public BaseInliner { using BaseInliner::VisitStmt_; /*! \brief Generate the predicate after inlining based on the consumer predicate */ - BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) { + SBlockRealize BuildInlinedConsumerPredicate(SBlockRealize producer_block_realize) { // Bind the producer block iter domains for simplification ffi::Map subst_map; - Block producer_block = producer_block_realize->block; + SBlock producer_block = producer_block_realize->block; for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { const IterVar& iter = producer_block->iter_vars[i]; const PrimExpr& binding = producer_block_realize->iter_values[i]; @@ -751,12 +751,12 @@ class ReverseComputeInliner : public BaseInliner { auto n = producer_block_realize.CopyOnWrite(); n->block = producer_block; n->predicate = analyzer_.Simplify(outer_predicate); - return ffi::GetRef(n); + return ffi::GetRef(n); } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - Block src_block = op->block; - BlockRealize tgt_block_realize = Downcast(StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockRealizeNode* op) final { + SBlock src_block = op->block; + SBlockRealize tgt_block_realize = Downcast(StmtMutator::VisitStmt_(op)); if (src_block.get() == producer_block_) { tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize); block_reuse.Set(src_block, tgt_block_realize->block); @@ -868,9 +868,9 @@ class ReverseComputeInliner : public BaseInliner { /*! \brief The IterMap representing the indices of the consumer's BufferLoad */ ffi::Array buffer_load_iter_map_{nullptr}; /*! \brief The producer block */ - const BlockNode* producer_block_{nullptr}; + const SBlockNode* producer_block_{nullptr}; /* \brief The consumer block */ - const BlockNode* consumer_block_{nullptr}; + const SBlockNode* consumer_block_{nullptr}; /*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted * as the predicate of the producer block after inlining. */ @@ -881,8 +881,8 @@ class ReverseComputeInliner : public BaseInliner { void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { - const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref); - Block producer_block = ffi::GetRef(_producer_block); + const SBlockNode* _producer_block = TVM_SREF_TO_SBLOCK(producer_block_sref); + SBlock producer_block = ffi::GetRef(_producer_block); HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block @@ -926,9 +926,9 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, bool check_only = false) { - const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); - Block consumer_block = ffi::GetRef(_consumer_block); - BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref); + const SBlockNode* _consumer_block = TVM_SREF_TO_SBLOCK(consumer_block_sref); + SBlock consumer_block = ffi::GetRef(_consumer_block); + SBlockRealize consumer_block_realize = GetSBlockRealize(self, consumer_block_sref); HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // @@ -943,7 +943,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); // Step 4. Analyze the block body - ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs(), + ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs(), consumer_block_realize, scope_root_sref, self->mod); if (!inliner.BodyPatternAllowInline(consumer_block_realize)) { throw BodyAnalysisError(true, self->mod, consumer_block); @@ -963,9 +963,9 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); // Step 8. Update the cached flags arith::Analyzer analyzer; - BlockInfo& block_info = self->block_info[producer_block_sref]; + SBlockInfo& block_info = self->block_info[producer_block_sref]; block_info.affine_binding = IsAffineBinding( - /*realize=*/GetBlockRealize(self, producer_block_sref), + /*realize=*/GetSBlockRealize(self, producer_block_sref), /*loop_var_ranges=*/ LoopDomainOfSRefTreePath(ffi::GetRef(producer_block_sref->parent)), /*analyzer=*/&analyzer); @@ -997,8 +997,8 @@ enum class EpilogueType { class ReductionEpilogueFuser : public BaseInliner { public: - explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, - const BlockRealize& epilogue_block_realize, + explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const SBlockNode* reduction_block, + const SBlockRealize& epilogue_block_realize, const StmtSRef& scope_root_sref) : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), reduction_block_(reduction_block), @@ -1016,15 +1016,15 @@ class ReductionEpilogueFuser : public BaseInliner { // BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class } - bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize); + bool BodyPatternAllowFusion(const SBlockRealize& epilogue_block_realize); // Step 2: Create single fused reduction block - Block CreateFusedReductionBlock(const BlockNode* reduction_block, - const BlockRealizeNode* reduction_realize); + SBlock CreateFusedReductionBlock(const SBlockNode* reduction_block, + const SBlockRealizeNode* reduction_realize); private: bool AnalyzeEpiloguePattern(const PrimExpr& value); - bool IsReductionBlock(const BlockNode* block); + bool IsReductionBlock(const SBlockNode* block); void ExtractEpilogueInfo(); // Helper function to extract BufferLoad nodes from BufferStore static std::vector ExtractBufferLoad(const Buffer& buffer, @@ -1050,8 +1050,8 @@ class ReductionEpilogueFuser : public BaseInliner { return std::move(extractor.result); } - const BlockNode* reduction_block_; - const BlockNode* epilogue_block_; + const SBlockNode* reduction_block_; + const SBlockNode* epilogue_block_; PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C Buffer epilogue_output_buffer_{nullptr}; // Output buffer D ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] @@ -1063,7 +1063,7 @@ class ReductionEpilogueFuser : public BaseInliner { PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping }; -bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { +bool ReductionEpilogueFuser::BodyPatternAllowFusion(const SBlockRealize& epilogue_block_realize) { // 1. Validate predicate if (!is_one(epilogue_block_realize->predicate)) { // Failure: Predicate in epilogue block is not supported @@ -1245,7 +1245,7 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { return false; } -bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { +bool ReductionEpilogueFuser::IsReductionBlock(const SBlockNode* block) { // Check if block has reduction iter vars for (const IterVar& iter : block->iter_vars) { if (iter->iter_type == kCommReduce) { @@ -1281,9 +1281,9 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() { } } -Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reduction_block, - const BlockRealizeNode* reduction_realize) { - ObjectPtr new_block = ffi::make_object(*reduction_block); +SBlock ReductionEpilogueFuser::CreateFusedReductionBlock( + const SBlockNode* reduction_block, const SBlockRealizeNode* reduction_realize) { + ObjectPtr new_block = ffi::make_object(*reduction_block); // 1. Map epilogue block vars to reduction block vars std::vector reduction_data_vars; @@ -1422,13 +1422,13 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti new_block->reads = new_reads; - return Block(new_block); + return SBlock(new_block); } /*! * \brief Check if a buffer is still referenced by other blocks in the scope */ -static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) { +static bool CheckBufferStillUsed(const SBlock& scope_root, const Buffer& buffer) { class BufferUsageChecker : public StmtVisitor { public: explicit BufferUsageChecker(const Buffer& buffer) : buffer_(buffer) {} @@ -1440,7 +1440,7 @@ static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) } private: - void VisitStmt_(const BlockRealizeNode* op) final { + void VisitStmt_(const SBlockRealizeNode* op) final { if (found_usage_) return; if (!op || !op->block.defined()) { @@ -1448,7 +1448,7 @@ static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) return; } - const BlockNode* block = op->block.get(); + const SBlockNode* block = op->block.get(); if (!block) { StmtVisitor::VisitStmt_(op); return; @@ -1474,7 +1474,7 @@ static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) StmtVisitor::VisitStmt_(op); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (found_usage_) return; if (!op) return; @@ -1506,18 +1506,18 @@ static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) */ class SingleBlockFusionReplacer : public StmtMutator { public: - static Block Replace(Block old_scope_root, Block new_fused_block, Block old_reduction_block, - Block old_epilogue_block, Buffer reduction_buffer) { + static SBlock Replace(SBlock old_scope_root, SBlock new_fused_block, SBlock old_reduction_block, + SBlock old_epilogue_block, Buffer reduction_buffer) { SingleBlockFusionReplacer replacer(std::move(new_fused_block), std::move(old_reduction_block), std::move(old_epilogue_block), std::move(reduction_buffer)); - Block result = Downcast(replacer(std::move(old_scope_root))); + SBlock result = Downcast(replacer(std::move(old_scope_root))); // Check if reduction_buffer is still referenced by other blocks bool buffer_still_used = CheckBufferStillUsed(result, reduction_buffer); // Remove intermediate temp buffer only if it's not used by other blocks if (!buffer_still_used) { - BlockNode* p = result.CopyOnWrite(); + SBlockNode* p = result.CopyOnWrite(); ffi::Array new_alloc_buffers; for (const Buffer& buf : p->alloc_buffers) { if (!buf.same_as(reduction_buffer)) { @@ -1531,8 +1531,8 @@ class SingleBlockFusionReplacer : public StmtMutator { } private: - explicit SingleBlockFusionReplacer(Block new_fused_block, Block old_reduction_block, - Block old_epilogue_block, Buffer reduction_buffer) + explicit SingleBlockFusionReplacer(SBlock new_fused_block, SBlock old_reduction_block, + SBlock old_epilogue_block, Buffer reduction_buffer) : new_fused_block_(std::move(new_fused_block)), old_reduction_block_(std::move(old_reduction_block)), old_epilogue_block_(std::move(old_epilogue_block)), @@ -1549,12 +1549,12 @@ class SingleBlockFusionReplacer : public StmtMutator { loop->thread_binding, loop->annotations); } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { if (realize->block.same_as(old_reduction_block_)) { // Replace reduction block with new fused block - ObjectPtr new_realize = ffi::make_object(*realize); + ObjectPtr new_realize = ffi::make_object(*realize); new_realize->block = new_fused_block_; - return BlockRealize(new_realize); + return SBlockRealize(new_realize); } else if (realize->block.same_as(old_epilogue_block_)) { // Remove epilogue block completely return Evaluate(0); @@ -1575,20 +1575,20 @@ class SingleBlockFusionReplacer : public StmtMutator { } private: - Block new_fused_block_; - Block old_reduction_block_; - Block old_epilogue_block_; + SBlock new_fused_block_; + SBlock old_reduction_block_; + SBlock old_epilogue_block_; Buffer reduction_buffer_; }; void FuseReductionEpilogueImpl(ScheduleState self, const StmtSRef& reduction_block_sref, const StmtSRef& epilogue_block_sref, bool check_only = false) { - const BlockNode* _reduction_block = TVM_SREF_TO_BLOCK(reduction_block_sref); - const BlockNode* _epilogue_block = TVM_SREF_TO_BLOCK(epilogue_block_sref); + const SBlockNode* _reduction_block = TVM_SREF_TO_SBLOCK(reduction_block_sref); + const SBlockNode* _epilogue_block = TVM_SREF_TO_SBLOCK(epilogue_block_sref); - Block reduction_block = ffi::GetRef(_reduction_block); - Block epilogue_block = ffi::GetRef(_epilogue_block); - BlockRealize epilogue_block_realize = GetBlockRealize(self, epilogue_block_sref); + SBlock reduction_block = ffi::GetRef(_reduction_block); + SBlock epilogue_block = ffi::GetRef(_epilogue_block); + SBlockRealize epilogue_block_realize = GetSBlockRealize(self, epilogue_block_sref); // Step 1. Get the scope block StmtSRef scope_root_sref = @@ -1614,24 +1614,24 @@ void FuseReductionEpilogueImpl(ScheduleState self, const StmtSRef& reduction_blo } // Step 5. Create single fused reduction block - BlockRealize reduction_realize = GetBlockRealize(self, reduction_block_sref); - Block fused_block = fuser.CreateFusedReductionBlock(_reduction_block, reduction_realize.get()); + SBlockRealize reduction_realize = GetSBlockRealize(self, reduction_block_sref); + SBlock fused_block = fuser.CreateFusedReductionBlock(_reduction_block, reduction_realize.get()); // Step 6. Transform and replace IR - const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + const SBlockNode* old_scope_root = TVM_SREF_TO_SBLOCK(scope_root_sref); - Block new_scope_root = - SingleBlockFusionReplacer::Replace(ffi::GetRef(old_scope_root), fused_block, + SBlock new_scope_root = + SingleBlockFusionReplacer::Replace(ffi::GetRef(old_scope_root), fused_block, reduction_block, epilogue_block, reduction_buffer); // Step 7. Update schedule state - ffi::Map block_reuse; - block_reuse.Set(ffi::GetRef(old_scope_root), new_scope_root); + ffi::Map block_reuse; + block_reuse.Set(ffi::GetRef(old_scope_root), new_scope_root); block_reuse.Set(reduction_block, fused_block); self->Replace(scope_root_sref, new_scope_root, block_reuse); - // Step 8. Update BlockInfo - self->UpdateScopeBlockInfo(GetBlockRealize(self, scope_root_sref)); + // Step 8. Update SBlockInfo + self->UpdateScopeSBlockInfo(GetSBlockRealize(self, scope_root_sref)); } void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, @@ -1650,7 +1650,7 @@ struct ComputeInlineTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv) { return sch->ComputeInline(block_rv); } @@ -1673,7 +1673,7 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraitsReverseComputeInline(block_rv); } @@ -1699,8 +1699,8 @@ struct FuseReductionEpilogueTraits : public UnpackedInstTraitsFuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); } diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 7e61fd4eb20a..7cf4466939eb 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -25,7 +25,7 @@ namespace tvm { namespace tir { /*! \brief Information used to create new padding block */ -struct PaddingBlockInfo { +struct PaddingSBlockInfo { /*! \brief In-bound block iter regions, wrt loop vars. */ ffi::Array in_bound_region; /*! \brief In-bound value, wrt block iter vars. */ @@ -38,7 +38,7 @@ struct PaddingBlockInfo { class PaddingPatternMatchError : public ScheduleError { public: - PaddingPatternMatchError(IRModule mod, Block block, const std::string& error_msg) + PaddingPatternMatchError(IRModule mod, SBlock block, const std::string& error_msg) : mod_(std::move(mod)), block_(std::move(block)), error_msg_(error_msg) {} ffi::String FastErrorString() const final { @@ -57,7 +57,7 @@ class PaddingPatternMatchError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; std::string error_msg_; }; @@ -67,9 +67,9 @@ class PaddingPatternMatchError : public ScheduleError { */ class PaddingInfoAnalyzer { public: - static PaddingBlockInfo CheckAndGetPaddingInfo(IRModule mod, const BlockRealizeNode* realize, - const ffi::Map& dom_map, - arith::Analyzer* analyzer) { + static PaddingSBlockInfo CheckAndGetPaddingInfo(IRModule mod, const SBlockRealizeNode* realize, + const ffi::Map& dom_map, + arith::Analyzer* analyzer) { PaddingInfoAnalyzer padding_analyzer(analyzer); if (!padding_analyzer.MatchPadding(realize, dom_map)) { throw PaddingPatternMatchError(mod, realize->block, padding_analyzer.error_msg_); @@ -81,10 +81,10 @@ class PaddingInfoAnalyzer { explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {} /*! \brief Detect padding pattern and update result. */ - bool MatchPadding(const BlockRealizeNode* realize, const ffi::Map& dom_map) { + bool MatchPadding(const SBlockRealizeNode* realize, const ffi::Map& dom_map) { // Step 1. Check match padding computation pattern. // A[...] = T.if_then_else(predicate, B[...], imm) - Block block = realize->block; + SBlock block = realize->block; std::unordered_map iter_values; for (size_t i = 0; i < realize->iter_values.size(); ++i) { Var block_var = block->iter_vars[i]->var; @@ -186,7 +186,7 @@ class PaddingInfoAnalyzer { void SetError(const std::string& msg) { error_msg_ = msg; } /*! \brief padding info analyse result. */ - PaddingBlockInfo info_; + PaddingSBlockInfo info_; /*! \brief current error message. */ std::string error_msg_; /*! \brief arithmetic analyzer. */ @@ -194,12 +194,12 @@ class PaddingInfoAnalyzer { }; /*! \brief Create block to fill constant pad values into full region */ -static std::pair CreateConstBlock(const BlockRealizeNode* realize, - const PaddingBlockInfo& info, - const ffi::Array& loops, - const Stmt& highest_pos_inclusive, - arith::Analyzer* analyzer) { - const Block& block = realize->block; +static std::pair CreateConstBlock(const SBlockRealizeNode* realize, + const PaddingSBlockInfo& info, + const ffi::Array& loops, + const Stmt& highest_pos_inclusive, + arith::Analyzer* analyzer) { + const SBlock& block = realize->block; ffi::Array new_iter_vars; ffi::Map repl_dict; @@ -227,8 +227,8 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re BufferStore store = Downcast(block->body); store.CopyOnWrite()->value = info.pad_value; store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr); - Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/{}, /*writes=*/{write_region}, - /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store)); + SBlock new_block(/*iter_vars=*/new_iter_vars, /*reads=*/{}, /*writes=*/{write_region}, + /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store)); // create new loop vars ffi::Array new_loop_vars; @@ -246,9 +246,9 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re for (size_t i = 0; i < realize->iter_values.size(); ++i) { new_iter_values.push_back(rewrite_expr(realize->iter_values[i])); } - BlockRealize new_realize(/*iter_values=*/new_iter_values, - /*predicate=*/rewrite_expr(realize->predicate), - /*block=*/new_block); + SBlockRealize new_realize(/*iter_values=*/new_iter_values, + /*predicate=*/rewrite_expr(realize->predicate), + /*block=*/new_block); // create new loops Stmt nest_stmt_root = new_realize; @@ -262,13 +262,13 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re } /*! \brief Create block to fill in-bound region values. */ -static std::pair CreateInBoundBlock(const BlockRealizeNode* realize, - const PaddingBlockInfo& info, +static std::pair CreateInBoundBlock(const SBlockRealizeNode* realize, + const PaddingSBlockInfo& info, - const ffi::Array& loops, - const Stmt& highest_pos_inclusive, - arith::Analyzer* analyzer) { - const Block& block = realize->block; + const ffi::Array& loops, + const Stmt& highest_pos_inclusive, + arith::Analyzer* analyzer) { + const SBlock& block = realize->block; ffi::Array new_iter_vars; ffi::Map repl_dict; @@ -330,11 +330,11 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* BufferStore store = Downcast(block->body); store.CopyOnWrite()->value = rewrite_expr(info.in_bound_value); store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr); - Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/reads, /*writes=*/writes, - /*name_hint=*/block->name_hint, /*body=*/std::move(store)); + SBlock new_block(/*iter_vars=*/new_iter_vars, /*reads=*/reads, /*writes=*/writes, + /*name_hint=*/block->name_hint, /*body=*/std::move(store)); PrimExpr new_predicate = rewrite_expr(info.in_bound_predicate); - BlockRealize new_realize(/*iter_values=*/new_iter_binding, /*predicate=*/new_predicate, - /*block=*/new_block); + SBlockRealize new_realize(/*iter_values=*/new_iter_binding, /*predicate=*/new_predicate, + /*block=*/new_block); // create new loops Stmt nest_stmt_root = new_realize; @@ -367,14 +367,14 @@ class DecomposePaddingBlockReplacer : public StmtMutator { /*! \brief highest in bound value filling loop with single child. */ Stmt in_bound_filling_loop; /*! \brief const pad value filling block. */ - BlockRealize const_filling_block; + SBlockRealize const_filling_block; /*! \brief in bound value filling block. */ - BlockRealize in_bound_filling_block; + SBlockRealize in_bound_filling_block; }; - static Block Replace(Block scope_root, const ReplaceDesc& desc) { + static SBlock Replace(SBlock scope_root, const ReplaceDesc& desc) { DecomposePaddingBlockReplacer replacer(desc); - return Downcast(replacer(std::move(scope_root))); + return Downcast(replacer(std::move(scope_root))); } private: @@ -411,8 +411,8 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, * - trim original block to write non-padding part only */ // Condition Checks and Information Collection - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + const SBlockRealizeNode* realize = GetSBlockRealize(self, block_sref).get(); ffi::Map dom_map; arith::Analyzer analyzer; @@ -445,7 +445,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, } } else if (!found_in_bound_filling_pos) { if (!cur_loop->body->IsInstance() && - !cur_loop->body->IsInstance()) { + !cur_loop->body->IsInstance()) { found_in_bound_filling_pos = true; } else { in_bound_filling_pos = cur_loop; @@ -454,12 +454,12 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, } ICHECK(in_bound_filling_pos.defined()); if (!found_const_filling_pos) { - throw LoopPositionError(self->mod, const_filling_pos, ffi::GetRef(block), + throw LoopPositionError(self->mod, const_filling_pos, ffi::GetRef(block), "decompose_padding"); } // Check 3. match padding pattern and return padding operation info. - PaddingBlockInfo info = + PaddingSBlockInfo info = PaddingInfoAnalyzer::CheckAndGetPaddingInfo(self->mod, realize, dom_map, &analyzer); // IR Manipulation @@ -473,8 +473,9 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer); // Step 2. Execute IR replacement. - Block old_scope_root_block = ffi::GetRef(scope_root_sref->StmtAs()); - Block new_scope_root = DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc); + SBlock old_scope_root_block = ffi::GetRef(scope_root_sref->StmtAs()); + SBlock new_scope_root = + DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc); if (check_only) { return block_sref; } @@ -482,11 +483,11 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Step 3. Update schedule states. self->Replace(scope_root_sref, new_scope_root, {{old_scope_root_block, new_scope_root}, - {ffi::GetRef(block), replace_desc.in_bound_filling_block->block}}); + {ffi::GetRef(block), replace_desc.in_bound_filling_block->block}}); auto new_block_sref = self->stmt2ref.at(replace_desc.const_filling_block->block.get()); // Set block info of created const pad value filling block - BlockInfo& block_info = self->block_info[new_block_sref]; + SBlockInfo& block_info = self->block_info[new_block_sref]; block_info.affine_binding = true; block_info.region_cover = true; block_info.stage_pipeline = true; @@ -536,7 +537,7 @@ bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.schedule.CanDecomposePadding", [](Schedule self, BlockRV block_rv, LoopRV loop_rv) { + "tir.schedule.CanDecomposePadding", [](Schedule self, SBlockRV block_rv, LoopRV loop_rv) { return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); }); } @@ -552,7 +553,7 @@ struct DecomposPaddingTraits : public UnpackedInstTraits static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, LoopRV loop_rv) { return sch->DecomposePadding(block_rv, loop_rv); } diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index de550979c18f..01cdb084950f 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -23,7 +23,7 @@ namespace tir { class WrongBlockIterTypeError : public ScheduleError { public: - explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, Block block) + explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, SBlock block) : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) { op_str_ = for_kind == ForKind::kParallel ? "parallel" @@ -56,7 +56,7 @@ class WrongBlockIterTypeError : public ScheduleError { IRModule mod_; std::string op_str_; Var loop_var_; - Block block_; + SBlock block_; }; /*! @@ -78,9 +78,9 @@ class WrongBlockIterTypeError : public ScheduleError { * the input block */ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, - const Var& loop_var, const BlockRealize& block_realize, + const Var& loop_var, const SBlockRealize& block_realize, runtime::ThreadScope thread_scope) { - const Block& block = block_realize->block; + const SBlock& block = block_realize->block; // Cond 1. The block is required to have affine bindings. // TODO(@automation): fix the check @@ -121,14 +121,14 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind for_kind, runtime::ThreadScope thread_scope) { PreOrderVisit(loop, [&](const ObjectRef& node) { - if (const auto* realize = node.as()) { + if (const auto* realize = node.as()) { // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block // inside `tir.init()`. We don't check the condition for such blocks. if (!self->stmt2ref.count(realize->block.get())) { return false; } CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, - ffi::GetRef(realize), thread_scope); + ffi::GetRef(realize), thread_scope); } return true; }); diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 0ad1d82ee0df..28293624b1d8 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -22,13 +22,13 @@ namespace tvm { namespace tir { -ffi::Array GetBlocks(const ScheduleState& self, const ffi::String& name, - const GlobalVar& gv) { +ffi::Array GetSBlocks(const ScheduleState& self, const ffi::String& name, + const GlobalVar& gv) { struct Finder : public StmtVisitor { explicit Finder(const ScheduleState& self, const ffi::String& name) : self_(self), name_(name) {} - void VisitStmt_(const BlockNode* block) override { + void VisitStmt_(const SBlockNode* block) override { if (block->name_hint == name_) { auto it = self_->stmt2ref.find(block); ICHECK(it != self_->stmt2ref.end()); @@ -61,7 +61,7 @@ ffi::Array GetLoops(const StmtSRef& block_sref) { ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { private: - void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } + void VisitStmt_(const SBlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } public: explicit Collector(const ScheduleState& self) : self(self) {} @@ -73,8 +73,8 @@ ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& p if (parent_sref->stmt->IsInstance()) { const auto* loop = static_cast(parent_sref->stmt); collector(loop->body); - } else if (parent_sref->stmt->IsInstance()) { - const auto* block = static_cast(parent_sref->stmt); + } else if (parent_sref->stmt->IsInstance()) { + const auto* block = static_cast(parent_sref->stmt); collector(block->body); } return std::move(collector.result); @@ -82,23 +82,23 @@ ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& p ffi::Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - return tir::GetProducers(block_sref, self->GetBlockScope(scope_root)); + return tir::GetProducers(block_sref, self->GetSBlockScope(scope_root)); } ffi::Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); + return tir::GetConsumers(block_sref, self->GetSBlockScope(scope_root)); } ffi::Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { - const auto* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + const auto* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); return tir::GetOutputBlocks(self, scope_block); } /******** InstructionKind Registration ********/ -struct GetBlockTraits : public UnpackedInstTraits { - static constexpr const char* kName = "GetBlock"; +struct GetSBlockTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetSBlock"; static constexpr bool kIsPure = true; private: @@ -106,13 +106,13 @@ struct GetBlockTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, ffi::String name, ffi::String func_name) { - return sch->GetBlock(name, func_name); + static SBlockRV UnpackedApplyToSchedule(Schedule sch, ffi::String name, ffi::String func_name) { + return sch->GetSBlock(name, func_name); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String name, ffi::String func_name) { - PythonAPICall py("get_block"); + PythonAPICall py("get_sblock"); py.Input("name", name); py.Input("func_name", func_name); py.SingleOutput(outputs); @@ -132,7 +132,7 @@ struct GetLoopsTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv) { return sch->GetLoops(block_rv); } @@ -156,14 +156,15 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { - if (auto block = block_or_loop_rv.as()) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { + if (auto block = block_or_loop_rv.as()) { return sch->GetChildBlocks(block.value()); } if (auto loop = block_or_loop_rv.as()) { return sch->GetChildBlocks(loop.value()); } - LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); throw; } @@ -188,7 +189,7 @@ struct GetProducersTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv) { return sch->GetProducers(block_rv); } @@ -212,7 +213,7 @@ struct GetConsumersTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv) { return sch->GetConsumers(block_rv); } @@ -236,7 +237,7 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv) { return sch->GetOutputBlocks(block_rv); } @@ -251,7 +252,7 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits friend struct ::tvm::tir::UnpackedInstTraits; }; -TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetSBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); diff --git a/src/tir/schedule/primitive/hide_buffer_access.cc b/src/tir/schedule/primitive/hide_buffer_access.cc index f5e92b8ba50b..98805845b6ea 100644 --- a/src/tir/schedule/primitive/hide_buffer_access.cc +++ b/src/tir/schedule/primitive/hide_buffer_access.cc @@ -86,7 +86,7 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, * - validity of buf_index_array * - validity of buf_type */ - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); int num_access_regions = 0; if (buf_type == "read") { num_access_regions = block->reads.size(); @@ -130,12 +130,12 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, /* Step 1: Replace old block with the new block */ - auto n = ffi::make_object(*block); + auto n = ffi::make_object(*block); n->reads = reads; n->writes = writes; - Block new_block = Block(n); - ffi::Map blk_map; - blk_map.Set(ffi::GetRef(block), new_block); + SBlock new_block = SBlock(n); + ffi::Map blk_map; + blk_map.Set(ffi::GetRef(block), new_block); self->Replace(block_sref, new_block, blk_map); } @@ -148,7 +148,7 @@ struct UnsafeHideBufferAccessTraits : public UnpackedInstTraits buf_index_array) { sch->UnsafeHideBufferAccess(block, buf_type, buf_index_array); } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index c625d8c153cf..707d4e22a886 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -76,12 +76,12 @@ class TransformLayoutPlanner : private StmtExprVisitor { // Loops within the analyzed block that should be replaced struct ReplacementPlan { ffi::Map replacements; - ffi::Map new_block_to_old; + ffi::Map new_block_to_old; }; // The block to be inserted, along with the location at which it // should be inserted. The location will be either a For or a - // Block, and will be after all writes the transformed buffer. + // SBlock, and will be after all writes the transformed buffer. struct EpiloguePlan { Stmt insert_after; Stmt new_block; @@ -92,7 +92,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { using TransformPlan = std::variant; - static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, + static TransformPlan Plan(SBlock block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, ffi::Optional pad_value, arith::Analyzer* analyzer) { ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) @@ -108,7 +108,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BufferStore store; // The block realize that contains the store, if any. - ffi::Optional innermost_block_realize; + ffi::Optional innermost_block_realize; // The nested loops whose values contribute to the indices used in // the store. Not all loop variables in the loopnest need to @@ -134,8 +134,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const BlockRealizeNode* op) override { - BindBlockRealize context(this, ffi::GetRef(op)); + void VisitStmt_(const SBlockRealizeNode* op) override { + BindBlockRealize context(this, ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); } @@ -221,7 +221,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { public: BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate, const IndexMap& inverse, const ffi::Optional& pad_value, - ffi::Map* new_block_to_old, arith::Analyzer* analyzer) + ffi::Map* new_block_to_old, arith::Analyzer* analyzer) : info(info), new_buffer(new_buffer), new_indices(inverse->initial_indices), @@ -248,7 +248,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return; } - BlockRealize block_realize = info.innermost_block_realize.value(); + SBlockRealize block_realize = info.innermost_block_realize.value(); const auto& block = block_realize->block; const ffi::Array& old_indices = info.store->indices; const auto& old_iter_vars = block->iter_vars; @@ -365,11 +365,11 @@ class TransformLayoutPlanner : private StmtExprVisitor { return StmtExprMutator::VisitStmt_(store.get()); } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockRealizeNode* op) final { + SBlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); if (op == info.innermost_block_realize.get()) { - Block block = realize->block; + SBlock block = realize->block; if (!block->iter_vars.same_as(this->new_iter_vars)) { block.CopyOnWrite()->iter_vars = this->new_iter_vars; RecordReplacement(op->block, block); @@ -386,9 +386,9 @@ class TransformLayoutPlanner : private StmtExprVisitor { return realize; } - Stmt VisitStmt_(const BlockNode* op) final { - Block orig = ffi::GetRef(op); - Block mutated = Downcast(StmtExprMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock orig = ffi::GetRef(op); + SBlock mutated = Downcast(StmtExprMutator::VisitStmt_(op)); RecordReplacement(orig, mutated); return mutated; @@ -403,7 +403,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } } - void RecordReplacement(Block before, Block after) { + void RecordReplacement(SBlock before, SBlock after) { if (before.same_as(after)) { return; } @@ -429,7 +429,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { PrimExpr padding_predicate; const IndexMap& inverse; const ffi::Optional& pad_value; - ffi::Map& new_block_to_old; + ffi::Map& new_block_to_old; bool all_stores_replaced{true}; arith::Analyzer* analyzer; @@ -488,8 +488,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::stringstream block_name; block_name << "buffer_" << new_buffer->name << "_assumptions"; auto read_region = BufferRegion::FromPoint(new_buffer, indices); - stmt = BlockRealize(iter_values, Bool(true), - Block(iter_vars, {read_region}, {}, block_name.str(), stmt)); + stmt = SBlockRealize(iter_values, Bool(true), + SBlock(iter_vars, {read_region}, {}, block_name.str(), stmt)); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { size_t i = (inverse->initial_indices.size() - 1) - rev_i; @@ -509,7 +509,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return std::nullopt; } - ffi::Map new_block_to_old; + ffi::Map new_block_to_old; auto generate_if_then_else_block = [&](const WriteInfo& info) -> ffi::Optional { if (!info.contains_row_major_traversal || !pad_value.defined() || is_zero(padding_predicate)) { @@ -579,8 +579,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::stringstream block_name; block_name << "buffer_" << new_buffer->name << "_padding"; auto write_region = BufferRegion::FromPoint(new_buffer, indices); - stmt = BlockRealize(iter_values, padding_predicate, - Block(iter_vars, {}, {write_region}, block_name.str(), stmt)); + stmt = SBlockRealize(iter_values, padding_predicate, + SBlock(iter_vars, {}, {write_region}, block_name.str(), stmt)); ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { @@ -657,7 +657,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { }; struct BindBlockRealize { - BindBlockRealize(TransformLayoutPlanner* self, BlockRealize block_realize) : self_(self) { + BindBlockRealize(TransformLayoutPlanner* self, SBlockRealize block_realize) : self_(self) { ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); for (size_t i = 0; i < block_realize->iter_values.size(); i++) { bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var, @@ -673,7 +673,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BindBlockRealize& operator=(BindBlockRealize&&) = delete; TransformLayoutPlanner* self_{nullptr}; - ffi::Optional cache_; + ffi::Optional cache_; std::vector bound_vars_; }; @@ -707,7 +707,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { * * Used to fill the `WriteInfo::innermost_block_realize` field.. */ - ffi::Optional innermost_block_realize_{std::nullopt}; + ffi::Optional innermost_block_realize_{std::nullopt}; /*! \brief The buffer to be replaced */ Buffer old_buffer_; @@ -719,23 +719,24 @@ class TransformLayoutPlanner : private StmtExprVisitor { */ class ReuseBlocksCollector : public tir::StmtVisitor { public: - static ffi::Map Collect(Block result, ffi::Map new_block_to_old) { + static ffi::Map Collect(SBlock result, + ffi::Map new_block_to_old) { return ReuseBlocksCollector(new_block_to_old).Run(result); } private: /*! \brief Entry point */ - ffi::Map Run(const Block result) { + ffi::Map Run(const SBlock result) { VisitStmt(result); return block_sref_reuse_; } /*! \brief Constructor */ - explicit ReuseBlocksCollector(ffi::Map new_block_to_old) + explicit ReuseBlocksCollector(ffi::Map new_block_to_old) : new_block_to_old_(new_block_to_old) {} /*! \brief Override the Stmt visiting behaviour */ - void VisitStmt_(const tir::BlockNode* block) override { - Block block_ref = ffi::GetRef(block); + void VisitStmt_(const tir::SBlockNode* block) override { + SBlock block_ref = ffi::GetRef(block); auto it = new_block_to_old_.find(block_ref); if (it != new_block_to_old_.end()) { block_sref_reuse_.Set((*it).second, (*it).first); @@ -744,9 +745,9 @@ class ReuseBlocksCollector : public tir::StmtVisitor { } /*! \brief New map to be filled with just blocks from scope block */ - ffi::Map block_sref_reuse_; + ffi::Map block_sref_reuse_; /*! \brief All block replacements collected so far */ - ffi::Map new_block_to_old_; + ffi::Map new_block_to_old_; }; class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { @@ -760,8 +761,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { * \return The new AST rooting at the original parent scope and the map from the old block to the * new block */ - static std::pair> Rewrite( - const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, + static std::pair> Rewrite( + const SBlock& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, const IndexMap& index_map, const ffi::Optional& opt_inverse, const PrimExpr& padding_predicate, const ffi::Optional& pad_value) { arith::Analyzer analyzer; @@ -772,13 +773,13 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { : TransformLayoutPlanner::NoPaddingRequired(); TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); - Block result = Downcast(rewriter(scope_stmt)); + SBlock result = Downcast(rewriter(scope_stmt)); if (auto plan_ptr = std::get_if(&plan)) { auto write_ptr = result.CopyOnWrite(); write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); } - ffi::Map block_sref_reuse = + ffi::Map block_sref_reuse = ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_); return {result, block_sref_reuse}; @@ -865,9 +866,9 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { (*old_access_regions).MutateByApply(fmutate); } - Stmt VisitStmt_(const BlockNode* op) final { - Block orig = [&]() { - Block block = ffi::GetRef(op); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock orig = [&]() { + SBlock block = ffi::GetRef(op); while (true) { if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) { block = (*it).second; @@ -878,9 +879,9 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { return block; }(); - Block block = Downcast(Parent::VisitStmt_(op)); + SBlock block = Downcast(Parent::VisitStmt_(op)); - auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto infered_access_regions = GetSBlockReadWriteRegion(block, buffer_data_to_buffer_); auto* n = block.CopyOnWrite(); RewriteAccessRegion(&n->reads, infered_access_regions[0]); RewriteAccessRegion(&n->writes, infered_access_regions[1]); @@ -896,7 +897,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { return block; } - void RecordReplacement(Block before, Block after) { + void RecordReplacement(SBlock before, SBlock after) { if (before.same_as(after)) { return; } @@ -919,7 +920,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const IndexMap& index_map_; const TransformLayoutPlanner::TransformPlan& plan_; ffi::Map buffer_data_to_buffer_; - ffi::Map new_block_to_old_; + ffi::Map new_block_to_old_; arith::Analyzer index_simplifier_; }; @@ -1150,9 +1151,9 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ arith::Analyzer analyzer; AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Input handling and error checking - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); Buffer old_buffer = - GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape); @@ -1174,7 +1175,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); ffi::Optional opt_inverse = std::nullopt; PrimExpr padding_predicate = Bool(false); @@ -1200,9 +1201,9 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block // alloc_buffers. auto [new_stmt, block_sref_reuse] = - TransformLayoutRewriter::Rewrite(ffi::GetRef(scope_block), old_buffer, new_buffer, + TransformLayoutRewriter::Rewrite(ffi::GetRef(scope_block), old_buffer, new_buffer, index_map, opt_inverse, padding_predicate, pad_value); - Block new_scope_block = Downcast(new_stmt); + SBlock new_scope_block = Downcast(new_stmt); // Step 4: Rewrite buffer_map of the PrimFunc if necessary. if (!defining_site_sref.defined()) { @@ -1287,12 +1288,12 @@ class NotBijectiveAffineIndexMapError : public ScheduleError { class IndexMapNotApplicableToBlockIterError : public ScheduleError { public: - static void Check(const IRModule mod, const Block& block, const IndexMap& index_map) { + static void Check(const IRModule mod, const SBlock& block, const IndexMap& index_map) { if (index_map->initial_indices.size() != block->iter_vars.size()) { throw IndexMapNotApplicableToBlockIterError(mod, block, index_map); } } - explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, IndexMap index_map) + explicit IndexMapNotApplicableToBlockIterError(IRModule mod, SBlock block, IndexMap index_map) : mod_(std::move(mod)), block_(std::move(block)), index_map_(std::move(index_map)) {} ffi::String FastErrorString() const final { @@ -1315,13 +1316,13 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; IndexMap index_map_; }; class OpaqueNewIterTypeError : public ScheduleError { public: - explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value) + explicit OpaqueNewIterTypeError(IRModule mod, SBlock block, PrimExpr iter_value) : mod_(std::move(mod)), block_(std::move(block)), iter_value_(std::move(iter_value)) {} ffi::String FastErrorString() const final { @@ -1341,14 +1342,14 @@ class OpaqueNewIterTypeError : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; PrimExpr iter_value_; }; void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - const Block& block = ffi::GetRef(block_ptr); + const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); + const SBlock& block = ffi::GetRef(block_ptr); arith::Analyzer analyzer; AddShapeVarBounds(self, block_sref.get(), &analyzer); @@ -1370,7 +1371,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, scope_sref = block_sref; } - BlockRealize block_realize = GetBlockRealize(self, block_sref); + SBlockRealize block_realize = GetSBlockRealize(self, block_sref); CheckBlockHasTrivialBinding(self, block_sref); // Step 3: Collect information of block iter vars @@ -1410,7 +1411,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); } if (iter_type == kOpaque) { - throw OpaqueNewIterTypeError(self->mod, ffi::GetRef(block_ptr), + throw OpaqueNewIterTypeError(self->mod, ffi::GetRef(block_ptr), transformed_block_iters[i]); } auto dtype = new_block_var.dtype(); @@ -1441,9 +1442,10 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, inverse_subst_map.Set(Downcast(block_vars[i]), inversed_new_block_vars[i]); } } - Block new_block = Downcast(Substitute(ffi::GetRef(block_ptr), inverse_subst_map)); + SBlock new_block = + Downcast(Substitute(ffi::GetRef(block_ptr), inverse_subst_map)); new_block.CopyOnWrite()->iter_vars = new_block_iters; - new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); + new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); // Step 5.3: Create outer loops for each new block iter. @@ -1454,7 +1456,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, } // Make new block realize - BlockRealizeNode* new_block_realize = block_realize.CopyOnWrite(); + SBlockRealizeNode* new_block_realize = block_realize.CopyOnWrite(); new_block_realize->iter_values = new_loop_vars; new_block_realize->block = new_block; @@ -1466,7 +1468,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, } // Step 6: Do the actual replacement - if (scope_sref->StmtAs()) { + if (scope_sref->StmtAs()) { ICHECK(new_loop_vars.empty()) << "Invalid block to loop replacement due to layout transform " << index_map; } @@ -1475,15 +1477,15 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, class BufferAxisSeparatorMutator : private ReplaceBufferMutator { public: - static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer, - ffi::Map* block_sref_reuse) { + static SBlock Mutate(const SBlock& scope_block, const Buffer& old_buffer, Buffer new_buffer, + ffi::Map* block_sref_reuse) { BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse); - return Downcast(mutator.VisitStmt(scope_block)); + return Downcast(mutator.VisitStmt(scope_block)); } private: BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer, - ffi::Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {} MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { @@ -1513,9 +1515,9 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array& axis_separators) { - const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); Buffer old_buffer = - GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); @@ -1524,17 +1526,17 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); // Step 1: Check and update axis_separators of the buffer. Buffer new_buffer = old_buffer; new_buffer.CopyOnWrite()->axis_separators = axis_separators; - ffi::Map block_sref_reuse; + ffi::Map block_sref_reuse; // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. - Block new_scope_block = BufferAxisSeparatorMutator::Mutate( - ffi::GetRef(scope_block), old_buffer, new_buffer, &block_sref_reuse); + SBlock new_scope_block = BufferAxisSeparatorMutator::Mutate( + ffi::GetRef(scope_block), old_buffer, new_buffer, &block_sref_reuse); if (!defining_site_sref.defined()) { // mutate buffer_map of the PrimFunc GlobalVar g_var; @@ -1567,7 +1569,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits static constexpr size_t kNumAttrs = 4; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, IndexMap index_map, Integer buffer_index, Integer buffer_index_type, ffi::Optional pad_value, Bool assume_injective_transform) { @@ -1636,7 +1638,7 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraitsTransformBlockLayout(block_rv, index_map); } @@ -1676,7 +1678,7 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits axis_separators) { return sch->SetAxisSeparator(block_rv, buffer_index.IntValue(), diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 3cd364b0fd2b..96ea4d2527d1 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -32,11 +32,11 @@ class BlockPredicateAppender : public StmtMutator { private: // For each direct child of type BlockRealizeNode, append the predicate - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { // We do not recursively do this - ObjectPtr n = CopyOnWrite(realize); + ObjectPtr n = CopyOnWrite(realize); n->predicate = n->predicate && to_append_; - return BlockRealize(n); + return SBlockRealize(n); } /*! \brief The predicate to be appended */ @@ -48,7 +48,7 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { public: explicit SubstituteVarAndCollectOpaqueBlock( std::function(const Var&)> vmap, - ffi::Map* opaque_blocks) + ffi::Map* opaque_blocks) : vmap_(vmap), opaque_blocks_(opaque_blocks) {} private: @@ -61,8 +61,8 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { } } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockRealizeNode* op) final { + SBlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); if (realize->block->iter_vars.empty()) { opaque_blocks_->Set(op->block, realize->block); } @@ -72,7 +72,7 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { /*! \brief The substitute function */ std::function(const Var&)> vmap_; /*! \brief The reuse mapping of opaque blocks */ - ffi::Map* opaque_blocks_; + ffi::Map* opaque_blocks_; }; /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ @@ -104,11 +104,11 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { return res; } - Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt VisitStmt_(const SBlockRealizeNode* op) final { // skip opaque block and update mapping if (op->iter_values.empty()) { - Block block = op->block; - BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + SBlock block = op->block; + SBlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); for (const auto& entry : *opaque_blocks_) { if (entry.second.same_as(block)) { opaque_blocks_->at(entry.first) = realize->block; @@ -127,7 +127,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { if (v.same_as(op->iter_values)) { return ffi::GetRef(op); } else { - ObjectPtr n = CopyOnWrite(op); + ObjectPtr n = CopyOnWrite(op); n->iter_values = std::move(v); return Stmt(n); } @@ -160,15 +160,15 @@ class BlockPropertyError : public ScheduleError { : state_(state), top_(top) {} private: - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { for (const IterVar& iter_var : op->iter_vars) { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - throw BlockPropertyError(state_->mod, ffi::GetRef(op)); + throw BlockPropertyError(state_->mod, ffi::GetRef(op)); } ffi::Optional high_exclusive = top_->parent ? ffi::GetRef(top_->parent) : ffi::Optional(std::nullopt); - CheckPartialAffineBinding(state_, ffi::GetRef(op), high_exclusive); + CheckPartialAffineBinding(state_, ffi::GetRef(op), high_exclusive); } } const ScheduleState& state_; @@ -179,7 +179,7 @@ class BlockPropertyError : public ScheduleError { checker(ffi::GetRef(sref->stmt)); } - explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} + explicit BlockPropertyError(IRModule mod, SBlock block) : mod_(mod), block_(std::move(block)) {} ffi::String FastErrorString() const final { return "ScheduleError: The block under the loops to be reordered have block iter type other " @@ -195,7 +195,7 @@ class BlockPropertyError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Block block_; + SBlock block_; }; class HasAnnotationOrThreadBindingError : public ScheduleError { @@ -426,7 +426,7 @@ ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); new_loop_vars.emplace_back(std::move(var)); } - ffi::Map opaque_block_reuse; + ffi::Map opaque_block_reuse; Stmt new_stmt = loop->body; new_stmt = SubstituteVarAndCollectOpaqueBlock( [&](const Var& v) -> ffi::Optional { @@ -464,7 +464,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { public: explicit BufferIndicesMapExtractor(Var loop_var) : loop_var_(loop_var) {} - static ffi::Map> Extract(Var loop_var, Block& block) { + static ffi::Map> Extract(Var loop_var, SBlock& block) { BufferIndicesMapExtractor extractor(loop_var); extractor(std::move(block->body)); return extractor.buffer_indices_map; @@ -503,7 +503,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(load); } - void VisitStmt_(const BlockNode* op) final { StmtVisitor::VisitStmt_(op); } + void VisitStmt_(const SBlockNode* op) final { StmtVisitor::VisitStmt_(op); } Var loop_var_; ffi::Map> buffer_indices_map; @@ -536,8 +536,8 @@ class BlockMutator : public StmtExprMutator { : new_loop_var_(new_loop_var), min_(min), extent_(extent) {} private: - Stmt VisitStmt_(const BlockNode* _op) final { - Block new_block = Downcast(StmtMutator::VisitStmt_(_op)); + Stmt VisitStmt_(const SBlockNode* _op) final { + SBlock new_block = Downcast(StmtMutator::VisitStmt_(_op)); // If iter_vars.size() is 0, then the block most probably be an Opaque block if (new_block->iter_vars.size() == 0 || inner_iter_var_index == -1) { @@ -603,7 +603,7 @@ class BlockMutator : public StmtExprMutator { return block_stmt; } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { ffi::Array iter_values = realize->iter_values; for (size_t i = 0; i < iter_values.size(); i++) { if (new_loop_var_.same_as(iter_values[i])) { @@ -612,7 +612,7 @@ class BlockMutator : public StmtExprMutator { break; } } - BlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); + SBlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); return stmt; } @@ -633,10 +633,10 @@ class BlockMutator : public StmtExprMutator { int inner_iter_var_index = -1; }; -const ffi::String get_block_name(Stmt loop_body) { - const BlockRealizeNode* blk_realize = loop_body.as(); +const ffi::String get_sblock_name(Stmt loop_body) { + const SBlockRealizeNode* blk_realize = loop_body.as(); if (blk_realize == nullptr) { - return get_block_name(loop_body.as()->body); + return get_sblock_name(loop_body.as()->body); } return blk_realize->block->name_hint; } @@ -659,7 +659,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref dtype = DataType::Int(bits); } - ffi::String block_name = get_block_name(loop->body) + "_" + loop->loop_var->name_hint; + ffi::String block_name = get_sblock_name(loop->body) + "_" + loop->loop_var->name_hint; int n = factors.size(); PrimExpr min_value = loop->min; PrimExpr extent_value; @@ -681,16 +681,16 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref const auto& partition_block_name = block_name + std::to_string(i) + "_partition"; // Create partition_block for the partitioned for loop - BlockRealize partition_block({}, extent_value > 0, - Block({}, {}, {}, partition_block_name, for_node)); + SBlockRealize partition_block({}, extent_value > 0, + SBlock({}, {}, {}, partition_block_name, for_node)); block_partitions.push_back(partition_block); min_value = extent_value; } // Create common block with all the partitioned blocks as its children blocks - BlockRealize common({}, make_const(DataType::Bool(), 1), - Block({}, {}, {}, block_name + "_common", tir::SeqStmt(block_partitions))); + SBlockRealize common({}, make_const(DataType::Bool(), 1), + SBlock({}, {}, {}, block_name + "_common", tir::SeqStmt(block_partitions))); // Replace existing loop with the newly created common block self->Replace(loop_sref, common, {}); @@ -698,7 +698,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref StmtSRef scope_root = tir::GetScopeRoot(self, scope_sref, /*require_stage_pipeline=*/false); bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); // Update the SRefTree for the newly created common block - self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, scope_root)); self->block_info[scope_root].affine_binding = scope_block_affine_binding; // Collect the SRef for each partitioned loop and return @@ -706,7 +706,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref partition_srefs.reserve(n); for (int i = 0; i < n; i++) { StmtSRef partition_loop_sref = - self->stmt2ref.at(block_partitions[i].as()->block->body.get()); + self->stmt2ref.at(block_partitions[i].as()->block->body.get()); partition_srefs.push_back(partition_loop_sref); } return partition_srefs; @@ -714,7 +714,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref class LoopReconstructor : private StmtMutator { public: - explicit LoopReconstructor(Block scope_root, const std::vector>& loops) + explicit LoopReconstructor(SBlock scope_root, const std::vector>& loops) : scope_root_(scope_root), loops_(loops) {} using StmtMutator::operator(); @@ -752,9 +752,9 @@ class LoopReconstructor : private StmtMutator { } private: - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { if (block != scope_root_.get()) { - return ffi::GetRef(block); + return ffi::GetRef(block); } return StmtMutator::VisitStmt_(block); } @@ -789,7 +789,7 @@ class LoopReconstructor : private StmtMutator { public: /*! \brief The root block of the block scope */ - Block scope_root_; + SBlock scope_root_; /*! \brief The given loops to be merge */ const std::vector>& loops_; /*! \brief The outermost new loop to replace the original loop */ @@ -860,10 +860,10 @@ StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs) { } } // Step 2. Create merged loops and replace the original loops - Block scope_root = ffi::GetRef(scope_root_sref->StmtAs()); + SBlock scope_root = ffi::GetRef(scope_root_sref->StmtAs()); LoopReconstructor reconstructor(scope_root, lca_nest_loops); reconstructor.MakeNewLoop(); - Block new_scope_root = Downcast(reconstructor(scope_root)); + SBlock new_scope_root = Downcast(reconstructor(scope_root)); // Step 3. Do the actual replacement self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); return self->stmt2ref.at(reconstructor.new_inner_loop_.get()); @@ -934,7 +934,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, } substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; - ffi::Map opaque_block_reuse; + ffi::Map opaque_block_reuse; auto f_substitute = [&](const Var& v) -> ffi::Optional { for (int i = 0; i < n; i++) { if (v.same_as(loops[i]->loop_var)) { @@ -1000,7 +1000,7 @@ std::pair GetBoundaryOfReorderRange( } for (const StmtSRefNode* v = loop_sref;; v = v->parent) { // Case 1. If `v` corresponds to a block, stop traversal. - if (v->stmt->IsInstance()) { + if (v->stmt->IsInstance()) { if (scope_block_visited) { throw LoopsNotAChainError(self->mod, std::nullopt, LoopsNotAChainError::ProblemKind::kNotUnderAScope); @@ -1146,10 +1146,10 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { public: explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {} - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { if (realize->block.get() == src_block_) { new_loop_ = For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, - ffi::GetRef(realize)); + ffi::GetRef(realize)); return new_loop_; } return StmtMutator::VisitStmt_(realize); @@ -1166,8 +1166,8 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { if (new_stmt->IsInstance()) { self->Replace(parent_sref, std::move(new_stmt), {}); } else { - Block old_parent_block = ffi::GetRef(parent_sref->StmtAs()); - Block new_parent_block = Downcast(new_stmt); + SBlock old_parent_block = ffi::GetRef(parent_sref->StmtAs()); + SBlock new_parent_block = Downcast(new_stmt); self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}}); } return self->stmt2ref.at(creator.new_loop_.get()); @@ -1364,7 +1364,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) { - if (auto block = rv.as()) { + if (auto block = rv.as()) { return sch->AddUnitLoop(block.value()); } else if (auto loop = rv.as()) { return sch->AddUnitLoop(loop.value()); diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc index f66ee2f63e33..7fd28445a812 100644 --- a/src/tir/schedule/primitive/pad_einsum.cc +++ b/src/tir/schedule/primitive/pad_einsum.cc @@ -67,7 +67,7 @@ ffi::Optional> CheckTrivialBufferAccess(const BufferRegion& buff /*! \brief The schedule error class when the padding size is invalid. */ class InvalidPaddingError : public ScheduleError { public: - InvalidPaddingError(IRModule mod, Block block, ffi::Array padding) + InvalidPaddingError(IRModule mod, SBlock block, ffi::Array padding) : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {} IRModule mod() const final { return mod_; } ffi::Array LocationsOfInterest() const final { return {block_}; } @@ -81,7 +81,7 @@ class InvalidPaddingError : public ScheduleError { return os.str(); } - static void Check(const ScheduleState& self, const Block& block, ffi::Array padding) { + static void Check(const ScheduleState& self, const SBlock& block, ffi::Array padding) { if (padding.size() != block->iter_vars.size()) { throw InvalidPaddingError(self->mod, block, padding); } @@ -94,14 +94,14 @@ class InvalidPaddingError : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; ffi::Array padding_; }; /*! \brief The schedule error class when the block body is not an Einsum pattern. */ class NonEinsumError : public ScheduleError { public: - explicit NonEinsumError(IRModule mod, Block block) + explicit NonEinsumError(IRModule mod, SBlock block) : mod_(std::move(mod)), block_(std::move(block)) {} IRModule mod() const final { return mod_; } @@ -115,7 +115,7 @@ class NonEinsumError : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; }; /*! \brief Data structure that represents a Einsum computation. */ @@ -157,7 +157,7 @@ struct BufferPadding { return result; } - Stmt MakeCopyBlock(bool is_read, ffi::Array* blocks, arith::Analyzer* analyzer) { + Stmt MakeCopyBlock(bool is_read, ffi::Array* blocks, arith::Analyzer* analyzer) { ffi::Array loop_vars; ffi::Array loop_doms; ffi::Array iter_vars; @@ -198,10 +198,11 @@ struct BufferPadding { if (!is_read) { std::swap(read_region, write_region); } - Block new_block(iter_vars, {read_region}, {write_region}, padded_buffer->name, std::move(body)); + SBlock new_block(iter_vars, {read_region}, {write_region}, padded_buffer->name, + std::move(body)); blocks->push_back(new_block); - body = BlockRealize(ffi::Array{loop_vars.begin(), loop_vars.end()}, Bool(true), - new_block); + body = SBlockRealize(ffi::Array{loop_vars.begin(), loop_vars.end()}, Bool(true), + new_block); for (int i = ndim - 1; i >= 0; --i) { body = For(loop_vars[i], loop_doms[i]->min, loop_doms[i]->extent, ForKind::kSerial, std::move(body)); @@ -210,7 +211,7 @@ struct BufferPadding { } }; -Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { +Einsum ExtractEinsum(const ScheduleState& self, const SBlock& block) { Einsum result; std::unordered_set buffer_used; int n_reads = block->reads.size(); @@ -272,7 +273,7 @@ class BufferNotAllocatedInScopeError : public ScheduleError { /*! \brief The schedule error class when the producer block cannot be padded. */ class InvalidProducerError : public ScheduleError { public: - explicit InvalidProducerError(IRModule mod, Block producer) + explicit InvalidProducerError(IRModule mod, SBlock producer) : mod_(std::move(mod)), producer_(std::move(producer)) {} ffi::String FastErrorString() const final { @@ -292,14 +293,14 @@ class InvalidProducerError : public ScheduleError { private: IRModule mod_; Buffer buffer_; - Block producer_; + SBlock producer_; }; class PadEinsumBufferReplacer : public StmtExprMutator { public: - Stmt VisitStmt_(const BlockNode* old_block_ptr) final { - Block old_block = ffi::GetRef(old_block_ptr); - Block block = Downcast(StmtMutator::VisitStmt_(old_block_ptr)); + Stmt VisitStmt_(const SBlockNode* old_block_ptr) final { + SBlock old_block = ffi::GetRef(old_block_ptr); + SBlock block = Downcast(StmtMutator::VisitStmt_(old_block_ptr)); ffi::Array iter_vars; iter_vars.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { @@ -329,9 +330,9 @@ class PadEinsumBufferReplacer : public StmtExprMutator { writes.push_back(write); } } - Block new_block = - Block(iter_vars, reads, writes, block->name_hint, block->body, block->init, - /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/block->annotations); + SBlock new_block = + SBlock(iter_vars, reads, writes, block->name_hint, block->body, block->init, + /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/block->annotations); block_sref_reuse_.Set(old_block, new_block); return new_block; } @@ -368,19 +369,19 @@ class PadEinsumBufferReplacer : public StmtExprMutator { ffi::Map iter2padded_extents; ffi::Map loop_var2padded_extent; ffi::Map buffer_map_; - ffi::Map block_sref_reuse_; + ffi::Map block_sref_reuse_; }; void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array& padding) { arith::Analyzer analyzer; // Step 1: Input checking and error handling - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - BlockRealize realize = GetBlockRealize(self, block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + SBlockRealize realize = GetSBlockRealize(self, block_sref); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - InvalidPaddingError::Check(self, ffi::GetRef(block), padding); + const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); + InvalidPaddingError::Check(self, ffi::GetRef(block), padding); // Step 2. Extract the Einsum pattern - ExtractEinsum(self, ffi::GetRef(block)); + ExtractEinsum(self, ffi::GetRef(block)); // Step 3. Figure out the padding needed PadEinsumBufferReplacer replacer; for (int i = 0, n = padding.size(); i < n; ++i) { @@ -430,7 +431,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array< // Step 5. For each buffer, if it needs padding, create a new buffer and a new block ffi::Array read_blocks; ffi::Array write_blocks; - ffi::Array new_copy_blocks; + ffi::Array new_copy_blocks; ffi::Array alloc_buffers; for (const BufferRegion& buffer_region : block->reads) { if (f_needs_padding(buffer_region->region)) { @@ -462,19 +463,19 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array< new_scope_body.insert(new_scope_body.end(), write_blocks.begin(), write_blocks.end()); } // Step 7. Create new scope - Block new_scope_block{nullptr}; + SBlock new_scope_block{nullptr}; { - ObjectPtr n = ffi::make_object(*scope_block); + ObjectPtr n = ffi::make_object(*scope_block); n->body = SeqStmt::Flatten(new_scope_body); n->alloc_buffers.insert(n->alloc_buffers.end(), alloc_buffers.begin(), alloc_buffers.end()); - new_scope_block = Block(n); + new_scope_block = SBlock(n); } - replacer.block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); + replacer.block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); // Step 8. Do replacement and update flags self->Replace(scope_sref, new_scope_block, replacer.block_sref_reuse_); - for (const Block& block : new_copy_blocks) { + for (const SBlock& block : new_copy_blocks) { StmtSRef block_sref = self->stmt2ref.at(block.get()); - BlockInfo& block_info = self->block_info[block_sref]; + SBlockInfo& block_info = self->block_info[block_sref]; block_info.affine_binding = true; block_info.region_cover = true; block_info.stage_pipeline = true; @@ -492,7 +493,7 @@ struct PadEinsumTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::Array padding) { + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block, ffi::Array padding) { sch->PadEinsum(block, padding); } diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc index 44a0f9bbe284..a8325c09e692 100644 --- a/src/tir/schedule/primitive/read_write_at.cc +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -51,12 +51,12 @@ void RelaxBufferRegions(const ffi::Array& buffer_regions, class ScopeReplacer : public StmtMutator { public: - static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, - const ForNode* new_loop) { - ObjectPtr new_scope_block = ffi::make_object(*scope_block); + static SBlock Replace(const SBlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = ffi::make_object(*scope_block); new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); new_scope_block->alloc_buffers.push_back(dst); - return Block(new_scope_block); + return SBlock(new_scope_block); } private: @@ -64,7 +64,7 @@ class ScopeReplacer : public StmtMutator { : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } - Stmt VisitStmt_(const BlockNode* block) final { return ffi::GetRef(block); } + Stmt VisitStmt_(const SBlockNode* block) final { return ffi::GetRef(block); } Stmt VisitStmt_(const ForNode* loop) final { if (loop == old_loop_) { found_ = true; @@ -81,7 +81,7 @@ class ScopeReplacer : public StmtMutator { class ReadWriteAtBufferReplacer : public StmtExprMutator { public: explicit ReadWriteAtBufferReplacer(const Buffer& src, const Buffer& dst, - ffi::Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} private: @@ -105,19 +105,19 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { return load; } - Stmt VisitStmt_(const BlockNode* _block) final { - Block old_block = ffi::GetRef(_block); - Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); - ObjectPtr new_block = ffi::make_object(*block.get()); + Stmt VisitStmt_(const SBlockNode* _block) final { + SBlock old_block = ffi::GetRef(_block); + SBlock block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = ffi::make_object(*block.get()); new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); - block_sref_reuse_->Set(old_block, Block(new_block)); - return Block(new_block); + block_sref_reuse_->Set(old_block, SBlock(new_block)); + return SBlock(new_block); } const Buffer& src_; const Buffer& dst_; - ffi::Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; }; struct ReadWriteAtImpl { @@ -125,16 +125,16 @@ struct ReadWriteAtImpl { static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, int buffer_index, const ffi::String& storage_scope, ffi::Map annotations) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer src = GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + Buffer src = GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, is_read ? BufferIndexType::kRead : BufferIndexType::kWrite); Buffer dst = WithScope(src, storage_scope); ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); - std::pair new_loop_block = + std::pair new_loop_block = impl.MakeLoopAndBlock(src->name + "_" + storage_scope); StmtSRef result_block_sref = impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); - impl.UpdateBlockInfo(result_block_sref, !new_loop_block.second->iter_values.empty()); + impl.UpdateSBlockInfo(result_block_sref, !new_loop_block.second->iter_values.empty()); return result_block_sref; } @@ -148,25 +148,25 @@ struct ReadWriteAtImpl { return result; } - StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const SBlockNode* new_block) { StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, /*require_stage_pipeline=*/true); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); - Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); - block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); + const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_root_sref); + SBlock new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); return self_->stmt2ref.at(new_block); } - void UpdateBlockInfo(const StmtSRef& new_block_sref, bool affine_binding) { - BlockInfo& block_info = self_->block_info[new_block_sref]; + void UpdateSBlockInfo(const StmtSRef& new_block_sref, bool affine_binding) { + SBlockInfo& block_info = self_->block_info[new_block_sref]; block_info.affine_binding = affine_binding; block_info.region_cover = true; block_info.stage_pipeline = true; } template - std::pair MakeLoopAndBlock(const ffi::String& new_block_name_hint) { + std::pair MakeLoopAndBlock(const ffi::String& new_block_name_hint) { ffi::Array subtrees = AsArray(loop_->body); int n_subtrees = subtrees.size(); runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); @@ -182,11 +182,11 @@ struct ReadWriteAtImpl { bool w_visited = false; auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, &scope](const ObjectRef& obj) -> bool { - const BlockRealizeNode* realize = obj.as(); + const SBlockRealizeNode* realize = obj.as(); if (realize == nullptr) { return true; } - const BlockNode* block = realize->block.get(); + const SBlockNode* block = realize->block.get(); bool has_r = HasBuffer(block->reads, src_); bool has_w = HasBuffer(block->writes, src_); r_visited = r_visited || has_r; @@ -200,7 +200,7 @@ struct ReadWriteAtImpl { /*low_inclusive=*/ffi::GetRef(self_->stmt2ref.at(block)->parent), /*high_exclusive=*/loop_sref_, /*extra_relax_scope=*/scope)), - /*bindings=*/GetBindings(ffi::GetRef(realize)), + /*bindings=*/GetBindings(ffi::GetRef(realize)), /*relaxed_regions=*/&relaxed_regions); } return false; @@ -251,19 +251,19 @@ struct ReadWriteAtImpl { subtrees.Set(i, Stmt(nullptr)); subtrees.Set(i, replacer(std::move(stmt))); } - BlockRealize realize = + SBlockRealize realize = is_read - ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) - : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + ? MakeSBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeSBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); subtrees.insert(subtrees.begin() + insert_pos, realize); ObjectPtr new_loop = ffi::make_object(*loop_); new_loop->body = SeqStmt(std::move(subtrees)); return {For(new_loop), realize}; } - BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, - const ffi::String& name_hint, const ffi::Map& loop_domain, - ffi::Array domain) const { + SBlockRealize MakeSBlock(const Buffer& copy_from, const Buffer& copy_to, + const ffi::String& name_hint, const ffi::Map& loop_domain, + ffi::Array domain) const { int n = domain.size(); std::vector loop_vars; loop_vars.reserve(n); @@ -304,18 +304,18 @@ struct ReadWriteAtImpl { for (int i = n - 1; i >= 0; --i) { stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); } - return BlockRealize( + return SBlockRealize( /*values=*/iter_values, /*predicate=*/const_true(), - Block(/*iter_vars=*/iter_vars, - /*reads=*/{BufferRegion(copy_from, domain)}, - /*writes=*/{BufferRegion(copy_to, domain)}, - /*name_hint=*/name_hint, // - /*body=*/std::move(stmt), - /*init=*/std::nullopt, - /*alloc_buffers=*/{}, - /*match_buffers=*/{}, - /*annotations=*/annotations_)); + SBlock(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/std::nullopt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); } explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, @@ -337,7 +337,7 @@ struct ReadWriteAtImpl { const Buffer& src_; const Buffer& dst_; ffi::Map annotations_; - ffi::Map block_sref_reuse_; + ffi::Map block_sref_reuse_; std::unique_ptr analyzer_; }; @@ -366,8 +366,8 @@ struct ReadAtTraits : public UnpackedInstTraits { StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, int buffer_index, const ffi::String& storage_scope); - static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, - Integer read_buffer_index, ffi::String storage_scope) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, SBlockRV block, + Integer read_buffer_index, ffi::String storage_scope) { return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); } @@ -396,8 +396,8 @@ struct WriteAtTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, - Integer write_buffer_index, ffi::String storage_scope) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, SBlockRV block, + Integer write_buffer_index, ffi::String storage_scope) { return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 0629757a13d8..fafc646682bf 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -37,17 +37,17 @@ class DecomposeReductionBlockReplacer : public StmtMutator { * \param old_reduction_block The reduction block we want to decompose * \return The new block scope and the updated reduction block */ - static std::pair Replace(Block old_scope_root, For target_loop, - Stmt decomposed_body, Block old_reduction_block) { + static std::pair Replace(SBlock old_scope_root, For target_loop, + Stmt decomposed_body, SBlock old_reduction_block) { DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body), std::move(old_reduction_block)); - return std::make_pair(Downcast(replacer(std::move(old_scope_root))), + return std::make_pair(Downcast(replacer(std::move(old_scope_root))), replacer.new_reduction_block_); } private: explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body, - Block old_reduction_block) + SBlock old_reduction_block) : target_loop_(std::move(target_loop)), decomposed_body_(std::move(decomposed_body)), old_reduction_block_(std::move(old_reduction_block)) {} @@ -61,9 +61,9 @@ class DecomposeReductionBlockReplacer : public StmtMutator { } } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { if (block == old_reduction_block_.get()) { - ObjectPtr p_new_block = CopyOnWrite(block); + ObjectPtr p_new_block = CopyOnWrite(block); p_new_block->name_hint = p_new_block->name_hint + "_update"; p_new_block->init = std::nullopt; // Add write regions back to read regions in update block. @@ -81,7 +81,7 @@ class DecomposeReductionBlockReplacer : public StmtMutator { new_reads.push_back(read_access); } p_new_block->reads = new_reads; - new_reduction_block_ = Block(p_new_block); + new_reduction_block_ = SBlock(p_new_block); return new_reduction_block_; } else { return StmtMutator::VisitStmt_(block); @@ -100,14 +100,14 @@ class DecomposeReductionBlockReplacer : public StmtMutator { private: For target_loop_; Stmt decomposed_body_; - Block old_reduction_block_; - Block new_reduction_block_; + SBlock old_reduction_block_; + SBlock new_reduction_block_; }; class LoopHeightError : public ScheduleError { public: - static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block, - const BlockRealizeNode* realize, + static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const SBlockNode* block, + const SBlockRealizeNode* realize, const ffi::Array& loops, const StmtSRef& loop_sref) { for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { @@ -126,13 +126,13 @@ class LoopHeightError : public ScheduleError { const Var& loop_var = higher_loop->StmtAs()->loop_var; if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - throw LoopHeightError(mod, ffi::GetRef(loop), ffi::GetRef(block)); + throw LoopHeightError(mod, ffi::GetRef(loop), ffi::GetRef(block)); } } } } - explicit LoopHeightError(IRModule mod, For loop, Block block) + explicit LoopHeightError(IRModule mod, For loop, SBlock block) : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -152,7 +152,7 @@ class LoopHeightError : public ScheduleError { IRModule mod_; For loop_; - Block block_; + SBlock block_; }; PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set& discarded_loops) { @@ -185,17 +185,17 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, * - generate corresponding init block and update block */ // Condition Checks and Information Collection - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Get the outer loops from high to low ffi::Array loops = GetLoops(block_sref); - const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); + const SBlockRealizeNode* realize = GetSBlockRealize(self, block_sref).get(); StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); if (self->enable_check) { // Cond 0. Check loop_sref is an ancestor of block_sref if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { - throw LoopPositionError(self->mod, ffi::GetRef(loop), ffi::GetRef(block), + throw LoopPositionError(self->mod, ffi::GetRef(loop), ffi::GetRef(block), "decompose_reduction"); } // Cond 1. Check block is reduction @@ -204,12 +204,12 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); } // IR Manipulation - ObjectPtr init_block = ffi::make_object(); - ObjectPtr init_realize = ffi::make_object(); + ObjectPtr init_block = ffi::make_object(); + ObjectPtr init_realize = ffi::make_object(); init_block->name_hint = block->name_hint + "_init"; init_block->annotations = block->annotations; init_realize->iter_values = {}; - init_realize->block = Block(init_block); + init_realize->block = SBlock(init_block); // Step 1. Create new block vars and their bindings // Maps an old block var to the new corresponding block var std::unordered_map block_var_map; @@ -266,7 +266,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops); // Step 5. Create new loops above init block std::unordered_map loop_var_map; - Stmt body = BlockRealize(init_realize); + Stmt body = SBlockRealize(init_realize); for (int i : chosen_loops) { For old_loop = ffi::GetRef(TVM_SREF_TO_FOR(loops[i])); // Create a new equivalent to the chosen loop @@ -288,13 +288,14 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, } body = Substitute(body, loop_var_map); // Step 6. Mutate IR - const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + const SBlockNode* old_scope_root = TVM_SREF_TO_SBLOCK(scope_root_sref); auto [new_scope_root, new_reduction_block] = DecomposeReductionBlockReplacer::Replace( - ffi::GetRef(old_scope_root), ffi::GetRef(loop), body, ffi::GetRef(block)); + ffi::GetRef(old_scope_root), ffi::GetRef(loop), body, + ffi::GetRef(block)); self->Replace(scope_root_sref, new_scope_root, - {{ffi::GetRef(old_scope_root), new_scope_root}, - {ffi::GetRef(block), new_reduction_block}}); - self->UpdateScopeBlockInfo(new_scope_root); + {{ffi::GetRef(old_scope_root), new_scope_root}, + {ffi::GetRef(block), new_reduction_block}}); + self->UpdateScopeSBlockInfo(new_scope_root); return self->stmt2ref.at(init_block.get()); } @@ -559,10 +560,10 @@ class LoopPropertyError : public ScheduleError { ffi::Array LocationsOfInterest() const final { return {loop_}; } static void CheckLoopProperty(const ScheduleState& self, const ffi::Array& loops, - const ForNode* rf_loop, const Block& block, + const ForNode* rf_loop, const SBlock& block, const std::unordered_set& data_par_loop_vars, const std::unordered_set& reduce_loop_vars) { - ffi::Array children_of_outermost_loop = + ffi::Array children_of_outermost_loop = GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); if (!children_of_outermost_loop[0]->block.same_as(block)) { throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); @@ -649,7 +650,7 @@ ffi::Array CreateRFactorBuffers(const ffi::Array& buf_store */ class BaseBlockCreator { public: - explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, + explicit BaseBlockCreator(SBlockRealize old_block_realize, For rf_loop, ffi::Array old_reduction_updates, CommReducer reducer, ffi::Array rf_buffers, bool is_rf_block) : old_block_realize_(std::move(old_block_realize)), @@ -695,7 +696,7 @@ class BaseBlockCreator { new_block_name = new_block_name + "_rf"; predicate = old_block_realize_->predicate; } - new_block_ = Block( + new_block_ = SBlock( /*iter_vars=*/iter_vars_, /*reads=*/read_regions_, /*writes=*/write_regions_, @@ -705,7 +706,7 @@ class BaseBlockCreator { /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/old_block_realize_->block->annotations); - new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_); + new_block_realize_ = SBlockRealize(iter_values_, predicate, new_block_); } private: @@ -765,15 +766,15 @@ class BaseBlockCreator { public: /*! \brief The new created block */ - Block new_block_; + SBlock new_block_; /*! \brief The new created block-realize */ - BlockRealize new_block_realize_; + SBlockRealize new_block_realize_; /*! \brief The indices used to access the intermediate rfactor buffer */ ffi::Array rf_buf_access_indices_; protected: /*! \brief The old block-realize */ - BlockRealize old_block_realize_; + SBlockRealize old_block_realize_; /*! \brief The number of block iters in the old block */ int n_block_iters_; /*! \brief The rfactor loop */ @@ -836,7 +837,7 @@ class BaseBlockCreator { */ class RFactorBlockCreator : public BaseBlockCreator { public: - explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, + explicit RFactorBlockCreator(SBlockRealize old_block_realize, For rf_loop, ffi::Array old_reduction_updates, CommReducer reducer, ffi::Array rf_buffers, std::unordered_map loop_vars2loop, @@ -915,7 +916,7 @@ class RFactorBlockCreator : public BaseBlockCreator { for (int i = 0; i < n_buffers_; ++i) { buffer_map.Set(old_reduction_updates_[i]->buffer, rf_buffers_[i]); } - const Block& old_block = old_block_realize_->block; + const SBlock& old_block = old_block_realize_->block; read_regions_.reserve(old_block->reads.size()); for (const BufferRegion& read_region : old_block->reads) { read_regions_.push_back( @@ -961,7 +962,7 @@ class RFactorBlockCreator : public BaseBlockCreator { */ class WriteBackBlockCreator : public BaseBlockCreator { public: - explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, + explicit WriteBackBlockCreator(SBlockRealize old_block_realize, For rf_loop, ffi::Array old_reduction_updates, CommReducer reducer, ffi::Array rf_buffers, IterVar rf_additional_iter, ffi::Array combiner_lhs, @@ -1039,7 +1040,7 @@ class WriteBackBlockCreator : public BaseBlockCreator { * \param loops The loops to be wrapped over the rfactor block * \return A Stmt which is the wrapping result */ -Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const ffi::Array& loops) { +Stmt CreateLoopOutsideRfactorBlock(SBlockRealize rf_block_realize, const ffi::Array& loops) { int n_loops = static_cast(loops.size()); // Step 1. Create new loop vars. @@ -1059,7 +1060,7 @@ Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const ffi::Arr new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); } { - BlockRealizeNode* p_rf_block_realize = rf_block_realize.CopyOnWrite(); + SBlockRealizeNode* p_rf_block_realize = rf_block_realize.CopyOnWrite(); p_rf_block_realize->iter_values = new_bindings; p_rf_block_realize->predicate = Substitute(rf_block_realize->predicate, new_loop_var_map); } @@ -1100,17 +1101,17 @@ class BlockReplacer : public StmtMutator { * \param rf_buffer The rfactor buffer to be added into the scope root's `alloc_buffers` * \return The transformed new scope root block */ - static Block Replace(Block scope_root_block, Stmt rf_body, For outermost_loop, - BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, - std::unordered_set reduce_loop_vars, - std::unordered_map loop_vars2loop, - const ffi::Array& rf_buffers) { + static SBlock Replace(SBlock scope_root_block, Stmt rf_body, For outermost_loop, + SBlockRealize wb_block_realize, SBlockRealize old_block_realize, + For rf_loop, std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop, + const ffi::Array& rf_buffers) { BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), std::move(wb_block_realize), std::move(old_block_realize), std::move(rf_loop), std::move(reduce_loop_vars), std::move(loop_vars2loop)); - Block new_scope_root = Downcast(replacer(std::move(scope_root_block))); - BlockNode* p = new_scope_root.CopyOnWrite(); + SBlock new_scope_root = Downcast(replacer(std::move(scope_root_block))); + SBlockNode* p = new_scope_root.CopyOnWrite(); for (const Buffer& rf_buffer : rf_buffers) { p->alloc_buffers.push_back(rf_buffer); } @@ -1118,8 +1119,8 @@ class BlockReplacer : public StmtMutator { } private: - explicit BlockReplacer(Stmt rf_body, For outermost_loop, BlockRealize wb_block_realize, - BlockRealize old_block_realize, For rf_loop, + explicit BlockReplacer(Stmt rf_body, For outermost_loop, SBlockRealize wb_block_realize, + SBlockRealize old_block_realize, For rf_loop, std::unordered_set reduce_loop_vars, std::unordered_map loop_vars2loop) : rf_body_(std::move(rf_body)), @@ -1154,7 +1155,7 @@ class BlockReplacer : public StmtMutator { return loop == outermost_loop_.get() ? SeqStmt({rf_body_, body}) : body; } - Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* block_realize) final { // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's // block-realize. And we directly return the new `wb_block_realize`. ICHECK_EQ(block_realize, old_block_realize_.get()); @@ -1174,8 +1175,8 @@ class BlockReplacer : public StmtMutator { private: Stmt rf_body_; For outermost_loop_; - BlockRealize wb_block_realize_; - BlockRealize old_block_realize_; + SBlockRealize wb_block_realize_; + SBlockRealize old_block_realize_; For rf_loop_; std::unordered_set reduce_loop_vars_; std::unordered_map loop_vars2loop_; @@ -1187,9 +1188,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // ***************************************************** // Step 1. Check some basic conditions for rfactor. Get the block and block-realize. - BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); + SBlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); - const Block& block = block_realize->block; + const SBlock& block = block_realize->block; StmtSRef scope_root = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/true); if (self->enable_check) { @@ -1271,8 +1272,8 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // ***************************************************** // Step 1. Substitute the old scope root block with the new scope root block. - Block old_scope_root_block = ffi::GetRef(scope_root->StmtAs()); - Block new_scope_root_block = BlockReplacer::Replace( + SBlock old_scope_root_block = ffi::GetRef(scope_root->StmtAs()); + SBlock new_scope_root_block = BlockReplacer::Replace( old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, ffi::GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); self->Replace( @@ -1283,7 +1284,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax std::vector new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()), self->stmt2ref.at(wb_block_creator.new_block_.get())}; for (const StmtSRef& new_block_sref : new_block_srefs) { - BlockInfo& info = self->block_info[new_block_sref]; + SBlockInfo& info = self->block_info[new_block_sref]; info.affine_binding = true; info.region_cover = true; info.stage_pipeline = true; @@ -1302,7 +1303,7 @@ struct DecomposeReductionTraits : public UnpackedInstTraitsDecomposeReduction(block_rv, loop_rv); } @@ -1328,7 +1329,7 @@ struct RFactorTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) { return sch->RFactor(loop_rv, factor_axis->value); } diff --git a/src/tir/schedule/primitive/reorder_block_iter_var.cc b/src/tir/schedule/primitive/reorder_block_iter_var.cc index 6acc5fa2d924..2a61734c44ef 100644 --- a/src/tir/schedule/primitive/reorder_block_iter_var.cc +++ b/src/tir/schedule/primitive/reorder_block_iter_var.cc @@ -27,7 +27,7 @@ namespace tir { */ class InvalidReorderIndex : public ScheduleError { public: - explicit InvalidReorderIndex(IRModule mod, Block block, ffi::Array new_order) + explicit InvalidReorderIndex(IRModule mod, SBlock block, ffi::Array new_order) : mod_(mod), block_(block), new_order_(new_order) {} IRModule mod() const final { return mod_; } ffi::String FastErrorString() const final { @@ -43,23 +43,23 @@ class InvalidReorderIndex : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; ffi::Array new_order_; }; class BlockIterVarRewriter : public StmtMutator { public: - ffi::Map block_map; - explicit BlockIterVarRewriter(const BlockNode* block_n, std::vector order) + ffi::Map block_map; + explicit BlockIterVarRewriter(const SBlockNode* block_n, std::vector order) : order_(std::move(order)), block_to_rewrite(block_n) {} private: std::vector order_; - const BlockNode* block_to_rewrite; - Stmt VisitStmt_(const BlockRealizeNode* op) final { + const SBlockNode* block_to_rewrite; + Stmt VisitStmt_(const SBlockRealizeNode* op) final { if (op->block.get() == block_to_rewrite) { auto block_n = CopyOnWrite(op->block.get()); - Block block = op->block; + SBlock block = op->block; ffi::Array new_iter_vars; ffi::Array new_iter_values; for (int idx : order_) { @@ -67,12 +67,12 @@ class BlockIterVarRewriter : public StmtMutator { new_iter_values.push_back(op->iter_values[idx]); } block_n->iter_vars = new_iter_vars; - Block new_block(block_n); + SBlock new_block(block_n); block_map.Set(block, new_block); auto block_realize_n = CopyOnWrite(op); block_realize_n->block = new_block; block_realize_n->iter_values = new_iter_values; - return BlockRealize(block_realize_n); + return SBlockRealize(block_realize_n); } else { return StmtMutator::VisitStmt_(op); } @@ -81,7 +81,7 @@ class BlockIterVarRewriter : public StmtMutator { void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, const ffi::Array& new_order) { - const BlockNode* block_n = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block_n = TVM_SREF_TO_SBLOCK(block_sref); std::vector new_order_vec; for (const Integer& x : new_order) { new_order_vec.push_back(x->value); @@ -95,25 +95,25 @@ void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, return x >= 0 && x < static_cast(num_block_itervars); }); if (!is_full || !is_unique || !is_within_boundary) { - throw InvalidReorderIndex(self->mod, ffi::GetRef(block_n), new_order); + throw InvalidReorderIndex(self->mod, ffi::GetRef(block_n), new_order); } // find parent block - const BlockNode* parent_block_n = nullptr; + const SBlockNode* parent_block_n = nullptr; const StmtSRefNode* p = block_sref.get()->parent; while (p != nullptr) { - if (p->stmt->IsInstance()) { - parent_block_n = TVM_SREF_TO_BLOCK(ffi::GetRef(p)); + if (p->stmt->IsInstance()) { + parent_block_n = TVM_SREF_TO_SBLOCK(ffi::GetRef(p)); break; } p = p->parent; } const StmtSRef parent_block_sref = ffi::GetRef(p); - const Block& parent_block = ffi::GetRef(parent_block_n); + const SBlock& parent_block = ffi::GetRef(parent_block_n); // rewrite block and blockrealize BlockIterVarRewriter rewriter(block_n, std::move(new_order_vec)); - Block new_parent_block = Downcast(rewriter(parent_block)); + SBlock new_parent_block = Downcast(rewriter(parent_block)); rewriter.block_map.Set(parent_block, new_parent_block); self->Replace(parent_block_sref, new_parent_block, rewriter.block_map); } @@ -127,7 +127,7 @@ struct ReorderBlockIterVarTraits : public UnpackedInstTraits new_order) { + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block, ffi::Array new_order) { sch->ReorderBlockIterVar(block, new_order); } diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index ff030bbef7a2..2b463207cf41 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -34,10 +34,10 @@ struct RollingBufferInfo { std::vector axis_overlaps; std::vector> axis_iter_vars; /*! \brief The map used for ScheduleStateNode::Replace. */ - ffi::Map block_reuse; + ffi::Map block_reuse; }; -BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, +BufferRegion GetRelaxedBufferRegion(const SBlockRealize& realize, const BufferRegion& buffer_region, const ffi::Map& dom_map) { ffi::Array relaxed_intsets = arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); @@ -52,7 +52,7 @@ BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferReg class RollingBufferDependencyError : public ScheduleError { public: - explicit RollingBufferDependencyError(IRModule mod, Block block) + explicit RollingBufferDependencyError(IRModule mod, SBlock block) : mod_(mod), block_(std::move(block)) {} ffi::String FastErrorString() const final { @@ -75,29 +75,29 @@ class RollingBufferDependencyError : public ScheduleError { */ static void Check(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - BlockScope scope = self->GetBlockScope(scope_root_sref); + SBlockScope scope = self->GetSBlockScope(scope_root_sref); for (const Dependency& producers : scope->GetDepsByDst(block_sref)) { if (!(producers->kind == DepKind::kRAW)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); } } for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) { if (!(consumers->kind == DepKind::kRAW)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); } } } private: IRModule mod_; - Block block_; + SBlock block_; }; class RollingBufferMatchError : public ScheduleError { public: - RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + RollingBufferMatchError(IRModule mod, SBlock block, BufferRegion buffer_region) : mod_(mod), block_(block), buffer_region_(buffer_region) {} ffi::String FastErrorString() const final { return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" @@ -117,13 +117,13 @@ class RollingBufferMatchError : public ScheduleError { private: IRModule mod_; - Block block_; + SBlock block_; BufferRegion buffer_region_; }; class RollingBufferInsertionError : public ScheduleError { public: - RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + RollingBufferInsertionError(IRModule mod, Buffer buffer, SBlock block) : mod_(mod), buffer_(std::move(buffer)), block_(block) {} ffi::String FastErrorString() const final { return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " @@ -143,7 +143,7 @@ class RollingBufferInsertionError : public ScheduleError { private: IRModule mod_; Buffer buffer_; - Block block_; + SBlock block_; }; class RollingBufferInfoCollector { @@ -153,8 +153,8 @@ class RollingBufferInfoCollector { const BufferRegion& buffer_region) { RollingBufferInfoCollector collector; if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferMatchError(mod, ffi::GetRef(block), buffer_region); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + throw RollingBufferMatchError(mod, ffi::GetRef(block), buffer_region); } return collector.info_; } @@ -291,10 +291,10 @@ class RollingBufferRewriter : public StmtExprMutator { *indices = std::move(new_indices); } - Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = ffi::GetRef(block); - Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); - BlockNode* n = stmt.CopyOnWrite(); + Stmt VisitStmt_(const SBlockNode* block) final { + SBlock old_stmt = ffi::GetRef(block); + SBlock stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + SBlockNode* n = stmt.CopyOnWrite(); if (block == scope_sref_->stmt) { ffi::Array new_alloc_buffers; for (const Buffer& buffer : stmt->alloc_buffers) { @@ -324,7 +324,7 @@ class RollingBufferRewriter : public StmtExprMutator { } } ffi::Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; - auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + auto infered_access_regions = GetSBlockReadWriteRegion(stmt, buffer_data_to_buffer); n->iter_vars = std::move(new_iter_vars); RewriteAccessRegion(&n->reads, infered_access_regions[0]); @@ -334,8 +334,8 @@ class RollingBufferRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { - BlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { + SBlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); // Append block predicate to avoid recomputing elements. if (rewrite_block_predicate_) { rewrite_block_predicate_ = false; @@ -353,7 +353,7 @@ class RollingBufferRewriter : public StmtExprMutator { And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); } } - BlockRealizeNode* n = stmt.CopyOnWrite(); + SBlockRealizeNode* n = stmt.CopyOnWrite(); n->predicate = condition; } return stmt; @@ -401,8 +401,8 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf * - Append block predicate to avoid recomputing overlapping elements. */ ffi::Map dom_map; - const BlockRealize& realize = GetBlockRealize(self, block_sref); - const Block& block = realize->block; + const SBlockRealize& realize = GetSBlockRealize(self, block_sref); + const SBlock& block = realize->block; // Step 1. Checking index, getting the target buffer region and the parent scope. const BufferRegion& buffer_region = @@ -443,7 +443,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf self->Replace(scope_root_sref, new_scope_root, info.block_reuse); // Step 7. Regenerate block info from the root block, because `region_cover` for the target block // and `stage_pipeline` for the root block are no longer satisfied after rolling buffer injection. - self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); + self->UpdateScopeSBlockInfo(tir::GetSBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); } struct RollingBufferTraits : public UnpackedInstTraits { @@ -455,7 +455,7 @@ struct RollingBufferTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index) { + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block, Integer write_buffer_index) { return sch->RollingBuffer(block, write_buffer_index.IntValue()); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index a8042e0c37eb..de09aa03dc0f 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -556,8 +556,8 @@ struct SampleComputeLocationTraits : public UnpackedInstTraits decision) { return sch->SampleComputeLocation(block_rv, decision); } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 35b221561978..636aa4dfc54b 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -25,13 +25,13 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); - BlockRVNode::RegisterReflection(); + SBlockRVNode::RegisterReflection(); LoopRVNode::RegisterReflection(); } /**************** Constructor ****************/ -BlockRV::BlockRV() { this->data_ = ffi::make_object(); } +SBlockRV::SBlockRV() { this->data_ = ffi::make_object(); } LoopRV::LoopRV() { this->data_ = ffi::make_object(); } @@ -66,7 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.schedule.BlockRV", []() { return BlockRV(); }) + .def("tir.schedule.SBlockRV", []() { return SBlockRV(); }) .def("tir.schedule.LoopRV", []() { return LoopRV(); }) .def("tir.schedule.ConcreteSchedule", [](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, @@ -94,7 +94,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto loop_rv = obj.as()) { return self->Get(loop_rv.value()); } - if (auto block_rv = obj.as()) { + if (auto block_rv = obj.as()) { return self->Get(block_rv.value()); } if (auto expr_rv = obj.as()) { @@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto loop_rv = obj.as()) { return self->GetSRef(loop_rv.value()); } - if (auto block_rv = obj.as()) { + if (auto block_rv = obj.as()) { return self->GetSRef(block_rv.value()); } if (auto stmt = obj.as()) { @@ -122,7 +122,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto loop_rv = obj.as()) { return self->RemoveRV(loop_rv.value()); } - if (auto block_rv = obj.as()) { + if (auto block_rv = obj.as()) { return self->RemoveRV(block_rv.value()); } if (auto expr_rv = obj.as()) { @@ -148,11 +148,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("tir.schedule.ScheduleGetBlock", &ScheduleNode::GetBlock) + .def_method("tir.schedule.ScheduleGetSBlock", &ScheduleNode::GetSBlock) .def_method("tir.schedule.ScheduleGetLoops", &ScheduleNode::GetLoops) .def("tir.schedule.ScheduleGetChildBlocks", [](Schedule self, ObjectRef rv) { - if (auto block_rv = rv.as()) { + if (auto block_rv = rv.as()) { return self->GetChildBlocks(block_rv.value()); } if (auto loop_rv = rv.as()) { @@ -179,7 +179,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("tir.schedule.ScheduleAddUnitLoop", [](Schedule self, ObjectRef rv) -> LoopRV { if (auto loop_rv = rv.as()) { return self->AddUnitLoop(loop_rv.value()); - } else if (auto block_rv = rv.as()) { + } else if (auto block_rv = rv.as()) { return self->AddUnitLoop(block_rv.value()); } else { LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " @@ -208,7 +208,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("tir.schedule.ScheduleCacheInplace", &ScheduleNode::CacheInplace) .def_method("tir.schedule.ScheduleCacheIndex", &ScheduleNode::CacheIndex) .def("tir.schedule.ScheduleReIndex", - [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { + [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type) { return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); }); @@ -238,7 +238,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("tir.schedule.ScheduleDecomposeReduction", &ScheduleNode::DecomposeReduction) .def_method("tir.schedule.ScheduleRFactor", &ScheduleNode::RFactor); } -/******** (FFI) Block annotation ********/ +/******** (FFI) SBlock annotation ********/ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() @@ -254,14 +254,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](Schedule self, ObjectRef target, bool preserve_unit_iters) { if (auto loop_rv = target.as()) { return self->Blockize(loop_rv.value(), preserve_unit_iters); - } else if (auto blocks = target.as>()) { + } else if (auto blocks = target.as>()) { return self->Blockize(blocks.value(), preserve_unit_iters); } LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); }) .def("tir.schedule.ScheduleTensorize", [](Schedule self, ObjectRef rv, ffi::String intrin, bool preserve_unit_iters) { - if (auto block_rv = rv.as()) { + if (auto block_rv = rv.as()) { self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); } else if (auto loop_rv = rv.as()) { self->Tensorize(loop_rv.value(), intrin, preserve_unit_iters); @@ -278,7 +278,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("tir.schedule.ScheduleAnnotate", [](Schedule self, ObjectRef rv, const ffi::String& ann_key, const Any& ann_val) { - if (auto block_rv = rv.as()) { + if (auto block_rv = rv.as()) { return self->Annotate(block_rv.value(), ann_key, ann_val); } if (auto loop_rv = rv.as()) { @@ -290,7 +290,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, const ffi::String& ann_key) { - if (auto block_rv = rv.as()) { + if (auto block_rv = rv.as()) { return self->Unannotate(block_rv.value(), ann_key); } if (auto loop_rv = rv.as()) { @@ -307,7 +307,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleTransformLayout", - [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, + [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform) { return self->TransformLayout(block_rv, buffer_index, @@ -316,7 +316,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def_method("tir.schedule.ScheduleTransformBlockLayout", &ScheduleNode::TransformBlockLayout) .def("tir.schedule.ScheduleSetAxisSeparator", - [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, + [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type, const ffi::Array& axis_separators) { return self->SetAxisSeparator(block_rv, buffer_index, static_cast(buffer_index_type), @@ -348,7 +348,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.schedule.ScheduleAnnotateBufferAccess", - [](Schedule self, const BlockRV& block_rv, int buffer_index, + [](Schedule self, const SBlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map) { return self->AnnotateBufferAccess( block_rv, buffer_index, diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index c299f52fde55..47845be9a516 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -137,7 +137,7 @@ bool ProducerCoversConsumer(const ffi::Array& buffer_shape, * \param new_stmt The statement that replaces the statement inside the sref */ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) { - ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); + ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); const StmtNode* old_stmt = sref->stmt; ICHECK_NE(new_stmt, old_stmt); self->stmt2ref[new_stmt] = ffi::GetRef(sref); @@ -146,16 +146,16 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new } /**************** Creation ****************/ -/*! \brief A helper class to update BlockInfo for a ScheduleStateNode */ -class BlockInfoCollector : private StmtVisitor { +/*! \brief A helper class to update SBlockInfo for a ScheduleStateNode */ +class SBlockInfoCollector : private StmtVisitor { public: static void Collect(ScheduleStateNode* self, const Stmt& stmt) { - BlockInfoCollector collector(self); + SBlockInfoCollector collector(self); collector.VisitStmt(stmt); } private: - explicit BlockInfoCollector(ScheduleStateNode* self) + explicit SBlockInfoCollector(ScheduleStateNode* self) : self_(self), srefs_{}, block2realize_{}, block_frames_{} { block_frames_.emplace_back(); } @@ -174,16 +174,16 @@ class BlockInfoCollector : private StmtVisitor { return sref; } - void MakeBlockInfo(StmtSRef scope_root) { + void MakeSBlockInfo(StmtSRef scope_root) { bool is_root_block = srefs_.empty(); - // Calculate `BlockInfo::scope` + // Calculate `SBlockInfo::scope` ffi::Array child_block_srefs = std::move(block_frames_.back()); - BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); + SBlockInfo& info = self_->block_info[scope_root] = SBlockInfo(SBlockScope(child_block_srefs)); // Set `affine_binding` if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, // then we set the affine binding flag as true only if the block has no block vars - const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(scope_root); if (block->iter_vars.empty()) info.affine_binding = true; } else { info.affine_binding = @@ -197,7 +197,7 @@ class BlockInfoCollector : private StmtVisitor { info.stage_pipeline = CheckRegionCoverAndStagePipeline(info, scope_root, child_block_srefs); } - bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& scope_root, + bool CheckRegionCoverAndStagePipeline(const SBlockInfo& info, const StmtSRef& scope_root, const ffi::Array& child_block_srefs) { const StmtSRefNode* limit = scope_root->parent; bool stage_pipeline = true; @@ -207,7 +207,7 @@ class BlockInfoCollector : private StmtVisitor { block_reads_unbound.reserve(child_block_srefs.size()); block_writes_unbound.reserve(child_block_srefs.size()); for (const StmtSRef& block_sref : child_block_srefs) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); ffi::Map binding = GetBindings(block2realize_.at(block)); // Step 1.1. Unbind read regions ffi::Array reads; @@ -228,8 +228,8 @@ class BlockInfoCollector : private StmtVisitor { for (const auto& kv : info.scope->dst2deps) { const StmtSRef& consumer_block_sref = kv.first; const ffi::Array& deps = kv.second; - const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); - const BlockRealize& consumer_realize = block2realize_.at(consumer_block); + const SBlockNode* consumer_block = TVM_SREF_TO_SBLOCK(consumer_block_sref); + const SBlockRealize& consumer_realize = block2realize_.at(consumer_block); bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; // Step 2.1. Extract the path to the scope root std::unordered_map> lca_loc; @@ -277,7 +277,7 @@ class BlockInfoCollector : private StmtVisitor { } // Step 2.3.2. Find all the regions written by each producer for (const StmtSRefNode* producer_block_sref : producer_block_srefs) { - const BlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt); + const SBlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt); StmtSRef parent_sref = ffi::GetRef(producer_block_sref->parent); for (const BufferRegion& region : block_writes_unbound.at(producer_block_sref)) { const BufferNode* buffer = region->buffer.get(); @@ -336,16 +336,16 @@ class BlockInfoCollector : private StmtVisitor { PopSRef(); } - void VisitStmt_(const BlockRealizeNode* realize) final { + void VisitStmt_(const SBlockRealizeNode* realize) final { block_frames_.emplace_back(); - const BlockNode* block = realize->block.get(); - block2realize_.emplace(block, ffi::GetRef(realize)); + const SBlockNode* block = realize->block.get(); + block2realize_.emplace(block, ffi::GetRef(realize)); // Recursive visit PushSRef(block); VisitStmt(block->body); // `block->init` is not visited StmtSRef sref = PopSRef(); - // Create BlockInfo for the block - MakeBlockInfo(sref); + // Create SBlockInfo for the block + MakeSBlockInfo(sref); // Update parent scope block_frames_.pop_back(); block_frames_.back().push_back(sref); @@ -362,7 +362,7 @@ class BlockInfoCollector : private StmtVisitor { /*! \brief The stack frame used to indicate the current scope */ std::vector srefs_; /*! \brief The BlockRealize corresponding to blocks */ - std::unordered_map block2realize_; + std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ std::vector> block_frames_; /*! \brief The auxiliary analyzer */ @@ -390,7 +390,7 @@ ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { if (auto opt = base_func.as()) { auto func = opt.value(); VerifyWellFormed(func); - BlockInfoCollector::Collect(self, func->body); + SBlockInfoCollector::Collect(self, func->body); } } data_ = std::move(n); @@ -445,11 +445,11 @@ struct ReuseInfo { */ std::unordered_set loop_sref_possible_reuse; /*! - * \brief Kind 2.2. Block sref reuse. - * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`, + * \brief Kind 2.2. SBlock sref reuse. + * Maps an old SBlock in `src_stmt` to a new block in `tgt_stmt`, * indicating the sref to the old block should be reused in the sref to the new block. */ - std::unordered_map block_sref_reuse; + std::unordered_map block_sref_reuse; }; /*! @@ -490,7 +490,7 @@ class ReuseCollector : public StmtVisitor { } } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (self_->stmt2ref.count(op)) { intact_.push_back(op); } else { @@ -523,7 +523,7 @@ class SRefTreePruner : public StmtVisitor { * \param src_stmt The `src_stmt` where stale srefs to be removed * \return Mapping from the reuse elements to reused srefs, more specifically: * 1) Loop reuse: maps a loop var to the reused sref - * 2) Block reuse: maps a block stmt to the reused sref, + * 2) SBlock reuse: maps a block stmt to the reused sref, * where the block comes from the subtree of `tgt_stmt` * 3) Intact reuse: not returned */ @@ -562,19 +562,19 @@ class SRefTreePruner : public StmtVisitor { VisitStmt(op->body); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (reuse_info_.intact.count(op)) { return; } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) << "IndexError: Cannot find corresponding StmtSRef for the block:\n" - << ffi::GetRef(op); + << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse const auto& sref_reuse = reuse_info_.block_sref_reuse; if (auto reuse_it = sref_reuse.find(op); reuse_it != sref_reuse.end()) { - const BlockNode* to_reuse = reuse_it->second; + const SBlockNode* to_reuse = reuse_it->second; // sref can be reused reused_srefs_.emplace(to_reuse, std::move(sref)); } else { @@ -650,7 +650,7 @@ class SRefUpdater : public StmtVisitor { ancestors_.pop_back(); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { StmtSRef& sref = self_->stmt2ref[op]; // Detect intact if (sref.defined()) { @@ -676,7 +676,7 @@ class SRefUpdater : public StmtVisitor { VisitStmt(op->body); ancestors_.pop_back(); // Additionally, need to update the scope because the block is changed - UpdateBlockInfo(sref); + UpdateSBlockInfo(sref); } void VisitStmt_(const SeqStmtNode* seq_stmt) final { @@ -684,16 +684,16 @@ class SRefUpdater : public StmtVisitor { SetSeqIndexInChildren(self_->stmt2ref, seq_stmt); } - void UpdateBlockInfo(const StmtSRef& block_sref) { - using TIter = std::unordered_map::iterator; + void UpdateSBlockInfo(const StmtSRef& block_sref) { + using TIter = std::unordered_map::iterator; // The caller is responsible for correcting the flags - BlockInfo new_info((BlockScope(GetChildBlockSRefOnSRefTree(self_, block_sref)))); + SBlockInfo new_info((SBlockScope(GetChildBlockSRefOnSRefTree(self_, block_sref)))); std::pair insert_result = self_->block_info.emplace(block_sref, new_info); bool inserted = insert_result.second; - BlockInfo& info = insert_result.first->second; + SBlockInfo& info = insert_result.first->second; if (inserted) { // Insertion has happened, update the flags accordingly - BlockInfo& info = insert_result.first->second; + SBlockInfo& info = insert_result.first->second; info.affine_binding = false; info.region_cover = false; info.stage_pipeline = false; @@ -723,11 +723,11 @@ class ChildReplacer : private StmtMutator { static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt, const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) { // Check the invariant - ICHECK(child_src_stmt->IsInstance() || // + ICHECK(child_src_stmt->IsInstance() || // child_src_stmt->IsInstance()); - ICHECK(child_tgt_stmt->IsInstance() || // - child_tgt_stmt->IsInstance() || // - child_tgt_stmt->IsInstance()); + ICHECK(child_tgt_stmt->IsInstance() || // + child_tgt_stmt->IsInstance() || // + child_tgt_stmt->IsInstance()); ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index); replacer.allow_copy_on_write_ = allow_copy_on_write; return replacer.CopyOnWriteAndVisit(parent_stmt); @@ -747,7 +747,7 @@ class ChildReplacer : private StmtMutator { } // Skipping sibling blocks and loops other than `src_stmt_` - Stmt VisitStmt_(const BlockNode* op) final { return ffi::GetRef(op); } + Stmt VisitStmt_(const SBlockNode* op) final { return ffi::GetRef(op); } Stmt VisitStmt_(const ForNode* op) final { return ffi::GetRef(op); } Stmt VisitStmt_(const SeqStmtNode* op) final { @@ -765,13 +765,13 @@ class ChildReplacer : private StmtMutator { if (stmt.get() == src_stmt) { // Case 1. src_stmt is For, stmt is For new_stmt = tgt_stmt_; - } else if (const auto* realize = stmt.as()) { + } else if (const auto* realize = stmt.as()) { // Case 2. stmt is BlockRealize, src_stmt is Block if (realize->block.get() == src_stmt) { - const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode); - ObjectPtr new_realize = ffi::make_object(*realize); - new_realize->block = ffi::GetRef(tgt_block); - new_stmt = BlockRealize(std::move(new_realize)); + const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, SBlockNode); + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = ffi::GetRef(tgt_block); + new_stmt = SBlockRealize(std::move(new_realize)); } } // Move new_stmt to position i @@ -789,11 +789,11 @@ class ChildReplacer : private StmtMutator { // where `body` means the body of either a block or a loop // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt` // and replace it with `child_tgt_stmt` - if (parent_stmt->IsInstance()) { - auto* block = const_cast(static_cast(parent_stmt)); - ObjectPtr new_block = CopyOnWrite(block); + if (parent_stmt->IsInstance()) { + auto* block = const_cast(static_cast(parent_stmt)); + ObjectPtr new_block = CopyOnWrite(block); new_block->body = this->VisitStmt(new_block->body); - return Block(std::move(new_block)); + return SBlock(std::move(new_block)); } else if (parent_stmt->IsInstance()) { auto* loop = const_cast(static_cast(parent_stmt)); ObjectPtr new_loop = CopyOnWrite(loop); @@ -816,13 +816,13 @@ class ChildReplacer : private StmtMutator { }; void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, - const ffi::Map& _block_sref_reuse) { + const ffi::Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; bool input_correct = (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || - (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || - (src_stmt->IsInstance() && tgt_stmt->IsInstance()); + (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || + (src_stmt->IsInstance() && tgt_stmt->IsInstance()); if (!input_correct) { LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey() << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n" @@ -844,7 +844,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // 3) all `stmt`s are correct, except for the root { // Step 0. Setup block_sref_reuse - std::unordered_map block_sref_reuse; + std::unordered_map block_sref_reuse; block_sref_reuse.reserve(_block_sref_reuse.size() + 1); for (const auto& kv : _block_sref_reuse) { block_sref_reuse.emplace(kv.first.get(), kv.second.get()); @@ -903,7 +903,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are correct // 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet - // 5) `tgt_stmt` is of type Loop, Block or BlockRealize + // 5) `tgt_stmt` is of type Loop, SBlock or BlockRealize // // During step `i`: // 1) Create `parent_stmt` that corresponds to `child_sref->parent` @@ -961,12 +961,12 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // If `g_func` was unique, after the 3 lines above: // `ref_new_func` points to the same unique function that `g_func` points to // Update the body of the function the sref belongs to Assign - const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode); + const auto* realize = TVM_TYPE_AS(g_func->body, SBlockRealizeNode); // Make `child_tgt_stmt` the root block - const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode); - ObjectPtr new_realize = ffi::make_object(*realize); - new_realize->block = ffi::GetRef(child_block); - new_func->body = BlockRealize(std::move(new_realize)); + const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, SBlockNode); + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = ffi::GetRef(child_block); + new_func->body = SBlockRealize(std::move(new_realize)); // Finally, move the `ref_new_func` back and update `this->mod` new_map->at(g_var) = std::move(ref_new_func); this->mod = ffi::GetRef(new_mod); @@ -992,23 +992,23 @@ void ScheduleStateNode::DebugVerify() const { } } -/**************** BlockInfo-related ****************/ +/**************** SBlockInfo-related ****************/ -BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - TVM_SREF_TO_BLOCK(block_sref); +SBlockInfo ScheduleStateNode::GetSBlockInfo(const StmtSRef& block_sref) const { + TVM_SREF_TO_SBLOCK(block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) - << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" + << "IndexError: Cannot find the corresponding SBlockScope to the block sref:\n" << ffi::GetRef(block_sref->stmt); return it->second; } -void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) { - BlockInfoCollector::Collect(this, stmt); +void ScheduleStateNode::UpdateScopeSBlockInfo(const Stmt& stmt) { + SBlockInfoCollector::Collect(this, stmt); } TVM_DLL ffi::Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { - const BlockInfo& info = self->GetBlockInfo(block_sref); + const SBlockInfo& info = self->GetSBlockInfo(block_sref); return {Bool(info.affine_binding), // Bool(info.region_cover), // Bool(info.stage_pipeline)}; @@ -1023,7 +1023,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { return ScheduleState(mod, debug_mask, enable_check); }) - .def_method("tir.schedule.ScheduleStateGetBlockScope", &ScheduleStateNode::GetBlockScope) + .def_method("tir.schedule.ScheduleStateGetSBlockScope", &ScheduleStateNode::GetSBlockScope) .def_method("tir.schedule.ScheduleStateReplace", &ScheduleStateNode::Replace) .def("tir.schedule.ScheduleStateGetSRef", [](ScheduleState self, Stmt stmt) -> ffi::Optional { diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 371aa0cb092d..abd3cfc522f8 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -76,9 +76,9 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, result.push_back(input); } else if (auto expr = input.as()) { result.push_back(expr.value()); - } else if (input.as() || // RV: block - input.as() || // RV: loop - input.as()) { // RV: var + } else if (input.as() || // RV: block + input.as() || // RV: loop + input.as()) { // RV: var auto it = rv_map.find(input.as()); ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; result.push_back(ffi::GetRef(it->second)); @@ -116,12 +116,12 @@ ffi::Array TranslateInputRVs( } else if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type and not string results.push_back(input); - } else if (input.as() || // RV: block - input.as() || // RV: loop - input.as()) { // RV: var + } else if (input.as() || // RV: block + input.as() || // RV: loop + input.as()) { // RV: var auto it = rv_names.find(input.cast()); if (it != rv_names.end()) { - // Case 1. BlockRV, LoopRV, VarRV + // Case 1. SBlockRV, LoopRV, VarRV results.push_back(it->second); } else { LOG(FATAL) << "IndexError: Random variable is not defined " << input; @@ -209,7 +209,7 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, results.push_back(ffi::String(std::string(name + 1, size - 2))); continue; } - // Case 0 & 1. None, BlockRV, LoopRV, VarRV + // Case 0 & 1. None, SBlockRV, LoopRV, VarRV auto it = named_rvs.find(name); CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; results.push_back(it->second); @@ -244,7 +244,7 @@ ffi::Array TranslateAddOutputRVs( ffi::String result; if (output == nullptr) { result = "_"; - } else if (output.as()) { + } else if (output.as()) { result = "b" + std::to_string(i); } else if (output.as()) { result = "l" + std::to_string(i); @@ -502,7 +502,7 @@ Trace TraceNode::Simplified(bool remove_postproc) const { for (const Any& obj : inst->inputs) { if (obj == nullptr) { continue; - } else if (obj.as() || obj.as() || obj.as()) { + } else if (obj.as() || obj.as() || obj.as()) { used_rvs.insert(obj.as()); continue; } else if (auto prim_expr = obj.as()) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 72606f243d69..178dbbeec8ec 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -100,7 +100,7 @@ ffi::Array TracedScheduleNode::SamplePartitionedTile( return results; } -LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, +LoopRV TracedScheduleNode::SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision) { LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); @@ -116,21 +116,21 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV TracedScheduleNode::GetBlock(const ffi::String& name, - const ffi::Optional& func_name) { +SBlockRV TracedScheduleNode::GetSBlock(const ffi::String& name, + const ffi::Optional& func_name) { GlobalVar gv = NullValue(); if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); } else if (func_working_on_.defined()) { gv = this->func_working_on_.value(); } else { - LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " + LOG(FATAL) << "ValueError: `get_sblock` does not know which function to be working on. Please " "specify the function name explicitly, or call `work_on` to specify the function " - "before using `get_block`."; + "before using `get_sblock`."; } - BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name); + SBlockRV result = ConcreteScheduleNode::GetSBlock(name, func_name); - static const InstructionKind& kind = InstructionKind::Get("GetBlock"); + static const InstructionKind& kind = InstructionKind::Get("GetSBlock"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{}, /*attrs=*/{name, gv->name_hint}, @@ -138,7 +138,7 @@ BlockRV TracedScheduleNode::GetBlock(const ffi::String& name, return result; } -ffi::Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { +ffi::Array TracedScheduleNode::GetLoops(const SBlockRV& block_rv) { ffi::Array results = ConcreteScheduleNode::GetLoops(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetLoops"); @@ -149,8 +149,8 @@ ffi::Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { return results; } -ffi::Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - ffi::Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); +ffi::Array TracedScheduleNode::GetChildBlocks(const SBlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -160,8 +160,8 @@ ffi::Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) return results; } -ffi::Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - ffi::Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); +ffi::Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + ffi::Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -171,8 +171,8 @@ ffi::Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return results; } -ffi::Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { - ffi::Array results = ConcreteScheduleNode::GetProducers(block_rv); +ffi::Array TracedScheduleNode::GetProducers(const SBlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetProducers(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetProducers"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -182,8 +182,8 @@ ffi::Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { return results; } -ffi::Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { - ffi::Array results = ConcreteScheduleNode::GetConsumers(block_rv); +ffi::Array TracedScheduleNode::GetConsumers(const SBlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetConsumers(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetConsumers"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -193,8 +193,8 @@ ffi::Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { return results; } -ffi::Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { - ffi::Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); +ffi::Array TracedScheduleNode::GetOutputBlocks(const SBlockRV& scope_block_rv) { + ffi::Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -280,7 +280,7 @@ void TracedScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { /*outputs=*/{})); } -void TracedScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, +void TracedScheduleNode::ReorderBlockIterVar(const SBlockRV& block_rv, const ffi::Array new_order) { ConcreteScheduleNode::ReorderBlockIterVar(block_rv, new_order); static const InstructionKind& kind = InstructionKind::Get("ReorderBlockIterVar"); @@ -289,7 +289,7 @@ void TracedScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, /*outputs=*/{})); } -LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) { +LoopRV TracedScheduleNode::AddUnitLoop(const SBlockRV& block_rv) { LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv); static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop"); @@ -354,10 +354,10 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { } /******** Schedule: Insert cache stages ********/ -BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks) { - BlockRV result = +SBlockRV TracedScheduleNode::CacheRead(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { + SBlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("CacheRead"); @@ -368,11 +368,11 @@ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_i return result; } -BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks) { - BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope, - consumer_blocks); +SBlockRV TracedScheduleNode::CacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { + SBlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope, + consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("CacheWrite"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -382,10 +382,10 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } -BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const IndexMap& index_map) { - BlockRV result = +SBlockRV TracedScheduleNode::ReindexCacheRead(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const IndexMap& index_map) { + SBlockRV result = ConcreteScheduleNode::ReindexCacheRead(block_rv, read_buffer_index, storage_scope, index_map); static const InstructionKind& kind = InstructionKind::Get("ReindexCacheRead"); @@ -398,11 +398,11 @@ BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_b return result; } -BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const ffi::String& storage_scope, - const IndexMap& index_map) { - BlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, - storage_scope, index_map); +SBlockRV TracedScheduleNode::ReindexCacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const IndexMap& index_map) { + SBlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, + storage_scope, index_map); static const InstructionKind& kind = InstructionKind::Get("ReindexCacheWrite"); trace_->Append( @@ -414,12 +414,13 @@ BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write return result; } -ffi::Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) { - ffi::Array result = +ffi::Array TracedScheduleNode::CacheInplace(const SBlockRV& block_rv, + int read_buffer_index, + const ffi::String& storage_scope) { + ffi::Array result = ConcreteScheduleNode::CacheInplace(block_rv, read_buffer_index, storage_scope); ffi::Array results; - for (const BlockRV& r : result) { + for (const SBlockRV& r : result) { results.push_back(r); } static const InstructionKind& kind = InstructionKind::Get("CacheInplace"); @@ -430,13 +431,13 @@ ffi::Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, in return result; } -ffi::Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, - const ffi::String& storage_scope, - int cse_thresh) { - ffi::Array result = +ffi::Array TracedScheduleNode::CacheIndex(const SBlockRV& block_rv, + const ffi::String& storage_scope, + int cse_thresh) { + ffi::Array result = ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); ffi::Array outputs; - for (const BlockRV& r : result) { + for (const SBlockRV& r : result) { outputs.push_back(r); } static const InstructionKind& kind = InstructionKind::Get("CacheIndex"); @@ -447,9 +448,9 @@ ffi::Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, return result; } -BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) { - BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); +SBlockRV TracedScheduleNode::ReIndex(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) { + SBlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); static const InstructionKind& kind = InstructionKind::Get("ReIndex"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -461,9 +462,9 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Data movement ********/ -BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int read_buffer_index, const ffi::String& storage_scope) { - BlockRV result = +SBlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const SBlockRV& block_rv, + int read_buffer_index, const ffi::String& storage_scope) { + SBlockRV result = ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("ReadAt"); @@ -474,9 +475,9 @@ BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_r return result; } -BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int write_buffer_index, const ffi::String& storage_scope) { - BlockRV result = +SBlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const SBlockRV& block_rv, + int write_buffer_index, const ffi::String& storage_scope) { + SBlockRV result = ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("WriteAt"); @@ -489,7 +490,7 @@ BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_ /******** Schedule: Compute location ********/ -void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, +void TracedScheduleNode::ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index) { ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index); @@ -501,7 +502,7 @@ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_r /*outputs=*/{})); } -void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, +void TracedScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index) { ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index); @@ -512,7 +513,7 @@ void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& /*outputs=*/{})); } -void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) { +void TracedScheduleNode::ComputeInline(const SBlockRV& block_rv) { ConcreteScheduleNode::ComputeInline(block_rv); static const InstructionKind& kind = InstructionKind::Get("ComputeInline"); @@ -522,7 +523,7 @@ void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) { /*outputs=*/{})); } -void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { +void TracedScheduleNode::ReverseComputeInline(const SBlockRV& block_rv) { ConcreteScheduleNode::ReverseComputeInline(block_rv); static const InstructionKind& kind = InstructionKind::Get("ReverseComputeInline"); @@ -532,8 +533,8 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /*outputs=*/{})); } -void TracedScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, - const BlockRV& epilogue_block_rv) { +void TracedScheduleNode::FuseReductionEpilogue(const SBlockRV& reduction_block_rv, + const SBlockRV& epilogue_block_rv) { ConcreteScheduleNode::FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); static const InstructionKind& kind = InstructionKind::Get("FuseReductionEpilogue"); @@ -545,8 +546,8 @@ void TracedScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv /******** Schedule: Reduction ********/ -BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { - BlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv); +SBlockRV TracedScheduleNode::DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) { + SBlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv); static const InstructionKind& kind = InstructionKind::Get("DecomposeReduction"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, @@ -555,8 +556,8 @@ BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const Lo return result; } -BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { - BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis); +SBlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { + SBlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis); static const InstructionKind& kind = InstructionKind::Get("RFactor"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{loop_rv}, @@ -565,9 +566,9 @@ BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { return result; } -/******** Schedule: Block annotation ********/ +/******** Schedule: SBlock annotation ********/ -void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, +void TracedScheduleNode::StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor, int offset) { ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, offset); static const InstructionKind& kind = InstructionKind::Get("StorageAlign"); @@ -578,7 +579,7 @@ void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, /*outputs=*/{})); } -void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, +void TracedScheduleNode::SetScope(const SBlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) { ConcreteScheduleNode::SetScope(block_rv, buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("SetScope"); @@ -589,7 +590,7 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /*outputs=*/{})); } -void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, +void TracedScheduleNode::UnsafeSetDType(const SBlockRV& block_rv, int buffer_index, const ffi::String& dtype) { ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype); static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType"); @@ -602,8 +603,8 @@ void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_inde /******** Schedule: Blockize & Tensorize ********/ -BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { - BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv, preserve_unit_iters); +SBlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { + SBlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Blockize"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, @@ -613,8 +614,9 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_i return new_block; } -BlockRV TracedScheduleNode::Blockize(const ffi::Array& blocks, bool preserve_unit_iters) { - BlockRV new_block = ConcreteScheduleNode::Blockize(blocks, preserve_unit_iters); +SBlockRV TracedScheduleNode::Blockize(const ffi::Array& blocks, + bool preserve_unit_iters) { + SBlockRV new_block = ConcreteScheduleNode::Blockize(blocks, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Blockize"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, @@ -635,7 +637,7 @@ void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& int /*outputs=*/{})); } -void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const ffi::String& intrin, +void TracedScheduleNode::Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); @@ -658,7 +660,7 @@ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const ffi::String& ann_ /*outputs=*/{})); } -void TracedScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& ann_key, +void TracedScheduleNode::Annotate(const SBlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val); static const InstructionKind& kind = InstructionKind::Get("Annotate"); @@ -677,7 +679,7 @@ void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& an /*outputs=*/{})); } -void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) { +void TracedScheduleNode::Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) { ConcreteScheduleNode::Unannotate(block_rv, ann_key); static const InstructionKind& kind = InstructionKind::Get("Unannotate"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -688,7 +690,7 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& /******** Schedule: Layout transformation ********/ -void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, +void TracedScheduleNode::TransformLayout(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, const ffi::Optional& pad_value, @@ -706,7 +708,7 @@ void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_ind /*outputs=*/{})); } -void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) { +void TracedScheduleNode::TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) { ConcreteScheduleNode::TransformBlockLayout(block_rv, index_map); static const InstructionKind& kind = InstructionKind::Get("TransformBlockLayout"); trace_->Append( @@ -716,7 +718,7 @@ void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const Ind /*outputs=*/{})); } -void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, +void TracedScheduleNode::SetAxisSeparator(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array& axis_separators) { ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type, @@ -730,8 +732,8 @@ void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_in } /******** Schedule: Padding ********/ -BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { - BlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv); +SBlockRV TracedScheduleNode::DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) { + SBlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv); static const InstructionKind& kind = InstructionKind::Get("DecomposePadding"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, @@ -741,7 +743,7 @@ BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const Loop return new_block; } -void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) { +void TracedScheduleNode::PadEinsum(const SBlockRV& block_rv, const ffi::Array& padding) { ConcreteScheduleNode::PadEinsum(block_rv, padding); static const InstructionKind& kind = InstructionKind::Get("PadEinsum"); trace_->Append(/*inst=*/Instruction( @@ -753,7 +755,7 @@ void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::ArrayAppend(/*inst=*/Instruction( @@ -774,7 +776,7 @@ void TracedScheduleNode::EnterPostproc() { /*outputs=*/{})); } -void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, +void TracedScheduleNode::UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type, const ffi::Array& buf_index_array) { ConcreteScheduleNode::UnsafeHideBufferAccess(block_rv, buf_type, buf_index_array); @@ -786,7 +788,7 @@ void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, /*outputs=*/{})); } -void TracedScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, +void TracedScheduleNode::AnnotateBufferAccess(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { ConcreteScheduleNode::AnnotateBufferAccess(block_rv, buffer_index, buffer_index_type, index_map); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8c7b16a47e8d..b3f08ed1f06a 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -52,16 +52,16 @@ class TracedScheduleNode : public ConcreteScheduleNode { ffi::Array SamplePartitionedTile( const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, ffi::Optional> decision = std::nullopt) final; - LoopRV SampleComputeLocation(const BlockRV& block_rv, + LoopRV SampleComputeLocation(const SBlockRV& block_rv, ffi::Optional decision = std::nullopt) final; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const ffi::String& name, const ffi::Optional& func_name) final; - ffi::Array GetLoops(const BlockRV& block_rv) final; - ffi::Array GetChildBlocks(const BlockRV& block_rv) final; - ffi::Array GetChildBlocks(const LoopRV& loop_rv) final; - ffi::Array GetProducers(const BlockRV& block_rv) final; - ffi::Array GetConsumers(const BlockRV& block_rv) final; - ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) final; + SBlockRV GetSBlock(const ffi::String& name, const ffi::Optional& func_name) final; + ffi::Array GetLoops(const SBlockRV& block_rv) final; + ffi::Array GetChildBlocks(const SBlockRV& block_rv) final; + ffi::Array GetChildBlocks(const LoopRV& loop_rv) final; + ffi::Array GetProducers(const SBlockRV& block_rv) final; + ffi::Array GetConsumers(const SBlockRV& block_rv) final; + ffi::Array GetOutputBlocks(const SBlockRV& scope_block_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) final; LoopRV Merge(const ffi::Array& loop_rvs) final; @@ -72,8 +72,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { const ffi::Array>& factor_rvs, bool preserve_unit_iters) final; void Reorder(const ffi::Array& ordered_loop_rvs) final; - void ReorderBlockIterVar(const BlockRV& block_rv, const ffi::Array new_order) final; - LoopRV AddUnitLoop(const BlockRV& block_rv) final; + void ReorderBlockIterVar(const SBlockRV& block_rv, const ffi::Array new_order) final; + LoopRV AddUnitLoop(const SBlockRV& block_rv) final; LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; @@ -81,72 +81,73 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) final; void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ - BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, - const ffi::Array consumer_blocks = {}) final; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + SBlockRV CacheRead(const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope, - const ffi::Array consumer_blocks = {}) final; - BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope, const IndexMap& index_map) final; - BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const ffi::Array consumer_blocks = {}) final; + SBlockRV CacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) final; + SBlockRV ReindexCacheRead(const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope, const IndexMap& index_map) final; - ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) final; - BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) final; - ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, - int cse_thresh) final; + SBlockRV ReindexCacheWrite(const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, const IndexMap& index_map) final; + ffi::Array CacheInplace(const SBlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) final; + SBlockRV ReIndex(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) final; + ffi::Array CacheIndex(const SBlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) final; /******** Schedule: Data movement ********/ - BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const ffi::String& storage_scope) final; - BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + SBlockRV ReadAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) final; + SBlockRV WriteAt(const LoopRV& loop_rv, const SBlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope) final; /******** Schedule: Compute location ********/ - void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + void ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; - void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, + void ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; - void ComputeInline(const BlockRV& block_rv) final; - void ReverseComputeInline(const BlockRV& block_rv) final; - void FuseReductionEpilogue(const BlockRV& reduction_block, const BlockRV& epilogue_block) final; + void ComputeInline(const SBlockRV& block_rv) final; + void ReverseComputeInline(const SBlockRV& block_rv) final; + void FuseReductionEpilogue(const SBlockRV& reduction_block, const SBlockRV& epilogue_block) final; /******** Schedule: Reduction ********/ - BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final; - BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; - /******** Schedule: Block annotation ********/ - void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + SBlockRV DecomposeReduction(const SBlockRV& block_rv, const LoopRV& loop_rv) final; + SBlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; + /******** Schedule: SBlock annotation ********/ + void StorageAlign(const SBlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; - void SetScope(const BlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) final; - void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const ffi::String& dtype) final; + void SetScope(const SBlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) final; + void UnsafeSetDType(const SBlockRV& block_rv, int buffer_index, const ffi::String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ - BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; - BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) final; - void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + SBlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; + SBlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) final; + void Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) final; void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) override; void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) override; - void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; - void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) override; + void Annotate(const SBlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const SBlockRV& block_rv, const ffi::String& ann_key) override; /******** Schedule: Layout transformation ********/ - void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const ffi::Optional& pad_value, + void TransformLayout(const SBlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map, + const ffi::Optional& pad_value, bool assume_injective_transform) override; - void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; - void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, + void TransformBlockLayout(const SBlockRV& block_rv, const IndexMap& index_map) override; + void SetAxisSeparator(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const ffi::Array& axis_separators) final; /******** Schedule: Padding ********/ - BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; - void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) final; + SBlockRV DecomposePadding(const SBlockRV& block_rv, const LoopRV& loop_rv) final; + void PadEinsum(const SBlockRV& block_rv, const ffi::Array& padding) final; /******** Schedule: Buffer transformation ********/ - void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final; + void RollingBuffer(const SBlockRV& block_rv, int write_buffer_index) final; /******** Schedule: Misc ********/ void EnterPostproc() final; - void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + void UnsafeHideBufferAccess(const SBlockRV& block_rv, const ffi::String& buf_type, const ffi::Array& buf_index_array) final; - void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + void AnnotateBufferAccess(const SBlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 9c3da9f32bea..0127a288698b 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -27,13 +27,13 @@ namespace tir { /******** Annotation ********/ -Block WithAnnotation(const BlockNode* block, const ffi::String& attr_key, - const ObjectRef& attr_value) { +SBlock WithAnnotation(const SBlockNode* block, const ffi::String& attr_key, + const ObjectRef& attr_value) { ffi::Map annotations = block->annotations; annotations.Set(attr_key, attr_value); - ObjectPtr new_block = ffi::make_object(*block); + ObjectPtr new_block = ffi::make_object(*block); new_block->annotations = std::move(annotations); - return Block(new_block); + return SBlock(new_block); } /******** Buffer Related ********/ @@ -128,13 +128,13 @@ ffi::Array ReplaceBufferRegion(ffi::Array /******** ReplaceBufferMutator ********/ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, - ffi::Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : block_sref_reuse_(block_sref_reuse) { buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer); } ReplaceBufferMutator::ReplaceBufferMutator(const ffi::Map& buffer_map, - ffi::Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : block_sref_reuse_(block_sref_reuse) { for (const auto& [old_buffer, new_buffer] : buffer_map) { buffer_var_map_[old_buffer->data.get()] = new_buffer; @@ -167,7 +167,7 @@ MatchBufferRegion ReplaceBufferMutator::VisitMatchBufferRegion( } } -Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { +Stmt ReplaceBufferMutator::VisitStmt_(const SBlockNode* block) { // To reduce the number of blocks in block sref reuse map, we check whether the block is really // mutated (i.e., the old buffer appears in the block). If so, we return the block after // mutation. Otherwise we just return the original block. @@ -214,35 +214,35 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block. ffi::Array alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); // Step 4. Recursively mutate the block. - Block mutated_block = Downcast(StmtMutator::VisitStmt_(block)); + SBlock mutated_block = Downcast(StmtMutator::VisitStmt_(block)); if (mutated_block.get() == block && reads.same_as(mutated_block->reads) && writes.same_as(mutated_block->writes) && alloc_buffers.same_as(mutated_block->alloc_buffers) && match_buffers.same_as(mutated_block->match_buffers)) { - return ffi::GetRef(block); + return ffi::GetRef(block); } else { - ObjectPtr n = CopyOnWrite(mutated_block.get()); + ObjectPtr n = CopyOnWrite(mutated_block.get()); n->reads = std::move(reads); n->writes = std::move(writes); n->alloc_buffers = std::move(alloc_buffers); n->match_buffers = std::move(match_buffers); - Block new_block(n); + SBlock new_block(n); if (block_sref_reuse_ != nullptr) { - block_sref_reuse_->Set(ffi::GetRef(block), new_block); + block_sref_reuse_->Set(ffi::GetRef(block), new_block); } return new_block; } } -/******** Block Removal ********/ +/******** SBlock Removal ********/ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt) { class OnlyLeafError : public ScheduleError { public: - explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root) + explicit OnlyLeafError(IRModule mod, SBlock leaf_block, SBlock scope_root) : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {} ffi::String FastErrorString() const final { @@ -258,8 +258,8 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ ffi::Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } IRModule mod_; - Block leaf_block_; - Block scope_root_; + SBlock leaf_block_; + SBlock scope_root_; }; // Go upwards until find an ancestor with more than one child @@ -278,7 +278,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ break; } } - if (const auto* block = sref->StmtAs()) { + if (const auto* block = sref->StmtAs()) { auto body = block->body; // Peel off AllocateConst nodes at the beginning of the block body. std::vector allocs; @@ -299,7 +299,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } if (const auto* seq = body.as()) { - ObjectPtr n = ffi::make_object(*block); + ObjectPtr n = ffi::make_object(*block); auto new_seq = RemoveFromSeqStmt(ffi::GetRef(seq), ffi::GetRef(last_stmt)); // Re-attach AllocateConst nodes auto new_body = MergeNest(allocs, new_seq); @@ -319,12 +319,12 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } } ICHECK(sref != nullptr && sref->stmt != nullptr); - const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref); - const auto* scope_block = TVM_SREF_TO_BLOCK(sref); - throw OnlyLeafError(self->mod, ffi::GetRef(leaf_block), ffi::GetRef(scope_block)); + const auto* leaf_block = TVM_SREF_TO_SBLOCK(leaf_block_sref); + const auto* scope_block = TVM_SREF_TO_SBLOCK(sref); + throw OnlyLeafError(self->mod, ffi::GetRef(leaf_block), ffi::GetRef(scope_block)); } -ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, +ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::SBlockRV& block_rv, const ffi::String& intrin_name, bool allow_padding) { ffi::Optional opt_tensorize_info = GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), @@ -346,7 +346,7 @@ ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir:: sch->PadEinsum(block_rv, info->block_iter_paddings.value()); // Now we need to find out all the padded Block's. - ffi::Array inlined_producers, inlined_consumers; + ffi::Array inlined_producers, inlined_consumers; for (const auto& producer : sch->GetProducers(block_rv)) { // PadEinsum will not modify the producer if it does not need padding. if (original_producers.count(sch->GetSRef(producer).get())) { @@ -476,8 +476,8 @@ void BlockBufferAccessSimplifier::SimplifyBufferIndices(ffi::Array* in *indices = this->IterMapSimplifyWithContext(*indices, true); } -Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { - Block block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); +Stmt BlockBufferAccessSimplifier::VisitStmt_(const SBlockNode* op) { + SBlock block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); auto* n = block.CopyOnWrite(); SimplifyAccessRegion(&n->reads); SimplifyAccessRegion(&n->writes); @@ -498,25 +498,25 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { /******** PrimFunc-level analysis and transformation ********/ -void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, ffi::Array* leaf_blocks) { - ffi::Array blocks = sch->GetChildBlocks(cur_block_rv); +void GetLeafBlocksHelper(Schedule sch, SBlockRV cur_block_rv, ffi::Array* leaf_blocks) { + ffi::Array blocks = sch->GetChildBlocks(cur_block_rv); if (blocks.empty()) { leaf_blocks->push_back(cur_block_rv); } else { - for (const BlockRV& block : blocks) { + for (const SBlockRV& block : blocks) { GetLeafBlocksHelper(sch, block, leaf_blocks); } } } ffi::Optional NormalizePrimFunc(Schedule sch) { - BlockRV root_block = sch->GetBlock("root"); - ffi::Array leaf_blocks; + SBlockRV root_block = sch->GetSBlock("root"); + ffi::Array leaf_blocks; GetLeafBlocksHelper(sch, root_block, &leaf_blocks); - for (const BlockRV& block : leaf_blocks) { + for (const SBlockRV& block : leaf_blocks) { StmtSRef block_sref = sch->GetSRef(block); ffi::Array loops = GetLoops(block_sref); - ffi::Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; + ffi::Array binds = GetSBlockRealize(sch->state(), block_sref)->iter_values; if (loops.size() == 0) continue; if (loops.size() != binds.size()) { return std::nullopt; @@ -535,7 +535,7 @@ ffi::Optional NormalizePrimFunc(Schedule sch) { ffi::Array> block_loops; ffi::Array> block_iters; ffi::Array block_is_reduction; - for (const BlockRV& block : leaf_blocks) { + for (const SBlockRV& block : leaf_blocks) { ffi::Array iters = sch->Get(block)->iter_vars; bool has_spatial_iter = false; ffi::Array index_map_inputs; diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 6e26f48320db..23a1dd0486a6 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -41,8 +41,8 @@ namespace tir { * \param attr_value The annotation value to be added * \return A new block with the given annotation as its last annotation */ -Block WithAnnotation(const BlockNode* block, const ffi::String& attr_key, - const ObjectRef& attr_value); +SBlock WithAnnotation(const SBlockNode* block, const ffi::String& attr_key, + const ObjectRef& attr_value); /******** Buffer Related ********/ @@ -131,10 +131,10 @@ class ReplaceBufferMutator : public StmtExprMutator { * sref. */ ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, - ffi::Map* block_sref_reuse); + ffi::Map* block_sref_reuse); ReplaceBufferMutator(const ffi::Map& buffer_map, - ffi::Map* block_sref_reuse); + ffi::Map* block_sref_reuse); protected: using StmtExprMutator::VisitExpr_; @@ -157,7 +157,7 @@ class ReplaceBufferMutator : public StmtExprMutator { virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer); - Stmt VisitStmt_(const BlockNode* block) override; + Stmt VisitStmt_(const SBlockNode* block) override; /*! * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in @@ -165,10 +165,10 @@ class ReplaceBufferMutator : public StmtExprMutator { */ std::unordered_map buffer_var_map_; /*! \brief The block sref reuse map for the following replacement */ - ffi::Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; }; -/******** Block Removal ********/ +/******** SBlock Removal ********/ /*! * \brief Construct a new AST, with a specific sref tree leaf removed. @@ -218,11 +218,11 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * block tiled according to the given intrin, std::nullopt if a valid loop mapping is not found */ ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, - const tir::BlockRV& block_rv, + const tir::SBlockRV& block_rv, const ffi::String& intrin_name, bool allow_padding = false); -/******** Block mutation ********/ +/******** SBlock mutation ********/ /*! * \brief Simplifier for indices of buffer access and block buffer access regions. @@ -250,7 +250,7 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { void SimplifyAccessRegion(ffi::Array* old_access_regions); void SimplifyBufferIndices(ffi::Array* indices); - Stmt VisitStmt_(const BlockNode* op) final; + Stmt VisitStmt_(const SBlockNode* op) final; Stmt VisitStmt_(const BufferStoreNode* op) final; PrimExpr VisitExpr_(const BufferLoadNode* op) final; }; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index cd48cb13d5aa..06752a09098e 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -72,11 +72,11 @@ inline ffi::Array LoopSRefs2Loops(const ffi::Array& loop_srefs) { * \param block_rvs The random variables to be converted * \return The conversion result srefs */ -inline ffi::Array BlockRVs2StmtSRefs(const Schedule& sch, - const ffi::Array& block_rvs) { +inline ffi::Array SBlockRVs2StmtSRefs(const Schedule& sch, + const ffi::Array& block_rvs) { ffi::Array block_srefs; block_srefs.reserve(block_rvs.size()); - for (const BlockRV& block_rv : block_rvs) { + for (const SBlockRV& block_rv : block_rvs) { block_srefs.push_back(sch->GetSRef(block_rv)); } return block_srefs; @@ -117,7 +117,7 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { if (to_remove.same_as(stmt)) { continue; } - if (const auto* realize = stmt.as()) { + if (const auto* realize = stmt.as()) { if (to_remove.same_as(realize->block)) { continue; } @@ -275,8 +275,8 @@ template inline ffi::Optional GetAnn(const StmtSRef& sref, const ffi::String& ann_key) { if (const auto* loop = sref->StmtAs()) { return GetAnn(loop, ann_key); - } else if (const auto* block = sref->StmtAs()) { - return GetAnn(block, ann_key); + } else if (const auto* block = sref->StmtAs()) { + return GetAnn(block, ann_key); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; @@ -318,7 +318,7 @@ inline bool HasAnn(const StmtSRef& sref, const ffi::String& ann_key, bool ann_va * \note Before invoking this helper function, make sure that the block has only spatial and * reduction loop axes. */ -inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, +inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::SBlockRV& block_rv, tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) { ffi::Array loops = sch->GetLoops(block_rv); @@ -380,9 +380,9 @@ inline ffi::String BufferIndexType2Str(BufferIndexType buffer_index_type) { /******** Utilities for retrieving information about blocks ********/ /*! \brief Returns the names of the blocks in the provided module. */ -inline std::unordered_set GetBlockNames(const IRModule& mod) { +inline std::unordered_set GetSBlockNames(const IRModule& mod) { struct BlockNameCollector : public tir::StmtVisitor { - void VisitStmt_(const tir::BlockNode* block) override { + void VisitStmt_(const tir::SBlockNode* block) override { block_names.insert(block->name_hint); StmtVisitor::VisitStmt(block->body); } @@ -399,7 +399,7 @@ inline std::unordered_set GetBlockNames(const IRModule& mod) { /*! \brief Query if the given block name exists in the module associated with the schedule */ inline bool HasBlock(const Schedule& sch, const std::string& block_name) { - auto block_names = GetBlockNames(sch->mod()); + auto block_names = GetSBlockNames(sch->mod()); return block_names.count(block_name); } diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 2b4598a99fa7..d62f21be1fdd 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -111,12 +111,12 @@ PrimFunc BindParams(PrimFunc f, const ffi::Array& constants) { } DataType dtype = DataType(constant_map[var]->dtype); - if (n->body->IsInstance()) { - auto* block_realize = n->body.as(); + if (n->body->IsInstance()) { + auto* block_realize = n->body.as(); auto block = block_realize->block; block.CopyOnWrite()->body = tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); - n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block); + n->body = SBlockRealize(block_realize->iter_values, block_realize->predicate, block); } else { n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); } diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 0ba4e75c3004..cc73121b5cdf 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -79,7 +79,7 @@ class Var2BufferCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { for (const Buffer& buffer : op->alloc_buffers) { var2buffer_[buffer->data].insert(buffer); } @@ -224,7 +224,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { // Step 0. Check there is no init part and block is opaque ICHECK(!op->init.defined()); ICHECK_EQ(op->iter_vars.size(), 0) << "CompactBufferRegion only works on opaque blocks"; @@ -291,7 +291,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } } - void VisitStmt_(const BlockRealizeNode* op) final { + void VisitStmt_(const SBlockRealizeNode* op) final { With ctx(op->predicate, &dom_map_, &hint_map_, &pending_conditions_); StmtExprVisitor::VisitStmt_(op); } @@ -562,16 +562,16 @@ class BufferCompactor : public StmtExprMutator { return load; } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { // Step 0. Check there is no Init part. ICHECK(!op->init.defined()); // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo. ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buf) { return RewriteAllocBuffer(buf); }); // Step 2. Recursively rewrite BufferLoad/BufferStore. - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); // Step 3. Update block signature. - BlockNode* n = block.CopyOnWrite(); + SBlockNode* n = block.CopyOnWrite(); RewriteBufferRegions(&n->reads); RewriteBufferRegions(&n->writes); RewriteMatchBuffers(&n->match_buffers); diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index f187252b2e31..546de79085d6 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -57,17 +57,17 @@ class OpaqueBlockConverter : public StmtExprMutator { return ffi::GetRef(var); } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { ICHECK(!block->init.defined()) << "Block Init part is not allowed in pass ConvertBlocksToOpaque"; - Block new_block = Downcast(StmtExprMutator::VisitStmt_(block)); + SBlock new_block = Downcast(StmtExprMutator::VisitStmt_(block)); if (!new_block->iter_vars.empty()) { new_block.CopyOnWrite()->iter_vars.clear(); } return new_block; } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { const auto* block_op = realize->block.get(); ICHECK(!block_op->init.defined()); @@ -86,7 +86,7 @@ class OpaqueBlockConverter : public StmtExprMutator { var_substitutes_.emplace(block_var->var.get(), v); } // Step 3. Visit recursively. - Block new_block = Downcast(VisitStmt(realize->block)); + SBlock new_block = Downcast(VisitStmt(realize->block)); // Step 4. Clear the variable bindings for (const auto& block_var : block_op->iter_vars) { @@ -96,9 +96,9 @@ class OpaqueBlockConverter : public StmtExprMutator { // Step 5. Return if (predicate.same_as(realize->predicate) && iter_values.same_as(realize->iter_values) && new_block.same_as(realize->block) && realize->iter_values.size() == 0) { - return ffi::GetRef(realize); + return ffi::GetRef(realize); } else { - return BlockRealize({}, predicate, new_block); + return SBlockRealize({}, predicate, new_block); } } diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 74c299456a4b..9b7442233de3 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -31,7 +31,7 @@ namespace transform { * \param max_thread_per_block The maximum number of threads per block. * \param max_threadblocks The maximum number of threadblocks. */ -void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread_per_block, +void ThreadBind(tir::Schedule sch, const tir::SBlockRV& block, int64_t max_thread_per_block, int64_t max_threadblocks = 256) { // fetch the loops ffi::Array loops = sch->GetLoops(block); @@ -146,8 +146,8 @@ Pass DefaultGPUSchedule() { int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); sch->WorkOn(gv->name_hint); - ffi::Array blocks = meta_schedule::BlockCollector::Collect(sch); - for (const tir::BlockRV& block : blocks) { + ffi::Array blocks = meta_schedule::SBlockCollector::Collect(sch); + for (const tir::SBlockRV& block : blocks) { auto childs = sch->GetChildBlocks(block); if (!childs.empty()) { continue; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 1a9ba390703f..800177fa5ca9 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -60,12 +60,12 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { explicit BufferFlattener(arith::Analyzer* ana) : IRMutatorWithAnalyzer(ana) {} - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { ICHECK_EQ(op->match_buffers.size(), 0) << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; - Block block = ffi::GetRef(op); + SBlock block = ffi::GetRef(op); ffi::Array alloc_buffers = op->alloc_buffers; alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc index 711c2a739f59..cee6018150a0 100644 --- a/src/tir/transforms/force_narrow_index_to_i32.cc +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -59,8 +59,8 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { return ffi::GetRef(op); } - Stmt VisitStmt_(const BlockNode* block) final { - Block block_ = Downcast(IndexDataTypeNormalizer::VisitStmt_(block)); + Stmt VisitStmt_(const SBlockNode* block) final { + SBlock block_ = Downcast(IndexDataTypeNormalizer::VisitStmt_(block)); // Check if the allocated integer buffers have dtype other than int32. for (const Buffer& buf : block_->alloc_buffers) { if (buf->dtype.is_int() && buf->dtype.bits() > 32) { diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index cdbe17508339..5bd3fb29a88f 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -116,7 +116,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { // Record the mapping from buffer data var to buffer for later lookup for (auto buffer : op->alloc_buffers) { buffer_map_.insert({buffer->data, buffer}); @@ -133,7 +133,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { auto prev_permute = permute_; permute_ = true; - Block block = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); + SBlock block = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); permute_ = prev_permute; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index f4258fc479d6..ab6d0c12d628 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -48,17 +48,17 @@ namespace software_pipeline { * \param buffer_data_to_buffer The map from buffer data to buffer. * \return The result block. */ -Block MakeBlock(const Stmt& body, const ffi::Map& buffer_data_to_buffer) { - if (const BlockRealizeNode* block_realize = body.as()) { +SBlock MakeSBlock(const Stmt& body, const ffi::Map& buffer_data_to_buffer) { + if (const SBlockRealizeNode* block_realize = body.as()) { if (is_one(block_realize->predicate)) { // no need to create a new block return block_realize->block; } } - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); + SBlock block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); ffi::Array> access = - GetBlockReadWriteRegion(block, buffer_data_to_buffer); - BlockNode* n = block.CopyOnWrite(); + GetSBlockReadWriteRegion(block, buffer_data_to_buffer); + SBlockNode* n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; return block; @@ -71,7 +71,7 @@ struct PipelineAnnotation { bool async; }; -using PipelineInfo = std::unordered_map; +using PipelineInfo = std::unordered_map; struct BufferAccessInfo { int def = -1; // the defining stage of the buffer @@ -247,12 +247,12 @@ class PipelineBodyRewriter : public StmtExprMutator { return buffer_region; } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { for (const Buffer& alloc_buffer : op->alloc_buffers) { buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); } - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - BlockNode* n = block.CopyOnWrite(); + SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); + SBlockNode* n = block.CopyOnWrite(); n->reads.MutateByApply([this](const BufferRegion& buffer_region) { return RewritePipelineBufferRegion(buffer_region); }); @@ -354,7 +354,7 @@ class PipelineRewriter : public StmtExprMutator { ordered_stmts_.resize(pipeline_info_.size()); for (const auto& pair : pipeline_info_) { - const Block& block = pair.first; + const SBlock& block = pair.first; int order = pair.second.order; ordered_stmts_.Set(order, block); } @@ -388,9 +388,9 @@ class PipelineRewriter : public StmtExprMutator { alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); buffer_data_to_buffer_.erase(alloc->data); } - Block block = MakeBlock(stmt, buffer_data_to_buffer_); + SBlock block = MakeSBlock(stmt, buffer_data_to_buffer_); block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); - return BlockRealize({}, Bool(true), block); + return SBlockRealize({}, Bool(true), block); } private: @@ -404,7 +404,7 @@ class PipelineRewriter : public StmtExprMutator { GetBufferAccessInfo() { std::unordered_map infos; for (const auto& pair : pipeline_info_) { - const Block& block = pair.first; + const SBlock& block = pair.first; int stage = pair.second.stage; max_stage_ = std::max(max_stage_, stage); @@ -482,7 +482,7 @@ class PipelineRewriter : public StmtExprMutator { // of block_i and block_j overlap. bool need_multi_version = false; for (const auto& pair1 : pipeline_info_) { - const Block& writer_block = pair1.first; + const SBlock& writer_block = pair1.first; const auto& writer_info = pair1.second; auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), @@ -494,7 +494,7 @@ class PipelineRewriter : public StmtExprMutator { } for (const auto& pair2 : pipeline_info_) { - const Block& reader_block = pair2.first; + const SBlock& reader_block = pair2.first; const auto& reader_info = pair2.second; auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), [&](const BufferRegion& buffer_region) { @@ -592,16 +592,16 @@ class PipelineRewriter : public StmtExprMutator { }; /*! Structure holding intermediate information for pipeline loop rewriting. */ - struct RewrittenBlockInfo { + struct RewrittenSBlockInfo { int stage; PrimExpr predicate; - Block block; + SBlock block; PrimExpr access_index; bool is_async; }; // Determine where to insert async_wait and the corresponding wait count. - void PopulateWaitCounts(const std::vector& new_blocks, + void PopulateWaitCounts(const std::vector& new_blocks, arith::Analyzer* ana_normalized, const std::unordered_map& buffer_to_commit_group, std::map* async_states_local) { @@ -730,10 +730,10 @@ class PipelineRewriter : public StmtExprMutator { // Given pipelined blocks and async-related information, generate final loop statements with async // scopes (if any). ffi::Array CompletePipelineLoopStatements( - const std::vector& blocks, + const std::vector& blocks, const std::map& async_states_local, arith::Analyzer* ana_normalized) const { - std::vector new_blocks = blocks; + std::vector new_blocks = blocks; std::vector commit_group_indices(new_blocks.size(), -1); for (const auto& [stage_id, state] : async_states_local) { if (!state.commit_groups.empty()) { @@ -748,7 +748,7 @@ class PipelineRewriter : public StmtExprMutator { if (state.pending_wait.valid()) { auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) { auto& block = new_blocks[i].block; - BlockNode* n = block.CopyOnWrite(); + SBlockNode* n = block.CopyOnWrite(); auto zero = make_zero(DataType::Int(32)); n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, @@ -774,7 +774,7 @@ class PipelineRewriter : public StmtExprMutator { for (size_t i = 0; i < new_blocks.size();) { if (commit_group_indices[i] == -1) { // A synchrnous block, not part of any commit group - stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); + stmts.push_back(SBlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); ++i; } else { ffi::Array group_bodies; @@ -795,8 +795,8 @@ class PipelineRewriter : public StmtExprMutator { for (auto body : group_bodies) { auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_commit_queue_scope, stage_id, body); - auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); - stmts.push_back(BlockRealize({}, predicate, new_block)); + auto new_block = MakeSBlock(commit_queue_scope, buffer_data_to_buffer_); + stmts.push_back(SBlockRealize({}, predicate, new_block)); } } } @@ -817,7 +817,7 @@ class PipelineRewriter : public StmtExprMutator { PrimExpr new_loop_var; PrimExpr extent = end - start; - auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; + auto make_nop = []() { return SBlockRealize({}, Bool(true), MakeSBlock(Evaluate(0), {})); }; if (analyzer_.CanProve(extent <= 0)) { return make_nop(); @@ -837,13 +837,13 @@ class PipelineRewriter : public StmtExprMutator { ana_normalized.Bind(Downcast(new_loop_var), Range(pipeline_loop_->min, extent)); } - std::vector new_blocks; + std::vector new_blocks; // Async related std::map async_states_local; std::unordered_map buffer_to_commit_group; - for (const Block& block : ordered_stmts_) { + for (const SBlock& block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; PrimExpr skewed_loop_var = new_loop_var - stage; PrimExpr inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && @@ -854,9 +854,9 @@ class PipelineRewriter : public StmtExprMutator { if (analyzer_.CanProve(!inbound)) { continue; } - Block new_block = Downcast(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, - pipeline_loop_, max_stage_ != 1, - fragment_info_)(block)); + SBlock new_block = Downcast( + PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, + max_stage_ != 1, fragment_info_)(block)); PrimExpr delta = start - pipeline_loop_->min; // This variable corresponds to @@ -871,7 +871,7 @@ class PipelineRewriter : public StmtExprMutator { inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); } - new_block = Downcast( + new_block = Downcast( Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); if (pipeline_info_[block].async) { @@ -909,7 +909,7 @@ class PipelineRewriter : public StmtExprMutator { local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound); } - BlockNode* n = new_block.CopyOnWrite(); + SBlockNode* n = new_block.CopyOnWrite(); n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); } @@ -963,7 +963,7 @@ class PipelineRewriter : public StmtExprMutator { } } - return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); + return SBlockRealize({}, Bool(true), MakeSBlock(std::move(new_loop), buffer_data_to_buffer_)); } arith::Analyzer analyzer_; @@ -975,7 +975,7 @@ class PipelineRewriter : public StmtExprMutator { const std::unordered_map& fragment_info_; int max_stage_ = -1; ffi::Map buffer_remap_; - ffi::Array ordered_stmts_; + ffi::Array ordered_stmts_; std::map async_states; ffi::Map preserved_annotations_; }; @@ -989,16 +989,16 @@ class PipelineRewriter : public StmtExprMutator { * destination to the source. */ void BuildDependencyGraph( - const ffi::Array& blocks, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { - std::unordered_map> buffer_writers; + const ffi::Array& blocks, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { + std::unordered_map> buffer_writers; - for (const Block& block : blocks) { + for (const SBlock& block : blocks) { for (const BufferRegion& read : block->reads) { auto it = buffer_writers.find(read->buffer->data); if (it != buffer_writers.end()) { - for (const Block& writer : it->second) { + for (const SBlock& writer : it->second) { if (dep_src2dst != nullptr) { (*dep_src2dst)[writer].push_back(block); } @@ -1040,12 +1040,12 @@ class PipelineInjector : private StmtExprMutator { * case 2: stage(A) == stage(B) and order(A) < order(B) */ void ValidatePipelineBody(const PipelineInfo& pipeline_info, - const ffi::Array& original_order) { + const ffi::Array& original_order) { std::unordered_set used_orders; std::unordered_map stage_max_order; - std::unordered_map order_to_block; - std::unordered_map block_to_stage; - for (const Block& block : original_order) { + std::unordered_map order_to_block; + std::unordered_map block_to_stage; + for (const SBlock& block : original_order) { const auto& stmt_info = pipeline_info.at(block); int order = stmt_info.order; CHECK(!used_orders.count(order)) @@ -1053,14 +1053,14 @@ class PipelineInjector : private StmtExprMutator { used_orders.insert(order); } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; BuildDependencyGraph(original_order, &dep_src2dst, nullptr); for (const auto& pair : dep_src2dst) { - const Block& src = pair.first; + const SBlock& src = pair.first; const auto& src_info = pipeline_info.at(src); - const ffi::Array& dsts = pair.second; - for (const Block& dst : dsts) { + const ffi::Array& dsts = pair.second; + for (const SBlock& dst : dsts) { const auto& dst_info = pipeline_info.at(dst); CHECK_LE(src_info.stage, dst_info.stage) << "ValueError: statement " << dst << " in stage " << dst_info.stage @@ -1085,7 +1085,7 @@ class PipelineInjector : private StmtExprMutator { // child of the block. Stmt pipeline_body{nullptr}; ffi::Array pipeline_allocs; - if (const auto* realize = for_node->body.as()) { + if (const auto* realize = for_node->body.as()) { const auto& block = realize->block; for (const auto& buffer : block->alloc_buffers) { ICHECK(buffer->IsInstance()); @@ -1105,16 +1105,16 @@ class PipelineInjector : private StmtExprMutator { // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be // converted into a block. PipelineInfo pipeline_info; - ffi::Array original_order; // pipeline body blocks in the original order + ffi::Array original_order; // pipeline body blocks in the original order auto f_add_child = [&](const Stmt& child) { - original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); + original_order.push_back(MakeSBlock(child, buffer_data_to_buffer_)); }; for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { - const auto* nested_block_realize = pipeline_body_seq->seq[i].as(); + const auto* nested_block_realize = pipeline_body_seq->seq[i].as(); if (nested_block_realize && is_one(nested_block_realize->predicate) && nested_block_realize->block->body->IsInstance()) { - const Block& nested_pipeline_block = nested_block_realize->block; + const SBlock& nested_pipeline_block = nested_block_realize->block; ICHECK( nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered for (const auto& buffer : nested_pipeline_block->alloc_buffers) { @@ -1175,7 +1175,7 @@ class PipelineInjector : private StmtExprMutator { pipeline_allocs, ffi::GetRef(op), pipeline_info, fragment_info_, preserved_annotations); - if (const auto* realize = op->body.as()) { + if (const auto* realize = op->body.as()) { const auto& block = realize->block; for (const auto& buffer : block->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); @@ -1189,7 +1189,7 @@ class PipelineInjector : private StmtExprMutator { * \param n The block pointer to which the buffer allocations are added. * \param alloc_buffers The buffer allocations to be added. */ - void AddAllocBuffers(BlockNode* n, const ffi::Array alloc_buffers) { + void AddAllocBuffers(SBlockNode* n, const ffi::Array alloc_buffers) { for (const Buffer& alloc_buffer : alloc_buffers) { n->alloc_buffers.push_back(alloc_buffer); Region region; @@ -1201,7 +1201,7 @@ class PipelineInjector : private StmtExprMutator { } } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { for (const auto& buffer : op->alloc_buffers) { buffer_data_to_buffer_.Set(buffer->data, buffer); } @@ -1214,7 +1214,7 @@ class PipelineInjector : private StmtExprMutator { << buffer_index << " vs. " << op->writes.size() << ")"; double_buffers.insert(op->writes[buffer_index]->buffer); } - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); for (const auto& buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index ce69053311d1..030aac3c75dd 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -120,9 +120,9 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, // We do not currently support inlining of schedulable TIR // functions. To support this use case, repeated names in - // `tir::Block` nodes resulting from multiple calls to the same + // `tir::SBlock` nodes resulting from multiple calls to the same // inlined function will need to be de-duplicated. - bool has_block_node = prim_func->body.as(); + bool has_block_node = prim_func->body.as(); if (has_block_node) return false; return true; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 8bcb2077c677..0d7a217a0a52 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -247,8 +247,8 @@ class IRConvertSSA final : public StmtExprMutator { return decl; } - Stmt VisitStmt_(const BlockNode* op) final { - Block block = ffi::GetRef(op); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = ffi::GetRef(op); // The BlockNode is the point of definition for the IterVar // instances. These re-defines must be present before visiting @@ -276,7 +276,7 @@ class IRConvertSSA final : public StmtExprMutator { write_ptr->iter_vars = iter_vars; } - Stmt output = Downcast(StmtExprMutator::VisitStmt_(block.get())); + Stmt output = Downcast(StmtExprMutator::VisitStmt_(block.get())); while (redefines.size()) redefines.pop_back(); @@ -748,7 +748,7 @@ class StorageAlignCollector : public StmtVisitor { const Stmt& body); /*! \brief For s-stir, the alignment annotations reside in block annotations. */ - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { auto it = op->annotations.find(attr::buffer_dim_align); if (it != op->annotations.end()) { auto storage_align_annotation = Downcast((*it).second); diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 2f7ac3ddb1c0..fb9ffb24db6c 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -72,11 +72,11 @@ bool IsBoundToThreadIdx(const ForNode* loop) { * \param block The block whose dominant property is to be checked * \return A boolean indicating if the block is a dominant block */ -bool IsDominantBlock(const Block& scope_block, const Block& block) { +bool IsDominantBlock(const SBlock& scope_block, const SBlock& block) { // Step 1. Count the number of writers for each buffer written by the scope block. std::unordered_map buffer_writer_cnt; PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) { - if (const auto* block = obj.as()) { + if (const auto* block = obj.as()) { for (const BufferRegion& buffer_region : block->writes) { ++buffer_writer_cnt[buffer_region->buffer.get()]; } @@ -105,9 +105,9 @@ bool IsDominantBlock(const Block& scope_block, const Block& block) { * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the * check again. */ -bool IsReductionBlock(const BlockRealize& realize, const ffi::Map& loop_range_map, - const Block& scope_block, arith::Analyzer* analyzer) { - const auto* block = realize->block.as(); +bool IsReductionBlock(const SBlockRealize& realize, const ffi::Map& loop_range_map, + const SBlock& scope_block, arith::Analyzer* analyzer) { + const auto* block = realize->block.as(); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { return false; @@ -123,11 +123,11 @@ bool IsReductionBlock(const BlockRealize& realize, const ffi::Map& l } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. - if (!IsDominantBlock(scope_block, ffi::GetRef(block))) { + if (!IsDominantBlock(scope_block, ffi::GetRef(block))) { return false; } // Cond 5. The reduction block vars are not used to index the output buffers. - return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)); + return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)); } /*! @@ -218,7 +218,7 @@ class InThreadReducerMaker : private StmtMutator { } private: - void VisitStmt_(const BlockNode* block) final { + void VisitStmt_(const SBlockNode* block) final { ffi::Array iter_vars = block->iter_vars; for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type == kCommReduce) { @@ -232,22 +232,22 @@ class InThreadReducerMaker : private StmtMutator { ffi::Array reduction_block_vars_; }; - static ffi::Optional Make(const BlockRealizeNode* src_realize, - ffi::Optional tgt_realize, Stmt stmt) { + static ffi::Optional Make(const SBlockRealizeNode* src_realize, + ffi::Optional tgt_realize, Stmt stmt) { return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt)); } private: - explicit InThreadReducerMaker(const BlockRealizeNode* src_realize, - ffi::Optional tgt_realize) + explicit InThreadReducerMaker(const SBlockRealizeNode* src_realize, + ffi::Optional tgt_realize) : src_realize_(src_realize), tgt_realize_(tgt_realize) {} - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { if (realize == src_realize_) { return tgt_realize_.defined() // ? tgt_realize_.value() : Stmt{nullptr}; } - return ffi::GetRef(realize); + return ffi::GetRef(realize); } Stmt VisitStmt_(const ForNode* loop) final { @@ -279,8 +279,8 @@ class InThreadReducerMaker : private StmtMutator { return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts); } - const BlockRealizeNode* src_realize_; - ffi::Optional tgt_realize_; + const SBlockRealizeNode* src_realize_; + ffi::Optional tgt_realize_; }; /*! @@ -295,7 +295,7 @@ class InThreadReducerMaker : private StmtMutator { * \param combiner_rhs The RHS values of the combiner * \param reduction_loops The reduction loops */ -Stmt TransformReductionBlock(const BlockRealizeNode* realize, // +Stmt TransformReductionBlock(const SBlockRealizeNode* realize, // const ffi::Optional>& it_buffers, // const ffi::Array& ct_buffers, // const ffi::Array& wb_buffers, // @@ -304,7 +304,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const ffi::Array& combiner_rhs, // const std::vector& reduction_loops) { int n_buffers = wb_buffers.size(); - const BlockNode* block = realize->block.get(); + const SBlockNode* block = realize->block.get(); auto f_create_buffer_regions = [](ffi::Array buffers) { ffi::Array regions; @@ -335,32 +335,32 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, inits.push_back( BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)})); } - stmts.push_back(BlockRealize(/*iter_values=*/{}, - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/{}, - /*reads=*/{}, - /*writes=*/it_buffer_regions.value(), - /*name_hint=*/block->name_hint + "_in_thread_init", - /*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0]))); + stmts.push_back(SBlockRealize(/*iter_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + SBlock(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/it_buffer_regions.value(), + /*name_hint=*/block->name_hint + "_in_thread_init", + /*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0]))); } // Stmt 2: do in-thread reduction { - ffi::Optional new_realize = std::nullopt; + ffi::Optional new_realize = std::nullopt; // If need to generate in-thread reduction, // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize // otherwise, directly remove given BlockRealize if (it_buffers.defined()) { - ObjectPtr new_block = ffi::make_object(*block); + ObjectPtr new_block = ffi::make_object(*block); new_block->reads = std::move(new_block->reads); new_block->writes = it_buffer_regions.value(); new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body)); new_block->init = std::nullopt; - ObjectPtr n = ffi::make_object(*realize); - n->block = Block(new_block); - new_realize = BlockRealize(n); + ObjectPtr n = ffi::make_object(*realize); + n->block = SBlock(new_block); + new_realize = SBlockRealize(n); } For loop = ffi::GetRef(reduction_loops[0]); if (ffi::Optional stmt = @@ -408,22 +408,22 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, bindings = realize->iter_values; reads = block->reads; } - stmts.push_back(BlockRealize( + stmts.push_back(SBlockRealize( /*iter_values=*/std::move(bindings), /*predicate=*/const_true(), /*block=*/ - Block(/*iter_vars=*/std::move(iter_vars), - /*reads=*/std::move(reads), - /*writes=*/ct_buffer_regions, - /*name_hint=*/block->name_hint + "_cross_thread", - /*body=*/ - AttrStmt(/*node=*/reducer, - /*attr_key=*/tir::attr::reduce_scope, - /*value=*/make_zero(DataType::Handle()), - /*body=*/ - Evaluate(Call(/*dtype=*/DataType::Handle(), - /*op=*/tir::builtin::tvm_thread_allreduce(), - /*args=*/std::move(parameters))))))); + SBlock(/*iter_vars=*/std::move(iter_vars), + /*reads=*/std::move(reads), + /*writes=*/ct_buffer_regions, + /*name_hint=*/block->name_hint + "_cross_thread", + /*body=*/ + AttrStmt(/*node=*/reducer, + /*attr_key=*/tir::attr::reduce_scope, + /*value=*/make_zero(DataType::Handle()), + /*body=*/ + Evaluate(Call(/*dtype=*/DataType::Handle(), + /*op=*/tir::builtin::tvm_thread_allreduce(), + /*args=*/std::move(parameters))))))); } // Stmt 4: write cross-thread reduction result to the original buffer { @@ -508,15 +508,15 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, } } - stmts.push_back(BlockRealize( + stmts.push_back(SBlockRealize( /*iter_values=*/std::move(bindings), /*predicate=*/wb_predicate, /*block=*/ - Block(/*iter_vars=*/std::move(iter_vars), - /*reads=*/std::move(ct_buffer_regions), - /*writes=*/std::move(wb_regions), - /*name_hint=*/block->name_hint + "_write_back", - /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0]))); + SBlock(/*iter_vars=*/std::move(iter_vars), + /*reads=*/std::move(ct_buffer_regions), + /*writes=*/std::move(wb_regions), + /*name_hint=*/block->name_hint + "_write_back", + /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0]))); } // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); @@ -537,21 +537,21 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, class CrossThreadReductionTransformer : public StmtMutator { private: // Check if the input block needs cross-thread reduction. - std::vector NeedCrossThreadReduction(const BlockRealizeNode* realize) { + std::vector NeedCrossThreadReduction(const SBlockRealizeNode* realize) { // Step 0. If the block is the root block, just return. if (block_stack_.empty()) { return {}; } // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. - if (!IsReductionBlock(ffi::GetRef(realize), loop_range_map_, - ffi::GetRef(block_stack_.back()), &analyzer_)) { + if (!IsReductionBlock(ffi::GetRef(realize), loop_range_map_, + ffi::GetRef(block_stack_.back()), &analyzer_)) { return {}; } // Step 2. Collect all the vars that appear in the bindings of reduction block iters. std::unordered_set reduction_vars; - GetVarsTouchedByBlockIters(ffi::GetRef(realize), nullptr, &reduction_vars); + GetVarsTouchedByBlockIters(ffi::GetRef(realize), nullptr, &reduction_vars); // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. // We call these loops "reduction-related". @@ -581,8 +581,8 @@ class CrossThreadReductionTransformer : public StmtMutator { // 3. at least one of the reduction thread vars of the cross-thread reduction // is free to this block (i.e., not bound to the block). std::vector> NeedCrossThreadBroadcast( - const BlockRealizeNode* realize) { - Block block = realize->block; + const SBlockRealizeNode* realize) { + SBlock block = realize->block; // If the block writes to local memory, no rewrite is needed. for (BufferRegion write_region : block->writes) { @@ -632,7 +632,7 @@ class CrossThreadReductionTransformer : public StmtMutator { * - the indices which is used to access the reduction buffers when storing the reduction results */ std::tuple, ffi::Array, ffi::Array> - CheckCanApplyCrossThreadReduction(const BlockNode* block, + CheckCanApplyCrossThreadReduction(const SBlockNode* block, const std::vector& reduction_loops) const { // Condition 1. All the reduction-related loops should be the deepest among all statements // outside the block (ignoring SeqStmt here). @@ -678,7 +678,7 @@ class CrossThreadReductionTransformer : public StmtMutator { ffi::Array combiner_lhs{nullptr}; ffi::Array combiner_rhs{nullptr}; std::tie(init_values, updates) = - GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, ffi::GetRef(block)); + GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, ffi::GetRef(block)); std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(std::nullopt, init_values, updates); @@ -706,7 +706,7 @@ class CrossThreadReductionTransformer : public StmtMutator { // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; PreOrderVisit(ffi::GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { - if (const auto* realize = obj.as()) { + if (const auto* realize = obj.as()) { CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " "block isn't the last block under its first reduction-related loop"; if (realize->block.get() == block) { @@ -774,19 +774,19 @@ class CrossThreadReductionTransformer : public StmtMutator { } } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { ffi::Map old_loop_range_map; block_stack_.push_back(block); std::swap(old_loop_range_map, loop_range_map_); - Block new_block = Downcast(StmtMutator::VisitStmt_(block)); + SBlock new_block = Downcast(StmtMutator::VisitStmt_(block)); block_stack_.pop_back(); std::swap(old_loop_range_map, loop_range_map_); // Insert the new allocated buffers into the block's `alloc_buffers` field. auto it = block2new_buffers_.find(block); if (it != block2new_buffers_.end()) { - BlockNode* p_new_block = new_block.CopyOnWrite(); + SBlockNode* p_new_block = new_block.CopyOnWrite(); for (const Buffer& new_buffer : it->second) { if (new_buffer.defined()) { p_new_block->alloc_buffers.push_back(new_buffer); @@ -796,9 +796,9 @@ class CrossThreadReductionTransformer : public StmtMutator { return new_block; } - void MakeCrossThreadReduction(const BlockRealizeNode* realize, + void MakeCrossThreadReduction(const SBlockRealizeNode* realize, const std::vector reduction_loops) { - const BlockNode* block = realize->block.get(); + const SBlockNode* block = realize->block.get(); // Step 1. Check whether cross-thread reduction can be applied. If no, throw an exception on // which condition the block violates. @@ -848,7 +848,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } Stmt MakeCrossThreadBroadcast( - const BlockRealizeNode* realize, + const SBlockRealizeNode* realize, const std::vector>& unbound_thread2range) { // Step 1. Generate loop var for each unbound thread. // Update the block predicate with clauses of `thread_var == min`. @@ -863,7 +863,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } // Step 2. Update the BlockRealize with the new predicate. - ObjectPtr p_realize = ffi::make_object(*realize); + ObjectPtr p_realize = ffi::make_object(*realize); p_realize->predicate = std::move(predicate); // Step 3. Wrap the updated BlockRealize with the new loops. @@ -885,7 +885,7 @@ class CrossThreadReductionTransformer : public StmtMutator { return body; } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const SBlockRealizeNode* realize) final { // Part 1. Check if the block needs cross-thread reduction rewrite. std::vector reduction_loops = NeedCrossThreadReduction(realize); if (!reduction_loops.empty()) { @@ -915,8 +915,8 @@ class CrossThreadReductionTransformer : public StmtMutator { bool has_cross_thread_reduction_ = false; std::vector statement_stack_; std::vector loop_stack_; - std::vector block_stack_; - std::unordered_map> block2new_buffers_; + std::vector block_stack_; + std::unordered_map> block2new_buffers_; std::unordered_map loop2new_stmt_; ffi::Map loop_range_map_; arith::Analyzer analyzer_; diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index 5ae654077316..3ccaa7cea75f 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -33,7 +33,7 @@ namespace tir { class InitBlockLower : public StmtMutator { private: - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const SBlockNode* block) final { if (!block->init.defined()) { return StmtMutator::VisitStmt_(block); } @@ -42,7 +42,7 @@ class InitBlockLower : public StmtMutator { auto n = CopyOnWrite(block); n->init = std::nullopt; n->body = SeqStmt::Flatten(init, body); - return Block(n); + return SBlock(n); } static Stmt DoLowering(const Stmt& init, const ffi::Array& iter_vars) { diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index dc3cc0dbab39..b426f60a450e 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -44,13 +44,13 @@ class MatchBufferLower : public StmtExprMutator { } private: - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { for (const MatchBufferRegion& match_buffer : op->match_buffers) { CheckAndUpdateVarMap(match_buffer); } Stmt stmt = StmtExprMutator ::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); ICHECK(op != nullptr); ffi::Array reads = op->reads.Map(std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index c0363dd8982f..b5d6f35eb8bc 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -31,7 +31,7 @@ namespace tvm { namespace tir { /*! - * \brief Remove Block to ensure that the TIR can not be scheduled again. + * \brief Remove SBlock to ensure that the TIR can not be scheduled again. */ class OpaqueBlockLower : public StmtExprMutator { public: @@ -42,12 +42,12 @@ class OpaqueBlockLower : public StmtExprMutator { } private: - Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt VisitStmt_(const SBlockRealizeNode* op) final { // We have convert blocks into opaque blocks in previous passes. ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " "call pass ConvertBlocksToOpaque before."; // Step 1. Visit the body - Block new_block = Downcast(this->VisitStmt(op->block)); + SBlock new_block = Downcast(this->VisitStmt(op->block)); PrimExpr predicate = this->VisitExpr(op->predicate); // Step 2. Transform the `predicate` to if-then-else Stmt body = new_block->body; diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 8d0b71e75e5d..4addb7823bda 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -50,7 +50,7 @@ class IntermediateStageRewriter { explicit IntermediateStageRewriter(const std::vector& ancestor_loop_or_blocks) : ancestor_loop_or_blocks_(ancestor_loop_or_blocks) {} - std::tuple Rewrite(const BlockNode* block) { + std::tuple Rewrite(const SBlockNode* block) { const BufferStoreNode* store = block->body.as(); CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == runtime::StorageRank::kShared) @@ -73,7 +73,7 @@ class IntermediateStageRewriter { BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; - Block new_block = ffi::GetRef(block); + SBlock new_block = ffi::GetRef(block); new_block.CopyOnWrite()->body = std::move(new_buffer_store); return {target_buffer, new_buffer, new_block, local_stage}; @@ -81,7 +81,7 @@ class IntermediateStageRewriter { private: /*! \brief Collect relaxed outer loops from innermost to outermost */ - std::vector CollectRelaxedOuterLoops(const BlockNode* block, + std::vector CollectRelaxedOuterLoops(const SBlockNode* block, const Buffer& target_buffer) { std::vector relaxed_loops; for (int n = static_cast(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) { @@ -97,15 +97,15 @@ class IntermediateStageRewriter { CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1])) << "ValueError: Expect the ancestor loops to have a single child."; } else { - const BlockRealizeNode* block_realize = ancestor_loop->body.as(); + const SBlockRealizeNode* block_realize = ancestor_loop->body.as(); ICHECK(block_realize != nullptr); CHECK(block_realize != nullptr && block_realize->block.get() == block) << "ValueError: Expect the ancestor loops to have a single child."; } } else { - const BlockRealizeNode* ancestor_block_realize = ancestor.as(); + const SBlockRealizeNode* ancestor_block_realize = ancestor.as(); ICHECK(ancestor_block_realize != nullptr); - const BlockNode* ancestor_block = ancestor_block_realize->block.get(); + const SBlockNode* ancestor_block = ancestor_block_realize->block.get(); auto it = std::find_if( ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(), [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); }); @@ -118,7 +118,7 @@ class IntermediateStageRewriter { } /*! \brief Create the intermediate stage. */ - Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer, + Stmt MakeLocalStage(const SBlockNode* block, const Buffer& new_buffer, ffi::Array local_stage_indices, std::vector relaxed_loops, const BufferStoreNode* store) { // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer. @@ -127,12 +127,12 @@ class IntermediateStageRewriter { // Step 1: Make block and block realize BufferRegion write_buffer_region = BufferRegion::FromPoint(new_buffer, local_stage_indices); local_stage = - Block(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "", - /*body=*/std::move(local_stage)); - local_stage = BlockRealize( + SBlock(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "", + /*body=*/std::move(local_stage)); + local_stage = SBlockRealize( /*iter_values=*/{}, - /*predicate=*/ancestor_loop_or_blocks_.back().as()->predicate, - Downcast(local_stage)); + /*predicate=*/ancestor_loop_or_blocks_.back().as()->predicate, + Downcast(local_stage)); // Step 2: Add outer loops ffi::Map subst_map; @@ -178,14 +178,14 @@ class SharedMemoryLocalStageInserter : public StmtMutator { return new_stmt; } - Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt VisitStmt_(const SBlockRealizeNode* op) final { ancestor_loop_or_blocks_.push_back(ffi::GetRef(op)); Stmt new_stmt = StmtMutator::VisitStmt_(op); ancestor_loop_or_blocks_.pop_back(); return new_stmt; } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { if (op->annotations.count(attr::manifest_shared_memory_local_stage)) { // Rewrite the shared memory access to load from the intermediate buffer. // The annotated block must be a leaf block (will be checked during rewriting). No need to @@ -249,8 +249,8 @@ class SharedMemoryLocalStageInserter : public StmtMutator { new_seq.push_back(body); } - Block new_block = ffi::GetRef(op); - BlockNode* new_block_node = new_block.CopyOnWrite(); + SBlock new_block = ffi::GetRef(op); + SBlockNode* new_block_node = new_block.CopyOnWrite(); // Add new buffer allocations if any. if (new_alloc_buffers.size() > 0) { new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers); diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 498de4796cd4..f4dc6579cd0b 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -211,7 +211,7 @@ class AutoPadder { return store; } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { // To reduce the number of blocks in block sref reuse map, we check whether the block is // really mutated (i.e., the old buffer appears in the block). If so, we return the block // after mutation. Otherwise we just return the original block. @@ -256,13 +256,13 @@ class AutoPadder { } if (changed) { - ObjectPtr block = CopyOnWrite(res.as()); + ObjectPtr block = CopyOnWrite(res.as()); block->reads = std::move(reads); block->writes = std::move(writes); block->match_buffers = std::move(match_buffers); return Stmt(block); } else { - return ffi::GetRef(op); + return ffi::GetRef(op); } } const ffi::Map& buffer_map_; @@ -561,7 +561,7 @@ class AutoPadder { * threadIdx. The iteration space would be {{0, 1, ..., 15}, {0, 1, ..., 15}}. * \param op the call node */ - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (const auto* eval = op->body.as()) { if (const auto* call = eval->value.as()) { if (call->op == builtin::tvm_load_matrix_sync() || @@ -661,11 +661,11 @@ class AutoCopyMutator : public StmtExprMutator { Stmt RewritePaddingBody(const Stmt& stmt) { return padder.RewriteBufferAccess(stmt); } private: - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtMutator::VisitStmt_(op)); // only rewrite the block annotated with "auto_copy" if (!GetAnn(op, tir::attr::auto_copy).value_or(false)) { - BlockNode* n = block.CopyOnWrite(); + SBlockNode* n = block.CopyOnWrite(); n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); return block; } @@ -691,7 +691,7 @@ class AutoCopyMutator : public StmtExprMutator { block->writes[0], // data_bits, // block->annotations); - BlockNode* n = block.CopyOnWrite(); + SBlockNode* n = block.CopyOnWrite(); OutputSet outputs; for (RewriteRule* rule : rules) { n->body = rule->Apply(std::move(n->body), constraints, &outputs); @@ -744,7 +744,7 @@ class ThreadExtentCollector : public StmtVisitor { } private: - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { if (ffi::Optional warp_execution = GetAnn(op, "warp_execution")) { if (warp_execution.value()->value != 0) { thread_extent_.Set("threadIdx.x", Integer(32)); diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index e69ac30366b1..4c03c155db1a 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -145,10 +145,10 @@ Stmt RewriteWmmaLoad(Stmt stmt) { /*buffer_type=*/kDefault); ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); - Stmt wmma_body = BlockRealize( + Stmt wmma_body = SBlockRealize( /*iter_values=*/{}, /*predicate=*/Bool(true), - Block( + SBlock( /*iter_vars=*/{}, /*reads=*/{BufferRegion(src_buffer, read_region)}, /*writes=*/{BufferRegion(tgt_buffer, write_region)}, @@ -254,43 +254,43 @@ Stmt RewriteWmmaStore(Stmt stmt) { ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); - Stmt wmma_body = BlockRealize( + Stmt wmma_body = SBlockRealize( /*iter_values=*/{}, // /*predicate=*/Bool(true), - Block(/*iter_vars=*/{}, - /*reads=*/{BufferRegion(src_buffer, read_region)}, - /*writes=*/{BufferRegion(tgt_buffer, write_region)}, - /*name_hint=*/"wmma_store", - Evaluate(Call( - /*data=*/runtime::DataType::Handle(), - /*op=*/builtin::tvm_store_matrix_sync(), - {/*0:*/ new_src_buffer->data, - /*1:*/ 16, - /*2:*/ 16, - /*3:*/ 16, - /*4:*/ floordiv(new_src_buffer->elem_offset, 256) + - floordiv(floormod(new_src_buffer->elem_offset, 256), 16), - /*5:*/ - Call( - /*data=*/runtime::DataType::Handle(), - /*op=*/builtin::tvm_access_ptr(), - { - /*0:*/ TypeAnnotation(new_tgt_buffer->dtype), - /*1:*/ new_tgt_buffer->data, - /*2:*/ new_tgt_buffer->elem_offset, - /*3:*/ new_tgt_buffer->strides[0] * 16, - /*4:*/ 2, - }), - /*6:*/ new_tgt_buffer->strides[0], - /*7:*/ StringImm("row_major")})), - /*init=*/std::nullopt, - /*alloc_buffers=*/{}, - /*match_buffers=*/ - { - MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), - MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), - }, - /*annotations=*/{})); + SBlock(/*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_store", + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_store_matrix_sync(), + {/*0:*/ new_src_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_src_buffer->elem_offset, 256) + + floordiv(floormod(new_src_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + { + /*0:*/ TypeAnnotation(new_tgt_buffer->dtype), + /*1:*/ new_tgt_buffer->data, + /*2:*/ new_tgt_buffer->elem_offset, + /*3:*/ new_tgt_buffer->strides[0] * 16, + /*4:*/ 2, + }), + /*6:*/ new_tgt_buffer->strides[0], + /*7:*/ StringImm("row_major")})), + /*init=*/std::nullopt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); for (int i = n - 3; i >= 0; i--) { auto new_loop = ffi::GetRef(loops[i]); new_loop.CopyOnWrite()->body = std::move(wmma_body); @@ -481,36 +481,36 @@ Stmt RewriteMmaStore(Stmt stmt) { // tgt[tx // 4, (tx % 4) * 2 + vec] = src[tx // 4, (tx % 4) * 2 + vec] Var tx = Var("tx"); Var vec = Var("vec"); - Stmt mma_body = BlockRealize( + Stmt mma_body = SBlockRealize( /*iter_values=*/{}, // /*predicate=*/Bool(true), - Block(/*iter_vars=*/{}, - /*reads=*/{BufferRegion(src_buffer, read_region)}, - /*writes=*/{BufferRegion(tgt_buffer, write_region)}, - /*name_hint=*/"mma_store", - AttrStmt( - /*node=*/IterVar( - /*dom=*/Range::FromMinExtent(0, 32), - /*var=*/tx, - /*iter_type=*/IterVarType::kThreadIndex, - /*thread_tag=*/"threadIdx.x"), - /*attr_key=*/"thread_extent", - /*value=*/Integer(32), - /*body=*/ - For(vec, 0, 2, ForKind::kVectorized, - /*body=*/ - BufferStore( - new_tgt_buffer, - BufferLoad(new_src_buffer, {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), - {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}))), - /*init=*/std::nullopt, - /*alloc_buffers=*/{}, - /*match_buffers=*/ - { - MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), - MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), - }, - /*annotations=*/{})); + SBlock(/*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"mma_store", + AttrStmt( + /*node=*/IterVar( + /*dom=*/Range::FromMinExtent(0, 32), + /*var=*/tx, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/"threadIdx.x"), + /*attr_key=*/"thread_extent", + /*value=*/Integer(32), + /*body=*/ + For(vec, 0, 2, ForKind::kVectorized, + /*body=*/ + BufferStore( + new_tgt_buffer, + BufferLoad(new_src_buffer, {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), + {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}))), + /*init=*/std::nullopt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); // Step 3.4. wrap outer loops for (int i = n - 3; i >= 0; i--) { diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 7f19a8992998..31e5cb348ec6 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -111,7 +111,7 @@ class DataTypeVisitor final : public StmtExprVisitor { return StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const BlockNode* op) { + void VisitStmt_(const SBlockNode* op) { for (const IterVar& iter : op->iter_vars) { analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); vextent_[iter->var.as()] = iter->dom->extent.dtype(); diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 779076a89f6f..65bd05975b67 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -35,7 +35,7 @@ namespace tir { class CollectManagedAllocations : public StmtExprVisitor { public: - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { for (const auto& buf : op->alloc_buffers) { managed_allocations.insert(buf->data.get()); } @@ -68,7 +68,7 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { buffer_alloc_recorder_.end(); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const SBlockNode* op) final { for (const Buffer& buffer : op->alloc_buffers) { buffer_alloc_recorder_.push_back(buffer); } @@ -160,7 +160,7 @@ class BufferAllocationLocator : public StmtExprMutator { return node; } - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { ICHECK(!op->init.defined()); ffi::Array alloc_buffers; auto it = alloc_buffers_.find(op); @@ -177,7 +177,7 @@ class BufferAllocationLocator : public StmtExprMutator { buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); } Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); ICHECK(op != nullptr); // No longer consider buffers created by match_buffer inside the block when updating access @@ -193,7 +193,7 @@ class BufferAllocationLocator : public StmtExprMutator { } } - ObjectPtr n = CopyOnWrite(op); + ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); // Erase buffer allocated inside the block from access region. n->reads = RemoveRedundantBufferRegion(n->reads); @@ -208,19 +208,19 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt InjectOpaqueBlock(Stmt body, const ffi::Array& alloc_buffers) { ICHECK(!alloc_buffers.empty()); - Block opaque_block(/*iter_vars=*/{}, - /*reads=*/{}, - /*writes=*/{}, - /*name_hint=*/"", - /*body=*/std::move(body), - /*init=*/std::nullopt, - /*alloc_buffers=*/alloc_buffers); - ObjectPtr n = CopyOnWrite(opaque_block.get()); + SBlock opaque_block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/"", + /*body=*/std::move(body), + /*init=*/std::nullopt, + /*alloc_buffers=*/alloc_buffers); + ObjectPtr n = CopyOnWrite(opaque_block.get()); ffi::Array> access = - GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); + GetSBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); n->reads = access[0]; n->writes = access[1]; - BlockRealize realize({}, Bool(true), Block(n)); + SBlockRealize realize({}, Bool(true), SBlock(n)); return realize; } diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 5b2b5704c5c9..86f1ed64007b 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -48,8 +48,8 @@ class RemoveLayoutRewriteBlock : public StmtMutator { } private: - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtMutator::VisitStmt_(op)); auto it = block->annotations.find(attr::meta_schedule_layout_rewrite_preproc); if (it == block->annotations.end() || !is_one(Downcast((*it).second))) { @@ -158,8 +158,8 @@ class AllocateConstRewrite : public StmtExprMutator { skip_tensor_rewrite_(skip_tensor_rewrite) {} private: - Stmt VisitStmt_(const BlockNode* op) final { - Block block = Downcast(StmtMutator::VisitStmt_(op)); + Stmt VisitStmt_(const SBlockNode* op) final { + SBlock block = Downcast(StmtMutator::VisitStmt_(op)); auto n = CopyOnWrite(block.get()); ffi::Array new_reads; for (auto read_region : op->reads) { diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 69002a9e1d78..ee72245184f6 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -103,7 +103,7 @@ class RenewDefMutator : public StmtExprMutator { STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var); STMT_REGENERATE_VAR_DEF(ForNode, loop_var); - Stmt VisitStmt_(const BlockNode* op) final { + Stmt VisitStmt_(const SBlockNode* op) final { // Step 0. Re-define Itervars ffi::Array iter_vars = op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); @@ -130,7 +130,7 @@ class RenewDefMutator : public StmtExprMutator { op->writes.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); // Step 5. Regenerate block. Since the defs are changed, we need to create a new block - auto n = ffi::make_object(*op); + auto n = ffi::make_object(*op); n->iter_vars = std::move(iter_vars); n->alloc_buffers = std::move(alloc_buffers); n->match_buffers = std::move(match_buffers); diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index 60b6ffda3219..31e249394524 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -43,8 +43,8 @@ namespace tir { */ class MmaBufferLayoutTransformer : public StmtExprMutator { public: - Stmt VisitStmt_(const BlockNode* op) { - Block block = ffi::GetRef(op); + Stmt VisitStmt_(const SBlockNode* op) { + SBlock block = ffi::GetRef(op); auto* n = block.CopyOnWrite(); auto fmutate = [this](const Buffer& buffer) { // m16n8k8.matrix[A/B/C] buffers are composed ofseveral small blocks. Assume the block's diff --git a/tests/cpp/data_type_rewriter_test.cc b/tests/cpp/data_type_rewriter_test.cc index 1eec334344b3..b7575812fe6e 100644 --- a/tests/cpp/data_type_rewriter_test.cc +++ b/tests/cpp/data_type_rewriter_test.cc @@ -73,7 +73,7 @@ TEST(DataTypeLegalizer, IfThenElse) { } TEST(DataTypeLegalizer, Block) { - auto block_node = ffi::make_object(); + auto block_node = ffi::make_object(); auto iter_var_node = ffi::make_object(); iter_var_node->var = Var("i", DataType::Int(32)); iter_var_node->dom = @@ -84,17 +84,17 @@ TEST(DataTypeLegalizer, Block) { block_node->writes = {}; block_node->name_hint = "block"; block_node->body = Evaluate(Integer(0)); - auto block_realize_node = ffi::make_object(); + auto block_realize_node = ffi::make_object(); auto loop_var = Var("i", DataType::Int(32)); block_realize_node->iter_values = {loop_var}; block_realize_node->predicate = const_true(); - block_realize_node->block = Block(block_node); + block_realize_node->block = SBlock(block_node); auto for_node = ffi::make_object(); for_node->loop_var = loop_var; for_node->min = IntImm(DataType::Int(64), 0); for_node->extent = IntImm(DataType::Int(64), 10); for_node->kind = ForKind::kSerial; - for_node->body = BlockRealize(block_realize_node); + for_node->body = SBlockRealize(block_realize_node); Stmt stmt = For(for_node); DataTypeLegalizer legalizer; @@ -104,9 +104,9 @@ TEST(DataTypeLegalizer, Block) { ASSERT_EQ(new_for->loop_var.dtype(), target_dtype); ASSERT_EQ(new_for->min.dtype(), target_dtype); ASSERT_EQ(new_for->extent.dtype(), target_dtype); - const BlockRealizeNode* new_block_realize = new_for->body.as(); + const SBlockRealizeNode* new_block_realize = new_for->body.as(); ASSERT_EQ(new_block_realize->iter_values[0].dtype(), target_dtype); - const BlockNode* new_block = new_block_realize->block.as(); + const SBlockNode* new_block = new_block_realize->block.as(); ASSERT_EQ(new_block->iter_vars[0]->dom->min.dtype(), target_dtype); ASSERT_EQ(new_block->iter_vars[0]->dom->extent.dtype(), target_dtype); ASSERT_EQ(new_block->iter_vars[0]->var.dtype(), target_dtype); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index ec7b4111d240..83418d352b5b 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -60,9 +60,9 @@ TEST(IRF, PreOrderVisit) { using namespace tvm::tir; Stmt init = IfThenElse(const_true(), Evaluate(Integer(0)), Evaluate(Integer(0))); Stmt body = Evaluate(Integer(1)); - Block block(/*iter_vars=*/{}, /*reads=*/{}, - /*writes=*/{}, /*name_hint=*/"block", /*body=*/body, - /*init=*/init); + SBlock block(/*iter_vars=*/{}, /*reads=*/{}, + /*writes=*/{}, /*name_hint=*/"block", /*body=*/body, + /*init=*/init); bool init_visited = false; bool stopped_at_if = true; bool body_visited = false; @@ -169,9 +169,9 @@ TEST(IRF, StmtVisitor) { MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); // construct block and block_realize - Block block = - Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region}); - Stmt block_realize = BlockRealize({}, const_true(), block); + SBlock block = SBlock({}, {buffer_region}, {buffer_region}, "block", body, body, {}, + {match_buffer_region}); + Stmt block_realize = SBlockRealize({}, const_true(), block); v.count = 0; v(block_realize); @@ -296,12 +296,12 @@ TEST(IRF, StmtMutator) { BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); // construct block and block_realize - Block block = - Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region}); - Stmt block_realize = BlockRealize({}, const_true(), block); + SBlock block = SBlock({}, {buffer_region}, {buffer_region}, "block", body, body, {}, + {match_buffer_region}); + Stmt block_realize = SBlockRealize({}, const_true(), block); body = v(std::move(block_realize)); // the body should be changed - Block new_block = body.as()->block; + SBlock new_block = body.as()->block; ICHECK(new_block->body.as()->body.as()->extents[1].same_as(x)); ICHECK(new_block->init.as()->body.as()->extents[1].same_as(x)); ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index 09c9fa13386e..31b6511e0e66 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -28,7 +28,7 @@ def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> N B = T.match_buffer(b, [1, d1, d2]) for i, j, k, l in T.grid(1, d1, d2, d3): - with T.block("reduce"): + with T.sblock("reduce"): vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) with T.init(): B[vi, vj, vk] = 0.0 @@ -41,7 +41,7 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) B = T.match_buffer(b, [1, d1, d2]) for i, j, k, l in T.grid(1, d1, d2, d3): - with T.block("reduce"): + with T.sblock("reduce"): vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) with T.init(): B[vi, vj, vk] = T.float32(-3.4028234663852886e38) @@ -65,7 +65,7 @@ def test_allreduce_sum(dims, target, dev): _, _, _d1, _d2, _d3 = reduce.params mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) sch = tvm.tir.Schedule(mod) - blk = sch.get_block("reduce") + blk = sch.get_sblock("reduce") i, j, k, l = sch.get_loops(blk) sch.bind(i, "blockIdx.x") sch.bind(j, "threadIdx.z") @@ -117,7 +117,7 @@ def test_allreduce_sum_compile(optional_metal_compile_callback): _, _, _d1, _d2, _d3 = reduce.params mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) sch = tvm.tir.Schedule(mod) - blk = sch.get_block("reduce") + blk = sch.get_sblock("reduce") i, j, k, l = sch.get_loops(blk) sch.bind(i, "blockIdx.x") sch.bind(j, "threadIdx.z") @@ -132,7 +132,7 @@ def test_allreduce_max(dims, target, dev): _, _, _d1, _d2, _d3 = reduce_max.params mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3}) sch = tvm.tir.Schedule(mod) - blk = sch.get_block("reduce") + blk = sch.get_sblock("reduce") i, j, k, l = sch.get_loops(blk) sch.bind(i, "blockIdx.x") sch.bind(j, "threadIdx.z") diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index fd2f598c924e..5650aabd58c3 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -27,10 +27,10 @@ def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> No tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_local = T.alloc_buffer((32), "float32", scope="local") - with T.block(): + with T.sblock(): T.reads(A[0:16]) T.writes(A_local[0:32]) A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0), dtype="float32") diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index ef425dbf73e0..d3e1cd61db6f 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -46,7 +46,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for i in range(vector_length): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(vector_length, i) T.reads(A[v_i], B[v_i]) T.writes(C[v_i]) @@ -55,7 +55,7 @@ def add( ) sch = tvm.tir.Schedule(add) - block = sch.get_block("C") + block = sch.get_sblock("C") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 32]) sch.bind(bx, "blockIdx.x") @@ -127,7 +127,7 @@ def shuffle_reinterpret( ): T.func_attr({"tir.noalias": True}) for i in range(n): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(n, i) T.reads(A[v_i]) T.writes(B[v_i]) @@ -156,7 +156,7 @@ def scalar_reinterpret( ): T.func_attr({"tir.noalias": True}) for i in range(n): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(n, i) T.reads(A[v_i]) T.writes(B[v_i]) @@ -173,7 +173,7 @@ def scalar_reinterpret( func = shuffle_reinterpret if func_type == "shuffle" else scalar_reinterpret sch = tvm.tir.Schedule(func) - block = sch.get_block("C") + block = sch.get_sblock("C") b = sch.get_loops(block) bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) sch.bind(bx, "blockIdx.x") diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 4ea938cad8ad..9688556b99aa 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -55,14 +55,14 @@ def add( ): T.func_attr({"tir.noalias": True}) for i in range(64): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(64, i) T.reads(A[v_i], B[v_i]) T.writes(C[v_i]) C[v_i] = T.Cast(dtype, T.Cast("float16", A[v_i]) + T.Cast("float16", B[v_i])) sch = tvm.tir.Schedule(add) - block = sch.get_block("C") + block = sch.get_sblock("C") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 32]) sch.bind(bx, "blockIdx.x") @@ -103,27 +103,27 @@ def add( B: T.Buffer((length,), native_dtype), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i in range(length): - with T.block("R"): + with T.sblock("R"): v_i = T.axis.spatial(length, i) T.reads(A[v_i]) T.writes(R[v_i]) R[v_i] = T.reinterpret(packed_dtype, A[v_i]) for i in range(length): - with T.block("B"): + with T.sblock("B"): v_i = T.axis.spatial(length, i) T.reads(R[v_i]) T.writes(B[v_i]) B[v_i] = T.reinterpret(native_dtype, R[v_i]) sch = tvm.tir.Schedule(add) - block = sch.get_block("R") + block = sch.get_sblock("R") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 32]) sch.bind(bx, "blockIdx.x") sch.bind(tx, "threadIdx.x") - block = sch.get_block("B") + block = sch.get_sblock("B") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 32]) sch.bind(bx, "blockIdx.x") @@ -171,9 +171,9 @@ def add( C: T.Buffer((vector_length,), native_dtype), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i in range(vector_length): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(vector_length, i) T.reads(A[v_i], B[v_i]) T.writes(C[v_i]) @@ -182,7 +182,7 @@ def add( ) sch = tvm.tir.Schedule(add) - block = sch.get_block("C") + block = sch.get_sblock("C") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 32]) sch.bind(bx, "blockIdx.x") @@ -228,11 +228,11 @@ def test_half_broadcast(bcast_length): @T.prim_func def vector_broadcast(a: T.Buffer((), dtype), vec: T.Buffer((bcast_length,), dtype)): for t in range(1): - with T.block("broadcast"): + with T.sblock("broadcast"): vec[0:bcast_length] = T.broadcast(a[()], bcast_length) sch = tvm.tir.Schedule(vector_broadcast) - block = sch.get_block("broadcast") + block = sch.get_sblock("broadcast") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 1]) sch.bind(bx, "blockIdx.x") @@ -305,16 +305,16 @@ def add( C: T.Buffer((length,), vec_dtype), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i in range(length): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(length, i) T.reads(A[v_i], B[v_i]) T.writes(C[v_i]) C[v_i] = A[v_i] + B[v_i] sch = tvm.tir.Schedule(add) - block = sch.get_block("C") + block = sch.get_sblock("C") b = sch.get_loops(block) bx, tx = sch.split(b[0], factors=[None, 32]) sch.bind(bx, "blockIdx.x") @@ -576,12 +576,12 @@ def quant_pack( storage_dtype, ), ): - # with T.block("root"): + # with T.sblock("root"): # test = T.alloc_buffer(1, dtype=vec_model_dtype, scope="local") for i0, i1 in T.grid( T.int64(weight_shape[0]), T.int64(weight_shape[1] // vector_length) ): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads( A[v_i0, v_i1 : v_i1 + vector_length], @@ -623,9 +623,9 @@ def dequant( dequantize: T.Buffer(out_shape, model_dtype), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(packed_weight_shape[0]), T.int64(packed_weight_shape[1])): - with T.block("dequantize"): + with T.sblock("dequantize"): v_i0 = T.axis.spatial(T.int64(packed_weight_shape[0]), i0) v_i1 = T.axis.spatial(T.int64(packed_weight_shape[1]), i1) T.reads( @@ -884,7 +884,7 @@ def moe_dequantize_gemv( num_seq = T.int64() x = T.match_buffer(x_handle, (num_seq, reduce_size), "float16") for expert_id in T.thread_binding(2, thread="blockIdx.y"): - with T.block("gemv_o"): + with T.sblock("gemv_o"): e = T.axis.spatial(2, expert_id) T.reads( w[indptr[0, e], 0:spatial_size, 0:reduce_size], @@ -895,13 +895,13 @@ def moe_dequantize_gemv( T.writes(o[e, 0:spatial_size]) y = T.alloc_buffer((spatial_size, reduce_size), "float16") for i1, i2 in T.grid(spatial_size, reduce_size): - with T.block("dequantize"): + with T.sblock("dequantize"): i, j = T.axis.remap("SS", [i1, i2]) T.reads(w[indptr[0, e], i, j], indptr[0, e], scale[0]) T.writes(y[i, j]) y[i, j] = T.Cast("float16", w[indptr[0, e], i, j]) * scale[0] for i1, i2 in T.grid(spatial_size, reduce_size): - with T.block("gemv"): + with T.sblock("gemv"): i, j = T.axis.remap("SR", [i1, i2]) T.reads(x[e, j], y[i, j]) T.writes(o[e, i]) @@ -983,12 +983,12 @@ def func_vectorize( C: T.Buffer((128,), dtype), ) -> None: for i in T.serial(128): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0) sch = tir.Schedule(func_vectorize) - (l,) = sch.get_loops(sch.get_block("compute")) + (l,) = sch.get_loops(sch.get_sblock("compute")) lo, li = sch.split(l, [None, vec_length]) sch.bind(lo, "threadIdx.x") sch.vectorize(li) diff --git a/tests/python/codegen/test_target_codegen_device.py b/tests/python/codegen/test_target_codegen_device.py index b897d50b41c7..d82cecee4e8e 100644 --- a/tests/python/codegen/test_target_codegen_device.py +++ b/tests/python/codegen/test_target_codegen_device.py @@ -36,7 +36,7 @@ def test_large_uint_imm(): sch = tir.Schedule(mod) # Get block and loop - block = sch.get_block("A") + block = sch.get_sblock("A") loop = sch.get_loops(block)[0] # Split and bind @@ -71,8 +71,8 @@ def test_add_pipeline(): sch = tir.Schedule(mod) # Get blocks and loops - c_block = sch.get_block("C") - d_block = sch.get_block("D") + c_block = sch.get_sblock("C") + d_block = sch.get_sblock("D") c_loop = sch.get_loops(c_block)[0] d_loop = sch.get_loops(d_block)[0] diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py index b115fddb57f7..844dbbc129fb 100644 --- a/tests/python/codegen/test_target_codegen_gpu_common.py +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -38,7 +38,7 @@ def run_test(tvm_intrin, np_func, dtype): B = te.compute(A.shape, lambda *i: tvm_intrin(A(*i)), name="B") func = te.create_prim_func([A, B]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("B")) + (x,) = sch.get_loops(sch.get_sblock("B")) sch.bind(x, "threadIdx.x") f = tvm.compile(sch.mod, target=target) a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 88b791d1aa52..0597d647f879 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -137,8 +137,8 @@ def test_llvm_multi_parallel(): sch = tir.Schedule(mod) # Get blocks and loops - c_block = sch.get_block("C") - b_block = sch.get_block("B") + c_block = sch.get_sblock("C") + b_block = sch.get_sblock("B") c_loop = sch.get_loops(c_block)[0] # Split and parallelize @@ -180,7 +180,7 @@ def check_llvm(nn, base): sch = tir.Schedule(mod) # Get block and loop - block = sch.get_block("C") + block = sch.get_sblock("C") loop = sch.get_loops(block)[0] # Split and parallelize @@ -216,7 +216,7 @@ def test_llvm_vadd_pipeline(): sch = tir.Schedule(mod) # Get block and loop - block = sch.get_block("C") + block = sch.get_sblock("C") loop = sch.get_loops(block)[0] # Split the loop @@ -245,7 +245,7 @@ def check_llvm(nn, base, stride): sch = tir.Schedule(mod) # Get block and loops - block = sch.get_block("C") + block = sch.get_sblock("C") i_loop, j_loop = sch.get_loops(block) # Split and parallelize @@ -444,7 +444,7 @@ def test_alignment(): sch = tir.Schedule(mod) # Get block and loop - block = sch.get_block("B") + block = sch.get_sblock("B") loop = sch.get_loops(block)[0] # Split and vectorize @@ -679,7 +679,7 @@ def test_dwarf_debug_information(): sch = tir.Schedule(mod) # Get block and loop - block = sch.get_block("C") + block = sch.get_sblock("C") loop = sch.get_loops(block)[0] # Split and parallelize @@ -769,7 +769,7 @@ def dotest(do_vectorize): sch = tir.Schedule(mod) # Get block and loop - block = sch.get_block("D") + block = sch.get_sblock("D") loop = sch.get_loops(block)[0] # Apply vectorization if requested diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index b969f0e0b911..061fe69947e7 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -34,7 +34,7 @@ def check_inf_nan(dev, n, value, dtype): C = te.compute((n,), lambda i: inf_value, name="C") prim_func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(prim_func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.compile(sch.mod, target=target) a = tvm.runtime.empty((n,), A.dtype, dev) @@ -62,7 +62,7 @@ def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")): T.func_attr({"global_symbol": "main"}) for i0_1 in T.thread_binding(3, thread="threadIdx.x"): for i0_0 in T.vectorized(2): - with T.block("block"): + with T.sblock("block"): vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1) B[vi0] = A[vi0 // 3, vi0 % 3] @@ -87,7 +87,7 @@ def check_erf(dev, n, dtype): C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.compile(sch.mod, target=target) a = tvm.runtime.empty((n,), A.dtype, dev) @@ -112,7 +112,7 @@ class IRModule: def main(A: T.Buffer((1, 2), "int32")): T.func_attr({"global_symbol": "main"}) for i in T.thread_binding(1, thread="threadIdx.x"): - with T.block("block"): + with T.sblock("block"): tx = T.axis.spatial(1, i) r = T.ramp(tx, 3, 2) A[0, T.ramp(0, 1, 2)] = r @@ -134,7 +134,7 @@ def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")): T.func_attr({"global_symbol": "main"}) for i0_1 in T.thread_binding(3, thread="threadIdx.x"): for i0_0 in T.vectorized(2): - with T.block("block"): + with T.sblock("block"): vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1) B[vi0] = T.Select((vi0 % 2) == 0, A[vi0], T.float32(0)) @@ -156,7 +156,7 @@ def test_vectorized_uint8(): def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): for i in T.thread_binding(4, thread="threadIdx.x"): for j in T.vectorized(4): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial(16, i * 4 + j) B[vi] = T.Cast("float32", A[vi]) @@ -176,7 +176,7 @@ def test_func_with_trailing_pod_params(): @T.prim_func def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32): for i in T.thread_binding(16, thread="threadIdx.x"): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial(16, i) B[vi] = A[vi] + x diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index 3e0fe7e31e50..b1534cfc1e83 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -36,7 +36,7 @@ def check_if_then_else(dev, n, dtype): func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) a = tvm.runtime.empty((n,), A.dtype, dev) @@ -53,7 +53,7 @@ def check_select(dev, n, dtype): C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) @@ -83,7 +83,7 @@ def check_inf_nan(dev, n, value, dtype): C = te.compute((n,), lambda i: inf_value, name="C") func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) a = tvm.runtime.empty((n,), A.dtype, dev) @@ -111,7 +111,7 @@ def check_max(dev, n, dtype): C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) @@ -136,7 +136,7 @@ def check_erf(dev, n, dtype): C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") func = te.create_prim_func([A, C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) @@ -173,7 +173,7 @@ def check_type_casting(ctx, n, dtype): # NOTE: test simple convert pattern func = te.create_prim_func([C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) tx, vx = sch.split(x, factors=[None, block_size]) sch.bind(tx, "threadIdx.x") sch.vectorize(vx) @@ -207,7 +207,7 @@ def _check(target, n, dtype): ) func = te.create_prim_func([C]) sch = tvm.tir.Schedule(func) - (x,) = sch.get_loops(sch.get_block("C")) + (x,) = sch.get_loops(sch.get_sblock("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index 9e2d18e109f9..008c7c2108c0 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -59,7 +59,7 @@ def rvv_with_vscale(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle): A = T.match_buffer(A_handle, (8,), dtype="float32", align=4, offset_factor=1) B = T.match_buffer(B_handle, (4, 8), dtype="float32", align=4, offset_factor=1, strides=[8, 1]) C = T.match_buffer(C_handle, (4,), dtype="float32", align=4, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(A[0:8], B[0:4, 0:8]) zero = T.call_llvm_intrin("float32xvscalex2", "llvm.riscv.vfmv.v.f", T.Broadcast(T.float32(0.0), T.vscale() * 2), C[0], T.uint64(1)) vec_A = T.call_llvm_intrin("float32xvscalex4", "llvm.riscv.vle", T.Broadcast(T.float32(0.0), T.vscale() * 4), T.tvm_access_ptr(T.type_annotation("float32"), A.data, 0, 8, 1), T.int64(8)) diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index cdd84fc57ae1..42dd7cdf9c25 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -98,7 +98,7 @@ def func( for bx in T.thread_binding(1, thread="blockIdx.x"): for tx in T.thread_binding(32, thread="threadIdx.x"): - with T.block("test"): + with T.sblock("test"): A_local = T.alloc_buffer((1,), "float32", scope="local") mask = T.alloc_buffer((1,), "uint32", scope="local") t0 = T.alloc_buffer((1,), "float32", scope="local") @@ -126,7 +126,7 @@ def func( for bx in T.thread_binding(1, thread="blockIdx.x"): for tx in T.thread_binding(1, thread="threadIdx.x"): - with T.block("test"): + with T.sblock("test"): for i in T.vectorized(0, 4): B[i] = T.exp2(A[i]) diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index cf7b46692661..9e6dc52a4961 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -416,7 +416,7 @@ def test_negative_operand_divmod(target, dev): @T.prim_func def func(A: T.Buffer((N, 2), "int32")): for i in T.serial(N): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.spatial(N, i) A[v_i, 0] = T.floordiv(v_i - offset, divisor) A[v_i, 1] = T.floormod(v_i - offset, divisor) @@ -464,7 +464,7 @@ def get_matmul(m, n, k, out_dtype="float32"): M, N, K = 16, 16, 32 func = get_matmul(M, N, K, out_dtype) sch = Schedule(func) - block = sch.get_block("compute") + block = sch.get_sblock("compute") i, j, k = sch.get_loops(block) i_outer, i_inner = sch.split(i, factors=[None, 16]) @@ -596,7 +596,7 @@ def run_test(tvm_intrin, np_func): mod = te.create_prim_func([A, B]) sch = tir.Schedule(mod) - block = sch.get_block("B") + block = sch.get_sblock("B") loop = sch.get_loops(block)[0] bx, tx = sch.split(loop, factors=[None, 64]) sch.bind(bx, "blockIdx.x") diff --git a/tests/python/contrib/test_android/test_meta_schedule.py b/tests/python/contrib/test_android/test_meta_schedule.py index eac5fab30357..7907a1933d24 100644 --- a/tests/python/contrib/test_android/test_meta_schedule.py +++ b/tests/python/contrib/test_android/test_meta_schedule.py @@ -35,7 +35,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 4718fa7e0671..d035b5d6feed 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -30,7 +30,7 @@ def ceildiv(o, d): # defines inner block shape: 8h8w32c -def get_block_shape(): +def get_sblock_shape(): return 8, 8, 32 @@ -44,7 +44,7 @@ def get_filter_block_shape(): def get_packed_shape(logical_shape_nhwc): assert len(logical_shape_nhwc) == 4 physical_shape_nhwc8h8w32c = [logical_shape_nhwc[0]] - block_shape = get_block_shape() + block_shape = get_sblock_shape() off_h, off_w, off_c = block_shape physical_shape_nhwc8h8w32c.append(ceildiv(logical_shape_nhwc[1], off_h)) physical_shape_nhwc8h8w32c.append(ceildiv(logical_shape_nhwc[2], off_w)) @@ -158,7 +158,7 @@ def conv2d_verify(output, ref_output, dtype): def conv2d_compute(X, filt, pad, stride, dilation): """Define conv2d compute""" - block_shape = get_block_shape() + block_shape = get_sblock_shape() block_H, block_W, block_C = block_shape filter_c_io, _, filter_c_ii = get_filter_block_shape() filter_c_i = filter_c_io * filter_c_ii diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index e5fc783510ac..c22eeb34ad4a 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -43,7 +43,7 @@ def conv2d_async_non_contig( # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") p0_global_vtcm = T.alloc_buffer( [T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)], dtype="uint8", @@ -64,7 +64,7 @@ def conv2d_async_non_contig( }, ): for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(1600)): - with T.block("p0_global.vtcm"): + with T.sblock("p0_global.vtcm"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(1), T.int64(0)) v2 = T.axis.spatial( @@ -79,7 +79,7 @@ def conv2d_async_non_contig( T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(1152)): - with T.block("fused_constant_global.vtcm"): + with T.sblock("fused_constant_global.vtcm"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(1), T.int64(0)) v2 = T.axis.spatial( @@ -100,7 +100,7 @@ def conv2d_async_non_contig( ] for oh_1, ow_1 in T.grid(T.int64(3), T.int64(6)): for oh_2_init, ow_2_init in T.grid(T.int64(6), T.int64(3)): - with T.block("conv2d_NCHWc_int8_o_init"): + with T.sblock("conv2d_NCHWc_int8_o_init"): v_n = T.axis.spatial(T.int64(1), T.int64(0)) v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) v_oh = T.axis.spatial( @@ -114,7 +114,7 @@ def conv2d_async_non_contig( conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] ) for oc_block_1 in T.vectorized(T.int64(32)): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1) T.reads() T.writes( @@ -128,7 +128,7 @@ def conv2d_async_non_contig( for kh_1, kw_1, oh_2, ow_2 in T.grid( T.int64(3), T.int64(3), T.int64(6), T.int64(3) ): - with T.block("conv2d_NCHWc_int8_o_update"): + with T.sblock("conv2d_NCHWc_int8_o_update"): v_n = T.axis.spatial(T.int64(1), T.int64(0)) v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) v_oh = T.axis.spatial( @@ -226,7 +226,7 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8") c_buffer = T.match_buffer(c_output, out_shape, dtype="int32") for n, index_0 in T.grid(size_a, size_w): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_index, vi_index = T.axis.remap("SR", [n, index_0]) T.reads( a_buffer[vn_index, 0:VRMPY_SIZE_B], @@ -301,7 +301,7 @@ def get_fake_conv_vtcm_schedule(size_a, size_w, blocks=2): """Generate fake conv schedule with VTCM.""" sch = conv_approximation(size_a, size_w) - compute_block = sch.get_block("c_buffer") + compute_block = sch.get_sblock("c_buffer") sch.cache_read(compute_block, 1, "global.vtcm") n = sch.get_loops(compute_block)[0] @@ -322,7 +322,7 @@ def get_multi_input_fake_conv_vtcm_schedule(size_a, size_w, blocks=2): """Generate multi input fake Conv using VTCM.""" sch = conv_approximation(size_a, size_w) - compute_block = sch.get_block("c_buffer") + compute_block = sch.get_sblock("c_buffer") n = sch.get_loops(compute_block)[0] n_outer, _ = sch.split(n, [blocks, None]) @@ -425,7 +425,7 @@ def test_loading_vtcm_for_vrmpy( ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_block("c_buffer"))[0] + n = sch.get_loops(sch.get_sblock("c_buffer"))[0] sch.annotate(n, "software_pipeline_stage", [0, 1, 2]) sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [0]) @@ -440,7 +440,7 @@ def test_loading_vtcm_for_vrmpy( ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_block("c_buffer"))[0] + n = sch.get_loops(sch.get_sblock("c_buffer"))[0] sch.annotate(n, "software_pipeline_stage", [0, 1, 2]) sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [0, 2]) @@ -455,7 +455,7 @@ def test_loading_vtcm_for_vrmpy( ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_block("c_buffer"))[0] + n = sch.get_loops(sch.get_sblock("c_buffer"))[0] sch.annotate(n, "software_pipeline_stage", [0, 3, 6]) sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [0, 6]) @@ -470,7 +470,7 @@ def test_loading_vtcm_for_vrmpy( ) sch = get_multi_input_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_block("c_buffer"))[0] + n = sch.get_loops(sch.get_sblock("c_buffer"))[0] sch.annotate(n, "software_pipeline_stage", [0, 0, 1, 2]) sch.annotate(n, "software_pipeline_order", [0, 1, 2, 3]) sch.annotate(n, "software_pipeline_async_stages", [0, 2]) @@ -485,7 +485,7 @@ def test_loading_vtcm_for_vrmpy( ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_block("c_buffer"))[0] + n = sch.get_loops(sch.get_sblock("c_buffer"))[0] sch.annotate(n, "software_pipeline_stage", [0, 1, 2]) sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [2]) @@ -542,12 +542,12 @@ def main( # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") conv2d_nchwc_int8 = T.alloc_buffer([1, 2, 112, 112, 32], dtype="int32", scope="global.vtcm") p0_global_vtcm = T.alloc_buffer([1, 1, 230, 230, 4], dtype="uint8", scope="global.vtcm") p1_global_vtcm = T.alloc_buffer([2, 1, 7, 7, 1, 32, 4], dtype="int8", scope="global.vtcm") for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(2, 1, 7, 7, 1, 32, 4): - with T.block("p1_global.vtcm"): + with T.sblock("p1_global.vtcm"): v0_ind, v1_ind, v2_ind, v3_ind, v4_ind, v5_ind, v6_ind = T.axis.remap( "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6] ) @@ -558,7 +558,7 @@ def main( ] for p_outer in T.serial(4): for index_0 in T.serial(55876): - with T.block("p0_global.vtcm"): + with T.sblock("p0_global.vtcm"): v0_ind = T.axis.spatial(1, 0) v1_ind = T.axis.spatial(1, 0) v2_ind = T.axis.spatial(230, p_outer * 56 + index_0 // 916) @@ -571,7 +571,7 @@ def main( ] for index_0 in T.parallel(28): for index_1, index_2, index_3 in T.grid(2, 14, 8): - with T.block("conv2d_NCHWc_int8_o_init"): + with T.sblock("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(2, index_1) o_height = T.axis.spatial( @@ -582,7 +582,7 @@ def main( T.reads() T.writes(conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32]) for i4_1 in T.vectorized(32): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(32, i4_1) T.reads() T.writes( @@ -594,7 +594,7 @@ def main( n, oc_chunk, o_height, o_width, oc_block_i_init ] = 0 for i1_1, i5_1, i6_1, i2_2, i3_2 in T.grid(2, 7, 7, 14, 8): - with T.block("conv2d_NCHWc_int8_o_update"): + with T.sblock("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(2, i1_1) o_height = T.axis.spatial(112, (p_outer * 28 + index_0) // 14 * 14 + i2_2) @@ -660,7 +660,7 @@ def main( dtype="int32x32", ) for index_0 in T.serial(200704): - with T.block("conv2d_nchwc_int8.vtcm"): + with T.sblock("conv2d_nchwc_int8.vtcm"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(2, index_0 % 7168 // 3584) ax2_1 = T.axis.spatial( @@ -694,7 +694,7 @@ def main( T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # buffer definition # body - # with T.block("root") + # with T.sblock("root") conv2d_nchwc_int8 = T.alloc_buffer([1, 2, 112, 112, 32], dtype="int32") for i0_0_i1_0_i2_0_i3_0_fused in T.parallel( 112, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} @@ -703,7 +703,7 @@ def main( for i1_1_init, i2_1_init, i3_1_init, i1_2_init, i2_2_init, i3_2_init in T.grid( 2, 1, 1, 1, 14, 8 ): - with T.block("conv2d_NCHWc_int8_o_init"): + with T.sblock("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(2, i1_1_init + i1_2_init) o_height = T.axis.spatial( @@ -716,7 +716,7 @@ def main( T.reads() T.writes(conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32]) for i4_1 in T.vectorized(32): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(32, i4_1) T.reads() T.writes( @@ -747,7 +747,7 @@ def main( i3_2, i4_0_2, # pylint: disable=unused-variable ) in T.grid(1, 2, 1, 1, 1, 7, 7, 1, 1, 1, 1, 1, 14, 8, 1): - with T.block("conv2d_NCHWc_int8_o_update"): + with T.sblock("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(2, i1_1 + i1_2) o_height = T.axis.spatial( @@ -815,7 +815,7 @@ def main( ) for ax0, ax1, ax2, ax3 in T.grid(1, 2, 14, 8): for ax4_fused in T.vectorized(32): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): ax0_1, ax1_1 = T.axis.remap("SS", [ax0, ax1]) ax2_1 = T.axis.spatial( 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + ax2 @@ -845,7 +845,7 @@ def test_meta(hexagon_session): base_runtime = evaluate(hexagon_session, sch, a_data, w_data, c_data) sch = tvm.tir.Schedule(ModulePipelined) - compute_block = sch.get_block("conv2d_NCHWc_int8_o_update") + compute_block = sch.get_sblock("conv2d_NCHWc_int8_o_update") outer = sch.get_loops(compute_block)[0] unscheduled_vtcm_runtime = evaluate( @@ -853,7 +853,7 @@ def test_meta(hexagon_session): ) sch = tvm.tir.Schedule(ModulePipelined) - compute_block = sch.get_block("conv2d_NCHWc_int8_o_update") + compute_block = sch.get_sblock("conv2d_NCHWc_int8_o_update") outer = sch.get_loops(compute_block)[0] sch.annotate(outer, "software_pipeline_stage", [0, 1, 2]) diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py index 1592bd020fd6..e0ded63c1f1d 100644 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py @@ -43,7 +43,7 @@ def compute_add_in_vtcm(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (m,), data_type, scope="global.vtcm") C = T.match_buffer(c, (m,), data_type, scope="global.vtcm") for ax0 in T.grid(m): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.remap("S", [ax0]) T.reads(A[v_ax0], B[v_ax0]) T.writes(C[v_ax0]) diff --git a/tests/python/contrib/test_hexagon/test_memory_alloc.py b/tests/python/contrib/test_hexagon/test_memory_alloc.py index ae086a2b4a02..f332df493b5f 100644 --- a/tests/python/contrib/test_hexagon/test_memory_alloc.py +++ b/tests/python/contrib/test_hexagon/test_memory_alloc.py @@ -35,7 +35,7 @@ def elwise(a: T.handle, b: T.handle): b_buffer = T.match_buffer(b, shape, dtype=dtype, axis_separators=axis_separators) for i, j in T.grid(dim0, dim1): - with T.block("compute"): + with T.sblock("compute"): b_buffer[i, j] = a_buffer[i, j] * T.cast(2, dtype=dtype) return elwise diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index 5d9f4128d172..d681bb1ea1c4 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -55,7 +55,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore b_buffer = T.match_buffer(b, (16, 16), "float32") c_buffer = T.match_buffer(c, (16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi_axis, vj_axis, vk_axis = T.axis.remap("SSR", [i, j, k]) with T.init(): c_buffer[vi_axis, vj_axis] = 0.0 # type: ignore @@ -201,13 +201,13 @@ def test_vrmpy_dense(hexagon_launcher): if not do_tune: ir_module = tvm.IRModule({"main": workload}) sch = tvm.tir.Schedule(ir_module) - block = sch.get_block("compute") + block = sch.get_sblock("compute") schedule_dense(sch, block, m_size, do_tune) else: with tempfile.TemporaryDirectory() as work_dir: def schedule_dense_for_tune(sch): - block = sch.get_block("compute") + block = sch.get_sblock("compute") return schedule_dense(sch, block, None, True) target = get_hexagon_target("v69") @@ -251,19 +251,19 @@ def main( # type: ignore 512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} ): for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1): - with T.block("compute_o_init"): + with T.sblock("compute_o_init"): i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init) j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init) T.reads() T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) # type: ignore for i1_1 in T.vectorized(32): - with T.block("compute_init"): + with T.sblock("compute_init"): j_i_init = T.axis.spatial(32, i1_1) T.reads() T.writes(compute[i, j_o * 32 + j_i_init]) compute[i, j_o * 32 + j_i_init] = 0 # type: ignore for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1): - with T.block("compute_o_update"): + with T.sblock("compute_o_update"): i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2) j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1) k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 6abfa812175f..6d55850168ee 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -81,7 +81,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") c_buffer = T.match_buffer(c, [operations, 128], dtype="int16") for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vmpybusv.128B"), @@ -103,7 +103,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") c_buffer = T.match_buffer(c, [operations, 128], dtype="int16") for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vaddubh.128B"), @@ -125,7 +125,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") c_buffer = T.match_buffer(c, [operations, 32], dtype="int32") for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), @@ -215,7 +215,7 @@ def test( ) sch = tvm.tir.Schedule(operator_producer(operation_count)) - block = sch.get_block("c_buffer") + block = sch.get_sblock("c_buffer") b = sch.get_loops(block) b_output, _ = sch.split(b[0], factors=[split_factor, None]) sch.parallel(b_output) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index ceabc6355732..b38ba47d9b76 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -58,7 +58,7 @@ def apply_unroll_vectorize(sch, blocks, unroll_split, vector_split): def apply_vrmpy_parallelization(sch): - block = sch.get_block("c_buffer") + block = sch.get_sblock("c_buffer") b = sch.get_loops(block) b_outer, _ = sch.split(b[0], factors=[4, None]) sch.parallel(b_outer) @@ -66,7 +66,7 @@ def apply_vrmpy_parallelization(sch): def apply_vtcm_cache_read_write(sch): - block = sch.get_block("c_buffer") + block = sch.get_sblock("c_buffer") sch.cache_read(block, 0, "global.vtcm") sch.cache_read(block, 1, "global.vtcm") sch.cache_write(block, 0, "global.vtcm") @@ -83,7 +83,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128) c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128) for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), @@ -119,7 +119,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, scope="global.vtcm" ) for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), @@ -154,15 +154,15 @@ def operator( c_v, [out_size], dtype="int32", align=128, scope="global.vtcm" ) for n, i in T.grid(operations, 128): - with T.block("a_buffer_global.vtcm"): + with T.sblock("a_buffer_global.vtcm"): vn_ind, vi_index = T.axis.remap("SS", [n, i]) a_global_vtcm[vn_ind * 128 + vi_index] = a_buffer[vn_ind, vi_index] for n, i in T.grid(operations, 128): - with T.block("b_buffer_global.vtcm"): + with T.sblock("b_buffer_global.vtcm"): vn_ind, vi_index = T.axis.remap("SS", [n, i]) b_global_vtcm[vn_ind * 128 + vi_index] = b_buffer[vn_ind, vi_index] for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), @@ -177,7 +177,7 @@ def operator( dtype="int32x32", ) for n, i in T.grid(operations, 32): - with T.block("c_buffer_global.vtcm"): + with T.sblock("c_buffer_global.vtcm"): vn_ind, vi_index = T.axis.remap("SS", [n, i]) c_buffer[vn_ind, vi_index] = c_global_vtcm[vn_ind * 32 + vi_index] @@ -260,7 +260,7 @@ def operator( ) ) for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), @@ -446,12 +446,12 @@ def test_loading_vtcm_for_vrmpy( sch = apply_vrmpy_parallelization(sch) sch = apply_unroll_vectorize( sch, - [sch.get_block("a_buffer_global.vtcm"), sch.get_block("b_buffer_global.vtcm")], + [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], unroll_split, vector_split, ) sch = apply_unroll_vectorize( - sch, [sch.get_block("c_buffer_global.vtcm")], unroll_split, c_vector_split_unallocated + sch, [sch.get_sblock("c_buffer_global.vtcm")], unroll_split, c_vector_split_unallocated ) vectorized_runtime, result = setup_and_run( hexagon_session, sch, input_a, input_b, input_c, operations @@ -464,14 +464,14 @@ def test_loading_vtcm_for_vrmpy( sch = apply_vrmpy_parallelization(sch) sch = apply_parallel_unroll_vectorize( sch, - [sch.get_block("a_buffer_global.vtcm"), sch.get_block("b_buffer_global.vtcm")], + [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], outer_split, unroll_split, vector_split, ) sch = apply_parallel_unroll_vectorize( sch, - [sch.get_block("c_buffer_global.vtcm")], + [sch.get_sblock("c_buffer_global.vtcm")], outer_split, unroll_split, c_vector_split_unallocated, @@ -486,12 +486,12 @@ def test_loading_vtcm_for_vrmpy( sch = apply_vrmpy_parallelization(sch) sch = apply_unroll_vectorize( sch, - [sch.get_block("a_buffer_global.vtcm"), sch.get_block("b_buffer_global.vtcm")], + [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], unroll_split, vector_split, ) sch = apply_unroll_vectorize( - sch, [sch.get_block("c_buffer_global.vtcm")], unroll_split, c_vector_split + sch, [sch.get_sblock("c_buffer_global.vtcm")], unroll_split, c_vector_split ) preallocated_vectorized_runtime, result = setup_and_run_preallocated( hexagon_session, sch, input_a, input_b, input_c, operations @@ -504,13 +504,13 @@ def test_loading_vtcm_for_vrmpy( sch = apply_vrmpy_parallelization(sch) sch = apply_parallel_unroll_vectorize( sch, - [sch.get_block("a_buffer_global.vtcm"), sch.get_block("b_buffer_global.vtcm")], + [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], outer_split, unroll_split, vector_split, ) sch = apply_parallel_unroll_vectorize( - sch, [sch.get_block("c_buffer_global.vtcm")], outer_split, unroll_split, c_vector_split + sch, [sch.get_sblock("c_buffer_global.vtcm")], outer_split, unroll_split, c_vector_split ) prealloc_vector_parallelized, result = setup_and_run_preallocated( hexagon_session, sch, input_a, input_b, input_c, operations diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index 60731a8febe0..64150f2e5926 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -41,7 +41,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations], dtype="float64") c_buffer = T.match_buffer(c, [operations], dtype="float64") for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind] = a_buffer[vn_ind] + b_buffer[vn_ind] @@ -58,7 +58,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations], dtype="float64") c_buffer = T.match_buffer(c, [operations], dtype="float64") for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind] = a_buffer[vn_ind] * b_buffer[vn_ind] @@ -75,7 +75,7 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: b_buffer = T.match_buffer(b, [operations], dtype="float64") c_buffer = T.match_buffer(c, [operations], dtype="float64") for n in T.grid(operations): - with T.block("c_buffer"): + with T.sblock("c_buffer"): vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind] = a_buffer[vn_ind] - b_buffer[vn_ind] @@ -153,7 +153,7 @@ def test_add( single_thread_runtime = evaluate(hexagon_session, operations, expected_output_producer, sch) sch = tvm.tir.Schedule(operator_producer(operations)) - block = sch.get_block("c_buffer") + block = sch.get_sblock("c_buffer") b = sch.get_loops(block) b_output, _ = sch.split(b[0], factors=[split_factor, None]) sch.parallel(b_output) diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py index 8a56e91581cb..277d9ed75da4 100644 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py @@ -40,7 +40,7 @@ def add( T.func_attr({"operator_name": "relax.add"}) for ax0 in range(2): for ax1 in range(2): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(2, ax0) v_ax1 = T.axis.spatial(2, ax1) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index b873f606e619..f2831f6508d2 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -35,7 +35,7 @@ def sigmoid_compute(sigmoid_input): def sigmoid_stir_schedule(sigmoid_input, sigmoid_output): sigmoid_func = te.create_prim_func([sigmoid_input, sigmoid_output]) sch = tir.Schedule(sigmoid_func, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") (n,) = sch.get_loops(block) sch.vectorize(n) diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index b4d2aed433b9..9835b2d8f48c 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -35,8 +35,8 @@ def a_plus_1_primfunc( ): for i in T.serial(outer): for j in T.serial(inner): - with T.block("compute"): - with T.block(): + with T.sblock("compute"): + with T.sblock(): out[i, j] = a_buffer[i, j] + T.cast(1, dtype) return a_plus_1_primfunc @@ -50,8 +50,8 @@ def a_plus_b_plus_1_primfunc( ): for i in T.serial(outer): for j in T.serial(inner): - with T.block("compute"): - with T.block(): + with T.sblock("compute"): + with T.sblock(): out[i, j] = a_buffer[i, j] + b_buffer[i, j] + T.cast(1, dtype) return a_plus_b_plus_1_primfunc @@ -114,7 +114,7 @@ def schedule(self, comp_type, sched_type, outer, inner, dtype, scope): """Generate schedule.""" sch = tir.Schedule(compute(comp_type, outer, inner, dtype)) - compute_block = sch.get_block("compute") + compute_block = sch.get_sblock("compute") i, _ = sch.get_loops(compute_block) if "read" in sched_type: diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py b/tests/python/contrib/test_hexagon/test_thread_pool.py index f61a2560cfad..34cb79f67ac7 100644 --- a/tests/python/contrib/test_hexagon/test_thread_pool.py +++ b/tests/python/contrib/test_hexagon/test_thread_pool.py @@ -41,7 +41,7 @@ def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32): B = T.match_buffer(b, (n,), dtype="float32") C = T.match_buffer(c, (n,), dtype="float32") for i in T.serial(n): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(n, i) C[vi] = A[vi] + B[vi] @@ -52,7 +52,7 @@ def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): B = T.match_buffer(b, (n,), dtype="float32") C = T.match_buffer(c, (n,), dtype="float32") for i in T.parallel(n): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(n, i) C[vi] = A[vi] + B[vi] diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index eec48a972ea2..ee48f4e00532 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -29,14 +29,14 @@ def scale_by_two(buffer_a: T.Buffer((8192,), "int8"), buffer_c: T.Buffer((8192,) 0, 8192, ): - with T.block("C"): + with T.sblock("C"): buffer_c[i] = buffer_a[i] * T.int8(2) def get_scale_by_two_schedule(): mod = tvm.IRModule.from_expr(scale_by_two.with_attr("global_symbol", "main")) sch = tir.Schedule(mod, debug_mask="all") - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") (flat,) = sch.get_loops(block_c) outer, _, _, _ = sch.split(flat, factors=[8, 4, 2, 128]) cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm") diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 42fca9c153aa..808cb401b907 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -45,7 +45,7 @@ def operator(a: T.handle, a_v: T.handle) -> None: a_buffer = T.match_buffer(a, size, dtype="int8", align=128, scope="global") a_global_vtcm = T.match_buffer(a_v, size, dtype="int8", align=128, scope="global.vtcm") for ax0 in T.serial(size): - with T.block("A_global.vtcm"): + with T.sblock("A_global.vtcm"): v0_ind = T.axis.spatial(size, ax0) T.reads(a_buffer[v0_ind]) T.writes(a_global_vtcm[v0_ind]) @@ -149,7 +149,7 @@ def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vecto # Run with some basic unroll and vectorize scheduling. sch = tvm.tir.Schedule(memcopy_operator(size)) - vtcm_block_a = sch.get_block("A_global.vtcm") + vtcm_block_a = sch.get_sblock("A_global.vtcm") v_block = sch.get_loops(vtcm_block_a) _, vio_a, vii_a = sch.split(v_block[0], factors=[None, unroll_split, vector_split]) sch.unroll(vio_a) @@ -158,7 +158,7 @@ def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vecto # Run with some basic unroll and vectorize scheduling and parallelization. sch = tvm.tir.Schedule(memcopy_operator(size)) - vtcm_block_a = sch.get_block("A_global.vtcm") + vtcm_block_a = sch.get_sblock("A_global.vtcm") v_block = sch.get_loops(vtcm_block_a) vbo_a, _, vio_a, vii_a = sch.split( v_block[0], factors=[outer_split, None, unroll_split, vector_split] @@ -170,7 +170,7 @@ def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vecto # Run with some basic unroll and vectorize scheduling and parallelization. sch = tvm.tir.Schedule(memcopy_operator(size)) - block = sch.get_block("A_global.vtcm") + block = sch.get_sblock("A_global.vtcm") loops = sch.get_loops(block) _, inner = sch.split(loops[0], [None, 128]) sch.tensorize(inner, DMA_READ_128_i8) diff --git a/tests/python/contrib/test_tir_triton_integration.py b/tests/python/contrib/test_tir_triton_integration.py index 95ccf28fbddb..9ad613a98cee 100644 --- a/tests/python/contrib/test_tir_triton_integration.py +++ b/tests/python/contrib/test_tir_triton_integration.py @@ -63,7 +63,7 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None x = T.match_buffer(x_handle, (m,), "float32") y = T.match_buffer(y_handle, (m,), "float32") output = T.match_buffer(output_handle, (m,), "float32") - with T.block("root"): + with T.sblock("root"): T.reads(x[0:m], y[0:m]) T.writes(output[0:m]) BLOCK_SIZE = T.meta_var(64) @@ -93,7 +93,7 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): x = T.match_buffer(x_handle, (m,)) y = T.match_buffer(y_handle, (m,)) output = T.match_buffer(output_handle, (m,)) - with T.block("root"): + with T.sblock("root"): T.reads(x[0:m], y[0:m]) T.writes(output[0:m]) T.call_packed( diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index b98b49591d09..76de7d62724d 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -157,7 +157,7 @@ def test_nvshmem_compile(): def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): for i in T.thread_binding(T.int64(8), thread="threadIdx.y"): for j in T.thread_binding(T.int64(16), thread="threadIdx.x"): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v0 = T.axis.spatial(T.int64(8), i) v1 = T.axis.spatial(T.int64(16), j) T.reads(A[v0, v1]) @@ -226,7 +226,7 @@ def query_pe( my_pe_out: T.Buffer((1,), "int32"), n_pes_out: T.Buffer((1,), "int32"), ): - with T.block("root"): + with T.sblock("root"): T.reads() T.writes(my_pe_out[0:1], n_pes_out[0:1]) T.call_kernel( diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 721115947480..7362a5d1f130 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -201,7 +201,7 @@ class TestMod: @T.prim_func def transpose(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): for i, j in T.grid(16, 8): - with T.block("transpose"): + with T.sblock("transpose"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @@ -245,14 +245,14 @@ class TestMod: @T.prim_func def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): for i, j in T.grid(16, 8): - with T.block("t1"): + with T.sblock("t1"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @T.prim_func def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): for i, j in T.grid(8, 16): - with T.block("t2"): + with T.sblock("t2"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] diff --git a/tests/python/dlight/test_benchmark.py b/tests/python/dlight/test_benchmark.py index b95a0695a585..f7db509bb46e 100644 --- a/tests/python/dlight/test_benchmark.py +++ b/tests/python/dlight/test_benchmark.py @@ -47,9 +47,9 @@ def full1(var_T_full: T.handle): T.func_attr({"op_pattern": 0, "tir.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads() T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -60,9 +60,9 @@ def full2(var_T_full: T.handle): T.func_attr({"op_pattern": 0, "tir.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads() T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -74,9 +74,9 @@ def matmul1(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.in n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -105,7 +105,7 @@ def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(409 m = T.int64() inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) - # with T.block("root"): + # with T.sblock("root"): matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared") inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared") @@ -117,7 +117,7 @@ def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(409 for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init) @@ -129,28 +129,28 @@ def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(409 for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(2)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("inp0_reindex_pad_shared"): + with T.sblock("inp0_reindex_pad_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(inp0[v0, v1, v2]) T.writes(inp0_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(4)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("inp1_reindex_shared"): + with T.sblock("inp1_reindex_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(inp1[v2, v1]) T.writes(inp1_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1] for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3) @@ -160,7 +160,7 @@ def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(409 matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3] for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(4), T.int64(2)): for ax2_1_1 in T.vectorized(T.int64(2)): - with T.block("matmul_reindex_pad_local"): + with T.sblock("matmul_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0_1) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1) diff --git a/tests/python/dlight/test_cpu_gemv.py b/tests/python/dlight/test_cpu_gemv.py index dc09e3e2f3e7..54587149c673 100644 --- a/tests/python/dlight/test_cpu_gemv.py +++ b/tests/python/dlight/test_cpu_gemv.py @@ -43,13 +43,13 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") var_T_divide_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") var_T_maximum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") var_T_minimum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") for i0, i1, i2, i3, k in T.grid(1, 32, 1, n, 128): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(lv1637[v_i0, v_i1, v_i2, v_k], lv1638[v_i0, v_i1, v_i3, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -57,25 +57,25 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1637[v_i0, v_i1, v_i2, v_k] * lv1638[v_i0, v_i1, v_i3, v_k] for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): - with T.block("T_maximum"): + with T.sblock("T_maximum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): - with T.block("T_minimum"): + with T.sblock("T_minimum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) for i0, i1, i2, i3 in T.grid(1, 32, 1, n): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -88,14 +88,14 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") for ax0_fused in range(32): for ax1_fused_0 in T.parallel((n + 63) // 64): for ax1_fused_1 in T.vectorized(64): for ax2_fused_0 in T.serial(2, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1, u_0, u_1 in T.grid(64, 1, 1): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v0 = T.axis.spatial(32, ax0_fused) v1 = T.axis.spatial(n, ax1_fused_0 * 64 + ax1_fused_1) v2 = T.axis.reduce(128, ax2_fused_0 * 64 + ax2_fused_1) @@ -106,7 +106,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p var_NT_matmul_intermediate[0, v0, 0, v1] = T.float16(0.0) var_NT_matmul_intermediate[0, v0, 0, v1] = var_NT_matmul_intermediate[0, v0, 0, v1] + lv1637[0, v0, 0, v2] * lv1638[0, v0, v1, v2] for ax0, ax1 in T.grid(32, n): - with T.block("compute"): + with T.sblock("compute"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(var_NT_matmul_intermediate[0, v0, 0, v1], lv1614[0, 0, 0, v1]) T.writes(var_compute_intermediate[0, v0, 0, v1]) @@ -120,16 +120,16 @@ def test_decode_gemv_256_threads(): @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) T.writes(p_output0_intermediate[v_i, v_j]) p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -140,13 +140,13 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for u_fused in range(1): for ax0_fused_0 in T.parallel(172): for ax0_fused_1 in T.vectorized(128): for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v0 = T.axis.spatial(22016, ax0_fused_0 * 128 + ax0_fused_1) v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) T.reads(lv1654[0, 0, v1], lv571[v0, v1 // 8], lv572[v0, v1 // 32]) @@ -168,16 +168,16 @@ def test_decode_gemv1(): @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) T.writes(p_output0_intermediate[v_i, v_j]) p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -188,13 +188,13 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for u_fused in range(1): for ax0_fused_0 in T.parallel(172): for ax0_fused_1 in T.vectorized(128): for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v0 = T.axis.spatial(22016, ax0_fused_0 * 128 + ax0_fused_1) v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) T.reads(lv1654[0, 0, v1], lv571[v0, v1 // 8], lv572[v0, v1 // 32]) @@ -216,17 +216,17 @@ def test_decode_gemv2(): @T.prim_func(private=True) def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((32000, 4096), "float16") var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") for i, j in T.grid(32000, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv771[v_i, v_j // 8], lv772[v_i, v_j // 32]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 32000, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv3216[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -234,7 +234,7 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv3216[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] for i0, i1, i2 in T.grid(1, 1, 32000): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) @@ -243,14 +243,14 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) @T.prim_func(private=True) def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") for u_fused in range(1): for ax0_fused_0 in T.parallel(250): for ax0_fused_1 in T.vectorized(128): for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v0 = T.axis.spatial(32000, ax0_fused_0 * 128 + ax0_fused_1) v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) T.reads(lv3216[0, 0, v1], lv771[v0, v1 // 8], lv772[v0, v1 // 32]) @@ -259,7 +259,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv3216[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv772[v0, v1 // 32]) for ax0 in range(32000): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(32000, ax0) T.reads(var_NT_matmul_intermediate[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) @@ -278,17 +278,17 @@ def test_decode_gemv3(): @T.prim_func(private=True) def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j // T.int64(32)]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -296,7 +296,7 @@ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.B var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv570[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -305,14 +305,14 @@ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.B @T.prim_func(private=True) def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") for u_fused in range(1): for ax0_fused_0 in T.parallel(T.int64(64)): for ax0_fused_1 in T.vectorized(T.int64(64)): for ax1_0_fused_0 in T.serial(T.int64(11), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(T.int64(128), T.int64(1), T.int64(8)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v0 = T.axis.spatial(T.int64(4096), ax0_fused_0 * T.int64(64) + ax0_fused_1) v1 = T.axis.reduce(T.int64(11008), (ax1_0_fused_0 * T.int64(128) + ax1_0_fused_1) * T.int64(8) + ax1_1_0 * T.int64(8) + ax1_1_1) T.where(ax1_0_fused_0 * T.int64(128) + ax1_0_fused_1 < T.int64(1376)) @@ -322,7 +322,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float16(0.0) var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + lv574[T.int64(0), T.int64(0), v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv576[v0, v1 // T.int64(32)]) for ax0 in range(T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v0 = T.axis.spatial(T.int64(4096), ax0) T.reads(lv570[T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0]) T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) @@ -340,17 +340,17 @@ def test_autogptq_decode_gemv(): @T.prim_func(private=True) def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv9[v_i // T.int64(8), v_j], lv10[lv12[v_i], v_j // T.int64(8)], lv12[v_i], lv11[lv12[v_i], v_j]) T.writes(decode_intermediate[v_i, v_j]) decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) - (T.Cast("float16", T.bitwise_and(T.shift_right(lv10[lv12[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) + T.float16(1))) * lv11[lv12[v_i], v_j] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv8[v_i0, v_i1, v_k], decode_intermediate[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -358,7 +358,7 @@ def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer( var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv8[v_i0, v_i1, v_k] * decode_intermediate[v_k, v_i2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv1613[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -384,38 +384,38 @@ def before( p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") for i, j in T.grid(11008, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15)))- T.float16(7)) * lv576[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for ax0, ax1, ax2 in T.grid(1, 1, 4096): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] @T.prim_func(private=True) def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") for i, j in T.grid(11008, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv576[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -423,7 +423,7 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for ax0, ax1, ax2 in T.grid(1, 1, 4096): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -444,17 +444,17 @@ def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") for i, j in T.grid(T.int64(4096), v): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v_i // T.int64(32), v_j] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -462,7 +462,7 @@ def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) @@ -475,17 +475,17 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") for i, j in T.grid(T.int64(4096), v): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv613[v_i // T.int64(32), v_j] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -493,7 +493,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) @@ -510,16 +510,16 @@ def test_blockized_gemv(): # fmt: off @T.prim_func(private=True) def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): - # with T.block("root"): + # with T.sblock("root"): for expert_id in T.thread_binding(2, thread="blockIdx.y"): - with T.block("gemv_o"): + with T.sblock("gemv_o"): v_expert_id_o = T.axis.spatial(2, expert_id) vi_o = T.axis.spatial(1, 0) vj_o = T.axis.reduce(1, 0) T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) T.writes(o[v_expert_id_o, 0:16384]) for i, j in T.grid(16384, 4096): - with T.block("gemv"): + with T.sblock("gemv"): vi_i, vj_i = T.axis.remap("SR", [i, j]) T.reads(x[0, vj_i], w[indptr[v_expert_id_o], vi_i, vj_i], indptr[v_expert_id_o]) T.writes(o[v_expert_id_o, vi_i]) @@ -530,9 +530,9 @@ def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "flo @T.prim_func(private=True) def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): for expert_id in T.thread_binding(2, thread="blockIdx.y"): - with T.block("gemv_o"): + with T.sblock("gemv_o"): v_expert_id_o = T.axis.spatial(2, expert_id) vi_o = T.axis.spatial(1, 0) vj_o = T.axis.reduce(1, 0) @@ -543,7 +543,7 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f for ax0_fused_1 in T.vectorized(128): for ax1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_fused_1, u_0, u_1 in T.grid(64, 1, 1): - with T.block("gemv"): + with T.sblock("gemv"): v0 = T.axis.spatial(16384, ax0_fused_0 * 128 + ax0_fused_1) v1 = T.axis.reduce(4096, ax1_fused_0 * 64 + ax1_fused_1) T.reads(x[0, v1], w[indptr[v_expert_id_o], v0, v1], indptr[v_expert_id_o]) @@ -565,7 +565,7 @@ def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: T.int6 output_buf = T.match_buffer( var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", align=8 ) - with T.block("exclusive_scan_thrust"): + with T.sblock("exclusive_scan_thrust"): T.reads() T.writes() T.call_packed( diff --git a/tests/python/dlight/test_gpu_conv.py b/tests/python/dlight/test_gpu_conv.py index 90603a8bf293..ad86225d18df 100644 --- a/tests/python/dlight/test_gpu_conv.py +++ b/tests/python/dlight/test_gpu_conv.py @@ -44,11 +44,11 @@ def before( ): pad_A = T.alloc_buffer((14308, 3, 2, 14, 14), "float16") for i0, i1, i2, i3, i4 in T.grid(14308, 3, 2, 14, 14): - with T.block("pad_A"): + with T.sblock("pad_A"): v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) pad_A[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(14308, 1280, 1, 1, 1, 3, 2, 14, 14): - with T.block("C"): + with T.sblock("C"): v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz]) with T.init(): C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0) @@ -57,7 +57,7 @@ def before( @T.prim_func def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): C_reindex_pad_local = T.alloc_buffer((1, 14336, 1280), "float16", scope="local") pad_A_reindex_pad_shared = T.alloc_buffer((1, 14336, 1184), "float16", scope="shared") W_reindex_pad_shared = T.alloc_buffer((1, 1280, 1184), "float16", scope="shared") @@ -69,7 +69,7 @@ def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3 for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(4, 2): for ax2_3_1_init in T.vectorized(2): - with T.block("C_init"): + with T.sblock("C_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) @@ -79,25 +79,25 @@ def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3 for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("pad_A_reindex_pad_shared"): + with T.sblock("pad_A_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(14336, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) pad_A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 14308 and v2 < 1176, A[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("W_reindex_pad_shared"): + with T.sblock("W_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 1176, W[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0)) for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): for ax2_3_1 in T.vectorized(2): - with T.block("C_update"): + with T.sblock("C_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) @@ -105,7 +105,7 @@ def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3 C_reindex_pad_local[0, v1, v2] = C_reindex_pad_local[0, v1, v2] + pad_A_reindex_pad_shared[0, v1, v3] * W_reindex_pad_shared[0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): - with T.block("C_reindex_pad_local"): + with T.sblock("C_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) diff --git a/tests/python/dlight/test_gpu_fallback.py b/tests/python/dlight/test_gpu_fallback.py index a4eaa3ad748c..d0fdfee0c575 100644 --- a/tests/python/dlight/test_gpu_fallback.py +++ b/tests/python/dlight/test_gpu_fallback.py @@ -33,11 +33,11 @@ def main( ): B = T.alloc_buffer((1, 1, 32, 128), "float16") for i, j, k, l in T.grid(1, 1, 32, 128): - with T.block("T_transpose"): + with T.sblock("T_transpose"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vk, vj, vl] for i, j, k in T.grid(1, 1, 4096): - with T.block("T_reshape"): + with T.sblock("T_reshape"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] @@ -51,7 +51,7 @@ def main( T.func_attr({"tir.is_scheduled": True}) for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1) T.reads(A[0, v0 // 128, 0, v0 % 128]) T.writes(C[0, 0, v0]) @@ -71,7 +71,7 @@ class Module: @T.prim_func def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), "float32")): for ax0, ax1 in T.grid(1, 6144): - with T.block("block"): + with T.sblock("block"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.reduce(6144, ax1) T.reads(A[v0, v1]) @@ -87,14 +87,14 @@ def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), "float32")): T.func_attr({"tir.is_scheduled": True}) for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): - with T.block("block_init"): + with T.sblock("block_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) T.reads() T.writes(B[0]) B[0] = T.float32(0) for ax1 in range(6144): - with T.block("block_update"): + with T.sblock("block_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.reduce(6144, ax1) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) @@ -132,7 +132,7 @@ def func( values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16") for l, h, pos in T.grid(nlayer, nhead, seqlen): - with T.block("block"): + with T.sblock("block"): vl, vh, vp = T.axis.remap("SSS", [l, h, pos]) values[vl, vh, vp] = pages[ page_table_values[page_table_indptr[seq_id] + T.floordiv(vp, page_size)], @@ -160,7 +160,7 @@ def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_tabl for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen + 1023) // 1024, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("block"): + with T.sblock("block"): v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead)) v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen) v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % seqlen) @@ -191,11 +191,11 @@ def gpu_func( ): B = T.alloc_buffer((1, 1, 32, 128), "float16") for i, j, k, l in T.grid(1, 1, 32, 128): - with T.block("T_transpose"): + with T.sblock("T_transpose"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vk, vj, vl] for i, j, k in T.grid(1, 1, 4096): - with T.block("T_reshape"): + with T.sblock("T_reshape"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] @@ -210,11 +210,11 @@ def cpu_func( T.func_attr({"target": T.target("llvm")}) B = T.alloc_buffer((1, 1, 32, 128), "float16") for i, j, k, l in T.grid(1, 1, 32, 128): - with T.block("T_transpose"): + with T.sblock("T_transpose"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vk, vj, vl] for i, j, k in T.grid(1, 1, 4096): - with T.block("T_reshape"): + with T.sblock("T_reshape"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] @@ -228,7 +228,7 @@ def gpu_func( T.func_attr({"tir.is_scheduled": True}) for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1) T.reads(A[0, v0 // 128, 0, v0 % 128]) T.writes(C[0, 0, v0]) @@ -242,11 +242,11 @@ def cpu_func( T.func_attr({"target": T.target("llvm")}) B = T.alloc_buffer((1, 1, 32, 128), "float16") for i, j, k, l in T.grid(1, 1, 32, 128): - with T.block("T_transpose"): + with T.sblock("T_transpose"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vk, vj, vl] for i, j, k in T.grid(1, 1, 4096): - with T.block("T_reshape"): + with T.sblock("T_reshape"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128] diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index b8662327eed9..818c1081b7cc 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -43,13 +43,13 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") var_T_divide_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") var_T_maximum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") var_T_minimum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") for i0, i1, i2, i3, k in T.grid(1, 32, 1, n, 128): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(lv1637[v_i0, v_i1, v_i2, v_k], lv1638[v_i0, v_i1, v_i3, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -57,25 +57,25 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1637[v_i0, v_i1, v_i2, v_k] * lv1638[v_i0, v_i1, v_i3, v_k] for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): - with T.block("T_maximum"): + with T.sblock("T_maximum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): - with T.block("T_minimum"): + with T.sblock("T_minimum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) for i0, i1, i2, i3 in T.grid(1, 32, 1, n): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -88,7 +88,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1, n), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local") @@ -102,7 +102,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p for ax3_1 in T.thread_binding(1, thread="threadIdx.y"): for ax3_2 in T.thread_binding(64, thread="threadIdx.x"): for ax3_3 in T.vectorized(2): - with T.block("lv1637_shared"): + with T.sblock("lv1637_shared"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) v2 = T.axis.spatial(1, ax2) @@ -112,7 +112,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3] for ax0_fused_ax1_fused_fused_2_init in range(1): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) // n) v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) % n) @@ -122,7 +122,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0, ax1, ax2_ax3_fused_0 in T.grid(1, 1, 1): for ax2_ax3_fused_1 in T.vectorized(2): - with T.block("lv1638_local"): + with T.sblock("lv1638_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) @@ -132,7 +132,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3] for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(1, 1): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) // n) v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) % n) @@ -144,7 +144,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p for ax0 in T.thread_binding(64, thread="threadIdx.x"): for ax2_ax3_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_ax3_fused_1_1 in T.vectorized(1): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0) v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) @@ -152,7 +152,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = T.float16(0) for ax1 in range(2): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) @@ -162,7 +162,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p for ax1_ax2_fused_1 in range(1): for ax1_ax2_fused_0 in T.thread_binding(1, thread="threadIdx.y"): for ax0 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(64, ax0) v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) @@ -173,7 +173,7 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"): for ax0_ax1_fused_1 in range(1): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) T.reads(var_NT_matmul_intermediate_local[0, v0, 0, v1], lv1614[0, 0, 0, v1]) @@ -188,16 +188,16 @@ def test_decode_gemv_256_threads(): @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) T.writes(p_output0_intermediate[v_i, v_j]) p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -208,7 +208,7 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local") lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local") @@ -217,7 +217,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(1): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() @@ -226,7 +226,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax1_0_fused_ax1_1_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused in T.serial(1): for ax0_1 in T.vectorized(1): - with T.block("lv571_local"): + with T.sblock("lv571_local"): v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0) T.reads(lv571[v0, v1]) @@ -234,7 +234,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 lv571_local[v0, v1] = lv571[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 8): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(1): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) @@ -245,14 +245,14 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax0 in T.thread_binding(16, thread="threadIdx.y"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(16, ax0) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(1): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) @@ -261,7 +261,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.x"): for ax0 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(16, ax0) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) @@ -284,16 +284,16 @@ def test_decode_gemv1(): @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) T.writes(p_output0_intermediate[v_i, v_j]) p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -304,7 +304,7 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 22016), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 22016), "float16", scope="local") lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local") @@ -317,7 +317,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax2_1 in T.thread_binding(16, thread="threadIdx.y"): for ax2_2 in T.thread_binding(32, thread="threadIdx.x"): for ax2_3 in T.vectorized(8): - with T.block("lv1654_shared"): + with T.sblock("lv1654_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(4096, ax2_0 * 4096 + ax2_1 * 256 + ax2_2 * 8 + ax2_3) T.reads(lv1654[v0, v1, v2]) @@ -325,7 +325,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 16 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() @@ -334,7 +334,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax1_0_fused_ax1_1_fused_0 in T.serial(16, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(1): for ax0_ax1_fused_1 in T.vectorized(1): - with T.block("lv571_local"): + with T.sblock("lv571_local"): v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 16 + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0) T.reads(lv571[v0, v1]) @@ -342,7 +342,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 lv571_local[v0, v1] = lv571[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 16 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) @@ -353,14 +353,14 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax0 in T.thread_binding(32, thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(1): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(4): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) @@ -369,7 +369,7 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax1_fused_2 in range(1): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(16, thread="threadIdx.y"): for ax0 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) @@ -392,17 +392,17 @@ def test_decode_gemv2(): @T.prim_func(private=True) def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((32000, 4096), "float16") var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") for i, j in T.grid(32000, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv771[v_i, v_j // 8], lv772[v_i, v_j // 32]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 32000, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv3216[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -410,7 +410,7 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv3216[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] for i0, i1, i2 in T.grid(1, 1, 32000): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) @@ -419,7 +419,7 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) @T.prim_func(private=True) def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32000), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 1, 32000), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((32, 1, 1, 32000), "float16", scope="local") @@ -433,7 +433,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 for ax2_1 in T.thread_binding(16, thread="threadIdx.y"): for ax2_2 in T.thread_binding(32, thread="threadIdx.x"): for ax2_3 in T.vectorized(8): - with T.block("lv3216_shared"): + with T.sblock("lv3216_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(4096, ax2_0 * 4096 + ax2_1 * 256 + ax2_2 * 8 + ax2_3) T.reads(lv3216[v0, v1, v2]) @@ -441,7 +441,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() @@ -450,7 +450,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 for ax1_0_fused_ax1_1_fused_0 in T.serial(16, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(1): for ax0_ax1_fused_1 in T.vectorized(1): - with T.block("lv771_local"): + with T.sblock("lv771_local"): v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 32 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0) T.reads(lv771[v0, v1]) @@ -458,7 +458,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 lv771_local[v0, v1] = lv771[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(128, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) @@ -469,14 +469,14 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 for ax0 in T.thread_binding(32, thread="threadIdx.x"): for ax2_fused_2_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_2_1 in T.vectorized(1): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(32, ax0) v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(4): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) @@ -485,7 +485,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 for ax1_fused_2 in range(1): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(16, thread="threadIdx.y"): for ax0 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(32, ax0) v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) @@ -495,7 +495,7 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] for ax0_fused_0_ax0_fused_1_fused in T.thread_binding(16, thread="threadIdx.y"): for ax0_fused_2 in range(1): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 16 + ax0_fused_0_ax0_fused_1_fused + ax0_fused_2) T.reads(var_NT_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) @@ -515,17 +515,17 @@ def test_decode_gemv3(): @T.prim_func(private=True) def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j // T.int64(32)]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -533,7 +533,7 @@ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.B var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv570[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -542,7 +542,7 @@ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.B @T.prim_func(private=True) def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") @@ -556,7 +556,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_3 in T.vectorized(T.int64(1)): - with T.block("lv574_shared"): + with T.sblock("lv574_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(11008), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(11008)) @@ -565,7 +565,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T lv574_shared[v0, v1, v2] = lv574[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() @@ -574,7 +574,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_fused_0 in range(T.int64(1)): for ax0_ax1_fused_1 in T.vectorized(T.int64(1)): - with T.block("lv575_local"): + with T.sblock("lv575_local"): v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1) v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0) T.reads(lv575[v0, v1]) @@ -582,7 +582,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T lv575_local[v0, v1] = lv575[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) @@ -593,14 +593,14 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax2_fused_1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(T.int64(1)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in range(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]) @@ -609,7 +609,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T for ax1_fused_1 in range(T.int64(1)): for ax1_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]) @@ -619,7 +619,7 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] for ax0_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_fused_1 in range(T.int64(1)): - with T.block("T_add"): + with T.sblock("T_add"): v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0 + ax0_fused_1) T.reads(lv570[T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) @@ -638,17 +638,17 @@ def test_autogptq_decode_gemv(): @T.prim_func(private=True) def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv9[v_i // T.int64(8), v_j], lv10[lv12[v_i], v_j // T.int64(8)], lv12[v_i], lv11[lv12[v_i], v_j]) T.writes(decode_intermediate[v_i, v_j]) decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) - (T.Cast("float16", T.bitwise_and(T.shift_right(lv10[lv12[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) + T.float16(1))) * lv11[lv12[v_i], v_j] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv8[v_i0, v_i1, v_k], decode_intermediate[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -656,7 +656,7 @@ def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer( var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv8[v_i0, v_i1, v_k] * decode_intermediate[v_k, v_i2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv1613[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -682,28 +682,28 @@ def before( p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") for i, j in T.grid(11008, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15)))- T.float16(7)) * lv576[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for ax0, ax1, ax2 in T.grid(1, 1, 4096): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] @T.prim_func(private=True) def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") var_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 4096), "float16", scope="local") var_matmul_intermediate_rf_local_1 = T.alloc_buffer((4, 1, 1, 4096), "float16", scope="local") @@ -713,7 +713,7 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.x"): for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(8): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) T.reads() @@ -722,7 +722,7 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(86, 1): for ax0, ax1 in T.grid(1, 1): - with T.block("lv576_local"): + with T.sblock("lv576_local"): v0 = T.axis.spatial(344, ax1_0_fused_ax1_1_fused_0 * 4 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) T.reads(lv576[v0, v1]) @@ -730,14 +730,14 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 lv576_local[v0, v1] = lv576[v0, v1] for ax1_0_fused_ax1_1_fused_3 in range(4): for ax0, ax1 in T.grid(1, 1): - with T.block("lv575_local"): + with T.sblock("lv575_local"): v0 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 4 + ax1_0_fused_ax1_1_fused_3 + ax0) v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) T.reads(lv575[v0, v1]) T.writes(lv575_local[v0, v1]) lv575_local[v0, v1] = lv575[v0, v1] for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) @@ -746,14 +746,14 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] + lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) for ax2 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(4, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) T.reads() T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = T.float16(0) for ax1 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0]) @@ -761,7 +761,7 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0] for ax1 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(4, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1) T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) @@ -771,7 +771,7 @@ def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096 var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.x"): for ax0_fused_1 in range(1): - with T.block("T_add"): + with T.sblock("T_add"): v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) @@ -792,17 +792,17 @@ def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") for i, j in T.grid(T.int64(4096), v): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v_i // T.int64(32), v_j] for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -810,7 +810,7 @@ def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) @@ -823,7 +823,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) - # with T.block("root"): + # with T.sblock("root"): var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16", scope="local") var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(8), T.int64(1), T.int64(1), v), "float16", scope="local") var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), v), "float16", scope="local") @@ -834,7 +834,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(T.int64(1), thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(T.int64(8)): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(8), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) @@ -847,7 +847,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), for ax2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_3 in T.thread_binding(T.int64(1), thread="threadIdx.y"): for ax2_4 in T.vectorized(T.int64(4)): - with T.block("lv1607_shared"): + with T.sblock("lv1607_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(T.int64(4096), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + (ax2_0 * T.int64(1024) + ax2_1 * T.int64(1024) + ax2_2 * T.int64(4) + ax2_3 * T.int64(4) + ax2_4)) T.where(((ax2_0 + ax2_1) * T.int64(256) + ax2_2 + ax2_3) * T.int64(4) + ax2_4 < T.int64(32)) @@ -856,7 +856,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), lv1607_shared[v0, v1, v2] = lv1607[v0, v1, v2] for ax1_0_fused_ax1_1_fused_1 in range(T.int64(1)): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - with T.block("lv613_local"): + with T.sblock("lv613_local"): v0 = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_0 + ax0) v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) @@ -865,7 +865,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), lv613_local[v0, v1] = lv613[v0, v1] for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - with T.block("lv612_local"): + with T.sblock("lv612_local"): v0 = T.axis.spatial(T.int64(512), ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0) v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) @@ -873,7 +873,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.writes(lv612_local[v0, v1]) lv612_local[v0, v1] = lv612[v0, v1] for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(T.int64(8)): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(8), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) @@ -883,7 +883,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] + lv1607_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused) % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613_local[(vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused) // T.int64(32) + vax1_0_fused_ax1_1_fused_0 + vax1_0_fused_ax1_1_fused_1, v0]) for ax2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(T.int64(1), ax0) v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax2) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax2 < v) @@ -891,7 +891,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax2) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax2 < v) @@ -900,7 +900,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0] for ax1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(T.int64(1), ax0) v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax1) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax1 < v) @@ -911,7 +911,7 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] for ax0_fused_0 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax0_fused_1 in range(T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax0_fused_0 + ax0_fused_1) T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + (ax0_fused_0 + ax0_fused_1) < v) T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) @@ -929,16 +929,16 @@ def test_blockized_gemv(): # fmt: off @T.prim_func(private=True) def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): - # with T.block("root"): + # with T.sblock("root"): for expert_id in T.thread_binding(2, thread="blockIdx.y"): - with T.block("gemv_o"): + with T.sblock("gemv_o"): v_expert_id_o = T.axis.spatial(2, expert_id) vi_o = T.axis.spatial(1, 0) vj_o = T.axis.reduce(1, 0) T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) T.writes(o[v_expert_id_o, 0:16384]) for i, j in T.grid(16384, 4096): - with T.block("gemv"): + with T.sblock("gemv"): vi_i, vj_i = T.axis.remap("SR", [i, j]) T.reads(x[0, vj_i], w[indptr[v_expert_id_o], vi_i, vj_i], indptr[v_expert_id_o]) T.writes(o[v_expert_id_o, vi_i]) @@ -949,9 +949,9 @@ def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "flo @T.prim_func(private=True) def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): for expert_id in T.thread_binding(2, thread="blockIdx.y"): - with T.block("gemv_o"): + with T.sblock("gemv_o"): v_expert_id_o = T.axis.spatial(2, expert_id) vi_o = T.axis.spatial(1, 0) vj_o = T.axis.reduce(1, 0) @@ -965,7 +965,7 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for u_fused_ax0_fused_fused_2_init in range(1): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(1): - with T.block("gemv_rf_init"): + with T.sblock("gemv_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(16, ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() @@ -975,7 +975,7 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f for ax0 in range(1): for ax1_ax2_fused_0 in range(8): for ax1_ax2_fused_1 in T.vectorized(1): - with T.block("w_local"): + with T.sblock("w_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1) v2 = T.axis.spatial(4096, ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax1_ax2_fused_0 + ax1_ax2_fused_1) @@ -984,7 +984,7 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f w_local[v0, v1, v2] = w[indptr[v_expert_id_o] + v0, v1, v2] for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(1, 8): for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(1): - with T.block("gemv_rf_update"): + with T.sblock("gemv_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(16, ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) @@ -995,14 +995,14 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f for ax0 in T.thread_binding(16, thread="threadIdx.y"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): - with T.block("gemv_rf_init"): + with T.sblock("gemv_rf_init"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(16, ax0) v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0]) o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0] = T.float16(0) for ax1 in range(1): - with T.block("gemv_rf_update"): + with T.sblock("gemv_rf_update"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0], o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, v_expert_id_o, v0]) @@ -1011,7 +1011,7 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.x"): for ax0 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("gemv"): + with T.sblock("gemv"): vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(16, ax0) v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1) T.reads(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0]) @@ -1034,7 +1034,7 @@ def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: T.int6 output_buf = T.match_buffer( var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", align=8 ) - with T.block("exclusive_scan_thrust"): + with T.sblock("exclusive_scan_thrust"): T.reads() T.writes() T.call_packed( diff --git a/tests/python/dlight/test_gpu_general_reduction.py b/tests/python/dlight/test_gpu_general_reduction.py index aafe76f900e4..5586cc74096b 100644 --- a/tests/python/dlight/test_gpu_general_reduction.py +++ b/tests/python/dlight/test_gpu_general_reduction.py @@ -43,13 +43,13 @@ def main(p_lv44: T.handle, p_output0: T.handle): n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv44[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) @@ -57,13 +57,13 @@ def main(p_lv44: T.handle, p_output0: T.handle): T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) @@ -71,14 +71,14 @@ def main(p_lv44: T.handle, p_output0: T.handle): T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) + T.sblock_attr({"axis": 3}) var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -92,14 +92,14 @@ def main(p_lv44: T.handle, p_output0: T.handle): n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), T.int64(32), n), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(32), n), scope="shared") for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial((m + T.int64(255)) // T.int64(256), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n + ax0) v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1) v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) + ax2_fused_1) @@ -112,7 +112,7 @@ def main(p_lv44: T.handle, p_output0: T.handle): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial((m + T.int64(255)) // T.int64(256), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n + ax0) v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1) v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) + ax2_fused_1) @@ -124,7 +124,7 @@ def main(p_lv44: T.handle, p_output0: T.handle): T_softmax_expsum_shared[T.int64(0), v0, v1] = T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(lv44[T.int64(0), v0, v1, v2] - T_softmax_maxelem_shared[T.int64(0), v0, v1]) for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_0 in T.serial((m + T.int64(255)) // T.int64(256), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) v2 = T.axis.spatial(m, ax2_0 * T.int64(256) + ax2_1) @@ -142,12 +142,12 @@ def test_softmax_2(): class Before: @T.prim_func def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1))) T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000))) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1))) for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_i1, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1]) @@ -155,13 +155,13 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k]) for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1]) @@ -169,11 +169,11 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof T_softmax_expsum[v_i0, v_i1] = T.float32(0) T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] @@ -182,14 +182,14 @@ class After: @T.prim_func def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(1)), scope="shared") for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(125), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.reduce(T.int64(32000), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(A[T.int64(0), T.int64(0), v1]) @@ -200,7 +200,7 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(125), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.reduce(T.int64(32000), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(A[T.int64(0), T.int64(0), v1], T_softmax_maxelem_shared[T.int64(0), T.int64(0)]) @@ -210,12 +210,12 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof T_softmax_expsum_shared[T.int64(0), T.int64(0)] = T_softmax_expsum_shared[T.int64(0), T.int64(0)] + T.exp(A[T.int64(0), T.int64(0), v1] - T_softmax_maxelem_shared[T.int64(0), T.int64(0)]) for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(125), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(32000), ax1_0 * T.int64(256) + ax1_1) T.reads(A[T.int64(0), T.int64(0), v1], T_softmax_maxelem_shared[T.int64(0), T.int64(0)], T_softmax_expsum_shared[T.int64(0), T.int64(0)]) T.writes(T_softmax_norm[T.int64(0), T.int64(0), v1]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) T_softmax_norm[T.int64(0), T.int64(0), v1] = T.exp(A[T.int64(0), T.int64(0), v1] - T_softmax_maxelem_shared[T.int64(0), T.int64(0)]) / T_softmax_expsum_shared[T.int64(0), T.int64(0)] # fmt: on @@ -228,12 +228,12 @@ def test_softmax_3(): class Before: @T.prim_func def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192))) T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192))) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192))) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), T.int64(32)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(input[v_i0, v_i1, v_k, v_i2]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) @@ -241,13 +241,13 @@ def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), " T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-340282346638528859811704183484516925440.0) T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], input[v_i0, v_i1, v_k, v_i2]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), T.int64(8192)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(input[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i3]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(input[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i3]) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), T.int64(32)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_k, v_i2]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) @@ -255,11 +255,11 @@ def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), " T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0.0) T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_k, v_i2] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), T.int64(8192)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i3]) T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i3] @@ -268,14 +268,14 @@ class After: @T.prim_func def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") for ax0_ax2_fused in T.thread_binding(T.int64(32768), thread="blockIdx.x"): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192) + ax0) v1 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192) + ax1) v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * T.int64(256) + ax2_fused_1) @@ -288,7 +288,7 @@ def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), " for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax2_fused_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192) + ax0) v1 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192) + ax1) v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * T.int64(256) + ax2_fused_1) @@ -300,14 +300,14 @@ def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), " T_softmax_expsum_shared[T.int64(0), v0, v1] = T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(input[T.int64(0), v0, v2, v1] - T_softmax_maxelem_shared[T.int64(0), v0, v1]) for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192)) v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(256) + ax1_1) v2 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192)) T.where(ax1_0 * T.int64(256) + ax1_1 < T.int64(32)) T.reads(input[T.int64(0), v0, v1, v2], T_softmax_maxelem_shared[T.int64(0), v0, v2], T_softmax_expsum_shared[T.int64(0), v0, v2]) T.writes(T_softmax_norm[T.int64(0), v0, v1, v2]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) T_softmax_norm[T.int64(0), v0, v1, v2] = T.exp(input[T.int64(0), v0, v1, v2] - T_softmax_maxelem_shared[T.int64(0), v0, v2]) / T_softmax_expsum_shared[T.int64(0), v0, v2] # fmt: on _check(Before, After) @@ -323,12 +323,12 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): + # with T.sblock("root"): A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(lv6[v_ax0, v_ax1, v_k2]) T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) @@ -340,13 +340,13 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) @@ -360,14 +360,14 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): + # with T.sblock("root"): A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), scope="shared") A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), scope="shared") for ax0_fused in T.thread_binding(n, thread="blockIdx.x"): for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(10), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v0 = T.axis.spatial(n, ax0_fused + ax0) v1 = T.axis.reduce(T.int64(2560), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(lv6[T.int64(0), v0, v1]) @@ -381,7 +381,7 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: A_red_temp_v1_shared[T.int64(0), v0] = v_A_red_temp_v1 for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(10), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(n, ax0_fused) v1 = T.axis.spatial(T.int64(2560), ax1_0 * T.int64(256) + ax1_1) T.reads(lv6[T.int64(0), v0, v1], A_red_temp_v0_shared[T.int64(0), v0], A_red_temp_v1_shared[T.int64(0), v0], weight1[v1], bias[v1]) @@ -401,10 +401,10 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), n)) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -412,7 +412,7 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm Ared_temp[v_bsz, v_i] = T.float32(0) Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -426,13 +426,13 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): Ared_temp_shared = T.alloc_buffer((T.int64(1), n), scope="shared") for ax0_fused in T.thread_binding(n, thread="blockIdx.x"): for ax0 in range(T.int64(1)): for ax1_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v0 = T.axis.spatial(n, ax0_fused + ax0) v1 = T.axis.reduce(T.int64(4096), ax1_fused_0 * T.int64(256) + ax1_fused_1) T.reads(A[T.int64(0), v0, v1]) @@ -442,7 +442,7 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm Ared_temp_shared[T.int64(0), v0] = Ared_temp_shared[T.int64(0), v0] + T.Cast("float32", A[T.int64(0), v0, v1]) * T.Cast("float32", A[T.int64(0), v0, v1]) for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax1_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v0 = T.axis.spatial(n, ax0_fused) v1 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(256) + ax1_1) T.reads(B[v1], A[T.int64(0), v0, v1], Ared_temp_shared[T.int64(0), v0]) @@ -466,13 +466,13 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: T_reshape_3 = T.alloc_buffer((32, 64)) T_group_norm = T.alloc_buffer((1, 32, 64)) for ax0, ax1, ax2 in T.grid(1, 32, 64): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[0, (v_ax1 * 64 + v_ax2) % 2048]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2]) T_reshape_1[v_ax0, v_ax1, v_ax2] = A[0, (v_ax1 * 64 + v_ax2) % 2048] for ax0, ax1, k2 in T.grid(1, 32, 64): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(T_reshape_1[v_ax0, v_ax1, v_k2]) T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) @@ -484,25 +484,25 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 for ax0, ax1 in T.grid(32, 64): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B[(v_ax0 * 64 + v_ax1) % 2048]) T.writes(T_reshape_2[v_ax0, v_ax1]) T_reshape_2[v_ax0, v_ax1] = B[(v_ax0 * 64 + v_ax1) % 2048] for ax0, ax1 in T.grid(32, 64): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(C[(v_ax0 * 64 + v_ax1) % 2048]) T.writes(T_reshape_3[v_ax0, v_ax1]) T_reshape_3[v_ax0, v_ax1] = C[(v_ax0 * 64 + v_ax1) % 2048] for ax0, ax1, ax2 in T.grid(1, 32, 64): - with T.block("T_group_norm"): + with T.sblock("T_group_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) T.writes(T_group_norm[v_ax0, v_ax1, v_ax2]) T_group_norm[v_ax0, v_ax1, v_ax2] = (T_reshape_1[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.015625)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.015625) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.015625) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.015625)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] for ax0, ax1 in T.grid(1, 2048): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_group_norm[0, v_ax1 % 2048 // 64, v_ax1 % 64]) T.writes(T_reshape[v_ax0, v_ax1]) @@ -513,14 +513,14 @@ class After: @T.prim_func def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: T.Buffer((2048,), "float32"), T_reshape: T.Buffer((1, 2048), "float32")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): A_red_temp_v0_shared = T.alloc_buffer((1, 32), scope="shared") A_red_temp_v1_shared = T.alloc_buffer((1, 32), scope="shared") for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax0 in range(32): for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): for ax1_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v0 = T.axis.spatial(32, ax0) v1 = T.axis.reduce(64, ax1_fused_0 * 256 + ax1_fused_1) T.where(ax1_fused_0 * 256 + ax1_fused_1 < 64) @@ -535,7 +535,7 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: A_red_temp_v1_shared[0, v0] = v_A_red_temp_v1 for ax1_1 in T.thread_binding(256, thread="threadIdx.x"): for ax1_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(2048, ax1_0 * 256 + ax1_1) T.reads(A[0, v1], A_red_temp_v0_shared[0, v1 // 64], A_red_temp_v1_shared[0, v1 // 64], B[v1], C[v1]) @@ -560,7 +560,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): - with T.block("pad"): + with T.sblock("pad"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) A_pad[v0, v1, v2] = T.if_then_else( v1 * T.int64(4096) + v2 < vocab_size, @@ -569,14 +569,14 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): ) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): - with T.block("max"): + with T.sblock("max"): v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) with T.init(): temp_max[v0, v1] = T.min_value("float32") temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): - with T.block("sum_exp"): + with T.sblock("sum_exp"): v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) with T.init(): temp_sum[v0, v1] = T.float32(0) @@ -587,7 +587,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): ) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): - with T.block("log"): + with T.sblock("log"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) blocked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1] @@ -612,7 +612,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): "pragma_unroll_explicit": 1, }, ): - with T.block("max"): + with T.sblock("max"): v0 = T.axis.spatial( batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0, @@ -642,7 +642,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): "pragma_unroll_explicit": 1, }, ): - with T.block("sum_exp"): + with T.sblock("sum_exp"): v0 = T.axis.spatial( batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0, @@ -677,7 +677,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): "pragma_unroll_explicit": 1, }, ): - with T.block("log"): + with T.sblock("log"): v0 = T.axis.spatial( batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks ) diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index ae07a3b7318c..ecfc6b524182 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -31,23 +31,23 @@ def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.B batch_size = T.int64() lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), T.int64(28672)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): compute = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16") dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16") for i0, i1 in T.grid(T.int64(4096), T.int64(28672)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(lv429[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv429[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(4096), T.int64(28672)): - with T.block("dequantize"): + with T.sblock("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], lv430[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize_intermediate_intermediate[v_i0, v_i1]) dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv430[v_i0, v_i1 // T.int64(32)] for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(4096), T.int64(28672)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv807[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -61,7 +61,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T batch_size = T.int64() lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), T.int64(28672)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16", scope="local") NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") @@ -72,7 +72,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) @@ -82,7 +82,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T for ax2_fused_u_fused_0 in T.serial(T.int64(112), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)): for ax0_1 in T.vectorized(T.int64(1)): - with T.block("dequantize"): + with T.sblock("dequantize"): v0 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + ax0_0_1 + ax0_1) v1 = T.axis.spatial(T.int64(28672), ax2_fused_u_fused_0 * T.int64(256) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(8) + ax1) T.reads(lv429[v0, v1 // T.int64(8)], lv430[v0, v1 // T.int64(32)]) @@ -90,7 +90,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T dequantize_intermediate_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv429[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv430[v0, v1 // T.int64(32)] for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) @@ -103,7 +103,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax2 in range(T.int64(4)): for ax3_fused_2_1 in T.vectorized(T.int64(2)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) @@ -111,7 +111,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0) for ax1 in range(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) @@ -121,7 +121,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) @@ -133,7 +133,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T for ax0 in range(T.int64(4)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax1_fused_2 in range(T.int64(2)): - with T.block("NT_matmul_intermediate_pad"): + with T.sblock("NT_matmul_intermediate_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) @@ -158,9 +158,9 @@ def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)), "float16"), va batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(K)), "float16") NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(N)), "float16") - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(N), T.int64(K)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(NT_matmul[v_i0, v_i1, v_i2]) @@ -174,7 +174,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(4096)), "float16") NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): NT_matmul_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") NT_matmul_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") NT_matmul_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") @@ -184,7 +184,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) @@ -194,7 +194,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float for ax2_fused_u_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) @@ -207,7 +207,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax2 in range(T.int64(4)): for ax3_fused_2_1 in T.vectorized(T.int64(2)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) @@ -215,7 +215,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0) for ax1 in range(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) @@ -225,7 +225,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) @@ -237,7 +237,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float for ax0 in range(T.int64(4)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax1_fused_2 in range(T.int64(2)): - with T.block("NT_matmul_pad"): + with T.sblock("NT_matmul_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) @@ -259,9 +259,9 @@ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int kv_seq_len = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len)) B = T.match_buffer(var_B, (T.int64(1), T.int64(32), kv_seq_len, T.int64(128))) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), kv_seq_len): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -283,7 +283,7 @@ def func(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), v A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") for i0, i1, k in T.grid(batch_size, T.int64(8), T.int64(4096)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_i1, v_k]) T.writes(C[v_i0, v_i1]) @@ -298,7 +298,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") - # with T.block("root"): + # with T.sblock("root"): C_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") C_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") C_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") @@ -308,7 +308,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) @@ -319,7 +319,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" for ax2_fused_u_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) @@ -333,7 +333,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2 in range(T.int64(4)): for ax3_fused_2_1 in T.vectorized(T.int64(2)): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) @@ -342,7 +342,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = T.float16(0) for ax1 in range(T.int64(4)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) @@ -353,7 +353,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) v1 = T.axis.spatial(T.int64(8), ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) @@ -366,7 +366,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" for ax0 in range(T.int64(4)): for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_fused_2 in range(T.int64(2)): - with T.block("C_pad"): + with T.sblock("C_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) @@ -396,15 +396,15 @@ def before( compute = T.alloc_buffer((4096, 6144), "float16") B = T.alloc_buffer((4096, 6144), "float16") for i0, i1 in T.grid(4096, 6144): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(B0[v_i0 // 8, v_i1], T.Cast("uint32", v_i0 % 8 * 4)), T.uint32(15))) for i0, i1 in T.grid(4096, 6144): - with T.block("dequantize"): + with T.sblock("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * B1[v_i0 // 32, v_i1] for i0, i1, i2, k in T.grid(batch_size, 1, 6144, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): C[v_i0, v_i1, v_i2] = T.float16(0) @@ -416,7 +416,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo batch_size = T.int32() A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 6144), "float16") - # with T.block("root"): + # with T.sblock("root"): B_local = T.alloc_buffer((4096, 6144), "float16", scope="local") A_pad_shared = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 4096), "float16", scope="shared") C_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 6144), "float16", scope="local") @@ -430,7 +430,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for ax0_1_init, ax2_fused_1_ax2_fused_3_fused_1_0_init in T.grid(4, 2): for ax2_fused_1_ax2_fused_3_fused_1_1_init in T.vectorized(4): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(32, ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0_init * 4 + ax2_fused_1_ax2_fused_3_fused_1_1_init) v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) @@ -439,14 +439,14 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0) for ax2_fused_0 in range(32): for ax0_ax1_fused in T.vectorized(4): - with T.block("B0_local"): + with T.sblock("B0_local"): v0 = T.axis.spatial(512, ax2_fused_0 * 16 + ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax0_ax1_fused) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) T.reads(B0[v0, v1]) T.writes(B0_local[v0, v1]) B0_local[v0, v1] = B0[v0, v1] for ax0_ax1_fused in T.vectorized(1): - with T.block("B1_local"): + with T.sblock("B1_local"): v0 = T.axis.spatial(128, ax2_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_0) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) T.reads(B1[v0, v1]) @@ -455,17 +455,17 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo for ax0_ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(2): - with T.block("A_pad"): + with T.sblock("A_pad"): v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 128) v1 = T.axis.spatial(4096, ax2_fused_0 * 128 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 128) T.reads(A[v0, 0, v1]) T.writes(A_pad_shared[v0, 0, v1]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 1]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 1]]}) A_pad_shared[v0, 0, v1] = T.if_then_else(v0 < batch_size, A[v0, 0, v1], T.float16(0)) for ax2_fused_2 in range(4): for ax0_ax1_fused_0 in range(2): for ax0_ax1_fused_1 in T.vectorized(4): - with T.block("dequantize"): + with T.sblock("dequantize"): v0 = T.axis.spatial(4096, ax2_fused_0 * 128 + ax2_fused_1_ax2_fused_3_fused_0 * 32 + ax2_fused_2 * 8 + ax0_ax1_fused_0 * 4 + ax0_ax1_fused_1) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) T.reads(B0_local[v0 // 8, v1], B1_local[v0 // 32, v1]) @@ -473,7 +473,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo B_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0_local[v0 // 8, v1], T.Cast("uint32", v0 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1_local[v0 // 32, v1] for ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(4, 2): for ax2_fused_1_ax2_fused_3_fused_1_1 in T.vectorized(4): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(32, ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_1) v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) @@ -484,7 +484,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo for ax3 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): for ax2_init in range(4): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(4, ax0) v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2_init) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax3) @@ -492,7 +492,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo T.writes(C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]) C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0) for ax2, ax1 in T.grid(4, 8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax3) @@ -502,7 +502,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo for ax1 in range(4): for ax2 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(4, ax0) v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax2) @@ -513,7 +513,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo C_pad_local[v0, 0, v1] = C_pad_local[v0, 0, v1] + C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] for ax0 in range(4): for ax1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("C_pad"): + with T.sblock("C_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1) T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 * 4 + ax0 == 0) and ax0_0 * 4 + ax0 < batch_size) diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index f27d9d370fce..0144f3ebae91 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -41,7 +41,7 @@ def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "f inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) for i0, i1, i2, k in T.grid(T.int64(1), m, T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): matmul[v_i0, v_i1, v_i2] = T.float32(0) @@ -53,7 +53,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), m = T.int64() inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) - # with T.block("root"): + # with T.sblock("root"): matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared") inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared") @@ -65,7 +65,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): for ax2_3_1_init in T.vectorized(T.int64(2)): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) @@ -77,29 +77,29 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(2)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("inp0_reindex_pad_shared"): + with T.sblock("inp0_reindex_pad_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(inp0[v0, v1, v2]) T.writes(inp0_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(4)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("inp1_reindex_shared"): + with T.sblock("inp1_reindex_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(inp1[v2, v1]) T.writes(inp1_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1] for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): for ax2_3_1 in T.vectorized(T.int64(2)): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) @@ -109,7 +109,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3] for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)): for ax2_1_1 in T.vectorized(T.int64(2)): - with T.block("matmul_reindex_pad_local"): + with T.sblock("matmul_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) @@ -128,7 +128,7 @@ def func(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul inp0 = T.match_buffer(var_inp0, (1, m, 4096)) matmul = T.match_buffer(var_matmul, (1, m, 4096)) for i0, i1, i2, k in T.grid(1, m, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): matmul[v_i0, v_i1, v_i2] = T.float32(0) @@ -140,7 +140,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma m = T.int32() inp0 = T.match_buffer(var_inp0, (1, m, 4096)) matmul = T.match_buffer(var_matmul, (1, m, 4096)) - # with T.block("root"): + # with T.sblock("root"): matmul_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="local") inp0_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="shared") inp1_reindex_shared = T.alloc_buffer((1, 4096, 4096), scope="shared") @@ -152,7 +152,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(4, 2): for ax2_3_1_init in T.vectorized(2): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) @@ -164,29 +164,29 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("inp0_reindex_pad_shared"): + with T.sblock("inp0_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(inp0[v0, v1, v2]) T.writes(inp0_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("inp1_reindex_shared"): + with T.sblock("inp1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(inp1[v2, v1]) T.writes(inp1_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1] for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): for ax2_3_1 in T.vectorized(2): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) @@ -196,7 +196,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma matmul_reindex_pad_local[0, v1, v2] = matmul_reindex_pad_local[0, v1, v2] + inp0_reindex_pad_shared[0, v1, v3] * inp1_reindex_shared[0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): - with T.block("matmul_reindex_pad_local"): + with T.sblock("matmul_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) @@ -220,13 +220,13 @@ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T. var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096))) var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096))) for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j]) T.writes(var_decode_intermediate[v_i, v_j]) var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -234,7 +234,7 @@ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T. var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(Out[v_ax0, v_ax1, v_ax2]) @@ -243,7 +243,7 @@ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T. @T.prim_func def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared") @@ -255,7 +255,7 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer(( for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): for ax2_3_1_init in T.vectorized(T.int64(2)): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) @@ -267,29 +267,29 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer(( for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(2)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("A_reindex_shared"): + with T.sblock("A_reindex_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(A[v0, v1, v2]) T.writes(A_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) A_reindex_shared[v0, v1, v2] = A[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(4)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("var_decode_intermediate_reindex_shared"): + with T.sblock("var_decode_intermediate_reindex_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1]) T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16))) for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): for ax2_3_1 in T.vectorized(T.int64(2)): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) @@ -299,7 +299,7 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer(( var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3] for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)): for ax2_1_1 in T.vectorized(T.int64(2)): - with T.block("var_matmul_intermediate_reindex_local"): + with T.sblock("var_matmul_intermediate_reindex_local"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) @@ -319,13 +319,13 @@ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T. var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096))) var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096))) for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j]) T.writes(var_decode_intermediate[v_i, v_j]) var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -333,7 +333,7 @@ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T. var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(Out[v_ax0, v_ax1, v_ax2]) @@ -354,20 +354,20 @@ def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buff lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), "float16") lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) var_compute_intermediate = T.alloc_buffer((T.int64(4096),)) var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv13[v_i, v_j // T.int64(8)], lv14[v_i, v_j // T.int64(32)]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv48[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -375,25 +375,25 @@ def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buff var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv48[v_i0, v_i1, v_k]) * T.Cast("float32", p_output0_intermediate_1[v_k, v_i2]) for i0 in range(T.int64(4096)): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(T.int64(4096), i0) T.reads(lv13_1[v_i0]) T.writes(var_compute_intermediate[v_i0]) var_compute_intermediate[v_i0] = T.Cast("float32", lv13_1[v_i0]) for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], var_compute_intermediate[v_ax2]) T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + var_compute_intermediate[v_ax2] for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("compute_1"): + with T.sblock("compute_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) var_compute_intermediate_1[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add_1"): + with T.sblock("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv3[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -406,7 +406,7 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), "float16") lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): var_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") lv48_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") p_output0_intermediate_1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), "float16", scope="shared") @@ -418,7 +418,7 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): for ax2_3_1_init in T.vectorized(T.int64(2)): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) @@ -430,29 +430,29 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(2)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("lv48_reindex_pad_shared"): + with T.sblock("lv48_reindex_pad_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(lv48[v0, v1, v2]) T.writes(lv48_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv48_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv48[v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(4)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("p_output0_intermediate_1_reindex_shared"): + with T.sblock("p_output0_intermediate_1_reindex_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(lv13[v2, v1 // T.int64(8)], lv14[v2, v1 // T.int64(32)]) T.writes(p_output0_intermediate_1_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) p_output0_intermediate_1_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v2, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v2, v1 // T.int64(32)] for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): for ax2_3_1 in T.vectorized(T.int64(2)): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) @@ -462,7 +462,7 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + T.Cast("float32", lv48_reindex_pad_shared[T.int64(0), v1, v3]) * T.Cast("float32", p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3]) for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)): for ax2_1_1 in T.vectorized(T.int64(2)): - with T.block("var_matmul_intermediate_reindex_pad_local"): + with T.sblock("var_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) @@ -483,14 +483,14 @@ def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "floa lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16") lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048))) var_T_multiply_intermediate = T.match_buffer(p_output0, (n, T.int64(2048)), "float16") - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate = T.alloc_buffer((n, T.int64(2048)), "float16") compute = T.alloc_buffer((n, T.int64(2048)), "float16") var_T_multiply_intermediate_1 = T.alloc_buffer((n, T.int64(2048)), "float16") var_T_squeeze_intermediate = T.alloc_buffer((n, T.int64(2048))) var_compute_intermediate = T.alloc_buffer((n, T.int64(2048)), "float16") for i0, i1, k in T.grid(n, T.int64(2048), T.int64(2048)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(lv26[v_i0, v_k], lv9[v_i1, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1]) @@ -498,31 +498,31 @@ def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "floa var_NT_matmul_intermediate[v_i0, v_i1] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1] = var_NT_matmul_intermediate[v_i0, v_i1] + lv26[v_i0, v_k] * lv9[v_i1, v_k] for i0, i1 in T.grid(n, T.int64(2048)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(var_NT_matmul_intermediate[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1]) for ax0, ax1 in T.grid(n, T.int64(2048)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1], compute[v_ax0, v_ax1]) T.writes(var_T_multiply_intermediate_1[v_ax0, v_ax1]) var_T_multiply_intermediate_1[v_ax0, v_ax1] = var_NT_matmul_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1] for ax0, ax1 in T.grid(n, T.int64(2048)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(lv52[T.int64(0), v_ax0, v_ax1]) T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1]) var_T_squeeze_intermediate[v_ax0, v_ax1] = lv52[T.int64(0), v_ax0, v_ax1] for i0, i1 in T.grid(n, T.int64(2048)): - with T.block("compute_1"): + with T.sblock("compute_1"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(var_T_squeeze_intermediate[v_i0, v_i1]) T.writes(var_compute_intermediate[v_i0, v_i1]) var_compute_intermediate[v_i0, v_i1] = T.Cast("float16", var_T_squeeze_intermediate[v_i0, v_i1]) for ax0, ax1 in T.grid(n, T.int64(2048)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(var_compute_intermediate[v_ax0, v_ax1], var_T_multiply_intermediate_1[v_ax0, v_ax1]) T.writes(var_T_multiply_intermediate[v_ax0, v_ax1]) @@ -535,7 +535,7 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16") lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048))) var_T_multiply_intermediate = T.match_buffer(p_output0, (n, T.int64(2048)), "float16") - # with T.block("root"): + # with T.sblock("root"): var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="local") lv26_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared") lv9_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), T.int64(2048)), "float16", scope="shared") @@ -547,7 +547,7 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): for ax2_3_1_init in T.vectorized(T.int64(2)): - with T.block("NT_matmul_init"): + with T.sblock("NT_matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) @@ -559,29 +559,29 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(2)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("lv26_reindex_pad_shared"): + with T.sblock("lv26_reindex_pad_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(lv26[v1, v2]) T.writes(lv26_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv26_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv26[v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(T.int64(4)): for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): - with T.block("lv9_reindex_shared"): + with T.sblock("lv9_reindex_shared"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) T.reads(lv9[v1, v2]) T.writes(lv9_reindex_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv9_reindex_shared[v0, v1, v2] = lv9[v1, v2] for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): for ax2_3_1 in T.vectorized(T.int64(2)): - with T.block("NT_matmul_update"): + with T.sblock("NT_matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) @@ -591,7 +591,7 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + lv26_reindex_pad_shared[T.int64(0), v1, v3] * lv9_reindex_shared[T.int64(0), v2, v3] for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)): for ax2_1_1 in T.vectorized(T.int64(2)): - with T.block("var_NT_matmul_intermediate_reindex_pad_local"): + with T.sblock("var_NT_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) @@ -621,7 +621,7 @@ def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "f inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) for i0, i1, i2, k in T.grid(T.int64(1), m, T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): matmul[v_i0, v_i1, v_i2] = T.float32(0) @@ -633,7 +633,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), m = T.int64() inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) - # with T.block("root"): + # with T.sblock("root"): inp0_reindex_pad = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16))) matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local") inp0_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), scope="local") @@ -643,7 +643,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for i1_2 in T.vectorized(T.int64(16)): - with T.block("inp0_reindex_pad"): + with T.sblock("inp0_reindex_pad"): v0 = T.axis.spatial(T.int64(1), i0) v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1) @@ -657,7 +657,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for i0_i1_fused_2_init in T.unroll(T.int64(16)): for i2_2_init in T.vectorized(T.int64(8)): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init) v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) @@ -669,7 +669,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for k_2 in T.unroll(T.int64(8)): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): for ax3 in T.vectorized(T.int64(16)): - with T.block("inp0_reindex_pad_local"): + with T.sblock("inp0_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1) v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2) @@ -680,7 +680,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), inp0_reindex_pad_local[v0, v1, v2, v3] = inp0_reindex_pad[v0, v1, v2, v3] for i0_i1_fused_2 in T.unroll(T.int64(16)): for i2_2 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2) v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) @@ -691,7 +691,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * inp1[v_k, v_i2] for ax0 in T.unroll(T.int64(16)): for ax1 in T.vectorized(T.int64(8)): - with T.block("matmul_pad"): + with T.sblock("matmul_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) @@ -709,24 +709,24 @@ def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.B seq_len = T.int64() rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") - # with T.block("root"): + # with T.sblock("root"): compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(12288)), "float16") for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(lv452[v_i0 // T.int64(8), v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): - with T.block("dequantize"): + with T.sblock("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1]) T.writes(dequantize_intermediate_intermediate[v_i0, v_i1]) dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1] for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2]) T.writes(matmul_intermediate[v_i0, v_i1, v_i2]) @@ -734,7 +734,7 @@ def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.B matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) @@ -746,7 +746,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T seq_len = T.int64() rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") - # with T.block("root"): + # with T.sblock("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local") rms_norm130_reindex_pad = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16") matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(12288)), "float16", scope="local") @@ -759,7 +759,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for i1_2 in T.vectorized(T.int64(16)): - with T.block("rms_norm130_reindex_pad"): + with T.sblock("rms_norm130_reindex_pad"): v0 = T.axis.spatial(T.int64(1), i0) v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1) @@ -773,7 +773,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): for i0_i1_fused_2_init in T.unroll(T.int64(16)): for i2_2_init in T.vectorized(T.int64(8)): - with T.block("matmul_init"): + with T.sblock("matmul_init"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init) v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) @@ -784,7 +784,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T for k_0 in range(T.int64(128)): for ax0 in range(T.int64(1)): for ax1 in T.vectorized(T.int64(8)): - with T.block("lv453_local"): + with T.sblock("lv453_local"): v0 = T.axis.spatial(T.int64(128), k_0 + ax0) v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) T.reads(lv453[v0, v1]) @@ -793,7 +793,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T for k_1 in range(T.int64(4)): for ax0 in range(T.int64(1)): for ax1 in T.vectorized(T.int64(8)): - with T.block("lv452_local"): + with T.sblock("lv452_local"): v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(4) + k_1 + ax0) v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) T.reads(lv452[v0, v1]) @@ -801,7 +801,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T lv452_local[v0, v1] = lv452[v0, v1] for k_2 in T.unroll(T.int64(8)): for ax0 in T.vectorized(T.int64(8)): - with T.block("dequantize"): + with T.sblock("dequantize"): v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) @@ -809,7 +809,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): for ax3 in T.vectorized(T.int64(16)): - with T.block("rms_norm130_reindex_pad_local"): + with T.sblock("rms_norm130_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1) v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2) @@ -820,7 +820,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T rms_norm130_reindex_pad_local[v0, v1, v2, v3] = rms_norm130_reindex_pad[v0, v1, v2, v3] for i0_i1_fused_2 in T.unroll(T.int64(16)): for i2_2 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): + with T.sblock("matmul_update"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2) v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) @@ -831,7 +831,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * dequantize_intermediate_intermediate_local[v_k, v_i2] for ax0 in T.unroll(T.int64(16)): for ax1 in T.vectorized(T.int64(8)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py index 261981c5e46c..c38586dd8a19 100644 --- a/tests/python/dlight/test_gpu_matmul_tensorize.py +++ b/tests/python/dlight/test_gpu_matmul_tensorize.py @@ -40,9 +40,9 @@ class TestMatmulTensorize(BaseBeforeAfter): @T.prim_func def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(256, 256, 256): - with T.block("compute"): + with T.sblock("compute"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(X[v_i, v_k], W[v_j, v_k]) T.writes(compute[v_i, v_j]) @@ -53,7 +53,7 @@ def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16" @T.prim_func def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_a") @@ -65,13 +65,13 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("compute_o_init"): + with T.sblock("compute_o_init"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) T.reads() T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_init_o"): + with T.sblock("compute_init_o"): v1_i_init_o = T.axis.spatial(1, 0) v2_i_init_o = T.axis.spatial(1, 0) T.reads() @@ -83,30 +83,30 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("X_reindex_shared.dyn"): + with T.sblock("X_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(X[v1, v2]) T.writes(X_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("W_reindex_shared.dyn"): + with T.sblock("W_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(W[v1, v2]) T.writes(W_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): + with T.sblock("X_reindex_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(16, ax3_0_0 * 4 + ax3_0_1 + ax1_0) @@ -117,7 +117,7 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): + with T.sblock("W_reindex_shared.dyn_wmma.matrix_b_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) v2_o = T.axis.spatial(16, ax3_0_0 * 4 + ax3_0_1 + ax1_0) @@ -127,14 +127,14 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("compute_o_update"): + with T.sblock("compute_o_update"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) v3_o = T.axis.reduce(16, ax3_0_0 * 4 + ax3_0_1) T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_o"): + with T.sblock("compute_o"): v1_i_o = T.axis.spatial(1, 0) v2_i_o = T.axis.spatial(1, 0) v3_i_o = T.axis.reduce(1, 0) @@ -145,7 +145,7 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): + with T.sblock("compute_reindex_shared.dyn_wmma.accumulator_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) @@ -157,13 +157,13 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 for ax0_ax1_fused_0 in range(8): for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("compute_reindex_shared.dyn"): + with T.sblock("compute_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) T.reads(compute_reindex_shared_dyn[v0, v1, v2]) T.writes(compute[v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] # fmt: on @@ -178,9 +178,9 @@ def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.ha m = T.int32() X = T.match_buffer(var_X, (m, 256), "float16") compute = T.match_buffer(var_compute, (m, 15)) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(m, 15, 256): - with T.block("compute"): + with T.sblock("compute"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(X[v_i, v_k], W[v_j, v_k]) T.writes(compute[v_i, v_j]) @@ -194,7 +194,7 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. m = T.int32() X = T.match_buffer(var_X, (m, 256), "float16") compute = T.match_buffer(var_compute, (m, 15)) - # with T.block("root"): + # with T.sblock("root"): compute_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 64), scope="local") X_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 256), "float16", scope="shared") W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", scope="shared") @@ -206,7 +206,7 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(4, 2): for ax2_3_1_init in T.vectorized(2): - with T.block("compute_init"): + with T.sblock("compute_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) @@ -218,29 +218,29 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("X_reindex_pad_shared"): + with T.sblock("X_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(X[v1, v2]) T.writes(X_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("W_reindex_pad_shared"): + with T.sblock("W_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(W[v1, v2]) T.writes(W_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], T.float16(0)) for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): for ax2_3_1 in T.vectorized(2): - with T.block("compute_update"): + with T.sblock("compute_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) @@ -250,7 +250,7 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. compute_reindex_pad_local[0, v1, v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32", X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0, v2, v3]) for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): - with T.block("compute_reindex_pad_local"): + with T.sblock("compute_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) @@ -271,18 +271,18 @@ def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Bu lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16") lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16") p_output0_intermediate = T.match_buffer(p_output0, (T.int32(1), n, T.int32(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((T.int32(4096), T.int32(2048)), "float16") var_NT_matmul_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") var_T_divide_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") for i, j in T.grid(T.int32(4096), T.int32(2048)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv686[v_i, v_j // T.int32(8)], lv687[v_i, v_j // T.int32(32)]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v_i, v_j // T.int32(8)], T.Cast("uint32", v_j % T.int32(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v_i, v_j // T.int32(32)] for i0, i1, i2, k in T.grid(T.int32(1), n, T.int32(4096), T.int32(2048)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv42[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -290,13 +290,13 @@ def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Bu var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv42[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv3[v_ax0, v_ax1, v_ax2]) T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2]) var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * T.float16(0.5) for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -309,7 +309,7 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16") lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16") p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") - # with T.block("root"): + # with T.sblock("root"): lv42_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="shared.dyn") p_output0_intermediate_1_reindex_shared_dyn = T.alloc_buffer((1, 4096, 2048), "float16", scope="shared.dyn") lv42_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="wmma.matrix_a") @@ -321,13 +321,13 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("NT_matmul_o_init"): + with T.sblock("NT_matmul_o_init"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("NT_matmul_init_o"): + with T.sblock("NT_matmul_init_o"): v1_i_init_o = T.axis.spatial(1, 0) v2_i_init_o = T.axis.spatial(1, 0) T.reads() @@ -339,30 +339,30 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("lv42_reindex_pad_shared.dyn"): + with T.sblock("lv42_reindex_pad_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(lv42[v0, v1, v2]) T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0)) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("p_output0_intermediate_1_reindex_shared.dyn"): + with T.sblock("p_output0_intermediate_1_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32]) T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32] for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"): + with T.sblock("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) @@ -373,7 +373,7 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"): + with T.sblock("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) @@ -383,14 +383,14 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), C = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("NT_matmul_o_update"): + with T.sblock("NT_matmul_o_update"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) v3_o = T.axis.reduce(128, ax3_0_0 * 4 + ax3_0_1) T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("NT_matmul_o"): + with T.sblock("NT_matmul_o"): v1_i_o = T.axis.spatial(1, 0) v2_i_o = T.axis.spatial(1, 0) v3_i_o = T.axis.reduce(1, 0) @@ -401,7 +401,7 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"): + with T.sblock("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) @@ -413,14 +413,14 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), for ax0_ax1_fused_0 in range(8): for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"): + with T.sblock("var_NT_matmul_intermediate_reindex_pad_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < n) T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]) T.writes(p_output0_intermediate[0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2] # fmt: on @@ -430,9 +430,9 @@ class TestMatmulInt8Tensorize(BaseBeforeAfter): @T.prim_func def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, r in T.grid(256, 256, 256): - with T.block("compute"): + with T.sblock("compute"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, r]) T.reads(X[v_i, v_k], W[v_j, v_k]) T.writes(compute[v_i, v_j]) @@ -443,7 +443,7 @@ def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), com @T.prim_func def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_a") @@ -455,13 +455,13 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("compute_o_init"): + with T.sblock("compute_o_init"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) T.reads() T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_init_o"): + with T.sblock("compute_init_o"): v1_i_init_o = T.axis.spatial(1, 0) v2_i_init_o = T.axis.spatial(1, 0) T.reads() @@ -473,30 +473,30 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("X_reindex_shared.dyn"): + with T.sblock("X_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(X[v1, v2]) T.writes(X_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] for ax0_ax1_fused_0 in range(1): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("W_reindex_shared.dyn"): + with T.sblock("W_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(W[v1, v2]) T.writes(W_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): + with T.sblock("X_reindex_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) @@ -507,7 +507,7 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): + with T.sblock("W_reindex_shared.dyn_wmma.matrix_b_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) @@ -517,14 +517,14 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("compute_o_update"): + with T.sblock("compute_o_update"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) v3_o = T.axis.reduce(16, ax3_0_0 + ax3_0_1) T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_o"): + with T.sblock("compute_o"): v1_i_o = T.axis.spatial(1, 0) v2_i_o = T.axis.spatial(1, 0) v3_i_o = T.axis.reduce(1, 0) @@ -535,7 +535,7 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): + with T.sblock("compute_reindex_shared.dyn_wmma.accumulator_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) @@ -547,13 +547,13 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c for ax0_ax1_fused_0 in range(8): for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("compute_reindex_shared.dyn"): + with T.sblock("compute_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) T.reads(compute_reindex_shared_dyn[v0, v1, v2]) T.writes(compute[v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] # fmt: on @@ -566,9 +566,9 @@ def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.ha m = T.int32() A = T.match_buffer(var_A, (1, m, 22016), "int8") matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(1, m, 4096, 22016): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(matmul_1[v_i0, v_i1, v_i2]) @@ -582,7 +582,7 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. m = T.int32() A = T.match_buffer(var_A, (1, m, 22016), "int8") matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") - # with T.block("root"): + # with T.sblock("root"): A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="shared.dyn") B_reindex_shared_dyn = T.alloc_buffer((1, 4096, 22016), "int8", scope="shared.dyn") A_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="wmma.matrix_a") @@ -594,13 +594,13 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("matmul_o_init"): + with T.sblock("matmul_o_init"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) T.reads() T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("matmul_init_o"): + with T.sblock("matmul_init_o"): v1_i_init_o = T.axis.spatial(1, 0) v2_i_init_o = T.axis.spatial(1, 0) T.reads() @@ -612,30 +612,30 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("A_reindex_pad_shared.dyn"): + with T.sblock("A_reindex_pad_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(A[v0, v1, v2]) T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0)) for ax0_ax1_fused_0 in range(1): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("B_reindex_shared.dyn"): + with T.sblock("B_reindex_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(B[v1, v2]) T.writes(B_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("A_reindex_pad_shared.dyn_wmma.matrix_a_o"): + with T.sblock("A_reindex_pad_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) @@ -646,7 +646,7 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "row_major") for ax0_0 in T.unroll(2): for ax1_0 in T.unroll(1): - with T.block("B_reindex_shared.dyn_wmma.matrix_b_o"): + with T.sblock("B_reindex_shared.dyn_wmma.matrix_b_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) @@ -656,14 +656,14 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. C = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "col_major") for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("matmul_o_update"): + with T.sblock("matmul_o_update"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) v3_o = T.axis.reduce(1376, ax3_0_0 + ax3_0_1) T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("matmul_o"): + with T.sblock("matmul_o"): v1_i_o = T.axis.spatial(1, 0) v2_i_o = T.axis.spatial(1, 0) v3_i_o = T.axis.reduce(1, 0) @@ -674,7 +674,7 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A_1.data, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, B_1.data, B_1.elem_offset // B_1.strides[0] // 16 * (B_1.strides[0] // 16) + B_1.elem_offset % B_1.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"): + with T.sblock("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) @@ -686,14 +686,14 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. for ax0_ax1_fused_0 in range(8): for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("matmul_1_reindex_pad_shared.dyn"): + with T.sblock("matmul_1_reindex_pad_shared.dyn"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < m) T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2]) T.writes(matmul_1[0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2] # fmt: on @@ -720,7 +720,7 @@ def before( A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096): - with T.block("C"): + with T.sblock("C"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.writes(C[v_i0, v_i1, v_i2]) with T.init(): @@ -733,7 +733,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha batch_size = T.int32() A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") - # with T.block("root"): + # with T.sblock("root"): A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="shared") B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", scope="shared") A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup") @@ -746,7 +746,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha for ax1_1 in T.thread_binding(1, thread="threadIdx.y"): for ax2_1 in T.thread_binding(4, thread="threadIdx.z"): for ax1_2_init, ax2_2_init, ax1_3_init_0, ax2_3_init_0 in T.grid(2, 2, 1, 1): - with T.block("C_init_o"): + with T.sblock("C_init_o"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0) v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2_init + ax2_3_init_0) @@ -760,7 +760,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_ax2_fused_4 in T.vectorized(4): - with T.block("A_reindex_pad_shared"): + with T.sblock("A_reindex_pad_shared"): v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) @@ -772,7 +772,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_ax2_fused_4 in T.vectorized(4): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) @@ -781,7 +781,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha B_reindex_shared[v0, v1, v2] = B[v1, v2] for ax3_1 in range(4): for ax0_0, ax1_0_1 in T.grid(2, 1): - with T.block("A_reindex_pad_shared_metal.simdgroup_o"): + with T.sblock("A_reindex_pad_shared_metal.simdgroup_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax0_0) v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) @@ -791,7 +791,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) for ax0_0, ax1_0_1 in T.grid(2, 1): - with T.block("B_reindex_shared_metal.simdgroup_o"): + with T.sblock("B_reindex_shared_metal.simdgroup_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax0_0) v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) @@ -801,7 +801,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) for ax1_2, ax2_2 in T.grid(2, 2): - with T.block("C_update_o"): + with T.sblock("C_update_o"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2) v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2) @@ -813,7 +813,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset // B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): - with T.block("C_reindex_pad_metal.simdgroup_o"): + with T.sblock("C_reindex_pad_metal.simdgroup_o"): v0_o = T.axis.spatial(1, ax0_1) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_0_1) v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_0_1) @@ -827,7 +827,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_ax2_fused_4 in T.vectorized(4): - with T.block("C_reindex_pad_shared"): + with T.sblock("C_reindex_pad_shared"): v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) @@ -853,15 +853,15 @@ def before( compute = T.alloc_buffer((28672, 4096), "float16") B = T.alloc_buffer((28672, 4096), "float16") for i0, i1 in T.grid(28672, 4096): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(B0[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 4)), T.uint32(15))) for i0, i1 in T.grid(28672, 4096): - with T.block("dequantize"): + with T.sblock("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * B1[v_i0, v_i1 // 32] for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) with T.init(): C[v_i0, v_i1, v_i2] = T.float16(0) @@ -873,7 +873,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f batch_size = T.int32() A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") - # with T.block("root"): + # with T.sblock("root"): A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="shared") B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", scope="shared") A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup") @@ -886,7 +886,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f for ax1_1 in T.thread_binding(1, thread="threadIdx.y"): for ax2_1 in T.thread_binding(4, thread="threadIdx.z"): for ax1_2_init, ax2_2_init, ax1_3_init_0, ax2_3_init_0 in T.grid(2, 2, 1, 1): - with T.block("NT_matmul_init_o"): + with T.sblock("NT_matmul_init_o"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0) v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2_init + ax2_3_init_0) @@ -900,7 +900,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_ax2_fused_4 in T.vectorized(4): - with T.block("A_reindex_pad_shared"): + with T.sblock("A_reindex_pad_shared"): v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) @@ -912,7 +912,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_ax2_fused_4 in T.vectorized(4): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) @@ -921,7 +921,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f B_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1[v1, v2 // 32] for ax3_1 in range(4): for ax0_0, ax1_0_1 in T.grid(2, 1): - with T.block("A_reindex_pad_shared_metal.simdgroup_o"): + with T.sblock("A_reindex_pad_shared_metal.simdgroup_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax0_0) v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) @@ -931,7 +931,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) for ax0_0, ax1_0_1 in T.grid(2, 1): - with T.block("B_reindex_shared_metal.simdgroup_o"): + with T.sblock("B_reindex_shared_metal.simdgroup_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax0_0) v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) @@ -941,7 +941,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) for ax1_2, ax2_2 in T.grid(2, 2): - with T.block("NT_matmul_update_o"): + with T.sblock("NT_matmul_update_o"): v0_o = T.axis.spatial(1, ax0) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2) v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2) @@ -953,7 +953,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] // 8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): - with T.block("C_reindex_pad_metal.simdgroup_o"): + with T.sblock("C_reindex_pad_metal.simdgroup_o"): v0_o = T.axis.spatial(1, ax0_1) v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_0_1) v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_0_1) @@ -967,7 +967,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_ax2_fused_4 in T.vectorized(4): - with T.block("C_reindex_pad_shared"): + with T.sblock("C_reindex_pad_shared"): v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) diff --git a/tests/python/dlight/test_gpu_reduction.py b/tests/python/dlight/test_gpu_reduction.py index 14187e823546..ca51e1e40f97 100644 --- a/tests/python/dlight/test_gpu_reduction.py +++ b/tests/python/dlight/test_gpu_reduction.py @@ -32,16 +32,16 @@ class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) T.writes(B[v_i, v_j]) B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(C[v_i0, v_i1, v_i2]) @@ -59,13 +59,13 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T S = T.match_buffer(S_handle, (4096, 128), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") C = T.match_buffer(C_handle, (1, 1, 4096), "float16") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() C_rf_local = T.alloc_buffer((512, 1, 1, 4096), "float16", scope="local") for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"): for ax1_0_fused_1 in T.thread_binding(512, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused) T.reads() @@ -73,7 +73,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T C_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) for ax1_0_fused_0 in range(1): for ax1_1 in range(8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused) vax1_0_fused_0 = T.axis.reduce(1, ax1_0_fused_0) @@ -83,7 +83,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T C_rf_local[vax1_0_fused_1, 0, 0, v0] = C_rf_local[vax1_0_fused_1, 0, 0, v0] + V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) for ax1_fused in range(1): for ax0 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_0_fused_1 = T.axis.reduce(512, ax0) v0 = T.axis.spatial(4096, ax0_fused) T.reads(C_rf_local[vax1_0_fused_1, 0, 0, v0]) @@ -107,16 +107,16 @@ class Before: @T.prim_func def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j]) T.writes(B[v_i, v_j]) B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2]) T.writes(C[v_i0, v_i1, v_i2]) @@ -130,24 +130,24 @@ class After: @T.prim_func def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") for i2_i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for i2_i0_i1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for k_0_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1) v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + i2_i0_i1_fused_1) C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0) for k_0_fused_0, k_1 in T.grid(32, 8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1) v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + i2_i0_i1_fused_1) vk_0_fused_0, vk_1 = T.axis.remap("RR", [k_0_fused_0, k_1]) C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 32, v_i2]) for ax1_ax2_ax3_fused in T.thread_binding(16, thread="threadIdx.x"): for ax0_fused in T.thread_binding(16, thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vk_0_fused_1 = T.axis.reduce(16, ax0_fused) v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + ax1_ax2_ax3_fused) with T.init(): @@ -170,16 +170,16 @@ class Before: @T.prim_func def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j]) T.writes(B[v_i, v_j]) B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(C[v_i0, v_i1, v_i2]) @@ -196,14 +196,14 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T S = T.match_buffer(S_handle, (128, 4096), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") C = T.match_buffer(C_handle, (1, 1, 4096), "float16") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() C_rf_local = T.alloc_buffer((1024, 1, 1, 4096), "float16", scope="local") for ax0_0_fused in T.thread_binding(512, thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_1_init in range(8): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax0_1_init) T.reads() @@ -211,7 +211,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T C_rf_local[vax1_fused_1, 0, 0, v0] = T.float16(0) for ax1_fused_0 in range(4): for ax0_1 in range(8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax0_1) vax1_fused_0 = T.axis.reduce(4, ax1_fused_0) @@ -221,7 +221,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T for ax1_fused_0 in range(1): for ax0 in T.thread_binding(1024, thread="threadIdx.x"): for ax1_fused_1 in range(8): - with T.block("matmul"): + with T.sblock("matmul"): vax1_fused_1 = T.axis.reduce(1024, ax0) v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax1_fused_1) T.reads(C_rf_local[vax1_fused_1, 0, 0, v0]) @@ -246,16 +246,16 @@ class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) T.writes(B[v_i, v_j]) B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2]) T.writes(C[v_i0, v_i1, v_i2]) @@ -269,18 +269,18 @@ class After: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") for i2_0_i0_i1_fused_0 in T.thread_binding(32, thread="blockIdx.x"): for i2_0_i0_i1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for k_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for i2_1_init in range(8): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vk_fused_1 = T.axis.spatial(16, k_fused_1) v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init) C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0) for k_fused_0, i2_1 in T.grid(256, 8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vk_fused_1 = T.axis.spatial(16, k_fused_1) v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1) vk_fused_0 = T.axis.reduce(256, k_fused_0) @@ -288,7 +288,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") for ax1_ax2_ax3_fused_0 in T.thread_binding(16, thread="threadIdx.x"): for ax1_ax2_ax3_fused_1 in range(8): for ax0_fused in T.thread_binding(16, thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vk_fused_1 = T.axis.reduce(16, ax0_fused) v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) with T.init(): @@ -311,17 +311,17 @@ class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((4096, 4096), "float16") C = T.alloc_buffer((1, 1, 4096), "float16") for i, j in T.grid(4096, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) T.writes(B[v_i, v_j]) B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(C[v_i0, v_i1, v_i2]) @@ -329,7 +329,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") C[v_i0, v_i1, v_i2] = T.float16(0) C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, v_k] * B[v_i2, v_k] for i0, i1, i2 in T.grid(1, 1, 4096): - with T.block("sigmoid"): + with T.sblock("sigmoid"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(C[v_i0, v_i1, v_i2]) T.writes(D[v_i0, v_i1, v_i2]) @@ -344,14 +344,14 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T S = T.match_buffer(S_handle, (4096, 128), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") D = T.match_buffer(D_handle, (1, 1, 4096), "float16") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() C_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") C_rf_local = T.alloc_buffer((512, 1, 1, 4096), "float16", scope="local") for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"): for ax1_0_fused_1 in T.thread_binding(512, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused) T.reads() @@ -359,7 +359,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T C_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) for ax1_0_fused_0 in range(1): for ax1_1 in range(8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused) vax1_0_fused_0 = T.axis.reduce(1, ax1_0_fused_0) @@ -369,7 +369,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T C_rf_local[vax1_0_fused_1, 0, 0, v0] = C_rf_local[vax1_0_fused_1, 0, 0, v0] + V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) for ax1_fused in range(1): for ax0 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_0_fused_1 = T.axis.reduce(512, ax0) v0 = T.axis.spatial(4096, ax0_fused) T.reads(C_rf_local[vax1_0_fused_1, 0, 0, v0]) @@ -378,7 +378,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T C_local[0, 0, v0] = T.float16(0) C_local[0, 0, v0] = C_local[0, 0, v0] + C_rf_local[vax1_0_fused_1, 0, 0, v0] for ax0 in range(1): - with T.block("sigmoid"): + with T.sblock("sigmoid"): v0 = T.axis.spatial(4096, ax0_fused + ax0) T.reads(C_local[0, 0, v0]) T.writes(D[0, 0, v0]) @@ -400,17 +400,17 @@ class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((4096, 4096), "float16") C_fp32 = T.alloc_buffer((1, 1, 4096), "float32") for i, j in T.grid(4096, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) T.writes(B[v_i, v_j]) B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(C_fp32[v_i0, v_i1, v_i2]) @@ -418,7 +418,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") C_fp32[v_i0, v_i1, v_i2] = T.float16(0) C_fp32[v_i0, v_i1, v_i2] = C_fp32[v_i0, v_i1, v_i2] + T.Cast("float32", V[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i2, v_k]) for i0, i1, i2 in T.grid(1, 1, 4096): - with T.block("cast"): + with T.sblock("cast"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(C_fp32[v_i0, v_i1, v_i2]) T.writes(C[v_i0, v_i1, v_i2]) @@ -433,14 +433,14 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T S = T.match_buffer(S_handle, (4096, 128), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") C = T.match_buffer(C_handle, (1, 1, 4096), "float16") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() C_fp32_local = T.alloc_buffer((1, 1, 4096), scope="local") C_fp32_rf_local = T.alloc_buffer((512, 1, 1, 4096), scope="local") for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"): for ax1_0_fused_1 in T.thread_binding(512, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused) T.reads() @@ -448,7 +448,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float32(0) for ax1_0_fused_0 in range(1): for ax1_1 in range(8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused) vax1_0_fused_0 = T.axis.reduce(1, ax1_0_fused_0) @@ -458,7 +458,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) for ax1_fused in range(1): for ax0 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_0_fused_1 = T.axis.reduce(512, ax0) v0 = T.axis.spatial(4096, ax0_fused) T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]) @@ -467,7 +467,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T C_fp32_local[0, 0, v0] = T.float32(0) C_fp32_local[0, 0, v0] = C_fp32_local[0, 0, v0] + C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] for ax0 in range(1): - with T.block("cast"): + with T.sblock("cast"): v0 = T.axis.spatial(4096, ax0_fused + ax0) T.reads(C_fp32_local[0, 0, v0]) T.writes(C[0, 0, v0]) @@ -490,13 +490,13 @@ def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), T.func_attr({"global_symbol": "main", "tir.noalias": True}) Ared_temp = T.alloc_buffer((1, 1)) for ax0 in range(4096): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v0 = T.axis.reduce(4096, ax0) with T.init(): Ared_temp[0, 0] = T.float32(0) Ared_temp[0, 0] = Ared_temp[0, 0] + T.Cast("float32", A[0, 0, v0]) * T.Cast("float32", A[0, 0, v0]) for ax0 in range(4096): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v0 = T.axis.spatial(4096, ax0) rms_norm[0, v0] = T.Cast("float16", T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp[0, 0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) @@ -508,14 +508,14 @@ def main(A_handle: T.handle, B_handle: T.handle, rms_norm_handle: T.handle): A = T.match_buffer(A_handle, (1, 1, 4096), "float16") B = T.match_buffer(B_handle, (4096,), "float16") rms_norm = T.match_buffer(rms_norm_handle, (1, 4096), "float16") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared") Ared_temp_rf_local = T.alloc_buffer((1024, 1, 1), scope="local") for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("Ared_temp_rf_init"): + with T.sblock("Ared_temp_rf_init"): vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) v0 = T.axis.spatial(T.int64(1), T.int64(0)) T.reads() @@ -523,7 +523,7 @@ def main(A_handle: T.handle, B_handle: T.handle, rms_norm_handle: T.handle): Ared_temp_rf_local[vax1_fused_1, 0, 0] = T.float32(0) for ax1_fused_0 in range(4): for u in range(1): - with T.block("Ared_temp_rf_update"): + with T.sblock("Ared_temp_rf_update"): vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) v0 = T.axis.spatial(T.int64(1), T.int64(0)) vax1_fused_0 = T.axis.reduce(4, ax1_fused_0) @@ -532,7 +532,7 @@ def main(A_handle: T.handle, B_handle: T.handle, rms_norm_handle: T.handle): Ared_temp_rf_local[vax1_fused_1, 0, 0] = Ared_temp_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", A[0, 0, vax1_fused_0 * 1024 + vax1_fused_1]) * T.Cast("float32", A[0, 0, vax1_fused_0 * 1024 + vax1_fused_1]) for ax1_fused in range(T.int64(1)): for ax0 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): vax1_fused_1 = T.axis.reduce(1024, ax0) v0 = T.axis.spatial(T.int64(1), T.int64(0)) T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0]) @@ -542,7 +542,7 @@ def main(A_handle: T.handle, B_handle: T.handle, rms_norm_handle: T.handle): Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] + Ared_temp_rf_local[vax1_fused_1, 0, 0] for ax0_fused_0 in range(4): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1) T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0]) T.writes(rms_norm[0, v0]) @@ -564,13 +564,13 @@ def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), " p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") for i, j in T.grid(11008, 4096): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) @@ -578,7 +578,7 @@ def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), " var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for ax0, ax1, ax2 in T.grid(1, 1, 4096): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) @@ -594,14 +594,14 @@ def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), " for ax0_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax1_0_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) T.reads() T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) for ax1_0_fused_0, ax1_1 in T.grid(86, 8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1]) @@ -610,7 +610,7 @@ def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), " var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], T.Cast("uint32", (vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): for ax0 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_0_fused_1 = T.axis.reduce(16, ax0) v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax1_fused) T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) @@ -620,7 +620,7 @@ def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), " var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] for ax0_fused_0_1 in T.thread_binding(16, thread="threadIdx.x"): for ax0_fused_1 in range(1): - with T.block("T_add"): + with T.sblock("T_add"): v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_0_1 + ax0_fused_1) T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) @@ -643,7 +643,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")) temp_local = T.alloc_buffer((256,)) for j in T.serial(256): for k in T.serial(256): - with T.block("sum"): + with T.sblock("sum"): vj, vk = T.axis.remap("SR", [j, k]) T.reads(A[vk, vj]) T.writes(temp_local[vj]) @@ -651,7 +651,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")) temp_local[vj] = T.float32(0) temp_local[vj] = temp_local[vj] + A[vk, vj] for i, j in T.grid(256, 256): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(temp_local[vj]) T.writes(B[vi, vj]) @@ -667,14 +667,14 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")) for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("sum_rf_init"): + with T.sblock("sum_rf_init"): vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) T.reads() T.writes(temp_local_rf_local[vax1_fused_1, v0]) temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) for ax1_fused_0, u in T.grid(16, 1): - with T.block("sum_rf_update"): + with T.sblock("sum_rf_update"): vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) vax1_fused_0 = T.axis.reduce(16, ax1_fused_0) @@ -683,7 +683,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")) temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[vax1_fused_0 * 16 + vax1_fused_1, v0] for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): for ax0 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("sum"): + with T.sblock("sum"): vax1_fused_1 = T.axis.reduce(16, ax0) v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax1_fused) T.reads(temp_local_rf_local[vax1_fused_1, v0]) @@ -694,7 +694,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")) for ax0_ax1_fused_0 in range(16): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("add"): + with T.sblock("add"): v0 = T.axis.spatial(256, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 16) v1 = T.axis.spatial(256, ax0_fused_0 * 16 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 16) T.reads(temp_local_shared[v1]) @@ -718,7 +718,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): temp_local = T.alloc_buffer((256,)) for i in T.serial(256): for k in T.serial(256): - with T.block("sum"): + with T.sblock("sum"): vi, vk = T.axis.remap("SR", [i, k]) T.reads(A[vi, vk]) T.writes(temp_local[vi]) @@ -726,7 +726,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): temp_local[vi] = T.float32(0) temp_local[vi] = temp_local[vi] + A[vi, vk] for i in T.grid(256): - with T.block("add"): + with T.sblock("add"): vi = T.axis.remap("S", [i]) T.reads(temp_local[vi]) T.writes(B[vi,]) @@ -737,25 +737,25 @@ class Expected: @T.prim_func def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): temp_local_local = T.alloc_buffer((256,), scope="local") temp_local_rf_local = T.alloc_buffer((256, 256), scope="local") for ax0_fused in T.thread_binding(256, thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("sum_rf_init"): + with T.sblock("sum_rf_init"): vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) T.reads() T.writes(temp_local_rf_local[vax1_fused_1, v0]) temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) for ax1_fused_0, u in T.grid(1, 1): - with T.block("sum_rf_update"): + with T.sblock("sum_rf_update"): vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) T.reads(temp_local_rf_local[vax1_fused_1, v0], A[v0, vax1_fused_0 * 256 + vax1_fused_1]) T.writes(temp_local_rf_local[vax1_fused_1, v0]) temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[v0, vax1_fused_0 * 256 + vax1_fused_1] for ax1_fused in range(1): for ax0 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum"): + with T.sblock("sum"): vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) T.reads(temp_local_rf_local[vax1_fused_1, v0]) T.writes(temp_local_local[v0]) @@ -763,7 +763,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): temp_local_local[v0] = T.float32(0) temp_local_local[v0] = temp_local_local[v0] + temp_local_rf_local[vax1_fused_1, v0] for ax0 in range(1): - with T.block("add"): + with T.sblock("add"): v0 = T.axis.spatial(256, ax0_fused + ax0) T.reads(temp_local_local[v0]) T.writes(B[v0]) @@ -783,17 +783,17 @@ class Module: @T.prim_func def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float16"), lv1: T.Buffer((1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 2560), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): p_output0_intermediate_1 = T.alloc_buffer((2560, 2560), "float16") var_matmul_intermediate = T.alloc_buffer((1, 2560), "float16") for i, j in T.grid(2560, 2560): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv9[v_i, v_j // 8], lv10[v_i, v_j // 32]) T.writes(p_output0_intermediate_1[v_i, v_j]) p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv10[v_i, v_j // 32] for i0, i1, k in T.grid(1, 2560, 2560): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(lv1[v_i0, v_k], p_output0_intermediate_1[v_k, v_i1]) T.writes(var_matmul_intermediate[v_i0, v_i1]) @@ -801,7 +801,7 @@ def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float var_matmul_intermediate[v_i0, v_i1] = T.float16(0) var_matmul_intermediate[v_i0, v_i1] = var_matmul_intermediate[v_i0, v_i1] + lv1[v_i0, v_k] * p_output0_intermediate_1[v_k, v_i1] for i0, i1 in T.grid(1, 2560): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(var_matmul_intermediate[v_i0, v_i1]) T.writes(p_output0_intermediate[v_i0, v_i1]) @@ -812,21 +812,21 @@ class Expected: @T.prim_func def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float16"), lv1: T.Buffer((1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 2560), "float32")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): var_matmul_intermediate_local = T.alloc_buffer((1, 2560), "float16", scope="local") var_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 2560), "float16", scope="local") for ax0_0_fused_0 in T.thread_binding(20, thread="blockIdx.x"): for ax0_0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): for ax0_1_init in range(8): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax0_0_fused_1 * 8 + ax0_1_init) T.reads() T.writes(var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0]) var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0] = T.float16(0) for ax1_fused_0, ax0_1 in T.grid(160, 8): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax0_0_fused_1 * 8 + ax0_1) vax1_fused_0 = T.axis.reduce(160, ax1_fused_0) @@ -836,7 +836,7 @@ def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float for ax1_fused_0 in T.thread_binding(16, thread="threadIdx.x"): for ax1_fused_1 in range(8): for ax0 in T.thread_binding(16, thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_fused_1 = T.axis.reduce(16, ax0) v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax1_fused_0 * 8 + ax1_fused_1) T.reads(var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0]) @@ -846,7 +846,7 @@ def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float var_matmul_intermediate_local[0, v0] = var_matmul_intermediate_local[0, v0] + var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0] for ax0_fused_0 in T.thread_binding(16, thread="threadIdx.x"): for ax0_fused_1 in range(8): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax0_fused_0 * 8 + ax0_fused_1) T.reads(var_matmul_intermediate_local[0, v0]) T.writes(p_output0_intermediate[0, v0]) @@ -868,9 +868,9 @@ def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64 n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16") - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(100), n): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -885,12 +885,12 @@ def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64 n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16") - # with T.block("root"): + # with T.sblock("root"): matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16", scope="local") for ax0_ax1_fused_0 in T.thread_binding(T.int64(320), thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"): for ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) // T.int64(100)) v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) % T.int64(100)) @@ -898,7 +898,7 @@ def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64 T.writes(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) for ax2_fused_0, u in T.grid((n + T.int64(15)) // T.int64(16), 1): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) // T.int64(100)) v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) % T.int64(100)) @@ -909,7 +909,7 @@ def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64 matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] * B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1] for ax1_ax2_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"): for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax2_fused_1 = T.axis.reduce(T.int64(16), ax0) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused_0 // T.int64(10)) v1 = T.axis.spatial(T.int64(100), ax0_ax1_fused_0 % T.int64(10) * T.int64(10) + ax1_ax2_fused) @@ -936,23 +936,23 @@ def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_ast kv_seq_len = T.int64() lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") - # with T.block("root"): + # with T.sblock("root"): var_T_repeat_intermediate = T.alloc_buffer((T.int64(1), kv_seq_len, T.int64(32), T.int64(128)), "float16") var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), kv_seq_len, T.int64(128)), "float16") for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), kv_seq_len, T.int64(32), T.int64(128)): - with T.block("T_repeat"): + with T.sblock("T_repeat"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(lv716[v_ax0, v_ax1, v_ax2 // T.int64(4), v_ax3]) T.writes(var_T_repeat_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_repeat_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv716[v_ax0, v_ax1, v_ax2 // T.int64(4), v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), kv_seq_len, T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_repeat_intermediate[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_repeat_intermediate[v_ax0, v_ax2, v_ax1, v_ax3] for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), kv_seq_len): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(astype66[v_i0, v_i1, v_i2, v_k], var_T_transpose_intermediate[v_i0, v_i1, v_k, v_i3]) T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -967,13 +967,13 @@ def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_ast kv_seq_len = T.int64() lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") - # with T.block("root"): + # with T.sblock("root"): var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16", scope="local") for ax0_0_ax1_fused_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"): for ax0_0_ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax0_1_init in range(T.int64(4)): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) // T.int64(128) * T.int64(4) + ax0_1_init) v1 = T.axis.spatial(T.int64(128), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) % T.int64(128)) @@ -981,7 +981,7 @@ def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_ast T.writes(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) for ax2_fused_0, ax0_1 in T.grid((kv_seq_len + T.int64(15)) // T.int64(16), T.int64(4)): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) // T.int64(128) * T.int64(4) + ax0_1) v1 = T.axis.spatial(T.int64(128), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) % T.int64(128)) @@ -993,7 +993,7 @@ def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_ast for ax1_0_ax2_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax1_1 in range(T.int64(4)): for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax2_fused_1 = T.axis.reduce(T.int64(16), ax0) v0 = T.axis.spatial(T.int64(32), ax0_0_ax1_fused_0 // T.int64(8) * T.int64(4) + ax1_1) v1 = T.axis.spatial(T.int64(128), ax0_0_ax1_fused_0 % T.int64(8) * T.int64(16) + ax1_0_ax2_fused) @@ -1024,7 +1024,7 @@ def main( C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) C_temp = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), vocab_size, T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(B[v_i0, v_i1, v_k], A[v_k, v_i2]) T.writes(C_temp[v_i0, v_i1, v_i2]) @@ -1034,7 +1034,7 @@ def main( C_temp[v_i0, v_i1, v_i2] + B[v_i0, v_i1, v_k] * A[v_k, v_i2] ) for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), vocab_size): - with T.block("epilogue"): + with T.sblock("epilogue"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(C_temp[v_i0, v_i1, v_i2]) T.writes(C[v_i0, v_i1, v_i2]) @@ -1049,20 +1049,20 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), " vocab_size = T.int64() A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) - # with T.block("root"): + # with T.sblock("root"): C_temp_local = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16", scope="local") C_temp_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(1), vocab_size), "float16", scope="local") for ax0_fused_0 in T.thread_binding(vocab_size, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1), thread="threadIdx.x"): for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): - with T.block("matmul_rf_init"): + with T.sblock("matmul_rf_init"): vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1) v0 = T.axis.spatial(vocab_size, ax0_fused_0 + ax0_fused_1) T.reads() T.writes(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1_fused_0, u in T.grid(T.int64(256), 1): - with T.block("matmul_rf_update"): + with T.sblock("matmul_rf_update"): vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1) v0 = T.axis.spatial(vocab_size, ax0_fused_0 + ax0_fused_1) vax1_fused_0 = T.axis.reduce(T.int64(256), ax1_fused_0) @@ -1071,7 +1071,7 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), " C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] = C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] + B[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * A[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0] for ax1_fused in T.thread_binding(T.int64(1), thread="threadIdx.x"): for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): - with T.block("matmul"): + with T.sblock("matmul"): vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused_0]) T.reads(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) T.writes(C_temp_local[T.int64(0), T.int64(0), v0]) @@ -1080,7 +1080,7 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), " C_temp_local[T.int64(0), T.int64(0), v0] = C_temp_local[T.int64(0), T.int64(0), v0] + C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] for ax0_fused_0_1 in T.thread_binding(T.int64(1), thread="threadIdx.x"): for ax0_fused_1 in range(T.int64(1)): - with T.block("epilogue"): + with T.sblock("epilogue"): v0 = T.axis.spatial(vocab_size, ax0_fused_0) T.reads(C_temp_local[T.int64(0), T.int64(0), v0]) T.writes(C[T.int64(0), T.int64(0), v0]) @@ -1101,13 +1101,13 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer(( T.func_attr({"tir.noalias": True}) NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1)), "float16") for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) with T.init(): NT_matmul_intermediate[v_i0, v_i1] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + A[v_i0, v_k] * weight[v_i1, v_k] for i0, i1 in T.grid(T.int64(1), T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) out[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0, v_i1]) @@ -1121,19 +1121,19 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer(( NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(1024), T.int64(1), T.int64(1)), "float16", scope="local") for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax1_fused_1 = T.axis.spatial(T.int64(1024), ax1_fused_1) v0 = T.axis.spatial(T.int64(1), T.int64(0)) NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = T.float16(0) for ax1_fused_0, u in T.grid(T.int64(2), 1): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax1_fused_1 = T.axis.spatial(T.int64(1024), ax1_fused_1) v0 = T.axis.spatial(T.int64(1), T.int64(0)) vax1_fused_0 = T.axis.reduce(T.int64(2), ax1_fused_0) NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] + A[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1] * weight[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1] for ax1_fused in range(T.int64(1)): for ax0 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax1_fused_1 = T.axis.reduce(T.int64(1024), ax0) v0 = T.axis.spatial(T.int64(1), T.int64(0)) with T.init(): @@ -1141,7 +1141,7 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer(( NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] = NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] + NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] for ax0_fused_0 in range(T.int64(1)): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) out[T.int64(0), T.int64(0)] = T.sigmoid(NT_matmul_intermediate_shared[T.int64(0), T.int64(0)]) @@ -1161,9 +1161,9 @@ class Before: @T.prim_func(private=True) def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(1)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2]) T.writes(matmul[v_i0, v_i1, v_i2]) diff --git a/tests/python/dlight/test_gpu_rmsnorm.py b/tests/python/dlight/test_gpu_rmsnorm.py index a186c2d19ae9..6003eed3e92c 100644 --- a/tests/python/dlight/test_gpu_rmsnorm.py +++ b/tests/python/dlight/test_gpu_rmsnorm.py @@ -43,7 +43,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T n = T.int32() data = T.match_buffer(var_data, (1, n, 4096), "float16") T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") - # with T.block("root"): + # with T.sblock("root"): T_cast_1 = T.alloc_buffer((1, n, 4096)) T_multiply = T.alloc_buffer((1, n, 4096)) T_multiply_red = T.alloc_buffer((1, n)) @@ -51,19 +51,19 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T T_cast_2 = T.alloc_buffer((4096,)) T_rms_norm = T.alloc_buffer((1, n, 4096)) for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_cast"): + with T.sblock("T_cast"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(data[v_ax0, v_ax1, v_ax2]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2]) T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data[v_ax0, v_ax1, v_ax2]) for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_cast_1[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2] for ax0, ax1, k2 in T.grid(1, n, 4096): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(T_multiply[v_ax0, v_ax1, v_k2]) T.writes(T_multiply_red[v_ax0, v_ax1]) @@ -71,25 +71,25 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2] for ax0, ax1 in T.grid(1, n): - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0 in range(4096): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): v_ax0 = T.axis.spatial(4096, ax0) T.reads(weight[v_ax0]) T.writes(T_cast_2[v_ax0]) T_cast_2[v_ax0] = T.Cast("float32", weight[v_ax0]) for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2] for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) T.writes(T_cast[v_ax0, v_ax1, v_ax2]) @@ -103,7 +103,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T n = T.int32() data = T.match_buffer(var_data, (1, n, 4096), "float16") T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") - # with T.block("root"): + # with T.sblock("root"): T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local") T_multiply_red_local = T.alloc_buffer((1, n), scope="local") rsqrt_shared = T.alloc_buffer((1, n), scope="shared") @@ -113,7 +113,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): for ax2_1 in range(1): for ax2_2 in T.vectorized(8): - with T.block("data_local"): + with T.sblock("data_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(n, ax0_ax1_fused) v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) @@ -121,7 +121,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T T.writes(data_local[v0, v1, v2]) data_local[v0, v1, v2] = data[v0, v1, v2] for ax0 in range(8): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) @@ -129,7 +129,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) T_multiply_local[v_ax0, v_ax1, v_ax2] = T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) for ax0 in range(8): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) @@ -138,7 +138,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T with T.init(): T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) T.reads(T_multiply_red_local[v_ax0, v_ax1]) @@ -146,7 +146,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): for ax0_1, ax0_2 in T.grid(1, 8): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) @@ -154,7 +154,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32", weight[v_ax2]) for ax0 in T.vectorized(8): - with T.block("T_cast_local"): + with T.sblock("T_cast_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(n, ax0_ax1_fused) v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) @@ -175,19 +175,19 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T n = T.int32() data = T.match_buffer(var_data, (1, n, 4096)) T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) - # with T.block("root"): + # with T.sblock("root"): T_multiply = T.alloc_buffer((1, n, 4096)) T_multiply_red = T.alloc_buffer((1, n)) rsqrt = T.alloc_buffer((1, n)) T_rms_norm = T.alloc_buffer((1, n, 4096)) for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(data[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = data[v_ax0, v_ax1, v_ax2] * data[v_ax0, v_ax1, v_ax2] for ax0, ax1, k2 in T.grid(1, n, 4096): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(T_multiply[v_ax0, v_ax1, v_k2]) T.writes(T_multiply_red[v_ax0, v_ax1]) @@ -195,19 +195,19 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2] for ax0, ax1 in T.grid(1, n): - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rsqrt[v_ax0, v_ax1], data[v_ax0, v_ax1, v_ax2], weight[v_ax2]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * data[v_ax0, v_ax1, v_ax2] * weight[v_ax2] for ax0, ax1, ax2 in T.grid(1, n, 4096): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) T.writes(T_cast[v_ax0, v_ax1, v_ax2]) @@ -221,7 +221,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T n = T.int32() data = T.match_buffer(var_data, (1, n, 4096)) T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) - # with T.block("root"): + # with T.sblock("root"): T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local") T_multiply_red_local = T.alloc_buffer((1, n), scope="local") rsqrt_shared = T.alloc_buffer((1, n), scope="shared") @@ -231,7 +231,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T for ax2_0 in T.thread_binding(512, thread="threadIdx.x"): for ax2_1 in range(1): for ax2_2 in T.vectorized(8): - with T.block("data_local"): + with T.sblock("data_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(n, ax0_ax1_fused) v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 * 8 + ax2_2) @@ -239,7 +239,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T T.writes(data_local[v0, v1, v2]) data_local[v0, v1, v2] = data[v0, v1, v2] for ax0 in range(8): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0) @@ -247,7 +247,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2]) T_multiply_local[v_ax0, v_ax1, v_ax2] = data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2] for ax0 in range(8): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0) @@ -256,7 +256,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T with T.init(): T_multiply_red_local[v_ax0, v_ax1] = T.float32(0) T_multiply_red_local[v_ax0, v_ax1] = T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2] - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) T.reads(T_multiply_red_local[v_ax0, v_ax1]) @@ -264,7 +264,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T rsqrt_shared[v_ax0, v_ax1] = T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)) for ax0_0 in T.thread_binding(512, thread="threadIdx.x"): for ax0_1, ax0_2 in T.grid(1, 8): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8 + ax0_2) @@ -272,7 +272,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2]) T_rms_norm_local[v_ax0, v_ax1, v_ax2] = rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2] for ax0 in T.vectorized(8): - with T.block("T_cast_local"): + with T.sblock("T_cast_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(n, ax0_ax1_fused) v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0) diff --git a/tests/python/dlight/test_gpu_transpose.py b/tests/python/dlight/test_gpu_transpose.py index 6aea731d5c02..bdd79e8f5efe 100644 --- a/tests/python/dlight/test_gpu_transpose.py +++ b/tests/python/dlight/test_gpu_transpose.py @@ -40,7 +40,7 @@ class Before: def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] @@ -49,7 +49,7 @@ class After: @T.prim_func def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_shared = T.alloc_buffer((T.int64(512), T.int64(4096)), scope="shared") for ax0_0_0 in T.thread_binding(T.int64(512), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): @@ -57,18 +57,18 @@ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_tr for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax0_ax1_fused_3 in T.unroll(T.int64(1)): - with T.block("rxplaceholder_shared"): + with T.sblock("rxplaceholder_shared"): v0 = T.axis.spatial(T.int64(512), ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) // T.int64(8)) v1 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(8) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) % T.int64(8)) T.reads(rxplaceholder[v0, v1]) T.writes(rxplaceholder_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 1]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 1]]}) rxplaceholder_shared[v0, v1] = rxplaceholder[v0, v1] for ax0_0_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for ax1_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax0_1_0 in range(T.int64(1)): for ax0_1_1 in range(T.int64(1)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v0 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(8) + ax0_0_1 + ax0_1_0 + ax0_1_1) v1 = T.axis.spatial(T.int64(512), ax1_0 * T.int64(16) + ax1_1) T.reads(rxplaceholder_shared[v1, v0]) @@ -87,13 +87,13 @@ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxpla T.func_attr({"tir.noalias": True}) decode = T.alloc_buffer((T.int64(4096), T.int64(4096))) for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) T.writes(decode[v_i, v_j]) decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(decode[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -111,18 +111,18 @@ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxpla for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax0_ax1_fused_3 in T.unroll(T.int64(8)): - with T.block("decode_shared"): + with T.sblock("decode_shared"): v0 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) // T.int64(64)) v1 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(64) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) % T.int64(64)) T.reads(rxplaceholder[v0 // T.int64(8), v1], rxplaceholder_1[v0 // T.int64(32), v1]) T.writes(decode_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 1]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 1]]}) decode_shared[v0, v1] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v0 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v0 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16))) for ax0_0_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for ax1_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax0_1_0 in range(T.int64(2)): for ax0_1_1 in T.vectorized(T.int64(4)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v0 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(64) + ax0_0_1 * T.int64(8) + ax0_1_0 * T.int64(4) + ax0_1_1) v1 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + ax1_1) T.reads(decode_shared[v1, v0]) @@ -141,13 +141,13 @@ def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.in T.func_attr({"tir.noalias": True}) decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) T.writes(decode_1[v_i, v_j]) decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(decode_1[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -158,7 +158,7 @@ class After: @T.prim_func def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): decode_1_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16", scope="shared") for ax0_0_0 in T.thread_binding(T.int64(52), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"): @@ -166,19 +166,19 @@ def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.in for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax0_ax1_fused_3 in T.unroll(T.int64(10)): - with T.block("decode_1_shared"): + with T.sblock("decode_1_shared"): v0 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 * T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) // T.int64(82)) v1 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(80) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 * T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) % T.int64(82)) T.where(ax0_0_0 * T.int64(80) + (((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) + ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3) % T.int64(82) < T.int64(4096) and ((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) + ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3 < T.int64(1312)) T.reads(A[v0 // T.int64(10), v1], B[v0 // T.int64(40), v1]) T.writes(decode_1_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 1]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 1]]}) decode_1_shared[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v0 // T.int64(10), v1], T.Cast("uint32", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v0 // T.int64(40), v1] for ax0_0_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for ax1_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): for ax0_1_0 in range(T.int64(3)): for ax0_1_1 in T.vectorized(T.int64(4)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v0 = T.axis.spatial(T.int64(4096), (ax0_0_0 * T.int64(8) + ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + ax0_1_1)) v1 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + ax1_1) T.where((ax0_0_0 * T.int64(8) + ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + ax0_1_1) < T.int64(4096) and ax0_0_0 * T.int64(8) + ax0_0_1 < T.int64(410) and ax0_1_0 * T.int64(4) + ax0_1_1 < T.int64(10)) diff --git a/tests/python/dlight/test_primitives.py b/tests/python/dlight/test_primitives.py index 08505055a97b..c0aab0bf0125 100644 --- a/tests/python/dlight/test_primitives.py +++ b/tests/python/dlight/test_primitives.py @@ -24,21 +24,21 @@ @T.prim_func def main(p0: T.Buffer((), "int32"), T_stack: T.Buffer((T.int64(3),), "int32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): compile_engine_const = T.alloc_buffer((), "int32") compile_engine_const_1 = T.alloc_buffer((), "int32") - with T.block("compile_engine_const"): + with T.sblock("compile_engine_const"): vi = T.axis.spatial(1, T.int64(0)) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = 16 - with T.block("compile_engine_const_1"): + with T.sblock("compile_engine_const_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = 20 for ax0 in range(T.int64(3)): - with T.block("T_stack"): + with T.sblock("T_stack"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(compile_engine_const[()], p0[()], compile_engine_const_1[()]) T.writes(T_stack[v_ax0]) diff --git a/tests/python/ir/test_datatype_nv_fp8.py b/tests/python/ir/test_datatype_nv_fp8.py index 0c17e844757f..9cc5ec6e6ceb 100644 --- a/tests/python/ir/test_datatype_nv_fp8.py +++ b/tests/python/ir/test_datatype_nv_fp8.py @@ -57,7 +57,7 @@ def func( A_fp32 = T.match_buffer(a_fp32, [128], dtype="float32") A_roundtrip = T.match_buffer(a_roundtrip, [128], dtype=dtype) for i in range(128): - with T.block("fp8_unary"): + with T.sblock("fp8_unary"): vi = T.axis.spatial(128, i) A_add_B[vi] = A[vi] + B[vi] A_sub_B[vi] = A[vi] - B[vi] diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index aea8492485ed..fe6d751e5e9b 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -33,7 +33,7 @@ def func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_arg_info.py b/tests/python/meta_schedule/test_meta_schedule_arg_info.py index 62dcb52f7415..a8cf59318820 100644 --- a/tests/python/meta_schedule/test_meta_schedule_arg_info.py +++ b/tests/python/meta_schedule/test_meta_schedule_arg_info.py @@ -29,7 +29,7 @@ def Matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (256, 512), "float32") C = T.match_buffer(c, (128, 512), "float32") for i, j, k in T.grid(128, 256, 512): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_builder.py b/tests/python/meta_schedule/test_meta_schedule_builder.py index 6da0a089180c..2b98d6cd2255 100644 --- a/tests/python/meta_schedule/test_meta_schedule_builder.py +++ b/tests/python/meta_schedule/test_meta_schedule_builder.py @@ -49,7 +49,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -68,13 +68,13 @@ def matmul_relu( # pylint: disable=no-self-argument D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -90,7 +90,7 @@ def batch_matmul( # pylint: disable=no-self-argument B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) for n, i, j, k in T.grid(16, 128, 128, 128): - with T.block("update"): + with T.sblock("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) with T.init(): C[vn, vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_cost_model.py b/tests/python/meta_schedule/test_meta_schedule_cost_model.py index dadedcf601aa..914825b64203 100644 --- a/tests/python/meta_schedule/test_meta_schedule_cost_model.py +++ b/tests/python/meta_schedule/test_meta_schedule_cost_model.py @@ -47,7 +47,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -60,7 +60,7 @@ class FullModule: def main(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads() T.writes(T_full[v_ax0, v_ax1]) diff --git a/tests/python/meta_schedule/test_meta_schedule_database.py b/tests/python/meta_schedule/test_meta_schedule_database.py index f8b2354c33bf..92f494997992 100644 --- a/tests/python/meta_schedule/test_meta_schedule_database.py +++ b/tests/python/meta_schedule/test_meta_schedule_database.py @@ -43,7 +43,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -60,13 +60,13 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(16, 16): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -76,7 +76,7 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s def _schedule_matmul(sch: Schedule): - block = sch.get_block("matmul") + block = sch.get_sblock("matmul") i, j, k = sch.get_loops(block=block) i_tiles = [1, 1, 2, 512] j_tiles = [1, 512, 1, 2] diff --git a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py index b901c3ce1372..50e278f1beae 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py @@ -38,9 +38,9 @@ def matmul( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1, i2 in T.grid(512, 512, 512): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C[i, j], A[i, k], B[k, j]) T.writes(C[i, j]) @@ -60,9 +60,9 @@ def main(placeholder: T.Buffer((1, 16, 7, 7, 32), "float32"), placeholder_1: T.B # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): - with T.block("T_layout_trans_1"): + with T.sblock("T_layout_trans_1"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(1, 0) ax2 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused // 3584) @@ -209,7 +209,7 @@ def test_cpu_matmul(): def _create_schedule(): func = matmul sch = tir.Schedule(func, debug_mask="all") - block = sch.get_block("C") + block = sch.get_sblock("C") i, j, k = sch.get_loops(block) i_o, i_i = sch.split(i, factors=[None, 16]) # outer: 32 j_o, j_i = sch.split(j, factors=[None, 8]) # outer: 64 @@ -410,14 +410,14 @@ def func(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [64, 32], dtype="float32") C = T.match_buffer(c, [64, 32], dtype="float32") for i, j in T.grid(64, 32): # type: ignore - with T.block(): + with T.sblock(): T.reads([A[i, j], B[i, j]]) # type: ignore T.writes([B[i, j], C[i, j]]) # type: ignore - with T.block("B"): + with T.sblock("B"): T.reads([A[i, j]]) # type: ignore T.writes([B[i, j]]) # type: ignore B[i, j] = A[i, j] # type: ignore - with T.block("C"): + with T.sblock("C"): T.reads([B[i, j]]) # type: ignore T.writes([C[i, j]]) # type: ignore C[i, j] = B[i, j] # type: ignore @@ -708,7 +708,7 @@ def test_empty_feature(): @T.prim_func def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads() T.writes(T_full[v_ax0, v_ax1]) @@ -730,7 +730,7 @@ def test_gpu(): def _create_schedule(): func = matmul sch = tir.Schedule(func, debug_mask="all") - c = sch.get_block("C") + c = sch.get_sblock("C") c_local = sch.cache_write(c, 0, "local") i, j, k = sch.get_loops(c) # pylint: disable=invalid-name diff --git a/tests/python/meta_schedule/test_meta_schedule_measure_callback.py b/tests/python/meta_schedule/test_meta_schedule_measure_callback.py index 0b7b22d92bb7..030e9350725a 100644 --- a/tests/python/meta_schedule/test_meta_schedule_measure_callback.py +++ b/tests/python/meta_schedule/test_meta_schedule_measure_callback.py @@ -37,7 +37,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py index a318ea35158f..89c57c0fc877 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py +++ b/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py @@ -40,7 +40,7 @@ def main( C: T.Buffer((M, N), "float16"), # type: ignore ): for i, j, k in T.grid(M, N, K): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -57,7 +57,7 @@ def main( C: T.Buffer((M, N), "float32"), # type: ignore ): for i, j, k in T.grid(M, N, K): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -98,8 +98,8 @@ def test_f16f16f16_mma_gemm(): # fmt: off mod = Gemm_F16F16F16 sch = Schedule(mod) - b0 = sch.get_block(name="C", func_name="main") - b1 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") + b1 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") b2 = sch.reindex(block=b0, buffer=("write", 0)) b3 = sch.reindex(block=b0, buffer=("read", 0)) @@ -179,7 +179,7 @@ def test_f16f16f16_mma_gemm(): v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v102) sch.enter_postproc() - b103 = sch.get_block(name="root", func_name="main") + b103 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b103, ann_key="meta_schedule.unroll_explicit") b104, b105, b106, b107, b108, b109 = sch.get_child_blocks(b103) l110, l111, l112, l113 = sch.get_loops(block=b104) @@ -188,23 +188,23 @@ def test_f16f16f16_mma_gemm(): l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b107) l132, l133, l134, l135, l136, l137, l138, l139, l140, l141 = sch.get_loops(block=b108) l142, l143, l144 = sch.get_loops(block=b109) - b145 = sch.get_block(name="C_o", func_name="main") + b145 = sch.get_sblock(name="C_o", func_name="main") l146, l147, l148, l149, l150, l151, l152, l153, l154, l155 = sch.get_loops(block=b145) b156 = sch.decompose_reduction(block=b145, loop=l149) sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize") sch.annotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize", ann_val="mma_init_m16n8k8_f16") sch.unannotate(block_or_loop=b145, ann_key="meta_schedule.auto_tensorize_init") sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize_init") - b157 = sch.get_block(name="C_o_init", func_name="main") + b157 = sch.get_sblock(name="C_o_init", func_name="main") sch.unannotate(block_or_loop=b157, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b157, tensor_intrin="mma_init_m16n8k8_f16", preserve_unit_iters=True) - b158 = sch.get_block(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", func_name="main") + b158 = sch.get_sblock(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", func_name="main") sch.unannotate(block_or_loop=b158, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b158, tensor_intrin="mma_load_m16n8k8_f16_A_shared_dyn", preserve_unit_iters=True) - b159 = sch.get_block(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", func_name="main") + b159 = sch.get_sblock(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", func_name="main") sch.unannotate(block_or_loop=b159, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b159, tensor_intrin="mma_load_m16n8k8_f16_B_shared_dyn", preserve_unit_iters=True) - b160 = sch.get_block(name="C_o_update", func_name="main") + b160 = sch.get_sblock(name="C_o_update", func_name="main") sch.unannotate(block_or_loop=b160, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b160, tensor_intrin="mma_sync_m16n8k8_f16f16f16", preserve_unit_iters=True) mod = sch.mod @@ -218,8 +218,8 @@ def test_f16f16f32_mma_gemm(): sch = Schedule(mod) # fmt: off sch = Schedule(mod) - b0 = sch.get_block(name="C", func_name="main") - b1 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") + b1 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") b2 = sch.reindex(block=b0, buffer=("write", 0)) b3 = sch.reindex(block=b0, buffer=("read", 0)) @@ -299,7 +299,7 @@ def test_f16f16f32_mma_gemm(): v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v102) sch.enter_postproc() - b103 = sch.get_block(name="root", func_name="main") + b103 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b103, ann_key="meta_schedule.unroll_explicit") b104, b105, b106, b107, b108, b109 = sch.get_child_blocks(b103) l110, l111, l112, l113 = sch.get_loops(block=b104) @@ -310,23 +310,23 @@ def test_f16f16f32_mma_gemm(): sch.annotate(block_or_loop=l132, ann_key="pragma_auto_unroll_max_step", ann_val=0) sch.annotate(block_or_loop=l132, ann_key="pragma_unroll_explicit", ann_val=1) l142, l143, l144 = sch.get_loops(block=b109) - b145 = sch.get_block(name="C_o", func_name="main") + b145 = sch.get_sblock(name="C_o", func_name="main") l146, l147, l148, l149, l150, l151, l152, l153, l154, l155 = sch.get_loops(block=b145) b156 = sch.decompose_reduction(block=b145, loop=l149) sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize") sch.annotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize", ann_val="mma_init_m16n8k8_f32") sch.unannotate(block_or_loop=b145, ann_key="meta_schedule.auto_tensorize_init") sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize_init") - b157 = sch.get_block(name="C_o_init", func_name="main") + b157 = sch.get_sblock(name="C_o_init", func_name="main") sch.unannotate(block_or_loop=b157, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b157, tensor_intrin="mma_init_m16n8k8_f32", preserve_unit_iters=True) - b158 = sch.get_block(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", func_name="main") + b158 = sch.get_sblock(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", func_name="main") sch.unannotate(block_or_loop=b158, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b158, tensor_intrin="mma_load_m16n8k8_f16_A_shared_dyn", preserve_unit_iters=True) - b159 = sch.get_block(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", func_name="main") + b159 = sch.get_sblock(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", func_name="main") sch.unannotate(block_or_loop=b159, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b159, tensor_intrin="mma_load_m16n8k8_f16_B_shared_dyn", preserve_unit_iters=True) - b160 = sch.get_block(name="C_o_update", func_name="main") + b160 = sch.get_sblock(name="C_o_update", func_name="main") sch.unannotate(block_or_loop=b160, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b160, tensor_intrin="mma_sync_m16n8k8_f16f16f32", preserve_unit_iters=True) mod = sch.mod diff --git a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py index 4147a9fbab86..a26a7381062a 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py +++ b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py @@ -32,13 +32,13 @@ def add(a: T.handle, b: T.handle) -> None: A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") # body for i, j, k in T.grid(2048, 2048, 2048): - with T.block("move"): + with T.sblock("move"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([A_cached[vi, vj, vk]]) A_cached[vi, vj, vk] = A[vi, vj, vk] for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(2048, k0 * 32 + k1) @@ -53,7 +53,7 @@ def add(a: T.handle, b: T.handle) -> None: def _sch(decision: int) -> Schedule: sch = Schedule(add, debug_mask="all") # pylint: disable=invalid-name - b0 = sch.get_block(name="move", func_name="main") + b0 = sch.get_sblock(name="move", func_name="main") l1 = sch.sample_compute_location(block=b0, decision=decision) sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True) # pylint: enable=invalid-name diff --git a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py index 728f522335bf..d7a421288c72 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py +++ b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py @@ -31,7 +31,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [512, 512]) C = T.match_buffer(c, [512, 512]) for i, j, k in T.grid(512, 512, 512): # type: ignore - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore with T.init(): C[vi, vj] = 0.0 # type: ignore @@ -45,8 +45,8 @@ def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name d0, d1, d2 = decisions - b0 = sch.get_block(name="C", func_name="main") - root = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") + root = sch.get_sblock(name="root", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) diff --git a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py index d3a431af0687..05d428454d3e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py +++ b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py @@ -28,7 +28,7 @@ def element_wise(var_A: T.handle, var_B: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") for i, j in T.grid(512, 512): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + 1.0 @@ -39,7 +39,7 @@ def element_wise(var_A: T.handle, var_B: T.handle) -> None: def _sch() -> Schedule: sch = Schedule(element_wise, debug_mask="all") # pylint: disable=invalid-name - b0 = sch.get_block(name="C", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") l1, l2 = sch.get_loops(block=b0) l3 = sch.fuse(l1, l2) v4 = sch.sample_categorical( diff --git a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py index c09ef3e87066..371ba076fe00 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py +++ b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py @@ -33,7 +33,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [512, 512]) C = T.match_buffer(c, [512, 512]) for i, j, k in T.grid(512, 512, 512): # type: ignore - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore with T.init(): C[vi, vj] = 0.0 # type: ignore @@ -47,7 +47,7 @@ def _sch(decisions: List[List[int]]) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name (d0,) = decisions - b0 = sch.get_block(name="C", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) diff --git a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py index a59a7e655b09..ff3c01d6266c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py +++ b/tests/python/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py @@ -31,7 +31,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [512, 512]) C = T.match_buffer(c, [512, 512]) for i, j, k in T.grid(512, 512, 512): # type: ignore - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore with T.init(): C[vi, vj] = 0.0 # type: ignore @@ -45,8 +45,8 @@ def _sch(decisions: List[List[int]]) -> Schedule: sch = Schedule(matmul, debug_mask="all") # pylint: disable=invalid-name d0, d1, d2 = decisions - b0 = sch.get_block(name="C", func_name="main") - root = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") + root = sch.get_sblock(name="root", func_name="main") sch.get_consumers(block=b0) b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") l2, l3, l4 = sch.get_loops(block=b0) diff --git a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py index 61888ed1a70e..b5066a383dc3 100644 --- a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py @@ -33,7 +33,7 @@ from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.target import Target -from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule import SBlockRV, Schedule # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, # fmt: off @@ -64,7 +64,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -80,13 +80,13 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @@ -101,15 +101,15 @@ def main(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer((1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") for i, j in T.grid(1024, 1024): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(1024, 1024): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 3.0 for i, j in T.grid(1024, 1024): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = C[vi, vj] * 5.0 @@ -126,14 +126,14 @@ def main(a: T.handle, d: T.handle) -> None: # with tir.block("root") B = T.alloc_buffer([1024, 1024], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(1024, i0_0 * 64 + i0_1) vj = T.axis.S(1024, i1_0 * 16 + i1_1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj] * T.float32(2) for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(1024, i0_0 * 64 + i0_1) vj = T.axis.S(1024, i1_0 * 16 + i1_1) T.reads([B[vi, vj]]) @@ -145,7 +145,7 @@ def main(a: T.handle, d: T.handle) -> None: # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_root(sch: Schedule, block: BlockRV) -> bool: +def _is_root(sch: Schedule, block: SBlockRV) -> bool: return sch.get_sref(block).parent is None @@ -160,7 +160,7 @@ class WowSoFancyScheduleRule(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() @@ -177,7 +177,7 @@ class DoubleScheduleRule(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() @@ -202,7 +202,7 @@ class TrinityDoubleRule(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() @@ -225,7 +225,7 @@ class ReorderScheduleRule(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] new_sch = sch.copy() @@ -324,7 +324,7 @@ class RemoveBlock(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: SBlockRV) -> List[Schedule]: if _is_root(sch, block): return [sch] sch = sch.copy() @@ -337,9 +337,9 @@ def correct_trace(a, b, c, d): [ "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="A", func_name="main")', - ' b1 = sch.get_block(name="B", func_name="main")', - ' b2 = sch.get_block(name="C", func_name="main")', + ' b0 = sch.get_sblock(name="A", func_name="main")', + ' b1 = sch.get_sblock(name="B", func_name="main")', + ' b2 = sch.get_sblock(name="C", func_name="main")', " sch.compute_inline(block=b1)", " l3, l4 = sch.get_loops(block=b2)", " l5, l6 = sch.split(loop=l3, factors=" @@ -379,7 +379,7 @@ def correct_trace(a, b, c, d): tvm.tir.schedule.schedule.ScheduleError, match="ScheduleError: An error occurred in the schedule primitive 'get-block'.", ): - sch.get_block("B", "main") + sch.get_sblock("B", "main") sch_trace = sch.trace.simplified(True) assert ( str(sch_trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) @@ -389,7 +389,7 @@ def correct_trace(a, b, c, d): ) -def test_target_blocks_search_space(): +def test_target_sblocks_search_space(): # Test that specific blocks of trinity matmul can be targeted. def filter_fn(block, target_names) -> bool: return block.name_hint in target_names diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py b/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py index 046bd7220f01..35341bc01083 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py @@ -55,7 +55,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -69,7 +69,7 @@ def test_postproc_disallow_async_strided_mem_copy_allows(): mod = Matmul sch = tir.Schedule(mod, debug_mask="all") - matmul_block = sch.get_block("matmul") + matmul_block = sch.get_sblock("matmul") loops = sch.get_loops(matmul_block) cache_read = sch.cache_read(matmul_block, 0, "global.vtcm") @@ -89,7 +89,7 @@ def test_postproc_disallow_async_strided_mem_copy_disallows(): mod = Matmul sch = tir.Schedule(mod, debug_mask="all") - matmul_block = sch.get_block("matmul") + matmul_block = sch.get_sblock("matmul") loops = sch.get_loops(matmul_block) # Make it a strided mem copy. diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py index 5dc2500d1b2d..7f818412bcfc 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -55,7 +55,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -72,7 +72,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (1024, 1024), "float32") for i, j in T.grid(1024, 1024): for k in T.serial(0, i): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index 9bb550e79e4a..7a3cd6178744 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -59,7 +59,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") # body - # with T.block("root") + # with T.sblock("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") @@ -69,7 +69,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) @@ -78,14 +78,14 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2) @@ -95,7 +95,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) @@ -114,7 +114,7 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") @@ -127,7 +127,7 @@ def main( for ax0_ax1_fused_2 in T.thread_binding( 0, 32, thread="threadIdx.x" ): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial( 512, ( @@ -155,7 +155,7 @@ def main( 0, 32, thread="threadIdx.x" ): for ax0_ax1_fused_3 in T.vectorized(0, 2): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial( 512, ( @@ -181,7 +181,7 @@ def main( T.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial( 512, @@ -190,12 +190,12 @@ def main( k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2) T.reads([A_shared[i, k], B_shared[k, j]]) T.writes([C_local[i, j]]) - T.block_attr({"warp_execution": 1}) + T.sblock_attr({"warp_execution": 1}) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1 @@ -217,7 +217,7 @@ def test_rewrite_cooperative_fetch(): sch = tir.Schedule(mod, debug_mask="all") # fmt: off # pylint: disable=line-too-long,invalid-name - b0 = sch.get_block(name="C", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) @@ -261,7 +261,7 @@ def test_rewrite_warp_execution(): sch = tir.Schedule(mod, debug_mask="all") # fmt: off # pylint: disable=line-too-long,invalid-name - b0 = sch.get_block(name="C", func_name="main") + b0 = sch.get_sblock(name="C", func_name="main") b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") l2, l3, l4 = sch.get_loops(block=b0) sch.annotate(b0, "warp_execution", 1) diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py index 8348c57c1949..0ebad18a1f79 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py @@ -82,7 +82,7 @@ def before( ) -> None: T.func_attr({"layout_free_buffers": [1]}) for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.S(16, i0 * 4 + i1) vj = T.axis.S(16, j) vk = T.axis.R(16, k0 * 4 + k1) @@ -98,12 +98,12 @@ def expected( T.func_attr({"layout_free_buffers": [1]}) B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32") for ax0, ax1 in T.grid(16, 16): - with T.block("layout_rewrite"): + with T.sblock("layout_rewrite"): i0, i1 = T.axis.remap("SS", [ax0, ax1]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1] for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.spatial(16, i0 * 4 + i1) vj = T.axis.spatial(16, j) vk = T.axis.reduce(16, k0 * 4 + k1) @@ -137,7 +137,7 @@ def before( ) -> None: T.func_attr({"layout_free_buffers": [0]}) for i, j in T.grid(16, 1): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) T.evaluate(A[vi, vj]) @@ -146,13 +146,13 @@ def expected(A: T.Buffer((16, 1), "float32")): A_global = T.alloc_buffer([16], dtype="float32") for ax0, ax1 in T.grid(16, 1): - with T.block("A_global"): + with T.sblock("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) A_global[v0] = A[v0, v1] for i, j in T.grid(16, 1): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) T.evaluate(A_global[vi]) @@ -165,7 +165,7 @@ def tir_matmul( ) -> None: T.func_attr({"layout_free_buffers": [1]}) for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.S(16, i0 * 4 + i1) vj = T.axis.S(16, j) vk = T.axis.R(16, k0 * 4 + k1) @@ -183,12 +183,12 @@ def rewritten_tir_matmul( T.func_attr({"layout_free_buffers": [1]}) B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32") for ax0, ax1 in T.grid(16, 16): - with T.block("layout_rewrite"): + with T.sblock("layout_rewrite"): i0, i1 = T.axis.remap("SS", [ax0, ax1]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1] for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.spatial(16, i0 * 4 + i1) vj = T.axis.spatial(16, j) vk = T.axis.reduce(16, k0 * 4 + k1) @@ -219,7 +219,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): for ax0, ax1, ax2 in T.grid(1, 30, 30): for ax3_fused in T.vectorized(64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0 = T.axis.spatial(1, ax0) i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) @@ -229,7 +229,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") for i3_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused in T.serial(57600): - with T.block("pad_temp_global"): + with T.sblock("pad_temp_global"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) @@ -238,7 +238,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), T.writes(pad_temp_global[v0, v1, v2, v3]) pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] for ax0_ax1_ax2_ax3_fused in T.serial(2304): - with T.block("p1_global"): + with T.sblock("p1_global"): v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) @@ -249,18 +249,18 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): for i3_3_fused_init in T.vectorized(2): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): nn = T.axis.spatial(1, i0_1 + i0_2_init + i0_3_init) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) T.reads() T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSRSRS"}) conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): for i3_3_fused in T.vectorized(2): - with T.block("conv2d_nhwc_update"): + with T.sblock("conv2d_nhwc_update"): nn = T.axis.spatial(1, i0_1 + i0_2 + i0_3) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) @@ -270,11 +270,11 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), rc = T.axis.reduce(64, i6_0 * 32 + i6_1) T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ry, rx, rc, ff]) T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSRSRS"}) conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ry, rx, rc, ff] for ax0, ax1, ax2 in T.grid(1, 4, 14): for ax3_fused in T.vectorized(4): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) @@ -295,16 +295,16 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): - with T.block("p1_global"): + with T.sblock("p1_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(p1[v0, v1, v2, v3]) T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) - T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc":True}) p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): for ax0, ax1, ax2 in T.grid(1, 30, 30): for ax3_fused in T.vectorized(64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0 = T.axis.spatial(1, ax0) i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) @@ -314,7 +314,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") for i3_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused in T.serial(57600): - with T.block("pad_temp_global"): + with T.sblock("pad_temp_global"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) @@ -323,7 +323,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), T.writes(pad_temp_global[v0, v1, v2, v3]) pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] for ax0_ax1_ax2_ax3_fused in T.serial(2304): - with T.block("p1_global"): + with T.sblock("p1_global"): v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) @@ -334,18 +334,18 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): for i3_3_fused_init in T.vectorized(2): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): nn = T.axis.spatial(1, i0_1 + i0_2_init + i0_3_init) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) T.reads() T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSRSRS"}) conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): for i3_3_fused in T.vectorized(2): - with T.block("conv2d_nhwc_update"): + with T.sblock("conv2d_nhwc_update"): nn = T.axis.spatial(1, i0_1 + i0_2 + i0_3) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) @@ -355,11 +355,11 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), rc = T.axis.reduce(64, i6_0 * 32 + i6_1) T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]) T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSRSRS"}) conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2] for ax0, ax1, ax2 in T.grid(1, 4, 14): for ax3_fused in T.vectorized(4): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) @@ -381,14 +381,14 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), p1_global2 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32", scope="global2") p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): - with T.block("p1_global"): + with T.sblock("p1_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(p1[v0, v1, v2, v3]) T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) - T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc":True}) p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): - with T.block("p1_global2"): + with T.sblock("p1_global2"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) T.writes(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) @@ -396,7 +396,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): for ax0, ax1, ax2 in T.grid(1, 30, 30): for ax3_fused in T.vectorized(64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0 = T.axis.spatial(1, ax0) i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) @@ -406,7 +406,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") for i3_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused in T.serial(57600): - with T.block("pad_temp_global"): + with T.sblock("pad_temp_global"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) @@ -415,7 +415,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), T.writes(pad_temp_global[v0, v1, v2, v3]) pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] for ax0_ax1_ax2_ax3_fused in T.serial(2304): - with T.block("p1_global"): + with T.sblock("p1_global"): v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) @@ -426,18 +426,18 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): for i3_3_fused_init in T.vectorized(2): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): nn = T.axis.spatial(1, i0_1 + i0_2_init + i0_3_init) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) T.reads() T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSRSRS"}) conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): for i3_3_fused in T.vectorized(2): - with T.block("conv2d_nhwc_update"): + with T.sblock("conv2d_nhwc_update"): nn = T.axis.spatial(1, i0_1 + i0_2 + i0_3) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) @@ -447,11 +447,11 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), rc = T.axis.reduce(64, i6_0 * 32 + i6_1) T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]) T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSRSRS"}) conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2] for ax0, ax1, ax2 in T.grid(1, 4, 14): for ax3_fused in T.vectorized(4): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) @@ -476,7 +476,7 @@ def test_layout_rewrite_cache_read_multiple(): target = Target("llvm") ctx = _create_context(Conv2dCacheRead, target) sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all") - sch.cache_read(sch.get_block("p1_global"), 0, "global2") + sch.cache_read(sch.get_sblock("p1_global"), 0, "global2") sch.enter_postproc() assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten) @@ -495,7 +495,7 @@ def before( for b_2_init, i_2_init, j_2_init, b_3_init, i_3_init, j_3_init in T.grid( T.int64(6), T.int64(1), T.int64(197), T.int64(1), T.int64(1), T.int64(1) ): - with T.block("T_batch_matmul_NT_init"): + with T.sblock("T_batch_matmul_NT_init"): v_b = T.axis.spatial( T.int64(12), b_3_init @@ -522,7 +522,7 @@ def before( T.int64(1), T.int64(1), ): - with T.block("T_batch_matmul_NT_update"): + with T.sblock("T_batch_matmul_NT_update"): v_b = T.axis.spatial( T.int64(12), b_3 @@ -553,18 +553,18 @@ def expected( [T.int64(2), T.int64(64), T.int64(6), T.int64(197)], dtype="int8" ) for ax0, ax1, ax2 in T.grid(T.int64(12), T.int64(197), T.int64(64)): - with T.block("p1_global"): + with T.sblock("p1_global"): v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(p1[v0, v1, v2]) T.writes(p1_global[v0 // T.int64(6), v2, v0 % T.int64(6), v1]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) p1_global[v0 // T.int64(6), v2, v0 % T.int64(6), v1] = p1[v0, v1, v2] for b_0_i_0_fused in T.parallel(T.int64(394)): for j_0, b_1, i_1, j_1 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): for b_2_init, i_2_init, j_2_init, b_3_init, i_3_init, j_3_init in T.grid( T.int64(6), T.int64(1), T.int64(197), T.int64(1), T.int64(1), T.int64(1) ): - with T.block("T_batch_matmul_NT_init"): + with T.sblock("T_batch_matmul_NT_init"): v_b = T.axis.spatial( T.int64(12), b_3_init @@ -590,7 +590,7 @@ def expected( T.int64(1), T.int64(1), ): - with T.block("T_batch_matmul_NT_update"): + with T.sblock("T_batch_matmul_NT_update"): v_b = T.axis.spatial( T.int64(12), b_3 diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index 13feaa85f671..38191748cc08 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -34,10 +34,10 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body - with T.block("root"): - T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) + with T.sblock("root"): + T.sblock_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): - with T.block("move"): + with T.sblock("move"): vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1) @@ -54,11 +54,11 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body - with T.block("root"): + with T.sblock("root"): for i0_j0_fused in T.parallel(0, 8192): for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): for k1_fused in T.vectorized(0, 32): - with T.block("move"): + with T.sblock("move"): vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1_fused) @@ -77,7 +77,7 @@ class Fused_NN_Dense: @T.prim_func def main(placeholder: T.Buffer((64, 768), "float32"), placeholder_1: T.Buffer((768, 768), "float32"), T_matmul_NT: T.Buffer((64, 768), "float32")) -> None: for i0, i1, i2 in T.grid(64, 768, 768): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(placeholder[i, k], placeholder_1[j, k]) T.writes(T_matmul_NT[i, j]) @@ -91,14 +91,14 @@ def before_matmul_vectorize( placeholder_1: T.Buffer((768, 768), "float32"), T_matmul_NT: T.Buffer((64, 768), "float32"), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.vectorize":64}) + T.sblock_attr({"meta_schedule.vectorize":64}) T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3): for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(48, 8, 1, 16, 8, 16): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i = T.axis.spatial(64, i0_2 * 8 + i0_3) j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3) k = T.axis.reduce(768, i2_0 * 16 + i2_1) @@ -108,7 +108,7 @@ def before_matmul_vectorize( T_matmul_NT_global[i, j] = T.float32(0) T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] + placeholder[i, k] * placeholder_1[j, k] for ax0, ax1 in T.grid(64, 16): - with T.block("T_matmul_NT_global"): + with T.sblock("T_matmul_NT_global"): v0 = T.axis.spatial(64, ax0) v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1) T.reads(T_matmul_NT_global[v0, v1]) @@ -125,7 +125,7 @@ def after_matmul_vectorize( for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3): for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8): for i1_3_fused in T.vectorized(16): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i = T.axis.spatial(64, i0_2 * 8 + i0_3) j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused) k = T.axis.reduce(768, i2_0 * 16 + i2_1) @@ -136,7 +136,7 @@ def after_matmul_vectorize( T_matmul_NT_global[i, j] = T_matmul_NT_global[i, j] + placeholder[i, k] * placeholder_1[j, k] for ax0 in T.serial(64): for ax1_fused in T.vectorized(16): - with T.block("T_matmul_NT_global"): + with T.sblock("T_matmul_NT_global"): v0 = T.axis.spatial(64, ax0) v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused) T.reads(T_matmul_NT_global[v0, v1]) @@ -150,10 +150,10 @@ def before_postproc_add( rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), ) -> None: - with T.block("root"): - T.block_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128}) + with T.sblock("root"): + T.sblock_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128}) for n, c0, h, w, c1 in T.grid(1, 8, 56, 56, 32): - with T.block("add_compute"): + with T.sblock("add_compute"): v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [n, c0, h, w, c1]) T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) T.writes(add_compute[v0, v1, v2, v3, v4]) @@ -166,10 +166,10 @@ def after_postproc_add( rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), ) -> None: - with T.block("root"): + with T.sblock("root"): for n_c0_h_w_c1_fused_0 in T.parallel(0, 6272): for n_c0_h_w_c1_fused_1 in T.vectorized(0, 128): - with T.block("add_compute"): + with T.sblock("add_compute"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(8, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) // 100352) v2 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 100352 // 1792) @@ -210,12 +210,12 @@ def test_no_unroll_for_spatial_block(): # fmt: off @T.prim_func def layer_norm(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "float32"), C: T.Buffer((4, 4, 32), "float32"), T_layer_norm: T.Buffer((1, 4, 4, 32), "float32")): - with T.block("root"): - T.block_attr({"meta_schedule.unroll_explicit": 512}) + with T.sblock("root"): + T.sblock_attr({"meta_schedule.unroll_explicit": 512}) A_red_temp_v0 = T.alloc_buffer((1,)) A_red_temp_v1 = T.alloc_buffer((1,)) for ax0, k1, k2, k3 in T.grid(1, 4, 4, 32): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0, v_k1, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k1, k2, k3]) T.reads(A[v_ax0, v_k1, v_k2, v_k3]) T.writes(A_red_temp_v0[v_ax0], A_red_temp_v1[v_ax0]) @@ -227,7 +227,7 @@ def layer_norm(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "f A_red_temp_v0[v_ax0] = v_A_red_temp_v0 A_red_temp_v1[v_ax0] = v_A_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], A_red_temp_v0[v_ax0], A_red_temp_v1[v_ax0], B[v_ax1, v_ax2, v_ax3], C[v_ax1, v_ax2, v_ax3]) T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -235,12 +235,12 @@ def layer_norm(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "f @T.prim_func def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "float32"), C: T.Buffer((4, 4, 32), "float32"), T_layer_norm: T.Buffer((1, 4, 4, 32), "float32")): - with T.block("root"): + with T.sblock("root"): A_red_temp_v0 = T.alloc_buffer((1,)) A_red_temp_v1 = T.alloc_buffer((1,)) for ax0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): for k1, k2, k3 in T.grid(4, 4, 32): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0 = T.axis.spatial(1, 0) v_k1, v_k2, v_k3 = T.axis.remap("RRR", [k1, k2, k3]) T.reads(A[0, v_k1, v_k2, v_k3]) @@ -253,7 +253,7 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo A_red_temp_v0[0] = v_A_red_temp_v0 A_red_temp_v1[0] = v_A_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0 = T.axis.spatial(1, 0) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[0, v_ax1, v_ax2, v_ax3], A_red_temp_v0[0], A_red_temp_v1[0], B[v_ax1, v_ax2, v_ax3], C[v_ax1, v_ax2, v_ax3]) diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py index 347b773b7ed0..46667f6d9b77 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -62,25 +62,25 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.sblock_attr({"meta_schedule.cooperative_fetch":1}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) + T.sblock_attr({"meta_schedule.cooperative_fetch":2}) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) @@ -90,7 +90,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) @@ -114,32 +114,32 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) + T.sblock_attr({"meta_schedule.cooperative_fetch":1}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) + T.sblock_attr({"meta_schedule.cooperative_fetch":2}) B_shared[v0, v1] = B[v0, v1] for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(2, 2, 16, 2): - with T.block("C_init"): + with T.sblock("C_init"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3_init * 16 + i0_4_init) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3_init * 2 + i1_4_init) T.reads([]) T.writes([C_local[i, j]]) C_local[i, j] = T.float32(0) for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): - with T.block("C_update"): + with T.sblock("C_update"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) @@ -147,7 +147,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: T.writes([C_local[i, j]]) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) @@ -164,7 +164,7 @@ def main(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256) for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 8): for ax1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_0 * 32 + ax1_1) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) @@ -174,7 +174,7 @@ def main(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256) T_softmax_maxelem_shared[i0_1] = T.max(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 8): for ax1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_0 * 32 + ax1_1) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) @@ -184,12 +184,12 @@ def main(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(8): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) - T.block_attr({"axis":1}) + T.sblock_attr({"axis":1}) T_softmax_norm[i0_3, i1] = T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3] diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py index 313657108c62..2913af5b136c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py @@ -93,7 +93,7 @@ def main( 4, 1, ): - with T.block("conv2d_NCHWc_int8_o"): + with T.sblock("conv2d_NCHWc_int8_o"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) @@ -108,16 +108,16 @@ def main( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"}) with T.init(): for i4_1 in T.serial(16): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for i4_1, i9_1 in T.grid(16, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1]) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block], @@ -127,7 +127,7 @@ def main( ], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block ] + T.cast( @@ -152,12 +152,12 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1 ): for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(4, 7, 4, 4): - with T.block("conv2d_NCHWc_int8_o_init"): + with T.sblock("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) @@ -166,7 +166,7 @@ def main( T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) for i4_1 in T.vectorized(16): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) @@ -191,7 +191,7 @@ def main( i3_3, i4_0_3, ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): - with T.block("conv2d_NCHWc_int8_o_update"): + with T.sblock("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) @@ -260,27 +260,27 @@ def main( for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64 ) v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): - with T.block("W_shared"): + with T.sblock("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64 ) v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(2, 4, 16, 8, 4, 1): - with T.block("compute_o"): + with T.sblock("compute_o"): i = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, @@ -295,14 +295,14 @@ def main( W_shared[j, k_o * 4 : k_o * 4 + 4], ) T.writes(compute_local[i, j]) - T.block_attr({"meta_schedule.auto_tensorize": "dp4a_s8s8s32"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "dp4a_s8s8s32"}) with T.init(): - with T.block("compute_init"): + with T.sblock("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_1 in T.serial(4): - with T.block("compute"): + with T.sblock("compute"): k = T.axis.reduce(4, i2_1) T.reads( compute_local[i, j], @@ -310,12 +310,14 @@ def main( W_shared[j, k_o * 4 + k], ) T.writes(compute_local[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr( + {"meta_schedule.tiling_structure": "SSSRRSRS"} + ) compute_local[i, j] = compute_local[i, j] + T.cast( X_shared[i, k_o * 4 + k], "int32" ) * T.cast(W_shared[j, k_o * 4 + k], "int32") for ax0, ax1 in T.grid(16, 16): - with T.block("compute_local"): + with T.sblock("compute_local"): v0 = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, @@ -340,7 +342,7 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") @@ -348,7 +350,7 @@ def main( for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4): - with T.block("compute_o_init"): + with T.sblock("compute_o_init"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3_init * 4 + i0_4_init ) @@ -361,34 +363,34 @@ def main( ) T.reads() T.writes(compute_local[i, j]) - T.block_attr({"meta_schedule.auto_tensorize": ""}) - with T.block("compute_init"): + T.sblock_attr({"meta_schedule.auto_tensorize": ""}) + with T.sblock("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64 ) v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): - with T.block("W_shared"): + with T.sblock("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64 ) v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(2, 4, 16, 8, 4, 1): - with T.block("compute_o_update"): + with T.sblock("compute_o_update"): i = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, @@ -436,7 +438,7 @@ def main( dtype="int32", ) for ax0, ax1 in T.grid(16, 16): - with T.block("compute_local"): + with T.sblock("compute_local"): v0 = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py index 719d9c2f9515..779832986d91 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -51,7 +51,7 @@ def main(var_A: T.handle, var_B: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") for i, j in T.grid(512, 512): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + 1.0 @@ -64,7 +64,7 @@ def main(var_A: T.handle, var_B: T.handle) -> None: B = T.match_buffer(var_B, [512, 512], dtype="float32") for i_j_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) // 512) vj = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) % 512) B[vi, vj] = A[vi, vj] + 1.0 @@ -76,13 +76,13 @@ class Before_norm_bmn: def main(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer((1,), "float32")) -> None: C = T.alloc_buffer([1], dtype="float32") for i0, i1, i2 in T.grid(1, 256, 256): - with T.block("C"): + with T.sblock("C"): b, i, j = T.axis.remap("SRR", [i0, i1, i2]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0 in T.serial(1): - with T.block("D"): + with T.sblock("D"): b = T.axis.S(1, i0) D[b] = T.sqrt(C[b], dtype="float32") @@ -95,7 +95,7 @@ def main(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer((1,), "float32")) -> for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): for i1, i2 in T.grid(256, 256): - with T.block("C"): + with T.sblock("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) with T.init(): @@ -103,7 +103,7 @@ def main(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer((1,), "float32")) -> C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): - with T.block("D"): + with T.sblock("D"): b = T.axis.S(1, 0) D[b] = T.sqrt(C[b], dtype="float32") @@ -115,7 +115,7 @@ def main( placeholder: T.Buffer((12, 64, 64), "float32"), T_reshape: T.Buffer((64, 768), "float32") ) -> None: for i0_i1_fused_0, i0_i1_fused_1 in T.grid(1536, 32): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): ax0 = T.axis.spatial(64, (i0_i1_fused_0 * 32 + i0_i1_fused_1) // 768) ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 32 + i0_i1_fused_1) % 768) T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64]) @@ -134,7 +134,7 @@ def main( placeholder: T.Buffer((12, 64, 64), "float32"), T_reshape: T.Buffer((64, 768), "float32") ) -> None: for i0_i1_fused_0, i0_i1_fused_1 in T.grid(1536000, 32): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): ax0 = T.axis.spatial(64, (i0_i1_fused_0 * 32 + i0_i1_fused_1) // 768) ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 32 + i0_i1_fused_1) % 768) T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64]) @@ -154,7 +154,7 @@ def main( ) -> None: for i0_i1_fused_0_i0_i1_fused_1_fused_0 in T.thread_binding(48, thread="blockIdx.x"): for i0_i1_fused_0_i0_i1_fused_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): ax0 = T.axis.spatial( 64, ( @@ -187,11 +187,11 @@ def main( placeholder: T.Buffer((12, 64, 64), "float32"), T_reshape: T.Buffer((64, 768), "float32") ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i0_i1_fused_0_i0_i1_fused_1_fused_1 in T.thread_binding(256, thread="blockIdx.x"): for i0_i1_fused_0_i0_i1_fused_1_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): for i0_i1_fused_0_i0_i1_fused_1_fused_0 in T.serial(188): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): ax0 = T.axis.spatial( 64, ( @@ -242,7 +242,7 @@ def before_unrolled_loop( for i1 in T.unroll(4): for i4 in T.unroll(6): for i5 in T.unroll(6): - with T.block("inverse"): + with T.sblock("inverse"): vh, vw = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial(196, i2_0 * 2 + i2_1) co = T.axis.spatial(64, i3_0 * 16 + i3_1) @@ -260,7 +260,7 @@ def after_unrolled_loop( ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(13, thread="blockIdx.x"): @@ -269,7 +269,7 @@ def after_unrolled_loop( for i1 in T.unroll(4): for i4 in T.unroll(6): for i5 in T.unroll(6): - with T.block("inverse"): + with T.sblock("inverse"): vh, vw = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial( 196, diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py index 0facc9b961e9..22cf305ca351 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py @@ -228,7 +228,7 @@ def GmmCuda0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): - with T.block("Z_init"): + with T.sblock("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) @@ -239,7 +239,7 @@ def GmmCuda0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) @@ -248,7 +248,7 @@ def GmmCuda0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) @@ -256,7 +256,7 @@ def GmmCuda0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): - with T.block("Z_update"): + with T.sblock("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) @@ -265,7 +265,7 @@ def GmmCuda0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): - with T.block("Z_local"): + with T.sblock("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) @@ -282,7 +282,7 @@ def GmmCuda1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): - with T.block("Z_init"): + with T.sblock("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) @@ -293,7 +293,7 @@ def GmmCuda1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) @@ -302,7 +302,7 @@ def GmmCuda1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) @@ -310,12 +310,12 @@ def GmmCuda1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): - with T.block("Z_update"): + with T.sblock("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) k = T.axis.reduce(128, i3_0 * 32 + i3_2) - T.block_attr({ + T.sblock_attr({ "meta_schedule.thread_extent_low_inclusive": 0, "meta_schedule.thread_extent_high_inclusive": 32, }) @@ -323,7 +323,7 @@ def GmmCuda1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): - with T.block("Z_local"): + with T.sblock("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) @@ -341,7 +341,7 @@ def GmmCuda2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): - with T.block("Z_init"): + with T.sblock("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) @@ -352,7 +352,7 @@ def GmmCuda2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) @@ -361,7 +361,7 @@ def GmmCuda2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) @@ -369,12 +369,12 @@ def GmmCuda2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): - with T.block("Z_update"): + with T.sblock("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) k = T.axis.reduce(128, i3_0 * 32 + i3_2) - T.block_attr({ + T.sblock_attr({ "meta_schedule.thread_extent_low_inclusive": 1024, "meta_schedule.thread_extent_high_inclusive": 1024, }) @@ -382,7 +382,7 @@ def GmmCuda2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), " T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): - with T.block("Z_local"): + with T.sblock("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) @@ -406,7 +406,7 @@ def GMMCUDATensorCore( s1_1 = T.int32() s1_2 = T.int32() # body - # with T.block("root") + # with T.sblock("root") Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator") X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared") Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared") @@ -416,7 +416,7 @@ def GMMCUDATensorCore( for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax1_0_3_init, ax2_0_3_init, ax1_0_4_init, ax2_0_4_init in T.grid(2, 1, 2, 4): - with T.block("Z_o_init"): + with T.sblock("Z_o_init"): v0 = T.axis.spatial(1, 0) v1_o = T.axis.spatial( 64, @@ -437,7 +437,7 @@ def GMMCUDATensorCore( v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 ] ) - T.block_attr( + T.sblock_attr( { "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, @@ -469,7 +469,7 @@ def GMMCUDATensorCore( for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial( 1024, ax0_0_ax1_0_0_ax2_0_0_fused // 16 * 256 @@ -495,13 +495,13 @@ def GMMCUDATensorCore( ) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused_0 in T.serial(8): for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial( 1024, ax3_0_0 * 32 @@ -526,11 +526,11 @@ def GMMCUDATensorCore( ) T.reads(Y[v0, v1]) T.writes(Y_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) Y_shared[v0, v1] = Y[v0, v1] for ax3_0_1 in T.serial(2): for ax0_0, ax1_0 in T.grid(4, 1): - with T.block("X_shared_wmma.matrix_a_o"): + with T.sblock("X_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial( 64, ax0_0_ax1_0_0_ax2_0_0_fused // 16 * 16 @@ -587,7 +587,7 @@ def GMMCUDATensorCore( ) ) for ax0_0, ax1_0 in T.grid(1, 4): - with T.block("Y_shared_wmma.matrix_b_o"): + with T.sblock("Y_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(64, ax3_0_0 * 2 + ax3_0_1) v1_o = T.axis.spatial( 64, ax0_0_ax1_0_0_ax2_0_0_fused % 16 * 4 + ax1_0 @@ -642,7 +642,7 @@ def GMMCUDATensorCore( for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid( 1, 2, 1, 1, 1, 2, 4 ): - with T.block("Z_o_update"): + with T.sblock("Z_o_update"): v0 = T.axis.spatial(1, 0) v1_o = T.axis.spatial( 64, @@ -674,7 +674,7 @@ def GMMCUDATensorCore( v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16 ] ) - T.block_attr( + T.sblock_attr( { "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, @@ -722,7 +722,7 @@ def GMMCUDATensorCore( ) ) for ax0_0, ax1_0 in T.grid(4, 4): - with T.block("Z_wmma.accumulator_o"): + with T.sblock("Z_wmma.accumulator_o"): v0_o = T.axis.spatial( 64, ax0_0_ax1_0_0_ax2_0_0_fused // 16 * 16 diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py b/tests/python/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py index cb4767221915..27886bd7caf9 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py @@ -49,7 +49,7 @@ def main(p0: T.Buffer((T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64 for n_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): for oc_chunk_0, oh_0, ow_0, oc_block_0_0 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(1)): for oc_chunk_1_init, oh_1_init, ow_1_init, oc_chunk_2_init, oh_2_init, ow_2_init in T.grid(T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(1), T.int64(9)): - with T.block("conv2d_NCHWc_int8_o_init"): + with T.sblock("conv2d_NCHWc_int8_o_init"): v_n = T.axis.spatial(T.int64(1), T.int64(0)) v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1_init + oc_chunk_2_init + oc_chunk_0) v_oh = T.axis.spatial(T.int64(54), oh_2_init + oh_0 * T.int64(27) + oh_1_init) @@ -58,14 +58,14 @@ def main(p0: T.Buffer((T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64 T.reads() T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]) for oc_block_1 in T.vectorized(T.int64(32)): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1) T.reads() T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init]) conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] = 0 for kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused in T.serial(T.int64(2), annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(26912)): - with T.block("p0_global.vtcm"): + with T.sblock("p0_global.vtcm"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_fused // T.int64(13456)) v2 = T.axis.spatial(T.int64(56), oh_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(13456) // T.int64(464)) @@ -75,7 +75,7 @@ def main(p0: T.Buffer((T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64 T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(9216)): - with T.block("p1_global.vtcm"): + with T.sblock("p1_global.vtcm"): v0 = T.axis.spatial(T.int64(2), oc_chunk_0) v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(4608)) v2 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4608) // T.int64(1536)) @@ -87,7 +87,7 @@ def main(p0: T.Buffer((T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64 T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6] for n_1, oc_chunk_1, oh_1, ow_1, oc_block_0_1, kh_1, kw_1, ic_outer_1, ic_f_inner_1, ic_s_inner_0_1, n_2, oc_chunk_2, oh_2, ow_2, oc_block_0_2 in T.grid(T.int64(1), T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(3), T.int64(3), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(9), T.int64(1)): - with T.block("conv2d_NCHWc_int8_o_update"): + with T.sblock("conv2d_NCHWc_int8_o_update"): v_n = T.axis.spatial(T.int64(1), T.int64(0)) v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1 + oc_chunk_2 + oc_chunk_0) v_oh = T.axis.spatial(T.int64(54), oh_2 + oh_0 * T.int64(27) + oh_1) @@ -99,11 +99,11 @@ def main(p0: T.Buffer((T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64 T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4)], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, T.int64(0) : T.int64(32), T.int64(0) : T.int64(4)]) T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]) for oc_block_1, ic_s_inner_1 in T.grid(T.int64(32), T.int64(4)): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): v_oc_block_i, v_ic_s_inner_i = T.axis.remap("SR", [oc_block_1, ic_s_inner_1]) T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i]) T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure":"SRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SRSRS"}) conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] = conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] + T.Cast("int32", p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i]) * T.Cast("int32", p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i]) #fmt on diff --git a/tests/python/meta_schedule/test_meta_schedule_runner.py b/tests/python/meta_schedule/test_meta_schedule_runner.py index 5b4f6944df91..8cc314edb57f 100644 --- a/tests/python/meta_schedule/test_meta_schedule_runner.py +++ b/tests/python/meta_schedule/test_meta_schedule_runner.py @@ -73,7 +73,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -90,13 +90,13 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(16, 16): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -110,7 +110,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, [16, 32, 32]) C = T.match_buffer(c, [16, 32, 32]) for n, i, j, k in T.grid(16, 32, 32, 32): - with T.block("update"): + with T.sblock("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) with T.init(): C[vn, vi, vj] = 0.0 @@ -126,7 +126,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, [32], "float32") C = T.match_buffer(c, [32], "float32") for i in range(32): - with T.block("add"): + with T.sblock("add"): vi = T.axis.S(32, i) C[vi] = A[vi] + B[vi] @@ -141,7 +141,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (4096, 4096), "float32") C = T.match_buffer(c, (4096, 4096), "float32") for i, j, k in T.grid(4096, 4096, 4096): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py index b21a4e0f7ec8..3162f927a888 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py @@ -35,7 +35,7 @@ def cpu_matmul_0( ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1, i2 in T.grid(4, 4, 512): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(A[i, k], B[k, j]) T.writes(C[i, j]) @@ -52,7 +52,7 @@ def cpu_matmul_1( T.func_attr({"global_symbol": "main", "tir.noalias": True}) C_rf = T.alloc_buffer([4, 4, 128], dtype="float32") for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128): - with T.block("C_rf"): + with T.sblock("C_rf"): vi2_1, i, j, vi2_0 = T.axis.remap("SSSR", [i2_1, i0, i1, i2_0]) T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j]) T.writes(C_rf[i, j, vi2_1]) @@ -62,11 +62,11 @@ def cpu_matmul_1( C_rf[i, j, vi2_1] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 128 + vi2_1, j] ) for i0, i1, i2_1 in T.grid(4, 4, 128): - with T.block("C"): + with T.sblock("C"): vi2_1, i, j = T.axis.remap("RSS", [i2_1, i0, i1]) T.reads(C_rf[i, j, vi2_1]) T.writes(C[i, j]) - T.block_attr({"meta_schedule.random_compute_producer": 1}) + T.sblock_attr({"meta_schedule.random_compute_producer": 1}) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + C_rf[i, j, vi2_1] @@ -80,7 +80,7 @@ def cpu_matmul_2( T.func_attr({"global_symbol": "main", "tir.noalias": True}) C_rf = T.alloc_buffer([4, 4, 4], dtype="float32") for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128): - with T.block("C_rf"): + with T.sblock("C_rf"): vi2_0, i, j, vi2_1 = T.axis.remap("SSSR", [i2_0, i0, i1, i2_1]) T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j]) T.writes(C_rf[i, j, vi2_0]) @@ -90,11 +90,11 @@ def cpu_matmul_2( C_rf[i, j, vi2_0] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 128 + vi2_1, j] ) for i0, i1, i2_0 in T.grid(4, 4, 4): - with T.block("C"): + with T.sblock("C"): vi2_0, i, j = T.axis.remap("RSS", [i2_0, i0, i1]) T.reads(C_rf[i, j, vi2_0]) T.writes(C[i, j]) - T.block_attr({"meta_schedule.random_compute_producer": 1}) + T.sblock_attr({"meta_schedule.random_compute_producer": 1}) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + C_rf[i, j, vi2_0] @@ -130,7 +130,7 @@ def argmax( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1 in T.grid(128, 128): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1) T.reads(idx[i, k], val[i, k]) @@ -153,7 +153,7 @@ def argmax_0( argmax_v1: T.Buffer(128, "float32"), ) -> None: for i0, i1 in T.grid(128, 128): - with T.block("argmax"): + with T.sblock("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(idx[i, k], val[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) @@ -177,7 +177,7 @@ def argmax_1( argmax_v0_rf = T.alloc_buffer([128, 16], dtype="int32") argmax_v1_rf = T.alloc_buffer([128, 16], dtype="float32") for i0, i1_0, i1_1 in T.grid(128, 8, 16): - with T.block("argmax_rf"): + with T.sblock("argmax_rf"): vi1_1, i, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) T.reads(idx[i, vi1_0 * 16 + vi1_1], val[i, vi1_0 * 16 + vi1_1]) T.writes(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) @@ -197,11 +197,11 @@ def argmax_1( argmax_v0_rf[i, vi1_1] = v_argmax_v0_rf argmax_v1_rf[i, vi1_1] = v_argmax_v1_rf for i0, i1_1 in T.grid(128, 16): - with T.block("argmax"): + with T.sblock("argmax"): vi1_1, i = T.axis.remap("RS", [i1_1, i0]) T.reads(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) T.writes(argmax_v0[i], argmax_v1[i]) - T.block_attr({"meta_schedule.random_compute_producer": 1}) + T.sblock_attr({"meta_schedule.random_compute_producer": 1}) with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) @@ -222,11 +222,11 @@ def argmax_2( argmax_v1: T.Buffer(128, "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") argmax_v0_rf = T.alloc_buffer([128, 8], dtype="int32") argmax_v1_rf = T.alloc_buffer([128, 8], dtype="float32") for i0, i1_0, i1_1 in T.grid(128, 8, 16): - with T.block("argmax_rf"): + with T.sblock("argmax_rf"): vi1_0, i, vi1_1 = T.axis.remap("SSR", [i1_0, i0, i1_1]) T.reads(idx[i, vi1_0 * 16 + vi1_1], val[i, vi1_0 * 16 + vi1_1]) T.writes(argmax_v0_rf[i, vi1_0], argmax_v1_rf[i, vi1_0]) @@ -246,11 +246,11 @@ def argmax_2( argmax_v0_rf[i, vi1_0] = v_argmax_v0_rf argmax_v1_rf[i, vi1_0] = v_argmax_v1_rf for i0, i1_0 in T.grid(128, 8): - with T.block("argmax"): + with T.sblock("argmax"): vi1_0, i = T.axis.remap("RS", [i1_0, i0]) T.reads(argmax_v0_rf[i, vi1_0], argmax_v1_rf[i, vi1_0]) T.writes(argmax_v0[i], argmax_v1[i]) - T.block_attr({"meta_schedule.random_compute_producer": 1}) + T.sblock_attr({"meta_schedule.random_compute_producer": 1}) with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py index 332bebd79d31..44d7226a5abc 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -34,8 +34,8 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): - T.block_attr({"schedule_rule": "test_apply_custom_rule"}) + with T.sblock("matmul"): + T.sblock_attr({"schedule_rule": "test_apply_custom_rule"}) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -43,7 +43,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.register_global_func("meta_schedule.cpu.test_apply_custom_rule") -def sch_fn(sch: tvm.tir.Schedule, block: tvm.tir.Block) -> List[tvm.tir.Schedule]: +def sch_fn(sch: tvm.tir.Schedule, block: tvm.tir.SBlock) -> List[tvm.tir.Schedule]: raise ValueError("Intended for meta_schedule.cpu.test_apply_custom_rule") diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py index a8219ca01a68..ca510ee335cb 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py @@ -29,7 +29,7 @@ def element_wise(var_A: T.handle, var_B: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") for i, j in T.grid(512, 512): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + 1.0 @@ -41,7 +41,7 @@ def reduction_loop_only( C: T.Buffer((), "float32"), ) -> None: for i0 in T.serial(2): - with T.block("C"): + with T.sblock("C"): k0 = T.axis.reduce(2, i0) T.reads(A[k0], B[k0]) T.writes(C[()]) @@ -56,7 +56,7 @@ def zero_dim_add( B: T.Buffer((), "float32"), C: T.Buffer((), "float32"), ) -> None: - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] @@ -68,10 +68,10 @@ def elementwise_0( B: T.Buffer((512, 512), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i_j_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) // 512) vj = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) % 512) T.reads(A[vi, vj]) @@ -106,7 +106,7 @@ def reduction_loop_only_0( for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): for i0 in T.serial(2): - with T.block("C"): + with T.sblock("C"): k0 = T.axis.reduce(2, i0) T.reads(A[k0], B[k0]) T.writes(C[()]) @@ -138,7 +138,7 @@ def zero_dim_add_0( ) -> None: for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) T.reads(A[()], B[()]) T.writes(C[()]) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py index 3f43e0133c29..6c609a084a59 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py @@ -45,29 +45,29 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): - with T.block("compute"): + with T.sblock("compute"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): compute_1[nn, ff, yy, xx] = T.float32(0) compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): - with T.block("bias_add"): + with T.sblock("bias_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): - with T.block("bn_mul"): + with T.sblock("bn_mul"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): - with T.block("bn_add"): + with T.sblock("bn_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0)) @@ -85,17 +85,17 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): - with T.block("compute"): + with T.sblock("compute"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): compute_1[nn, ff, yy, xx] = T.float32(0) compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) @@ -116,7 +116,7 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): @@ -125,7 +125,7 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): - with T.block("pad_temp_shared"): + with T.sblock("pad_temp_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) @@ -133,14 +133,14 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): - with T.block("W_shared"): + with T.sblock("W_shared"): v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) v2 = T.axis.spatial(3, i5_0) v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): - with T.block("compute"): + with T.sblock("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) @@ -151,14 +151,14 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): - with T.block("compute_local"): + with T.sblock("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) @@ -172,7 +172,7 @@ def main(X: T.Buffer((1, 512, 56, 56), "float32"), W: T.Buffer((512, 512, 3, 3), for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(8, thread="threadIdx.x"): for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): - with T.block("compute"): + with T.sblock("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) @@ -183,7 +183,7 @@ def main(X: T.Buffer((1, 512, 56, 56), "float32"), W: T.Buffer((512, 512, 3, 3), compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + T.if_then_else(yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): - with T.block("compute_local"): + with T.sblock("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) @@ -199,23 +199,23 @@ def main(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256) T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_maxelem[i0_1] = T.min_value("float32") T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) T_softmax_exp[i0_2, i1_1] = T.exp(A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32") for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4, k = T.axis.remap("SR", [i0_3, i1]) with T.init(): T_softmax_expsum[i0_4] = T.float32(0) T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] for i0_5, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] @@ -227,19 +227,19 @@ def main(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256) T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_maxelem[i0_1] = T.min_value("float32") T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_expsum[i0_2] = T.float32(0) T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) T_softmax_norm[i0_4, i1_1] = T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] @@ -260,30 +260,30 @@ def main( T_add_1 = T.alloc_buffer([1, 384], dtype="int64") T_where = T.alloc_buffer([1, 384], dtype="int64") T_take = T.alloc_buffer([1, 384, 768], dtype="float32") - with T.block("compile_engine_const"): + with T.sblock("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.int64(0) for i0, i1 in T.grid(1, 384): - with T.block("T_less"): + with T.sblock("T_less"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const[()]) T.writes(T_less[ax0, ax1]) T_less[ax0, ax1] = placeholder[ax0, ax1] < compile_engine_const[()] - with T.block("compile_engine_const_1"): + with T.sblock("compile_engine_const_1"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = T.int64(30522) for i0, i1 in T.grid(1, 384): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const_1[()]) T.writes(T_add_1[ax0, ax1]) T_add_1[ax0, ax1] = placeholder[ax0, ax1] + compile_engine_const_1[()] for i0, i1 in T.grid(1, 384): - with T.block("T_where"): + with T.sblock("T_where"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0, ax1]) T.writes(T_where[ax0, ax1]) @@ -291,7 +291,7 @@ def main( T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1], placeholder[ax0, ax1] ) for i0, i1, i2 in T.grid(1, 384, 768): - with T.block("T_take"): + with T.sblock("T_take"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder_1[T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2], @@ -302,7 +302,7 @@ def main( T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2 ] for i0, i1, i2 in T.grid(1, 384, 768): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) @@ -316,9 +316,9 @@ def main(placeholder: T.Buffer((1, 384), "int64"), placeholder_1: T.Buffer((3052 # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1, i2 in T.grid(1, 384, 768): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(placeholder[ax0, ax1], placeholder_1[T.min(T.max(T.int64(0), placeholder[ax0, ax1]), T.int64(30521)) : T.min(T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) @@ -331,9 +331,9 @@ def main(T_full: T.Buffer((1, 12, 4096), "int64")) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1, i2 in T.grid(1, 12, 4096): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads() T.writes(T_full[ax0, ax1, ax2]) @@ -347,7 +347,7 @@ def main(p0: T.Buffer((16, 14, 14, 256), "int8"), p1: T.Buffer((1024, 1, 1, 256) # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") compile_engine_const = T.alloc_buffer([], dtype="int32") pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8") conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") @@ -359,19 +359,19 @@ def main(p0: T.Buffer((16, 14, 14, 256), "int8"), p1: T.Buffer((1024, 1, 1, 256) T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32") - with T.block("compile_engine_const"): + with T.sblock("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = 59 for i0, i1, i2, i3 in T.grid(16, 14, 14, 256): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) T.writes(conv2d_nhwc[nn, yy, xx, ff]) @@ -379,55 +379,55 @@ def main(p0: T.Buffer((16, 14, 14, 256), "int8"), p1: T.Buffer((1024, 1, 1, 256) conv2d_nhwc[nn, yy, xx, ff] = 0 conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024): - with T.block("compute"): + with T.sblock("compute"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3]) T.writes(T_add_1[ax0, ax1, ax2, ax3]) T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3] for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0]) T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0] for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024): - with T.block("compute_2"): + with T.sblock("compute_2"): i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) T.writes(compute_3[i0_8, i1_8, i2_8, i3_8]) compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32") for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3]) T.writes(T_add_2[ax0, ax1, ax2, ax3]) T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3] for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024): - with T.block("compute_3"): + with T.sblock("compute_3"): i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10]) T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11]) T.writes(compute[i0_11, i1_11, i2_11, i3_11]) @@ -501,17 +501,17 @@ def test_inline_constant_tensor(): def test_conv2d_int8_inline_constant_scalars(): sch = Schedule(Conv2dInt8) - conv2d = sch.get_block("conv2d_nhwc") + conv2d = sch.get_sblock("conv2d_nhwc") sch.cache_write(conv2d, 0, "shared") with pytest.raises(tvm.tir.ScheduleError) as e: - sch.reverse_compute_inline(sch.get_block("T_add_1")) + sch.reverse_compute_inline(sch.get_sblock("T_add_1")) err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)" assert err_msg in str(e) - ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const")) - sch.reverse_compute_inline(sch.get_block("T_add_1")) + ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_sblock("compile_engine_const")) + sch.reverse_compute_inline(sch.get_sblock("T_add_1")) def test_inline_constant_scalars_skip_output_block(): @@ -521,14 +521,14 @@ def test_inline_constant_scalars_skip_output_block(): class Full: @T.prim_func def main(T_full: T.Buffer((), "float32")): - with T.block("T_full"): + with T.sblock("T_full"): vi = T.axis.spatial(1, 0) T.reads() T.writes(T_full[()]) T_full[()] = T.float32(1) sch = Schedule(Full) - sch = ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("T_full"))[0] + sch = ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_sblock("T_full"))[0] assert_structural_equal(sch.mod, Full) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 6f446ae14eda..d99793aa92ef 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -37,13 +37,13 @@ def main( T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_maxelem[i0_1] = T.min_value("float32") T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_expsum[i0_2] = T.float32(0) @@ -51,9 +51,9 @@ def main( A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" ) for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_4, i1_1] = ( T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] @@ -69,12 +69,12 @@ def softmax_mn_0( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem[i0_1]) @@ -82,7 +82,7 @@ def softmax_mn_0( T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2]) T.writes(T_softmax_exp[i0_2, i1_1]) @@ -90,7 +90,7 @@ def softmax_mn_0( A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32" ) for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4, k = T.axis.remap("SR", [i0_3, i1]) T.reads(T_softmax_exp[i0_4, k]) T.writes(T_softmax_expsum[i0_4]) @@ -98,11 +98,11 @@ def softmax_mn_0( T_softmax_expsum[i0_4] = T.float32(0) T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] for i0_5, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6]) T.writes(T_softmax_norm[i0_6, i1_2]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] @T.prim_func @@ -112,14 +112,14 @@ def softmax_mn_1( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_1 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) @@ -132,7 +132,7 @@ def softmax_mn_1( ) for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): T.where(i1_0 * 512 + i1_1 < 256) i0_2 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) @@ -142,7 +142,7 @@ def softmax_mn_1( A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4, k = T.axis.remap("SR", [i0_3, i1]) T.reads(T_softmax_exp[i0_4, k]) T.writes(T_softmax_expsum[i0_4]) @@ -150,11 +150,11 @@ def softmax_mn_1( T_softmax_expsum[i0_4] = T.float32(0) T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] for i0_5, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6]) T.writes(T_softmax_norm[i0_6, i1_2]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] @T.prim_func @@ -164,12 +164,12 @@ def softmax_mn_2( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem[i0_1]) @@ -177,7 +177,7 @@ def softmax_mn_2( T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2]) T.writes(T_softmax_exp[i0_2, i1_1]) @@ -187,7 +187,7 @@ def softmax_mn_2( for i0_3 in T.serial(256): for ax0, ax1_0 in T.grid(1, 32): for ax1_1 in T.thread_binding(8, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 8 + ax1_1) T.reads(T_softmax_exp[i0_4, k]) @@ -199,12 +199,12 @@ def softmax_mn_2( ) for i1_0 in T.serial(32): for i1_1_1 in T.thread_binding(8, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0_3) i1 = T.axis.spatial(256, i1_0 * 8 + i1_1_1) T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5] ) @@ -216,14 +216,14 @@ def softmax_mn_3( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_1 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) @@ -236,7 +236,7 @@ def softmax_mn_3( ) for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): T.where(i1_0 * 512 + i1_1 < 256) i0_2 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) @@ -248,7 +248,7 @@ def softmax_mn_3( for i0_3 in T.serial(256): for ax0, ax1_0 in T.grid(1, 32): for ax1_1 in T.thread_binding(8, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 8 + ax1_1) T.reads(T_softmax_exp[i0_4, k]) @@ -260,12 +260,12 @@ def softmax_mn_3( ) for i1_0 in T.serial(32): for i1_1 in T.thread_binding(8, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0_3) i1 = T.axis.spatial(256, i1_0 * 8 + i1_1) T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5] ) @@ -304,7 +304,7 @@ def softmax_mn_after_inline_0( T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem[i0_1]) @@ -312,7 +312,7 @@ def softmax_mn_after_inline_0( T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) T.reads(A[i0_2, k], T_softmax_maxelem[i0_2]) T.writes(T_softmax_expsum[i0_2]) @@ -322,11 +322,11 @@ def softmax_mn_after_inline_0( A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" ) for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4]) T.writes(T_softmax_norm[i0_4, i1_1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_4, i1_1] = ( T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] @@ -340,7 +340,7 @@ def softmax_mn_after_inline_1( T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1_0 in T.grid(256, 4): for i1_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, i1_0 * 64 + i1_1) T.reads(A[i0_1, k]) @@ -349,7 +349,7 @@ def softmax_mn_after_inline_1( T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) T.reads(A[i0_2, k], T_softmax_maxelem[i0_2]) T.writes(T_softmax_expsum[i0_2]) @@ -359,11 +359,11 @@ def softmax_mn_after_inline_1( A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" ) for i0_3, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4]) T.writes(T_softmax_norm[i0_4, i1_1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_4, i1_1] = ( T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] @@ -376,7 +376,7 @@ def softmax_mn_after_inline_2( T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem[i0_1]) @@ -386,7 +386,7 @@ def softmax_mn_after_inline_2( for i0_3 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_2 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) @@ -399,7 +399,7 @@ def softmax_mn_after_inline_2( ) for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): T.where(i1_0 * 512 + i1_1 < 256) i0_4 = T.axis.spatial(256, i0_3) i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1) @@ -407,7 +407,7 @@ def softmax_mn_after_inline_2( A[i0_4, i1_1_1], T_softmax_maxelem[i0_4], T_softmax_expsum_shared[i0_4] ) T.writes(T_softmax_norm[i0_4, i1_1_1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_4, i1_1_1] = ( T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum_shared[i0_4] @@ -422,7 +422,7 @@ def softmax_mn_after_inline_3( for i0_3 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_1 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) @@ -435,7 +435,7 @@ def softmax_mn_after_inline_3( ) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_2 = T.axis.spatial(256, i0_3 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) @@ -448,7 +448,7 @@ def softmax_mn_after_inline_3( ) for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): T.where(i1_0 * 512 + i1_1 < 256) i0_4 = T.axis.spatial(256, i0_3) i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1) @@ -458,7 +458,7 @@ def softmax_mn_after_inline_3( T_softmax_expsum_shared[i0_4], ) T.writes(T_softmax_norm[i0_4, i1_1_1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_4, i1_1_1] = ( T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem_shared[i0_4], dtype="float32") / T_softmax_expsum_shared[i0_4] @@ -502,10 +502,10 @@ def batch_norm_bmn_0(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "floa # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer([1], dtype="float32") for i0, i1, i2 in T.grid(1, 512, 512): - with T.block("C"): + with T.sblock("C"): b, i, j = T.axis.remap("SRR", [i0, i1, i2]) T.reads(A[b, i, j]) T.writes(C[b]) @@ -513,7 +513,7 @@ def batch_norm_bmn_0(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "floa C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0 in T.serial(1): - with T.block("D"): + with T.sblock("D"): b = T.axis.spatial(1, i0) T.reads(C[b]) T.writes(D[b]) @@ -524,12 +524,12 @@ def batch_norm_bmn_1(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "floa # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C_shared = T.alloc_buffer([1], dtype="float32", scope="shared") for i0_0 in T.serial(1): for ax0, ax1_ax2_fused_0 in T.grid(1, 1024): for ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("C"): + with T.sblock("C"): b = T.axis.spatial(1, ax0) i = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) // 512) j = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) % 512) @@ -539,7 +539,7 @@ def batch_norm_bmn_1(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "floa C_shared[b] = T.float32(0) C_shared[b] = C_shared[b] + A[b, i, j] * A[b, i, j] for i0_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("D"): + with T.sblock("D"): T.where(i0_0 * 256 + i0_1 < 1) b = T.axis.spatial(1, i0_0 * 256 + i0_1) T.reads(C_shared[b]) @@ -574,7 +574,7 @@ def argmax( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1 in T.grid(128, 128): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1) T.reads(idx[i, k], val[i, k]) @@ -596,7 +596,7 @@ def argmax_32( argmax_v1: T.Buffer((1,), "float32"), ) -> None: for i0, i1 in T.grid(1, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(1, i0) k = T.axis.reduce(32, i1) T.reads(idx[i, k], val[i, k]) @@ -619,9 +619,9 @@ def argmax_0( argmax_v1: T.Buffer(128, "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i0, i1 in T.grid(128, 128): - with T.block("argmax"): + with T.sblock("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(idx[i, k], val[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) @@ -643,10 +643,10 @@ def argmax_1( argmax_v1: T.Buffer(128, "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i0, i1_0 in T.grid(128, 2): for i1_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 64 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -692,9 +692,9 @@ def argmax_0( argmax_v1: T.Buffer((1,), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i0, i1 in T.grid(1, 32): - with T.block("argmax"): + with T.sblock("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(idx[i, k], val[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) @@ -716,10 +716,10 @@ def argmax_1( argmax_v1: T.Buffer((1,), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i0, i1_0 in T.grid(1, 1): for i1_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(1, i0) k = T.axis.reduce(32, i1_0 * 64 + i1_1) T.where(i1_0 * 64 + i1_1 < 32) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt.py index fdc5b702594b..751e2c21efa0 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt.py @@ -38,22 +38,22 @@ def cpu_matmul_0( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C_global = T.alloc_buffer([512, 512], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 8, 8, 1): for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(16, 2, 8, 32, 32, 8): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3) j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3) k = T.axis.reduce(512, i2_0 * 32 + i2_1) T.reads(A[i, k], B[k, j]) T.writes(C_global[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C_global[i, j] = T.float32(0) C_global[i, j] = C_global[i, j] + A[i, k] * B[k, j] for ax0, ax1 in T.grid(64, 64): - with T.block("C_global"): + with T.sblock("C_global"): v0 = T.axis.spatial(512, i0_1 * 64 + ax0) v1 = T.axis.spatial(512, i1_0 * 64 + ax1) T.reads(C_global[v0, v1]) @@ -69,22 +69,22 @@ def cpu_matmul_1( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C_global = T.alloc_buffer([512, 512], dtype="float32") for i0_0, i1_0 in T.grid(1, 8): for i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(8, 1, 16, 2, 8, 32, 32, 8): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3) j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3) k = T.axis.reduce(512, i2_0 * 32 + i2_1) T.reads(A[i, k], B[k, j]) T.writes(C_global[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C_global[i, j] = T.float32(0) C_global[i, j] = C_global[i, j] + A[i, k] * B[k, j] for ax0, ax1 in T.grid(512, 64): - with T.block("C_global"): + with T.sblock("C_global"): v0 = T.axis.spatial(512, ax0) v1 = T.axis.spatial(512, i1_0 * 64 + ax1) T.reads(C_global[v0, v1]) @@ -100,17 +100,17 @@ def cpu_matmul_2( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid( 1, 8, 8, 1, 16, 2, 8, 32, 32, 8 ): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3) j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3) k = T.axis.reduce(512, i2_0 * 32 + i2_1) T.reads(A[i, k], B[k, j]) T.writes(C[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -156,23 +156,23 @@ def cpu_matmul_relu_0( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer([512, 512], dtype="float32") for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid( 256, 4, 1, 4, 64, 1, 32, 8, 2, 1 ): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3) k = T.axis.reduce(512, i2_0 * 8 + i2_1) T.reads(A[i, k], B[k, j]) T.writes(C[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] for i0, i1 in T.grid(512, 512): - with T.block("compute"): + with T.sblock("compute"): i0_4, i1_4 = T.axis.remap("SS", [i0, i1]) T.reads(C[i0_4, i1_4]) T.writes(compute[i0_4, i1_4]) @@ -187,22 +187,22 @@ def cpu_matmul_relu_1( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer([512, 512], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(256, 4, 1, 4): for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(64, 1, 32, 8, 2, 1): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3) k = T.axis.reduce(512, i2_0 * 8 + i2_1) T.reads(A[i, k], B[k, j]) T.writes(C[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] for ax0, ax1 in T.grid(2, 32): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(512, i0_0 * 2 + ax0) i1 = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + ax1) T.reads(C[i0, i1]) @@ -218,22 +218,22 @@ def cpu_matmul_relu_2( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer([512, 512], dtype="float32") for i0_0, i1_0 in T.grid(256, 4): for i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(1, 4, 64, 1, 32, 8, 2, 1): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3) k = T.axis.reduce(512, i2_0 * 8 + i2_1) T.reads(A[i, k], B[k, j]) T.writes(C[i, j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] for ax0, ax1 in T.grid(2, 128): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(512, i0_0 * 2 + ax0) i1 = T.axis.spatial(512, i1_0 * 128 + ax1) T.reads(C[i0, i1]) @@ -280,7 +280,7 @@ def cuda_matmul_0( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") @@ -289,27 +289,27 @@ def cuda_matmul_0( for i0_2_i1_2_fused in T.thread_binding(4, thread="threadIdx.x"): for i2_0 in T.serial(128): for ax0_ax1_fused in T.serial(256): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial( 512, i0_0_i1_0_fused // 16 * 64 + ax0_ax1_fused // 4 ) v1 = T.axis.spatial(512, i2_0 * 4 + ax0_ax1_fused % 4) T.reads(A[v0, v1]) T.writes(A_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in T.serial(128): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(512, i2_0 * 4 + ax0_ax1_fused // 32) v1 = T.axis.spatial( 512, i0_0_i1_0_fused % 16 * 32 + ax0_ax1_fused % 32 ) T.reads(B[v0, v1]) T.writes(B_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(2, 1, 1, 2, 16, 4): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial( 512, i0_0_i1_0_fused // 16 * 64 @@ -328,7 +328,7 @@ def cuda_matmul_0( k = T.axis.reduce(512, i2_0 * 4 + i2_1 * 2 + i2_2) T.reads(A_shared[i, k], B_shared[k, j]) T.writes(C_local[i, j]) - T.block_attr( + T.sblock_attr( { "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, @@ -339,7 +339,7 @@ def cuda_matmul_0( C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(16, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial( 512, i0_0_i1_0_fused // 16 * 64 + i0_1_i1_1_fused // 2 * 16 + ax0 ) @@ -386,7 +386,7 @@ def cuda_matmul_relu_0( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer([512, 512], dtype="float32") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") @@ -396,27 +396,27 @@ def cuda_matmul_relu_0( for i0_2_i1_2_fused in T.thread_binding(8, thread="threadIdx.x"): for i2_0 in T.serial(8): for ax0_ax1_fused in T.serial(4096): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial( 512, i0_0_i1_0_fused // 8 * 64 + ax0_ax1_fused // 64 ) v1 = T.axis.spatial(512, i2_0 * 64 + ax0_ax1_fused % 64) T.reads(A[v0, v1]) T.writes(A_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in T.serial(4096): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(512, i2_0 * 64 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 512, i0_0_i1_0_fused % 8 * 64 + ax0_ax1_fused % 64 ) T.reads(B[v0, v1]) T.writes(B_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(8, 2, 1, 8, 2, 2): - with T.block("C"): + with T.sblock("C"): i = T.axis.spatial( 512, i0_0_i1_0_fused // 8 * 64 @@ -436,7 +436,7 @@ def cuda_matmul_relu_0( k = T.axis.reduce(512, i2_0 * 64 + i2_1 * 8 + i2_2) T.reads(A_shared[i, k], B_shared[k, j]) T.writes(C_local[i, j]) - T.block_attr( + T.sblock_attr( { "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, @@ -447,7 +447,7 @@ def cuda_matmul_relu_0( C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(4, 2): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial( 512, i0_0_i1_0_fused // 8 * 64 @@ -466,7 +466,7 @@ def cuda_matmul_relu_0( T.writes(C[v0, v1]) C[v0, v1] = C_local[v0, v1] for i0, i1 in T.grid(512, 512): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(C[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) @@ -501,7 +501,7 @@ def sum_with_trivial_block_iter( B: T.Buffer((1, 64, 1), "float32"), ) -> None: for i0, i1, i2, i3 in T.grid(1, 64, 1, 768): - with T.block("sum"): + with T.sblock("sum"): ax0, ax1, ax2, k2 = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(A[ax0, ax1, k2]) T.writes(B[ax0, ax1, ax2]) @@ -530,7 +530,7 @@ def cpu_conv2d_nhwc( T.func_attr({"global_symbol": "main", "tir.noalias": True}) PadInput = T.alloc_buffer((1, 58, 58, 64), "float16") for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) @@ -559,7 +559,7 @@ def cpu_conv2d_nhwc( w_2, co_2, ) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2) v_h = T.axis.spatial(56, h_0 * 56 + h_1 * 4 + h_2) v_w = T.axis.spatial(56, w_0 * 28 + w_1 * 14 + w_2) @@ -572,7 +572,7 @@ def cpu_conv2d_nhwc( weight[v_rh, v_rw, v_rc, v_co], ) T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SRSRS"}) with T.init(): conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float16(0) conv2d_nhwc[v_n, v_h, v_w, v_co] = ( @@ -642,23 +642,23 @@ def cache_read_specify_consumer_0( for i_2_j_2_fused in T.thread_binding(16, thread="threadIdx.x"): for k_0 in range(2): for ax0_ax1_fused in range(131072): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(512, ax0_ax1_fused // 256) v1 = T.axis.spatial(512, k_0 * 256 + ax0_ax1_fused % 256) T.reads(A[v0, v1]) T.writes(A_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in range(65536): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(512, k_0 * 256 + ax0_ax1_fused // 256) v1 = T.axis.spatial(512, i_0_j_0_fused * 256 + ax0_ax1_fused % 256) T.reads(B[v0, v1]) T.writes(B_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) B_shared[v0, v1] = B[v0, v1] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(64, 1, 1, 4, 1, 16): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial( 512, i_1_j_1_fused // 8 * 8 + i_2_j_2_fused // 2 + i_3 + i_4, @@ -674,7 +674,7 @@ def cache_read_specify_consumer_0( v_k = T.axis.reduce(512, k_0 * 256 + k_1 * 4 + k_2) T.reads(A_shared[v_i, v_k], B_shared[v_k, v_j]) T.writes(C_local[v_i, v_j]) - T.block_attr( + T.sblock_attr( { "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, @@ -687,7 +687,7 @@ def cache_read_specify_consumer_0( C_local[v_i, v_j] + A_shared[v_i, v_k] * B_shared[v_k, v_j] ) for ax0, ax1 in T.grid(1, 16): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial( 512, i_1_j_1_fused // 8 * 8 + i_2_j_2_fused // 2 + ax0, @@ -703,7 +703,7 @@ def cache_read_specify_consumer_0( T.writes(C[v0, v1]) C[v0, v1] = C_local[v0, v1] for ax0, ax1 in T.grid(512, 512): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(512, ax0) v_ax1 = T.axis.spatial(512, ax1) T.reads(C[v_ax0, v_ax1], A[v_ax0, v_ax1]) @@ -746,7 +746,7 @@ def pool_blocked_cache_read_write( X_global = T.alloc_buffer((1, 2, 8, 8, 8, 8, 32), "uint8") for b_0, c_o_0, h_o_0, w_o_0, h_i_0, w_i_0, c_i_0 in T.grid(1, 2, 4, 1, 8, 1, 4): for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in range(896): - with T.block("X_global"): + with T.sblock("X_global"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(2, c_o_0) v2 = T.axis.spatial(8, h_o_0 * 2) @@ -762,7 +762,7 @@ def pool_blocked_cache_read_write( for wh, ww, b_1, c_o_1, h_o_1, w_o_1, h_i_1, w_i_1, c_i_1 in T.grid( 2, 2, 1, 1, 1, 4, 1, 8, 8 ): - with T.block("pool"): + with T.sblock("pool"): v_b = T.axis.spatial(1, b_0 + b_1) v_c_o = T.axis.spatial(2, c_o_0 + c_o_1) v_h_o = T.axis.spatial(4, h_o_0 + h_o_1) @@ -783,7 +783,7 @@ def pool_blocked_cache_read_write( ] ) T.writes(pool_global[v_b, v_c_o, v_h_o, v_w_o, v_h_i, v_w_i, v_c_i]) - T.block_attr({"meta_schedule.tiling_structure": "SRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SRS"}) with T.init(): pool_global[v_b, v_c_o, v_h_o, v_w_o, v_h_i, v_w_i, v_c_i] = T.uint8(0) pool_global[v_b, v_c_o, v_h_o, v_w_o, v_h_i, v_w_i, v_c_i] = T.max( @@ -799,7 +799,7 @@ def pool_blocked_cache_read_write( ], ) for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 4, 1, 8, 8): - with T.block("pool_global"): + with T.sblock("pool_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(2, c_o_0 + ax1) v2 = T.axis.spatial(4, h_o_0 + ax2) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py index b435b92280be..4a2bc9e837cb 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -39,7 +39,7 @@ def conv2d_nchwc( ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): ( n, oc_chunk, @@ -72,11 +72,11 @@ def conv2d_nchwc( @T.prim_func def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): conv2d_NCHWc_int8_global = T.alloc_buffer((1, 16, 56, 56, 16), "int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1): for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): - with T.block("conv2d_NCHWc_int8_o"): + with T.sblock("conv2d_NCHWc_int8_o"): n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) @@ -89,23 +89,23 @@ def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": intrin}) + T.sblock_attr({"meta_schedule.auto_tensorize": intrin}) with T.init(): for i4_1 in range(16): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init]) conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init] = 0 for i4_1, i9_1 in T.grid(16, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.Cast("int32", placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i]) * T.Cast("int32", placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 2, 1, 16): - with T.block("conv2d_NCHWc_int8_global"): + with T.sblock("conv2d_NCHWc_int8_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(16, i1_0 * 2 + i1_1 + ax1) v2 = T.axis.spatial(56, i2_0 * 2 + ax2) @@ -118,11 +118,11 @@ def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place @T.prim_func def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): conv2d_NCHWc_int8_global = T.alloc_buffer((1, 16, 56, 56, 16), "int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1): for i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): - with T.block("conv2d_NCHWc_int8_o"): + with T.sblock("conv2d_NCHWc_int8_o"): n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) @@ -135,23 +135,23 @@ def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": intrin}) + T.sblock_attr({"meta_schedule.auto_tensorize": intrin}) with T.init(): for i4_1 in range(16): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init]) conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init] = 0 for i4_1, i9_1 in T.grid(16, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.Cast("int32", placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i]) * T.Cast("int32", placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 2, 1, 16): - with T.block("conv2d_NCHWc_int8_global"): + with T.sblock("conv2d_NCHWc_int8_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(16, i1_0 * 2 + ax1) v2 = T.axis.spatial(56, i2_0 * 2 + ax2) @@ -164,9 +164,9 @@ def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place @T.prim_func def x86_conv2d_nchwc_2(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): - with T.block("conv2d_NCHWc_int8_o"): + with T.sblock("conv2d_NCHWc_int8_o"): n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3) oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) @@ -179,20 +179,20 @@ def x86_conv2d_nchwc_2(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": intrin}) + T.sblock_attr({"meta_schedule.auto_tensorize": intrin}) with T.init(): for i4_1 in range(16): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 for i4_1, i9_1 in T.grid(16, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] + T.Cast("int32", placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i]) * T.Cast("int32", placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) # fmt: on decision_0 = [ @@ -307,7 +307,7 @@ def dp4a_dense_0( compute: T.Buffer((128, 128), "int32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): compute_local = T.alloc_buffer((128, 128), "int32", scope="local") X_shared = T.alloc_buffer((128, 128), "int8", scope="shared") W_shared = T.alloc_buffer((128, 128), "int8", scope="shared") @@ -316,23 +316,23 @@ def dp4a_dense_0( for i_2_j_2_fused in T.thread_binding(2, thread="threadIdx.x"): for k_0_0 in range(1): for ax0_ax1_fused in range(16384): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in range(16384): - with T.block("W_shared"): + with T.sblock("W_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for k_0_1, i_3, j_3, k_0_2, i_4, j_4 in T.grid(1, 2, 4, 32, 2, 1): - with T.block("compute_o"): + with T.sblock("compute_o"): v_i = T.axis.spatial( 128, i_1_j_1_fused // 32 * 8 + i_2_j_2_fused * 4 + i_3 * 2 + i_4 ) @@ -343,14 +343,14 @@ def dp4a_dense_0( W_shared[v_j, v_k_o * 4 : v_k_o * 4 + 4], ) T.writes(compute_local[v_i, v_j]) - T.block_attr({"meta_schedule.auto_tensorize": "dp4a_s8s8s32"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "dp4a_s8s8s32"}) with T.init(): - with T.block("compute_init"): + with T.sblock("compute_init"): T.reads() T.writes(compute_local[v_i, v_j]) compute_local[v_i, v_j] = 0 for k_1 in range(4): - with T.block("compute"): + with T.sblock("compute"): v_k_i = T.axis.reduce(4, k_1) T.reads( compute_local[v_i, v_j], @@ -358,12 +358,14 @@ def dp4a_dense_0( W_shared[v_j, v_k_o * 4 + v_k_i], ) T.writes(compute_local[v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr( + {"meta_schedule.tiling_structure": "SSSRRSRS"} + ) compute_local[v_i, v_j] = compute_local[v_i, v_j] + T.Cast( "int32", X_shared[v_i, v_k_o * 4 + v_k_i] ) * T.Cast("int32", W_shared[v_j, v_k_o * 4 + v_k_i]) for ax0, ax1 in T.grid(4, 4): - with T.block("compute_local"): + with T.sblock("compute_local"): v0 = T.axis.spatial( 128, i_1_j_1_fused // 32 * 8 + i_2_j_2_fused * 4 + ax0 ) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index 6e0aa4cf8ae1..20cde6e83207 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -84,7 +84,7 @@ def test_matmul_relu(shared_scope): @T.prim_func def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) @@ -96,74 +96,74 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_0_0 in range(1): for ax0_ax1_fused in range(4096): - with T.block("A_reindex_shared"): + with T.sblock("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) A_reindex_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in range(4096): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_shared[v0, v1] = B[v0, v1] for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("A_reindex_shared_wmma.matrix_a_o"): + with T.sblock("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): + with T.sblock("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("B_reindex_shared_wmma.matrix_b_o"): + with T.sblock("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): + with T.sblock("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(2): for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): + with T.sblock("C_reindex_shared_wmma.accumulator_o"): v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) v2 = T.axis.spatial(2, ax2 + ax2_1) @@ -172,15 +172,15 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f v5_o = T.axis.spatial(1, 0) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax4, ax5 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): + with T.sblock("C_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): - with T.block("C_reindex_shared"): + with T.sblock("C_reindex_shared"): v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) v2 = T.axis.spatial(2, ax2) @@ -189,7 +189,7 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ @@ -235,7 +235,7 @@ def test_matmul_relu_with_fallback(): @T.prim_func def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared") C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="wmma.accumulator") A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") @@ -247,74 +247,74 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_0_0 in range(2): for ax0_ax1_fused in range(2048): - with T.block("A_reindex_shared"): + with T.sblock("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) A_reindex_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in range(8192): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) B_reindex_shared[v0, v1] = B[v0, v1] for ax2_0_1 in range(1): for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("A_reindex_shared_wmma.matrix_a_o"): + with T.sblock("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): + with T.sblock("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(4, 4): - with T.block("B_reindex_shared_wmma.matrix_b_o"): + with T.sblock("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): + with T.sblock("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 4, 2, 4): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2) T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(2): for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 4): - with T.block("C_reindex_shared_wmma.accumulator_o"): + with T.sblock("C_reindex_shared_wmma.accumulator_o"): v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) v1 = T.axis.spatial(2, ax0_ax1_fused) v2 = T.axis.spatial(2, ax2 + ax2_1) @@ -323,15 +323,15 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v5_o = T.axis.spatial(1, 0) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): + with T.sblock("C_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(2048): - with T.block("C_reindex_shared"): + with T.sblock("C_reindex_shared"): v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) v2 = T.axis.spatial(2, ax2) @@ -340,7 +340,7 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ @@ -392,7 +392,7 @@ def test_conv2d(shared_scope): @T.prim_func def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): PadInput = T.alloc_buffer((1, 18, 18, 32), "float16") conv2d_nhwc_reindex_shared_dyn = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) conv2d_nhwc_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") @@ -401,7 +401,7 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, PadInput_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") weight_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) @@ -411,95 +411,95 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): for ax2_0_0 in range(1): for ax0_ax1_fused in range(4608): - with T.block("PadInput_reindex_shared.dyn"): + with T.sblock("PadInput_reindex_shared.dyn"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) T.reads(PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32]) T.writes(PadInput_reindex_shared_dyn[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) PadInput_reindex_shared_dyn[v0, v1] = PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32] for ax0_ax1_fused in range(4608): - with T.block("weight_reindex_shared.dyn"): + with T.sblock("weight_reindex_shared.dyn"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) T.writes(weight_reindex_shared_dyn[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) weight_reindex_shared_dyn[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] for ax2_0_1 in range(18): for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("PadInput_reindex_shared.dyn_wmma.matrix_a_o"): + with T.sblock("PadInput_reindex_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) T.reads(PadInput_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("PadInput_reindex_shared.dyn_wmma.matrix_a"): + with T.sblock("PadInput_reindex_shared.dyn_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(PadInput_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("weight_reindex_shared.dyn_wmma.matrix_b_o"): + with T.sblock("weight_reindex_shared.dyn_wmma.matrix_b_o"): v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) T.reads(weight_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("weight_reindex_shared.dyn_wmma.matrix_b"): + with T.sblock("weight_reindex_shared.dyn_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(weight_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 1, 1): - with T.block("conv2d_nhwc_o"): + with T.sblock("conv2d_nhwc_o"): v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0_3 + ax0_0_4) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) T.reads(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(1): for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator_o"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3]) v4_o = T.axis.spatial(1, 0) v5_o = T.axis.spatial(1, 0) T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) T.writes(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax4, ax5 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator"): + with T.sblock("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) T.writes(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4_i, v5_i]) conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(256): - with T.block("conv2d_nhwc_reindex_shared.dyn"): + with T.sblock("conv2d_nhwc_reindex_shared.dyn"): v0, v1, v2 = T.axis.remap("SSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2]) v3 = T.axis.spatial(1, 0) v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ @@ -575,7 +575,7 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer((128, 128)) C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope=shared_scope) C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope="wmma.accumulator") @@ -588,74 +588,74 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): for ax0_ax1_fused in range(1024): - with T.block("A_reindex_shared"): + with T.sblock("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) A_reindex_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in range(1024): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) B_reindex_shared[v0, v1] = B[v0, v1] for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("A_reindex_shared_wmma.matrix_a_o"): + with T.sblock("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): + with T.sblock("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("B_reindex_shared_wmma.matrix_b_o"): + with T.sblock("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): + with T.sblock("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 2, 2): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init]) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(2): for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 2): - with T.block("C_reindex_shared_wmma.accumulator_o"): + with T.sblock("C_reindex_shared_wmma.accumulator_o"): v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) v2 = T.axis.spatial(2, ax2 + ax2_1) @@ -664,15 +664,15 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v5_o = T.axis.spatial(1, 0) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + T.sblock_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) for ax4, ax5 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): + with T.sblock("C_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): - with T.block("C_reindex_shared"): + with T.sblock("C_reindex_shared"): v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) v2 = T.axis.spatial(2, ax2) @@ -681,10 +681,10 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32] = C_reindex_shared[v0, v1, v2, v3, v4, v5] for i0, i1 in T.grid(128, 128): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(C[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) @@ -765,74 +765,74 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_0_0 in range(1): for ax0_ax1_fused in range(4096): - with T.block("A_reindex_shared"): + with T.sblock("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0)) for ax0_ax1_fused in range(4096): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0)) for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("A_reindex_shared_wmma.matrix_a_o"): + with T.sblock("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): + with T.sblock("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("B_reindex_shared_wmma.matrix_b_o"): + with T.sblock("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): + with T.sblock("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(2): for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): + with T.sblock("C_reindex_shared_wmma.accumulator_o"): v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) v2 = T.axis.spatial(2, ax2 + ax2_1) @@ -841,15 +841,15 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 v5_o = T.axis.spatial(1, 0) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): + with T.sblock("C_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): - with T.block("C_reindex_shared"): + with T.sblock("C_reindex_shared"): v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) v2 = T.axis.spatial(2, ax2) @@ -859,7 +859,7 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on @@ -902,7 +902,7 @@ def test_conv_1x1(): @T.prim_func def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="wmma.accumulator") PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") @@ -914,53 +914,53 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax4_0_0 in range(2): for ax0_ax1_fused in range(8192): - with T.block("PadInput_reindex_shared"): + with T.sblock("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax0_ax1_fused // 32) v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) T.reads(inputs[0, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1] for ax0_ax1_ax2_ax3_fused in range(2048): - with T.block("weight_reindex_shared"): + with T.sblock("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused // 64) v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for ax4_0_1 in range(1): for ax0_0, ax1_0 in T.grid(8, 2): - with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + with T.sblock("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax0_0) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0) T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("PadInput_reindex_shared_wmma.matrix_a"): + with T.sblock("PadInput_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 4): - with T.block("weight_reindex_shared_wmma.matrix_b_o"): + with T.sblock("weight_reindex_shared_wmma.matrix_b_o"): v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) v2_o = T.axis.spatial(4, ax4_0_0 * 2 + ax2_0) v3_o = T.axis.spatial(4, ax3_0) T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("weight_reindex_shared_wmma.matrix_b"): + with T.sblock("weight_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(8, 1, 2, 1, 4): - with T.block("conv2d_nhwc_o"): + with T.sblock("conv2d_nhwc_o"): v0_o = T.axis.spatial(1, 0) v1_o = T.axis.spatial(1, 0) v2_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4) @@ -968,25 +968,25 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init]) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init] = T.float32(0) for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) for ax2 in range(8): for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 4): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(2, ax0_ax1_fused) v1_o = T.axis.spatial(1, 0) v2_o = T.axis.spatial(8, ax2 + ax2_1) @@ -995,15 +995,15 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v5_o = T.axis.spatial(1, 0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(2048): - with T.block("conv2d_nhwc_reindex_shared"): + with T.sblock("conv2d_nhwc_reindex_shared"): v0 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(8, ax2) @@ -1012,7 +1012,7 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on @@ -1060,7 +1060,7 @@ def test_padded_conv(): @T.prim_func def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator") PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared") @@ -1072,74 +1072,74 @@ def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buf for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): for ax2_0_0 in range(10): for ax0_ax1_fused in range(28672): - with T.block("PadInput_reindex_pad_shared"): + with T.sblock("PadInput_reindex_pad_shared"): v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16) v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16) T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3]) T.writes(PadInput_reindex_pad_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0)) for ax0_ax1_fused in range(512): - with T.block("weight_reindex_pad_shared"): + with T.sblock("weight_reindex_pad_shared"): v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32) v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1]) T.writes(weight_reindex_pad_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0)) for ax2_0_1 in range(1): for ax0_0, ax1_0 in T.grid(14, 1): - with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"): + with T.sblock("PadInput_reindex_pad_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0) v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0) T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"): + with T.sblock("PadInput_reindex_pad_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"): + with T.sblock("weight_reindex_pad_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0) v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0) T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("weight_reindex_pad_shared_wmma.matrix_b"): + with T.sblock("weight_reindex_pad_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1): - with T.block("conv2d_nhwc_o"): + with T.sblock("conv2d_nhwc_o"): v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2) T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init]) conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(14): for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 2): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) v2_o = T.axis.spatial(14, ax2 + ax2_1) @@ -1148,15 +1148,15 @@ def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buf v5_o = T.axis.spatial(1, 0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(4096): - with T.block("conv2d_nhwc_reindex_shared"): + with T.sblock("conv2d_nhwc_reindex_shared"): v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512) v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) v2 = T.axis.spatial(14, ax2) @@ -1165,7 +1165,7 @@ def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buf v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on @@ -1212,7 +1212,7 @@ def test_padded_matmul_single_padded_input(): @T.prim_func def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): C_reindex_pad_shared = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="shared") C_reindex_pad_shared_wmma_accumulator = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="wmma.accumulator") A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") @@ -1224,74 +1224,74 @@ def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): for ax2_0_0 in range(32): for ax0_ax1_fused in range(65536): - with T.block("A_reindex_pad_shared"): + with T.sblock("A_reindex_pad_shared"): v0 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_fused // 128) v1 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_pad_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) A_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 1023, A[v0, v1], T.float16(0.0)) for ax0_ax1_fused in range(8192): - with T.block("B_reindex_shared"): + with T.sblock("B_reindex_shared"): v0 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused // 64) v1 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused % 16 * 64 + ax0_ax1_fused % 64) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_shared[v0, v1] = B[v0, v1] for ax2_0_1 in range(8): for ax0_0, ax1_0 in T.grid(8, 1): - with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + with T.sblock("A_reindex_pad_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0) v1_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax1_0) T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_pad_shared_wmma.matrix_a"): + with T.sblock("A_reindex_pad_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("B_reindex_shared_wmma.matrix_b_o"): + with T.sblock("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0) T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): + with T.sblock("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 4, 2): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0_3 * 4 + ax0_0_4) v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3 * 2 + ax1_0_4) v2_o = T.axis.reduce(256, ax2_0_0 * 8 + ax2_0_1 + ax2_0_2) T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init]) C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0.0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] = C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(8): for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 2): - with T.block("C_reindex_pad_shared_wmma.accumulator_o"): + with T.sblock("C_reindex_pad_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_fused // 2) v1_o = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_fused % 2) v2_o = T.axis.spatial(8, ax2 + ax2_1) @@ -1300,15 +1300,15 @@ def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: v5_o = T.axis.spatial(1, 0) T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): - with T.block("C_reindex_pad_shared_wmma.accumulator"): + with T.sblock("C_reindex_pad_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(4096): - with T.block("C_reindex_pad_shared"): + with T.sblock("C_reindex_pad_shared"): v0 = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) v1 = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_ax3_ax4_ax5_fused % 1024 // 512) v2 = T.axis.spatial(8, ax2) @@ -1318,7 +1318,7 @@ def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: T.where(ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_ax3_ax4_ax5_fused // 1024 * 128 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 1023) T.reads(C_reindex_pad_shared[v0, v1, v2, v3, v4, v5]) T.writes(C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32] = C_reindex_pad_shared[v0, v1, v2, v3, v4, v5] # fmt: on @@ -1360,7 +1360,7 @@ def test_padded_matmul_no_padded_output(): @T.prim_func def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): C_reindex_shared = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="shared") C_reindex_shared_wmma_accumulator = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="wmma.accumulator") A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") @@ -1372,74 +1372,74 @@ def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T. for ax0_0_2_ax1_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): for ax2_0_0 in range(128): for ax0_ax1_fused in range(4096): - with T.block("A_reindex_pad_shared"): + with T.sblock("A_reindex_pad_shared"): v0 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused // 16 * 256 + ax0_0_1_ax1_0_1_fused * 128 + ax0_ax1_fused // 32) v1 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused % 32) T.reads(A[v0, v1]) T.writes(A_reindex_pad_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) A_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 4095, A[v0, v1], T.float16(0.0)) for ax0_ax1_fused in range(2048): - with T.block("B_reindex_pad_shared"): + with T.sblock("B_reindex_pad_shared"): v0 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused // 64) v1 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused % 16 * 64 + ax0_ax1_fused % 64) T.reads(B[v0, v1]) T.writes(B_reindex_pad_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 4095, B[v0, v1], T.float16(0.0)) for ax2_0_1 in range(2): for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + with T.sblock("A_reindex_pad_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0) v1_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax1_0) T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_pad_shared_wmma.matrix_a"): + with T.sblock("A_reindex_pad_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0, ax1_0 in T.grid(1, 4): - with T.block("B_reindex_pad_shared_wmma.matrix_b_o"): + with T.sblock("B_reindex_pad_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0) T.reads(B_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_pad_shared_wmma.matrix_b"): + with T.sblock("B_reindex_pad_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 1, 4): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0_3 + ax0_0_4) v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0_3 * 4 + ax1_0_4) v2_o = T.axis.reduce(256, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0.0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) for ax2 in range(2): for ax0_ax1_fused in T.thread_binding(4, thread="threadIdx.y"): for ax2_1, ax3 in T.grid(1, 4): - with T.block("C_reindex_shared_wmma.accumulator_o"): + with T.sblock("C_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_fused) v1_o = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) v2_o = T.axis.spatial(2, ax2 + ax2_1) @@ -1448,15 +1448,15 @@ def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T. v5_o = T.axis.spatial(1, 0) T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): + with T.sblock("C_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(4096): - with T.block("C_reindex_shared"): + with T.sblock("C_reindex_shared"): v0 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) v1 = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) v2 = T.axis.spatial(2, ax2) @@ -1465,7 +1465,7 @@ def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T. v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = C_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index 2a0a67d4c786..8bfe27dbfcea 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -36,7 +36,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -51,12 +51,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) - T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2}) + T.sblock_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2}) for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -90,127 +90,127 @@ def main(placeholder: T.Buffer((1, 13, 13, 3, 85), "float32"), placeholder_1: T. T_concat = T.alloc_buffer([10647, 80], dtype="float32") T_transpose = T.alloc_buffer([80, 10647], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): - with T.block("T_sigmoid"): + with T.sblock("T_sigmoid"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4]) T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): - with T.block("T_strided_slice_with_axes_1"): + with T.sblock("T_strided_slice_with_axes_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): - with T.block("T_sigmoid_1"): + with T.sblock("T_sigmoid_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(8112, 80): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): - with T.block("T_strided_slice_with_axes_2"): + with T.sblock("T_strided_slice_with_axes_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): - with T.block("T_sigmoid_2"): + with T.sblock("T_sigmoid_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): - with T.block("T_strided_slice_with_axes_3"): + with T.sblock("T_strided_slice_with_axes_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): - with T.block("T_sigmoid_3"): + with T.sblock("T_sigmoid_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(2028, 80): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape_1[ax0, ax1]) T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): - with T.block("T_strided_slice_with_axes_4"): + with T.sblock("T_strided_slice_with_axes_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): - with T.block("T_sigmoid_4"): + with T.sblock("T_sigmoid_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): - with T.block("T_strided_slice_with_axes_5"): + with T.sblock("T_strided_slice_with_axes_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): - with T.block("T_sigmoid_5"): + with T.sblock("T_sigmoid_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(507, 80): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape_2[ax0, ax1]) T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1 in T.grid(10647, 80): - with T.block("T_concat"): + with T.sblock("T_concat"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1]) T.writes(T_concat[ax0, ax1]) T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32") for i0, i1 in T.grid(80, 10647): - with T.block("T_transpose"): + with T.sblock("T_transpose"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_concat[ax1, ax0]) T.writes(T_transpose[ax0, ax1]) T_transpose[ax0, ax1] = T_concat[ax1, ax0] for i0, i1, i2 in T.grid(1, 80, 10647): - with T.block("T_expand_dims"): + with T.sblock("T_expand_dims"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_transpose[ax1, ax2]) T.writes(T_expand_dims[ax0, ax1, ax2]) @@ -231,10 +231,10 @@ def Matmul_0( # function attr dict T.func_attr({"global_symbol": "main"}) # body - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr( + T.sblock_attr( { "meta_schedule.parallel": 512, "meta_schedule.unroll_explicit": 16, @@ -242,7 +242,7 @@ def Matmul_0( } ) for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(A[vi, vk], B[vk, vj]) T.writes(C[vi, vj]) diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py index 2e912af18a6a..f474cff02760 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py @@ -38,13 +38,13 @@ def main(a: T.handle, b: T.handle) -> None: A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") # body for i, j, k in T.grid(2048, 2048, 2048): - with T.block("move"): + with T.sblock("move"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([A_cached[vi, vj, vk]]) A_cached[vi, vj, vk] = A[vi, vj, vk] for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(2048, k0 * 32 + k1) @@ -65,11 +65,11 @@ def add_0( # function attr dict T.func_attr({"global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") for i0, j0, i1, j1, k0, i2 in T.grid(128, 64, 4, 4, 64, 4): for ax0, ax1, ax2 in T.grid(1, 8, 32): - with T.block("move"): + with T.sblock("move"): vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2 + ax0) vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + ax1) vk = T.axis.spatial(2048, k0 * 32 + ax2) @@ -77,7 +77,7 @@ def add_0( T.writes(A_cached[vi, vj, vk]) A_cached[vi, vj, vk] = A[vi, vj, vk] for j2, k1 in T.grid(8, 32): - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(2048, k0 * 32 + k1) diff --git a/tests/python/meta_schedule/test_meta_schedule_search_strategy.py b/tests/python/meta_schedule/test_meta_schedule_search_strategy.py index 04a6e187a6a7..b2a247a37043 100644 --- a/tests/python/meta_schedule/test_meta_schedule_search_strategy.py +++ b/tests/python/meta_schedule/test_meta_schedule_search_strategy.py @@ -41,7 +41,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore B = T.match_buffer(b, (32, 32), "float32") C = T.match_buffer(c, (32, 32), "float32") for i, j, k in T.grid(32, 32, 32): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 # type: ignore @@ -62,7 +62,7 @@ def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> def _schedule_matmul(sch: Schedule): - block = sch.get_block("matmul") + block = sch.get_sblock("matmul") i, j, k = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) @@ -123,7 +123,7 @@ def test_meta_schedule_replay_func( def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name def _schedule_matmul_small(sch: Schedule): - block = sch.get_block("matmul") + block = sch.get_sblock("matmul") _, j, k = sch.get_loops(block=block) _, _ = sch.split(j, sch.sample_perfect_tile(j, n=2)) _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2)) diff --git a/tests/python/meta_schedule/test_meta_schedule_space_cpu.py b/tests/python/meta_schedule/test_meta_schedule_space_cpu.py index 6935f62e8b27..9ebf873e81a2 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_cpu.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_cpu.py @@ -44,21 +44,21 @@ def test_cpu_c1d(): @T.prim_func def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) + T.sblock_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) PadInput = T.alloc_buffer((1, 258, 64), dtype="float32") conv1d_nlc_global = T.alloc_buffer((1, 128, 128), dtype="float32") for i0, i1, i2 in T.grid(1, 258, 64): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(inputs[v_i0, v_i1 - 1, v_i2]) T.writes(PadInput[v_i0, v_i1, v_i2]) PadInput[v_i0, v_i1, v_i2] = T.if_then_else(1 <= v_i1 and v_i1 < 257, inputs[v_i0, v_i1 - 1, v_i2], T.float32(0)) for n_0, l_0, co_0, n_1, l_1, co_1 in T.grid(1, 1, 2, 1, 1, 8): for rl_0, rc_0, n_2, l_2, co_2, rl_1, rc_1, n_3, l_3, co_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): - with T.block("conv1d_nlc"): + with T.sblock("conv1d_nlc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_l = T.axis.spatial(128, l_0 * 128 + l_1 * 128 + l_2 * 2 + l_3) v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + co_2 + co_3) @@ -66,12 +66,12 @@ def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 v_rc = T.axis.reduce(64, rc_0 + rc_1) T.reads(PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc], weight[v_rl, v_rc, v_co]) T.writes(conv1d_nlc_global[v_n, v_l, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv1d_nlc_global[v_n, v_l, v_co] = T.float32(0) conv1d_nlc_global[v_n, v_l, v_co] = conv1d_nlc_global[v_n, v_l, v_co] + PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc] * weight[v_rl, v_rc, v_co] for ax0, ax1, ax2 in T.grid(1, 128, 8): - with T.block("conv1d_nlc_global"): + with T.sblock("conv1d_nlc_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + ax2) T.reads(conv1d_nlc_global[v0, v1, v2]) @@ -80,16 +80,16 @@ def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 @T.prim_func def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 258, 64)) conv1d_nlc_global = T.alloc_buffer((1, 128, 128)) for n_0, l_0, co_0 in T.grid(1, 1, 2): for n_1, l_1, co_1 in T.grid(1, 1, 8): for ax0, ax1, ax2 in T.grid(1, 257, 64): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(258, ax1) v_i2 = T.axis.spatial(64, ax2) @@ -97,7 +97,7 @@ def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 T.writes(PadInput[v_i0, v_i1, v_i2]) PadInput[v_i0, v_i1, v_i2] = T.if_then_else(1 <= v_i1 and v_i1 < 257, inputs[v_i0, v_i1 - 1, v_i2], T.float32(0)) for rl_0, rc_0, n_2, l_2, co_2, rl_1, rc_1, n_3, l_3, co_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1): - with T.block("conv1d_nlc"): + with T.sblock("conv1d_nlc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_l = T.axis.spatial(128, l_0 * 128 + l_1 * 128 + l_2 * 2 + l_3) v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + co_2 + co_3) @@ -105,12 +105,12 @@ def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 v_rc = T.axis.reduce(64, rc_0 + rc_1) T.reads(PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc], weight[v_rl, v_rc, v_co]) T.writes(conv1d_nlc_global[v_n, v_l, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv1d_nlc_global[v_n, v_l, v_co] = T.float32(0) conv1d_nlc_global[v_n, v_l, v_co] = conv1d_nlc_global[v_n, v_l, v_co] + PadInput[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc] * weight[v_rl, v_rc, v_co] for ax0, ax1, ax2 in T.grid(1, 128, 64): - with T.block("conv1d_nlc_global"): + with T.sblock("conv1d_nlc_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(128, co_0 * 64 + ax2) T.reads(conv1d_nlc_global[v0, v1, v2]) @@ -121,12 +121,12 @@ def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 def c1d_2(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) for n_0, l_0, co_0, n_1, l_1, co_1, rl_0, rc_0, n_2, l_2, co_2, rl_1, rc_1, n_3, l_3, co_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1): - with T.block("conv1d_nlc"): + with T.sblock("conv1d_nlc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_l = T.axis.spatial(128, l_0 * 128 + l_1 * 128 + l_2 * 2 + l_3) v_co = T.axis.spatial(128, co_0 * 64 + co_1 * 8 + co_2 + co_3) @@ -134,7 +134,7 @@ def c1d_2(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 v_rc = T.axis.reduce(64, rc_0 + rc_1) T.reads(inputs[v_n, v_l * 2 + v_rl - 1, v_co // 128 * 64 + v_rc], weight[v_rl, v_rc, v_co]) T.writes(conv1d_nlc[v_n, v_l, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv1d_nlc[v_n, v_l, v_co] = T.float32(0) conv1d_nlc[v_n, v_l, v_co] = conv1d_nlc[v_n, v_l, v_co] + T.if_then_else(1 <= v_l * 2 + v_rl and v_l * 2 + v_rl < 257, inputs[v_n, v_l * 2 + v_rl - 1, v_co // 128 * 64 + v_rc], T.float32(0)) * weight[v_rl, v_rc, v_co] @@ -183,15 +183,15 @@ def test_cpu_c2d(): @T.prim_func def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 230, 230, 3)) conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 64)) for n_0, h_0, w_0, co_0, n_1, h_1, w_1 in T.grid(1, 7, 4, 2, 1, 1, 28): for ax0, ax1, ax2, ax3 in T.grid(1, 37, 7, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, h_0 * 32 + ax1) v_i2 = T.axis.spatial(230, w_0 * 56 + w_1 * 2 + ax2) @@ -201,7 +201,7 @@ def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for co_1 in range(8): for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(112, h_0 * 16 + h_1 * 16 + h_2 * 8 + h_3) v_w = T.axis.spatial(112, w_0 * 28 + w_1 + w_2 + w_3) @@ -211,12 +211,12 @@ def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 16, 1, 4): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, h_0 * 16 + ax1) v2 = T.axis.spatial(112, w_0 * 28 + w_1 + ax2) @@ -227,21 +227,21 @@ def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, @T.prim_func def c2d_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 230, 230, 3)) conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 64)) for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for n_0, h_0, w_0, co_0 in T.grid(1, 7, 4, 2): for n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(112, h_0 * 16 + h_1 * 16 + h_2 * 8 + h_3) v_w = T.axis.spatial(112, w_0 * 28 + w_1 + w_2 + w_3) @@ -251,12 +251,12 @@ def c2d_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 16, 28, 32): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, h_0 * 16 + ax1) v2 = T.axis.spatial(112, w_0 * 28 + ax2) @@ -267,14 +267,14 @@ def c2d_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, @T.prim_func def c2d_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 230, 230, 3)) for n_0, h_0 in T.grid(1, 7): for ax0, ax1, ax2, ax3 in T.grid(1, 37, 229, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, h_0 * 32 + ax1) v_i2 = T.axis.spatial(230, ax2) @@ -283,7 +283,7 @@ def c2d_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(4, 2, 1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(112, h_0 * 16 + h_1 * 16 + h_2 * 8 + h_3) v_w = T.axis.spatial(112, w_0 * 28 + w_1 + w_2 + w_3) @@ -293,7 +293,7 @@ def c2d_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc[v_n, v_h, v_w, v_co] = conv2d_nhwc[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] @@ -348,15 +348,15 @@ def test_cpu_c3d(): @T.prim_func def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 22, 230, 230, 3)) conv3d_ndhwc_global = T.alloc_buffer((1, 8, 112, 112, 64)) for n_0, d_0, h_0, w_0, co_0 in T.grid(1, 2, 4, 1, 2): for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 13, 61, 229, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(22, d_0 * 8 + ax1) v_i2 = T.axis.spatial(230, h_0 * 56 + ax2) @@ -367,7 +367,7 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 PadInput[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(3 <= v_i1 and v_i1 < 19 and 3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4], T.float32(0)) for n_1, d_1, h_1, w_1, co_1 in T.grid(1, 4, 4, 14, 1): for rd_0, rh_0, rw_0, rc_0, n_2, d_2, h_2, w_2, co_2, rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3 in T.grid(1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): - with T.block("conv3d_ndhwc"): + with T.sblock("conv3d_ndhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_d = T.axis.spatial(8, d_0 * 4 + d_1 + d_2 + d_3) v_h = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + h_2 * 7 + h_3) @@ -379,12 +379,12 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v_rc = T.axis.reduce(3, rc_0 + rc_1) T.reads(PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rd, v_rh, v_rw, v_rc, v_co]) T.writes(conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = T.float32(0) conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] + PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rd, v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 7, 8, 32): - with T.block("conv3d_ndhwc_global"): + with T.sblock("conv3d_ndhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, d_0 * 4 + d_1 + ax1) v2 = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + ax2) @@ -396,16 +396,16 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 @T.prim_func def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 22, 230, 230, 3)) conv3d_ndhwc_global = T.alloc_buffer((1, 8, 112, 112, 64)) for n_0, d_0, h_0, w_0, co_0 in T.grid(1, 2, 4, 1, 2): for n_1, d_1, h_1, w_1 in T.grid(1, 4, 4, 14): for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 7, 19, 21, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(22, d_0 * 8 + d_1 * 2 + ax1) v_i2 = T.axis.spatial(230, h_0 * 56 + h_1 * 14 + ax2) @@ -415,7 +415,7 @@ def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4]) PadInput[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(3 <= v_i1 and v_i1 < 19 and 3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4], T.float32(0)) for co_1, rd_0, rh_0, rw_0, rc_0, n_2, d_2, h_2, w_2, co_2, rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): - with T.block("conv3d_ndhwc"): + with T.sblock("conv3d_ndhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_d = T.axis.spatial(8, d_0 * 4 + d_1 + d_2 + d_3) v_h = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + h_2 * 7 + h_3) @@ -427,12 +427,12 @@ def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v_rc = T.axis.reduce(3, rc_0 + rc_1) T.reads(PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rd, v_rh, v_rw, v_rc, v_co]) T.writes(conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = T.float32(0) conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc_global[v_n, v_d, v_h, v_w, v_co] + PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rd, v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 4, 28, 112, 32): - with T.block("conv3d_ndhwc_global"): + with T.sblock("conv3d_ndhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, d_0 * 4 + ax1) v2 = T.axis.spatial(112, h_0 * 28 + ax2) @@ -444,14 +444,14 @@ def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 @T.prim_func def c3d_2(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 22, 230, 230, 3)) for n_0, d_0, h_0, w_0, co_0, n_1, d_1, h_1, w_1 in T.grid(1, 2, 4, 1, 2, 1, 4, 4, 14): for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 7, 19, 21, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(22, d_0 * 8 + d_1 * 2 + ax1) v_i2 = T.axis.spatial(230, h_0 * 56 + h_1 * 14 + ax2) @@ -461,7 +461,7 @@ def c3d_2(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4]) PadInput[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(3 <= v_i1 and v_i1 < 19 and 3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3 - 3, v_i4], T.float32(0)) for co_1, rd_0, rh_0, rw_0, rc_0, n_2, d_2, h_2, w_2, co_2, rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1): - with T.block("conv3d_ndhwc"): + with T.sblock("conv3d_ndhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_d = T.axis.spatial(8, d_0 * 4 + d_1 + d_2 + d_3) v_h = T.axis.spatial(112, h_0 * 28 + h_1 * 7 + h_2 * 7 + h_3) @@ -473,7 +473,7 @@ def c3d_2(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v_rc = T.axis.reduce(3, rc_0 + rc_1) T.reads(PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight[v_rd, v_rh, v_rw, v_rc, v_co]) T.writes(conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co] = T.float32(0) conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc[v_n, v_d, v_h, v_w, v_co] + PadInput[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight[v_rd, v_rh, v_rw, v_rc, v_co] @@ -534,15 +534,15 @@ def test_cpu_cap(): @T.prim_func def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 18, 18, 4, 4, 32)) conv2d_capsule_nhwijc_global = T.alloc_buffer((1, 8, 8, 4, 4, 32)) for n_0, h_0, w_0, cap_i_0, cap_j_0, co_0, n_1, h_1 in T.grid(1, 2, 1, 1, 1, 1, 1, 4): for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 17, 4, 4, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(18, h_0 * 8 + h_1 * 2 + ax1) v_i2 = T.axis.spatial(18, ax2) @@ -552,7 +552,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5], T.float32(0)) for w_1, cap_i_1, cap_j_1, co_1 in T.grid(4, 1, 4, 2): for rh_0, rw_0, cap_k_0, rc_0, n_2, h_2, w_2, cap_i_2, cap_j_2, co_2, rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): - with T.block("conv2d_capsule_nhwijc"): + with T.sblock("conv2d_capsule_nhwijc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(8, h_0 * 4 + h_1 + h_2 + h_3) v_w = T.axis.spatial(8, w_0 * 8 + w_1 * 2 + w_2 + w_3) @@ -565,12 +565,12 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( v_rc = T.axis.reduce(32, rc_0 * 32 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) T.writes(conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 2, 4, 1, 16): - with T.block("conv2d_capsule_nhwijc_global"): + with T.sblock("conv2d_capsule_nhwijc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, h_0 * 4 + h_1 + ax1) v2 = T.axis.spatial(8, w_1 * 2 + ax2) @@ -583,16 +583,16 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( @T.prim_func def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 18, 18, 4, 4, 32)) conv2d_capsule_nhwijc_global = T.alloc_buffer((1, 8, 8, 4, 4, 32)) for n_0, h_0, w_0, cap_i_0, cap_j_0, co_0 in T.grid(1, 2, 1, 1, 1, 1): for n_1, h_1, w_1, cap_i_1, cap_j_1, co_1 in T.grid(1, 4, 4, 1, 4, 2): for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 3, 5, 4, 4, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(18, h_0 * 8 + h_1 * 2 + ax1) v_i2 = T.axis.spatial(18, w_1 * 4 + ax2) @@ -601,7 +601,7 @@ def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5]) PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5], T.float32(0)) for rh_0, rw_0, cap_k_0, rc_0, n_2, h_2, w_2, cap_i_2, cap_j_2, co_2, rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): - with T.block("conv2d_capsule_nhwijc"): + with T.sblock("conv2d_capsule_nhwijc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(8, h_0 * 4 + h_1 + h_2 + h_3) v_w = T.axis.spatial(8, w_0 * 8 + w_1 * 2 + w_2 + w_3) @@ -614,12 +614,12 @@ def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( v_rc = T.axis.reduce(32, rc_0 * 32 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) T.writes(conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc_global[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 8, 4, 4, 32): - with T.block("conv2d_capsule_nhwijc_global"): + with T.sblock("conv2d_capsule_nhwijc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, h_0 * 4 + ax1) v2, v3, v4, v5 = T.axis.remap("SSSS", [ax2, ax3, ax4, ax5]) @@ -629,19 +629,19 @@ def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( @T.prim_func def cap_2(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 18, 18, 4, 4, 32)) for i0, i1, i2, i3, i4, i5 in T.grid(1, 18, 18, 4, 4, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3, v_i4, v_i5 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5]) PadInput[v_i0, v_i1, v_i2, v_i3, v_i4, v_i5] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3, v_i4, v_i5], T.float32(0)) for n_0, h_0, w_0, cap_i_0, cap_j_0, co_0, n_1, h_1, w_1, cap_i_1, cap_j_1, co_1, rh_0, rw_0, cap_k_0, rc_0, n_2, h_2, w_2, cap_i_2, cap_j_2, co_2, rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16): - with T.block("conv2d_capsule_nhwijc"): + with T.sblock("conv2d_capsule_nhwijc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(8, h_0 * 4 + h_1 + h_2 + h_3) v_w = T.axis.spatial(8, w_0 * 8 + w_1 * 2 + w_2 + w_3) @@ -654,7 +654,7 @@ def cap_2(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( v_rc = T.axis.reduce(32, rc_0 * 32 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) T.writes(conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] @@ -716,21 +716,21 @@ def test_cpu_dep(): @T.prim_func def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 114, 114, 32)) depth_conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 32)) for i0, i1, i2, i3 in T.grid(1, 114, 114, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 113 and 1 <= v_i2 and v_i2 < 113, placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for n_0, h_0, w_0, c_0, n_1, h_1, w_1, c_1 in T.grid(1, 1, 1, 1, 1, 4, 4, 8): for rh_0, rw_0, n_2, h_2, w_2, c_2, rh_1, rw_1, n_3, h_3, w_3, c_3 in T.grid(1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): - with T.block("depth_conv2d_nhwc"): + with T.sblock("depth_conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(112, h_0 * 112 + h_1 * 28 + h_2 * 14 + h_3) v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 28 + w_2 * 4 + w_3) @@ -739,12 +739,12 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) T.reads(PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_1[0, v_rh, v_rw, v_c]) T.writes(depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = T.float32(0) depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_1[0, v_rh, v_rw, v_c] for ax0, ax1, ax2, ax3 in T.grid(1, 28, 28, 4): - with T.block("depth_conv2d_nhwc_global"): + with T.sblock("depth_conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, h_1 * 28 + ax1) v2 = T.axis.spatial(112, w_1 * 28 + ax2) @@ -755,21 +755,21 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. @T.prim_func def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 114, 114, 32)) depth_conv2d_nhwc_global = T.alloc_buffer((1, 112, 112, 32)) for i0, i1, i2, i3 in T.grid(1, 114, 114, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 113 and 1 <= v_i2 and v_i2 < 113, placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for n_0, h_0, w_0, c_0 in T.grid(1, 1, 1, 1): for n_1, h_1, w_1, c_1, rh_0, rw_0, n_2, h_2, w_2, c_2, rh_1, rw_1, n_3, h_3, w_3, c_3 in T.grid(1, 4, 4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): - with T.block("depth_conv2d_nhwc"): + with T.sblock("depth_conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(112, h_0 * 112 + h_1 * 28 + h_2 * 14 + h_3) v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 28 + w_2 * 4 + w_3) @@ -778,12 +778,12 @@ def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) T.reads(PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_1[0, v_rh, v_rw, v_c]) T.writes(depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = T.float32(0) depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc_global[v_n, v_h, v_w, v_c] + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_1[0, v_rh, v_rw, v_c] for ax0, ax1, ax2, ax3 in T.grid(1, 112, 112, 32): - with T.block("depth_conv2d_nhwc_global"): + with T.sblock("depth_conv2d_nhwc_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(depth_conv2d_nhwc_global[v0, v1, v2, v3]) T.writes(depth_conv2d_nhwc[v0, v1, v2, v3]) @@ -791,14 +791,14 @@ def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. @T.prim_func def dep_2(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 114, 114, 32)) for n_0, h_0, w_0, c_0, n_1, h_1 in T.grid(1, 1, 1, 1, 1, 4): for ax0, ax1, ax2, ax3 in T.grid(1, 30, 114, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(114, h_1 * 28 + ax1) v_i2, v_i3 = T.axis.remap("SS", [ax2, ax3]) @@ -806,7 +806,7 @@ def dep_2(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 113 and 1 <= v_i2 and v_i2 < 113, placeholder[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for w_1, c_1, rh_0, rw_0, n_2, h_2, w_2, c_2, rh_1, rw_1, n_3, h_3, w_3, c_3 in T.grid(4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2): - with T.block("depth_conv2d_nhwc"): + with T.sblock("depth_conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(112, h_0 * 112 + h_1 * 28 + h_2 * 14 + h_3) v_w = T.axis.spatial(112, w_0 * 112 + w_1 * 28 + w_2 * 4 + w_3) @@ -815,7 +815,7 @@ def dep_2(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) T.reads(PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_1[0, v_rh, v_rw, v_c]) T.writes(depth_conv2d_nhwc[v_n, v_h, v_w, v_c]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): depth_conv2d_nhwc[v_n, v_h, v_w, v_c] = T.float32(0) depth_conv2d_nhwc[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc[v_n, v_h, v_w, v_c] + PadInput[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_1[0, v_rh, v_rw, v_c] @@ -865,15 +865,15 @@ def test_cpu_dil(): @T.prim_func def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 230, 230, 3)) conv2d_nhwc_global = T.alloc_buffer((1, 109, 109, 64)) for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1 in T.grid(1, 109, 1, 4, 1, 1, 1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 13, 229, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, h_0 * 2 + ax1) v_i2 = T.axis.spatial(230, ax2) @@ -882,7 +882,7 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(109, h_0 + h_1 + h_2 + h_3) v_w = T.axis.spatial(109, w_0 * 109 + w_1 * 109 + w_2 + w_3) @@ -892,12 +892,12 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 109, 8): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(109, h_0 + ax1) v2 = T.axis.spatial(109, ax2) @@ -908,16 +908,16 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, @T.prim_func def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 230, 230, 3)) conv2d_nhwc_global = T.alloc_buffer((1, 109, 109, 64)) for n_0, h_0, w_0, co_0 in T.grid(1, 109, 1, 4): for n_1, h_1, w_1, co_1, rh_0 in T.grid(1, 1, 1, 2, 7): for ax0, ax1, ax2, ax3 in T.grid(1, 1, 229, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, h_0 * 2 + rh_0 * 2 + ax1) v_i2 = T.axis.spatial(230, ax2) @@ -926,7 +926,7 @@ def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(109, h_0 + h_1 + h_2 + h_3) v_w = T.axis.spatial(109, w_0 * 109 + w_1 * 109 + w_2 + w_3) @@ -936,12 +936,12 @@ def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 109, 16): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(109, h_0 + ax1) v2 = T.axis.spatial(109, ax2) @@ -952,14 +952,14 @@ def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, @T.prim_func def dil_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 230, 230, 3)) for n_0, h_0 in T.grid(1, 109): for ax0, ax1, ax2, ax3 in T.grid(1, 13, 229, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, h_0 * 2 + ax1) v_i2 = T.axis.spatial(230, ax2) @@ -968,7 +968,7 @@ def dil_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, inputs[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 4, 1, 1, 1, 2, 7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(109, h_0 + h_1 + h_2 + h_3) v_w = T.axis.spatial(109, w_0 * 109 + w_1 * 109 + w_2 + w_3) @@ -978,7 +978,7 @@ def dil_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc[v_n, v_h, v_w, v_co] = conv2d_nhwc[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] @@ -1031,26 +1031,26 @@ def test_cpu_gmm(): @T.prim_func def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) Z_global = T.alloc_buffer((1, 128, 128)) for b_0, i_0, j_0, b_1, i_1, j_1 in T.grid(1, 4, 2, 1, 1, 8): for k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8): - with T.block("Z"): + with T.sblock("Z"): v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) v_i = T.axis.spatial(128, i_0 * 32 + i_1 * 32 + i_2 * 2 + i_3) v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 8 + j_3) v_k = T.axis.reduce(128, k_0 + k_1) T.reads(X[v_b, v_i, v_k], Y[v_b, v_k, v_j]) T.writes(Z_global[v_b, v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): Z_global[v_b, v_i, v_j] = T.float32(0) Z_global[v_b, v_i, v_j] = Z_global[v_b, v_i, v_j] + X[v_b, v_i, v_k] * Y[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 32, 8): - with T.block("Z_global"): + with T.sblock("Z_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i_0 * 32 + ax1) v2 = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + ax2) @@ -1060,26 +1060,26 @@ def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "flo @T.prim_func def gmm_1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) Z_global = T.alloc_buffer((1, 128, 128)) for b_0, i_0, j_0 in T.grid(1, 4, 2): for b_1, i_1, j_1, k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): - with T.block("Z"): + with T.sblock("Z"): v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) v_i = T.axis.spatial(128, i_0 * 32 + i_1 * 32 + i_2 * 2 + i_3) v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 8 + j_3) v_k = T.axis.reduce(128, k_0 + k_1) T.reads(X[v_b, v_i, v_k], Y[v_b, v_k, v_j]) T.writes(Z_global[v_b, v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): Z_global[v_b, v_i, v_j] = T.float32(0) Z_global[v_b, v_i, v_j] = Z_global[v_b, v_i, v_j] + X[v_b, v_i, v_k] * Y[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 32, 64): - with T.block("Z_global"): + with T.sblock("Z_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i_0 * 32 + ax1) v2 = T.axis.spatial(128, j_0 * 64 + ax2) @@ -1089,19 +1089,19 @@ def gmm_1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "flo @T.prim_func def gmm_2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) for b_0, i_0, j_0, b_1, i_1, j_1, k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8): - with T.block("Z"): + with T.sblock("Z"): v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) v_i = T.axis.spatial(128, i_0 * 32 + i_1 * 32 + i_2 * 2 + i_3) v_j = T.axis.spatial(128, j_0 * 64 + j_1 * 8 + j_2 * 8 + j_3) v_k = T.axis.reduce(128, k_0 + k_1) T.reads(X[v_b, v_i, v_k], Y[v_b, v_k, v_j]) T.writes(Z[v_b, v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): Z[v_b, v_i, v_j] = T.float32(0) Z[v_b, v_i, v_j] = Z[v_b, v_i, v_j] + X[v_b, v_i, v_k] * Y[v_b, v_k, v_j] @@ -1142,15 +1142,15 @@ def test_cpu_grp(): @T.prim_func def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 58, 58, 64)) conv2d_nhwc_global = T.alloc_buffer((1, 28, 28, 128)) for n_0, h_0, w_0, co_0 in T.grid(1, 7, 1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 9, 57, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(58, h_0 * 8 + ax1) v_i2 = T.axis.spatial(58, ax2) @@ -1160,7 +1160,7 @@ def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for n_1, h_1, w_1, co_1 in T.grid(1, 4, 1, 1): for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(28, h_0 * 4 + h_1 + h_2 + h_3) v_w = T.axis.spatial(28, w_0 * 28 + w_1 * 28 + w_2 * 7 + w_3) @@ -1170,12 +1170,12 @@ def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, v_rc = T.axis.reduce(16, rc_0 * 2 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 28, 64): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(28, h_0 * 4 + h_1 + ax1) v2 = T.axis.spatial(28, ax2) @@ -1186,21 +1186,21 @@ def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, @T.prim_func def grp_1(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 58, 58, 64)) conv2d_nhwc_global = T.alloc_buffer((1, 28, 28, 128)) for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for n_0, h_0, w_0, co_0 in T.grid(1, 7, 1, 2): for n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 4, 1, 1, 1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(28, h_0 * 4 + h_1 + h_2 + h_3) v_w = T.axis.spatial(28, w_0 * 28 + w_1 * 28 + w_2 * 7 + w_3) @@ -1210,12 +1210,12 @@ def grp_1(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, v_rc = T.axis.reduce(16, rc_0 * 2 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_nhwc_global[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 4, 28, 64): - with T.block("conv2d_nhwc_global"): + with T.sblock("conv2d_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(28, h_0 * 4 + ax1) v2 = T.axis.spatial(28, ax2) @@ -1226,14 +1226,14 @@ def grp_1(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, @T.prim_func def grp_2(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 58, 58, 64)) for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0 in T.grid(1, 7, 1, 2, 1, 4, 1, 1, 1, 3): for ax0, ax1, ax2, ax3 in T.grid(1, 3, 55, 32): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(58, h_0 * 8 + h_1 * 2 + ax1) v_i2 = T.axis.spatial(58, rw_0 + ax2) @@ -1242,7 +1242,7 @@ def grp_2(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 57 and 1 <= v_i2 and v_i2 < 57, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(28, h_0 * 4 + h_1 + h_2 + h_3) v_w = T.axis.spatial(28, w_0 * 28 + w_1 * 28 + w_2 * 7 + w_3) @@ -1252,7 +1252,7 @@ def grp_2(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, v_rc = T.axis.reduce(16, rc_0 * 2 + rc_1) T.reads(PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc[v_n, v_h, v_w, v_co] = conv2d_nhwc[v_n, v_h, v_w, v_co] + PadInput[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight[v_rh, v_rw, v_rc, v_co] @@ -1305,21 +1305,21 @@ def test_cpu_t2d(): @T.prim_func def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 6, 6, 512)) conv2d_transpose_nhwc_global = T.alloc_buffer((1, 8, 8, 256)) for i0, i1, i2, i3 in T.grid(1, 6, 6, 512): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 5 and 1 <= v_i2 and v_i2 < 5, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1 in T.grid(1, 1, 2, 8, 1, 4, 1, 4): for rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): - with T.block("conv2d_transpose_nhwc"): + with T.sblock("conv2d_transpose_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(8, h_0 * 8 + h_1 * 2 + h_2 * 2 + h_3) v_w = T.axis.spatial(8, w_0 * 4 + w_1 * 4 + w_2 * 4 + w_3) @@ -1329,12 +1329,12 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1) T.reads(PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], weight[3 - v_rh, 3 - v_rw, v_rc, v_co]) T.writes(conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], T.float32(0)) * weight[3 - v_rh, 3 - v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 8): - with T.block("conv2d_transpose_nhwc_global"): + with T.sblock("conv2d_transpose_nhwc_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, h_1 * 2 + ax1) v2 = T.axis.spatial(8, w_0 * 4 + ax2) @@ -1345,15 +1345,15 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 @T.prim_func def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) PadInput = T.alloc_buffer((1, 6, 6, 512)) conv2d_transpose_nhwc_global = T.alloc_buffer((1, 8, 8, 256)) for n_0, h_0, w_0, co_0 in T.grid(1, 1, 2, 8): for ax0, ax1, ax2, ax3 in T.grid(1, 6, 4, 512): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1 = T.axis.remap("SS", [ax0, ax1]) v_i2 = T.axis.spatial(6, w_0 * 2 + ax2) v_i3 = T.axis.spatial(512, ax3) @@ -1361,7 +1361,7 @@ def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 5 and 1 <= v_i2 and v_i2 < 5, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float32(0)) for n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): - with T.block("conv2d_transpose_nhwc"): + with T.sblock("conv2d_transpose_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(8, h_0 * 8 + h_1 * 2 + h_2 * 2 + h_3) v_w = T.axis.spatial(8, w_0 * 4 + w_1 * 4 + w_2 * 4 + w_3) @@ -1371,12 +1371,12 @@ def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1) T.reads(PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], weight[3 - v_rh, 3 - v_rw, v_rc, v_co]) T.writes(conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc_global[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, PadInput[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], T.float32(0)) * weight[3 - v_rh, 3 - v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 4, 32): - with T.block("conv2d_transpose_nhwc_global"): + with T.sblock("conv2d_transpose_nhwc_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(8, w_0 * 4 + ax2) v3 = T.axis.spatial(256, co_0 * 32 + ax3) @@ -1386,12 +1386,12 @@ def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 @T.prim_func def t2d_2(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) for n_0, h_0, w_0, co_0, n_1, h_1, w_1, co_1, rh_0, rw_0, rc_0, n_2, h_2, w_2, co_2, rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3 in T.grid(1, 1, 2, 8, 1, 4, 1, 4, 2, 2, 64, 1, 1, 1, 1, 2, 2, 8, 1, 2, 4, 8): - with T.block("conv2d_transpose_nhwc"): + with T.sblock("conv2d_transpose_nhwc"): v_n = T.axis.spatial(1, n_0 + n_1 + n_2 + n_3) v_h = T.axis.spatial(8, h_0 * 8 + h_1 * 2 + h_2 * 2 + h_3) v_w = T.axis.spatial(8, w_0 * 4 + w_1 * 4 + w_2 * 4 + w_3) @@ -1401,7 +1401,7 @@ def t2d_2(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1) T.reads(inputs[v_n, (v_h + v_rh) // 2 - 1, (v_w + v_rw) // 2 - 1, v_rc], weight[3 - v_rh, 3 - v_rw, v_rc, v_co]) T.writes(conv2d_transpose_nhwc[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): conv2d_transpose_nhwc[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_transpose_nhwc[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, T.if_then_else(1 <= (v_h + v_rh) // 2 and (v_h + v_rh) // 2 < 5 and 1 <= (v_w + v_rw) // 2 and (v_w + v_rw) // 2 < 5, inputs[v_n, (v_h + v_rh) // 2 - 1, (v_w + v_rw) // 2 - 1, v_rc], T.float32(0)), T.float32(0)) * weight[3 - v_rh, 3 - v_rw, v_rc, v_co] @@ -1455,14 +1455,14 @@ def test_cpu_nrm(): @T.prim_func def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) C = T.alloc_buffer((1,)) C_rf = T.alloc_buffer((1, 32768)) for b, i_j_fused_0, i_j_fused_1 in T.grid(1, 32768, 2): - with T.block("C_rf"): + with T.sblock("C_rf"): vi_j_fused_0, v_b, vi_j_fused_1 = T.axis.remap("SSR", [i_j_fused_0, b, i_j_fused_1]) T.reads(A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256]) T.writes(C_rf[v_b, vi_j_fused_0]) @@ -1470,7 +1470,7 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C_rf[v_b, vi_j_fused_0] = T.float32(0) C_rf[v_b, vi_j_fused_0] = C_rf[v_b, vi_j_fused_0] + A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] * A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] for b, i_j_fused_0 in T.grid(1, 32768): - with T.block("C"): + with T.sblock("C"): vi_j_fused_0, v_b = T.axis.remap("RS", [i_j_fused_0, b]) T.reads(C_rf[v_b, vi_j_fused_0]) T.writes(C[v_b]) @@ -1478,7 +1478,7 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C[v_b] = T.float32(0) C[v_b] = C[v_b] + C_rf[v_b, vi_j_fused_0] for b in range(1): - with T.block("D"): + with T.sblock("D"): v_b = T.axis.spatial(1, b) T.reads(C[v_b]) T.writes(D[v_b]) @@ -1486,14 +1486,14 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N @T.prim_func def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) C = T.alloc_buffer((1,)) C_rf = T.alloc_buffer((1, 2)) for b, i_j_fused_0, i_j_fused_1 in T.grid(1, 32768, 2): - with T.block("C_rf"): + with T.sblock("C_rf"): vi_j_fused_1, v_b, vi_j_fused_0 = T.axis.remap("SSR", [i_j_fused_1, b, i_j_fused_0]) T.reads(A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256]) T.writes(C_rf[v_b, vi_j_fused_1]) @@ -1501,7 +1501,7 @@ def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C_rf[v_b, vi_j_fused_1] = T.float32(0) C_rf[v_b, vi_j_fused_1] = C_rf[v_b, vi_j_fused_1] + A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] * A[v_b, (vi_j_fused_0 * 2 + vi_j_fused_1) // 256, (vi_j_fused_0 * 2 + vi_j_fused_1) % 256] for b, i_j_fused_1 in T.grid(1, 2): - with T.block("C"): + with T.sblock("C"): vi_j_fused_1, v_b = T.axis.remap("RS", [i_j_fused_1, b]) T.reads(C_rf[v_b, vi_j_fused_1]) T.writes(C[v_b]) @@ -1509,7 +1509,7 @@ def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C[v_b] = T.float32(0) C[v_b] = C[v_b] + C_rf[v_b, vi_j_fused_1] for b in range(1): - with T.block("D"): + with T.sblock("D"): v_b = T.axis.spatial(1, b) T.reads(C[v_b]) T.writes(D[v_b]) @@ -1517,13 +1517,13 @@ def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N @T.prim_func def nrm_2(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) C = T.alloc_buffer((1,)) for b, i, j in T.grid(1, 256, 256): - with T.block("C"): + with T.sblock("C"): v_b, v_i, v_j = T.axis.remap("SRR", [b, i, j]) T.reads(A[v_b, v_i, v_j]) T.writes(C[v_b]) @@ -1531,7 +1531,7 @@ def nrm_2(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C[v_b] = T.float32(0) C[v_b] = C[v_b] + A[v_b, v_i, v_j] * A[v_b, v_i, v_j] for b in range(1): - with T.block("D"): + with T.sblock("D"): v_b = T.axis.spatial(1, b) T.reads(C[v_b]) T.writes(D[v_b]) @@ -1568,16 +1568,16 @@ def test_cpu_sfm(): @T.prim_func def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum = T.alloc_buffer((256,)) T_softmax_expsum_rf = T.alloc_buffer((256, 16)) T_softmax_maxelem_rf = T.alloc_buffer((256, 4)) for i0, k_0, k_1 in T.grid(256, 4, 64): - with T.block("T_softmax_maxelem_rf"): + with T.sblock("T_softmax_maxelem_rf"): vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) T.reads(A[v_i0, vk_0 * 64 + vk_1]) T.writes(T_softmax_maxelem_rf[v_i0, vk_0]) @@ -1585,7 +1585,7 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf[v_i0, vk_0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_rf[v_i0, vk_0] = T.max(T_softmax_maxelem_rf[v_i0, vk_0], A[v_i0, vk_0 * 64 + vk_1]) for i0, k_0 in T.grid(256, 4): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) T.reads(T_softmax_maxelem_rf[v_i0, vk_0]) T.writes(T_softmax_maxelem[v_i0]) @@ -1593,7 +1593,7 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_0]) for i0, k_0, k_1 in T.grid(256, 16, 16): - with T.block("T_softmax_expsum_rf"): + with T.sblock("T_softmax_expsum_rf"): vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) T.reads(A[v_i0, vk_0 * 16 + vk_1], T_softmax_maxelem[v_i0]) T.writes(T_softmax_expsum_rf[v_i0, vk_0]) @@ -1602,7 +1602,7 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_rf[v_i0, vk_0] = T_softmax_expsum_rf[v_i0, vk_0] + T.exp(A[v_i0, vk_0 * 16 + vk_1] - T_softmax_maxelem[v_i0]) for i0, i1 in T.grid(256, 256): for ax0, ax1 in T.grid(16, 1): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): vk_0 = T.axis.reduce(16, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) T.reads(T_softmax_expsum_rf[v_i0, vk_0]) @@ -1610,19 +1610,19 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 with T.init(): T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_0] - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 16, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_exp = T.alloc_buffer((256, 256)) T_softmax_expsum = T.alloc_buffer((256,)) @@ -1630,7 +1630,7 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf = T.alloc_buffer((256, 64)) for i0 in range(256): for ax0, ax1, ax2 in T.grid(64, 1, 4): - with T.block("T_softmax_maxelem_rf"): + with T.sblock("T_softmax_maxelem_rf"): vk_1 = T.axis.spatial(64, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) vk_0 = T.axis.reduce(4, ax2) @@ -1641,7 +1641,7 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf[v_i0, vk_1] = T.max(T_softmax_maxelem_rf[v_i0, vk_1], A[v_i0, vk_0 * 64 + vk_1]) for i1 in range(256): for ax0, ax1 in T.grid(64, 1): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): vk_1 = T.axis.reduce(64, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) T.reads(T_softmax_maxelem_rf[v_i0, vk_1]) @@ -1649,13 +1649,13 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 with T.init(): T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_1]) - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) T.writes(T_softmax_exp[v_i0, v_i1]) T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) for i0, k_0, k_1 in T.grid(256, 16, 16): - with T.block("T_softmax_expsum_rf"): + with T.sblock("T_softmax_expsum_rf"): vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) T.reads(T_softmax_exp[v_i0, vk_0 * 16 + vk_1]) T.writes(T_softmax_expsum_rf[v_i0, vk_0]) @@ -1663,7 +1663,7 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_rf[v_i0, vk_0] = T.float32(0) T_softmax_expsum_rf[v_i0, vk_0] = T_softmax_expsum_rf[v_i0, vk_0] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] for i0, k_0 in T.grid(256, 16): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) T.reads(T_softmax_expsum_rf[v_i0, vk_0]) T.writes(T_softmax_expsum[v_i0]) @@ -1671,24 +1671,24 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_0] for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum = T.alloc_buffer((256,)) T_softmax_expsum_rf = T.alloc_buffer((256, 16)) for i0, k in T.grid(256, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_k = T.axis.remap("SR", [i0, k]) T.reads(A[v_i0, v_k]) T.writes(T_softmax_maxelem[v_i0]) @@ -1696,7 +1696,7 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for i0, k_0, k_1 in T.grid(256, 16, 16): - with T.block("T_softmax_expsum_rf"): + with T.sblock("T_softmax_expsum_rf"): vk_0, v_i0, vk_1 = T.axis.remap("SSR", [k_0, i0, k_1]) T.reads(A[v_i0, vk_0 * 16 + vk_1], T_softmax_maxelem[v_i0]) T.writes(T_softmax_expsum_rf[v_i0, vk_0]) @@ -1704,7 +1704,7 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_rf[v_i0, vk_0] = T.float32(0) T_softmax_expsum_rf[v_i0, vk_0] = T_softmax_expsum_rf[v_i0, vk_0] + T.exp(A[v_i0, vk_0 * 16 + vk_1] - T_softmax_maxelem[v_i0]) for i0, k_0 in T.grid(256, 16): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) T.reads(T_softmax_expsum_rf[v_i0, vk_0]) T.writes(T_softmax_expsum[v_i0]) @@ -1712,19 +1712,19 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_0] for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_exp = T.alloc_buffer((256, 256)) T_softmax_expsum = T.alloc_buffer((256,)) @@ -1732,7 +1732,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf = T.alloc_buffer((256, 256)) for i0, i1 in T.grid(256, 256): for ax0, ax1, ax2 in T.grid(256, 1, 1): - with T.block("T_softmax_maxelem_rf"): + with T.sblock("T_softmax_maxelem_rf"): vk_0 = T.axis.spatial(256, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) vk_1 = T.axis.reduce(1, ax2) @@ -1742,7 +1742,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf[v_i0, vk_0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_rf[v_i0, vk_0] = T.max(T_softmax_maxelem_rf[v_i0, vk_0], A[v_i0, vk_0 + vk_1]) for ax0, ax1 in T.grid(256, 1): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): vk_0 = T.axis.reduce(256, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) T.reads(T_softmax_maxelem_rf[v_i0, vk_0]) @@ -1751,7 +1751,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_0]) for ax0, ax1 in T.grid(1, 256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0 = T.axis.spatial(256, i0 + ax0) v_i1 = T.axis.spatial(256, ax1) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) @@ -1759,7 +1759,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) for ax0 in range(16): for ax0_1, ax1, ax2 in T.grid(1, 1, 16): - with T.block("T_softmax_expsum_rf"): + with T.sblock("T_softmax_expsum_rf"): vk_1 = T.axis.spatial(16, ax0 + ax0_1) v_i0 = T.axis.spatial(256, i0 + ax1) vk_0 = T.axis.reduce(16, ax2) @@ -1769,7 +1769,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_rf[v_i0, vk_1] = T.float32(0) T_softmax_expsum_rf[v_i0, vk_1] = T_softmax_expsum_rf[v_i0, vk_1] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] for ax1 in range(1): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): vk_1 = T.axis.reduce(16, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) T.reads(T_softmax_expsum_rf[v_i0, vk_1]) @@ -1777,19 +1777,19 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 with T.init(): T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_1] - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 0, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_exp = T.alloc_buffer((256, 256)) T_softmax_expsum = T.alloc_buffer((256,)) @@ -1797,7 +1797,7 @@ def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf = T.alloc_buffer((256, 1)) for i0 in range(256): for ax0, ax1, ax2 in T.grid(1, 1, 256): - with T.block("T_softmax_maxelem_rf"): + with T.sblock("T_softmax_maxelem_rf"): vk_1 = T.axis.spatial(1, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) vk_0 = T.axis.reduce(256, ax2) @@ -1807,7 +1807,7 @@ def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf[v_i0, vk_1] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_rf[v_i0, vk_1] = T.max(T_softmax_maxelem_rf[v_i0, vk_1], A[v_i0, vk_0 + vk_1]) for k_1 in range(1): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): vk_1, v_i0 = T.axis.remap("RS", [k_1, i0]) T.reads(T_softmax_maxelem_rf[v_i0, vk_1]) T.writes(T_softmax_maxelem[v_i0]) @@ -1815,13 +1815,13 @@ def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_1]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) T.writes(T_softmax_exp[v_i0, v_i1]) T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) for i0, k_0, k_1 in T.grid(256, 16, 16): - with T.block("T_softmax_expsum_rf"): + with T.sblock("T_softmax_expsum_rf"): vk_1, v_i0, vk_0 = T.axis.remap("SSR", [k_1, i0, k_0]) T.reads(T_softmax_exp[v_i0, vk_0 * 16 + vk_1]) T.writes(T_softmax_expsum_rf[v_i0, vk_1]) @@ -1829,7 +1829,7 @@ def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_rf[v_i0, vk_1] = T.float32(0) T_softmax_expsum_rf[v_i0, vk_1] = T_softmax_expsum_rf[v_i0, vk_1] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] for i0, k_1 in T.grid(256, 16): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): vk_1, v_i0 = T.axis.remap("RS", [k_1, i0]) T.reads(T_softmax_expsum_rf[v_i0, vk_1]) T.writes(T_softmax_expsum[v_i0]) @@ -1837,26 +1837,26 @@ def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_1] for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_exp = T.alloc_buffer((256, 256)) T_softmax_expsum = T.alloc_buffer((256,)) T_softmax_expsum_rf = T.alloc_buffer((256, 16)) for i0 in range(256): for ax0, ax1 in T.grid(1, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(256, i0 + ax0) v_k = T.axis.reduce(256, ax1) T.reads(A[v_i0, v_k]) @@ -1865,7 +1865,7 @@ def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for ax0, ax1 in T.grid(1, 256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0 = T.axis.spatial(256, i0 + ax0) v_i1 = T.axis.spatial(256, ax1) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) @@ -1873,7 +1873,7 @@ def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) for ax0 in range(16): for ax0_1, ax1, ax2 in T.grid(1, 1, 16): - with T.block("T_softmax_expsum_rf"): + with T.sblock("T_softmax_expsum_rf"): vk_1 = T.axis.spatial(16, ax0 + ax0_1) v_i0 = T.axis.spatial(256, i0 + ax1) vk_0 = T.axis.reduce(16, ax2) @@ -1883,7 +1883,7 @@ def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_rf[v_i0, vk_1] = T.float32(0) T_softmax_expsum_rf[v_i0, vk_1] = T_softmax_expsum_rf[v_i0, vk_1] + T_softmax_exp[v_i0, vk_0 * 16 + vk_1] for ax1 in range(1): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): vk_1 = T.axis.reduce(16, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) T.reads(T_softmax_expsum_rf[v_i0, vk_1]) @@ -1892,25 +1892,25 @@ def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_expsum_rf[v_i0, vk_1] for i1 in range(256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum = T.alloc_buffer((256,)) T_softmax_maxelem_rf = T.alloc_buffer((256, 64)) for i0 in range(256): for ax0, ax1, ax2 in T.grid(64, 1, 4): - with T.block("T_softmax_maxelem_rf"): + with T.sblock("T_softmax_maxelem_rf"): vk_0 = T.axis.spatial(64, ax0) v_i0 = T.axis.spatial(256, i0 + ax1) vk_1 = T.axis.reduce(4, ax2) @@ -1920,7 +1920,7 @@ def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf[v_i0, vk_0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_rf[v_i0, vk_0] = T.max(T_softmax_maxelem_rf[v_i0, vk_0], A[v_i0, vk_0 * 4 + vk_1]) for k_0 in range(64): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): vk_0, v_i0 = T.axis.remap("RS", [k_0, i0]) T.reads(T_softmax_maxelem_rf[v_i0, vk_0]) T.writes(T_softmax_maxelem[v_i0]) @@ -1928,7 +1928,7 @@ def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_0]) for i0, k in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_k = T.axis.remap("SR", [i0, k]) T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) T.writes(T_softmax_expsum[v_i0]) @@ -1936,24 +1936,24 @@ def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum = T.alloc_buffer((256,)) T_softmax_maxelem_rf = T.alloc_buffer((256, 4)) for i0, k_0, k_1 in T.grid(256, 64, 4): - with T.block("T_softmax_maxelem_rf"): + with T.sblock("T_softmax_maxelem_rf"): vk_1, v_i0, vk_0 = T.axis.remap("SSR", [k_1, i0, k_0]) T.reads(A[v_i0, vk_0 * 4 + vk_1]) T.writes(T_softmax_maxelem_rf[v_i0, vk_1]) @@ -1961,7 +1961,7 @@ def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_rf[v_i0, vk_1] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_rf[v_i0, vk_1] = T.max(T_softmax_maxelem_rf[v_i0, vk_1], A[v_i0, vk_0 * 4 + vk_1]) for i0, k_1 in T.grid(256, 4): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): vk_1, v_i0 = T.axis.remap("RS", [k_1, i0]) T.reads(T_softmax_maxelem_rf[v_i0, vk_1]) T.writes(T_softmax_maxelem[v_i0]) @@ -1970,7 +1970,7 @@ def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], T_softmax_maxelem_rf[v_i0, vk_1]) for i0, i1 in T.grid(256, 256): for ax0, ax1 in T.grid(1, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0 = T.axis.spatial(256, i0 + ax0) v_k = T.axis.reduce(256, ax1) T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) @@ -1978,25 +1978,25 @@ def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 with T.init(): T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_8(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_exp = T.alloc_buffer((256, 256)) T_softmax_expsum = T.alloc_buffer((256,)) for i0 in range(256): for ax0, ax1 in T.grid(1, 256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(256, i0 + ax0) v_k = T.axis.reduce(256, ax1) T.reads(A[v_i0, v_k]) @@ -2005,13 +2005,13 @@ def sfm_8(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k]) for i1 in range(256): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0]) T.writes(T_softmax_exp[v_i0, v_i1]) T_softmax_exp[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) for i0, k in T.grid(256, 256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_k = T.axis.remap("SR", [i0, k]) T.reads(T_softmax_exp[v_i0, v_k]) T.writes(T_softmax_expsum[v_i0]) @@ -2019,11 +2019,11 @@ def sfm_8(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T.float32(0) T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k] for i0, i1 in T.grid(256, 256): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] # fmt: on decision_0 = [ @@ -2129,13 +2129,13 @@ def test_cpu_cbr(): @T.prim_func def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) Conv2dOutput = T.alloc_buffer((1, 112, 112, 64)) for nn_0, yy_0, xx_0, ff_0, nn_1, yy_1, xx_1, ff_1, ry_0, rx_0, rc_0, nn_2, yy_2, xx_2, ff_2, ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3 in T.grid(1, 2, 7, 1, 1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): - with T.block("Conv2dOutput"): + with T.sblock("Conv2dOutput"): v_nn = T.axis.spatial(1, nn_0 + nn_1 + nn_2 + nn_3) v_yy = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + yy_2 * 28 + yy_3) v_xx = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + xx_2 * 2 + xx_3) @@ -2145,12 +2145,12 @@ def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(data[v_nn, v_yy * 2 + v_ry - 3, v_xx * 2 + v_rx - 3, v_rc], kernel[v_ry, v_rx, v_rc, v_ff]) T.writes(Conv2dOutput[v_nn, v_yy, v_xx, v_ff]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = T.float32(0) Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput[v_nn, v_yy, v_xx, v_ff] + T.if_then_else(3 <= v_yy * 2 + v_ry and v_yy * 2 + v_ry < 227 and 3 <= v_xx * 2 + v_rx and v_xx * 2 + v_rx < 227, data[v_nn, v_yy * 2 + v_ry - 3, v_xx * 2 + v_rx - 3, v_rc], T.float32(0)) * kernel[v_ry, v_rx, v_rc, v_ff] for i0, i1, i2, i3 in T.grid(1, 112, 112, 64): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(Conv2dOutput[v_i0, v_i1, v_i2, v_i3], bias[v_i3], bn_scale[v_i3], bn_offset[v_i3]) T.writes(compute[v_i0, v_i1, v_i2, v_i3]) @@ -2158,15 +2158,15 @@ def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 @T.prim_func def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) PaddedInput = T.alloc_buffer((1, 230, 230, 3)) Conv2dOutput = T.alloc_buffer((1, 112, 112, 64)) for nn_0, yy_0 in T.grid(1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 117, 229, 3): - with T.block("PaddedInput"): + with T.sblock("PaddedInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, yy_0 * 112 + ax1) v_i2 = T.axis.spatial(230, ax2) @@ -2176,7 +2176,7 @@ def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, data[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for xx_0, ff_0, nn_1, yy_1, xx_1, ff_1 in T.grid(7, 1, 1, 2, 2, 32): for ry_0, rx_0, rc_0, nn_2, yy_2, xx_2, ff_2, ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3 in T.grid(7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): - with T.block("Conv2dOutput"): + with T.sblock("Conv2dOutput"): v_nn = T.axis.spatial(1, nn_0 + nn_1 + nn_2 + nn_3) v_yy = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + yy_2 * 28 + yy_3) v_xx = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + xx_2 * 2 + xx_3) @@ -2186,12 +2186,12 @@ def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc], kernel[v_ry, v_rx, v_rc, v_ff]) T.writes(Conv2dOutput[v_nn, v_yy, v_xx, v_ff]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = T.float32(0) Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput[v_nn, v_yy, v_xx, v_ff] + PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc] * kernel[v_ry, v_rx, v_rc, v_ff] for ax0, ax1, ax2, ax3 in T.grid(1, 28, 8, 2): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + ax1) v_i2 = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + ax2) @@ -2202,15 +2202,15 @@ def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 @T.prim_func def cbr_2(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) PaddedInput = T.alloc_buffer((1, 230, 230, 3)) Conv2dOutput = T.alloc_buffer((1, 112, 112, 64)) for nn_0, yy_0 in T.grid(1, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 117, 229, 3): - with T.block("PaddedInput"): + with T.sblock("PaddedInput"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(230, yy_0 * 112 + ax1) v_i2 = T.axis.spatial(230, ax2) @@ -2220,7 +2220,7 @@ def cbr_2(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i1 and v_i1 < 227 and 3 <= v_i2 and v_i2 < 227, data[v_i0, v_i1 - 3, v_i2 - 3, v_i3], T.float32(0)) for xx_0, ff_0 in T.grid(7, 1): for nn_1, yy_1, xx_1, ff_1, ry_0, rx_0, rc_0, nn_2, yy_2, xx_2, ff_2, ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3 in T.grid(1, 2, 2, 32, 7, 7, 1, 1, 1, 4, 1, 1, 1, 3, 1, 28, 2, 2): - with T.block("Conv2dOutput"): + with T.sblock("Conv2dOutput"): v_nn = T.axis.spatial(1, nn_0 + nn_1 + nn_2 + nn_3) v_yy = T.axis.spatial(112, yy_0 * 56 + yy_1 * 28 + yy_2 * 28 + yy_3) v_xx = T.axis.spatial(112, xx_0 * 16 + xx_1 * 8 + xx_2 * 2 + xx_3) @@ -2230,12 +2230,12 @@ def cbr_2(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1) T.reads(PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc], kernel[v_ry, v_rx, v_rc, v_ff]) T.writes(Conv2dOutput[v_nn, v_yy, v_xx, v_ff]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = T.float32(0) Conv2dOutput[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput[v_nn, v_yy, v_xx, v_ff] + PaddedInput[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc] * kernel[v_ry, v_rx, v_rc, v_ff] for ax0, ax1, ax2, ax3 in T.grid(1, 56, 16, 64): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(1, ax0) v_i1 = T.axis.spatial(112, yy_0 * 56 + ax1) v_i2 = T.axis.spatial(112, xx_0 * 16 + ax2) @@ -2292,16 +2292,16 @@ def test_cpu_tbg(): @T.prim_func def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) query_T = T.alloc_buffer((1, 12, 128, 64)) value_T = T.alloc_buffer((1, 12, 64, 128)) C_global = T.alloc_buffer((1, 12, 128, 128)) for b_0, h_0, i_0, j_0, b_1, h_1, i_1 in T.grid(1, 1, 1, 2, 1, 6, 2): for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64): - with T.block("value_T"): + with T.sblock("value_T"): v_b = T.axis.spatial(1, ax0) v_h = T.axis.spatial(12, h_1 * 2 + ax1) v_d = T.axis.spatial(64, ax2) @@ -2310,7 +2310,7 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, T.writes(value_T[v_b, v_h, v_d, v_l]) value_T[v_b, v_h, v_d, v_l] = value[v_b, v_l, v_h, v_d] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64): - with T.block("query_T"): + with T.sblock("query_T"): v_b = T.axis.spatial(1, ax0) v_h = T.axis.spatial(12, h_1 * 2 + ax1) v_l = T.axis.spatial(128, i_1 * 64 + ax2) @@ -2320,7 +2320,7 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, query_T[v_b, v_h, v_l, v_d] = query[v_b, v_l, v_h, v_d] for j_1 in range(8): for k_0, b_2, h_2, i_2, j_2, k_1, b_3, h_3, i_3, j_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): - with T.block("C"): + with T.sblock("C"): v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) v_h = T.axis.spatial(12, h_0 * 12 + h_1 * 2 + h_2 + h_3) v_i = T.axis.spatial(128, i_0 * 128 + i_1 * 64 + i_2 * 32 + i_3) @@ -2328,12 +2328,12 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, v_k = T.axis.reduce(64, k_0 * 64 + k_1) T.reads(query_T[v_b, v_h, v_i, v_k], value_T[v_b, v_h, v_k, v_j]) T.writes(C_global[v_b, v_h, v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C_global[v_b, v_h, v_i, v_j] = T.float32(0) C_global[v_b, v_h, v_i, v_j] = C_global[v_b, v_h, v_i, v_j] + query_T[v_b, v_h, v_i, v_k] * value_T[v_b, v_h, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8): - with T.block("C_global"): + with T.sblock("C_global"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(12, h_1 * 2 + ax1) v2 = T.axis.spatial(128, i_1 * 64 + ax2) @@ -2344,15 +2344,15 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, @T.prim_func def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 64, "meta_schedule.vectorize": 64}) query_T = T.alloc_buffer((1, 12, 128, 64)) value_T = T.alloc_buffer((1, 12, 64, 128)) C_global = T.alloc_buffer((1, 12, 128, 128)) for b, h, l, d in T.grid(1, 12, 128, 64): - with T.block("query_T"): + with T.sblock("query_T"): v_b, v_h, v_l, v_d = T.axis.remap("SSSS", [b, h, l, d]) T.reads(query[v_b, v_l, v_h, v_d]) T.writes(query_T[v_b, v_h, v_l, v_d]) @@ -2360,7 +2360,7 @@ def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, for b_0, h_0, i_0, j_0 in T.grid(1, 1, 1, 2): for b_1, h_1, i_1, j_1, k_0, b_2, h_2, i_2, j_2, k_1 in T.grid(1, 6, 2, 8, 1, 1, 2, 2, 4, 64): for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 2): - with T.block("value_T"): + with T.sblock("value_T"): v_b = T.axis.spatial(1, ax0) v_h = T.axis.spatial(12, h_1 * 2 + h_2 + ax1) v_d = T.axis.spatial(64, k_1 + ax2) @@ -2369,7 +2369,7 @@ def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, T.writes(value_T[v_b, v_h, v_d, v_l]) value_T[v_b, v_h, v_d, v_l] = value[v_b, v_l, v_h, v_d] for b_3, h_3, i_3, j_3 in T.grid(1, 1, 32, 2): - with T.block("C"): + with T.sblock("C"): v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) v_h = T.axis.spatial(12, h_0 * 12 + h_1 * 2 + h_2 + h_3) v_i = T.axis.spatial(128, i_0 * 128 + i_1 * 64 + i_2 * 32 + i_3) @@ -2377,12 +2377,12 @@ def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, v_k = T.axis.reduce(64, k_0 * 64 + k_1) T.reads(query_T[v_b, v_h, v_i, v_k], value_T[v_b, v_h, v_k, v_j]) T.writes(C_global[v_b, v_h, v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C_global[v_b, v_h, v_i, v_j] = T.float32(0) C_global[v_b, v_h, v_i, v_j] = C_global[v_b, v_h, v_i, v_j] + query_T[v_b, v_h, v_i, v_k] * value_T[v_b, v_h, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(1, 12, 128, 64): - with T.block("C_global"): + with T.sblock("C_global"): v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) v3 = T.axis.spatial(128, j_0 * 64 + ax3) T.reads(C_global[v0, v1, v2, v3]) @@ -2391,14 +2391,14 @@ def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, @T.prim_func def tbg_2(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) + T.sblock_attr({"meta_schedule.parallel": 288, "meta_schedule.unroll_explicit": 512, "meta_schedule.vectorize": 64}) value_T = T.alloc_buffer((1, 12, 64, 128)) for b_0, h_0, i_0, j_0, b_1, h_1, i_1, j_1 in T.grid(1, 1, 1, 2, 1, 6, 2, 8): for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8): - with T.block("value_T"): + with T.sblock("value_T"): v_b = T.axis.spatial(1, ax0) v_h = T.axis.spatial(12, h_1 * 2 + ax1) v_d = T.axis.spatial(64, ax2) @@ -2407,7 +2407,7 @@ def tbg_2(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, T.writes(value_T[v_b, v_h, v_d, v_l]) value_T[v_b, v_h, v_d, v_l] = value[v_b, v_l, v_h, v_d] for k_0, b_2, h_2, i_2, j_2, k_1, b_3, h_3, i_3, j_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): - with T.block("C"): + with T.sblock("C"): v_b = T.axis.spatial(1, b_0 + b_1 + b_2 + b_3) v_h = T.axis.spatial(12, h_0 * 12 + h_1 * 2 + h_2 + h_3) v_i = T.axis.spatial(128, i_0 * 128 + i_1 * 64 + i_2 * 32 + i_3) @@ -2415,7 +2415,7 @@ def tbg_2(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, v_k = T.axis.reduce(64, k_0 * 64 + k_1) T.reads(query[v_b, v_i, v_h, v_k], value_T[v_b, v_h, v_k, v_j]) T.writes(C[v_b, v_h, v_i, v_j]) - T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSRSRS"}) with T.init(): C[v_b, v_h, v_i, v_j] = T.float32(0) C[v_b, v_h, v_i, v_j] = C[v_b, v_h, v_i, v_j] + query[v_b, v_i, v_h, v_k] * value_T[v_b, v_h, v_k, v_j] diff --git a/tests/python/meta_schedule/test_meta_schedule_space_cuda.py b/tests/python/meta_schedule/test_meta_schedule_space_cuda.py index d05ade960164..9242bcc4d321 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_cuda.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_cuda.py @@ -44,10 +44,10 @@ def test_cuda_c1d(): @T.prim_func def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) conv1d_nlc_local = T.alloc_buffer((1, 128, 128), scope="local") PadInput_shared = T.alloc_buffer((1, 258, 64), scope="shared") weight_shared = T.alloc_buffer((3, 64, 128), scope="shared") @@ -56,25 +56,25 @@ def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 for n_2_l_2_co_2_fused in T.thread_binding(4, thread="threadIdx.x"): for rl_0, rc_0 in T.grid(1, 16): for ax0_ax1_ax2_fused in range(260): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(258, n_0_l_0_co_0_fused * 64 + ax0_ax1_ax2_fused // 4) v2 = T.axis.spatial(64, rc_0 * 4 + ax0_ax1_ax2_fused % 4) T.reads(inputs[v0, v1 - 1, v2]) T.writes(PadInput_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) PadInput_shared[v0, v1, v2] = T.if_then_else(1 <= v1 and v1 < 257, inputs[v0, v1 - 1, v2], T.float32(0)) for ax0_ax1_ax2_fused in range(1536): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0 = T.axis.spatial(3, ax0_ax1_ax2_fused // 512) v1 = T.axis.spatial(64, rc_0 * 4 + ax0_ax1_ax2_fused % 512 // 128) v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) T.reads(weight[v0, v1, v2]) T.writes(weight_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2] = weight[v0, v1, v2] for rl_1, rc_1, n_3, l_3, co_3, rl_2, rc_2, n_4, l_4, co_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8): - with T.block("conv1d_nlc"): + with T.sblock("conv1d_nlc"): v_n = T.axis.spatial(1, n_3 + n_4) v_l = T.axis.spatial(128, n_0_l_0_co_0_fused * 32 + n_1_l_1_co_1_fused // 2 * 4 + l_3 * 4 + l_4) v_co = T.axis.spatial(128, n_1_l_1_co_1_fused % 2 * 64 + n_2_l_2_co_2_fused * 16 + co_3 * 8 + co_4) @@ -82,12 +82,12 @@ def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 v_rc = T.axis.reduce(64, rc_0 * 4 + rc_1 * 2 + rc_2) T.reads(PadInput_shared[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc], weight_shared[v_rl, v_rc, v_co]) T.writes(conv1d_nlc_local[v_n, v_l, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv1d_nlc_local[v_n, v_l, v_co] = T.float32(0) conv1d_nlc_local[v_n, v_l, v_co] = conv1d_nlc_local[v_n, v_l, v_co] + PadInput_shared[v_n, v_l * 2 + v_rl, v_co // 128 * 64 + v_rc] * weight_shared[v_rl, v_rc, v_co] for ax0, ax1, ax2 in T.grid(1, 4, 16): - with T.block("conv1d_nlc_local"): + with T.sblock("conv1d_nlc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, n_0_l_0_co_0_fused * 32 + n_1_l_1_co_1_fused // 2 * 4 + ax1) v2 = T.axis.spatial(128, n_1_l_1_co_1_fused % 2 * 64 + n_2_l_2_co_2_fused * 16 + ax2) @@ -122,10 +122,10 @@ def test_cuda_c2d(): @T.prim_func def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), scope="local") PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") @@ -134,27 +134,27 @@ def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, for n_2_h_2_w_2_co_2_fused in T.thread_binding(14, thread="threadIdx.x"): for rh_0, rw_0, rc_0 in T.grid(1, 1, 1): for ax0_ax1_ax2_ax3_fused in range(80379): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused // 351) v2 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused // 8 * 112 + ax0_ax1_ax2_ax3_fused % 351 // 3) v3 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 3) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(1176): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 168) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused % 168 // 24) v2 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 24 // 8) v3 = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 8 * 8 + ax0_ax1_ax2_ax3_fused % 8) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 7, 1, 1, 8, 4, 1, 7, 1, 3, 1, 1, 1, 2): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(112, n_2_h_2_w_2_co_2_fused * 8 + h_3 + h_4) v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 56 + n_1_h_1_w_1_co_1_fused // 4 * 4 + w_3 + w_4) @@ -164,12 +164,12 @@ def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1 * 3 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 4, 2): - with T.block("conv2d_nhwc_local"): + with T.sblock("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, n_2_h_2_w_2_co_2_fused * 8 + ax1) v2 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 56 + n_1_h_1_w_1_co_1_fused // 4 * 4 + ax2) @@ -206,10 +206,10 @@ def test_cuda_c3d(): @T.prim_func def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) conv3d_ndhwc_local = T.alloc_buffer((1, 8, 112, 112, 64), scope="local") PadInput_shared = T.alloc_buffer((1, 22, 230, 230, 3), scope="shared") weight_shared = T.alloc_buffer((7, 7, 7, 3, 64), scope="shared") @@ -218,7 +218,7 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 for n_2_d_2_h_2_w_2_co_2_fused in T.thread_binding(392, thread="threadIdx.x"): for rd_0, rh_0, rw_0, rc_0 in T.grid(1, 1, 1, 1): for ax0_ax1_ax2_ax3_ax4_fused in range(1687959): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused // 80379) v2 = T.axis.spatial(230, ax0_ax1_ax2_ax3_ax4_fused % 80379 // 351) @@ -226,10 +226,10 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v4 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_fused % 3) T.reads(inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4]) T.writes(PadInput_shared[v0, v1, v2, v3, v4]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) PadInput_shared[v0, v1, v2, v3, v4] = T.if_then_else(3 <= v1 and v1 < 19 and 3 <= v2 and v2 < 227 and 3 <= v3 and v3 < 227, inputs[v0, v1 - 3, v2 - 3, v3 - 3, v4], T.float32(0)) for ax0_ax1_ax2_ax3_ax4_fused in range(65856): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused // 9408) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused % 9408 // 1344) v2 = T.axis.spatial(7, ax0_ax1_ax2_ax3_ax4_fused % 1344 // 192) @@ -237,10 +237,10 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v4 = T.axis.spatial(64, ax0_ax1_ax2_ax3_ax4_fused % 64) T.reads(weight[v0, v1, v2, v3, v4]) T.writes(weight_shared[v0, v1, v2, v3, v4]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2, v3, v4] = weight[v0, v1, v2, v3, v4] for rd_1, rh_1, rw_1, rc_1, n_3, d_3, h_3, w_3, co_3, rd_2, rh_2, rw_2, rc_2, n_4, d_4, h_4, w_4, co_4 in T.grid(7, 7, 1, 3, 1, 2, 2, 1, 32, 1, 1, 7, 1, 1, 1, 2, 4, 1): - with T.block("conv3d_ndhwc"): + with T.sblock("conv3d_ndhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_d = T.axis.spatial(8, n_2_d_2_h_2_w_2_co_2_fused // 98 * 2 + d_3 + d_4) v_h = T.axis.spatial(112, n_1_d_1_h_1_w_1_co_1_fused // 2 * 28 + n_2_d_2_h_2_w_2_co_2_fused % 98 // 14 * 4 + h_3 * 2 + h_4) @@ -252,12 +252,12 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 v_rc = T.axis.reduce(3, rc_0 * 3 + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rd, v_rh, v_rw, v_rc, v_co]) T.writes(conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co] = T.float32(0) conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co] = conv3d_ndhwc_local[v_n, v_d, v_h, v_w, v_co] + PadInput_shared[v_n, v_d * 2 + v_rd, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rd, v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 4, 4, 32): - with T.block("conv3d_ndhwc_local"): + with T.sblock("conv3d_ndhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, n_2_d_2_h_2_w_2_co_2_fused // 98 * 2 + ax1) v2 = T.axis.spatial(112, n_1_d_1_h_1_w_1_co_1_fused // 2 * 28 + n_2_d_2_h_2_w_2_co_2_fused % 98 // 14 * 4 + ax2) @@ -296,10 +296,10 @@ def test_cuda_cap(): @T.prim_func def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 64}) + T.sblock_attr({"meta_schedule.unroll_explicit": 64}) conv2d_capsule_nhwijc_local = T.alloc_buffer((1, 8, 8, 4, 4, 32), scope="local") PadInput_shared = T.alloc_buffer((1, 18, 18, 4, 4, 32), scope="shared") weight_shared = T.alloc_buffer((3, 3, 4, 4, 32, 32), scope="shared") @@ -308,7 +308,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( for n_2_h_2_w_2_cap_i_2_cap_j_2_co_2_fused in T.thread_binding(4, thread="threadIdx.x"): for rh_0, rw_0, cap_k_0, rc_0 in T.grid(3, 3, 2, 8): for ax0_ax1_ax2_ax3_ax4_ax5_fused in range(48): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(18, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused // 64 * 4 + rh_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) v2 = T.axis.spatial(18, T.Add(n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 64 // 8 * 2 + rw_0, 0)) @@ -317,10 +317,10 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( v5 = T.axis.spatial(32, rc_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) T.reads(inputs[v0, v1 - 1, v2 - 1, v3, v4, v5]) T.writes(PadInput_shared[v0, v1, v2, v3, v4, v5]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) PadInput_shared[v0, v1, v2, v3, v4, v5] = T.if_then_else(1 <= v1 and v1 < 17 and 1 <= v2 and v2 < 17, inputs[v0, v1 - 1, v2 - 1, v3, v4, v5], T.float32(0)) for ax0_ax1_ax2_ax3_ax4_ax5_fused in range(256): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0, v1 = T.axis.remap("SS", [rh_0, rw_0]) v2 = T.axis.spatial(4, cap_k_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused // 128) v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_ax4_ax5_fused % 128 // 32) @@ -328,10 +328,10 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( v5 = T.axis.spatial(32, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 4 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8) T.reads(weight[v0, v1, v2, v3, v4, v5]) T.writes(weight_shared[v0, v1, v2, v3, v4, v5]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3, v4, v5] = weight[v0, v1, v2, v3, v4, v5] for rh_1, rw_1, cap_k_1, rc_1, n_3, h_3, w_3, cap_i_3, cap_j_3, co_3, rh_2, rw_2, cap_k_2, rc_2, n_4, h_4, w_4, cap_i_4, cap_j_4, co_4 in T.grid(1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 8): - with T.block("conv2d_capsule_nhwijc"): + with T.sblock("conv2d_capsule_nhwijc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused // 64 * 2 + h_3 + h_4) v_w = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 64 // 8 + w_3 + w_4) @@ -344,12 +344,12 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( v_rc = T.axis.reduce(32, rc_0 * 4 + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc], weight_shared[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co]) T.writes(conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = T.float32(0) conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] = conv2d_capsule_nhwijc_local[v_n, v_h, v_w, v_cap_i, v_cap_j, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_cap_i, v_cap_k, v_rc] * weight_shared[v_rh, v_rw, v_cap_k, v_cap_j, v_rc, v_co] for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 2, 1, 1, 2, 8): - with T.block("conv2d_capsule_nhwijc_local"): + with T.sblock("conv2d_capsule_nhwijc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused // 64 * 2 + ax1) v2 = T.axis.spatial(8, n_0_h_0_w_0_cap_i_0_cap_j_0_co_0_fused % 64 // 8 + ax2) @@ -390,10 +390,10 @@ def test_cuda_dep(): @T.prim_func def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) depth_conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 32), scope="local") PadInput_shared = T.alloc_buffer((1, 114, 114, 32), scope="shared") placeholder_shared = T.alloc_buffer((1, 3, 3, 32), scope="shared") @@ -402,27 +402,27 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. for n_2_h_2_w_2_c_2_fused in T.thread_binding(14, thread="threadIdx.x"): for rh_0, rw_0 in T.grid(1, 1): for ax0_ax1_ax2_ax3_fused in range(415872): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(114, ax0_ax1_ax2_ax3_fused // 3648) v2 = T.axis.spatial(114, ax0_ax1_ax2_ax3_fused % 3648 // 32) v3 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 32) T.reads(placeholder[v0, v1 - 1, v2 - 1, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 113 and 1 <= v2 and v2 < 113, placeholder[v0, v1 - 1, v2 - 1, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(288): - with T.block("placeholder_shared"): + with T.sblock("placeholder_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 96) v2 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 96 // 32) v3 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 32) T.reads(placeholder_1[v0, v1, v2, v3]) T.writes(placeholder_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) placeholder_shared[v0, v1, v2, v3] = placeholder_1[v0, v1, v2, v3] for rh_1, rw_1, n_3, h_3, w_3, c_3, rh_2, rw_2, n_4, h_4, w_4, c_4 in T.grid(3, 1, 1, 4, 16, 8, 1, 3, 1, 7, 1, 1): - with T.block("depth_conv2d_nhwc"): + with T.sblock("depth_conv2d_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(112, n_1_h_1_w_1_c_1_fused // 2 * 28 + h_3 * 7 + h_4) v_w = T.axis.spatial(112, n_2_h_2_w_2_c_2_fused // 2 * 16 + w_3 + w_4) @@ -431,12 +431,12 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1 * 3 + rw_2) T.reads(PadInput_shared[v_n, v_h + v_rh, v_w + v_rw, v_c], placeholder_shared[0, v_rh, v_rw, v_c]) T.writes(depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c] = T.float32(0) depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c] = depth_conv2d_nhwc_local[v_n, v_h, v_w, v_c] + PadInput_shared[v_n, v_h + v_rh, v_w + v_rw, v_c] * placeholder_shared[0, v_rh, v_rw, v_c] for ax0, ax1, ax2, ax3 in T.grid(1, 28, 16, 8): - with T.block("depth_conv2d_nhwc_local"): + with T.sblock("depth_conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, n_1_h_1_w_1_c_1_fused // 2 * 28 + ax1) v2 = T.axis.spatial(112, n_2_h_2_w_2_c_2_fused // 2 * 16 + ax2) @@ -471,10 +471,10 @@ def test_cuda_dil(): @T.prim_func def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 512}) + T.sblock_attr({"meta_schedule.unroll_explicit": 512}) conv2d_nhwc_local = T.alloc_buffer((1, 109, 109, 64), scope="local") PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") @@ -483,25 +483,25 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, for n_2_h_2_w_2_co_2_fused in T.thread_binding(1, thread="threadIdx.x"): for rh_0, rw_0, rc_0 in T.grid(7, 7, 3): for ax0_ax1_ax2_ax3_fused in range(217): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(230, T.Add(n_0_h_0_w_0_co_0_fused // 2 * 2 + rh_0 * 2, 0)) v2 = T.axis.spatial(230, rw_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) v3 = T.axis.spatial(3, T.Add(rc_0, 0)) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(32): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0, v1, v2 = T.axis.remap("SSS", [rh_0, rw_0, rc_0]) v3 = T.axis.spatial(64, n_0_h_0_w_0_co_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 4): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(109, n_0_h_0_w_0_co_0_fused // 2 + h_3 + h_4) v_w = T.axis.spatial(109, n_1_h_1_w_1_co_1_fused + w_3 + w_4) @@ -511,12 +511,12 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, v_rc = T.axis.reduce(3, rc_0 + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh * 2, v_w * 2 + v_rw * 2, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 32): - with T.block("conv2d_nhwc_local"): + with T.sblock("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(109, n_0_h_0_w_0_co_0_fused // 2 + ax1) v2 = T.axis.spatial(109, n_1_h_1_w_1_co_1_fused + ax2) @@ -552,10 +552,10 @@ def test_cuda_gmm(): @T.prim_func def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) Z_local = T.alloc_buffer((1, 128, 128), scope="local") X_shared = T.alloc_buffer((1, 128, 128), scope="shared") Y_shared = T.alloc_buffer((1, 128, 128), scope="shared") @@ -564,37 +564,37 @@ def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "flo for b_2_i_2_j_2_fused in T.thread_binding(2, thread="threadIdx.x"): for k_0 in range(1): for ax0_ax1_ax2_fused in range(16384): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128) v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused in range(16384): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128) v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1): - with T.block("Z"): + with T.sblock("Z"): v_b = T.axis.spatial(1, b_3 + b_4) v_i = T.axis.spatial(128, b_1_i_1_j_1_fused * 4 + i_3 * 2 + i_4) v_j = T.axis.spatial(128, b_2_i_2_j_2_fused * 64 + j_3 + j_4) v_k = T.axis.reduce(128, k_0 * 128 + k_1 * 4 + k_2) T.reads(X_shared[v_b, v_i, v_k], Y_shared[v_b, v_k, v_j]) T.writes(Z_local[v_b, v_i, v_j]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): Z_local[v_b, v_i, v_j] = T.float32(0) Z_local[v_b, v_i, v_j] = Z_local[v_b, v_i, v_j] + X_shared[v_b, v_i, v_k] * Y_shared[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 4, 64): - with T.block("Z_local"): + with T.sblock("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, b_1_i_1_j_1_fused * 4 + ax1) v2 = T.axis.spatial(128, b_2_i_2_j_2_fused * 64 + ax2) @@ -626,10 +626,10 @@ def test_cuda_grp(): @T.prim_func def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) conv2d_nhwc_local = T.alloc_buffer((1, 28, 28, 128), scope="local") PadInput_shared = T.alloc_buffer((1, 58, 58, 64), scope="shared") weight_shared = T.alloc_buffer((3, 3, 16, 128), scope="shared") @@ -638,26 +638,26 @@ def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, for n_2_h_2_w_2_co_2_fused in T.thread_binding(112, thread="threadIdx.x"): for rh_0, rw_0, rc_0 in T.grid(3, 3, 1): for ax0_ax1_ax2_ax3_fused in range(95040): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(58, n_0_h_0_w_0_co_0_fused * 28 + rh_0 + ax0_ax1_ax2_ax3_fused % 95040 // 3520) v2 = T.axis.spatial(58, rw_0 + ax0_ax1_ax2_ax3_fused % 3520 // 64) v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) T.reads(inputs[v0, v1 - 1, v2 - 1, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(2048): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0, v1 = T.axis.remap("SS", [rh_0, rw_0]) v2 = T.axis.spatial(16, ax0_ax1_ax2_ax3_fused // 128) v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 128) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 1, 2, 1, 2, 1, 2, 1, 1, 8, 1, 7, 4, 4): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(28, n_0_h_0_w_0_co_0_fused * 14 + h_3 * 7 + h_4) v_w = T.axis.spatial(28, n_2_h_2_w_2_co_2_fused // 16 * 4 + w_3 * 4 + w_4) @@ -667,12 +667,12 @@ def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, v_rc = T.axis.reduce(16, rc_0 * 16 + rc_1 * 8 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 32 * 16 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 14, 4, 8): - with T.block("conv2d_nhwc_local"): + with T.sblock("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(28, n_0_h_0_w_0_co_0_fused * 14 + ax1) v2 = T.axis.spatial(28, n_2_h_2_w_2_co_2_fused // 16 * 4 + ax2) @@ -708,10 +708,10 @@ def test_cuda_t2d(): @T.prim_func def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 64}) + T.sblock_attr({"meta_schedule.unroll_explicit": 64}) conv2d_transpose_nhwc_local = T.alloc_buffer((1, 8, 8, 256), scope="local") PadInput_shared = T.alloc_buffer((1, 6, 6, 512), scope="shared") weight_shared = T.alloc_buffer((4, 4, 512, 256), scope="shared") @@ -720,27 +720,27 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 for n_2_h_2_w_2_co_2_fused in T.thread_binding(1, thread="threadIdx.x"): for rh_0, rw_0, rc_0 in T.grid(4, 1, 16): for ax0_ax1_ax2_ax3_fused in range(rh_0 % 2 * 96 + 96): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (rh_0 % 2 + 1)) // 96) v2 = T.axis.spatial(6, n_0_h_0_w_0_co_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32) v3 = T.axis.spatial(512, rc_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(inputs[v0, v1 - 1, v2 - 1, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 2}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 5 and 1 <= v2 and v2 < 5, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(2048): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0 = T.axis.spatial(4, rh_0 * -1 + 3) v1 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused // 512) v2 = T.axis.spatial(512, rc_0 * 32 + ax0_ax1_ax2_ax3_fused % 512 // 16) v3 = T.axis.spatial(256, n_0_h_0_w_0_co_0_fused % 16 * 16 + ax0_ax1_ax2_ax3_fused % 16) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(1, 1, 4, 1, 2, 1, 8, 1, 4, 8, 1, 1, 2, 1): - with T.block("conv2d_transpose_nhwc"): + with T.sblock("conv2d_transpose_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused // 64 * 2 + h_3 + h_4) v_w = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused % 64 // 16 * 2 + w_3 * 2 + w_4) @@ -750,12 +750,12 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 v_rc = T.axis.reduce(512, rc_0 * 32 + rc_1 * 8 + rc_2) T.reads(PadInput_shared[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], weight_shared[3 - v_rh, 3 - v_rw, v_rc, v_co]) T.writes(conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_transpose_nhwc_local[v_n, v_h, v_w, v_co] + T.if_then_else((v_h + v_rh) % 2 == 0 and (v_w + v_rw) % 2 == 0, PadInput_shared[v_n, (v_h + v_rh) // 2, (v_w + v_rw) // 2, v_rc], T.float32(0)) * weight_shared[3 - v_rh, 3 - v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 2, 8): - with T.block("conv2d_transpose_nhwc_local"): + with T.sblock("conv2d_transpose_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused // 64 * 2 + ax1) v2 = T.axis.spatial(8, n_0_h_0_w_0_co_0_fused % 64 // 16 * 2 + ax2) @@ -792,15 +792,15 @@ def test_cuda_nrm(): @T.prim_func def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 512}) + T.sblock_attr({"meta_schedule.unroll_explicit": 512}) C = T.alloc_buffer((1,)) for b_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for b_fused_1 in T.thread_binding(1, thread="threadIdx.x"): for i, j in T.grid(256, 256): - with T.block("C"): + with T.sblock("C"): v_b = T.axis.spatial(1, 0) v_i, v_j = T.axis.remap("RR", [i, j]) T.reads(A[v_b, v_i, v_j]) @@ -810,7 +810,7 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C[v_b] = C[v_b] + A[v_b, v_i, v_j] * A[v_b, v_i, v_j] for b_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for b_fused_1 in T.thread_binding(1, thread="threadIdx.x"): - with T.block("D"): + with T.sblock("D"): v_b = T.axis.spatial(1, 0) T.reads(C[v_b]) T.writes(D[v_b]) @@ -818,15 +818,15 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N @T.prim_func def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) C_shared = T.alloc_buffer((1,), scope="shared") for b_0_fused in T.thread_binding(1, thread="blockIdx.x"): for ax0, ax1_ax2_fused_0 in T.grid(1, 512): for ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): - with T.block("C"): + with T.sblock("C"): v_b = T.axis.spatial(1, ax0) v_i = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) // 256) v_j = T.axis.reduce(256, (ax1_ax2_fused_0 * 128 + ax1_ax2_fused_1) % 256) @@ -836,7 +836,7 @@ def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N C_shared[v_b] = T.float32(0) C_shared[v_b] = C_shared[v_b] + A[v_b, v_i, v_j] * A[v_b, v_i, v_j] for b_1 in T.thread_binding(128, thread="threadIdx.x"): - with T.block("D"): + with T.sblock("D"): v_b = T.axis.spatial(1, b_1) T.where(T.Mul(0, 128) + b_1 < 1) T.reads(C_shared[v_b]) @@ -865,16 +865,16 @@ def test_cuda_sfm(): @T.prim_func def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 0}) + T.sblock_attr({"meta_schedule.unroll_explicit": 0}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum = T.alloc_buffer((256,)) for i0_fused_0 in T.thread_binding(2, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for k in range(256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(256, i0_fused_0 * 128 + i0_fused_1) v_k = T.axis.reduce(256, k) T.reads(A[v_i0, v_k]) @@ -885,7 +885,7 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): for k in range(256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0 = T.axis.spatial(256, i0_fused_0 * 256 + i0_fused_1) v_k = T.axis.reduce(256, k) T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) @@ -895,26 +895,26 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) for i0_i1_fused_0 in T.thread_binding(1024, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0 = T.axis.spatial(256, (i0_i1_fused_0 * 64 + i0_i1_fused_1) // 256) v_i1 = T.axis.spatial(256, (i0_i1_fused_0 * 64 + i0_i1_fused_1) % 256) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum = T.alloc_buffer((256,)) for i0_fused in T.thread_binding(256, thread="blockIdx.x"): for k_0 in range(64): for k_1 in T.thread_binding(4, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(256, i0_fused) v_k = T.axis.reduce(256, k_0 * 4 + k_1) T.reads(A[v_i0, v_k]) @@ -925,7 +925,7 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 for i0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): for k in range(256): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0 = T.axis.spatial(256, i0_fused_0 * 64 + i0_fused_1) v_k = T.axis.reduce(256, k) T.reads(A[v_i0, v_k], T_softmax_maxelem[v_i0]) @@ -935,26 +935,26 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) for i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0 = T.axis.spatial(256, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 256) v_i1 = T.axis.spatial(256, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 256) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 512}) + T.sblock_attr({"meta_schedule.unroll_explicit": 512}) T_softmax_maxelem = T.alloc_buffer((256,)) T_softmax_expsum_shared = T.alloc_buffer((256,), scope="shared") for i0_fused_0 in T.thread_binding(8, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for k in range(256): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(256, i0_fused_0 * 32 + i0_fused_1) v_k = T.axis.reduce(256, k) T.reads(A[v_i0, v_k]) @@ -965,7 +965,7 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 for i0_fused in T.thread_binding(256, thread="blockIdx.x"): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0 = T.axis.spatial(256, i0_fused + ax0) v_k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) @@ -976,27 +976,27 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_shared[v_i0] = T_softmax_expsum_shared[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem[v_i0]) for i1_0 in range(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0 = T.axis.spatial(256, i0_fused) v_i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0], T_softmax_expsum_shared[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum_shared[v_i0] @T.prim_func def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 0}) + T.sblock_attr({"meta_schedule.unroll_explicit": 0}) T_softmax_maxelem_shared = T.alloc_buffer((256,), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((256,), scope="shared") for i0_fused in T.thread_binding(256, thread="blockIdx.x"): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(256, i0_fused + ax0) v_k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) @@ -1007,7 +1007,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_maxelem_shared[v_i0] = T.max(T_softmax_maxelem_shared[v_i0], A[v_i0, v_k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0 = T.axis.spatial(256, i0_fused + ax0) v_k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) @@ -1018,13 +1018,13 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_expsum_shared[v_i0] = T_softmax_expsum_shared[v_i0] + T.exp(A[v_i0, v_k] - T_softmax_maxelem_shared[v_i0]) for i1_0 in range(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0 = T.axis.spatial(256, i0_fused) v_i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[v_i0, v_i1], T_softmax_maxelem_shared[v_i0], T_softmax_expsum_shared[v_i0]) T.writes(T_softmax_norm[v_i0, v_i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem_shared[v_i0]) / T_softmax_expsum_shared[v_i0] # fmt: on decision_0 = [ @@ -1064,10 +1064,10 @@ def test_cuda_cbr(): @T.prim_func def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 512}) + T.sblock_attr({"meta_schedule.unroll_explicit": 512}) Conv2dOutput_local = T.alloc_buffer((1, 112, 112, 64), scope="local") PaddedInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") kernel_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") @@ -1076,27 +1076,27 @@ def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 for nn_2_yy_2_xx_2_ff_2_fused in T.thread_binding(128, thread="threadIdx.x"): for ry_0, rx_0, rc_0 in T.grid(7, 1, 3): for ax0_ax1_ax2_ax3_fused in range(8251): - with T.block("PaddedInput_shared"): + with T.sblock("PaddedInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(230, ry_0 + ax0_ax1_ax2_ax3_fused // 37) v2 = T.axis.spatial(230, nn_0_yy_0_xx_0_ff_0_fused // 2 * 32 + ax0_ax1_ax2_ax3_fused % 37) v3 = T.axis.spatial(3, rc_0) T.reads(data[v0, v1 - 3, v2 - 3, v3]) T.writes(PaddedInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) PaddedInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, data[v0, v1 - 3, v2 - 3, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(224): - with T.block("kernel_shared"): + with T.sblock("kernel_shared"): v0 = T.axis.spatial(7, ry_0) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 32) v2 = T.axis.spatial(3, rc_0) v3 = T.axis.spatial(64, nn_0_yy_0_xx_0_ff_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(kernel[v0, v1, v2, v3]) T.writes(kernel_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 1}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 1}) kernel_shared[v0, v1, v2, v3] = kernel[v0, v1, v2, v3] for ry_1, rx_1, rc_1, nn_3, yy_3, xx_3, ff_3, ry_2, rx_2, rc_2, nn_4, yy_4, xx_4, ff_4 in T.grid(1, 1, 1, 1, 1, 1, 2, 1, 7, 1, 1, 7, 1, 8): - with T.block("Conv2dOutput"): + with T.sblock("Conv2dOutput"): v_nn = T.axis.spatial(1, nn_3 + nn_4) v_yy = T.axis.spatial(112, nn_1_yy_1_xx_1_ff_1_fused // 2 * 56 + nn_2_yy_2_xx_2_ff_2_fused // 16 * 7 + yy_3 * 7 + yy_4) v_xx = T.axis.spatial(112, nn_0_yy_0_xx_0_ff_0_fused // 2 * 16 + nn_2_yy_2_xx_2_ff_2_fused % 16 + xx_3 + xx_4) @@ -1106,12 +1106,12 @@ def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 v_rc = T.axis.reduce(3, rc_0 + rc_1 + rc_2) T.reads(PaddedInput_shared[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc], kernel_shared[v_ry, v_rx, v_rc, v_ff]) T.writes(Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff] = T.float32(0) Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff] = Conv2dOutput_local[v_nn, v_yy, v_xx, v_ff] + PaddedInput_shared[v_nn, v_yy * 2 + v_ry, v_xx * 2 + v_rx, v_rc] * kernel_shared[v_ry, v_rx, v_rc, v_ff] for ax0, ax1, ax2, ax3 in T.grid(1, 7, 1, 16): - with T.block("Conv2dOutput_local"): + with T.sblock("Conv2dOutput_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, nn_1_yy_1_xx_1_ff_1_fused // 2 * 56 + nn_2_yy_2_xx_2_ff_2_fused // 16 * 7 + ax1) v2 = T.axis.spatial(112, nn_0_yy_0_xx_0_ff_0_fused // 2 * 16 + nn_2_yy_2_xx_2_ff_2_fused % 16 + ax2) @@ -1147,10 +1147,10 @@ def test_cuda_tbg(): @T.prim_func def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) C_local = T.alloc_buffer((1, 12, 128, 128), scope="local") query_T_shared = T.alloc_buffer((1, 12, 128, 64), scope="shared") value_T_shared = T.alloc_buffer((1, 12, 64, 128), scope="shared") @@ -1159,27 +1159,27 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, for b_2_h_2_i_2_j_2_fused in T.thread_binding(32, thread="threadIdx.x"): for k_0 in range(8): for ax0_ax1_ax2_ax3_fused in range(12288): - with T.block("query_T_shared"): + with T.sblock("query_T_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 1024) v2 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 1024 // 8) v3 = T.axis.spatial(64, k_0 * 8 + ax0_ax1_ax2_ax3_fused % 8) T.reads(query[v0, v2, v1, v3]) T.writes(query_T_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) query_T_shared[v0, v1, v2, v3] = query[v0, v2, v1, v3] for ax0_ax1_ax2_ax3_fused in range(3072): - with T.block("value_T_shared"): + with T.sblock("value_T_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 256) v2 = T.axis.spatial(64, k_0 * 8 + ax0_ax1_ax2_ax3_fused % 256 // 32) v3 = T.axis.spatial(128, b_0_h_0_i_0_j_0_fused * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(value[v0, v3, v1, v2]) T.writes(value_T_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) value_T_shared[v0, v1, v2, v3] = value[v0, v3, v1, v2] for k_1, b_3, h_3, i_3, j_3, k_2, b_4, h_4, i_4, j_4 in T.grid(4, 1, 2, 1, 1, 2, 1, 1, 4, 1): - with T.block("C"): + with T.sblock("C"): v_b = T.axis.spatial(1, b_3 + b_4) v_h = T.axis.spatial(12, b_1_h_1_i_1_j_1_fused // 32 * 2 + h_3 + h_4) v_i = T.axis.spatial(128, b_1_h_1_i_1_j_1_fused % 32 // 8 * 32 + b_2_h_2_i_2_j_2_fused // 4 * 4 + i_3 * 4 + i_4) @@ -1187,12 +1187,12 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, v_k = T.axis.reduce(64, k_0 * 8 + k_1 * 2 + k_2) T.reads(query_T_shared[v_b, v_h, v_i, v_k], value_T_shared[v_b, v_h, v_k, v_j]) T.writes(C_local[v_b, v_h, v_i, v_j]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): C_local[v_b, v_h, v_i, v_j] = T.float32(0) C_local[v_b, v_h, v_i, v_j] = C_local[v_b, v_h, v_i, v_j] + query_T_shared[v_b, v_h, v_i, v_k] * value_T_shared[v_b, v_h, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 1): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(12, b_1_h_1_i_1_j_1_fused // 32 * 2 + ax1) v2 = T.axis.spatial(128, b_1_h_1_i_1_j_1_fused % 32 // 8 * 32 + b_2_h_2_i_2_j_2_fused // 4 * 4 + ax2) diff --git a/tests/python/meta_schedule/test_meta_schedule_space_cuda_async.py b/tests/python/meta_schedule/test_meta_schedule_space_cuda_async.py index 3386b0102564..f129d8c28468 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_cuda_async.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_cuda_async.py @@ -45,10 +45,10 @@ def get_c2d_prim_func(stage: int): @T.prim_func def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), scope="local") PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") @@ -57,27 +57,27 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 for n_2_h_2_w_2_co_2_fused in T.thread_binding(64, thread="threadIdx.x"): for rh_0, rw_0, rc_0 in T.grid(1, 1, 3): for ax0_ax1_ax2_ax3_fused in range(693): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33) v2 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33) v3 = T.axis.spatial(3, rc_0) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(3136): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 448) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused % 448 // 64) v2 = T.axis.spatial(3, rc_0) v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 1, 1, 1, 1): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4) v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4) @@ -87,12 +87,12 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 v_rc = T.axis.reduce(3, rc_0 + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): - with T.block("conv2d_nhwc_local"): + with T.sblock("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + ax1) v2 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2) @@ -106,10 +106,10 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 @T.prim_func def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) conv2d_nhwc_local = T.alloc_buffer((1, 112, 112, 64), scope="local") PadInput_shared = T.alloc_buffer((1, 230, 230, 3), scope="shared") weight_shared = T.alloc_buffer((7, 7, 3, 64), scope="shared") @@ -118,27 +118,27 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 for n_2_h_2_w_2_co_2_fused in T.thread_binding(64, thread="threadIdx.x"): for rh_0_rw_0_rc_0_fused in T.serial(3, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}): for ax0_ax1_ax2_ax3_fused in range(693): - with T.block("PadInput_shared"): + with T.sblock("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused // 8 * 16 + ax0_ax1_ax2_ax3_fused // 33) v2 = T.axis.spatial(230, n_0_h_0_w_0_co_0_fused % 8 * 28 + ax0_ax1_ax2_ax3_fused % 33) v3 = T.axis.spatial(3, rh_0_rw_0_rc_0_fused) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) PadInput_shared[v0, v1, v2, v3] = T.if_then_else(3 <= v1 and v1 < 227 and 3 <= v2 and v2 < 227, inputs[v0, v1 - 3, v2 - 3, v3], T.float32(0)) for ax0_ax1_ax2_ax3_fused in range(3136): - with T.block("weight_shared"): + with T.sblock("weight_shared"): v0 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused // 448) v1 = T.axis.spatial(7, ax0_ax1_ax2_ax3_fused % 448 // 64) v2 = T.axis.spatial(3, rh_0_rw_0_rc_0_fused) v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_shared[v0, v1, v2, v3]) - T.block_attr({"meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 3}) weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for rh_1, rw_1, rc_1, n_3, h_3, w_3, co_3, rh_2, rw_2, rc_2, n_4, h_4, w_4, co_4 in T.grid(7, 1, 1, 1, 1, 14, 1, 1, 7, 1, 1, 1, 1, 1): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_n = T.axis.spatial(1, n_3 + n_4) v_h = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + h_3 + h_4) v_w = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + w_3 + w_4) @@ -148,12 +148,12 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 v_rc = T.axis.reduce(3, rh_0_rw_0_rc_0_fused + rc_1 + rc_2) T.reads(PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc], weight_shared[v_rh, v_rw, v_rc, v_co]) T.writes(conv2d_nhwc_local[v_n, v_h, v_w, v_co]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): conv2d_nhwc_local[v_n, v_h, v_w, v_co] = T.float32(0) conv2d_nhwc_local[v_n, v_h, v_w, v_co] = conv2d_nhwc_local[v_n, v_h, v_w, v_co] + PadInput_shared[v_n, v_h * 2 + v_rh, v_w * 2 + v_rw, v_co // 64 * 3 + v_rc] * weight_shared[v_rh, v_rw, v_rc, v_co] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): - with T.block("conv2d_nhwc_local"): + with T.sblock("conv2d_nhwc_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused // 8 * 8 + n_1_h_1_w_1_co_1_fused // 4 * 4 + n_2_h_2_w_2_co_2_fused // 16 + ax1) v2 = T.axis.spatial(112, n_0_h_0_w_0_co_0_fused % 8 * 14 + ax2) @@ -199,10 +199,10 @@ def get_gmm_prim_func(stage: int): @T.prim_func def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32"), Z: T.Buffer((1, 1024, 1024), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) Z_local = T.alloc_buffer((1, 1024, 1024), scope="local") X_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") Y_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") @@ -211,37 +211,37 @@ def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "f for b_2_i_2_j_2_fused in T.thread_binding(64, thread="threadIdx.x"): for k_0 in range(64): for ax0_ax1_ax2_fused in range(1024): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16) v2 = T.axis.spatial(1024, k_0 * 16 + ax0_ax1_ax2_fused % 16) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused in range(1024): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, k_0 * 16 + ax0_ax1_ax2_fused // 64) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(2, 1, 1, 1, 8, 1, 1, 2): - with T.block("Z"): + with T.sblock("Z"): v_b = T.axis.spatial(1, b_3 + b_4) v_i = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + i_3 + i_4) v_j = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + j_3 * 2 + j_4) v_k = T.axis.reduce(1024, k_0 * 16 + k_1 * 8 + k_2) T.reads(X_shared[v_b, v_i, v_k], Y_shared[v_b, v_k, v_j]) T.writes(Z_local[v_b, v_i, v_j]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): Z_local[v_b, v_i, v_j] = T.float32(0) Z_local[v_b, v_i, v_j] = Z_local[v_b, v_i, v_j] + X_shared[v_b, v_i, v_k] * Y_shared[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 1, 2): - with T.block("Z_local"): + with T.sblock("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + ax1) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + ax2) @@ -254,10 +254,10 @@ def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "f @T.prim_func def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32"), Z: T.Buffer((1, 1024, 1024), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 16}) + T.sblock_attr({"meta_schedule.unroll_explicit": 16}) Z_local = T.alloc_buffer((1, 1024, 1024), scope="local") X_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") Y_shared = T.alloc_buffer((1, 1024, 1024), scope="shared") @@ -266,37 +266,37 @@ def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "f for b_2_i_2_j_2_fused in T.thread_binding(64, thread="threadIdx.x"): for k_0_fused in T.serial(64, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, stage - 2]}): for ax0_ax1_ax2_fused in range(1024): - with T.block("X_shared"): + with T.sblock("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + ax0_ax1_ax2_fused // 16) v2 = T.axis.spatial(1024, k_0_fused * 16 + ax0_ax1_ax2_fused % 16) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused in range(1024): - with T.block("Y_shared"): + with T.sblock("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1024, k_0_fused * 16 + ax0_ax1_ax2_fused // 64) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + ax0_ax1_ax2_fused % 64) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) - T.block_attr({"meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"meta_schedule.cooperative_fetch": 4}) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for k_1, b_3, i_3, j_3, k_2, b_4, i_4, j_4 in T.grid(2, 1, 1, 1, 8, 1, 1, 2): - with T.block("Z"): + with T.sblock("Z"): v_b = T.axis.spatial(1, b_3 + b_4) v_i = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + i_3 + i_4) v_j = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + j_3 * 2 + j_4) v_k = T.axis.reduce(1024, k_0_fused * 16 + k_1 * 8 + k_2) T.reads(X_shared[v_b, v_i, v_k], Y_shared[v_b, v_k, v_j]) T.writes(Z_local[v_b, v_i, v_j]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) with T.init(): Z_local[v_b, v_i, v_j] = T.float32(0) Z_local[v_b, v_i, v_j] = Z_local[v_b, v_i, v_j] + X_shared[v_b, v_i, v_k] * Y_shared[v_b, v_k, v_j] for ax0, ax1, ax2 in T.grid(1, 1, 2): - with T.block("Z_local"): + with T.sblock("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(1024, b_0_i_0_j_0_fused // 16 * 64 + b_1_i_1_j_1_fused // 4 * 8 + b_2_i_2_j_2_fused // 8 + ax1) v2 = T.axis.spatial(1024, b_0_i_0_j_0_fused % 16 * 64 + b_1_i_1_j_1_fused % 4 * 16 + b_2_i_2_j_2_fused % 8 * 2 + ax2) diff --git a/tests/python/meta_schedule/test_meta_schedule_space_generator.py b/tests/python/meta_schedule/test_meta_schedule_space_generator.py index 9457a9a40f00..100c5255afd5 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_generator.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_generator.py @@ -45,7 +45,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -56,7 +56,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: def schedule_matmul(sch: Schedule): - block = sch.get_block("matmul") + block = sch.get_sblock("matmul") i, j, k = sch.get_loops(block=block) # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) diff --git a/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py b/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py index 4cb5ec59630c..dd0215b5586e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py @@ -38,7 +38,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_task_scheduler.py b/tests/python/meta_schedule/test_meta_schedule_task_scheduler.py index ab0e3f0123dd..95edc36ace9a 100644 --- a/tests/python/meta_schedule/test_meta_schedule_task_scheduler.py +++ b/tests/python/meta_schedule/test_meta_schedule_task_scheduler.py @@ -44,7 +44,7 @@ def main( # type: ignore B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 # type: ignore @@ -65,13 +65,13 @@ def main( # type: ignore D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) # type: ignore @@ -89,7 +89,7 @@ def main( # type: ignore B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) for n, i, j, k in T.grid(16, 128, 128, 128): - with T.block("matmul"): + with T.sblock("matmul"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) with T.init(): C[vn, vi, vj] = 0.0 # type: ignore @@ -100,7 +100,7 @@ def main( # type: ignore def _schedule_matmul(sch: Schedule): - block = sch.get_block("matmul") + block = sch.get_sblock("matmul") i, j, k = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) @@ -109,7 +109,7 @@ def _schedule_matmul(sch: Schedule): def _schedule_batch_matmul(sch: Schedule): - block = sch.get_block("matmul") + block = sch.get_sblock("matmul") i, j, k, t = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2]) diff --git a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py index 637f3093d8e1..f42c05e2790e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py @@ -38,13 +38,13 @@ def main( # function attr dict T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(p0[i, k], p1[j, k]) T.writes(T_matmul_NT[i, j]) - T.block_attr({"layout_free_placeholders": []}) + T.sblock_attr({"layout_free_placeholders": []}) with T.init(): T_matmul_NT[i, j] = T.float32(0) T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] @@ -61,25 +61,25 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body - # with T.block("root") + # with T.sblock("root") T_matmul_NT = T.alloc_buffer([128, 128], dtype="float32") compile_engine_const = T.alloc_buffer([], dtype="float32") for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(p0[i, k], p1[j, k]) T.writes(T_matmul_NT[i, j]) - T.block_attr({"layout_free_placeholders": []}) + T.sblock_attr({"layout_free_placeholders": []}) with T.init(): T_matmul_NT[i, j] = T.float32(0) T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1[j, k] - with T.block("compile_engine_const"): + with T.sblock("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.float32(1) for i0, i1 in T.grid(128, 128): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_matmul_NT[ax0, ax1], compile_engine_const[()]) T.writes(T_add[ax0, ax1]) @@ -97,21 +97,21 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body - # with T.block("root") + # with T.sblock("root") T_matmul_NT_global = T.alloc_buffer([128, 128], dtype="float32") p1_global = T.alloc_buffer([2, 128, 64], dtype="float32") for ax0, ax1 in T.grid(128, 128): - with T.block("p1_global"): + with T.sblock("p1_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(p1[v0, v1]) T.writes(p1_global[v0 // 64, v1, v0 % 64]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) p1_global[v0 // 64, v1, v0 % 64] = p1[v0, v1] for i0_0_i1_0_fused_fused in T.parallel(4): for i0_1, i1_1 in T.grid(8, 1): for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 1, 2): for i1_3_fused_init in T.vectorized(64): - with T.block("T_matmul_NT_init"): + with T.sblock("T_matmul_NT_init"): i = T.axis.spatial( 128, i0_0_i1_0_fused_fused // 2 * 64 @@ -128,7 +128,7 @@ def main( ) T.reads() T.writes(T_matmul_NT_global[i, j]) - T.block_attr( + T.sblock_attr( { "layout_free_placeholders": [], "meta_schedule.tiling_structure": "SSRSRS", @@ -137,7 +137,7 @@ def main( T_matmul_NT_global[i, j] = T.float32(0) for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(128, 4, 1, 1, 2): for i1_3_fused in T.vectorized(64): - with T.block("T_matmul_NT_update"): + with T.sblock("T_matmul_NT_update"): i = T.axis.spatial( 128, i0_0_i1_0_fused_fused // 2 * 64 + i0_1 * 8 + i0_2 * 2 + i0_3 ) @@ -150,7 +150,7 @@ def main( T_matmul_NT_global[i, j], p0[i, k], p1_global[j // 64, k, j % 64] ) T.writes(T_matmul_NT_global[i, j]) - T.block_attr( + T.sblock_attr( { "layout_free_placeholders": [], "meta_schedule.tiling_structure": "SSRSRS", @@ -161,7 +161,7 @@ def main( ) for ax0 in T.serial(64): for ax1_fused in T.vectorized(64): - with T.block("T_matmul_NT_global"): + with T.sblock("T_matmul_NT_global"): v0 = T.axis.spatial(128, i0_0_i1_0_fused_fused // 2 * 64 + ax0) v1 = T.axis.spatial(128, i0_0_i1_0_fused_fused % 2 * 64 + ax1_fused) T.reads(T_matmul_NT_global[v0, v1]) @@ -176,38 +176,38 @@ def main(p0: T.Buffer((128, 128), "float32"), p1: T.Buffer((128, 128), "float32" # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body - # with T.block("root") + # with T.sblock("root") T_matmul_NT = T.alloc_buffer([128, 128], dtype="float32") p1_global = T.alloc_buffer([8, 4, 16, 32], dtype="float32") for ax0, ax1 in T.grid(128, 128): - with T.block("p1_global"): + with T.sblock("p1_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(p1[v0, v1]) T.writes(p1_global[v1 // 16, v0 // 32, v1 % 16, v0 % 32]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) p1_global[v1 // 16, v0 // 32, v1 % 16, v0 % 32] = p1[v0, v1] for i0_0_i1_0_i0_1_i1_1_fused in T.parallel(16, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): for i0_2_init, i1_2_init, i0_3_init in T.grid(4, 4, 2): for i1_3_fused_init in T.vectorized(32): - with T.block("T_matmul_NT_init"): + with T.sblock("T_matmul_NT_init"): i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused * 8 + i0_2_init * 2 + i0_3_init) j = T.axis.spatial(128, i1_2_init * 32 + i1_3_fused_init) T.reads() T.writes(T_matmul_NT[i, j]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) T_matmul_NT[i, j] = T.float32(0) for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(8, 4, 4, 16, 2): for i1_3_fused in T.vectorized(32): - with T.block("T_matmul_NT_update"): + with T.sblock("T_matmul_NT_update"): i = T.axis.spatial(128, i0_0_i1_0_i0_1_i1_1_fused * 8 + i0_2 * 2 + i0_3) j = T.axis.spatial(128, i1_2 * 32 + i1_3_fused) k = T.axis.reduce(128, i2_0 * 16 + i2_1) T.reads(T_matmul_NT[i, j], p0[i, k], p1_global[k // 16, j // 32, k % 16, j % 32]) T.writes(T_matmul_NT[i, j]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) + T.sblock_attr({"layout_free_placeholders":[], "meta_schedule.tiling_structure":"SSRSRS"}) T_matmul_NT[i, j] = T_matmul_NT[i, j] + p0[i, k] * p1_global[k // 16, j // 32, k % 16, j % 32] for i0_i1_fused in T.parallel(16384): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial(128, i0_i1_fused // 128) ax1 = T.axis.spatial(128, i0_i1_fused % 128) T.reads(T_matmul_NT[ax0, ax1]) @@ -226,7 +226,7 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body - # with T.block("root") + # with T.sblock("root") T_matmul_NT_local = T.alloc_buffer([128, 128], dtype="float32", scope="local") p0_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") p1_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") @@ -238,7 +238,7 @@ def main( for i0_1_i1_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(1, 4, 1, 1): - with T.block("T_matmul_NT_init"): + with T.sblock("T_matmul_NT_init"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 4 * 16 @@ -255,7 +255,7 @@ def main( ) T.reads() T.writes(T_matmul_NT_local[i, j]) - T.block_attr( + T.sblock_attr( { "layout_free_placeholders": [], "meta_schedule.thread_extent_high_inclusive": 256, @@ -268,7 +268,7 @@ def main( for ax0_ax1_fused_0 in T.serial(1): for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(2): - with T.block("p0_shared"): + with T.sblock("p0_shared"): T.where( (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 2 + ax0_ax1_fused_2 @@ -300,7 +300,7 @@ def main( for ax0_ax1_fused_0 in T.serial(1): for ax0_ax1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("p1_shared"): + with T.sblock("p1_shared"): T.where( (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2 @@ -330,7 +330,7 @@ def main( T.writes(p1_shared[v0, v1]) p1_shared[v0, v1] = p1[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(1, 1, 4, 4, 1, 1): - with T.block("T_matmul_NT_update"): + with T.sblock("T_matmul_NT_update"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + i0_3 + i0_4, @@ -345,7 +345,7 @@ def main( k = T.axis.reduce(128, i2_0 * 4 + i2_1 * 4 + i2_2) T.reads(T_matmul_NT_local[i, j], p0_shared[i, k], p1_shared[j, k]) T.writes(T_matmul_NT_local[i, j]) - T.block_attr( + T.sblock_attr( { "layout_free_placeholders": [], "meta_schedule.thread_extent_high_inclusive": 256, @@ -357,7 +357,7 @@ def main( T_matmul_NT_local[i, j] + p0_shared[i, k] * p1_shared[j, k] ) for ax0, ax1 in T.grid(1, 4): - with T.block("T_matmul_NT_local"): + with T.sblock("T_matmul_NT_local"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 4 * 16 + i0_2_i1_2_fused // 8 + ax0 ) @@ -376,7 +376,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") @@ -392,13 +392,13 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " T_cast_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_subtract_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) T.writes(conv2d_nhwc[nn, yy, xx, ff]) @@ -406,79 +406,79 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " conv2d_nhwc[nn, yy, xx, ff] = 0 conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_cast"): + with T.sblock("T_cast"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[ax0, ax1, ax2, ax3]) T.writes(T_cast[ax0, ax1, ax2, ax3]) T_cast[ax0, ax1, ax2, ax3] = T.cast(T_add[ax0, ax1, ax2, ax3], "int64") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_cast[ax0, ax1, ax2, ax3], p4[0, 0, 0, ax3]) T.writes(T_multiply[ax0, ax1, ax2, ax3]) T_multiply[ax0, ax1, ax2, ax3] = T_cast[ax0, ax1, ax2, ax3] * p4[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply[ax0, ax1, ax2, ax3], p5[0, 0, 0, ax3]) T.writes(T_add_1[ax0, ax1, ax2, ax3]) T_add_1[ax0, ax1, ax2, ax3] = T_multiply[ax0, ax1, ax2, ax3] + p5[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_right_shift"): + with T.sblock("T_right_shift"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3]) T.writes(T_right_shift[ax0, ax1, ax2, ax3]) T_right_shift[ax0, ax1, ax2, ax3] = T.shift_right(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3], dtype="int64") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_right_shift[ax0, ax1, ax2, ax3]) T.writes(T_cast_1[ax0, ax1, ax2, ax3]) T_cast_1[ax0, ax1, ax2, ax3] = T.cast(T_right_shift[ax0, ax1, ax2, ax3], "int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p7[()], T_cast_1[ax0, ax1, ax2, ax3]) T.writes(T_add_2[ax0, ax1, ax2, ax3]) T_add_2[ax0, ax1, ax2, ax3] = p7[()] + T_cast_1[ax0, ax1, ax2, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("compute"): + with T.sblock("compute"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add_2[i0_2, i1_2, i2_2, i3_2]) T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) compute_1[i0_2, i1_2, i2_2, i3_2] = T.max(T.min(T_add_2[i0_2, i1_2, i2_2, i3_2], 255), 0) for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) T.reads(compute_1[ax0, ax1, ax2, ax3]) T.writes(T_cast_2[ax0, ax1, ax2, ax3]) T_cast_2[ax0, ax1, ax2, ax3] = T.cast(compute_1[ax0, ax1, ax2, ax3], "uint8") for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): - with T.block("T_cast_3"): + with T.sblock("T_cast_3"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) T.reads(T_cast_2[ax0, ax1, ax2, ax3]) T.writes(T_cast_3[ax0, ax1, ax2, ax3]) T_cast_3[ax0, ax1, ax2, ax3] = T.cast(T_cast_2[ax0, ax1, ax2, ax3], "int32") for i0_5, i1_5, i2_5, i3_5 in T.grid(16, 56, 56, 256): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) T.reads(T_cast_3[ax0, ax1, ax2, ax3], p8[0]) T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) T_subtract_1[ax0, ax1, ax2, ax3] = T_cast_3[ax0, ax1, ax2, ax3] - p8[0] for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_7, i1_7, i2_7, i3_7 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) T.reads(T_subtract_1[i0_7, i1_7, i2_7, i3_7]) T.writes(compute[i0_7, i1_7, i2_7, i3_7]) @@ -492,7 +492,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") @@ -512,13 +512,13 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_cast_4 = T.alloc_buffer([16, 56, 56, 256], dtype="uint8") for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) T.writes(conv2d_nhwc[nn, yy, xx, ff]) @@ -526,103 +526,103 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " conv2d_nhwc[nn, yy, xx, ff] = 0 conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_cast"): + with T.sblock("T_cast"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[ax0, ax1, ax2, ax3]) T.writes(T_cast[ax0, ax1, ax2, ax3]) T_cast[ax0, ax1, ax2, ax3] = T.cast(T_add[ax0, ax1, ax2, ax3], "int64") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_cast[ax0, ax1, ax2, ax3], p4[0, 0, 0, ax3]) T.writes(T_multiply[ax0, ax1, ax2, ax3]) T_multiply[ax0, ax1, ax2, ax3] = T_cast[ax0, ax1, ax2, ax3] * p4[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply[ax0, ax1, ax2, ax3], p5[0, 0, 0, ax3]) T.writes(T_add_1[ax0, ax1, ax2, ax3]) T_add_1[ax0, ax1, ax2, ax3] = T_multiply[ax0, ax1, ax2, ax3] + p5[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_right_shift"): + with T.sblock("T_right_shift"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3]) T.writes(T_right_shift[ax0, ax1, ax2, ax3]) T_right_shift[ax0, ax1, ax2, ax3] = T.shift_right(T_add_1[ax0, ax1, ax2, ax3], p6[0, 0, 0, ax3], dtype="int64") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_right_shift[ax0, ax1, ax2, ax3]) T.writes(T_cast_1[ax0, ax1, ax2, ax3]) T_cast_1[ax0, ax1, ax2, ax3] = T.cast(T_right_shift[ax0, ax1, ax2, ax3], "int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p7[()], T_cast_1[ax0, ax1, ax2, ax3]) T.writes(T_add_2[ax0, ax1, ax2, ax3]) T_add_2[ax0, ax1, ax2, ax3] = p7[()] + T_cast_1[ax0, ax1, ax2, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("compute"): + with T.sblock("compute"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add_2[i0_2, i1_2, i2_2, i3_2]) T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) compute_1[i0_2, i1_2, i2_2, i3_2] = T.max(T.min(T_add_2[i0_2, i1_2, i2_2, i3_2], 255), 0) for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) T.reads(compute_1[ax0, ax1, ax2, ax3]) T.writes(T_cast_2[ax0, ax1, ax2, ax3]) T_cast_2[ax0, ax1, ax2, ax3] = T.cast(compute_1[ax0, ax1, ax2, ax3], "uint8") for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): - with T.block("T_cast_3"): + with T.sblock("T_cast_3"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) T.reads(T_cast_2[ax0, ax1, ax2, ax3]) T.writes(T_cast_3[ax0, ax1, ax2, ax3]) T_cast_3[ax0, ax1, ax2, ax3] = T.cast(T_cast_2[ax0, ax1, ax2, ax3], "int32") for i0_5, i1_5, i2_5, i3_5 in T.grid(16, 56, 56, 256): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) T.reads(T_cast_3[ax0, ax1, ax2, ax3], p8[0]) T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) T_subtract_1[ax0, ax1, ax2, ax3] = T_cast_3[ax0, ax1, ax2, ax3] - p8[0] for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_7, i1_7, i2_7, i3_7 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) T.reads(T_subtract_1[i0_7, i1_7, i2_7, i3_7]) T.writes(compute_2[i0_7, i1_7, i2_7, i3_7]) compute_2[i0_7, i1_7, i2_7, i3_7] = T.q_multiply_shift(T_subtract_1[i0_7, i1_7, i2_7, i3_7], 1098990753, 31, 1, dtype="int32") for i0_8, i1_8, i2_8, i3_8 in T.grid(16, 56, 56, 256): - with T.block("T_add_3"): + with T.sblock("T_add_3"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_8, i1_8, i2_8, i3_8]) T.reads(compute_2[ax0, ax1, ax2, ax3], p9[ax0, ax1, ax2, ax3]) T.writes(T_add_3[ax0, ax1, ax2, ax3]) T_add_3[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] + p9[ax0, ax1, ax2, ax3] for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 56, 56, 256): - with T.block("compute_2"): + with T.sblock("compute_2"): i0_10, i1_10, i2_10, i3_10 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) T.reads(T_add_3[i0_10, i1_10, i2_10, i3_10]) T.writes(compute_3[i0_10, i1_10, i2_10, i3_10]) compute_3[i0_10, i1_10, i2_10, i3_10] = T.max(T.min(T_add_3[i0_10, i1_10, i2_10, i3_10], 255), 0) for i0_11, i1_11, i2_11, i3_11 in T.grid(16, 56, 56, 256): - with T.block("T_cast_4"): + with T.sblock("T_cast_4"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_11, i1_11, i2_11, i3_11]) T.reads(compute_3[ax0, ax1, ax2, ax3]) T.writes(T_cast_4[ax0, ax1, ax2, ax3]) T_cast_4[ax0, ax1, ax2, ax3] = T.cast(compute_3[ax0, ax1, ax2, ax3], "uint8") for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): - with T.block("compute_3"): + with T.sblock("compute_3"): i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) T.reads(T_cast_4[i0_13, i1_13, i2_13, i3_13]) T.writes(compute[i0_13, i1_13, i2_13, i3_13]) @@ -634,7 +634,7 @@ class Conv2dInt8_tensorcore_scheduled: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "uint8")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") @@ -645,42 +645,42 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="vthread.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(16, thread="threadIdx.x"): for ax0_0_init, ax1_0_init, ax0_1_init, ax1_1_init, ax2_0_3_init, ax3_0_3_init, ax0_2_init, ax1_2_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1, 1, 1, 1, 1, 1, 1): - with T.block("conv2d_nhwc_o_init"): + with T.sblock("conv2d_nhwc_o_init"): v0_o = T.axis.spatial(1, ax0_0_init + ax0_1_init + ax0_2_init) v1_o = T.axis.spatial(1, ax1_0_init + ax1_1_init + ax1_2_init) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init + ax3_0_4_init) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): for ax0_ax1_fused_0 in range(16): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(16): - with T.block("pad_temp_reindex_shared"): + with T.sblock("pad_temp_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32) T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 16]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 16]]}) pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] for ax0_ax1_ax2_ax3_fused_0 in range(8): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8): - with T.block("p1_reindex_shared"): + with T.sblock("p1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32) v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32) T.reads(p1[v2, v0, v1, v3]) T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align": [[0, 2, 32, 16]]}) + T.sblock_attr({"buffer_dim_align": [[0, 2, 32, 16]]}) p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): for ax0_0_1, ax1_0_1 in T.grid(1, 2): - with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) @@ -689,7 +689,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " C = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2): - with T.block("p1_reindex_shared_wmma.matrix_b_o"): + with T.sblock("p1_reindex_shared_wmma.matrix_b_o"): v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax2_0) v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0) @@ -699,7 +699,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " C = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): - with T.block("conv2d_nhwc_o_update"): + with T.sblock("conv2d_nhwc_o_update"): v0_o = T.axis.spatial(1, ax0_0 + ax0_1 + ax0_2) v1_o = T.axis.spatial(1, ax1_0 + ax1_1 + ax1_2) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) @@ -707,13 +707,13 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + T.sblock_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) A = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) @@ -723,7 +723,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") for ax0, ax1_0 in T.grid(128, 2): for ax1_1 in T.thread_binding(16, thread="threadIdx.x"): - with T.block("conv2d_nhwc_reindex_shared"): + with T.sblock("conv2d_nhwc_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + ax0) v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + ax1_0 * 16 + ax1_1) T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[0, 0, 0, v1], p5[0, 0, 0, v1], p6[0, 0, 0, v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) @@ -737,7 +737,7 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") compile_engine_const = T.alloc_buffer([], dtype="float32") conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") T_add = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") @@ -759,134 +759,134 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, T_add_3 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") compute_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") T_cast_5 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") - with T.block("compile_engine_const"): + with T.sblock("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.float32(0.94537687301635742) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 128, 7, 7, 16, 1, 1, 32, 4, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) T.reads(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) - T.block_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) + T.sblock_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) with T.init(): conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.cast(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32") * T.cast(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4], p2[ax0, ax1, 0, 0, ax4]) T.writes(T_add[ax0, ax1, ax2, ax3, ax4]) T_add[ax0, ax1, ax2, ax3, ax4] = conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4] + p2[ax0, ax1, 0, 0, ax4] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast"): + with T.sblock("T_cast"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast[ax0, ax1, ax2, ax3, ax4]) T_cast[ax0, ax1, ax2, ax3, ax4] = T.cast(T_add[ax0, ax1, ax2, ax3, ax4], "float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast[ax0, ax1, ax2, ax3, ax4], p3[ax0, ax1, 0, 0, ax4]) T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) T_multiply[ax0, ax1, ax2, ax3, ax4] = T_cast[ax0, ax1, ax2, ax3, ax4] * p3[ax0, ax1, 0, 0, ax4] - with T.block("compile_engine_const_1"): + with T.sblock("compile_engine_const_1"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = T.float32(54.5) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_multiply[ax0, ax1, ax2, ax3, ax4], compile_engine_const_1[()]) T.writes(T_add_1[ax0, ax1, ax2, ax3, ax4]) T_add_1[ax0, ax1, ax2, ax3, ax4] = T_multiply[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_1[()] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_floor"): + with T.sblock("T_floor"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_floor[ax0, ax1, ax2, ax3, ax4]) T_floor[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_floor[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_1[ax0, ax1, ax2, ax3, ax4]) T_cast_1[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor[ax0, ax1, ax2, ax3, ax4], "int32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_1[i0_1, i1_1, i2_1, i3_1, i4_1]) T.writes(compute_1[i0_1, i1_1, i2_1, i3_1, i4_1]) compute_1[i0_1, i1_1, i2_1, i3_1, i4_1] = T.max(T.min(T_cast_1[i0_1, i1_1, i2_1, i3_1, i4_1], 255), 0) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(compute_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_2[ax0, ax1, ax2, ax3, ax4]) T_cast_2[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_1[ax0, ax1, ax2, ax3, ax4], "uint8") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_3"): + with T.sblock("T_cast_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_3[ax0, ax1, ax2, ax3, ax4]) T_cast_3[ax0, ax1, ax2, ax3, ax4] = T.cast(T_cast_2[ax0, ax1, ax2, ax3, ax4], "float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_3[ax0, ax1, ax2, ax3, ax4], p4[0]) T.writes(T_subtract[ax0, ax1, ax2, ax3, ax4]) T_subtract[ax0, ax1, ax2, ax3, ax4] = T_cast_3[ax0, ax1, ax2, ax3, ax4] - p4[0] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(compile_engine_const[()], T_subtract[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) T_multiply_1[ax0, ax1, ax2, ax3, ax4] = compile_engine_const[()] * T_subtract[ax0, ax1, ax2, ax3, ax4] - with T.block("compile_engine_const_2"): + with T.sblock("compile_engine_const_2"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_2[()]) compile_engine_const_2[()] = T.float32(0.5) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_multiply_1[ax0, ax1, ax2, ax3, ax4], compile_engine_const_2[()]) T.writes(T_add_2[ax0, ax1, ax2, ax3, ax4]) T_add_2[ax0, ax1, ax2, ax3, ax4] = T_multiply_1[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_2[()] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_floor_1"): + with T.sblock("T_floor_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_floor_1[ax0, ax1, ax2, ax3, ax4]) T_floor_1[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_4"): + with T.sblock("T_cast_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_floor_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_4[ax0, ax1, ax2, ax3, ax4]) T_cast_4[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_1[ax0, ax1, ax2, ax3, ax4], "int32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_3"): + with T.sblock("T_add_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_4[ax0, ax1, ax2, ax3, ax4], p5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_add_3[ax0, ax1, ax2, ax3, ax4]) T_add_3[ax0, ax1, ax2, ax3, ax4] = T_cast_4[ax0, ax1, ax2, ax3, ax4] + p5[ax0, ax1, ax2, ax3, ax4] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_2, i1_2, i2_2, i3_2, i4_2 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_3[i0_2, i1_2, i2_2, i3_2, i4_2]) T.writes(compute_2[i0_2, i1_2, i2_2, i3_2, i4_2]) compute_2[i0_2, i1_2, i2_2, i3_2, i4_2] = T.max(T.min(T_add_3[i0_2, i1_2, i2_2, i3_2, i4_2], 255), 0) for i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_5"): + with T.sblock("T_cast_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_3, i1_3, i2_3, i3_3, i4_3]) T.reads(compute_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_5[ax0, ax1, ax2, ax3, ax4]) T_cast_5[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_2[ax0, ax1, ax2, ax3, ax4], "uint8") for i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(1, 128, 7, 7, 16): - with T.block("compute_2"): + with T.sblock("compute_2"): i0_5, i1_5, i2_5, i3_5, i4_5 = T.axis.remap("SSSSS", [i0_4, i1_4, i2_4, i3_4, i4_4]) T.reads(T_cast_5[i0_5, i1_5, i2_5, i3_5, i4_5]) T.writes(compute[i0_5, i1_5, i2_5, i3_5, i4_5]) @@ -900,7 +900,7 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") compile_engine_const = T.alloc_buffer([], dtype="float32") conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") T_add = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") @@ -930,180 +930,180 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, compute_1 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") T_cast_8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") compute_2 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="uint8") - with T.block("compile_engine_const"): + with T.sblock("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.float32(0.95489668846130371) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 128, 7, 7, 16, 1, 1, 32, 4, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) T.reads(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) - T.block_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) + T.sblock_attr({"schedule_rule":"meta_schedule.conv2d_NCHWc_int8", "workload":["conv2d_NCHWc_int8.x86", ["TENSOR", [1, 32, 7, 7, 16], "uint8"], ["TENSOR", [128, 32, 1, 1, 4, 16, 4], "int8"], [1, 1], [0, 0, 0, 0], [1, 1], "NCHW16c", "NCHW16c", "int32"]}) with T.init(): conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.cast(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32") * T.cast(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4], p2[ax0, ax1, 0, 0, ax4]) T.writes(T_add[ax0, ax1, ax2, ax3, ax4]) T_add[ax0, ax1, ax2, ax3, ax4] = conv2d_NCHWc_int8[ax0, ax1, ax2, ax3, ax4] + p2[ax0, ax1, 0, 0, ax4] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast"): + with T.sblock("T_cast"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_1[ax0, ax1, ax2, ax3, ax4]) T_cast_1[ax0, ax1, ax2, ax3, ax4] = T.cast(T_add[ax0, ax1, ax2, ax3, ax4], "float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_1[ax0, ax1, ax2, ax3, ax4], p3[ax0, ax1, 0, 0, ax4]) T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) T_multiply[ax0, ax1, ax2, ax3, ax4] = T_cast_1[ax0, ax1, ax2, ax3, ax4] * p3[ax0, ax1, 0, 0, ax4] - with T.block("compile_engine_const_1"): + with T.sblock("compile_engine_const_1"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = T.float32(65.5) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_multiply[ax0, ax1, ax2, ax3, ax4], compile_engine_const_1[()]) T.writes(T_add_1[ax0, ax1, ax2, ax3, ax4]) T_add_1[ax0, ax1, ax2, ax3, ax4] = T_multiply[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_1[()] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_floor"): + with T.sblock("T_floor"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_floor[ax0, ax1, ax2, ax3, ax4]) T_floor[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_floor[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_2[ax0, ax1, ax2, ax3, ax4]) T_cast_2[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor[ax0, ax1, ax2, ax3, ax4], "int32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_2[i0_1, i1_1, i2_1, i3_1, i4_1]) T.writes(compute[i0_1, i1_1, i2_1, i3_1, i4_1]) compute[i0_1, i1_1, i2_1, i3_1, i4_1] = T.max(T.min(T_cast_2[i0_1, i1_1, i2_1, i3_1, i4_1], 255), 0) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(compute[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_3[ax0, ax1, ax2, ax3, ax4]) T_cast_3[ax0, ax1, ax2, ax3, ax4] = T.cast(compute[ax0, ax1, ax2, ax3, ax4], "uint8") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_3"): + with T.sblock("T_cast_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_4[ax0, ax1, ax2, ax3, ax4]) T_cast_4[ax0, ax1, ax2, ax3, ax4] = T.cast(T_cast_3[ax0, ax1, ax2, ax3, ax4], "float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_4[ax0, ax1, ax2, ax3, ax4], p4[0]) T.writes(T_subtract[ax0, ax1, ax2, ax3, ax4]) T_subtract[ax0, ax1, ax2, ax3, ax4] = T_cast_4[ax0, ax1, ax2, ax3, ax4] - p4[0] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(compile_engine_const[()], T_subtract[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) T_multiply_1[ax0, ax1, ax2, ax3, ax4] = compile_engine_const[()] * T_subtract[ax0, ax1, ax2, ax3, ax4] - with T.block("compile_engine_const_2"): + with T.sblock("compile_engine_const_2"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_2[()]) compile_engine_const_2[()] = T.float32(0.5) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_multiply_1[ax0, ax1, ax2, ax3, ax4], compile_engine_const_2[()]) T.writes(T_add_2[ax0, ax1, ax2, ax3, ax4]) T_add_2[ax0, ax1, ax2, ax3, ax4] = T_multiply_1[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_2[()] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_floor_1"): + with T.sblock("T_floor_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_floor_1[ax0, ax1, ax2, ax3, ax4]) T_floor_1[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_4"): + with T.sblock("T_cast_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_floor_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_5[ax0, ax1, ax2, ax3, ax4]) T_cast_5[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_1[ax0, ax1, ax2, ax3, ax4], "int32") - with T.block("compile_engine_const_3"): + with T.sblock("compile_engine_const_3"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_3[()]) compile_engine_const_3[()] = T.float32(0.71245479583740234) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_5"): + with T.sblock("T_cast_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(p5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_6[ax0, ax1, ax2, ax3, ax4]) T_cast_6[ax0, ax1, ax2, ax3, ax4] = T.cast(p5[ax0, ax1, ax2, ax3, ax4], "float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(compile_engine_const_3[()], T_cast_6[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) T_multiply_2[ax0, ax1, ax2, ax3, ax4] = compile_engine_const_3[()] * T_cast_6[ax0, ax1, ax2, ax3, ax4] - with T.block("compile_engine_const_4"): + with T.sblock("compile_engine_const_4"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_4[()]) compile_engine_const_4[()] = T.float32(0.5) for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_3"): + with T.sblock("T_add_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_multiply_2[ax0, ax1, ax2, ax3, ax4], compile_engine_const_4[()]) T.writes(T_add_3[ax0, ax1, ax2, ax3, ax4]) T_add_3[ax0, ax1, ax2, ax3, ax4] = T_multiply_2[ax0, ax1, ax2, ax3, ax4] + compile_engine_const_4[()] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_floor_2"): + with T.sblock("T_floor_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_floor_2[ax0, ax1, ax2, ax3, ax4]) T_floor_2[ax0, ax1, ax2, ax3, ax4] = T.floor(T_add_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_6"): + with T.sblock("T_cast_6"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_floor_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_7[ax0, ax1, ax2, ax3, ax4]) T_cast_7[ax0, ax1, ax2, ax3, ax4] = T.cast(T_floor_2[ax0, ax1, ax2, ax3, ax4], "int32") for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("T_add_4"): + with T.sblock("T_add_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_cast_5[ax0, ax1, ax2, ax3, ax4], T_cast_7[ax0, ax1, ax2, ax3, ax4]) T.writes(T_add_4[ax0, ax1, ax2, ax3, ax4]) T_add_4[ax0, ax1, ax2, ax3, ax4] = T_cast_5[ax0, ax1, ax2, ax3, ax4] + T_cast_7[ax0, ax1, ax2, ax3, ax4] for i0, i1, i2, i3, i4 in T.grid(1, 128, 7, 7, 16): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_2, i1_2, i2_2, i3_2, i4_2 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_add_4[i0_2, i1_2, i2_2, i3_2, i4_2]) T.writes(compute_1[i0_2, i1_2, i2_2, i3_2, i4_2]) compute_1[i0_2, i1_2, i2_2, i3_2, i4_2] = T.max(T.min(T_add_4[i0_2, i1_2, i2_2, i3_2, i4_2], 255), 0) for i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_7"): + with T.sblock("T_cast_7"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_3, i1_3, i2_3, i3_3, i4_3]) T.reads(compute_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast_8[ax0, ax1, ax2, ax3, ax4]) T_cast_8[ax0, ax1, ax2, ax3, ax4] = T.cast(compute_1[ax0, ax1, ax2, ax3, ax4], "uint8") for i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(1, 128, 7, 7, 16): - with T.block("compute_2"): + with T.sblock("compute_2"): i0_5, i1_5, i2_5, i3_5, i4_5 = T.axis.remap("SSSSS", [i0_4, i1_4, i2_4, i3_4, i4_4]) T.reads(T_cast_8[i0_5, i1_5, i2_5, i3_5, i4_5]) T.writes(compute_2[i0_5, i1_5, i2_5, i3_5, i4_5]) compute_2[i0_5, i1_5, i2_5, i3_5, i4_5] = T.max(T.min(T_cast_8[i0_5, i1_5, i2_5, i3_5, i4_5], T.uint8(255)), T.uint8(0)) for i0_6, i1_6, i2_6, i3_6, i4_6 in T.grid(1, 128, 7, 7, 16): - with T.block("T_cast_8"): + with T.sblock("T_cast_8"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0_6, i1_6, i2_6, i3_6, i4_6]) T.reads(compute_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_cast[ax0, ax1, ax2, ax3, ax4]) @@ -1118,12 +1118,12 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32") for i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused in T.parallel(128, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): for i2_1, i3_1, i4_0_1 in T.grid(7, 1, 1): for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i4_0_2_init, i0_3_init, i1_3_init, i2_3_init, i3_3_init, i4_0_3_init in T.grid(1, 1, 1, 1, 1, 1, 1, 1, 7, 1): - with T.block("conv2d_NCHWc_int8_o_init"): + with T.sblock("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, i0_2_init + i0_3_init) oc_chunk = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused + i1_2_init + i1_3_init) oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init) @@ -1132,13 +1132,13 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) for i4_1 in T.vectorized(16): - with T.block("conv2d_NCHWc_int8_init"): + with T.sblock("conv2d_NCHWc_int8_init"): oc_block_i_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1): - with T.block("conv2d_NCHWc_int8_o_update"): + with T.sblock("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, i0_2 + i0_3) oc_chunk = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused + i1_2 + i1_3) oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3) @@ -1162,7 +1162,7 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7): for ax4_fused in T.vectorized(16): - with T.block("T_cast_8"): + with T.sblock("T_cast_8"): ax0_1 = T.axis.spatial(1, ax0) ax1_1 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused + ax1) ax2_1 = T.axis.spatial(7, i2_1 + ax2) @@ -1181,7 +1181,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), # function attr dict T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") data_pad = T.alloc_buffer([1, 58, 58, 64], dtype="float32") input_tile = T.alloc_buffer([6, 6, 196, 64], dtype="float32") B = T.alloc_buffer([6, 6], dtype="float32") @@ -1192,74 +1192,74 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), conv2d_winograd = T.alloc_buffer([1, 56, 56, 64], dtype="float32") T_add = T.alloc_buffer([1, 56, 56, 64], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): - with T.block("data_pad"): + with T.sblock("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) - T.block_attr({"schedule_rule":"None"}) + T.sblock_attr({"schedule_rule":"None"}) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") for i0, i1, i2, i3 in T.grid(6, 6, 196, 64): - with T.block("input_tile"): + with T.sblock("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci]) T.writes(input_tile[eps, nu, p, ci]) - T.block_attr({"schedule_rule":"None"}) + T.sblock_attr({"schedule_rule":"None"}) input_tile[eps, nu, p, ci] = data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci] for i0, i1 in T.grid(6, 6): - with T.block("B"): + with T.sblock("B"): i, j = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(B[i, j]) - T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + T.sblock_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 196, 64, 6, 6): - with T.block("data_pack"): + with T.sblock("data_pack"): eps, nu, p, ci, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(input_tile[r_a, r_b, p, ci], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) T.writes(data_pack[eps, nu, p, ci]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) with T.init(): data_pack[eps, nu, p, ci] = T.float32(0) data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu] for i0, i1, i2, i3, i4 in T.grid(6, 6, 196, 64, 64): - with T.block("bgemm"): + with T.sblock("bgemm"): eps, nu, p, co, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) T.reads(data_pack[eps, nu, p, ci], p1[eps, nu, co, ci]) T.writes(bgemm[eps, nu, p, co]) - T.block_attr({"layout_free_placeholders":[]}) + T.sblock_attr({"layout_free_placeholders":[]}) with T.init(): bgemm[eps, nu, p, co] = T.float32(0) bgemm[eps, nu, p, co] = bgemm[eps, nu, p, co] + data_pack[eps, nu, p, ci] * p1[eps, nu, co, ci] for i0, i1 in T.grid(6, 4): - with T.block("A"): + with T.sblock("A"): i, j = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(A[i, j]) - T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + T.sblock_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 196, 64, 6, 6): - with T.block("inverse"): + with T.sblock("inverse"): vh, vw, p, co, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(bgemm[r_a, r_b, p, co], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) T.writes(inverse[vh, vw, p, co]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) with T.init(): inverse[vh, vw, p, co] = T.float32(0) inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("conv2d_winograd"): + with T.sblock("conv2d_winograd"): n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co]) T.writes(conv2d_winograd[n, h, w, co]) conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[ax0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("T_relu"): + with T.sblock("T_relu"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[ax0, ax1, ax2, ax3]) T.writes(T_relu[ax0, ax1, ax2, ax3]) @@ -1273,7 +1273,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body - # with T.block("root") + # with T.sblock("root") data_pad = T.alloc_buffer([1, 58, 58, 64], dtype="float32") input_tile = T.alloc_buffer([6, 6, 196, 64], dtype="float32") B = T.alloc_buffer([6, 6], dtype="float32") @@ -1285,80 +1285,80 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), T_add = T.alloc_buffer([1, 56, 56, 64], dtype="float32") T_add_1 = T.alloc_buffer([1, 56, 56, 64], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): - with T.block("data_pad"): + with T.sblock("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) - T.block_attr({"schedule_rule":"None"}) + T.sblock_attr({"schedule_rule":"None"}) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, p0[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32") for i0, i1, i2, i3 in T.grid(6, 6, 196, 64): - with T.block("input_tile"): + with T.sblock("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci]) T.writes(input_tile[eps, nu, p, ci]) - T.block_attr({"schedule_rule":"None"}) + T.sblock_attr({"schedule_rule":"None"}) input_tile[eps, nu, p, ci] = data_pad[p // 196, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu, ci] for i0, i1 in T.grid(6, 6): - with T.block("B"): + with T.sblock("B"): i, j = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(B[i, j]) - T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + T.sblock_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 196, 64, 6, 6): - with T.block("data_pack"): + with T.sblock("data_pack"): eps, nu, p, ci, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(input_tile[r_a, r_b, p, ci], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) T.writes(data_pack[eps, nu, p, ci]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) with T.init(): data_pack[eps, nu, p, ci] = T.float32(0) data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu] for i0, i1, i2, i3, i4 in T.grid(6, 6, 196, 64, 64): - with T.block("bgemm"): + with T.sblock("bgemm"): eps, nu, p, co, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) T.reads(data_pack[eps, nu, p, ci], p1[eps, nu, co, ci]) T.writes(bgemm[eps, nu, p, co]) - T.block_attr({"layout_free_placeholders":[]}) + T.sblock_attr({"layout_free_placeholders":[]}) with T.init(): bgemm[eps, nu, p, co] = T.float32(0) bgemm[eps, nu, p, co] = bgemm[eps, nu, p, co] + data_pack[eps, nu, p, ci] * p1[eps, nu, co, ci] for i0, i1 in T.grid(6, 4): - with T.block("A"): + with T.sblock("A"): i, j = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(A[i, j]) - T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) + T.sblock_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 196, 64, 6, 6): - with T.block("inverse"): + with T.sblock("inverse"): vh, vw, p, co, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(bgemm[r_a, r_b, p, co], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) T.writes(inverse[vh, vw, p, co]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) with T.init(): inverse[vh, vw, p, co] = T.float32(0) inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("conv2d_winograd"): + with T.sblock("conv2d_winograd"): n, h, w, co = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co]) T.writes(conv2d_winograd[n, h, w, co]) conv2d_winograd[n, h, w, co] = inverse[h % 4, w % 4, n * 196 + h // 4 * 14 + w // 4, co] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[ax0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[ax0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[ax0, ax1, ax2, ax3], p3[ax0, ax1, ax2, ax3]) T.writes(T_add_1[ax0, ax1, ax2, ax3]) T_add_1[ax0, ax1, ax2, ax3] = T_add[ax0, ax1, ax2, ax3] + p3[ax0, ax1, ax2, ax3] for i0, i1, i2, i3 in T.grid(1, 56, 56, 64): - with T.block("T_relu"): + with T.sblock("T_relu"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add_1[ax0, ax1, ax2, ax3]) T.writes(T_relu[ax0, ax1, ax2, ax3]) @@ -1372,7 +1372,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body - # with T.block("root") + # with T.sblock("root") input_tile_local = T.alloc_buffer([6, 6, 196, 64], dtype="float32", scope="local") data_pack = T.alloc_buffer([6, 6, 196, 64], dtype="float32") bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") @@ -1383,53 +1383,53 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(98, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): - with T.block("input_tile"): + with T.sblock("input_tile"): eps, nu = T.axis.remap("SS", [ax0, ax1]) p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8 + ax2) ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8 + ax3) T.reads(p0[p // 196, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1, ci]) T.writes(input_tile_local[eps, nu, p, ci]) - T.block_attr({"schedule_rule":"None"}) + T.sblock_attr({"schedule_rule":"None"}) input_tile_local[eps, nu, p, ci] = T.if_then_else(1 <= p % 196 // 14 * 4 + eps and p % 196 // 14 * 4 + eps < 57 and 1 <= p % 14 * 4 + nu and p % 14 * 4 + nu < 57, p0[p // 196, p % 196 // 14 * 4 + eps - 1, p % 14 * 4 + nu - 1, ci], T.float32(0), dtype="float32") for i0 in T.unroll(6): for i1 in T.unroll(6): - with T.block("data_pack_init"): + with T.sblock("data_pack_init"): eps, nu = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8) ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) T.reads() T.writes(data_pack[eps, nu, p, ci]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) data_pack[eps, nu, p, ci] = T.float32(0) for i4 in T.unroll(6): for i5 in T.unroll(6): - with T.block("data_pack_update"): + with T.sblock("data_pack_update"): eps, nu = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) // 896 * 14 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 112 // 8) ci = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 896 // 112 * 8 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 128 + i2_0_i3_0_i2_1_i3_1_fused_1) % 8) r_a, r_b = T.axis.remap("RR", [i4, i5]) T.reads(data_pack[eps, nu, p, ci], input_tile_local[r_a, r_b, p, ci]) T.writes(data_pack[eps, nu, p, ci]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_data_pack.cuda"}) data_pack[eps, nu, p, ci] = data_pack[eps, nu, p, ci] + input_tile_local[r_a, r_b, p, ci] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_b % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_b % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_b % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_b % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_b % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_b % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_b % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_b % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_b % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_b % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_b % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_b % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_b % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_b % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_b % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_b % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_b % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_b % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(168, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(4, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(48, thread="threadIdx.x"): for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(1, 1, 14, 1, 1, 1, 1, 1): - with T.block("bgemm_init"): + with T.sblock("bgemm_init"): eps = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3_init + i0_4_init) nu = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3_init + i1_4_init) p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3_init + i2_4_init) co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 16 + i3_3_init + i3_4_init) T.reads() T.writes(bgemm_local[eps, nu, p, co]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T.sblock_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) bgemm_local[eps, nu, p, co] = T.float32(0) for i4_0 in T.serial(2): for ax0_ax1_ax2_ax3_fused_0 in T.serial(28): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(48, thread="threadIdx.x"): for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): - with T.block("data_pack_shared"): + with T.sblock("data_pack_shared"): v0 = T.axis.spatial(6, (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 896) v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28) v2 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 896 // 32) @@ -1440,7 +1440,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), for ax0_ax1_ax2_ax3_fused_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(48, thread="threadIdx.x"): for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): - with T.block("p1_shared"): + with T.sblock("p1_shared"): v0 = T.axis.spatial(6, (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 512) v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28) v2 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 4 * 16 + (ax0_ax1_ax2_ax3_fused_0 * 192 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 512 // 32) @@ -1449,7 +1449,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), T.writes(p1_shared[v0, v1, v2, v3]) p1_shared[v0, v1, v2, v3] = p1[v0, v1, v2, v3] for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(2, 1, 1, 14, 1, 16, 1, 1, 1, 1): - with T.block("bgemm_update"): + with T.sblock("bgemm_update"): eps = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + i0_3 + i0_4) nu = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + i1_3 + i1_4) p = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + i2_3 + i2_4) @@ -1457,10 +1457,10 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), ci = T.axis.reduce(64, i4_0 * 32 + i4_1 * 16 + i4_2) T.reads(bgemm_local[eps, nu, p, co], data_pack_shared[eps, nu, p, ci], p1_shared[eps, nu, co, ci]) T.writes(bgemm_local[eps, nu, p, co]) - T.block_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + T.sblock_attr({"layout_free_placeholders":[], "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) bgemm_local[eps, nu, p, co] = bgemm_local[eps, nu, p, co] + data_pack_shared[eps, nu, p, ci] * p1_shared[eps, nu, co, ci] for ax0, ax1, ax2, ax3 in T.grid(1, 1, 14, 1): - with T.block("bgemm_local"): + with T.sblock("bgemm_local"): v0 = T.axis.spatial(6, i0_1_i1_1_i2_1_i3_1_fused // 2 * 3 + i0_2_i1_2_i2_2_i3_2_fused // 16 + ax0) v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 28 + ax1) v2 = T.axis.spatial(196, i0_0_i1_0_i2_0_i3_0_fused % 28 // 4 * 28 + i0_1_i1_1_i2_1_i3_1_fused % 2 * 14 + ax2) @@ -1472,18 +1472,18 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(512, thread="threadIdx.x"): for i0 in T.unroll(4): for i1 in T.unroll(4): - with T.block("inverse_init"): + with T.sblock("inverse_init"): T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) vh, vw = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) // 448 * 7 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 224 // 32) co = T.axis.spatial(64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 448 // 224 * 32 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 32) T.reads() T.writes(inverse[vh, vw, p, co]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) inverse[vh, vw, p, co] = T.float32(0) for i4 in T.unroll(6): for i5 in T.unroll(6): - with T.block("inverse_update"): + with T.sblock("inverse_update"): T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) vh, vw = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial(196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) // 448 * 7 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 512 + i2_0_i3_0_i2_1_i3_1_fused_1) % 224 // 32) @@ -1491,11 +1491,11 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), r_a, r_b = T.axis.remap("RR", [i4, i5]) T.reads(inverse[vh, vw, p, co], bgemm[r_a, r_b, p, co]) T.writes(inverse[vh, vw, p, co]) - T.block_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) + T.sblock_attr({"auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], "schedule_rule":"meta_schedule.winograd_inverse.cuda"}) inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co] * T.Select(r_a % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_b % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_b % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_b % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_b % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_b % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_b % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_b % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_b % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_b % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_b % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_b % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_b % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_b % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) for i0_i1_i2_i3_fused_0 in T.thread_binding(1568, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":1024, "pragma_unroll_explicit":1}): for i0_i1_i2_i3_fused_1 in T.thread_binding(128, thread="threadIdx.x"): - with T.block("conv2d_winograd"): + with T.sblock("conv2d_winograd"): n = T.axis.spatial(1, 0) h = T.axis.spatial(56, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) // 3584) w = T.axis.spatial(56, (i0_i1_i2_i3_fused_0 * 128 + i0_i1_i2_i3_fused_1) % 3584 // 64) @@ -1512,7 +1512,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") @@ -1522,13 +1522,13 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " compute_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_subtract_1 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) T.writes(conv2d_nhwc[nn, yy, xx, ff]) @@ -1536,43 +1536,43 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " conv2d_nhwc[nn, yy, xx, ff] = 0 conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("compute"): + with T.sblock("compute"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) T.reads(p7[()], compute_1[ax0, ax1, ax2, ax3]) T.writes(T_add_1[ax0, ax1, ax2, ax3]) T_add_1[ax0, ax1, ax2, ax3] = p7[()] + compute_1[ax0, ax1, ax2, ax3] for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) T.reads(compute_2[ax0, ax1, ax2, ax3], p8[0]) T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p8[0] for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 56, 56, 256): - with T.block("compute_2"): + with T.sblock("compute_2"): i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) T.writes(compute[i0_8, i1_8, i2_8, i3_8]) @@ -1586,7 +1586,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") pad_temp = T.alloc_buffer([16, 56, 56, 64], dtype="int8") conv2d_nhwc = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_subtract = T.alloc_buffer([16, 56, 56, 256], dtype="int32") @@ -1599,13 +1599,13 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " compute_4 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") T_add_2 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 64): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 56, 56, 256, 1, 1, 64): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc]) T.writes(conv2d_nhwc[nn, yy, xx, ff]) @@ -1613,61 +1613,61 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " conv2d_nhwc[nn, yy, xx, ff] = 0 conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32") for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3] for i0, i1, i2, i3 in T.grid(16, 56, 56, 256): - with T.block("compute"): + with T.sblock("compute"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2]) T.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32") for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 56, 56, 256): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) T.reads(p7[()], compute_1[ax0, ax1, ax2, ax3]) T.writes(T_add_1[ax0, ax1, ax2, ax3]) T_add_1[ax0, ax1, ax2, ax3] = p7[()] + compute_1[ax0, ax1, ax2, ax3] for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 56, 56, 256): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5]) T.writes(compute_2[i0_5, i1_5, i2_5, i3_5]) compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0) for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 56, 56, 256): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) T.reads(compute_2[ax0, ax1, ax2, ax3], p8[0]) T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p8[0] for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 56, 56, 256): - with T.block("compute_2"): + with T.sblock("compute_2"): i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8]) T.writes(compute_3[i0_8, i1_8, i2_8, i3_8]) compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1457846997, 31, 0, dtype="int32") for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 56, 56, 256): - with T.block("compute_3"): + with T.sblock("compute_3"): i0_10, i1_10, i2_10, i3_10 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9]) T.reads(p9[i0_10, i1_10, i2_10, i3_10]) T.writes(compute_4[i0_10, i1_10, i2_10, i3_10]) compute_4[i0_10, i1_10, i2_10, i3_10] = T.q_multiply_shift(p9[i0_10, i1_10, i2_10, i3_10], 2101000910, 31, 0, dtype="int32") for i0_11, i1_11, i2_11, i3_11 in T.grid(16, 56, 56, 256): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_11, i1_11, i2_11, i3_11]) T.reads(compute_3[ax0, ax1, ax2, ax3], compute_4[ax0, ax1, ax2, ax3]) T.writes(T_add_2[ax0, ax1, ax2, ax3]) T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + compute_4[ax0, ax1, ax2, ax3] for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): - with T.block("compute_4"): + with T.sblock("compute_4"): i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) T.reads(T_add_2[i0_13, i1_13, i2_13, i3_13]) T.writes(compute[i0_13, i1_13, i2_13, i3_13]) @@ -1679,10 +1679,10 @@ class Conv2dInt8_with_predicate_scheduled: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): T.func_attr({"tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") @@ -1694,53 +1694,53 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): for ax0_ax1_fused in range(1024): - with T.block("pad_temp_reindex_shared"): + with T.sblock("pad_temp_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] for ax0_ax1_ax2_ax3_fused in range(2048): - with T.block("p1_reindex_shared"): + with T.sblock("p1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(p1[v2, v0, v1, v3]) T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): for ax0_0_1, ax1_0_1 in T.grid(1, 1): - with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax1_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): - with T.block("pad_temp_reindex_shared_wmma.matrix_a"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): - with T.block("p1_reindex_shared_wmma.matrix_b_o"): + with T.sblock("p1_reindex_shared_wmma.matrix_b_o"): v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax3_0) T.reads(p1_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("p1_reindex_shared_wmma.matrix_b"): + with T.sblock("p1_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads(p1_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) T.writes(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): - with T.block("conv2d_nhwc_o"): + with T.sblock("conv2d_nhwc_o"): v0_o = T.axis.spatial(1, ax0_0 + ax0_1 + ax0_2) v1_o = T.axis.spatial(1, ax1_0 + ax1_1 + ax1_2) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) @@ -1748,36 +1748,36 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): - with T.block("conv2d_nhwc_reindex_shared"): + with T.sblock("conv2d_nhwc_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) @@ -1801,8 +1801,8 @@ def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): def test_dense_add_cpu(): def apply_anchor_trace(sch: Schedule) -> None: - b0 = sch.get_block(name="T_matmul_NT", func_name="main") - b1 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="T_matmul_NT", func_name="main") + b1 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8 = sch.sample_perfect_tile( @@ -1827,7 +1827,7 @@ def apply_anchor_trace(sch: Schedule) -> None: ) sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26) sch.enter_postproc() - b27 = sch.get_block(name="root", func_name="main") + b27 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel") sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize") sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit") @@ -1842,10 +1842,10 @@ def apply_anchor_trace(sch: Schedule) -> None: sch.parallel(loop=l45) l46 = sch.fuse(l44, preserve_unit_iters=True) sch.vectorize(loop=l46) - b47 = sch.get_block(name="T_matmul_NT", func_name="main") + b47 = sch.get_sblock(name="T_matmul_NT", func_name="main") l48, l49, l50, l51, l52, l53, l54, l55, l56 = sch.get_loops(block=b47) b57 = sch.decompose_reduction(block=b47, loop=l51) - b58 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b58 = sch.get_sblock(name="T_matmul_NT_update", func_name="main") b59 = sch.cache_read(block=b58, read_buffer_index=2, storage_scope="global") sch.transform_layout( block=b58, @@ -1873,8 +1873,8 @@ def apply_anchor_trace(sch: Schedule) -> None: def test_dense_add_cpu_no_write_cache(): def apply_trace(sch): - b0 = sch.get_block(name="T_matmul_NT", func_name="main") - b1 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="T_matmul_NT", func_name="main") + b1 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8 = sch.sample_perfect_tile( @@ -1897,7 +1897,7 @@ def apply_trace(sch): ) sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25) sch.enter_postproc() - b26 = sch.get_block(name="root", func_name="main") + b26 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.parallel") sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.vectorize") sch.unannotate(block_or_loop=b26, ann_key="meta_schedule.unroll_explicit") @@ -1909,10 +1909,10 @@ def apply_trace(sch): sch.vectorize(loop=l39) sch.annotate(block_or_loop=l38, ann_key="pragma_auto_unroll_max_step", ann_val=16) sch.annotate(block_or_loop=l38, ann_key="pragma_unroll_explicit", ann_val=1) - b40 = sch.get_block(name="T_matmul_NT", func_name="main") + b40 = sch.get_sblock(name="T_matmul_NT", func_name="main") l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b40) b48 = sch.decompose_reduction(block=b40, loop=l42) - b49 = sch.get_block(name="T_matmul_NT_update", func_name="main") + b49 = sch.get_sblock(name="T_matmul_NT_update", func_name="main") b50 = sch.cache_read(block=b49, read_buffer_index=2, storage_scope="global") sch.transform_layout( block=b49, @@ -1941,8 +1941,8 @@ def apply_trace(sch): def test_dense_add_gpu(): def apply_anchor_trace(sch: Schedule) -> None: - b0 = sch.get_block(name="T_matmul_NT", func_name="main") - b1 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="T_matmul_NT", func_name="main") + b1 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8, v9 = sch.sample_perfect_tile( @@ -2019,7 +2019,7 @@ def apply_anchor_trace(sch: Schedule) -> None: l67, l68, l69 = sch.split(loop=l66, factors=[None, 128, 4], preserve_unit_iters=True) sch.vectorize(loop=l69) sch.bind(loop=l68, thread_axis="threadIdx.x") - b70 = sch.get_block(name="root", func_name="main") + b70 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b70, ann_key="meta_schedule.unroll_explicit") b71, b72, b73, b74 = sch.get_child_blocks(b70) l75, l76, l77, l78, l79, l80, l81 = sch.get_loops(block=b71) @@ -2034,7 +2034,7 @@ def apply_anchor_trace(sch: Schedule) -> None: l99, l100, l101, l102, l103 = sch.get_loops(block=b74) sch.annotate(block_or_loop=l99, ann_key="pragma_auto_unroll_max_step", ann_val=64) sch.annotate(block_or_loop=l99, ann_key="pragma_unroll_explicit", ann_val=1) - b104 = sch.get_block(name="T_matmul_NT", func_name="main") + b104 = sch.get_sblock(name="T_matmul_NT", func_name="main") l105, l106, l107, l108, l109, l110, l111, l112, l113, l114 = sch.get_loops(block=b104) b115 = sch.decompose_reduction(block=b104, loop=l108) @@ -2043,22 +2043,22 @@ def apply_anchor_trace(sch: Schedule) -> None: def test_conv2d_int8_tensorcore(): def apply_trace(sch): - b0 = sch.get_block(name="pad_temp", func_name="main") - b1 = sch.get_block(name="conv2d_nhwc", func_name="main") - b2 = sch.get_block(name="T_subtract", func_name="main") - b3 = sch.get_block(name="T_add", func_name="main") - b4 = sch.get_block(name="T_cast", func_name="main") - b5 = sch.get_block(name="T_multiply", func_name="main") - b6 = sch.get_block(name="T_add_1", func_name="main") - b7 = sch.get_block(name="T_right_shift", func_name="main") - b8 = sch.get_block(name="T_cast_1", func_name="main") - b9 = sch.get_block(name="T_add_2", func_name="main") - b10 = sch.get_block(name="compute", func_name="main") - b11 = sch.get_block(name="T_cast_2", func_name="main") - b12 = sch.get_block(name="T_cast_3", func_name="main") - b13 = sch.get_block(name="T_subtract_1", func_name="main") - b14 = sch.get_block(name="compute_1", func_name="main") - b15 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="pad_temp", func_name="main") + b1 = sch.get_sblock(name="conv2d_nhwc", func_name="main") + b2 = sch.get_sblock(name="T_subtract", func_name="main") + b3 = sch.get_sblock(name="T_add", func_name="main") + b4 = sch.get_sblock(name="T_cast", func_name="main") + b5 = sch.get_sblock(name="T_multiply", func_name="main") + b6 = sch.get_sblock(name="T_add_1", func_name="main") + b7 = sch.get_sblock(name="T_right_shift", func_name="main") + b8 = sch.get_sblock(name="T_cast_1", func_name="main") + b9 = sch.get_sblock(name="T_add_2", func_name="main") + b10 = sch.get_sblock(name="compute", func_name="main") + b11 = sch.get_sblock(name="T_cast_2", func_name="main") + b12 = sch.get_sblock(name="T_cast_3", func_name="main") + b13 = sch.get_sblock(name="T_subtract_1", func_name="main") + b14 = sch.get_sblock(name="compute_1", func_name="main") + b15 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") b16 = sch.reindex(block=b1, buffer=("write", 0)) b17 = sch.reindex(block=b1, buffer=("read", 0)) @@ -2383,7 +2383,7 @@ def apply_trace(sch): l217, l218, l219 = sch.split(loop=l216, factors=[None, 16, 8], preserve_unit_iters=True) sch.vectorize(loop=l219) sch.bind(loop=l218, thread_axis="threadIdx.x") - b220 = sch.get_block(name="root", func_name="main") + b220 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b220, ann_key="meta_schedule.unroll_explicit") b221, b222, b223, b224, b225, b226, b227 = sch.get_child_blocks(b220) l228, l229, l230, l231, l232, l233, l234, l235, l236 = sch.get_loops(block=b221) @@ -2438,7 +2438,7 @@ def apply_trace(sch): l291, l292, l293, l294, l295 = sch.get_loops(block=b227) sch.annotate(block_or_loop=l291, ann_key="pragma_auto_unroll_max_step", ann_val=512) sch.annotate(block_or_loop=l291, ann_key="pragma_unroll_explicit", ann_val=1) - b296 = sch.get_block(name="conv2d_nhwc_o", func_name="main") + b296 = sch.get_sblock(name="conv2d_nhwc_o", func_name="main") ( l297, l298, @@ -2466,19 +2466,21 @@ def apply_trace(sch): ) sch.unannotate(block_or_loop=b296, ann_key="meta_schedule.auto_tensorize_init") sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize_init") - b314 = sch.get_block(name="conv2d_nhwc_o_init", func_name="main") + b314 = sch.get_sblock(name="conv2d_nhwc_o_init", func_name="main") sch.unannotate(block_or_loop=b314, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b314, tensor_intrin="wmma_fill_16x16x16_s32") - b315 = sch.get_block(name="pad_temp_reindex_shared_wmma.matrix_a_o", func_name="main") + b315 = sch.get_sblock(name="pad_temp_reindex_shared_wmma.matrix_a_o", func_name="main") sch.unannotate(block_or_loop=b315, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b315, tensor_intrin="wmma_load_16x16x16_s8_a_shared") - b316 = sch.get_block(name="p1_reindex_shared_wmma.matrix_b_o", func_name="main") + b316 = sch.get_sblock(name="p1_reindex_shared_wmma.matrix_b_o", func_name="main") sch.unannotate(block_or_loop=b316, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b316, tensor_intrin="wmma_load_16x16x16_s8_b_trans_shared") - b317 = sch.get_block(name="conv2d_nhwc_o_update", func_name="main") + b317 = sch.get_sblock(name="conv2d_nhwc_o_update", func_name="main") sch.unannotate(block_or_loop=b317, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b317, tensor_intrin="wmma_sync_16x16x16_s8s8s32_trans") - b318 = sch.get_block(name="conv2d_nhwc_reindex_shared_wmma.accumulator_o", func_name="main") + b318 = sch.get_sblock( + name="conv2d_nhwc_reindex_shared_wmma.accumulator_o", func_name="main" + ) sch.unannotate(block_or_loop=b318, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b318, tensor_intrin="wmma_store_16x16x16_s32_shared") @@ -2487,28 +2489,28 @@ def apply_trace(sch): def test_conv2d_int8_vnni(): def apply_trace(sch): - b0 = sch.get_block(name="compile_engine_const", func_name="main") - b1 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") - b2 = sch.get_block(name="T_add", func_name="main") - b3 = sch.get_block(name="T_cast", func_name="main") - b4 = sch.get_block(name="T_multiply", func_name="main") - b5 = sch.get_block(name="compile_engine_const_1", func_name="main") - b6 = sch.get_block(name="T_add_1", func_name="main") - b7 = sch.get_block(name="T_floor", func_name="main") - b8 = sch.get_block(name="T_cast_1", func_name="main") - b9 = sch.get_block(name="compute", func_name="main") - b10 = sch.get_block(name="T_cast_2", func_name="main") - b11 = sch.get_block(name="T_cast_3", func_name="main") - b12 = sch.get_block(name="T_subtract", func_name="main") - b13 = sch.get_block(name="T_multiply_1", func_name="main") - b14 = sch.get_block(name="compile_engine_const_2", func_name="main") - b15 = sch.get_block(name="T_add_2", func_name="main") - b16 = sch.get_block(name="T_floor_1", func_name="main") - b17 = sch.get_block(name="T_cast_4", func_name="main") - b18 = sch.get_block(name="T_add_3", func_name="main") - b19 = sch.get_block(name="compute_1", func_name="main") - b20 = sch.get_block(name="T_cast_5", func_name="main") - b21 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="compile_engine_const", func_name="main") + b1 = sch.get_sblock(name="conv2d_NCHWc_int8", func_name="main") + b2 = sch.get_sblock(name="T_add", func_name="main") + b3 = sch.get_sblock(name="T_cast", func_name="main") + b4 = sch.get_sblock(name="T_multiply", func_name="main") + b5 = sch.get_sblock(name="compile_engine_const_1", func_name="main") + b6 = sch.get_sblock(name="T_add_1", func_name="main") + b7 = sch.get_sblock(name="T_floor", func_name="main") + b8 = sch.get_sblock(name="T_cast_1", func_name="main") + b9 = sch.get_sblock(name="compute", func_name="main") + b10 = sch.get_sblock(name="T_cast_2", func_name="main") + b11 = sch.get_sblock(name="T_cast_3", func_name="main") + b12 = sch.get_sblock(name="T_subtract", func_name="main") + b13 = sch.get_sblock(name="T_multiply_1", func_name="main") + b14 = sch.get_sblock(name="compile_engine_const_2", func_name="main") + b15 = sch.get_sblock(name="T_add_2", func_name="main") + b16 = sch.get_sblock(name="T_floor_1", func_name="main") + b17 = sch.get_sblock(name="T_cast_4", func_name="main") + b18 = sch.get_sblock(name="T_add_3", func_name="main") + b19 = sch.get_sblock(name="compute_1", func_name="main") + b20 = sch.get_sblock(name="T_cast_5", func_name="main") + b21 = sch.get_sblock(name="root", func_name="main") sch.compute_inline(block=b20) sch.compute_inline(block=b19) sch.compute_inline(block=b18) @@ -2627,7 +2629,7 @@ def apply_trace(sch): ) sch.annotate(block_or_loop=b21, ann_key="meta_schedule.unroll_explicit", ann_val=v120) sch.enter_postproc() - b121 = sch.get_block(name="root", func_name="main") + b121 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.parallel") sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.vectorize") sch.unannotate(block_or_loop=b121, ann_key="meta_schedule.unroll_explicit") @@ -2673,7 +2675,7 @@ def apply_trace(sch): sch.vectorize(loop=l164) sch.annotate(block_or_loop=l155, ann_key="pragma_auto_unroll_max_step", ann_val=64) sch.annotate(block_or_loop=l155, ann_key="pragma_unroll_explicit", ann_val=1) - b165 = sch.get_block(name="conv2d_NCHWc_int8_o", func_name="main") + b165 = sch.get_sblock(name="conv2d_NCHWc_int8_o", func_name="main") ( l166, l167, @@ -2703,12 +2705,12 @@ def apply_trace(sch): b190 = sch.decompose_reduction(block=b165, loop=l170) sch.unannotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize") sch.annotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize", ann_val="") - b191 = sch.get_block(name="conv2d_NCHWc_int8_o_init", func_name="main") + b191 = sch.get_sblock(name="conv2d_NCHWc_int8_o_init", func_name="main") sch.unannotate(block_or_loop=b191, ann_key="meta_schedule.auto_tensorize") (b192,) = sch.get_child_blocks(b191) (l193,) = sch.get_loops(block=b192) sch.vectorize(loop=l193) - b194 = sch.get_block(name="conv2d_NCHWc_int8_o_update", func_name="main") + b194 = sch.get_sblock(name="conv2d_NCHWc_int8_o_update", func_name="main") sch.unannotate(block_or_loop=b194, ann_key="meta_schedule.auto_tensorize") sch.tensorize(block_or_loop=b194, tensor_intrin=VNNI_INTRIN) @@ -2724,15 +2726,15 @@ def apply_trace(sch): def test_winograd_gpu(): def apply_trace(sch): - b0 = sch.get_block(name="B", func_name="main") - b1 = sch.get_block(name="data_pack", func_name="main") - b2 = sch.get_block(name="bgemm", func_name="main") - b3 = sch.get_block(name="A", func_name="main") - b4 = sch.get_block(name="inverse", func_name="main") - b5 = sch.get_block(name="conv2d_winograd", func_name="main") - b6 = sch.get_block(name="T_add", func_name="main") - b7 = sch.get_block(name="T_relu", func_name="main") - b8 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="B", func_name="main") + b1 = sch.get_sblock(name="data_pack", func_name="main") + b2 = sch.get_sblock(name="bgemm", func_name="main") + b3 = sch.get_sblock(name="A", func_name="main") + b4 = sch.get_sblock(name="inverse", func_name="main") + b5 = sch.get_sblock(name="conv2d_winograd", func_name="main") + b6 = sch.get_sblock(name="T_add", func_name="main") + b7 = sch.get_sblock(name="T_relu", func_name="main") + b8 = sch.get_sblock(name="root", func_name="main") sch.compute_inline(block=b0) (b9,) = sch.get_producers(block=b1) (b10,) = sch.get_producers(block=b9) @@ -2927,7 +2929,7 @@ def apply_trace(sch): l162, l163, l164 = sch.split(loop=l161, factors=[None, 48, 4], preserve_unit_iters=True) sch.vectorize(loop=l164) sch.bind(loop=l163, thread_axis="threadIdx.x") - b165 = sch.get_block(name="root", func_name="main") + b165 = sch.get_sblock(name="root", func_name="main") sch.unannotate(block_or_loop=b165, ann_key="meta_schedule.unroll_explicit") b166, b167, b168, b169, b170, b171, b172, b173 = sch.get_child_blocks(b165) l174, l175, l176, l177, l178, l179 = sch.get_loops(block=b166) @@ -2969,10 +2971,10 @@ def apply_trace(sch): l227, l228 = sch.get_loops(block=b173) sch.annotate(block_or_loop=l227, ann_key="pragma_auto_unroll_max_step", ann_val=1024) sch.annotate(block_or_loop=l227, ann_key="pragma_unroll_explicit", ann_val=1) - b229 = sch.get_block(name="data_pack", func_name="main") + b229 = sch.get_sblock(name="data_pack", func_name="main") l230, l231, l232, l233, l234, l235 = sch.get_loops(block=b229) b236 = sch.decompose_reduction(block=b229, loop=l234) - b237 = sch.get_block(name="bgemm", func_name="main") + b237 = sch.get_sblock(name="bgemm", func_name="main") ( l238, l239, @@ -2990,7 +2992,7 @@ def apply_trace(sch): l251, ) = sch.get_loops(block=b237) b252 = sch.decompose_reduction(block=b237, loop=l241) - b253 = sch.get_block(name="inverse", func_name="main") + b253 = sch.get_sblock(name="inverse", func_name="main") l254, l255, l256, l257, l258, l259 = sch.get_loops(block=b253) b260 = sch.decompose_reduction(block=b253, loop=l258) @@ -3018,16 +3020,16 @@ def test_inline_order(): # such cases. def apply_trace(sch: Schedule) -> None: - b0 = sch.get_block(name="pad_temp", func_name="main") - b1 = sch.get_block(name="conv2d_nhwc", func_name="main") - b2 = sch.get_block(name="T_subtract", func_name="main") - b3 = sch.get_block(name="T_add", func_name="main") - b4 = sch.get_block(name="compute", func_name="main") - b5 = sch.get_block(name="T_add_1", func_name="main") - b6 = sch.get_block(name="compute_1", func_name="main") - b7 = sch.get_block(name="T_subtract_1", func_name="main") - b8 = sch.get_block(name="compute_2", func_name="main") - b9 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="pad_temp", func_name="main") + b1 = sch.get_sblock(name="conv2d_nhwc", func_name="main") + b2 = sch.get_sblock(name="T_subtract", func_name="main") + b3 = sch.get_sblock(name="T_add", func_name="main") + b4 = sch.get_sblock(name="compute", func_name="main") + b5 = sch.get_sblock(name="T_add_1", func_name="main") + b6 = sch.get_sblock(name="compute_1", func_name="main") + b7 = sch.get_sblock(name="T_subtract_1", func_name="main") + b8 = sch.get_sblock(name="compute_2", func_name="main") + b9 = sch.get_sblock(name="root", func_name="main") sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") b10 = sch.reindex(block=b1, buffer=("write", 0)) b11 = sch.reindex(block=b1, buffer=("read", 0)) diff --git a/tests/python/meta_schedule/test_meta_schedule_tune_context.py b/tests/python/meta_schedule/test_meta_schedule_tune_context.py index 69b38c82a11f..2e8aa22b2e50 100644 --- a/tests/python/meta_schedule/test_meta_schedule_tune_context.py +++ b/tests/python/meta_schedule/test_meta_schedule_tune_context.py @@ -37,7 +37,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/meta_schedule/test_meta_schedule_tune_tir.py b/tests/python/meta_schedule/test_meta_schedule_tune_tir.py index 0d349752537a..84aee8cd9aa0 100644 --- a/tests/python/meta_schedule/test_meta_schedule_tune_tir.py +++ b/tests/python/meta_schedule/test_meta_schedule_tune_tir.py @@ -28,7 +28,7 @@ from tvm.meta_schedule.testing.local_rpc import LocalRPC from tvm.script import tir as T from tvm.target import Target -from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule import SBlockRV, Schedule logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -40,7 +40,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -53,11 +53,11 @@ def two_step(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j in T.grid(1024, 1024): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(1024, 1024): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 3.0 @@ -152,7 +152,7 @@ class RemoveBlock(ms.schedule_rule.PyScheduleRule): def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: pass - def apply(self, sch: Schedule, block: BlockRV): + def apply(self, sch: Schedule, block: SBlockRV): if sch.get(block).name_hint == "root": return [sch] sch = sch.copy() diff --git a/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py index b461f39dd744..6f85c8b2ac76 100644 --- a/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py +++ b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py @@ -46,11 +46,11 @@ def max_pool2d_opencl( (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) ): - with T.block("pool_max"): + with T.sblock("pool_max"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] ) @@ -64,7 +64,7 @@ def max_pool2d_opencl( ] ) T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( -340282346638528859811704183484516925440.0 @@ -87,9 +87,9 @@ def te_layout_transform( (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) T.reads(x[v_self, v_i0, v_i1, v_i2]) T.writes( @@ -110,11 +110,11 @@ def te_layout_transform2( (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for self, i0, i1, i2, i3 in T.grid( T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) ): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) @@ -171,11 +171,11 @@ def max_pool2d_opencl( (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) ): - with T.block("pool_max"): + with T.sblock("pool_max"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] ) @@ -189,7 +189,7 @@ def max_pool2d_opencl( ] ) T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( -340282346638528859811704183484516925440.0 @@ -212,9 +212,9 @@ def te_layout_transform( (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) T.reads(x[v_self, v_i0, v_i1, v_i2]) T.writes( @@ -235,11 +235,11 @@ def te_layout_transform2( (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for self, i0, i1, i2, i3 in T.grid( T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) ): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py index a1308098216c..c6f755fd7d59 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py @@ -39,37 +39,37 @@ def gelu1( T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_multiply_1 = T.alloc_buffer((T.int64(128), T.int64(64))) compute = T.alloc_buffer((T.int64(128), T.int64(64))) T_multiply_2 = T.alloc_buffer((T.int64(128), T.int64(64))) T_add = T.alloc_buffer((T.int64(128), T.int64(64))) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(T.int64(128), T.int64(64)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -82,9 +82,9 @@ def matmul1( matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul_1[v_i0, v_i1]) @@ -99,9 +99,9 @@ def matmul2( matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul_1[v_i0, v_i1]) @@ -198,37 +198,37 @@ def gelu1( T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_multiply_1 = T.alloc_buffer((T.int64(128), T.int64(64))) compute = T.alloc_buffer((T.int64(128), T.int64(64))) T_multiply_2 = T.alloc_buffer((T.int64(128), T.int64(64))) T_add = T.alloc_buffer((T.int64(128), T.int64(64))) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(T.int64(128), T.int64(64)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -241,9 +241,9 @@ def matmul11( matmul: T.Buffer((T.int64(64), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(64), T.int64(128), T.int64(64)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul[v_i0, v_i1]) @@ -258,9 +258,9 @@ def matmul2( matmul: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul[v_i0, v_i1]) @@ -275,15 +275,15 @@ def split11( T_split_1: T.Buffer((64, 64), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax1, ax2 in T.grid(64, 64): - with T.block("T_split"): + with T.sblock("T_split"): v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T.reads(A[v_ax1, v_ax2]) T.writes(T_split[v_ax1, v_ax2]) T_split[v_ax1, v_ax2] = A[v_ax1, v_ax2] for ax1, ax2 in T.grid(64, 64): - with T.block("T_split_1"): + with T.sblock("T_split_1"): v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T.reads(A[v_ax1 + 64, v_ax2]) T.writes(T_split_1[v_ax1, v_ax2]) diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py index bd9fe4eb09d4..cea00ab75981 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py @@ -39,37 +39,37 @@ def gelu( T_multiply: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_multiply_1 = T.alloc_buffer((T.int64(128), T.int64(128))) compute = T.alloc_buffer((T.int64(128), T.int64(128))) T_multiply_2 = T.alloc_buffer((T.int64(128), T.int64(128))) T_add = T.alloc_buffer((T.int64(128), T.int64(128))) for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(T.int64(128), T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -82,9 +82,9 @@ def matmul( matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(128)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul_1[v_i0, v_i1]) @@ -128,37 +128,37 @@ def gelu1( T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_multiply_1 = T.alloc_buffer((T.int64(128), T.int64(64))) compute = T.alloc_buffer((T.int64(128), T.int64(64))) T_multiply_2 = T.alloc_buffer((T.int64(128), T.int64(64))) T_add = T.alloc_buffer((T.int64(128), T.int64(64))) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(T.int64(128), T.int64(64)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(128), T.int64(64)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -171,9 +171,9 @@ def matmul1( matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul_1[v_i0, v_i1]) @@ -188,9 +188,9 @@ def matmul2( matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(A[v_i0, v_k], B[v_k, v_i1]) T.writes(matmul_1[v_i0, v_i1]) @@ -245,9 +245,9 @@ def add( T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax0, v_ax1, v_ax2]) T.writes(T_add[v_ax0, v_ax1, v_ax2]) @@ -260,9 +260,9 @@ def divide( T_divide: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -277,9 +277,9 @@ def matmul( matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_k, v_i2]) T.writes(matmul[v_i0, v_i1, v_i2]) @@ -296,11 +296,11 @@ def matmul1( matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(256), T.int64(128) ): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -318,11 +318,11 @@ def matmul2( matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(128), T.int64(256) ): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -340,9 +340,9 @@ def maximum( T_maximum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_maximum"): + with T.sblock("T_maximum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_maximum[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -357,9 +357,9 @@ def minimum( T_minimum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_minimum"): + with T.sblock("T_minimum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]) T.writes(T_minimum[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -373,9 +373,9 @@ def reshape( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -398,9 +398,9 @@ def reshape1( T_reshape: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( A[ @@ -424,9 +424,9 @@ def reshape2( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -448,9 +448,9 @@ def reshape3( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( A[ @@ -475,10 +475,10 @@ def rms_norm( rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -488,7 +488,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -512,9 +512,9 @@ def rotary_embedding( rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[256 + v_i1 - 256, v_i3], @@ -538,14 +538,14 @@ def softmax( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(256)), "float16") T_softmax_exp = T.alloc_buffer( (T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16" ) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(256)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) @@ -555,7 +555,7 @@ def softmax( T_softmax_maxelem[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_k] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) @@ -563,7 +563,7 @@ def softmax( A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2] ) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) @@ -573,13 +573,13 @@ def softmax( T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2] ) T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) + T.sblock_attr({"axis": 3}) T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = ( T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] ) @@ -590,9 +590,9 @@ def transpose( T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -604,9 +604,9 @@ def transpose1( T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -618,9 +618,9 @@ def transpose2( T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(128), T.int64(256)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax3, v_ax2]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -632,9 +632,9 @@ def transpose3( T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -856,9 +856,9 @@ def add( T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax0, v_ax1, v_ax2]) T.writes(T_add[v_ax0, v_ax1, v_ax2]) @@ -871,9 +871,9 @@ def divide1( T_divide: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -888,11 +888,11 @@ def matmul11( matmul: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(16), T.int64(256), T.int64(256), T.int64(128) ): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -910,11 +910,11 @@ def matmul21( matmul: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(16), T.int64(256), T.int64(128), T.int64(256) ): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) @@ -932,9 +932,9 @@ def matmul3( matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(2048)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(2048), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_k, v_i2]) T.writes(matmul[v_i0, v_i1, v_i2]) @@ -951,9 +951,9 @@ def matmul4( matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(2048)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_k, v_i2]) T.writes(matmul[v_i0, v_i1, v_i2]) @@ -970,9 +970,9 @@ def maximum1( T_maximum: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_maximum"): + with T.sblock("T_maximum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_maximum[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -987,9 +987,9 @@ def minimum1( T_minimum: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_minimum"): + with T.sblock("T_minimum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]) T.writes(T_minimum[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1003,9 +1003,9 @@ def reshape11( T_reshape: T.Buffer((T.int64(256), T.int64(16), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(16), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( A[ @@ -1029,9 +1029,9 @@ def reshape21( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(16), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -1053,9 +1053,9 @@ def reshape31( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(2048)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(2048)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( A[ @@ -1079,9 +1079,9 @@ def reshape4( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(16), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -1105,10 +1105,10 @@ def rms_norm( rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -1118,7 +1118,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -1142,9 +1142,9 @@ def rotary_embedding( rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[256 + v_i1 - 256, v_i3], @@ -1168,9 +1168,9 @@ def rotary_embedding1( rotary: T.Buffer((T.int64(1), 256, T.int64(16), T.int64(128)), "float16"), ): T.func_attr({"global_symbol": "rotary_embedding", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(16), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[256 + v_i1 - 256, v_i3], @@ -1194,14 +1194,14 @@ def softmax1( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(16), T.int64(256)), "float16") T_softmax_exp = T.alloc_buffer( (T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16" ) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(16), T.int64(256)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) @@ -1211,7 +1211,7 @@ def softmax1( T_softmax_maxelem[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_k] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) @@ -1219,7 +1219,7 @@ def softmax1( A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2] ) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) @@ -1229,13 +1229,13 @@ def softmax1( T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2] ) T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) + T.sblock_attr({"axis": 3}) T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = ( T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] ) @@ -1246,9 +1246,9 @@ def transpose11( T_transpose: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1260,9 +1260,9 @@ def transpose21( T_transpose: T.Buffer((T.int64(1), T.int64(16), T.int64(128), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(128), T.int64(256)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax3, v_ax2]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1274,9 +1274,9 @@ def transpose31( T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(16), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1288,9 +1288,9 @@ def transpose4( T_transpose: T.Buffer((T.int64(4096), T.int64(2048)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(4096), T.int64(2048)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -1302,9 +1302,9 @@ def transpose5( T_transpose: T.Buffer((T.int64(2048), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(2048), T.int64(4096)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index 1c41d9367661..7f77ee4be946 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -96,13 +96,13 @@ def split1(var_A: T.handle, var_T_split: T.handle, var_T_split_1: T.handle): A = T.match_buffer(var_A, (128, 128), "float32") T_split = T.match_buffer(var_T_split, (64, 128), "float32") T_split_1 = T.match_buffer(var_T_split_1, (64, 128), "float32") - # with T.block("root"): + # with T.sblock("root"): for ax1, ax2 in T.grid(64, 128): - with T.block("T_split"): + with T.sblock("T_split"): v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T_split[v_ax1, v_ax2] = A[v_ax1, v_ax2] for ax1, ax2 in T.grid(64, 128): - with T.block("T_split_1"): + with T.sblock("T_split_1"): v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T_split_1[v_ax1, v_ax2] = A[v_ax1 + 64, v_ax2] @@ -142,15 +142,15 @@ def split1( T_split_1: T.Buffer((64, 128), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax1, ax2 in T.grid(64, 128): - with T.block("T_split"): + with T.sblock("T_split"): v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T.reads(A[v_ax1, v_ax2]) T.writes(T_split[v_ax1, v_ax2]) T_split[v_ax1, v_ax2] = A[v_ax1, v_ax2] for ax1, ax2 in T.grid(64, 128): - with T.block("T_split_1"): + with T.sblock("T_split_1"): v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T.reads(A[v_ax1 + 64, v_ax2]) T.writes(T_split_1[v_ax1, v_ax2]) @@ -393,10 +393,10 @@ def rms_norm( A = T.match_buffer(var_A, (T.int64(1), 256, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), 256, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -406,7 +406,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -435,9 +435,9 @@ def rotary_embedding( rotary = T.match_buffer( var_rotary, (T.int64(1), 256, T.int64(32), T.int64(128)), "float16" ) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[256 + v_i1 - 256, v_i3], @@ -603,10 +603,10 @@ def rms_norm( rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -616,7 +616,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -640,9 +640,9 @@ def rotary_embedding( rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[256 + v_i1 - 256, v_i3], @@ -819,9 +819,9 @@ def add( T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T.reads(A[T.int64(0), v_ax1, v_ax2], B[T.int64(0), v_ax1, v_ax2]) @@ -837,9 +837,9 @@ def divide( T_divide: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), v_ax1, v_ax2, v_ax3]) @@ -855,9 +855,9 @@ def matmul( matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(4096)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_k = T.axis.remap("SSR", [i1, i2, k]) T.reads(A[T.int64(0), v_i1, v_k], B[v_k, v_i2]) @@ -875,11 +875,11 @@ def matmul1( matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(256), T.int64(128) ): - with T.block("matmul"): + with T.sblock("matmul"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSR", [i1, i2, i3, k]) T.reads(A[T.int64(0), v_i1, v_i2, v_k], B[T.int64(0), v_i1, v_k, v_i3]) @@ -898,11 +898,11 @@ def matmul2( matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(128), T.int64(256) ): - with T.block("matmul"): + with T.sblock("matmul"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSR", [i1, i2, i3, k]) T.reads(A[T.int64(0), v_i1, v_i2, v_k], B[T.int64(0), v_i1, v_k, v_i3]) @@ -921,9 +921,9 @@ def maximum( T_maximum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_maximum"): + with T.sblock("T_maximum"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[T.int64(0), v_ax1, v_ax2, v_ax3], B[T.int64(0), v_ax1, v_ax2, v_ax3]) @@ -939,9 +939,9 @@ def minimum( T_minimum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_minimum"): + with T.sblock("T_minimum"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads( @@ -958,9 +958,9 @@ def reshape( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[T.int64(0), v_ax1, v_ax2 * T.int64(128) + v_ax3]) @@ -975,9 +975,9 @@ def reshape1( T_reshape: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) @@ -989,9 +989,9 @@ def reshape2( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[v_ax1, v_ax2, v_ax3]) @@ -1004,9 +1004,9 @@ def reshape3( T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2 = T.axis.remap("SS", [ax1, ax2]) T.reads(A[T.int64(0), v_ax1, v_ax2 // T.int64(128), v_ax2 % T.int64(128)]) @@ -1022,10 +1022,10 @@ def rms_norm( rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) v_i, v_k = T.axis.remap("SR", [i, k]) T.reads(A[T.int64(0), v_i, v_k]) @@ -1036,7 +1036,7 @@ def rms_norm( "float32", A[T.int64(0), v_i, v_k] ) * T.Cast("float32", A[T.int64(0), v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) v_i, v_k = T.axis.remap("SS", [i, k]) T.reads(B[v_k], A[T.int64(0), v_i, v_k], Ared_temp[T.int64(0), v_i]) @@ -1061,9 +1061,9 @@ def rotary_embedding( rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_i3 = T.axis.remap("SSS", [i1, i2, i3]) T.reads( @@ -1093,14 +1093,14 @@ def softmax( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(256)), "float16") T_softmax_exp = T.alloc_buffer( (T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16" ) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(256)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_k = T.axis.remap("SSR", [i1, i2, k]) T.reads(A[T.int64(0), v_i1, v_i2, v_k]) @@ -1111,7 +1111,7 @@ def softmax( T_softmax_maxelem[T.int64(0), v_i1, v_i2], A[T.int64(0), v_i1, v_i2, v_k] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_i3 = T.axis.remap("SSS", [i1, i2, i3]) T.reads( @@ -1122,7 +1122,7 @@ def softmax( A[T.int64(0), v_i1, v_i2, v_i3] - T_softmax_maxelem[T.int64(0), v_i1, v_i2] ) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_k = T.axis.remap("SSR", [i1, i2, k]) T.reads(T_softmax_exp[T.int64(0), v_i1, v_i2, v_k]) @@ -1134,7 +1134,7 @@ def softmax( + T_softmax_exp[T.int64(0), v_i1, v_i2, v_k] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) v_i1, v_i2, v_i3 = T.axis.remap("SSS", [i1, i2, i3]) T.reads( @@ -1142,7 +1142,7 @@ def softmax( T_softmax_expsum[T.int64(0), v_i1, v_i2], ) T.writes(T_softmax_norm[T.int64(0), v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) + T.sblock_attr({"axis": 3}) T_softmax_norm[T.int64(0), v_i1, v_i2, v_i3] = ( T_softmax_exp[T.int64(0), v_i1, v_i2, v_i3] / T_softmax_expsum[T.int64(0), v_i1, v_i2] @@ -1154,9 +1154,9 @@ def transpose( T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -1168,9 +1168,9 @@ def transpose1( T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[T.int64(0), v_ax2, v_ax1, v_ax3]) @@ -1185,9 +1185,9 @@ def transpose2( T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(128), T.int64(256)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[T.int64(0), v_ax1, v_ax3, v_ax2]) @@ -1202,9 +1202,9 @@ def transpose3( T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1, v_ax2, v_ax3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(A[T.int64(0), v_ax2, v_ax1, v_ax3]) @@ -1602,10 +1602,10 @@ def rms_norm( n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), n)) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -1615,7 +1615,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -1645,9 +1645,9 @@ def rotary_embedding( rotary = T.match_buffer( var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16" ) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[m + v_i1 - n, v_i3], @@ -1810,10 +1810,10 @@ def rms_norm( n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), n)) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -1823,7 +1823,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) @@ -1853,9 +1853,9 @@ def rotary_embedding( rotary = T.match_buffer( var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16" ) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( B[m + v_i1 - n, v_i3], diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_parser.py b/tests/python/relax/distributed/test_distributed_tvmscript_parser.py index a1079288af29..f8f3cec6fd8e 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_parser.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_parser.py @@ -63,7 +63,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -126,7 +126,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -166,7 +166,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index d11774d73797..9d15cabe64f1 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -91,7 +91,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -138,9 +138,9 @@ class Module: @T.prim_func def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(""): + with T.sblock(""): vi, vj = T.axis.remap("SS", [i, j]) T.reads(x[vi, vj]) T.writes(y[vi, vj]) diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 2845622bbe3e..0b4a015c2d01 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -539,7 +539,7 @@ def reshape( T_reshape: T.Buffer((8, 3), "float32"), ): for i0, i1 in T.grid(8, 3): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads( rxplaceholder[ @@ -568,7 +568,7 @@ def reshape_scheduled( ): for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 + i0_i1_fused_1) // 3) ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 + i0_i1_fused_1) % 3) T.reads( @@ -598,7 +598,7 @@ def expand_dims( ): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] ) @@ -620,7 +620,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16" ) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -646,7 +646,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): A = T.match_buffer(var_A, (T.int64(1), n), "int32") T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") for ax0 in range(n): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(n, ax0) T.reads(A[T.int64(0), v_ax0 % n]) T.writes(T_reshape[v_ax0]) @@ -663,7 +663,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): A = T.match_buffer(var_A, (n, T.int64(4096)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) @@ -684,7 +684,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16" ) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -710,9 +710,9 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( A[ @@ -741,12 +741,12 @@ def reshape_raggedness( B: T.Buffer((100, 12, 64), "float32"), ): for b in T.serial(8): - with T.block("block0"): + with T.sblock("block0"): vb = T.axis.spatial(8, b) for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]): for h in T.serial(12): for f in T.serial(64): - with T.block("block1"): + with T.sblock("block1"): vi, vh, vf = T.axis.remap("SSS", [i, h, f]) B[src_indptr[vb] + vi, vh, vf] = A[ src_indptr[vb] + vi, vh * 64 + vf @@ -760,11 +760,11 @@ def test_reshape_pattern_reject_seqstmt(): def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): C = T.alloc_buffer((128, 128), "float32") for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) C[vi0, vi1] = A[vi0, vi1] for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) B[vi0, vi1] = C[vi0, vi1] + T.float32(1) @@ -772,11 +772,11 @@ def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32") def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): C = T.alloc_buffer((128, 128), "float32") for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) C[vi0, vi1] = A[vi0, vi1] for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) B[vi0, vi1] = C[vi0, vi1] @@ -788,7 +788,7 @@ def test_reshape_pattern_reject_reduction(): @T.prim_func def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi0] = T.float32(0) @@ -801,7 +801,7 @@ def test_reshape_pattern_reject_reduction(): @T.prim_func def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi0] = T.float32(0) diff --git a/tests/python/relax/test_analysis_detect_recursion.py b/tests/python/relax/test_analysis_detect_recursion.py index b4c7adc84456..320226083a08 100644 --- a/tests/python/relax/test_analysis_detect_recursion.py +++ b/tests/python/relax/test_analysis_detect_recursion.py @@ -424,11 +424,11 @@ class CallPrimFunc: def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): C = T.alloc_buffer((128, 128), "float32") for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) C[vi0, vi1] = A[vi0, vi1] for i0, i1 in T.grid(4, 4): - with T.block("identity"): + with T.sblock("identity"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) B[vi0, vi1] = C[vi0, vi1] diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py b/tests/python/relax/test_analysis_suggest_layout_transforms.py index 03eaef0267b6..fd24e782903d 100644 --- a/tests/python/relax/test_analysis_suggest_layout_transforms.py +++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py @@ -25,9 +25,9 @@ def apply_transformations(func, suggested_transfoms, print_transformation=False): sch = tir.Schedule(func) for block, per_block_transformations in suggested_transfoms.items(): - blockrv = sch.get_block(block.name_hint) + blockrv = sch.get_sblock(block.name_hint) for obj, index_map in per_block_transformations.items(): - if isinstance(obj, tir.Block): + if isinstance(obj, tir.SBlock): block_name = obj.name_hint if print_transformation: print("Block transformation: ", block_name, " :: ", index_map) @@ -48,12 +48,12 @@ def nested_block( relu: T.Buffer((32, 64, 224, 224), "float32"), ): for i, j in T.grid(32, 64): - with T.block("outer"): + with T.sblock("outer"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(arg[v_i, v_j, 0:224, 0:224]) T.writes(relu[v_i, v_j, 0:224, 0:224]) for k, l in T.grid(224, 224): - with T.block("inner"): + with T.sblock("inner"): v_k, v_l = T.axis.remap("SS", [k, l]) T.reads(arg[v_i, v_j, v_k, v_l]) T.writes(relu[v_i, v_j, v_k, v_l]) @@ -73,7 +73,7 @@ def elemwise( relu: T.Buffer((32, 64, 224, 224), "float32"), ): for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(arg[v_i0, v_i1, v_i2, v_i3]) T.writes(relu[v_i0, v_i1, v_i2, v_i3]) @@ -97,7 +97,7 @@ def elemwise( relu: T.Buffer((32, 64, 224, 224), "float32"), ): for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(arg[v_i0, v_i1, v_i2, v_i3]) T.writes(relu[v_i0, v_i1, v_i2, v_i3]) @@ -116,7 +116,7 @@ def before( output: T.Buffer((32, 64), "float32"), ): for ax0, ax1 in T.grid(32, 64): - with T.block("compute"): + with T.sblock("compute"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -135,7 +135,7 @@ def before( output: T.Buffer((32 * 64, 10), "float32"), ): for ax0, ax1, ax2 in T.grid(32, 64, 10): - with T.block("compute"): + with T.sblock("compute"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(arg[v_ax0, v_ax1]) T.writes(output[v_ax0 * v_ax1, v_ax2]) @@ -154,7 +154,7 @@ def before( output: T.Buffer((16), "float32"), ): for ax0, ax1 in T.grid(4, 4): - with T.block("flatten"): + with T.sblock("flatten"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg[v_ax0, v_ax1]) T.writes(output[v_ax0 * 4 + v_ax1]) @@ -173,7 +173,7 @@ def before( output: T.Buffer((4, 8), "float32"), ): for ax0, ax1, ax2 in T.grid(4, 8, 4): - with T.block("compute"): + with T.sblock("compute"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(arg[v_ax1, v_ax2]) T.writes(output[v_ax0, v_ax1]) @@ -185,7 +185,7 @@ def expected( output: T.Buffer((32), "float32"), ): for ax0, ax2 in T.grid(32, 4): - with T.block("compute"): + with T.sblock("compute"): v_ax0, v_ax2 = T.axis.remap("SS", [ax0, ax2]) T.reads(arg[v_ax0 % 8, v_ax2]) T.writes(output[v_ax0]) @@ -205,7 +205,7 @@ def elemwise( relu: T.Buffer((32, 64, 224, 224), "float32"), ): for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(arg[v_i0, v_i1, v_i2, v_i3]) T.writes(relu[v_i0, v_i1, v_i2, v_i3]) @@ -226,7 +226,7 @@ def before( sum: T.Buffer((32, 64), "float32"), ): for ax0, k2, ax1, k3 in T.grid(32, 224, 64, 224): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_k2, v_ax1, v_k3 = T.axis.remap("SRSR", [ax0, k2, ax1, k3]) T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) T.writes(sum[v_ax0, v_ax1]) @@ -240,7 +240,7 @@ def expected( sum: T.Buffer((32, 16, 4), "float32"), ): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 16, 224, 4): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v0, v1, v2, v3, v4 = T.axis.remap("SRSRS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg[v0, v1, v2, v3, v4]) T.writes(sum[v0, v2, v4]) @@ -265,7 +265,7 @@ def before(arg: T.handle, relu: T.handle): Arg = T.match_buffer(arg, (N, C, H, W)) Relu = T.match_buffer(relu, (N, C, H, W)) for i0, i1, i2, i3 in T.grid(N, C, H, W): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(Arg[v_i0, v_i1, v_i2, v_i3]) T.writes(Relu[v_i0, v_i1, v_i2, v_i3]) @@ -279,9 +279,9 @@ def expected(arg: T.handle, relu: T.handle): W = T.int64() Arg = T.match_buffer(arg, (N, H, W, C)) Relu = T.match_buffer(relu, (N, H, W, C)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): - with T.block("compute"): + with T.sblock("compute"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(Arg[v0, v1, v2, v3]) T.writes(Relu[v0, v1, v2, v3]) @@ -301,7 +301,7 @@ def before( relu: T.Buffer((32, 64, 224, 224), "float32"), ): for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(arg[v_i0, v_i1, v_i2, v_i3]) T.writes(relu[v_i0, v_i1, v_i2, v_i3]) @@ -313,7 +313,7 @@ def expected( relu: T.Buffer((32, 224, 224, 64), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): - with T.block("compute"): + with T.sblock("compute"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v0, v1, v2, v3]) T.writes(relu[v0, v1, v2, v3]) @@ -333,7 +333,7 @@ def before( pool_max: T.Buffer((32, 64, 111, 223), "float32"), ): for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(32, 64, 111, 223, 2, 2): - with T.block("pool_max"): + with T.sblock("pool_max"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap( "SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1] ) @@ -346,7 +346,7 @@ def before( ] ) T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(-3.4028234663852886e38) pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.max( @@ -364,13 +364,13 @@ def expected( arg: T.Buffer((32, 224, 224, 64), "float32"), pool_max: T.Buffer((32, 111, 223, 64), "float32"), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 111, 223, 64, 2, 2): - with T.block("pool_max"): + with T.sblock("pool_max"): v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) T.reads(arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3]) T.writes(pool_max[v0, v1, v2, v3]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) pool_max[v0, v1, v2, v3] = T.max( @@ -399,13 +399,13 @@ def before( ), ): for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(32, 4, 110, 220, 16, 5, 5): - with T.block("pool_max"): + with T.sblock("pool_max"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] ) T.reads(arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4]) T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(-3.4028234663852886e38) pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( @@ -419,11 +419,11 @@ def expected( pool_max: T.Buffer((32, 110, 220, 64), "float32"), ): for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 110, 220, 64, 5, 5): - with T.block("pool_max"): + with T.sblock("pool_max"): v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) T.reads(arg[v0, v1 * 2 + v4, v2 + v5, v3]) T.writes(pool_max[v0, v1, v2, v3]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) pool_max[v0, v1, v2, v3] = T.max( @@ -446,7 +446,7 @@ def before( sum: T.Buffer((32, 64), "float32"), ): for ax0, ax1, k2, k3 in T.grid(32, 64, 224, 224): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) T.writes(sum[v_ax0, v_ax1]) @@ -460,7 +460,7 @@ def expected( sum: T.Buffer((32, 4, 16), "float32"), ): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 4, 224, 224, 16): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v0, v1, v2, v3, v4 = T.axis.remap("SSRRS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg[v0, v1, v2, v3, v4]) T.writes(sum[v0, v1, v4]) @@ -483,7 +483,7 @@ def before( resize: T.Buffer((32, 64, 202, 246), "float32"), ): for i0, i1, i2, i3 in T.grid(32, 64, 202, 246): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(arg[v_i0, v_i1, 0:224, 0:224]) T.writes(resize[v_i0, v_i1, v_i2, v_i3]) @@ -523,9 +523,9 @@ def expected( arg: T.Buffer((32, 64, 224, 224), "float32"), resize: T.Buffer((32, 202, 246, 64), "float32"), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(32, 202, 246, 64): - with T.block("resize"): + with T.sblock("resize"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v0, v3, 0:224, 0:224]) T.writes(resize[v0, v1, v2, v3]) @@ -574,7 +574,7 @@ def before( T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 64, 10, 8): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( arg[ @@ -597,9 +597,9 @@ def expected( arg: T.Buffer((32, 224, 224, 16, 4), "float32"), T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 10, 8, 16, 4): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg[v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4]) T.writes(T_strided_slice_with_axes[v0, v1, v2, v3, v4]) @@ -622,9 +622,9 @@ def before( T_add: T.Buffer((32, 64, 224, 224), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(32, 64, 224, 224): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( arg0[v_ax0, v_ax1, v_ax2, v_ax3], @@ -642,9 +642,9 @@ def expected( T_add: T.Buffer((32, 224, 224, 16, 4), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4): - with T.block("T_add"): + with T.sblock("T_add"): v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg0[v0, v1, v2, v3, v4], arg1[v1, v2, v3, v4]) T.writes(T_add[v0, v1, v2, v3, v4]) @@ -664,7 +664,7 @@ def before( T_transpose: T.Buffer((32, 224, 224, 64), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v_ax0, v_ax3, v_ax1, v_ax2]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -676,7 +676,7 @@ def expected( T_transpose: T.Buffer((32, 224, 64, 224), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 224, 64, 224): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v0, v2, v3, v1]) T.writes(T_transpose[v0, v1, v2, v3]) @@ -696,7 +696,7 @@ def before( PadInput: T.Buffer((32, 64, 230, 230), "float32"), ): for i0, i1, i2, i3 in T.grid(32, 64, 230, 230): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) @@ -712,7 +712,7 @@ def expected( PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"), ): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 230, 230, 16, 4): - with T.block("PadInput"): + with T.sblock("PadInput"): v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg[v0, v1 - 2, v2 - 2, v3, v4]) T.writes(PadInput[v0, v1, v2, v3, v4]) @@ -737,13 +737,13 @@ def before( split1: T.Buffer((32, 32, 224, 224), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -756,13 +756,13 @@ def expected( split1: T.Buffer((32, 224, 224, 32), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v0, v1, v2, v3]) T.writes(split0[v0, v1, v2, v3]) split0[v0, v1, v2, v3] = arg[v0, v1, v2, v3] for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v0, v1, v2, v3 + 32]) T.writes(split1[v0, v1, v2, v3]) @@ -785,13 +785,13 @@ def before( split1: T.Buffer((32, 32, 224, 224), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -803,15 +803,15 @@ def expected( split0: T.Buffer((32, 224, 224, 8, 4), "float32"), split1: T.Buffer((32, 224, 224, 8, 4), "float32"), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg[v0, v1, v2, v3, v4]) T.writes(split0[v0, v1, v2, v3, v4]) split0[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3, v4] for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(arg[v0, v1, v2, v3 + 8, v4]) T.writes(split1[v0, v1, v2, v3, v4]) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index fd7c70c6148c..54b860a07c59 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -715,7 +715,7 @@ def main(A: R.Tensor([16], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -740,7 +740,7 @@ def main(A: R.Tensor([4, 4], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -764,7 +764,7 @@ def main(A: R.Tensor([16], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -789,7 +789,7 @@ def main(A: R.Tensor([32], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -813,7 +813,7 @@ def main(A: R.Tensor([16], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -839,7 +839,7 @@ def main(A: R.Tensor([16], "float32")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -865,7 +865,7 @@ def main(A: R.Tensor([16], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] + T.float16(1.0) @@ -898,7 +898,7 @@ def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): B = T.match_buffer(B_handle, [M, N], dtype="float16") for i, j in T.grid(M, N): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi * N + vj] @@ -931,7 +931,7 @@ def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): B = T.match_buffer(B_handle, [M, N], dtype="float16") for i, j in T.grid(M, N): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi * N + vj] @@ -966,7 +966,7 @@ def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): B = T.match_buffer(B_handle, [M, N], dtype="float16") for i, j in T.grid(M, N): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi * N + vj] @@ -1004,7 +1004,7 @@ def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle): B = T.match_buffer(B_handle, [16, M, N], dtype="float16") for i, j, k in T.grid(16, M, N): - with T.block("compute"): + with T.sblock("compute"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi * N * M + vj * N + vk] @@ -1037,7 +1037,7 @@ def flatten(A_handle: T.handle, B_handle: T.handle): B = T.match_buffer(B_handle, [M * N], dtype="float16") for i in T.grid(M * N): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi // N, vi % N] @@ -1075,7 +1075,7 @@ def flatten(A_handle: T.handle, B_handle: T.handle): B = T.match_buffer(B_handle, [M * N], dtype="float16") for i in T.grid(M * N): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi // N, vi % N] @@ -1114,7 +1114,7 @@ def flatten(A_handle: T.handle, B_handle: T.handle): B = T.match_buffer(B_handle, [M * N], dtype="float16") for i in T.grid(M * N): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) B[vi] = A[vi // N, vi % N] @@ -1139,7 +1139,7 @@ def main(A: R.Tensor([16], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) A[vi] = A[vi] + T.float16(1.0) @@ -1164,7 +1164,7 @@ def main(A: R.Tensor([16], "float16")): @T.prim_func def add_one(A: T.Buffer(16, "float16")): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) A[vi] = A[vi] + T.float16(1.0) @@ -1196,12 +1196,12 @@ def add_one( C: T.Buffer(16, "float16"), ): for i in range(32): - with T.block("inplace_B"): + with T.sblock("inplace_B"): vi = T.axis.remap("S", [i]) B[vi] = B[vi] + T.float16(1.0) for i in range(16): - with T.block("output_C"): + with T.sblock("output_C"): vi = T.axis.remap("S", [i]) C[vi] = A[vi] + T.float16(1.0) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 96ebbfc2ef32..d1378a174d67 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -445,7 +445,7 @@ def addone(A_handle: T.handle, B_handle: T.handle) -> None: B = T.match_buffer(B_handle, (m, n), "float32") T.func_attr(({"global_symbol": "addone"})) for i, j in T.grid(m, n): - with T.block("addone"): + with T.sblock("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.int32(1) diff --git a/tests/python/relax/test_backend_dispatch_sampling.py b/tests/python/relax/test_backend_dispatch_sampling.py index fb36f877758b..f06a6a66b93e 100644 --- a/tests/python/relax/test_backend_dispatch_sampling.py +++ b/tests/python/relax/test_backend_dispatch_sampling.py @@ -52,9 +52,9 @@ def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): usample = T.match_buffer(B, (out_batch, 1)) sample_indices = T.match_buffer(C, (out_batch, 1), "int64") output_index = T.match_buffer(D, (out_batch, 1), "int64") - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(out_batch, vocab_size): - with T.block("T_get_sample_index"): + with T.sblock("T_get_sample_index"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) if usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] or v_ax1 + T.int64(1) == vocab_size: if v_ax1 == T.int64(0): @@ -92,7 +92,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1)) row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int64") token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int64") - # with T.block("root"): + # with T.sblock("root"): aggregate = T.alloc_buffer((), scope="local") sample_id_local = T.alloc_buffer((), "int64", scope="local") step_iter = T.alloc_buffer((), "int32", scope="local") @@ -104,7 +104,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl aggregate[()] = T.Cast("float32", 0) step_iter[()] = 0 while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < T.Cast("int64", (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512))): - with T.block(""): + with T.sblock(""): T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()]) T.writes(sample_id_local[()], aggregate[()]) prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local") @@ -119,7 +119,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0)) prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0), prob_local, T.Cast("float32", 0)) valid[v] = prob_local > T.float32(0) and idx < vocab_size - with T.block(""): + with T.sblock(""): T.reads(prob_gt_threshold[T.int64(0):T.int64(4)]) T.writes(step_aggregate[()]) local_sum = T.alloc_buffer((), scope="local") @@ -150,7 +150,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)] for v in T.unroll(T.int64(4)): greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07) - with T.block(""): + with T.sblock(""): T.reads(greater_than_u[T.int64(0):T.int64(4)]) T.writes(mask[T.int64(0):T.int64(4)]) shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared") @@ -162,7 +162,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl for v in T.unroll(T.int64(4)): mask[v] = mask[v] and valid[v] indices[v] = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v - with T.block(""): + with T.sblock(""): T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)]) T.writes(sample_id_local[()]) local_sum = T.alloc_buffer((), "int64", scope="local") diff --git a/tests/python/relax/test_blockbuilder_emit_te.py b/tests/python/relax/test_blockbuilder_emit_te.py index 5b8e7e9ba20c..ef3502315913 100644 --- a/tests/python/relax/test_blockbuilder_emit_te.py +++ b/tests/python/relax/test_blockbuilder_emit_te.py @@ -50,7 +50,7 @@ def te_func( ): T.func_attr({"tir.noalias": True}) for i in range(T.int64(10)): - with T.block("B"): + with T.sblock("B"): v_i = T.axis.spatial(T.int64(10), i) T.writes(B[v_i]) B[v_i] = A[v_i + m] @@ -101,7 +101,7 @@ def te_slice( T.func_attr({"tir.noalias": True}) for i in range(A.shape[1]): - with T.block("slice"): + with T.sblock("slice"): vi = T.axis.remap("S", [i]) Output[vi] = A[row_index, vi] diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index c645dce96bd4..bbc0d0897f01 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -1263,9 +1263,9 @@ def decode( decode_1: T.Buffer((T.int64(64), T.int64(128)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j in T.grid(T.int64(64), T.int64(128)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(A[v_i, v_j // T.int64(2)], B[v_j]) T.writes(decode_1[v_i, v_j]) @@ -1296,11 +1296,11 @@ def encode( compute: T.Buffer((T.int64(128),), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): max_abs_value = T.alloc_buffer((T.int64(128),), "float16") scale = T.alloc_buffer((T.int64(128),)) for i, k in T.grid(T.int64(128), T.int64(64)): - with T.block("max_abs_value"): + with T.sblock("max_abs_value"): v_i, v_k = T.axis.remap("SR", [i, k]) T.reads(A[v_i, v_k]) T.writes(max_abs_value[v_i]) @@ -1308,7 +1308,7 @@ def encode( max_abs_value[v_i] = T.float16(-65504) max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) for i in range(T.int64(128)): - with T.block("scale"): + with T.sblock("scale"): v_i = T.axis.spatial(T.int64(128), i) T.reads(max_abs_value[v_i]) T.writes(scale[v_i]) @@ -1316,7 +1316,7 @@ def encode( T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) ) * T.float32(0.125) for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)): - with T.block("w_gathered"): + with T.sblock("w_gathered"): v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k]) T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * T.int64(2) + v_k]) T.writes(w_gathered[v_j, v_i]) @@ -1351,7 +1351,7 @@ def encode( ), ) for i0 in range(T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(T.int64(128), i0) T.reads(scale[v_i0]) T.writes(compute[v_i0]) @@ -1520,9 +1520,9 @@ def decode( decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j in T.grid(T.int64(64), T.int64(64)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(A[v_i, v_j], B[v_j]) T.writes(decode_1[v_i, v_j]) @@ -1535,11 +1535,11 @@ def encode( compute: T.Buffer((T.int64(64),), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): max_abs_value = T.alloc_buffer((T.int64(64),), "float16") scale = T.alloc_buffer((T.int64(64),)) for i, k in T.grid(T.int64(64), T.int64(64)): - with T.block("max_abs_value"): + with T.sblock("max_abs_value"): v_i, v_k = T.axis.remap("SR", [i, k]) T.reads(A[v_i, v_k]) T.writes(max_abs_value[v_i]) @@ -1547,7 +1547,7 @@ def encode( max_abs_value[v_i] = T.float16(-65504) max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) for i in range(T.int64(64)): - with T.block("scale"): + with T.sblock("scale"): v_i = T.axis.spatial(T.int64(64), i) T.reads(max_abs_value[v_i]) T.writes(scale[v_i]) @@ -1555,7 +1555,7 @@ def encode( T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) ) * T.float32(0.0078125) for j, i in T.grid(T.int64(64), T.int64(64)): - with T.block("w_gathered"): + with T.sblock("w_gathered"): v_j, v_i = T.axis.remap("SS", [j, i]) T.reads(A[v_i, v_j], scale[v_i]) T.writes(w_gathered[v_j, v_i]) @@ -1570,7 +1570,7 @@ def encode( ), ) for i0 in range(T.int64(64)): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(T.int64(64), i0) T.reads(scale[v_i0]) T.writes(compute[v_i0]) @@ -1666,10 +1666,10 @@ def rms_norm( rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1))) for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("Ared_temp"): + with T.sblock("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) @@ -1679,7 +1679,7 @@ def rms_norm( "float32", A[v_bsz, v_i, v_k] ) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("rms_norm"): + with T.sblock("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm[v_bsz, v_i, v_k]) @@ -1799,9 +1799,9 @@ def decode( decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j in T.grid(T.int64(64), T.int64(64)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(A[v_i, v_j], B[v_j]) T.writes(decode_1[v_i, v_j]) @@ -1814,11 +1814,11 @@ def encode( compute: T.Buffer((T.int64(64),), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): max_abs_value = T.alloc_buffer((T.int64(64),), "float16") scale = T.alloc_buffer((T.int64(64),)) for i, k in T.grid(T.int64(64), T.int64(64)): - with T.block("max_abs_value"): + with T.sblock("max_abs_value"): v_i, v_k = T.axis.remap("SR", [i, k]) T.reads(A[v_i, v_k]) T.writes(max_abs_value[v_i]) @@ -1826,7 +1826,7 @@ def encode( max_abs_value[v_i] = T.float16(-65504) max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) for i in range(T.int64(64)): - with T.block("scale"): + with T.sblock("scale"): v_i = T.axis.spatial(T.int64(64), i) T.reads(max_abs_value[v_i]) T.writes(scale[v_i]) @@ -1834,7 +1834,7 @@ def encode( T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) ) * T.float32(0.0078125) for j, i in T.grid(T.int64(64), T.int64(64)): - with T.block("w_gathered"): + with T.sblock("w_gathered"): v_j, v_i = T.axis.remap("SS", [j, i]) T.reads(A[v_i, v_j], scale[v_i]) T.writes(w_gathered[v_j, v_i]) @@ -1849,7 +1849,7 @@ def encode( ), ) for i0 in range(T.int64(64)): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(T.int64(64), i0) T.reads(scale[v_i0]) T.writes(compute[v_i0]) @@ -1933,7 +1933,7 @@ def decode( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("decode"): + with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(A[v_i, v_j], B[v_i // T.int64(64), v_j]) T.writes(decode_1[v_i, v_j]) @@ -1966,7 +1966,7 @@ def encode( ) ) for i, j, k in T.grid(T.int64(2), T.int64(128), T.int64(64)): - with T.block("max_abs_value"): + with T.sblock("max_abs_value"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_j, v_i * T.int64(64) + v_k]) T.writes(max_abs_value[v_i, v_j]) @@ -1976,7 +1976,7 @@ def encode( max_abs_value[v_i, v_j], T.fabs(A[v_j, v_i * T.int64(64) + v_k]) ) for i, j in T.grid(T.int64(2), T.int64(128)): - with T.block("scale"): + with T.sblock("scale"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(max_abs_value[v_i, v_j]) T.writes(scale[v_i, v_j]) @@ -1984,7 +1984,7 @@ def encode( T.Cast("float32", max_abs_value[v_i, v_j]), T.float32(0.0001) ) * T.float32(0.0078125) for j, i in T.grid(T.int64(128), T.int64(128)): - with T.block("w_gathered"): + with T.sblock("w_gathered"): v_j, v_i = T.axis.remap("SS", [j, i]) T.reads(A[v_i, v_j], scale[v_j // T.int64(64), v_i]) T.writes(w_gathered[v_j, v_i]) @@ -2001,7 +2001,7 @@ def encode( ), ) for i0, i1 in T.grid(T.int64(2), T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(scale[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 00805152b499..d8961900cdd4 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -176,7 +176,7 @@ def tir_id(x: T.handle, y: T.handle) -> None: B = T.match_buffer(y, (m, n), "int32") for i, j in T.grid(m, n): - with T.block("id"): + with T.sblock("id"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -190,7 +190,7 @@ def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, n), "int32") for i, j in T.grid(m, n): - with T.block("id"): + with T.sblock("id"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] C[vi, vj] = A[vi, vj] @@ -379,7 +379,7 @@ def expected_add( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -397,13 +397,13 @@ def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) compute = T.alloc_buffer((T.int64(2), T.int64(3))) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.sigmoid(A[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], compute[v_ax0, v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -447,7 +447,7 @@ def add_inplace( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -460,7 +460,7 @@ def multiply_inplace( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -473,7 +473,7 @@ def subtract_inplace( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(1), T.int64(3)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(B[v_ax0, v_ax1]) @@ -565,7 +565,7 @@ def add_inplace(var_A: T.handle, var_B: T.handle): A = T.match_buffer(var_A, (a, b)) B = T.match_buffer(var_B, (a, b)) for ax0, ax1 in T.grid(a, b): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -578,7 +578,7 @@ def subtract_inplace(var_A: T.handle, var_B: T.handle): A = T.match_buffer(var_A, (a, b)) B = T.match_buffer(var_B, (a, b)) for ax0, ax1 in T.grid(a, b): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(B[v_ax0, v_ax1]) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 90e2948a320c..ee21c14c6f44 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -40,7 +40,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (32, 32)) for i0, j0, k0 in T.grid(32, 32, 32): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): C[i, j] = 0.0 @@ -52,7 +52,7 @@ def tir_relu(x: T.handle, y: T.handle): A = T.match_buffer(x, (32, 32)) B = T.match_buffer(y, (32, 32)) for i, j in T.grid(32, 32): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.max(A[vi, vj], 0.0) @@ -61,7 +61,7 @@ def tir_zeros(x: T.handle, n: T.int64): T.func_attr({"global_symbol": "tir_zeros"}) A = T.match_buffer(x, [n]) for i in range(n): - with T.block(): + with T.sblock(): vi = T.axis.remap("S", [i]) A[vi] = 1.0 diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py index 22d14a2cfd2b..ff964c4a6804 100644 --- a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -51,11 +51,11 @@ def add( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_A_assumptions"): + with T.sblock("buffer_A_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -75,7 +75,7 @@ def add( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_B_assumptions"): + with T.sblock("buffer_B_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -95,7 +95,7 @@ def add( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("compute"): + with T.sblock("compute"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -154,11 +154,11 @@ def add( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_A_assumptions"): + with T.sblock("buffer_A_assumptions"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] @@ -175,7 +175,7 @@ def add( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_B_assumptions"): + with T.sblock("buffer_B_assumptions"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] @@ -193,7 +193,7 @@ def add( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) ): for axis5_1_axis6_fused in T.vectorized(T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( "SSSS", [axis1, axis2, axis3, axis4] @@ -252,11 +252,11 @@ def sub( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_A_assumptions"): + with T.sblock("buffer_A_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -276,7 +276,7 @@ def sub( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_B_assumptions"): + with T.sblock("buffer_B_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -296,7 +296,7 @@ def sub( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("compute"): + with T.sblock("compute"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -355,11 +355,11 @@ def sub( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_A_assumptions"): + with T.sblock("buffer_A_assumptions"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] @@ -376,7 +376,7 @@ def sub( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_B_assumptions"): + with T.sblock("buffer_B_assumptions"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] @@ -394,7 +394,7 @@ def sub( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) ): for axis5_1_axis6_fused in T.vectorized(T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( "SSSS", [axis1, axis2, axis3, axis4] @@ -453,11 +453,11 @@ def mul( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_A_assumptions"): + with T.sblock("buffer_A_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -477,7 +477,7 @@ def mul( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_B_assumptions"): + with T.sblock("buffer_B_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -497,7 +497,7 @@ def mul( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("compute"): + with T.sblock("compute"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) @@ -556,11 +556,11 @@ def mul( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_A_assumptions"): + with T.sblock("buffer_A_assumptions"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] @@ -577,7 +577,7 @@ def mul( for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) ): - with T.block("buffer_B_assumptions"): + with T.sblock("buffer_B_assumptions"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] @@ -595,7 +595,7 @@ def mul( T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) ): for axis5_1_axis6_fused in T.vectorized(T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( "SSSS", [axis1, axis2, axis3, axis4] diff --git a/tests/python/relax/test_frontend_common.py b/tests/python/relax/test_frontend_common.py index 85424df2f602..e1c7fe3ee077 100644 --- a/tests/python/relax/test_frontend_common.py +++ b/tests/python/relax/test_frontend_common.py @@ -75,7 +75,7 @@ def pad( ): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(x[v_i0, v_i1, v_i2, v_i3]) T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) @@ -115,7 +115,7 @@ def replicate_pad( ): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): - with T.block("ReplicatePadInput"): + with T.sblock("ReplicatePadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( x[ @@ -176,7 +176,7 @@ def mirror_pad( ): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): - with T.block("MirrorPadInput"): + with T.sblock("MirrorPadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(x[v_i0, v_i1, T.int64(0) : T.int64(4), T.int64(0) : T.int64(4)]) T.writes(MirrorPadInput[v_i0, v_i1, v_i2, v_i3]) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index b3eac1d42709..24c6ff34ba78 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -57,11 +57,11 @@ def main( # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body - # with T.block("root") + # with T.sblock("root") matmul = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1]) T.writes(matmul[v_i0, v_i1]) @@ -69,13 +69,13 @@ def main( matmul[v_i0, v_i1] = T.float32(0) matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1] for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1] for i0, i1 in T.grid(T.int64(10), T.int64(10)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_add[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) @@ -85,9 +85,9 @@ def main( workload = db.commit_workload(Input1_ir) sch = tir.Schedule(Input1_ir, debug_mask="all") - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") + b0 = sch.get_sblock(name="matmul", func_name="main") + b1 = sch.get_sblock(name="T_add", func_name="main") + b2 = sch.get_sblock(name="root", func_name="main") sch.compute_inline(block=b1) sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l3, l4, l5 = sch.get_loops(block=b0) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 28c11f6dfaf5..f0e4eefa1cd1 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -594,9 +594,9 @@ class Expected: @T.prim_func(private=True) def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(10), T.int64(10)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) @@ -733,7 +733,7 @@ def inplace_take( pos = T.match_buffer(var_pos, (seq_len,), "int32") embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype) for ax0, ax1 in T.grid(seq_len, hidden_size): - with T.block("T_take"): + with T.sblock("T_take"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(weight[pos[v0], v1], pos[v0]) T.writes(embeddings[v0, v1]) @@ -766,7 +766,7 @@ def inplace_take( pos = T.match_buffer(var_pos, (seq_len,), "int32") embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype) for ax0, ax1 in T.grid(seq_len, hidden_size): - with T.block("T_take"): + with T.sblock("T_take"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(weight[pos[v0], v1], pos[v0]) T.writes(embeddings[v0, v1]) @@ -1032,9 +1032,9 @@ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: usample = T.match_buffer(D, (out_batch, 1)) sample_indices = T.match_buffer(E, (out_batch, 1), "int64") output_index = T.match_buffer(F, (out_batch, 1), "int64") - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(out_batch, vocab_size): - with T.block("T_get_index_from_sorted"): + with T.sblock("T_get_index_from_sorted"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))]) T.writes(output_index[v_ax0, 0]) @@ -1052,9 +1052,9 @@ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): top_p = T.match_buffer(B, (batch, 1)) top_k = T.match_buffer(C, (batch, 1), "int64") renorm_prob = T.match_buffer(D, (batch, 1)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(batch, vocab_size): - with T.block("T_get_renorm_prob"): + with T.sblock("T_get_renorm_prob"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0]) T.writes(renorm_prob[v_ax0, 0]) @@ -1153,9 +1153,9 @@ class Expected: @T.prim_func(private=True) def filter_with_top_p_top_k(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(2), T.int64(1)), "float32"), filter_with_top_p_top_k: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j in T.grid(T.int64(2), T.int64(3)): - with T.block("filter_with_top_p_top_k"): + with T.sblock("filter_with_top_p_top_k"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(B[v_i, T.int64(0)], A[v_i, v_j]) T.writes(filter_with_top_p_top_k[v_i, v_j]) @@ -1169,9 +1169,9 @@ def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.h top_p = T.match_buffer(C, (batch, 1)) top_k = T.match_buffer(D, (batch, 1), "int64") cutoff = T.match_buffer(E, (batch, 1)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(batch, vocab_size): - with T.block("T_get_renorm_prob"): + with T.sblock("T_get_renorm_prob"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0], sorted_prob[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))]) T.writes(cutoff[v_ax0, 0]) diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py b/tests/python/relax/test_meta_schedule_relax_integration.py index 6f3cdfa9a0de..096a53998da6 100644 --- a/tests/python/relax/test_meta_schedule_relax_integration.py +++ b/tests/python/relax/test_meta_schedule_relax_integration.py @@ -57,17 +57,17 @@ class Module: @T.prim_func(private=True) def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32") fused_constant = T.allocate_const([-171701247, -1719837685, 1801664104, -634316588, 920159370, -132073802, 2142531563, 1465185701, -1505608067, 1737948828, 1581089391, -1986167320, -1449581822, 35714587, 496324563, -1430879015, -1615680873, 1198514997, 1494683955, 1567376558, 1319924884, -380548171, 296785437, -1546305981, -398644701, -2004794585, -1850413687, 2072643657, 847950121, -544212073, -199532669, -343273682, 953721562, -1930209358, 1573600108, -577689853], "int32", [3, 3, 4, 1]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)): - with T.block("PaddedInput"): + with T.sblock("PaddedInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3]) PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0) for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)): - with T.block("DepthwiseConv2d"): + with T.sblock("DepthwiseConv2d"): v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj]) fused_constant_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant) T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant_1[v_di, v_dj, v_c, T.int64(0)]) @@ -79,17 +79,17 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64( @T.prim_func(private=True) def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): PaddedInput0 = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32") fused_constant0 = T.allocate_const([2042349344, -2076067063, 1528163722, -1156452837, -2097172051, 1137787079, -601389657, 1907495997, 987801941, 1073738593, -1410339796, -689755358, 90351522, -44886952, -1914103775, -691553659, -1288505112, -1376578817, -2067933148, -1413101824, 1261422027, -156976862, -1185734459, 1608778622, -664209483, 1907479806, 1838595152, 464942526, 877953160, 415131837, -2010736511, 1218242769, -1440127632, 112931, 521745784, -1931145893], "int32", [3, 3, 4, 1]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)): - with T.block("PaddedInput"): + with T.sblock("PaddedInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) T.writes(PaddedInput0[v_i0, v_i1, v_i2, v_i3]) PaddedInput0[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0) for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)): - with T.block("DepthwiseConv2d"): + with T.sblock("DepthwiseConv2d"): v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj]) fused_constant0_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant0) T.reads(PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]) @@ -101,19 +101,19 @@ def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int6 @T.prim_func(private=True) def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32") DepthwiseConv2d = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32") fused_nn_conv2d_constant = T.allocate_const([1, 1, 1, 1], "int32", [1, 1, 1, 4]) fused_constant_2 = T.allocate_const([687940110, -910571705, -901609800, -500525928, 506872399, 1070176297, -305936110, 1625439784, -1565626954, -1705688881, -866370805, -1750740826, 300497007, -626864803, 390295545, 222549121, 319224543, -2003064970, 657992492, 2014175448, 653278589, -768810984, -294555581, -1197167662, 1703154671, -1540759805, -568817430, -1729755444, -275458074, 2078945571, 1683298006, -1029327874, 1315093181, 159010501, 875694807, -223655381], "int32", [3, 3, 4, 1]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)): - with T.block("PaddedInput"): + with T.sblock("PaddedInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3]) PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0) for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)): - with T.block("DepthwiseConv2d"): + with T.sblock("DepthwiseConv2d"): v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj]) fused_constant_2_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant_2) T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)]) @@ -122,7 +122,7 @@ def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64 DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0 DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) fused_nn_conv2d_constant_1 = T.Buffer((1, 1, 1, 4), "int32", data=fused_nn_conv2d_constant) T.reads(DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3], fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3]) diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index e0992d276fbb..f2d8617bb438 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -1003,7 +1003,7 @@ def strided_slice( ): T.func_attr({"tir.noalias": True}) for iters in T.grid(*B.shape): - with T.block("T_dynamic_strided_slice"): + with T.sblock("T_dynamic_strided_slice"): i, j = T.axis.remap("SS", iters) B[i, j] = A[i + index, j] @@ -1030,9 +1030,9 @@ class expected: def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64): T.func_attr({"tir.noalias": True}) T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16))) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)): - with T.block("T_dynamic_strided_slice_with_axes"): + with T.sblock("T_dynamic_strided_slice_with_axes"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0 + index, v_ax1]) T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1]) diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index 9d05690f38b1..150796044588 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -32,7 +32,7 @@ def identity_tir(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [54, 96]) for i, j in T.grid(54, 96): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] diff --git a/tests/python/relax/test_optimize_layout_transform.py b/tests/python/relax/test_optimize_layout_transform.py index 85e70b48add3..21656a7d72a0 100644 --- a/tests/python/relax/test_optimize_layout_transform.py +++ b/tests/python/relax/test_optimize_layout_transform.py @@ -47,9 +47,9 @@ def relax_add_replacement( output: T.Buffer((4, 4), "float32"), ): T.func_attr({"operator_name": "relax.add"}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -101,9 +101,9 @@ def relax_add_replacement( output: T.Buffer((4, 4), "float32"), ): T.func_attr({"operator_name": "relax.add"}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -149,9 +149,9 @@ def relax_add_replacement( output: T.Buffer((4, 4), "float32"), ): T.func_attr({"operator_name": "relax.add"}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -216,9 +216,9 @@ def relax_add_replacement( output: T.Buffer((4, 4), "float32"), ): T.func_attr({"operator_name": "relax.add"}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -272,9 +272,9 @@ def relax_relu_replacement( arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32") ): T.func_attr({"operator_name": "relax.relu"}) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -287,9 +287,9 @@ def remove_pad(var_input: T.handle, var_output: T.handle): input = T.match_buffer(var_input, (p0,)) i0 = T.int64() output = T.match_buffer(var_output, (i0,)) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(i0): - with T.block("output"): + with T.sblock("output"): v_ax0 = T.axis.spatial(i0, ax0) T.reads(input[v_ax0]) T.writes(output[v_ax0]) @@ -349,9 +349,9 @@ def relax_relu_replacement( arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32") ): T.func_attr({"operator_name": "relax.relu"}) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -364,9 +364,9 @@ def remove_pad(var_input: T.handle, var_output: T.handle): input = T.match_buffer(var_input, (p0,)) i0 = T.int64() output = T.match_buffer(var_output, (i0,)) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(i0): - with T.block("output"): + with T.sblock("output"): v_ax0 = T.axis.spatial(i0, ax0) T.reads(input[v_ax0]) T.writes(output[v_ax0]) diff --git a/tests/python/relax/test_pytorch_integration.py b/tests/python/relax/test_pytorch_integration.py index 6839906e7a28..88d4e615cfd8 100644 --- a/tests/python/relax/test_pytorch_integration.py +++ b/tests/python/relax/test_pytorch_integration.py @@ -71,7 +71,7 @@ def matmul( C = T.match_buffer(var_C, (n, 20), "float32") for i, j, k in T.grid(n, 20, 16): - with T.block("block"): + with T.sblock("block"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py b/tests/python/relax/test_runtime_builtin_rnn_state.py index 515c6ee648ff..c34c957ce7e9 100644 --- a/tests/python/relax/test_runtime_builtin_rnn_state.py +++ b/tests/python/relax/test_runtime_builtin_rnn_state.py @@ -203,7 +203,7 @@ def _rnn_state_get( for i in range(batch_size): for s in T.grid(*shape): - with T.block("copy"): + with T.sblock("copy"): vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s]) seq_id: T.int32 = seq_slot_ids[vi] history_id: T.int32 = history_slot_ids[vi] @@ -238,7 +238,7 @@ def _rnn_state_set( for i in range(batch_size): for s in T.grid(*shape): - with T.block("copy"): + with T.sblock("copy"): vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s]) seq_id: T.int32 = seq_slot_ids[vi] history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast( diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index 4061da3a9c2e..42ce4ea4f5e6 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -45,7 +45,7 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None x = T.match_buffer(x_handle, (m,), "float32") y = T.match_buffer(y_handle, (m,), "float32") output = T.match_buffer(output_handle, (m,), "float32") - with T.block("root"): + with T.sblock("root"): T.reads(x[0:m], y[0:m]) T.writes(output[0:m]) BLOCK_SIZE = T.meta_var(64) @@ -75,7 +75,7 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): x = T.match_buffer(x_handle, (m,)) y = T.match_buffer(y_handle, (m,)) output = T.match_buffer(output_handle, (m,)) - with T.block("root"): + with T.sblock("root"): T.reads(x[0:m], y[0:m]) T.writes(output[0:m]) T.call_packed( diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 78a52c66723e..1d2cd34445e8 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -280,7 +280,7 @@ def zeros(A: T.Buffer((2, 3), "int32")): # just overwrites A with 0s T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.writes(A[ax0, ax1]) A[ax0, ax1] = T.int32(0) @@ -298,7 +298,7 @@ class Expected: def zeros(A: T.Buffer((2, 3), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.writes(A[ax0, ax1]) A[ax0, ax1] = T.int32(0) @@ -324,7 +324,7 @@ def copy( # copies the contents of C into A and B T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(C[ax0, ax1]) T.writes(A[ax0, ax1], B[ax0, ax1]) @@ -353,7 +353,7 @@ def copy( # copies the contents of C into A and B T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(C[ax0, ax1]) T.writes(A[ax0, ax1], B[ax0, ax1]) @@ -387,7 +387,7 @@ def copy( # copies the contents of C into A, out1, and out2 T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(C[ax0, ax1]) T.writes(A[ax0, ax1], out1[ax0, ax1], out2[ax0, ax1]) @@ -426,7 +426,7 @@ def copy( ): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(C[ax0, ax1]) T.writes(A[ax0, ax1], out1[ax0, ax1], out2[ax0, ax1]) diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py index 70559ab369c3..e1465c1a5d7b 100644 --- a/tests/python/relax/test_transform_alter_op_impl.py +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -52,7 +52,7 @@ class Before: def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.add"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0], arg1[v_ax0]) T.writes(output[v_ax0]) @@ -71,7 +71,7 @@ class Expected: def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): T.func_attr({"operator_name": "relax.add"}) for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -91,7 +91,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" @T.prim_func(private=True) def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -115,7 +115,7 @@ class Before: def mul_by_2(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.mul_by_2"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -134,7 +134,7 @@ class Expected: def relax_mul_by_2_replacement(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.mul_by_2"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -152,7 +152,7 @@ def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32" def add_x_x(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.mul_by_2"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -175,7 +175,7 @@ class Before: def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.some_op"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0], arg1[v_ax0]) T.writes(output0[v_ax0], output1[v_ax0]) @@ -195,7 +195,7 @@ class Expected: def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): T.func_attr({"operator_name": "relax.some_op"}) for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1]) @@ -219,7 +219,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" @T.prim_func(private=True) def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1]) @@ -245,7 +245,7 @@ class Before: def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.some_op"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0], arg1[v_ax0]) T.writes(output0[v_ax0], output1[v_ax0]) @@ -265,7 +265,7 @@ class Expected: def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): T.func_attr({"operator_name": "relax.some_op"}) for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1]) @@ -289,7 +289,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" @T.prim_func(private=True) def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1]) @@ -325,7 +325,7 @@ def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32") def relu(arg0: T.Buffer((14,), "float32"), output: T.Buffer((14,), "float32")): T.func_attr({"operator_name": "relax.relu"}) for ax0 in T.grid(14): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.remap("S", [ax0]) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -365,9 +365,9 @@ def relax_relu_replacement( arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32") ): T.func_attr({"operator_name": "relax.relu"}) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -380,9 +380,9 @@ def remove_pad(var_input: T.handle, var_output: T.handle): input = T.match_buffer(var_input, (p0,)) i0 = T.int64() output = T.match_buffer(var_output, (i0,)) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(i0): - with T.block("output"): + with T.sblock("output"): v_ax0 = T.axis.spatial(i0, ax0) T.reads(input[v_ax0]) T.writes(output[v_ax0]) @@ -391,7 +391,7 @@ def remove_pad(var_input: T.handle, var_output: T.handle): @T.prim_func(private=True) def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): for ax0 in T.grid(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.remap("S", [ax0]) T.reads(arg0[v_ax0]) T.writes(output[v_ax0]) @@ -417,7 +417,7 @@ class Before: def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.add"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0], arg1[v_ax0]) T.writes(output[v_ax0]) @@ -437,9 +437,9 @@ class Expected: @T.prim_func(private=True) def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): T.func_attr({"operator_name": "relax.add"}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -463,7 +463,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" @T.prim_func(private=True) def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) T.writes(output[v_ax0, v_ax1]) @@ -489,7 +489,7 @@ def reshape( ): T.func_attr({"operator_name": "relax.reshape"}) for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( A[ @@ -525,7 +525,7 @@ def relax_reshape_replacement( ): T.func_attr({"operator_name": "relax.reshape"}) for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) @@ -560,7 +560,7 @@ def reshape_new( T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"), ): for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) @@ -587,7 +587,7 @@ class Before: def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), "float32")): T.func_attr({"operator_name": "relax.some_op"}) for ax0 in range(16): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(16, ax0) T.reads(arg0[v_ax0], arg1[v_ax0]) T.writes(output0[v_ax0], output1[v_ax0]) @@ -607,7 +607,7 @@ class Expected: def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): T.func_attr({"operator_name": "relax.some_op"}) for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1] @@ -629,7 +629,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" @T.prim_func(private=True) def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): for ax0, ax1 in T.grid(4, 4): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1] diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py index 97e0c9d52411..85ebac722bc0 100644 --- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -48,7 +48,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -81,7 +81,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, k), "float32") for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -103,7 +103,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: T.int64, k: C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -124,7 +124,7 @@ def sum(x: T.handle, y: T.handle) -> None: B = T.match_buffer(y, (16,)) for i, j in T.grid(16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj = T.axis.remap("SR", [i, j]) with T.init(): B[vi] = 0.0 @@ -145,7 +145,7 @@ def elemwise(x: T.handle, y: T.handle) -> None: B = T.match_buffer(y, (16, 16)) for i, j in T.grid(16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + 1.0 @@ -164,7 +164,7 @@ def broadcast(x: T.handle, y: T.handle) -> None: B = T.match_buffer(y, (16, 16, 16, 16)) for i0, j0, i1, j1 in T.grid(16, 16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1]) B[vi0, vj0, vi1, vj1] = A[vj0, vj1] @@ -183,7 +183,7 @@ def injective(x: T.handle, y: T.handle) -> None: B = T.match_buffer(y, (16, 16)) for i, j in T.grid(16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4] @@ -204,9 +204,9 @@ def tir_bias_add( # function attr dict T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1 in T.grid(1, 1000): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(A[ax0, ax1], B[ax1]) T.writes(C[ax0, ax1]) @@ -228,7 +228,7 @@ def add_with_unit_dim_len_broadcast( ) -> None: T.func_attr({"global_symbol": "add5", "tir.noalias": True}) for i0, i1, i2, i3 in T.grid(1, 64, 112, 112): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0]) T.writes(C[ax0, ax1, ax2, ax3]) @@ -250,7 +250,7 @@ def add_zero_dim( ) -> None: T.func_attr({"global_symbol": "add8", "tir.noalias": True}) for i0 in T.serial(128): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial(128, i0) T.reads(A[ax0], B[()]) T.writes(C[ax0]) @@ -272,10 +272,10 @@ def max_pool2d( # function attr dict T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 64, 114, 114): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1]) T.writes(pad_temp_1[ax0, ax1, ax2, ax3]) @@ -286,7 +286,7 @@ def max_pool2d( dtype="float32", ) for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3): - with T.block("tensor"): + with T.sblock("tensor"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads( tensor_1[ax0, ax1, ax2, ax3], @@ -316,12 +316,12 @@ def softmax( # function attr dict T.func_attr({"global_symbol": "softmax", "T.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32") T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32") T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32") for i0_7, i1_3 in T.grid(16, 16): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_8, k = T.axis.remap("SR", [i0_7, i1_3]) T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) T.writes(T_softmax_maxelem_1[i0_8]) @@ -331,7 +331,7 @@ def softmax( T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k] ) for i0_9, i1_4 in T.grid(16, 16): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4]) T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10]) T.writes(T_softmax_exp_1[i0_10, i1_5]) @@ -339,7 +339,7 @@ def softmax( rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32" ) for i0_11, i1_6 in T.grid(16, 16): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_12, k = T.axis.remap("SR", [i0_11, i1_6]) T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k]) T.writes(T_softmax_expsum_1[i0_12]) @@ -349,11 +349,11 @@ def softmax( T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k] ) for i0_13, i1_7 in T.grid(16, 16): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7]) T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14]) T.writes(T_softmax_norm_1[i0_14, i1_8]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm_1[i0_14, i1_8] = ( T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14] ) @@ -371,7 +371,7 @@ def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")): rxplaceholder = T.match_buffer( var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1 ) - with T.block("cumsum_generic"): + with T.sblock("cumsum_generic"): T.reads(rxplaceholder[0:10, 0:16]) T.writes(out_buf[0:160]) for fused in T.parallel(1): @@ -400,7 +400,7 @@ def sum_sqsum( sqsum: T.Buffer((32,), "float32"), ): for ax0, k0 in T.grid(32, 64): - with T.block("block"): + with T.sblock("block"): v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0]) T.reads(A[v_ax0, v_k0]) T.writes(vsum[v_ax0], sqsum[v_ax0]) @@ -423,7 +423,7 @@ class Module: @T.prim_func def no_buffer_stores(A: T.Buffer((32, 64), "float32"), vsum: T.Buffer((32,), "float32")): for ax0, k0 in T.grid(32, 64): - with T.block("block"): + with T.sblock("block"): v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0]) T.reads(A[v_ax0, v_k0]) T.writes(vsum[v_ax0]) diff --git a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py index 46f7c8aa87be..d4f01247f698 100644 --- a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py +++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py @@ -34,7 +34,7 @@ def matmul( C: T.Buffer((T.int64(32), T.int64(32)), "float32"), ): for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -58,7 +58,7 @@ def matmul1( ): T.func_attr({"layout_free_buffers": [1]}) for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -88,7 +88,7 @@ def matmul( C: T.Buffer((T.int64(32), T.int64(32)), "float32"), ): for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -116,7 +116,7 @@ def matmul1( ): T.func_attr({"layout_free_buffers": [1]}) for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -148,7 +148,7 @@ def matmul( C: T.Buffer((T.int64(32), T.int64(32)), "float32"), ): for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -185,7 +185,7 @@ def matmul1( ): T.func_attr({"layout_free_buffers": [1]}) for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -226,7 +226,7 @@ def matmul( C: T.Buffer((T.int64(32), T.int64(32)), "float32"), ): for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -263,7 +263,7 @@ def matmul1( ): T.func_attr({"layout_free_buffers": [1]}) for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] @@ -276,7 +276,7 @@ def matmul2( ): T.func_attr({"layout_free_buffers": [0]}) for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): - with T.block("C"): + with T.sblock("C"): with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index 39f6d061f721..07f4a8b2417d 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -37,7 +37,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -64,7 +64,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index c46701d33a85..b7f353049c96 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -36,7 +36,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: B = T.match_buffer(y, (16, 16)) C = T.match_buffer(z, (16, 16)) for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.S(16, i0 * 4 + i1) vj = T.axis.S(16, j) vk = T.axis.R(16, k0 * 4 + k1) diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index 5b12480e253c..bedbab4aaadc 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -402,7 +402,7 @@ def product( C: T.Buffer([16, 16], "int32"), ): for iters in T.grid(*A.shape): - with T.block("compute"): + with T.sblock("compute"): i, j = T.axis.remap("SS", iters) C[i, j] = A[i, j] * B[i, j] @@ -413,7 +413,7 @@ def sum( C: T.Buffer([16, 16], "int32"), ): for iters in T.grid(*A.shape): - with T.block("compute"): + with T.sblock("compute"): i, j = T.axis.remap("SS", iters) C[i, j] = A[i, j] + B[i, j] diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 0ddf985ec4ba..afaa1f107d6c 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -166,7 +166,7 @@ def tir_add( z: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] @@ -203,7 +203,7 @@ def tir_add( z: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] @@ -244,7 +244,7 @@ def tir_add( z: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] @@ -292,7 +292,7 @@ def tir_matmul( y = T.match_buffer(y_handle, (n, k), "float32") z = T.match_buffer(z_handle, (m, k), "float32") for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): z[vi, vj] = 0.0 @@ -329,7 +329,7 @@ def unused_func( ) -> None: T.func_attr({"global_symbol": "tir_unused"}) for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] @@ -375,7 +375,7 @@ def tir_add_tensors( z: T.Buffer((16, 16), "float32"), ): for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = InputModule.tir_add_float32(x[vi, vj], y[vi, vj]) @@ -401,7 +401,7 @@ def unused_func1( ) -> None: T.func_attr({"global_symbol": "tir_unused"}) for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] diff --git a/tests/python/relax/test_transform_few_shot_tuning.py b/tests/python/relax/test_transform_few_shot_tuning.py index e769c911a3f0..f9861bb7d57d 100644 --- a/tests/python/relax/test_transform_few_shot_tuning.py +++ b/tests/python/relax/test_transform_few_shot_tuning.py @@ -40,9 +40,9 @@ def matmul( C: T.Buffer((32, 32), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): - with T.block("C"): + with T.sblock("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) @@ -55,12 +55,12 @@ class Softmax: @T.prim_func def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32"), T_softmax_norm: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(8), T.int64(3456)), "float32") T_softmax_exp = T.alloc_buffer((T.int64(8), T.int64(3456), T.int64(3456)), "float32") T_softmax_expsum = T.alloc_buffer((T.int64(8), T.int64(3456)), "float32") for i0, i1, k in T.grid(T.int64(8), T.int64(3456), T.int64(3456)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(rxplaceholder[v_i0, v_i1, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1]) @@ -68,13 +68,13 @@ def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504) T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], rxplaceholder[v_i0, v_i1, v_k]) for i0, i1, i2 in T.grid(T.int64(8), T.int64(3456), T.int64(3456)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(rxplaceholder[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) for i0, i1, k in T.grid(T.int64(8), T.int64(3456), T.int64(3456)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1]) @@ -82,11 +82,11 @@ def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), T.int64(3456)), T_softmax_expsum[v_i0, v_i1] = T.float16(0) T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] for i0, i1, i2 in T.grid(T.int64(8), T.int64(3456), T.int64(3456)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] @tvm.script.ir_module @@ -94,7 +94,7 @@ class Fused_Variance_Cast1: @T.prim_func def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_subtract = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(34560))) @@ -102,7 +102,7 @@ def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), co T_multiply_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(34560)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(lv3[v_ax0, v_ax1, v_k2]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2]) @@ -110,25 +110,25 @@ def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), co rxplaceholder_red[v_ax0, v_ax1, v_ax2] = T.float32(0) rxplaceholder_red[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0, v_ax1, v_ax2] + lv3[v_ax0, v_ax1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0, v_ax1, v_ax2] * T.float32(2.8935185185185186e-05) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(34560)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv3[v_ax0, v_ax1, v_ax2], T_divide[v_ax0, v_ax1, T.int64(0)]) T.writes(T_subtract[v_ax0, v_ax1, v_ax2]) T_subtract[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] - T_divide[v_ax0, v_ax1, T.int64(0)] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(34560)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_subtract[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_subtract[v_ax0, v_ax1, v_ax2] * T_subtract[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(34560)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(T_multiply[v_ax0, v_ax1, v_k2]) T.writes(T_multiply_red[v_ax0, v_ax1, v_ax2]) @@ -136,13 +136,13 @@ def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), co T_multiply_red[v_ax0, v_ax1, v_ax2] = T.float32(0) T_multiply_red[v_ax0, v_ax1, v_ax2] = T_multiply_red[v_ax0, v_ax1, v_ax2] + T_multiply[v_ax0, v_ax1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_red[v_ax0, v_ax1, v_ax2]) T.writes(T_divide_1[v_ax0, v_ax1, v_ax2]) T_divide_1[v_ax0, v_ax1, v_ax2] = T_multiply_red[v_ax0, v_ax1, v_ax2] * T.float32(2.8935185185185186e-05) for i0, i1, i2 in T.grid(T.int64(1), T.int64(32), T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_divide_1[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) @@ -153,11 +153,11 @@ class Fuse_Mean_Cast1: @T.prim_func def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(34560)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(lv[v_ax0, v_ax1, v_k2]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2]) @@ -165,13 +165,13 @@ def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), "float32"), com rxplaceholder_red[v_ax0, v_ax1, v_ax2] = T.float32(0) rxplaceholder_red[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0, v_ax1, v_ax2] + lv[v_ax0, v_ax1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0, v_ax1, v_ax2] * T.float32(2.8935185185185186e-05) for i0, i1, i2 in T.grid(T.int64(1), T.int64(32), T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_divide[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) @@ -182,7 +182,7 @@ class Module: @T.prim_func def main(lv26: T.Buffer((T.int64(1), T.int64(3456), T.int64(2560)), "float16"), T_multiply: T.Buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_strided_slice_with_axes = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") T_divide = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") @@ -201,109 +201,109 @@ def main(lv26: T.Buffer((T.int64(1), T.int64(3456), T.int64(2560)), "float16"), T_strided_slice_with_axes_1 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") T_multiply_7 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), "float16") for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv26[v_ax0, v_ax1, v_ax2 + T.int64(1280)]) T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2]) T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] = lv26[v_ax0, v_ax1, v_ax2 + T.int64(1280)] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] * T.float16(0.70718232044198892) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_divide[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) T_multiply_1[v_ax0, v_ax1, v_ax2] = T_divide[v_ax0, v_ax1, v_ax2] * T.float16(1.4140625) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_2[v_ax0, v_ax1, v_ax2]) T_multiply_2[v_ax0, v_ax1, v_ax2] = T_multiply_1[v_ax0, v_ax1, v_ax2] * T.float16(0.70710678118654757) for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply_2[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) compute[v_i0, v_i1, v_i2] = T.Cast("float32", T_multiply_2[v_i0, v_i1, v_i2]) for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("compute_1"): + with T.sblock("compute_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(compute[v_i0, v_i1, v_i2]) T.writes(compute_1[v_i0, v_i1, v_i2]) compute_1[v_i0, v_i1, v_i2] = T.erf(compute[v_i0, v_i1, v_i2]) for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("compute_2"): + with T.sblock("compute_2"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(compute_1[v_i0, v_i1, v_i2]) T.writes(compute_2[v_i0, v_i1, v_i2]) compute_2[v_i0, v_i1, v_i2] = T.Cast("float16", compute_1[v_i0, v_i1, v_i2]) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_1_1"): + with T.sblock("T_multiply_1_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(compute_2[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_3[v_ax0, v_ax1, v_ax2]) T_multiply_3[v_ax0, v_ax1, v_ax2] = compute_2[v_ax0, v_ax1, v_ax2] * T.float16(0.5) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_3[v_ax0, v_ax1, v_ax2]) T.writes(T_add[v_ax0, v_ax1, v_ax2]) T_add[v_ax0, v_ax1, v_ax2] = T.float16(0.5) + T_multiply_3[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2]) T_multiply_4[v_ax0, v_ax1, v_ax2] = T_multiply_1[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_3"): + with T.sblock("T_multiply_3"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_4[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2]) T_multiply_5[v_ax0, v_ax1, v_ax2] = T_multiply_4[v_ax0, v_ax1, v_ax2] * T.float16(1.4140625) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_5[v_ax0, v_ax1, v_ax2], T_divide[v_ax0, v_ax1, v_ax2]) T.writes(T_divide_1[v_ax0, v_ax1, v_ax2]) T_divide_1[v_ax0, v_ax1, v_ax2] = T_multiply_5[v_ax0, v_ax1, v_ax2] / T_divide[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_add_1"): + with T.sblock("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_divide_1[v_ax0, v_ax1, v_ax2]) T.writes(T_add_1[v_ax0, v_ax1, v_ax2]) T_add_1[v_ax0, v_ax1, v_ax2] = T_divide_1[v_ax0, v_ax1, v_ax2] + T.float16(-1) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_add_2"): + with T.sblock("T_add_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_1[v_ax0, v_ax1, v_ax2]) T.writes(T_add_2[v_ax0, v_ax1, v_ax2]) T_add_2[v_ax0, v_ax1, v_ax2] = T_add_1[v_ax0, v_ax1, v_ax2] + T.float16(1) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_4"): + with T.sblock("T_multiply_4"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2], T_add_2[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_6[v_ax0, v_ax1, v_ax2]) T_multiply_6[v_ax0, v_ax1, v_ax2] = T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] * T_add_2[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_strided_slice_with_axes_1"): + with T.sblock("T_strided_slice_with_axes_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv26[v_ax0, v_ax1, v_ax2]) T.writes(T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2]) T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2] = lv26[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_5"): + with T.sblock("T_multiply_5"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_6[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_7[v_ax0, v_ax1, v_ax2]) T_multiply_7[v_ax0, v_ax1, v_ax2] = T_multiply_6[v_ax0, v_ax1, v_ax2] * T.float16(0.5) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)): - with T.block("T_multiply_6"): + with T.sblock("T_multiply_6"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2], T_multiply_7[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index c62a01768eec..e3eb7f636760 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -61,7 +61,7 @@ class Module: @T.prim_func def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: for i, j in T.grid(16, 16): - with T.block("addone"): + with T.sblock("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) @@ -91,7 +91,7 @@ class Module: @T.prim_func def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")) -> None: for i, j in T.grid(3, 2): - with T.block("transpose"): + with T.sblock("transpose"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] @@ -120,7 +120,7 @@ class Module: @T.prim_func def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")) -> None: for i, j in T.grid(2, 2): - with T.block("addone"): + with T.sblock("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) @@ -151,7 +151,7 @@ class Module: @T.prim_func def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: for i, j in T.grid(16, 16): - with T.block("identity"): + with T.sblock("identity"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -186,7 +186,7 @@ def addone(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) for i, j in T.grid(n, m): - with T.block("addone"): + with T.sblock("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) @@ -197,7 +197,7 @@ def sub( C: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): - with T.block("sub"): + with T.sblock("sub"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] - B[vi, vj] @@ -248,7 +248,7 @@ class Module: @T.prim_func def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: for i, j in T.grid(16, 16): - with T.block("addone"): + with T.sblock("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.int32(1) diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 16a076c41cb4..9b312f289041 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -889,7 +889,7 @@ def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(64), T.int64(64)], dtype="float32") rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(64), T.int64(64)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(A[ax0, ax1, k2, k3]) T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) @@ -901,7 +901,7 @@ def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], gamma[ax2, ax3], beta[ax2, ax3]) T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) @@ -910,7 +910,7 @@ def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), @T.prim_func(private=True) def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): - with T.block("relu"): + with T.sblock("relu"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[v_i0, v_i1, v_i2, v_i3]) T.writes(B[v_i0, v_i1, v_i2, v_i3]) @@ -921,11 +921,11 @@ class Expected: @T.prim_func(private=True) def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), gamma: T.Buffer((T.int64(64), T.int64(64)), "float32"), beta: T.Buffer((T.int64(64), T.int64(64)), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): T.func_attr({"op_pattern": 4}) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(64), T.int64(64))) rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(64), T.int64(64))) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(A[ax0, ax1, k2, k3]) T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) @@ -937,7 +937,7 @@ def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], gamma[ax2, ax3], beta[ax2, ax3]) T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) @@ -946,9 +946,9 @@ def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), @T.prim_func(private=True) def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): T.func_attr({"op_pattern": 0}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): - with T.block("relu"): + with T.sblock("relu"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[v_i0, v_i1, v_i2, v_i3]) T.writes(B[v_i0, v_i1, v_i2, v_i3]) @@ -1009,7 +1009,7 @@ class Expected: def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): T.func_attr({"op_pattern": 0, "tir.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_1[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1019,7 +1019,7 @@ def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64( def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320)), "float32")): T.func_attr({"op_pattern": 0, "tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(320)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) T.writes(T_add[v_ax0, v_ax1]) @@ -1029,7 +1029,7 @@ def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), rxplace def add2(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): T.func_attr({"op_pattern": 0, "tir.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)]) T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1040,13 +1040,13 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int T.func_attr({"op_pattern": 4, "tir.noalias": True}) pad_temp = T.alloc_buffer((T.int64(2), T.int64(320), T.int64(66), T.int64(66))) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(320), T.int64(66), T.int64(66)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)]) T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(65) and T.int64(1) <= v_i3 and v_i3 < T.int64(65), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0)) for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64), T.int64(320), T.int64(3), T.int64(3)): - with T.block("conv2d_nchw"): + with T.sblock("conv2d_nchw"): v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], rxplaceholder_1[v_ff, v_rc, v_ry, v_rx]) T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) @@ -1058,7 +1058,7 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1280), T.int64(320)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(320)), "float32")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) for i0, i1, k in T.grid(T.int64(2), T.int64(320), T.int64(1280)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) T.writes(matmul[v_i0, v_i1]) @@ -1070,7 +1070,7 @@ def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), "float32"), rxpl def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float32")): T.func_attr({"op_pattern": 2, "tir.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(320), T.int64(1), T.int64(1)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[(v_ax1 + v_ax2 + v_ax3) % T.int64(320)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1080,7 +1080,7 @@ def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"), T_reshape: T.Bu def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(320), T.int64(1), T.int64(1)), "float32")): T.func_attr({"op_pattern": 2, "tir.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(1), T.int64(1)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[((v_ax1 + v_ax2 + v_ax3) // T.int64(320) + v_ax0) % T.int64(2), (v_ax1 + v_ax2 + v_ax3) % T.int64(320)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1090,7 +1090,7 @@ def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), T_r def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), "float32"), T_transpose: T.Buffer((T.int64(1280), T.int64(320)), "float32")): T.func_attr({"op_pattern": 2, "tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(1280), T.int64(320)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -1163,9 +1163,9 @@ class Expected: @T.prim_func(private=True) def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")): T.func_attr({"op_pattern": 0, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(1), T.int64(128)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) T.writes(T_add[v_ax0, v_ax1]) @@ -1174,9 +1174,9 @@ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceh @T.prim_func(private=True) def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")): T.func_attr({"op_pattern": 0, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(1), T.int64(10)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) T.writes(T_add[v_ax0, v_ax1]) @@ -1185,9 +1185,9 @@ def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceh @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) T.writes(matmul_1[v_i0, v_i1]) @@ -1198,9 +1198,9 @@ def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxpla @T.prim_func(private=True) def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(128)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) T.writes(matmul[v_i0, v_i1]) @@ -1211,9 +1211,9 @@ def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxpl @T.prim_func(private=True) def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")): T.func_attr({"op_pattern": 0, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(1), T.int64(128)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) @@ -1222,9 +1222,9 @@ def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute @T.prim_func(private=True) def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")): T.func_attr({"op_pattern": 2, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(784), T.int64(128)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -1233,9 +1233,9 @@ def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), @T.prim_func(private=True) def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")): T.func_attr({"op_pattern": 2, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(128), T.int64(10)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) @@ -1516,7 +1516,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[()]) T.writes(Out[v_ax0, v_ax1]) @@ -1526,7 +1526,7 @@ def add( def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(A[v_i0, v_i1]) @@ -1536,7 +1536,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -1578,7 +1578,7 @@ def add( ): T.func_attr({"tir.noalias": True, "op_pattern": 0}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[()]) T.writes(Out[v_ax0, v_ax1]) @@ -1588,7 +1588,7 @@ def add( def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True, "op_pattern": 0}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(A[v_i0, v_i1]) @@ -1598,7 +1598,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True, "op_pattern": 0}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(A[v_ax0, v_ax1]) @@ -1653,9 +1653,9 @@ class Before: @T.prim_func(private=True) def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(16)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(16), T.int64(16)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(lv[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) @@ -1664,9 +1664,9 @@ def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer( @T.prim_func(private=True) def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: T.Buffer((T.int64(16), T.int64(16)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)): - with T.block("T_matmul"): + with T.sblock("T_matmul"): v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k]) T.reads(x[v_ax0, v_k], lv2[v_k, v_ax1]) T.writes(T_matmul[v_ax0, v_ax1]) diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 2219c01ccb1e..fd955c4311f7 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -705,7 +705,7 @@ def relu( out: T.Buffer((1, 64, 56, 56), "float32"), ): for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): - with T.block("root"): + with T.sblock("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) out[i, j, k, l] = T.max(data[i, j, k, l], 0.0) @@ -732,9 +732,9 @@ def relu( data: T.Buffer((1, 64, 56, 56), "float32"), out: T.Buffer((1, 64, 56, 56), "float32"), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): - with T.block("root"): + with T.sblock("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(data[i, j, k, l]) T.writes(out[i, j, k, l]) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index a67bc63f9bf2..dcc18c664489 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -639,19 +639,19 @@ def fused_add1_exp1_squeeze1( T_add = T.alloc_buffer((T.int64(20), T.int64(10))) compute = T.alloc_buffer((T.int64(20), T.int64(10))) for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], p0[()]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(20), T.int64(10)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_add[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_squeeze[v_ax0, v_ax1]) @@ -667,19 +667,19 @@ def fused_add_exp_squeeze( T_add = T.alloc_buffer((T.int64(10), T.int64(20))) compute = T.alloc_buffer((T.int64(10), T.int64(20))) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], p0[()]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_add[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_squeeze[v_ax0, v_ax1]) @@ -751,7 +751,7 @@ def dynamic_tir_kernel(a: T.handle, b: T.handle): B = T.match_buffer(b, [m, n], "float32") for iters in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): i, j = T.axis.remap("SS", iters) B[i, j] = A[i, j] * i + j @@ -783,12 +783,12 @@ def fused_function( T.func_attr({"tir.noalias": True}) Y = T.alloc_buffer(X.shape, "float32") for iters in T.grid(*X.shape): - with T.block("compute_Y"): + with T.sblock("compute_Y"): i, j = T.axis.remap("SS", iters) Y[i, j] = X[i, j] * i + j for iters in T.grid(*X.shape): - with T.block("compute_Z"): + with T.sblock("compute_Z"): i, j = T.axis.remap("SS", iters) Z[i, j] = Y[i, j] * i + j @@ -821,7 +821,7 @@ def dynamic_tir_kernel(a: T.handle, b: T.handle, c: T.handle, d: T.handle): D = T.match_buffer(d, [m * n], "float32") for i, j in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) D[vi * 32 + vj] = A[vi * 32 + vj] * B[vi] + C[vj] @@ -867,12 +867,12 @@ def fused_function( T.func_attr({"tir.noalias": True}) Y = T.alloc_buffer((T.int64(512),)) for i, j in T.grid(T.int64(16), T.int64(32)): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) Y[vi * 32 + vj] = X[vi * 32 + vj] * B[vi] + C[vj] for i, j in T.grid(T.int64(16), T.int64(32)): - with T.block("compute_1"): + with T.sblock("compute_1"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi * 32 + vj] = Y[vi * 32 + vj] * B[vi] + C[vj] @@ -961,7 +961,7 @@ def foo( m: T.int64, ): for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) rotary[v0, v1, v2, v3] = Y[m + v1 - 1, v3] * X[v0, v1, v2, v3] @@ -1009,13 +1009,13 @@ def fused( T.func_attr({"tir.noalias": True}) T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T_add[v_ax0, v_ax1, v_ax2, v_ax3] = ( X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, v_ax3] ) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * T_add[v0, v1, v2, v3] @@ -1053,7 +1053,7 @@ def concatenate( ): T.func_attr({"op_pattern": 2, "tir.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)): - with T.block("T_concat"): + with T.sblock("T_concat"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3], @@ -1073,7 +1073,7 @@ def transpose2( ): T.func_attr({"op_pattern": 2, "tir.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1124,7 +1124,7 @@ def fused_concatenate_transpose2( (T.int64(2), T.int64(4), T.int64(64), T.int64(64)) ) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)): - with T.block("T_concat"): + with T.sblock("T_concat"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3]) T.writes(T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1134,7 +1134,7 @@ def fused_concatenate_transpose2( inp_0[v_ax0, v_ax1, v_ax2, v_ax3], ) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1, v_ax2]) T.writes(T_transpose_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1202,11 +1202,11 @@ def fused_transpose_matmul( var_T_matmul_intermediate = T.match_buffer(p_output0, (n - T.int64(1), T.int64(3))) var_T_transpose_intermediate = T.alloc_buffer((T.int64(4), T.int64(3))) for ax0, ax1 in T.grid(T.int64(4), T.int64(3)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) var_T_transpose_intermediate[v_ax0, v_ax1] = x[v_ax1, v_ax0] for ax0, ax1, k in T.grid(n - T.int64(1), T.int64(3), T.int64(4)): - with T.block("T_matmul"): + with T.sblock("T_matmul"): v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k]) with T.init(): var_T_matmul_intermediate[v_ax0, v_ax1] = T.float32(0) @@ -1245,9 +1245,9 @@ def reshape( T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"), ): T.func_attr({"op_pattern": 2, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[ @@ -1310,9 +1310,9 @@ def fused_reshape( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( lv_0[ @@ -1364,7 +1364,7 @@ def add( Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) Out[vi, vj] = A[vi, vj] + T.float16(1.0) @@ -1374,7 +1374,7 @@ def add1( Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) Out[vi, vj] = A[vi, vj] + T.float16(2.0) @@ -1412,13 +1412,13 @@ def fused_func( T.func_attr({"tir.noalias": True}) Out_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(input_embeds[vi, vj]) T.writes(Out_intermediate[vi, vj]) Out_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1) for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add_1"): + with T.sblock("add_1"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(Out_intermediate[vi, vj]) T.writes(Out_intermediate_1[vi, vj]) @@ -1477,7 +1477,7 @@ def foo( ) for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) rotary[v0, v1, v2, v3] = Y[m + v1 - 1, v3] * X[v0, v1, v2, v3] @@ -1538,13 +1538,13 @@ def fused( for ax0, ax1, ax2, ax3 in T.grid( T.int64(1), sequence_length, T.int64(32), T.int64(128) ): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T_add[v_ax0, v_ax1, v_ax2, v_ax3] = ( X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, v_ax3] ) for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length, T.int64(32), T.int64(128)): - with T.block("rotary"): + with T.sblock("rotary"): v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * T_add[v0, v1, v2, v3] @@ -1585,7 +1585,7 @@ def sum_1d( X = T.match_buffer(X_handle, [num_elements], "float32") for i in range(num_elements): - with T.block("sum"): + with T.sblock("sum"): vi = T.axis.remap("R", [i]) with T.init(): Y[0] = 0.0 @@ -1626,7 +1626,7 @@ def fused( T.func_attr({"tir.noalias": True}) for i in range(T.int64(64)): - with T.block("sum"): + with T.sblock("sum"): vi = T.axis.remap("R", [i]) with T.init(): Y[0] = 0.0 @@ -1660,7 +1660,7 @@ def sum_1d( X = T.match_buffer(X_handle, [num_elements], "float32") for i in range(num_elements): - with T.block("sum"): + with T.sblock("sum"): vi = T.axis.remap("R", [i]) with T.init(): Sum[0] = 0.0 @@ -1673,7 +1673,7 @@ def sum_scalar( Sum: T.Buffer([T.int64(1)], "float32"), ): for i in range(T.int64(1)): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.remap("S", [i]) Sum[vi] = X[vi] + Y[vi] @@ -1728,21 +1728,21 @@ def fused( YSum = T.alloc_buffer([T.int64(1)], "float32") for i in range(T.int64(64)): - with T.block("XSum"): + with T.sblock("XSum"): vi = T.axis.remap("R", [i]) with T.init(): XSum[0] = 0.0 XSum[0] = XSum[0] + X[vi] for i in range(T.int64(16)): - with T.block("YSum"): + with T.sblock("YSum"): vi = T.axis.remap("R", [i]) with T.init(): YSum[0] = 0.0 YSum[0] = YSum[0] + Y[vi] for i in range(T.int64(1)): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.remap("S", [i]) Out[vi] = XSum[vi] + YSum[vi] @@ -1784,7 +1784,7 @@ def sum_1d( X = T.match_buffer(X_handle, [num_elements], "float32") for i in range(num_elements): - with T.block("sum"): + with T.sblock("sum"): vi = T.axis.remap("R", [i]) with T.init(): Y[0] = 0.0 @@ -1826,7 +1826,7 @@ def fused( T.func_attr({"tir.noalias": True}) for i in range(T.int64(64)): - with T.block("sum"): + with T.sblock("sum"): vi = T.axis.remap("R", [i]) with T.init(): Y[0] = 0.0 @@ -1854,7 +1854,7 @@ def add( Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) Out[vi, vj] = A[vi, vj] + T.float16(1.0) @@ -1865,7 +1865,7 @@ def take( T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), ): for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): - with T.block("T_take"): + with T.sblock("T_take"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] @@ -1908,11 +1908,11 @@ def fused_func( T.func_attr({"tir.noalias": True}) Out_handle_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) Out_handle_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1) for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): - with T.block("T_take"): + with T.sblock("T_take"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T_take[v_ax0, v_ax1] = Out_handle_intermediate[input_ids[v_ax0], v_ax1] @@ -1945,7 +1945,7 @@ def add_inplace( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) # T.reads(A[v_ax0, v_ax1], B[()]) # T.writes(A[v_ax0, v_ax1]) @@ -1955,7 +1955,7 @@ def add_inplace( def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) # T.reads(A[v_i0, v_i1]) # T.writes(A[v_i0, v_i1]) @@ -1965,7 +1965,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) # T.reads(A[v_ax0, v_ax1]) # T.writes(A[v_ax0, v_ax1]) @@ -2024,15 +2024,15 @@ def fused_add_exp_squeeze( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) x[v_i0, v_i1] = T.exp(x[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) x[v_ax0, v_ax1] = x[v_ax0, v_ax1] @@ -2068,7 +2068,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] @@ -2076,7 +2076,7 @@ def add( def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) @@ -2084,7 +2084,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) A[v_ax0, v_ax1] = A[v_ax0, v_ax1] @@ -2137,15 +2137,15 @@ def fused_add_exp_squeeze( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(10), T.int64(20)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1] @@ -2178,7 +2178,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] @@ -2229,15 +2229,15 @@ def fused_sums( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] @@ -2300,7 +2300,7 @@ def add( Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) Out[vi, vj] = A[vi, vj] + T.float16(1.0) @@ -2311,7 +2311,7 @@ def take( T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), ): for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): - with T.block("T_take"): + with T.sblock("T_take"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] @@ -2328,7 +2328,7 @@ def add(a: T.handle, b: T.handle, c: T.handle): C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) for iters in T.grid(T.int64(16), T.int64(32)): - with T.block("compute"): + with T.sblock("compute"): i, j = T.axis.remap("SS", iters) C[i, j] = A[i, j] + B[i, j] @@ -2373,12 +2373,12 @@ def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle): C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1]) for iters in T.grid(*X.shape): - with T.block("compute_Y"): + with T.sblock("compute_Y"): i, j = T.axis.remap("SS", iters) Temp[i, j] = X[i, j] + Y[i, j] for iters in T.grid(*X.shape): - with T.block("compute_Z"): + with T.sblock("compute_Z"): i, j = T.axis.remap("SS", iters) C[i, j] = Temp[i, j] + Z[i, j] @@ -2411,7 +2411,7 @@ def mul(a: T.handle, b: T.handle, c: T.handle): C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) for iters in T.grid(T.int64(16), T.int64(32)): - with T.block("compute"): + with T.sblock("compute"): i, j = T.axis.remap("SS", iters) C[i, j] = A[i, j] * B[i, j] @@ -2451,7 +2451,7 @@ class Before: def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): T.func_attr({"tir.noalias": True}) for i in range(10): - with T.block("compute1"): + with T.sblock("compute1"): vi = T.axis.spatial(10, i) y[vi] = x[vi] + T.float32(1.0) @@ -2459,7 +2459,7 @@ def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): def mul1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): T.func_attr({"tir.noalias": True}) for i in range(10): - with T.block("compute1"): + with T.sblock("compute1"): vi = T.axis.spatial(10, i) y[vi] = x[vi] * T.float32(2.0) @@ -2488,18 +2488,18 @@ def fused_add_mul(p_x: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": True}) x = T.match_buffer(p_x, (T.int64(10),)) y_intermediate_1 = T.match_buffer(p_output0, (T.int64(10),), elem_offset=T.int32(0)) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() y_intermediate = T.alloc_buffer((T.int64(10),), elem_offset=T.int32(0)) for i in range(10): - with T.block("compute1"): + with T.sblock("compute1"): vi = T.axis.spatial(10, i) T.reads(x[vi]) T.writes(y_intermediate[vi]) y_intermediate[vi] = x[vi] + T.float32(1.0) for i in range(10): - with T.block("compute2"): + with T.sblock("compute2"): vi = T.axis.spatial(10, i) T.reads(y_intermediate[vi]) T.writes(y_intermediate_1[vi]) diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py b/tests/python/relax/test_transform_fuse_transpose_matmul.py index ad98fd229963..c538ca5c8cee 100644 --- a/tests/python/relax/test_transform_fuse_transpose_matmul.py +++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py @@ -48,9 +48,9 @@ def NT_matmul( NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(x[v_i0, v_k], w[v_i1, v_k]) T.writes(NT_matmul[v_i0, v_i1]) @@ -103,9 +103,9 @@ def NT_matmul( NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(x[v_i0, v_k], w[v_i1, v_k]) T.writes(NT_matmul[v_i0, v_i1]) diff --git a/tests/python/relax/test_transform_gradient.py b/tests/python/relax/test_transform_gradient.py index 47c41ca108f9..1686e26a690a 100644 --- a/tests/python/relax/test_transform_gradient.py +++ b/tests/python/relax/test_transform_gradient.py @@ -1210,7 +1210,7 @@ def sum( ): T.func_attr({"tir.noalias": True}) for k0, k1 in T.grid(T.int64(3), T.int64(3)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) T.reads(rxplaceholder[v_k0, v_k1]) T.writes(rxplaceholder_red[()]) diff --git a/tests/python/relax/test_transform_gradient_te_register.py b/tests/python/relax/test_transform_gradient_te_register.py index 55f33f30198e..5a208c3bbebb 100644 --- a/tests/python/relax/test_transform_gradient_te_register.py +++ b/tests/python/relax/test_transform_gradient_te_register.py @@ -62,9 +62,9 @@ class Expected: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_1: T.Buffer((T.int64(5), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mul"): + with T.sblock("f_mul"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i0, v_i1]) T.writes(f_mul_1[v_i0, v_i1]) @@ -73,15 +73,15 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64 @T.prim_func(private=True) def f_mul_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), C: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_grad_1: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_grad_2: T.Buffer((T.int64(5), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mul_grad_1"): + with T.sblock("f_mul_grad_1"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(C[v_i0, v_i1], A[v_i0, v_i1]) T.writes(f_mul_grad_1[v_i0, v_i1]) f_mul_grad_1[v_i0, v_i1] = C[v_i0, v_i1] * A[v_i0, v_i1] for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mul_grad_2"): + with T.sblock("f_mul_grad_2"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(B[v_i0, v_i1], A[v_i0, v_i1]) T.writes(f_mul_grad_2[v_i0, v_i1]) @@ -149,9 +149,9 @@ class Before: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_1: T.Buffer((T.int64(5), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mul"): + with T.sblock("f_mul"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i0, v_i1]) T.writes(f_mul_1[v_i0, v_i1]) @@ -178,9 +178,9 @@ class Expected: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T.int64(5), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mul2"): + with T.sblock("f_mul2"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(f_mul2[v_i0, v_i1]) @@ -189,9 +189,9 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T. @T.prim_func(private=True) def f_mulk_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mulk_grad_1: T.Buffer((T.int64(5), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mulk_grad"): + with T.sblock("f_mulk_grad"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(f_mulk_grad_1[v_i0, v_i1]) @@ -258,9 +258,9 @@ class Before: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T.int64(5), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): - with T.block("f_mul2"): + with T.sblock("f_mul2"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(f_mul2[v_i0, v_i1]) @@ -291,9 +291,9 @@ def f_mul(var_A: T.handle, var_B: T.handle, var_f_mul: T.handle): A = T.match_buffer(var_A, (n, n)) B = T.match_buffer(var_B, (n, n)) f_mul_1 = T.match_buffer(var_f_mul, (n, n)) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(n, n): - with T.block("f_mul"): + with T.sblock("f_mul"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i0, v_i1]) T.writes(f_mul_1[v_i0, v_i1]) @@ -308,15 +308,15 @@ def f_mul_grad(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_f_mul_grad C = T.match_buffer(var_C, (n, n)) f_mul_grad_1 = T.match_buffer(var_f_mul_grad_1, (n, n)) f_mul_grad_2 = T.match_buffer(var_f_mul_grad_2, (n, n)) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(n, n): - with T.block("f_mul_grad_1"): + with T.sblock("f_mul_grad_1"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(C[v_i0, v_i1], A[v_i0, v_i1]) T.writes(f_mul_grad_1[v_i0, v_i1]) f_mul_grad_1[v_i0, v_i1] = C[v_i0, v_i1] * A[v_i0, v_i1] for i0, i1 in T.grid(n, n): - with T.block("f_mul_grad_2"): + with T.sblock("f_mul_grad_2"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(B[v_i0, v_i1], A[v_i0, v_i1]) T.writes(f_mul_grad_2[v_i0, v_i1]) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index f30afdae849c..95e7a4d5dddb 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -333,7 +333,7 @@ def sub( C: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): - with T.block("sub"): + with T.sblock("sub"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] - B[vi, vj] diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index ae0521a0e2f8..902657b99af2 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -33,7 +33,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -69,9 +69,9 @@ class Expected: def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -114,7 +114,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -151,9 +151,9 @@ class Expected: def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -199,7 +199,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -236,9 +236,9 @@ class Expected: def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -288,7 +288,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -325,9 +325,9 @@ class Expected: def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -445,7 +445,7 @@ def slice_buffer( slice_index: T.int64, ): for i in T.grid(16): - with T.block("slice_buffer"): + with T.sblock("slice_buffer"): vi = T.axis.remap("S", [i]) Output[vi] = Input[slice_index, vi] @@ -483,7 +483,7 @@ def slice_buffer( slice_index: T.int64, ): for i in T.grid(16): - with T.block("slice_buffer"): + with T.sblock("slice_buffer"): vi = T.axis.remap("S", [i]) Output[vi] = Input[slice_index, vi] @@ -500,7 +500,7 @@ def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle): w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32") out = T.match_buffer(var_out, (16, ic, 3, 3), "float32") for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -540,7 +540,7 @@ def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle): w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32") out = T.match_buffer(var_out, (16, ic, 3, 3), "float32") for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -581,7 +581,7 @@ def test_output_with_use_site(): class Module: @T.prim_func def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")): - with T.block("block"): + with T.sblock("block"): T.reads(x[()]) T.writes(y[()]) y[()] = x[()] @@ -603,7 +603,7 @@ def main_transform_params( class Expected: @T.prim_func def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")): - with T.block("block"): + with T.sblock("block"): T.reads(x[()]) T.writes(y[()]) y[()] = x[()] diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py index e1f5c54b1bca..230ade222031 100644 --- a/tests/python/relax/test_transform_legalize_ops.py +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -49,7 +49,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def add(rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) T.writes(T_add[ax0, ax1, ax2, ax3]) @@ -77,7 +77,7 @@ def mul2(x: R.Tensor((3, 3), "float32")): @T.prim_func(private=True) def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_id[v_ax0, v_ax1]) @@ -102,7 +102,7 @@ def mul2(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float3 @T.prim_func(private=True) def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_id[v_ax0, v_ax1]) @@ -112,7 +112,7 @@ def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: def multiply(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -195,9 +195,9 @@ def multiply( T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -219,9 +219,9 @@ def multiply( T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -241,9 +241,9 @@ def equal( T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): - with T.block("T_equal"): + with T.sblock("T_equal"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_equal[v_ax0, v_ax1]) @@ -402,7 +402,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for iters in T.grid(T.int64(32), T.int64(32)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", iters) C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] @@ -427,7 +427,7 @@ def add_llvm( ): T.func_attr({"target": T.target("llvm"), "tir.noalias": True}) for iters in T.grid(T.int64(32), T.int64(32)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", iters) C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py index 7b9405782433..970eae5b3577 100644 --- a/tests/python/relax/test_transform_legalize_ops_binary.py +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -45,7 +45,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_add[ax0, ax1, ax2, ax3]) @@ -76,7 +76,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_add[ax0, ax1]) @@ -107,7 +107,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_add[ax0, ax1]) @@ -153,7 +153,7 @@ def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_add: T rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_add = T.match_buffer(var_T_add, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_add[ax0, ax1, ax2, ax3]) @@ -194,7 +194,7 @@ def add( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] + rhs @@ -222,7 +222,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_divide[ax0, ax1, ax2, ax3]) @@ -253,7 +253,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_divide[ax0, ax1]) @@ -284,7 +284,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_divide[ax0, ax1]) @@ -330,7 +330,7 @@ def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_div rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_divide = T.match_buffer(var_T_divide, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_divide[ax0, ax1, ax2, ax3]) @@ -371,7 +371,7 @@ def divide( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] / rhs @@ -399,7 +399,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def floor_divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_floor_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_floor_divide"): + with T.sblock("T_floor_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_floor_divide[ax0, ax1, ax2, ax3]) @@ -430,7 +430,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_floor_divide"): + with T.sblock("T_floor_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_floor_divide[ax0, ax1]) @@ -461,7 +461,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_floor_divide"): + with T.sblock("T_floor_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_floor_divide[ax0, ax1]) @@ -507,7 +507,7 @@ def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_floor_divide = T.match_buffer(var_T_floor_divide, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_floor_divide"): + with T.sblock("T_floor_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_floor_divide[ax0, ax1, ax2, ax3]) @@ -548,7 +548,7 @@ def floor_divide( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_floordiv"): + with T.sblock("T_floordiv"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = T.floor(lhs[vi, vj, vk] / rhs) @@ -576,7 +576,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def multiply(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_multiply: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_multiply[ax0, ax1, ax2, ax3]) @@ -622,7 +622,7 @@ def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_m rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_multiply = T.match_buffer(var_T_multiply, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_multiply[ax0, ax1, ax2, ax3]) @@ -663,7 +663,7 @@ def multiply( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] * rhs @@ -685,9 +685,9 @@ class Expected: @T.prim_func(private=True) def power(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_power: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_power"): + with T.sblock("T_power"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -729,9 +729,9 @@ def power(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_powe b = T.int64() rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c, T.int64(1))) T_power = T.match_buffer(var_T_power, (a, b, c, d)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d): - with T.block("T_power"): + with T.sblock("T_power"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) T.writes(T_power[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -781,7 +781,7 @@ def power( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_power"): + with T.sblock("T_power"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = T.pow(lhs[vi, vj, vk], rhs) @@ -809,7 +809,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def subtract(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_subtract: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) @@ -855,7 +855,7 @@ def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_s rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_subtract = T.match_buffer(var_T_subtract, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) @@ -896,7 +896,7 @@ def subtract( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] - rhs @@ -927,7 +927,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_equal"): + with T.sblock("T_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_equal[ax0, ax1, ax2, ax3]) @@ -958,7 +958,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_equal"): + with T.sblock("T_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_equal[ax0, ax1]) @@ -989,7 +989,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_equal"): + with T.sblock("T_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_equal[ax0, ax1]) @@ -1035,7 +1035,7 @@ def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_equa rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_equal = T.match_buffer(var_T_equal, [a, b, c, d], dtype="bool") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_equal"): + with T.sblock("T_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_equal[ax0, ax1, ax2, ax3]) @@ -1076,7 +1076,7 @@ def equal( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] == rhs @@ -1104,7 +1104,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def greater(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_greater"): + with T.sblock("T_greater"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) T.writes(T_greater[ax0, ax1, ax2, ax3]) @@ -1135,7 +1135,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_greater"): + with T.sblock("T_greater"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_greater[ax0, ax1]) @@ -1166,7 +1166,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_greater"): + with T.sblock("T_greater"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_greater[ax0, ax1]) @@ -1212,7 +1212,7 @@ def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_gr rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_greater = T.match_buffer(var_T_greater, [a, b, c, d], dtype="bool") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_greater"): + with T.sblock("T_greater"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) T.writes(T_greater[ax0, ax1, ax2, ax3]) @@ -1253,7 +1253,7 @@ def greater( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = rhs < lhs[vi, vj, vk] @@ -1281,7 +1281,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def greater_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_greater_equal"): + with T.sblock("T_greater_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) T.writes(T_greater_equal[ax0, ax1, ax2, ax3]) @@ -1327,7 +1327,7 @@ def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, va rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_greater_equal = T.match_buffer(var_T_greater_equal, [a, b, c, d], dtype="bool") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_greater_equal"): + with T.sblock("T_greater_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) T.writes(T_greater_equal[ax0, ax1, ax2, ax3]) @@ -1368,7 +1368,7 @@ def greater_equal( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = rhs <= lhs[vi, vj, vk] @@ -1396,7 +1396,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def less(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_less"): + with T.sblock("T_less"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_less[ax0, ax1, ax2, ax3]) @@ -1442,7 +1442,7 @@ def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less: rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_less = T.match_buffer(var_T_less, [a, b, c, d], dtype="bool") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_less"): + with T.sblock("T_less"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_less[ax0, ax1, ax2, ax3]) @@ -1483,7 +1483,7 @@ def less( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] < rhs @@ -1511,7 +1511,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def less_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_less_equal"): + with T.sblock("T_less_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_less_equal[ax0, ax1, ax2, ax3]) @@ -1542,7 +1542,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_less_equal"): + with T.sblock("T_less_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_less_equal[ax0, ax1]) @@ -1573,7 +1573,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_less_equal"): + with T.sblock("T_less_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_less_equal[ax0, ax1]) @@ -1619,7 +1619,7 @@ def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_less_equal = T.match_buffer(var_T_less_equal, [a, b, c, d], dtype="bool") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_less_equal"): + with T.sblock("T_less_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_less_equal[ax0, ax1, ax2, ax3]) @@ -1660,7 +1660,7 @@ def less_equal( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] <= rhs @@ -1688,7 +1688,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def not_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_not_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_not_equal"): + with T.sblock("T_not_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_not_equal[ax0, ax1, ax2, ax3]) @@ -1734,7 +1734,7 @@ def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_not_equal = T.match_buffer(var_T_not_equal, [a, b, c, d], dtype="bool") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_not_equal"): + with T.sblock("T_not_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_not_equal[ax0, ax1, ax2, ax3]) @@ -1775,7 +1775,7 @@ def not_equal( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = lhs[vi, vj, vk] != rhs @@ -1804,7 +1804,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def maximum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_maximum: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_maximum"): + with T.sblock("T_maximum"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_maximum[ax0, ax1, ax2, ax3]) @@ -1835,7 +1835,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_maximum"): + with T.sblock("T_maximum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_maximum[ax0, ax1]) @@ -1866,7 +1866,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_maximum"): + with T.sblock("T_maximum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_maximum[ax0, ax1]) @@ -1912,7 +1912,7 @@ def maximum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_ma rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_maximum = T.match_buffer(var_T_maximum, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_maximum"): + with T.sblock("T_maximum"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_maximum[ax0, ax1, ax2, ax3]) @@ -1953,7 +1953,7 @@ def maximum( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = T.max(lhs[vi, vj, vk], rhs) @@ -1982,7 +1982,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") def minimum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_minimum: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_minimum"): + with T.sblock("T_minimum"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_minimum[ax0, ax1, ax2, ax3]) @@ -2013,7 +2013,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_minimum"): + with T.sblock("T_minimum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_minimum[ax0, ax1]) @@ -2044,7 +2044,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_minimum"): + with T.sblock("T_minimum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_minimum[ax0, ax1]) @@ -2090,7 +2090,7 @@ def minimum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_mi rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_minimum = T.match_buffer(var_T_minimum, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_minimum"): + with T.sblock("T_minimum"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_minimum[ax0, ax1, ax2, ax3]) @@ -2131,7 +2131,7 @@ def minimum( ): T.func_attr({"tir.noalias": True}) for i, j, k in T.grid(*lhs.shape): - with T.block("T_add"): + with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) output[vi, vj, vk] = T.min(lhs[vi, vj, vk], rhs) diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 23cc2c767a10..3e763b0b63f9 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -110,9 +110,9 @@ class Expected: @T.prim_func(private=True) def reshape(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_reshape: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(10), T.int64(2), T.int64(5)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[((v_ax1 * T.int64(5) + v_ax2) // T.int64(10) + v_ax0) % T.int64(10), (v_ax1 * T.int64(5) + v_ax2) % T.int64(10)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) @@ -121,9 +121,9 @@ def reshape(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_reshape: T.Buf @T.prim_func(private=True) def transpose(A: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(10), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(10), T.int64(5)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v_ax1, v_ax0, v_ax2]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2]) diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index ca5ba0f43751..c240de579e5a 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -44,7 +44,7 @@ def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "int32"): def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -75,7 +75,7 @@ def main() -> R.Tensor((2, 3), "int32"): def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -106,7 +106,7 @@ def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -144,7 +144,7 @@ def full(rxplaceholder: T.Buffer((), "int32"), var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -175,7 +175,7 @@ def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(( def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -206,7 +206,7 @@ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -237,7 +237,7 @@ def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(( def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float64")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -275,7 +275,7 @@ def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -306,7 +306,7 @@ def main() -> R.Tensor((2, 3), "float32"): def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -344,7 +344,7 @@ def ones(var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -375,7 +375,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -413,7 +413,7 @@ def ones(var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -444,7 +444,7 @@ def main() -> R.Tensor((2, 3), "float32"): def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -482,7 +482,7 @@ def zeros(var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -513,7 +513,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -551,7 +551,7 @@ def zeros(var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(T_full[ax0, ax1]) @@ -607,7 +607,7 @@ def arange(var_T_arange: T.handle, n: T.int64): T.func_attr({"tir.noalias": True}) T_arange = T.match_buffer(var_T_arange, (n // T.int64(2),), "int64") for ax0 in range(n // T.int64(2)): - with T.block("T_arange"): + with T.sblock("T_arange"): v_ax0 = T.axis.spatial(n // T.int64(2), ax0) T_arange[v_ax0] = v_ax0 * T.int64(2) + T.int64(1) # fmt: on @@ -636,7 +636,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): def tril(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): - with T.block("trilu"): + with T.sblock("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(trilu[i0_1, i1_1, i2_1]) @@ -678,7 +678,7 @@ def tril(var_rxplaceholder: T.handle, var_trilu: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") for i0, i1, i2 in T.grid(m, n, k): - with T.block("trilu"): + with T.sblock("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(trilu[i0_1, i1_1, i2_1]) @@ -709,7 +709,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): def triu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): - with T.block("trilu"): + with T.sblock("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(trilu[i0_1, i1_1, i2_1]) @@ -751,7 +751,7 @@ def triu(var_rxplaceholder: T.handle, var_trilu: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") for i0, i1, i2 in T.grid(m, n, k): - with T.block("trilu"): + with T.sblock("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(trilu[i0_1, i1_1, i2_1]) @@ -785,7 +785,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "int32"): def cast(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(compute[i0_1, i1_1, i2_1]) @@ -845,7 +845,7 @@ def cast(var_rxplaceholder: T.handle, var_compute: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) diff --git a/tests/python/relax/test_transform_legalize_ops_distributed.py b/tests/python/relax/test_transform_legalize_ops_distributed.py index b727fc330ab1..fbec6f2b07a2 100644 --- a/tests/python/relax/test_transform_legalize_ops_distributed.py +++ b/tests/python/relax/test_transform_legalize_ops_distributed.py @@ -38,9 +38,9 @@ class Expected: @T.prim_func(private=True) def strided_slice(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), redistribute_replica_to_shard: T.Buffer((T.int64(10), T.int64(5)), "float32"), worker_id: T.int64): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(10), T.int64(5)): - with T.block("redistribute_replica_to_shard"): + with T.sblock("redistribute_replica_to_shard"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, worker_id * T.int64(5) + v_i1]) T.writes(redistribute_replica_to_shard[v_i0, v_i1]) diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index 44469acdc1c0..d326f751a80d 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -34,25 +34,25 @@ class Expected: @T.prim_func(private=True) def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), rxplaceholder_3: T.Buffer((T.int64(4),), "float32"), pred_grad: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): all_weights = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) T_broadcast_to = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) all_weights_red = T.alloc_buffer(()) T_divide = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("all_weights"): + with T.sblock("all_weights"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder_3[rxplaceholder_2[v_i0, v_i1, v_i2]], rxplaceholder_2[v_i0, v_i1, v_i2]) T.writes(all_weights[v_i0, v_i1, v_i2]) all_weights[v_i0, v_i1, v_i2] = rxplaceholder_3[rxplaceholder_2[v_i0, v_i1, v_i2]] for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder[()]) T.writes(T_broadcast_to[v_ax0, v_ax1, v_ax2]) T_broadcast_to[v_ax0, v_ax1, v_ax2] = rxplaceholder[()] for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("all_weights_red"): + with T.sblock("all_weights_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(all_weights[v_k0, v_k1, v_k2]) T.writes(all_weights_red[()]) @@ -60,13 +60,13 @@ def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T all_weights_red[()] = T.float32(0) all_weights_red[()] = all_weights_red[()] + all_weights[v_k0, v_k1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_broadcast_to[v_ax0, v_ax1, v_ax2], all_weights_red[()]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = T_broadcast_to[v_ax0, v_ax1, v_ax2] / all_weights_red[()] for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("pred_grad"): + with T.sblock("pred_grad"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_2[v_i0, v_i2, v_i3], all_weights[v_i0, v_i2, v_i3], T_divide[v_i0, v_i2, v_i3]) T.writes(pred_grad[v_i0, v_i1, v_i2, v_i3]) @@ -97,32 +97,32 @@ class Expected: @T.prim_func(private=True) def te_nll_loss_backward_no_weight(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), pred_grad: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_full = T.alloc_buffer((T.int64(3),)) all_weights = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) T_broadcast_to = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) all_weights_red = T.alloc_buffer(()) T_divide = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) for ax0 in range(T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads() T.writes(T_full[v_ax0]) T_full[v_ax0] = T.float32(1) for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("all_weights"): + with T.sblock("all_weights"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_full[rxplaceholder_2[v_i0, v_i1, v_i2]], rxplaceholder_2[v_i0, v_i1, v_i2]) T.writes(all_weights[v_i0, v_i1, v_i2]) all_weights[v_i0, v_i1, v_i2] = T_full[rxplaceholder_2[v_i0, v_i1, v_i2]] for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder[()]) T.writes(T_broadcast_to[v_ax0, v_ax1, v_ax2]) T_broadcast_to[v_ax0, v_ax1, v_ax2] = rxplaceholder[()] for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("all_weights_red"): + with T.sblock("all_weights_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(all_weights[v_k0, v_k1, v_k2]) T.writes(all_weights_red[()]) @@ -130,13 +130,13 @@ def te_nll_loss_backward_no_weight(rxplaceholder: T.Buffer((), "float32"), rxpla all_weights_red[()] = T.float32(0) all_weights_red[()] = all_weights_red[()] + all_weights[v_k0, v_k1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_broadcast_to[v_ax0, v_ax1, v_ax2], all_weights_red[()]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = T_broadcast_to[v_ax0, v_ax1, v_ax2] / all_weights_red[()] for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("pred_grad"): + with T.sblock("pred_grad"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_2[v_i0, v_i2, v_i3], all_weights[v_i0, v_i2, v_i3], T_divide[v_i0, v_i2, v_i3]) T.writes(pred_grad[v_i0, v_i1, v_i2, v_i3]) @@ -173,27 +173,27 @@ def main(output_grad: R.Tensor((), dtype="float32"), predictions: R.Tensor((4,), @T.prim_func(private=True) def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((), "int64"), rxplaceholder_3: T.Buffer((T.int64(4),), "float32"), pred_grad: T.Buffer((T.int64(4),), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): all_weights = T.alloc_buffer(()) T_broadcast_to = T.alloc_buffer(()) T_divide = T.alloc_buffer(()) - with T.block("all_weights"): + with T.sblock("all_weights"): vi = T.axis.spatial(T.int64(1), T.int64(0)) T.reads(rxplaceholder_3[rxplaceholder_2[()]], rxplaceholder_2[()]) T.writes(all_weights[()]) all_weights[()] = rxplaceholder_3[rxplaceholder_2[()]] - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): vi = T.axis.spatial(1, T.int64(0)) T.reads(rxplaceholder[()]) T.writes(T_broadcast_to[()]) T_broadcast_to[()] = rxplaceholder[()] - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_broadcast_to[()], all_weights[()]) T.writes(T_divide[()]) T_divide[()] = T_broadcast_to[()] / all_weights[()] for i in range(T.int64(4)): - with T.block("pred_grad"): + with T.sblock("pred_grad"): v_i = T.axis.spatial(T.int64(4), i) T.reads(rxplaceholder_2[()], all_weights[()], T_divide[()]) T.writes(pred_grad[v_i]) @@ -218,18 +218,18 @@ class Expected: @T.prim_func(private=True) def max_pool2d_backward(A: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), B: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(15), T.int64(13))) maxpool_grad_argmax_v0 = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "int64") maxpool_grad_argmax_v1 = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(3), T.int64(2), T.int64(15), T.int64(13)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(B[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 - T.int64(1)]) T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3]) pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(2) <= v_ax2 and v_ax2 < T.int64(12) and T.int64(1) <= v_ax3 and v_ax3 < T.int64(11), B[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38)) for ax0, ax1, ax2, ax3, dh, dw in T.grid(T.int64(3), T.int64(2), T.int64(6), T.int64(5), T.int64(5), T.int64(5)): - with T.block("maxpool_grad_argmax"): + with T.sblock("maxpool_grad_argmax"): v_ax0, v_ax1, v_ax2, v_ax3, v_dh, v_dw = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, dh, dw]) T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw]) T.writes(maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3], maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -241,7 +241,7 @@ def max_pool2d_backward(A: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64 maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_maxpool_grad_argmax_v0 maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_maxpool_grad_argmax_v1 for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), T.int64(10), T.int64(10), T.int64(3), T.int64(3)): - with T.block("T_pool_grad"): + with T.sblock("T_pool_grad"): v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww]) T.reads(maxpool_grad_argmax_v0[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww], A[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww]) T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -274,9 +274,9 @@ class Expected: @T.prim_func(private=True) def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), T.int64(10), T.int64(10), T.int64(3), T.int64(3)): - with T.block("T_pool_grad"): + with T.sblock("T_pool_grad"): v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww]) T.reads(rxplaceholder[v_ax0, v_ax1, T.Div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh, T.Div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww]) T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -312,7 +312,7 @@ def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, va rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(5)), offset_factor=1) rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(3), T.int64(4), T.int64(5)), offset_factor=1) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(2),), "int32", offset_factor=1) - with T.block("take_backward"): + with T.sblock("take_backward"): for i in range(T.int64(60)): out_buf[i // T.int64(5) // T.int64(4), i // T.int64(5) % T.int64(4), i % T.int64(5)] = T.float32(0) for parallel, serial in T.grid(T.int64(15), T.int64(2)): @@ -351,7 +351,7 @@ def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, va rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (m, n), offset_factor=1) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (i,), "int32", offset_factor=1) out_buf = T.match_buffer(var_take_backward, (m, n)) - with T.block("take_backward"): + with T.sblock("take_backward"): for i_1 in range(m * n): out_buf[i_1 // n, i_1 % n] = T.float32(0) for parallel, serial in T.grid(m, i): diff --git a/tests/python/relax/test_transform_legalize_ops_image.py b/tests/python/relax/test_transform_legalize_ops_image.py index 7c06ed46b64f..bda0010dcf13 100644 --- a/tests/python/relax/test_transform_legalize_ops_image.py +++ b/tests/python/relax/test_transform_legalize_ops_image.py @@ -41,7 +41,7 @@ def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor((2, 16, 16, 3), "floa def resize2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(8), T.int64(8), T.int64(3)), "float32"), resize: T.Buffer((T.int64(2), T.int64(16), T.int64(16), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(16), T.int64(3)): - with T.block("resize"): + with T.sblock("resize"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1]) T.writes(resize[i0_1, i1_1, i2_1, i3_1]) @@ -88,7 +88,7 @@ def resize2d(var_rxplaceholder: T.handle, var_resize: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w, T.int64(16)], dtype="float32") resize = T.match_buffer(var_resize, [n, c, oh, ow, T.int64(16)], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(n, c, oh, ow, T.int64(16)): - with T.block("resize"): + with T.sblock("resize"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(rxplaceholder[i0_1, i1_1, T.int64(0) : T.max(h, T.int64(1)), T.int64(0) : T.max(w, T.int64(1)), i4_1]) T.writes(resize[i0_1, i1_1, i2_1, i3_1, i4_1]) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 44419e51e7dc..dc43f84f07d0 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -46,7 +46,7 @@ def main(x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((4,), "int64")) -> def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "int64"), T_take: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): - with T.block("T_take"): + with T.sblock("T_take"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, rxplaceholder_1[ax1], ax2], rxplaceholder_1[ax1]) T.writes(T_take[ax0, ax1, ax2]) @@ -77,7 +77,7 @@ def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) -> R.Tensor( def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), index: T.int64, T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i2 in T.grid(T.int64(2), T.int64(4)): - with T.block("T_take"): + with T.sblock("T_take"): ax0, ax2 = T.axis.remap("SS", [i0, i2]) T.reads(rxplaceholder[ax0, index, ax2]) T.writes(T_take[ax0, ax2]) @@ -108,7 +108,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4), "float32"): def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i2 in T.grid(T.int64(2), T.int64(4)): - with T.block("T_take"): + with T.sblock("T_take"): ax0, ax2 = T.axis.remap("SS", [i0, i2]) T.reads(rxplaceholder[ax0, T.int64(0), ax2]) T.writes(T_take[ax0, ax2]) @@ -149,7 +149,7 @@ def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_take: rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [i], dtype="int64") T_take = T.match_buffer(var_T_take, [m, i], dtype="float32") for i0, i1 in T.grid(m, i): - with T.block("T_take"): + with T.sblock("T_take"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, rxplaceholder_1[ax1]], rxplaceholder_1[ax1]) T.writes(T_take[ax0, ax1]) @@ -184,7 +184,7 @@ def take(x_handle: T.handle, T_take: T.Buffer((T.int64(2), T.int64(4)), "float32 T.func_attr({"tir.noalias": True}) for i0, i2 in T.grid(T.int64(2), T.int64(4)): - with T.block("T_take"): + with T.sblock("T_take"): ax0, ax2 = T.axis.remap("SS", [i0, i2]) T.reads(rxplaceholder[ax0, n-1, ax2]) T.writes(T_take[ax0, ax2]) @@ -215,7 +215,7 @@ def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4, 9, 10, 3) def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(4), T.int64(9), T.int64(10), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(9), T.int64(10), T.int64(3)): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)]) T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3]) @@ -245,9 +245,9 @@ def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")): @T.prim_func(private=True) def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(7), T.int64(9), T.int64(10), T.int64(2)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(7), T.int64(9), T.int64(10), T.int64(2)): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0 + T.int64(1), v_ax1, v_ax2, v_ax3 + T.int64(2)]) T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -276,9 +276,9 @@ def strided_slice(var_A: T.handle, var_T_dynamic_strided_slice_with_axes: T.hand m, n = T.int64(), T.int64() A = T.match_buffer(var_A, (m, n)) T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.int64(3), n)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), n): - with T.block("T_dynamic_strided_slice_with_axes"): + with T.sblock("T_dynamic_strided_slice_with_axes"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0 * T.int64(3) + T.int64(1), v_ax1]) T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1]) @@ -322,7 +322,7 @@ def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T. rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") for i0, i1 in T.grid(T.int64(3), n): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1]) T.writes(T_strided_slice_with_axes[ax0, ax1]) @@ -358,7 +358,7 @@ def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T. rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") for i0, i1 in T.grid(T.int64(3), n): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1]) T.writes(T_strided_slice_with_axes[ax0, ax1]) @@ -390,7 +390,7 @@ def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T. rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") for i0, i1 in T.grid(T.int64(3), n): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1]) T.writes(T_strided_slice_with_axes[ax0, ax1]) @@ -422,9 +422,9 @@ def dynamic_strided_slice( T_strided_slice_dynamic = T.match_buffer( var_T_strided_slice_dynamic, (s, s_1, s_2, s_3) ) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(s, s_1, s_2, s_3): - with T.block("T_strided_slice_dynamic"): + with T.sblock("T_strided_slice_dynamic"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( rxplaceholder[ @@ -463,9 +463,9 @@ def shape_func( T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),), "int64"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i in range(T.int64(4)): - with T.block("T_shape_func_strided_slice_dynamic"): + with T.sblock("T_shape_func_strided_slice_dynamic"): v_i = T.axis.spatial(T.int64(4), i) T.reads( rxplaceholder_3[v_i], rxplaceholder_1[v_i], rxplaceholder_2[v_i] @@ -710,9 +710,9 @@ def dynamic_strided_slice( rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), n)) s, s_1 = T.int64(), T.int64() T_strided_slice_dynamic = T.match_buffer(var_T_strided_slice_dynamic, (s, s_1)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(s, s_1): - with T.block("T_strided_slice_dynamic"): + with T.sblock("T_strided_slice_dynamic"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads( rxplaceholder_3[ @@ -743,9 +743,9 @@ def shape_func( T.func_attr({"tir.noalias": True}) n = T.int64() rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), n)) - # with T.block("root"): + # with T.sblock("root"): for i in range(T.int64(2)): - with T.block("T_shape_func_strided_slice_dynamic"): + with T.sblock("T_shape_func_strided_slice_dynamic"): v_i = T.axis.spatial(T.int64(2), i) T.reads(rxplaceholder_2[v_i], rxplaceholder[v_i], rxplaceholder_1[v_i]) T.writes(T_shape_func_strided_slice_dynamic[v_i]) @@ -905,7 +905,7 @@ def main(x: R.Tensor((4,), "float32"), y: R.Tensor((2, 3, 4, 5), "float32")) -> def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(5), T.int64(4)): - with T.block("matmul"): + with T.sblock("matmul"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[k], rxplaceholder_1[i0_1, i1_1, k, i2_1]) T.writes(matmul[i0_1, i1_1, i2_1]) @@ -938,7 +938,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((5,), "float32")) -> def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer(T.int64(5), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("matmul"): + with T.sblock("matmul"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, k], rxplaceholder_1[k]) T.writes(matmul[i0_1, i1_1, i2_1]) @@ -971,7 +971,7 @@ def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.Tensor def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "float32"), matmul: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) for i0 in T.serial(T.int64(4)): - with T.block("matmul"): + with T.sblock("matmul"): k = T.axis.reduce(T.int64(4), i0) T.reads(rxplaceholder[k], rxplaceholder_1[k]) T.writes(matmul[()]) @@ -1004,7 +1004,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float16"), y: R.Tensor((6, 2, 3, 5, 7), "flo def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(5), T.int64(7)), "float16"), matmul: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7), T.int64(5)): - with T.block("matmul"): + with T.sblock("matmul"): i0_1, i1_1, i2_1, i3_1, i4_1, k = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[i1_1, i2_1, i3_1, k], rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1]) T.writes(matmul[i0_1, i1_1, i2_1, i3_1, i4_1]) @@ -1056,7 +1056,7 @@ def matmul(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmu rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, T.int64(1), c, k, n], dtype="float32") matmul = T.match_buffer(var_matmul, [a, b, c, m, n], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(a, b, c, m, n, k): - with T.block("matmul"): + with T.sblock("matmul"): i0_1, i1_1, i2_1, i3_1, i4_1, k_1 = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[i1_1, T.int64(0), i3_1, k_1], rxplaceholder_1[i0_1, T.int64(0), i2_1, k_1, i4_1]) T.writes(matmul[i0_1, i1_1, i2_1, i3_1, i4_1]) @@ -1083,9 +1083,9 @@ class Expected: @T.prim_func(private=True) def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(7)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(7)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(7), T.int64(5)): - with T.block("matmul"): + with T.sblock("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) @@ -1131,7 +1131,7 @@ def einsum( ): T.func_attr({"tir.noalias": True}) for ax0, ax1, j in T.grid(T.int64(2), T.int64(4), T.int64(3)): - with T.block("T_einsum"): + with T.sblock("T_einsum"): v_ax0, v_ax1, v_j = T.axis.remap("SSR", [ax0, ax1, j]) T.reads(rxplaceholder[v_ax0, v_j], rxplaceholder_1[v_j, v_ax1]) T.writes(T_einsum[v_ax0, v_ax1]) @@ -1183,7 +1183,7 @@ def einsum( rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (b, c)) T_einsum = T.match_buffer(var_T_einsum, (a, c)) for ax0, ax1, j in T.grid(a, c, b): - with T.block("T_einsum"): + with T.sblock("T_einsum"): v_ax0, v_ax1, v_j = T.axis.remap("SSR", [ax0, ax1, j]) T.reads(rxplaceholder[v_ax0, v_j], rxplaceholder_1[v_j, v_ax1]) T.writes(T_einsum[v_ax0, v_ax1]) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 17e0160c6183..ec5186f4f0a5 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -44,7 +44,7 @@ def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32") def broadcast_to(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3)), "float32"), T_broadcast_to: T.Buffer((T.int64(4), T.int64(2), T.int64(5), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(2), T.int64(5), T.int64(3)): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax1, T.int64(0), ax3]) T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) @@ -89,7 +89,7 @@ def broadcast_to(var_rxplaceholder: T.handle, var_T_broadcast_to: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), d], dtype="float32") T_broadcast_to = T.match_buffer(var_T_broadcast_to, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax1, T.int64(0), ax3]) T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) @@ -120,7 +120,7 @@ def main(x1: R.Tensor((1, 2, 3), "float32"), x2: R.Tensor((1, 3, 3), "float32"), def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(3), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(1), T.int64(4), T.int64(3)), "float32"), T_concat: T.Buffer((T.int64(1), T.int64(9), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(1), T.int64(9), T.int64(3)): - with T.block("T_concat"): + with T.sblock("T_concat"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2]) T.writes(T_concat[ax0, ax1, ax2]) @@ -153,7 +153,7 @@ def main(t: R.Tuple(R.Tensor((3, 4), "float32"), R.Tensor((3, 5), "float32"))) - def concatenate(rxplaceholder: T.Buffer((T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(5)), "float32"), T_concat: T.Buffer((T.int64(3), T.int64(9)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(3), T.int64(9)): - with T.block("T_concat"): + with T.sblock("T_concat"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder_1[ax0, ax1 - T.int64(4)], rxplaceholder[ax0, ax1]) T.writes(T_concat[ax0, ax1]) @@ -203,7 +203,7 @@ def concatenate(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_ rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [a, b2], dtype="float32") T_concat = T.match_buffer(var_T_concat, [a, b0 + b1 + b2], dtype="float32") for i0, i1 in T.grid(a, b0 + b1 + b2): - with T.block("T_concat"): + with T.sblock("T_concat"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder_2[ax0, ax1 - b0 - b1], rxplaceholder_1[ax0, ax1 - b0], rxplaceholder[ax0, ax1]) T.writes(T_concat[ax0, ax1]) @@ -234,7 +234,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1) def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), expand_dims: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap("SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads(rxplaceholder[i0_1, i4_1, i6_1]) T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) @@ -276,7 +276,7 @@ def expand_dims(var_rxplaceholder: T.handle, var_expand_dims: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") expand_dims = T.match_buffer(var_expand_dims, [a, T.int64(1), b, T.int64(1), c, T.int64(1)], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), b, T.int64(1), c, T.int64(1)): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[i0_1, i2_1, i4_1]) T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) @@ -307,7 +307,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((24,), "float32"): def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(24), "float32")): T.func_attr({"tir.noalias": True}) for i0 in T.serial(T.int64(24)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0 = T.axis.spatial(T.int64(24), i0) T.reads(rxplaceholder[ax0 % T.int64(24) // T.int64(12), ax0 % T.int64(12) // T.int64(4), ax0 % T.int64(4)]) T.writes(T_reshape[ax0]) @@ -338,7 +338,7 @@ def main(x: R.Tensor((), "float32")) -> R.Tensor((1,), "float32"): def reshape(rxplaceholder: T.Buffer((), "float32"), T_reshape: T.Buffer(T.int64(1), "float32")): T.func_attr({"tir.noalias": True}) for i0 in T.serial(T.int64(1)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0 = T.axis.spatial(T.int64(1), i0) T.reads(rxplaceholder[()]) T.writes(T_reshape[ax0]) @@ -380,7 +380,7 @@ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") T_reshape = T.match_buffer(var_T_reshape, [a * b * c], dtype="float32") for i0 in T.serial(a * b * c): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0 = T.axis.spatial(a * b * c, i0) T.reads(rxplaceholder[ax0 // c // b % a, ax0 // c % b, ax0 % c]) T.writes(T_reshape[ax0]) @@ -411,7 +411,7 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float3 def transpose(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(4), T.int64(3), T.int64(1)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(4), T.int64(3), T.int64(1)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) T.writes(T_transpose[ax0, ax1, ax2, ax3]) @@ -456,7 +456,7 @@ def transpose(var_rxplaceholder: T.handle, var_T_transpose: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") T_transpose = T.match_buffer(var_T_transpose, [b, d, c, a], dtype="float32") for i0, i1, i2, i3 in T.grid(b, d, c, a): - with T.block("T_transpose"): + with T.sblock("T_transpose"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) T.writes(T_transpose[ax0, ax1, ax2, ax3]) @@ -487,7 +487,7 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): def reshape(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(8), T.int64(3)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[T.int64(0), (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), (ax0 * T.int64(3) + ax1) % T.int64(4)]) T.writes(T_reshape[ax0, ax1]) @@ -516,9 +516,9 @@ def reshape( T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(8), T.int64(3)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads( rxplaceholder[ @@ -575,7 +575,7 @@ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") T_reshape = T.match_buffer(var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32") for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[(ax0 * b * T.int64(2) + ax1) // b % a, (ax0 * b * T.int64(2) + ax1) % b]) T.writes(T_reshape[ax0, ax1]) @@ -617,7 +617,7 @@ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32" ) for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads( rxplaceholder[ @@ -653,9 +653,9 @@ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(10), b)) T_reshape = T.match_buffer(var_T_reshape, (T.int64(5), b * T.int64(2))) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(5), b * T.int64(2)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads( rxplaceholder[ @@ -728,7 +728,7 @@ def reshape( N = T.int64() T_reshape = T.match_buffer(var_T_reshape, [M,N], "float32") for i,j in T.grid(M,N): - with T.block("T_reshape"): + with T.sblock("T_reshape"): vi,vj = T.axis.remap('SS',[i,j]) T.reads(rxplaceholder[(vi*N + vj) % 16]) T.writes(T_reshape[vi,vj]) @@ -758,19 +758,19 @@ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 3, 4), "fl def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_split_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_2: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): - with T.block("T_split"): + with T.sblock("T_split"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, ax2]) T.writes(T_split[ax0, ax1, ax2]) T_split[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): - with T.block("T_split_1"): + with T.sblock("T_split_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1 + T.int64(3), ax2]) T.writes(T_split_1[ax0, ax1, ax2]) T_split_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(3), ax2] for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): - with T.block("T_split_2"): + with T.sblock("T_split_2"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1 + T.int64(7), ax2]) T.writes(T_split_2[ax0, ax1, ax2]) @@ -801,19 +801,19 @@ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "fl def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, ax2]) T.writes(T_split_sections[ax0, ax1, ax2]) T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2]) T.writes(T_split_sections_1[ax0, ax1, ax2]) T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(4), ax2] for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)): - with T.block("T_split_sections_2"): + with T.sblock("T_split_sections_2"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2]) T.writes(T_split_sections_2[ax0, ax1, ax2]) @@ -845,13 +845,13 @@ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "fl def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, ax2]) T.writes(T_split_sections[ax0, ax1, ax2]) T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) T.writes(T_split_sections_1[ax0, ax1, ax2]) @@ -891,19 +891,19 @@ def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_spl T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2)], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("T_split_sections"): + with T.sblock("T_split_sections"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) T.writes(T_split_sections[ax0, ax1]) T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] for i0, i1 in T.grid(m, n): - with T.block("T_split_sections_1"): + with T.sblock("T_split_sections_1"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1 + n]) T.writes(T_split_sections_1[ax0, ax1]) T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n] for i0, i1 in T.grid(m, n): - with T.block("T_split_sections_2"): + with T.sblock("T_split_sections_2"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) T.writes(T_split_sections_2[ax0, ax1]) @@ -934,7 +934,7 @@ def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), " def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(1), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(4)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, T.int64(0), ax1, ax2, T.int64(0), ax3]) T.writes(T_squeeze[ax0, ax1, ax2, ax3]) @@ -965,7 +965,7 @@ def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) : def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2]) T.writes(T_squeeze[ax0, ax1, ax2]) @@ -1004,7 +1004,7 @@ def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, T.int64(1), b, T.int64(1)], dtype="float32") T_squeeze = T.match_buffer(var_T_squeeze, [a, b, T.int64(1)], dtype="float32") for i0, i1, i2 in T.grid(a, b, T.int64(1)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, T.int64(0), ax1, ax2]) T.writes(T_squeeze[ax0, ax1, ax2]) @@ -1035,7 +1035,7 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Te def collapse_sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, k0 = T.axis.remap("SSR", [i0, i1, i2]) T.reads(rxplaceholder[k0, ax1]) T.writes(rxplaceholder_red[ax0, ax1]) @@ -1071,7 +1071,7 @@ def main( def collapse_sum(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), T.int64(3)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2]) T.reads(rxplaceholder[v_k0, v_ax0, v_k2]) T.writes(rxplaceholder_red[v_ax0, v_ax1]) @@ -1103,9 +1103,9 @@ def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((6, 2, 3), dtype=" @T.prim_func(private=True) def repeat(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_repeat: T.Buffer((T.int64(6), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(6), T.int64(2), T.int64(3)): - with T.block("T_repeat"): + with T.sblock("T_repeat"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2]) T.writes(T_repeat[v_ax0, v_ax1, v_ax2]) @@ -1140,10 +1140,10 @@ def repeat( T_repeat: T.Buffer((T.int64(36),), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_reshape = T.alloc_buffer((T.int64(18),)) for ax0 in range(T.int64(18)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(18), ax0) T.reads( rxplaceholder[ @@ -1159,7 +1159,7 @@ def repeat( v_ax0 % T.int64(3), ] for ax0 in range(T.int64(36)): - with T.block("T_repeat"): + with T.sblock("T_repeat"): v_ax0 = T.axis.spatial(T.int64(36), ax0) T.reads(T_reshape[v_ax0 // T.int64(2)]) T.writes(T_repeat[v_ax0]) @@ -1189,9 +1189,9 @@ def repeat(var_rxplaceholder: T.handle, var_T_repeat: T.handle): c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c)) T_repeat = T.match_buffer(var_T_repeat, (T.int64(2) * a, b, c)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(a * T.int64(2), b, c): - with T.block("T_repeat"): + with T.sblock("T_repeat"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2]) T.writes(T_repeat[v_ax0, v_ax1, v_ax2]) @@ -1224,9 +1224,9 @@ class Expected: @T.prim_func(private=True) def tile(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_tile: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(9)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(9)): - with T.block("T_tile"): + with T.sblock("T_tile"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax1 % T.int64(3), v_ax2 % T.int64(2), v_ax3 % T.int64(3)]) T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1261,9 +1261,9 @@ def tile(var_rxplaceholder: T.handle, var_T_tile: T.handle): c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c)) T_tile = T.match_buffer(var_T_tile, (T.int64(2), a, b * T.int64(2), c * T.int64(3))) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), a, b * T.int64(2), c * T.int64(3)): - with T.block("T_tile"): + with T.sblock("T_tile"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax1 % a, v_ax2 % b, v_ax3 % c]) T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1305,7 +1305,7 @@ def flip( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_reverse_sequence"): + with T.sblock("T_reverse_sequence"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[T.int64(1) - v_ax0, v_ax1]) T.writes(T_reverse_sequence[v_ax0, v_ax1]) @@ -1347,7 +1347,7 @@ def flip(var_rxplaceholder: T.handle, var_T_reverse_sequence: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b)) T_reverse_sequence = T.match_buffer(var_T_reverse_sequence, (a, b)) for ax0, ax1 in T.grid(a, b): - with T.block("T_reverse_sequence"): + with T.sblock("T_reverse_sequence"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, b - v_ax1 - T.int64(1)]) T.writes(T_reverse_sequence[v_ax0, v_ax1]) @@ -1388,7 +1388,7 @@ def scatter_elements( rxplaceholder_2 = T.match_buffer( var_rxplaceholder_2, (T.int64(2), T.int64(2)), offset_factor=1 ) - with T.block("scatter_elements_generic"): + with T.sblock("scatter_elements_generic"): for i in T.parallel(T.int64(16)): out_buf[i // T.int64(4), i % T.int64(4)] = rxplaceholder[ i // T.int64(4), i % T.int64(4) @@ -1483,7 +1483,7 @@ def scatter_elements( ) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (m, n), offset_factor=1) out_buf = T.match_buffer(var_scatter_elements_generic, (a, b)) - with T.block("scatter_elements_generic"): + with T.sblock("scatter_elements_generic"): for i in T.parallel(a * b): out_buf[i // b, i % b] = rxplaceholder[i // b, i % b] for fused in T.parallel(m): @@ -1565,9 +1565,9 @@ class Expected: @T.prim_func(private=True) def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(21), T.int64(30)), "float32"), te_layout_transform_1: T.Buffer((T.int64(10), T.int64(30), T.int64(7), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2 in T.grid(T.int64(10), T.int64(21), T.int64(30)): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(A[v_i0, v_i1, v_i2]) T.writes(te_layout_transform_1[v_i0, v_i2, v_i1 // T.int64(3), v_i1 % T.int64(3)]) @@ -1602,9 +1602,9 @@ class Expected: @T.prim_func(private=True) def te_layout_transform_with_pad(A: T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "float32"), te_layout_transform_with_pad_1: T.Buffer((T.int64(10), T.int64(30), T.int64(7), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30), T.int64(7), T.int64(3)): - with T.block("te_layout_transform_with_pad"): + with T.sblock("te_layout_transform_with_pad"): v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS", [axis0, axis1, axis2, axis3]) T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3, v_axis1]) T.writes(te_layout_transform_with_pad_1[v_axis0, v_axis1, v_axis2, v_axis3]) @@ -1642,9 +1642,9 @@ def te_layout_transform_with_pad(var_A: T.handle, var_te_layout_transform_with_p a, b, c = T.int64(), T.int64(), T.int64() A = T.match_buffer(var_A, (a, b, c)) te_layout_transform_with_pad_1 = T.match_buffer(var_te_layout_transform_with_pad, (a, c, (b - b % T.int64(-3)) // T.int64(3), T.int64(3))) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3 in T.grid(a, c, (b - b % T.int64(-3)) // T.int64(3), T.int64(3)): - with T.block("te_layout_transform_with_pad_with_pad"): + with T.sblock("te_layout_transform_with_pad_with_pad"): v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS", [axis0, axis1, axis2, axis3]) T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3, v_axis1]) T.writes(te_layout_transform_with_pad_1[v_axis0, v_axis1, v_axis2, v_axis3]) @@ -1684,9 +1684,9 @@ class Expected: def te_layout_transform_with_pad_axis_separator(A: T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "float32"), var_te_layout_transform_with_pad_axis_separator: T.handle): T.func_attr({"tir.noalias": True}) te_layout_transform_with_pad_axis_separator_1 = T.match_buffer(var_te_layout_transform_with_pad_axis_separator, (T.int64(10), T.int64(30), T.int64(7), T.int64(3)), axis_separators=[3]) - # with T.block("root"): + # with T.sblock("root"): for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30), T.int64(7), T.int64(3)): - with T.block("te_layout_transform_with_pad_axis_separator"): + with T.sblock("te_layout_transform_with_pad_axis_separator"): v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS", [axis0, axis1, axis2, axis3]) T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3, v_axis1]) T.writes(te_layout_transform_with_pad_axis_separator_1[v_axis0, v_axis1, v_axis2, v_axis3]) @@ -1762,7 +1762,7 @@ def te_layout_transform( ): T.func_attr({"tir.noalias": True}) for i in range(T.int64(16)): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): vi = T.axis.spatial(T.int64(16), i) te_layout_transform[vi // T.int64(4), vi % T.int64(4)] = A[vi] @@ -1804,19 +1804,19 @@ def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, indices = T.match_buffer(var_indices, (T.int64(4), T.int64(1)), "int64") updates = T.match_buffer(var_updates, (T.int64(4),), offset_factor=1) out_buf = T.match_buffer(var_scatter_nd_generic, (T.int64(8),)) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() T_transpose = T.alloc_buffer((T.int64(1), T.int64(4)), "int64") for ax0 in range(T.int64(1)): for ax1 in range(T.int64(4)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(T.int64(4), ax1) T.reads(indices[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = indices[v_ax1, v_ax0] - with T.block("scatter_nd_generic"): + with T.sblock("scatter_nd_generic"): T.reads() T.writes() for i in range(T.int64(8)): diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index e81e1bab2af4..5d3390b8108e 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -48,13 +48,13 @@ def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), B: T T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(A[v_i0, v_i1, v_i2 - T.int64(1)]) T.writes(pad_temp[v_i0, v_i1, v_i2]) pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), A[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0)) for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(16), T.int64(3)): - with T.block("group_conv1d_ncw"): + with T.sblock("group_conv1d_ncw"): v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) T.reads(pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], B[v_ff, v_rc, v_ry]) T.writes(group_conv1d_ncw[v_nn, v_ff, v_yy]) @@ -86,16 +86,16 @@ def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype= @T.prim_func(private=True) def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(28)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[v_i0, v_i1, v_i2]) T.writes(pad_temp[v_i0, v_i1, v_i2]) pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(3), T.int64(3)): - with T.block("conv1d_ncw"): + with T.sblock("conv1d_ncw"): v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry]) T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) @@ -127,16 +127,16 @@ def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), d @T.prim_func(private=True) def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer((T.int64(2), T.int64(28), T.int64(128))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(28), T.int64(128)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[v_i0, v_i1, v_i2]) T.writes(pad_temp[v_i0, v_i1, v_i2]) pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] for nn, yy, ff, ry, rc in T.grid(T.int64(2), T.int64(26), T.int64(64), T.int64(3), T.int64(128)): - with T.block("conv1d_nwc"): + with T.sblock("conv1d_nwc"): v_nn, v_yy, v_ff, v_ry, v_rc = T.axis.remap("SSSRR", [nn, yy, ff, ry, rc]) T.reads(pad_temp[v_nn, v_yy + v_ry, v_rc], rxplaceholder_1[v_ff, v_rc, v_ry]) T.writes(conv1d_nwc[v_nn, v_yy, v_ff]) @@ -182,16 +182,16 @@ def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1 f, kw = T.int64(), T.int64() rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw)) conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w + T.int64(1) - kw)) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer((n, c, w)) for i0, i1, i2 in T.grid(n, c, w): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[v_i0, v_i1, v_i2]) T.writes(pad_temp[v_i0, v_i1, v_i2]) pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] for nn, ff, yy, rc, ry in T.grid(n, f, w + T.int64(1) - kw, c, kw): - with T.block("conv1d_ncw"): + with T.sblock("conv1d_ncw"): v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry]) T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) @@ -222,19 +222,19 @@ def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58))) kernel = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(55)): - with T.block("data_dilate"): + with T.sblock("data_dilate"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) data_dilate[v_i0, v_i1, v_i2] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0), x[v_i0, v_i1, v_i2 // T.int64(2)], T.float32(0.0)) for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(58)): - with T.block("data_pad"): + with T.sblock("data_pad"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) data_pad[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56), data_dilate[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0.0)) for o, i, w_1 in T.grid(T.int64(16), T.int64(128), T.int64(3)): - with T.block("kernel"): + with T.sblock("kernel"): v_o, v_i, v_w = T.axis.remap("SSS", [o, i, w_1]) kernel[v_o, v_i, v_w] = w[v_i, v_o, T.int64(2) - v_w] for b, c, w_1, dc, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(16), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_b, v_c, v_w, v_dc, v_dw = T.axis.remap("SSSRR", [b, c, w_1, dc, dw]) with T.init(): compute[v_b, v_c, v_w] = T.float32(0.0) @@ -272,13 +272,13 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer([T.int64(2), T.int64(128), T.int64(30), T.int64(30)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(30), T.int64(30)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(29) and T.int64(1) <= i3_1 and i3_1 < T.int64(29), rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(13), T.int64(16), T.int64(3), T.int64(3)): - with T.block("group_conv2d_nchw"): + with T.sblock("group_conv2d_nchw"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)], rxplaceholder_1[ff, rc, ry, rx]) T.writes(group_conv2d_nchw[nn, ff, yy, xx]) @@ -312,13 +312,13 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64 T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26), T.int64(3), T.int64(3), T.int64(3)): - with T.block("conv2d_nchw"): + with T.sblock("conv2d_nchw"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx]) T.writes(conv2d_nchw[nn, ff, yy, xx]) @@ -352,13 +352,13 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(28), T.int6 T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer([T.int64(2), T.int64(28), T.int64(28), T.int64(128)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(28), T.int64(28), T.int64(128)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(26), T.int64(26), T.int64(64), T.int64(3), T.int64(3), T.int64(128)): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, yy + ry, xx + rx, rc], rxplaceholder_1[ff, rc, ry, rx]) T.writes(conv2d_nhwc[nn, yy, xx, ff]) @@ -414,13 +414,13 @@ def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2 conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h + T.int64(1) - kh, w + T.int64(1) - kw], dtype="float32") pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32") for i0, i1, i2, i3 in T.grid(n, c, h, w): - with T.block("pad_temp"): + with T.sblock("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] for i0, i1, i2, i3, i4, i5, i6 in T.grid(n, f, h + T.int64(1) - kh, w + T.int64(1) - kw, c, kh, kw): - with T.block("conv2d_nchw"): + with T.sblock("conv2d_nchw"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx]) T.writes(conv2d_nchw[nn, ff, yy, xx]) @@ -452,30 +452,30 @@ def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w: R.Tensor((128, 16, 3 @T.prim_func(private=True) def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(16), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56), T.int64(84)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(55), T.int64(82))) data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58), T.int64(86))) kernel_transform = T.alloc_buffer((T.int64(16), T.int64(128), T.int64(3), T.int64(3))) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(55), T.int64(82)): - with T.block("data_dilate"): + with T.sblock("data_dilate"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)]) T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3]) data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)], T.float32(0)) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(58), T.int64(86)): - with T.block("data_pad"): + with T.sblock("data_pad"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)]) T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56) and T.int64(1) <= v_i3 and v_i3 < T.int64(83), data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0)) for i, o, h, w in T.grid(T.int64(16), T.int64(128), T.int64(3), T.int64(3)): - with T.block("kernel_transform"): + with T.sblock("kernel_transform"): v_i, v_o, v_h, v_w = T.axis.remap("SSSS", [i, o, h, w]) T.reads(rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w]) T.writes(kernel_transform[v_i, v_o, v_h, v_w]) kernel_transform[v_i, v_o, v_h, v_w] = rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w] for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(128), T.int64(56), T.int64(84), T.int64(16), T.int64(3), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw]) T.reads(data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw]) T.writes(compute[v_b, v_c, v_h, v_w]) @@ -507,30 +507,30 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 4, 3, 3), @T.prim_func(private=True) def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(4), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4), T.int64(30), T.int64(30)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): data_dilate = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) data_pad = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(32), T.int64(32))) kernel_transform = T.alloc_buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3))) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("data_dilate"): + with T.sblock("data_dilate"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3]) T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3]) data_dilate[v_i0, v_i1, v_i2, v_i3] = rxplaceholder[v_i0, v_i1, v_i2, v_i3] for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(32)): - with T.block("data_pad"): + with T.sblock("data_pad"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)]) T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(30) and T.int64(2) <= v_i3 and v_i3 < T.int64(30), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)], T.float32(0)) for o, i, h, w in T.grid(T.int64(4), T.int64(3), T.int64(3), T.int64(3)): - with T.block("kernel_transform"): + with T.sblock("kernel_transform"): v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w]) T.reads(rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w]) T.writes(kernel_transform[v_o, v_i, v_h, v_w]) kernel_transform[v_o, v_i, v_h, v_w] = rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w] for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(4), T.int64(30), T.int64(30), T.int64(3), T.int64(3), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw]) T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c, v_dc, v_dh, v_dw]) T.writes(compute[v_b, v_c, v_h, v_w]) @@ -579,30 +579,30 @@ def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, kw = T.int64() rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kh, kw)) compute = T.match_buffer(var_compute, (n, c, h * T.int64(3) + kh - T.int64(3), w * T.int64(3) + kw - T.int64(3))) - # with T.block("root"): + # with T.sblock("root"): data_dilate = T.alloc_buffer((n, c, h * T.int64(3) - T.int64(2), w * T.int64(3) - T.int64(2))) data_pad = T.alloc_buffer((n, c, h * T.int64(3) + kh * T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4))) kernel_transform = T.alloc_buffer((c, c, kh, kw)) for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) - T.int64(2), w * T.int64(3) - T.int64(2)): - with T.block("data_dilate"): + with T.sblock("data_dilate"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)]) T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3]) data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(3) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)], T.float32(0)) for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) + kh * T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)): - with T.block("data_pad"): + with T.sblock("data_pad"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw]) T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh <= v_i2 + T.int64(1) and v_i2 + T.int64(3)< h * T.int64(3) + kh and kw <= v_i3 + T.int64(1) and v_i3 + T.int64(3) < w * T.int64(3) + kw , data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw], T.float32(0)) for o, i, h_1, w_1 in T.grid(c, c, kh, kw): - with T.block("kernel_transform"): + with T.sblock("kernel_transform"): v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h_1, w_1]) T.reads(rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)]) T.writes(kernel_transform[v_o, v_i, v_h, v_w]) kernel_transform[v_o, v_i, v_h, v_w] = rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)] for b, c_1, h_1, w_1, dc, dh, dw in T.grid(n, c, h * T.int64(3) + kh - T.int64(3), w * T.int64(3) + kw - T.int64(3), c, kh, kw): - with T.block("compute"): + with T.sblock("compute"): v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c_1, h_1, w_1, dc, dh, dw]) T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c, v_dc, v_dh, v_dw]) T.writes(compute[v_b, v_c, v_h, v_w]) @@ -636,17 +636,17 @@ def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer([T.int64(4), T.int64(114), T.int64(114), T.int64(6)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3]) T.writes(pad_temp[ax0, ax1, ax2, ax3]) pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax1 and ax1 < T.int64(113) and T.int64(1) <= ax2 and ax2 < T.int64(113), rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3], T.float32(-3.4028234663852886e+38), dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)): - with T.block("pool_max"): + with T.sblock("pool_max"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3]) T.writes(pool_max[ax0, ax1, ax2, ax3]) - T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule":"meta_schedule.pool_max"}) with T.init(): pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38) pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3]) @@ -676,11 +676,11 @@ def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 1 def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): - with T.block("pool_max"): + with T.sblock("pool_max"): ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4]) T.writes(pool_max[ax0, ax1, ax2, ax3, ax4]) - T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule":"meta_schedule.pool_max"}) with T.init(): pool_max[ax0, ax1, ax2, ax3, ax4] = T.float32(-3.4028234663852886e+38) pool_max[ax0, ax1, ax2, ax3, ax4] = T.max(pool_max[ax0, ax1, ax2, ax3, ax4], rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4]) @@ -711,17 +711,17 @@ def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T. T.func_attr({"tir.noalias": True}) pad_temp = T.alloc_buffer([T.int64(4), T.int64(6), T.int64(116), T.int64(116)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)]) T.writes(pad_temp[ax0, ax1, ax2, ax3]) pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax2 and ax2 < T.int64(113) and T.int64(1) <= ax3 and ax3 < T.int64(113), rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38), dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)): - with T.block("pool_max"): + with T.sblock("pool_max"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1]) T.writes(pool_max[ax0, ax1, ax2, ax3]) - T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule":"meta_schedule.pool_max"}) with T.init(): pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38) pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1]) @@ -767,17 +767,17 @@ class Expected: @T.prim_func(private=True) def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer((T.int64(4), T.int64(114), T.int64(114), T.int64(6))) pool_sum = T.alloc_buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - T.int64(1), v_ax3]) T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3]) pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax1 and v_ax1 < T.int64(113) and T.int64(1) <= v_ax2 and v_ax2 < T.int64(113), rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - T.int64(1), v_ax3], T.float32(0)) for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)): - with T.block("pool_sum"): + with T.sblock("pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * T.int64(2) + v_rv1, v_ax3]) T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -785,11 +785,11 @@ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * T.int64(2) + v_rv1, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6)): - with T.block("pool_avg"): + with T.sblock("pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) - T.block_attr({"schedule_rule": "meta_schedule.pool_avg"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_avg"}) pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax1 * T.int64(2) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax1 * T.int64(2), T.int64(0)) - v_ax1 * T.int64(2)) * (T.min(v_ax2 * T.int64(2) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(2), T.int64(0)) - v_ax2 * T.int64(2)), T.int64(1))) @R.function @@ -816,10 +816,10 @@ class Expected: @T.prim_func(private=True) def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pool_sum = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16))) for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): - with T.block("pool_sum"): + with T.sblock("pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap("SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + v_rv1, v_ax4]) T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) @@ -827,11 +827,11 @@ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T. pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(0) pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + v_rv1, v_ax4] for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)): - with T.block("pool_avg"): + with T.sblock("pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T.block_attr({"schedule_rule": "meta_schedule.pool_avg"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_avg"}) pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", T.max((T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1) - T.max(T.int64(0) - v_ax2, T.int64(0))) * (T.min(T.int64(2), T.int64(111) - v_ax3) + T.int64(1) - T.max(T.int64(0) - v_ax3, T.int64(0))), T.int64(1))) @R.function def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) -> R.Tensor((4, 4, 110, 110, 16), dtype="float32"): @@ -857,17 +857,17 @@ class Expected: @T.prim_func(private=True) def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(116), T.int64(116))) pool_sum = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1)]) T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3]) pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(1) <= v_ax2 and v_ax2 < T.int64(113) and T.int64(1) <= v_ax3 and v_ax3 < T.int64(113), rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1)], T.float32(0)) for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)): - with T.block("pool_sum"): + with T.sblock("pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 * T.int64(3) + v_rv1]) T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -875,11 +875,11 @@ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T. pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 * T.int64(3) + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38)): - with T.block("pool_avg"): + with T.sblock("pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) - T.block_attr({"schedule_rule": "meta_schedule.pool_avg"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_avg"}) pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax2 * T.int64(3) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(3), T.int64(0)) - v_ax2 * T.int64(3)) * (T.min(v_ax3 * T.int64(3) + T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax3 * T.int64(3), T.int64(0)) - v_ax3 * T.int64(3)), T.int64(1))) @R.function @@ -936,7 +936,7 @@ def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64 T.func_attr({"tir.noalias": True}) adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)], dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16), T.int64(7), T.int64(7)): - with T.block("adaptive_pool_sum"): + with T.sblock("adaptive_pool_sum"): ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads(rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4]) T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) @@ -944,11 +944,11 @@ def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64 adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = T.float32(0) adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] + rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4] for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)): - with T.block("adaptive_pool_avg"): + with T.sblock("adaptive_pool_avg"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4]) - T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) + T.sblock_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] / T.float32(49.0) # fmt: on @@ -977,7 +977,7 @@ def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(16), T.int6 T.func_attr({"tir.noalias": True}) adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(16), T.int64(7), T.int64(7)], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7), T.int64(1), T.int64(1)): - with T.block("adaptive_pool_sum"): + with T.sblock("adaptive_pool_sum"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1]) T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3]) @@ -985,11 +985,11 @@ def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(16), T.int6 adaptive_pool_sum[ax0, ax1, ax2, ax3] = T.float32(0) adaptive_pool_sum[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] + rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1] for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7)): - with T.block("adaptive_pool_avg"): + with T.sblock("adaptive_pool_avg"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3]) T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3]) - T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) + T.sblock_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] # fmt: on @@ -1036,7 +1036,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) @@ -1075,7 +1075,7 @@ def relu(var_rxplaceholder: T.handle, var_compute: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) @@ -1107,7 +1107,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def leaky_relu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(x[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) @@ -1145,7 +1145,7 @@ def leaky_relu(var_x: T.handle, var_compute: T.handle): x = T.match_buffer(var_x, (m, n)) compute = T.match_buffer(var_compute, (m, n)) for i0, i1 in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(x[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) @@ -1175,16 +1175,16 @@ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32" @T.prim_func(private=True) def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): slope_broadcasted = T.alloc_buffer((T.int64(3),)) for c in range(T.int64(3)): - with T.block("slope_broadcasted"): + with T.sblock("slope_broadcasted"): v_c = T.axis.spatial(T.int64(3), c) T.reads(y[T.int64(0)]) T.writes(slope_broadcasted[v_c]) slope_broadcasted[v_c] = y[T.int64(0)] for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1]) T.writes(compute[v_i0, v_i1]) @@ -1219,16 +1219,16 @@ def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), var_compute: T m = T.int64() x = T.match_buffer(var_x, (m, T.int64(7))) compute = T.match_buffer(var_compute, (m, T.int64(7))) - # with T.block("root"): + # with T.sblock("root"): slope_broadcasted = T.alloc_buffer((T.int64(7),)) for c in range(T.int64(7)): - with T.block("slope_broadcasted"): + with T.sblock("slope_broadcasted"): v_c = T.axis.spatial(T.int64(7), c) T.reads(y[T.int64(0)]) T.writes(slope_broadcasted[v_c]) slope_broadcasted[v_c] = y[T.int64(0)] for i0, i1 in T.grid(m, T.int64(7)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1]) T.writes(compute[v_i0, v_i1]) @@ -1263,31 +1263,31 @@ def gelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer( T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3))) T_add = T.alloc_buffer((T.int64(2), T.int64(3))) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -1329,31 +1329,31 @@ def gelu(var_x: T.handle, var_T_multiply: T.handle): T_multiply_2 = T.alloc_buffer((m, n)) T_add = T.alloc_buffer((m, n)) for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) for ax0, ax1 in T.grid(m, n): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -1392,55 +1392,55 @@ def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Bu compute = T.alloc_buffer((T.int64(2), T.int64(3))) T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3))) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_3[v_ax0, v_ax1]) T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_3"): + with T.sblock("T_multiply_3"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1]) T.writes(T_multiply_4[v_ax0, v_ax1]) T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_4[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_4"): + with T.sblock("T_multiply_4"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply_5[v_ax0, v_ax1]) T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1] for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_5[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add_1"): + with T.sblock("T_add_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_add_1[v_ax0, v_ax1]) T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_5"): + with T.sblock("T_multiply_5"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -1476,7 +1476,7 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle): m, n = T.int64(), T.int64() A = T.match_buffer(var_A, (m, n)) T_multiply = T.match_buffer(var_T_multiply, (m, n)) - # with T.block("root"): + # with T.sblock("root"): T_multiply_1 = T.alloc_buffer((m, n)) T_multiply_2 = T.alloc_buffer((m, n)) T_multiply_3 = T.alloc_buffer((m, n)) @@ -1486,55 +1486,55 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle): compute = T.alloc_buffer((m, n)) T_add_1 = T.alloc_buffer((m, n)) for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_1[v_ax0, v_ax1]) T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_2[v_ax0, v_ax1]) T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_multiply_3[v_ax0, v_ax1]) T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_3"): + with T.sblock("T_multiply_3"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1]) T.writes(T_multiply_4[v_ax0, v_ax1]) T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_4[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_4"): + with T.sblock("T_multiply_4"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) T.writes(T_multiply_5[v_ax0, v_ax1]) T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1] for i0, i1 in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_5[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1]) for ax0, ax1 in T.grid(m, n): - with T.block("T_add_1"): + with T.sblock("T_add_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(compute[v_ax0, v_ax1]) T.writes(T_add_1[v_ax0, v_ax1]) T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1] for ax0, ax1 in T.grid(m, n): - with T.block("T_multiply_5"): + with T.sblock("T_multiply_5"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) @@ -1566,13 +1566,13 @@ def silu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multipl T.func_attr({"tir.noalias": True}) compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1]) T.writes(T_multiply[ax0, ax1]) @@ -1612,13 +1612,13 @@ def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") compute = T.alloc_buffer([m, n], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) for i0, i1 in T.grid(m, n): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1]) T.writes(T_multiply[ax0, ax1]) @@ -1652,7 +1652,7 @@ def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int6 T_softmax_exp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(16), T.int64(32)], dtype="float32") T_softmax_expsum = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1]) @@ -1660,13 +1660,13 @@ def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int6 T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1]) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): i0_2, i1_2, i2_2, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_2, i1_2, i2_2, i3_1], T_softmax_maxelem[i0_2, i1_2, i3_1]) T.writes(T_softmax_exp[i0_2, i1_2, i2_2, i3_1]) T_softmax_exp[i0_2, i1_2, i2_2, i3_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_2, i3_1] - T_softmax_maxelem[i0_2, i1_2, i3_1], dtype="float32") for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4, i1_4, i2_4, k = T.axis.remap("SSSR", [i0_3, i1_3, i2_3, i3]) T.reads(T_softmax_exp[i0_4, i1_4, k, i2_4]) T.writes(T_softmax_expsum[i0_4, i1_4, i2_4]) @@ -1674,11 +1674,11 @@ def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int6 T_softmax_expsum[i0_4, i1_4, i2_4] = T.float32(0) T_softmax_expsum[i0_4, i1_4, i2_4] = T_softmax_expsum[i0_4, i1_4, i2_4] + T_softmax_exp[i0_4, i1_4, k, i2_4] for i0_5, i1_5, i2_5, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_6, i1_6, i2_6, i3_2 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3]) T.reads(T_softmax_exp[i0_6, i1_6, i2_6, i3_2], T_softmax_expsum[i0_6, i1_6, i3_2]) T.writes(T_softmax_norm[i0_6, i1_6, i2_6, i3_2]) - T.block_attr({"axis":2}) + T.sblock_attr({"axis":2}) T_softmax_norm[i0_6, i1_6, i2_6, i3_2] = T_softmax_exp[i0_6, i1_6, i2_6, i3_2] / T_softmax_expsum[i0_6, i1_6, i3_2] # fmt: on @@ -1720,7 +1720,7 @@ def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): T_softmax_exp = T.alloc_buffer([a, b, c], dtype="float32") T_softmax_expsum = T.alloc_buffer([a, b], dtype="float32") for i0, i1, i2 in T.grid(a, b, c): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, k]) T.writes(T_softmax_maxelem[i0_1, i1_1]) @@ -1728,13 +1728,13 @@ def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): T_softmax_maxelem[i0_1, i1_1] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[i0_1, i1_1] = T.max(T_softmax_maxelem[i0_1, i1_1], rxplaceholder[i0_1, i1_1, k]) for i0, i1, i2 in T.grid(a, b, c): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): i0_2, i1_2, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_2, i1_2, i2_1], T_softmax_maxelem[i0_2, i1_2]) T.writes(T_softmax_exp[i0_2, i1_2, i2_1]) T_softmax_exp[i0_2, i1_2, i2_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_1] - T_softmax_maxelem[i0_2, i1_2], dtype="float32") for i0_3, i1_3, i2 in T.grid(a, b, c): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_4, i1_4, k = T.axis.remap("SSR", [i0_3, i1_3, i2]) T.reads(T_softmax_exp[i0_4, i1_4, k]) T.writes(T_softmax_expsum[i0_4, i1_4]) @@ -1742,11 +1742,11 @@ def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): T_softmax_expsum[i0_4, i1_4] = T.float32(0) T_softmax_expsum[i0_4, i1_4] = T_softmax_expsum[i0_4, i1_4] + T_softmax_exp[i0_4, i1_4, k] for i0_5, i1_5, i2 in T.grid(a, b, c): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_6, i1_6, i2_2 = T.axis.remap("SSS", [i0_5, i1_5, i2]) T.reads(T_softmax_exp[i0_6, i1_6, i2_2], T_softmax_expsum[i0_6, i1_6]) T.writes(T_softmax_norm[i0_6, i1_6, i2_2]) - T.block_attr({"axis":2}) + T.sblock_attr({"axis":2}) T_softmax_norm[i0_6, i1_6, i2_2] = T_softmax_exp[i0_6, i1_6, i2_2] / T_softmax_expsum[i0_6, i1_6] # fmt: on @@ -1776,7 +1776,7 @@ def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T. T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") compute_1 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1]) @@ -1784,7 +1784,7 @@ def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T. T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e38) T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1]) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): - with T.block("compute"): + with T.sblock("compute"): i0_2, i1_2, i2_2, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_2, i1_2, k, i2_2], T_softmax_maxelem[i0_2, i1_2, i2_2]) T.writes(compute_1[i0_2, i1_2, i2_2]) @@ -1792,11 +1792,11 @@ def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T. compute_1[i0_2, i1_2, i2_2] = T.float32(0) compute_1[i0_2, i1_2, i2_2] = compute_1[i0_2, i1_2, i2_2] + T.exp(rxplaceholder[i0_2, i1_2, k, i2_2] - T_softmax_maxelem[i0_2, i1_2, i2_2], dtype="float32") for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_4, i1_4, i2_4, i3_1 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3]) T.reads(rxplaceholder[i0_4, i1_4, i2_4, i3_1], T_softmax_maxelem[i0_4, i1_4, i3_1], compute_1[i0_4, i1_4, i3_1]) T.writes(compute[i0_4, i1_4, i2_4, i3_1]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) compute[i0_4, i1_4, i2_4, i3_1] = (rxplaceholder[i0_4, i1_4, i2_4, i3_1] - T_softmax_maxelem[i0_4, i1_4, i3_1] - T.log(compute_1[i0_4, i1_4, i3_1], dtype="float32")) # fmt: on @@ -1838,7 +1838,7 @@ def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") compute_1 = T.alloc_buffer([a, b], dtype="float32") for i0, i1, k in T.grid(a, b, c): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(rxplaceholder[v_i0, v_i1, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1]) @@ -1846,7 +1846,7 @@ def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e38) T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], rxplaceholder[v_i0, v_i1, v_k]) for i0, i1, k in T.grid(a, b, c): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(rxplaceholder[v_i0, v_i1, v_k], T_softmax_maxelem[v_i0, v_i1]) T.writes(compute_1[v_i0, v_i1]) @@ -1854,11 +1854,11 @@ def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): compute_1[v_i0, v_i1] = T.float32(0) compute_1[v_i0, v_i1] = compute_1[v_i0, v_i1] + T.exp(rxplaceholder[v_i0, v_i1, v_k] - T_softmax_maxelem[v_i0, v_i1], dtype="float32") for i0, i1, i2 in T.grid(a, b, c): - with T.block("compute_1"): + with T.sblock("compute_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1], compute_1[v_i0, v_i1],) T.writes(compute[v_i0, v_i1, v_i2]) - T.block_attr({"axis": 2}) + T.sblock_attr({"axis": 2}) compute[v_i0, v_i1, v_i2] = (rxplaceholder[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1] - T.log(compute_1[v_i0, v_i1], dtype="float32")) # fmt: on @@ -1888,20 +1888,20 @@ def cross_entropy_with_logits(x: T.Buffer((T.int64(3),), "float32"), y: T.Buffer T_multiply_1 = T.alloc_buffer((T.int64(3),)) T_multiply_red = T.alloc_buffer(()) for ax0 in range(T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(x[v_ax0], y[v_ax0]) T.writes(T_multiply_1[v_ax0]) T_multiply_1[v_ax0] = x[v_ax0] * y[v_ax0] for k0 in range(T.int64(3)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_k0 = T.axis.reduce(T.int64(3), k0) T.reads(T_multiply_1[v_k0]) T.writes(T_multiply_red[()]) with T.init(): T_multiply_red[()] = T.float32(0.0) T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[v_k0] - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_multiply[()]) @@ -1935,25 +1935,25 @@ def cross_entropy_with_logits(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply_red = T.alloc_buffer(()) T_multiply_1 = T.alloc_buffer(()) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * y[v_ax0, v_ax1] for k0, k1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) T.reads(T_multiply[v_k0, v_k1]) T.writes(T_multiply_red[()]) with T.init(): T_multiply_red[()] = T.float32(0.0) T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1] - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_multiply_1[()]) T_multiply_1[()] = T_multiply_red[()] * T.float32(-1.0) - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_1[()]) T.writes(T_divide[()]) @@ -1992,25 +1992,25 @@ def cross_entropy_with_logits(var_x: T.handle, var_y: T.handle, T_divide: T.Buff T_multiply_red = T.alloc_buffer(()) T_multiply_1 = T.alloc_buffer(()) for ax0, ax1 in T.grid(n, m): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * y[v_ax0, v_ax1] for k0, k1 in T.grid(n, m): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) T.reads(T_multiply[v_k0, v_k1]) T.writes(T_multiply_red[()]) with T.init(): T_multiply_red[()] = T.float32(0.0) T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1] - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_multiply_1[()]) T_multiply_1[()] = T_multiply_red[()] * T.float32(-1.0) - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_1[()]) T.writes(T_divide[()]) @@ -2043,7 +2043,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),)) T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),)) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() x_red = T.alloc_buffer((T.int64(3),)) @@ -2070,7 +2070,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for k0 in range(T.int64(2)): for k2 in range(T.int64(28)): for k3 in range(T.int64(28)): - with T.block("x_red"): + with T.sblock("x_red"): v_ax0 = T.axis.spatial(T.int64(3), ax0) v_k0 = T.axis.reduce(T.int64(2), k0) v_k2 = T.axis.reduce(T.int64(28), k2) @@ -2081,7 +2081,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov x_red[v_ax0] = T.float32(0.0) x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] for ax0 in range(T.int64(3)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(x_red[v_ax0]) T.writes(T_divide[v_ax0]) @@ -2090,7 +2090,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2102,7 +2102,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2114,7 +2114,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2126,7 +2126,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_subtract_2"): + with T.sblock("T_subtract_2"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2138,7 +2138,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2150,7 +2150,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for k0 in range(T.int64(2)): for k2 in range(T.int64(28)): for k3 in range(T.int64(28)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0 = T.axis.spatial(T.int64(3), ax0) v_k0 = T.axis.reduce(T.int64(2), k0) v_k2 = T.axis.reduce(T.int64(28), k2) @@ -2161,7 +2161,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T_multiply_red[v_ax0] = T.float32(0.0) T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] for ax0 in range(T.int64(3)): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(T_multiply_red[v_ax0]) T.writes(T_divide_1[v_ax0]) @@ -2170,7 +2170,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2182,7 +2182,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2194,7 +2194,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for i1 in range(T.int64(3)): for i2 in range(T.int64(1)): for i3 in range(T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(T.int64(1), i0) v_i1 = T.axis.spatial(T.int64(3), i1) v_i2 = T.axis.spatial(T.int64(1), i2) @@ -2206,7 +2206,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_divide_2"): + with T.sblock("T_divide_2"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2218,7 +2218,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2230,7 +2230,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2242,7 +2242,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2254,7 +2254,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(T.int64(3)): for ax2 in range(T.int64(28)): for ax3 in range(T.int64(28)): - with T.block("T_add_1"): + with T.sblock("T_add_1"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(28), ax2) @@ -2263,37 +2263,37 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] for ax0 in range(T.int64(3)): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(moving_mean[v_ax0]) T.writes(T_multiply_2[v_ax0]) T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] for ax0 in range(T.int64(3)): - with T.block("T_multiply_3"): + with T.sblock("T_multiply_3"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(T_divide[v_ax0]) T.writes(T_multiply_3[v_ax0]) T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] for ax0 in range(T.int64(3)): - with T.block("T_add_2"): + with T.sblock("T_add_2"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) T.writes(T_add_1[v_ax0]) T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] for ax0 in range(T.int64(3)): - with T.block("T_multiply_4"): + with T.sblock("T_multiply_4"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(moving_var[v_ax0]) T.writes(T_multiply_4[v_ax0]) T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] for ax0 in range(T.int64(3)): - with T.block("T_multiply_5"): + with T.sblock("T_multiply_5"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(T_divide_1[v_ax0]) T.writes(T_multiply_5[v_ax0]) T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] for ax0 in range(T.int64(3)): - with T.block("T_add_3"): + with T.sblock("T_add_3"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(T_multiply_4[v_ax0], T_multiply_5[v_ax0]) T.writes(T_add_2[v_ax0]) @@ -2337,7 +2337,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T_add = T.match_buffer(var_T_add, (n, h, w, c)) T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() x_red = T.alloc_buffer((h,)) @@ -2364,7 +2364,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for k0 in range(n): for k2 in range(w): for k3 in range(c): - with T.block("x_red"): + with T.sblock("x_red"): v_ax0 = T.axis.spatial(h, ax0) v_k0 = T.axis.reduce(n, k0) v_k2 = T.axis.reduce(w, k2) @@ -2375,7 +2375,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov x_red[v_ax0] = T.float32(0.0) x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] for ax0 in range(h): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0 = T.axis.spatial(h, ax0) T.reads(x_red[v_ax0]) T.writes(T_divide[v_ax0]) @@ -2384,7 +2384,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2396,7 +2396,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2408,7 +2408,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_subtract_1"): + with T.sblock("T_subtract_1"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2420,7 +2420,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_subtract_2"): + with T.sblock("T_subtract_2"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2432,7 +2432,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2444,7 +2444,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for k0 in range(n): for k2 in range(w): for k3 in range(c): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0 = T.axis.spatial(h, ax0) v_k0 = T.axis.reduce(n, k0) v_k2 = T.axis.reduce(w, k2) @@ -2455,7 +2455,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T_multiply_red[v_ax0] = T.float32(0.0) T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] for ax0 in range(h): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): v_ax0 = T.axis.spatial(h, ax0) T.reads(T_multiply_red[v_ax0]) T.writes(T_divide_1[v_ax0]) @@ -2464,7 +2464,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2476,7 +2476,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2488,7 +2488,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for i1 in range(h): for i2 in range(T.int64(1)): for i3 in range(T.int64(1)): - with T.block("compute"): + with T.sblock("compute"): v_i0 = T.axis.spatial(T.int64(1), i0) v_i1 = T.axis.spatial(h, i1) v_i2 = T.axis.spatial(T.int64(1), i2) @@ -2500,7 +2500,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_divide_2"): + with T.sblock("T_divide_2"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2512,7 +2512,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2524,7 +2524,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_multiply_1"): + with T.sblock("T_multiply_1"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2536,7 +2536,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(T.int64(1)): for ax3 in range(T.int64(1)): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(T.int64(1), ax2) @@ -2548,7 +2548,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov for ax1 in range(h): for ax2 in range(w): for ax3 in range(c): - with T.block("T_add_1"): + with T.sblock("T_add_1"): v_ax0 = T.axis.spatial(n, ax0) v_ax1 = T.axis.spatial(h, ax1) v_ax2 = T.axis.spatial(w, ax2) @@ -2557,37 +2557,37 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] for ax0 in range(c): - with T.block("T_multiply_2"): + with T.sblock("T_multiply_2"): v_ax0 = T.axis.spatial(c, ax0) T.reads(moving_mean[v_ax0]) T.writes(T_multiply_2[v_ax0]) T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] for ax0 in range(h): - with T.block("T_multiply_3"): + with T.sblock("T_multiply_3"): v_ax0 = T.axis.spatial(h, ax0) T.reads(T_divide[v_ax0]) T.writes(T_multiply_3[v_ax0]) T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] for ax0 in range(T.max(c, h)): - with T.block("T_add_2"): + with T.sblock("T_add_2"): v_ax0 = T.axis.spatial(T.max(c, h), ax0) T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) T.writes(T_add_1[v_ax0]) T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] for ax0 in range(c): - with T.block("T_multiply_4"): + with T.sblock("T_multiply_4"): v_ax0 = T.axis.spatial(c, ax0) T.reads(moving_var[v_ax0]) T.writes(T_multiply_4[v_ax0]) T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] for ax0 in range(h): - with T.block("T_multiply_5"): + with T.sblock("T_multiply_5"): v_ax0 = T.axis.spatial(h, ax0) T.reads(T_divide_1[v_ax0]) T.writes(T_multiply_5[v_ax0]) T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] for ax0 in range(T.max(c, h)): - with T.block("T_add_3"): + with T.sblock("T_add_3"): v_ax0 = T.axis.spatial(T.max(c, h), ax0) T.reads(T_multiply_4[v_ax0], T_multiply_5[v_ax0]) T.writes(T_add_2[v_ax0]) @@ -2629,7 +2629,7 @@ def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.in rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1, k2, k3]) T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) @@ -2641,7 +2641,7 @@ def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.in rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3]) T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) @@ -2669,11 +2669,11 @@ class LayerNorm_1D_Expected: @T.prim_func(private=True) def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): x_red_temp_v0 = T.alloc_buffer(()) x_red_temp_v1 = T.alloc_buffer(()) for k0 in range(T.int64(3)): - with T.block("x_red_temp"): + with T.sblock("x_red_temp"): v_k0 = T.axis.reduce(T.int64(3), k0) T.reads(x[v_k0]) T.writes(x_red_temp_v0[()], x_red_temp_v1[()]) @@ -2685,7 +2685,7 @@ def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffe x_red_temp_v0[()] = v_x_red_temp_v0 x_red_temp_v1[()] = v_x_red_temp_v1 for ax0 in range(T.int64(3)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0]) T.writes(T_layer_norm[v_ax0]) @@ -2723,7 +2723,7 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4), T.int64(5)), "float16") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4), T.int64(5)), "float16") T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(3))) @@ -2732,7 +2732,7 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r for ax1 in range(T.int64(3)): for k2 in range(T.int64(4)): for k3 in range(T.int64(5)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_k2 = T.axis.reduce(T.int64(4), k2) @@ -2750,7 +2750,7 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r for ax1 in range(T.int64(3)): for ax2 in range(T.int64(4)): for ax3 in range(T.int64(5)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0 = T.axis.spatial(T.int64(2), ax0) v_ax1 = T.axis.spatial(T.int64(3), ax1) v_ax2 = T.axis.spatial(T.int64(4), ax2) @@ -2803,7 +2803,7 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r rxplaceholder_red_temp_v0 = T.alloc_buffer([n], dtype="float32") rxplaceholder_red_temp_v1 = T.alloc_buffer([n], dtype="float32") for i0, i1, i2 in T.grid(n, s, f): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2]) T.reads(rxplaceholder[ax0, k1, k2]) T.writes(rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0]) @@ -2815,7 +2815,7 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r rxplaceholder_red_temp_v0[ax0] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[ax0] = v_rxplaceholder_red_temp_v1 for i0, i1, i2 in T.grid(n, s, f): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2]) T.writes(T_layer_norm[ax0, ax1, ax2]) @@ -2846,13 +2846,13 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2))) T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) @@ -2864,25 +2864,25 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) T.writes(T_reshape_2[v_ax0, v_ax1]) T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) T.writes(T_reshape_3[v_ax0, v_ax1]) T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_group_norm"): + with T.sblock("T_group_norm"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.float32(40) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -2916,7 +2916,7 @@ def main(x: R.Tensor((2, 4, 4, 5), dtype="float16"), gamma: R.Tensor((4,), dtype @T.prim_func(private=True) def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(4),), "float16"), rxplaceholder_2: T.Buffer((T.int64(4),), "float16"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16") T_cast = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2))) @@ -2925,19 +2925,19 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16") T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16") for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_cast"): + with T.sblock("T_cast"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float32", T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) T.reads(T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4]) T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) @@ -2949,25 +2949,25 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) T.writes(T_reshape_2[v_ax0, v_ax1]) T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) T.writes(T_reshape_3[v_ax0, v_ax1]) T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): - with T.block("T_group_norm"): + with T.sblock("T_group_norm"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.float32(40) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) + T.float32(1.0000000000000001e-05))) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -3003,7 +3003,7 @@ def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4) * c,)) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4) * c,)) T_reshape = T.match_buffer(var_T_reshape, (n, T.int64(4) * c, h, w)) - # with T.block("root"): + # with T.sblock("root"): T_reshape_1 = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) rxplaceholder_red_temp_v0 = T.alloc_buffer((n, T.int64(4))) rxplaceholder_red_temp_v1 = T.alloc_buffer((n, T.int64(4))) @@ -3011,13 +3011,13 @@ def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r T_reshape_3 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) T_group_norm = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w] for ax0, ax1, k2, k3, k4 in T.grid(n, T.int64(4), c, h, w): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) @@ -3029,25 +3029,25 @@ def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 for ax0, ax1 in T.grid(T.int64(4), c): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) T.writes(T_reshape_2[v_ax0, v_ax1]) T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))] for ax0, ax1 in T.grid(T.int64(4), c): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) T.writes(T_reshape_3[v_ax0, v_ax1]) T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))] for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): - with T.block("T_group_norm"): + with T.sblock("T_group_norm"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_group_norm[(((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -3080,7 +3080,7 @@ class Expected: @T.prim_func(private=True) def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) @@ -3088,19 +3088,19 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_cast"): + with T.sblock("T_cast"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3]) T.writes(T_multiply_red[v_ax0, v_ax1]) @@ -3108,25 +3108,25 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] / T.float32(20) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B[v_ax0, v_ax1]) T.writes(T_cast_2[v_ax0, v_ax1]) T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -3156,7 +3156,7 @@ class Expected: @T.prim_func(private=True) def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), B: T.Buffer((T.int64(4), T.int64(5)), "float16"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) @@ -3164,19 +3164,19 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_cast"): + with T.sblock("T_cast"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_ax3]) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3]) T.writes(T_multiply_red[v_ax0, v_ax1]) @@ -3184,25 +3184,25 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] / T.float32(20) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B[v_ax0, v_ax1]) T.writes(T_cast_2[v_ax0, v_ax1]) T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1]) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -3239,7 +3239,7 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): A = T.match_buffer(var_A, (n, s, f)) B = T.match_buffer(var_B, (s, f)) T_cast = T.match_buffer(var_T_cast, (n, s, f)) - # with T.block("root"): + # with T.sblock("root"): T_cast_1 = T.alloc_buffer((n, s, f)) T_multiply = T.alloc_buffer((n, s, f)) T_multiply_red = T.alloc_buffer((n,)) @@ -3247,19 +3247,19 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): T_cast_2 = T.alloc_buffer((s, f)) T_rms_norm = T.alloc_buffer((n, s, f)) for ax0, ax1, ax2 in T.grid(n, s, f): - with T.block("T_cast"): + with T.sblock("T_cast"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v_ax0, v_ax1, v_ax2]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2]) T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(n, s, f): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_cast_1[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2] for ax0, k1, k2 in T.grid(n, s, f): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2]) T.reads(T_multiply[v_ax0, v_k1, v_k2]) T.writes(T_multiply_red[v_ax0]) @@ -3267,25 +3267,25 @@ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): T_multiply_red[v_ax0] = T.float32(0) T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2] for ax0 in range(n): - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0 = T.axis.spatial(n, ax0) T.reads(T_multiply_red[v_ax0]) T.writes(rsqrt[v_ax0]) rsqrt[v_ax0] = T.rsqrt(T_multiply_red[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(s, f): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B[v_ax0, v_ax1]) T.writes(T_cast_2[v_ax0, v_ax1]) T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2 in T.grid(n, s, f): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rsqrt[v_ax0], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, v_ax2]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(n, s, f): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2]) T.writes(T_cast[v_ax0, v_ax1, v_ax2]) @@ -3318,7 +3318,7 @@ class Expected: @T.prim_func(private=True) def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) @@ -3326,19 +3326,19 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_cast"): + with T.sblock("T_cast"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3]) T.writes(T_multiply_red[v_ax0, v_ax1]) @@ -3346,25 +3346,25 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): - with T.block("rsqrt"): + with T.sblock("rsqrt"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] / T.float32(20) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): - with T.block("T_cast_1"): + with T.sblock("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B[v_ax0, v_ax1]) T.writes(T_cast_2[v_ax0, v_ax1]) T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_rms_norm"): + with T.sblock("T_rms_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], T_cast_2[v_ax2, v_ax3]) T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_cast_2"): + with T.sblock("T_cast_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -3395,7 +3395,7 @@ class Expected: @T.prim_func(private=True) def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), k: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), v: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), bias: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) T_reshape = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) T_transpose_2 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(8), T.int64(8))) @@ -3417,70 +3417,70 @@ def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) T_batch_matmul_NN = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(16))) T_reshape_5 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(16))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): - with T.block("T_transpose"): + with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(q[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(8)): - with T.block("T_transpose_1"): + with T.sblock("T_transpose_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(k[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)): - with T.block("T_reshape_1"): + with T.sblock("T_reshape_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2]) T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] for b, i, j, k_1 in T.grid(T.int64(128), T.int64(16), T.int64(8), T.int64(8)): - with T.block("T_batch_matmul_NT"): + with T.sblock("T_batch_matmul_NT"): v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k_1]) T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k]) T.writes(T_batch_matmul_NT[v_b, v_i, v_j]) - T.block_attr({"layout_free_placeholders": [T_reshape_1]}) + T.sblock_attr({"layout_free_placeholders": [T_reshape_1]}) with T.init(): T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0.0) T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] * T.float32(0.10000000000000001) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): - with T.block("T_reshape_2"): + with T.sblock("T_reshape_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)]) T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], bias[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] + bias[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("T_reshape_3"): + with T.sblock("T_reshape_3"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]) T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2]) T_reshape_3[v_ax0, v_ax1, v_ax2] = T_add[((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)] for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("trilu"): + with T.sblock("trilu"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_reshape_3[v_i0, v_i1, v_i2]) T.writes(trilu[v_i0, v_i1, v_i2]) trilu[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, T_reshape_3[v_i0, v_i1, v_i2], T.float32(0.0)) for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), T.int64(1), T.int64(8)): - with T.block("trilu_red"): + with T.sblock("trilu_red"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(trilu[v_ax0, v_ax1, v_k2]) T.writes(trilu_red[v_ax0, v_ax1, v_ax2]) @@ -3488,25 +3488,25 @@ def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) trilu_red[v_ax0, v_ax1, v_ax2] = T.float32(-340282346638528859811704183484516925440.0) trilu_red[v_ax0, v_ax1, v_ax2] = T.max(trilu_red[v_ax0, v_ax1, v_ax2], trilu[v_ax0, v_ax1, v_k2]) for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(trilu[v_ax0, v_ax1, v_ax2], trilu_red[v_ax0, v_ax1, T.int64(0)]) T.writes(T_subtract[v_ax0, v_ax1, v_ax2]) T_subtract[v_ax0, v_ax1, v_ax2] = trilu[v_ax0, v_ax1, v_ax2] - trilu_red[v_ax0, v_ax1, T.int64(0)] for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_subtract[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) compute[v_i0, v_i1, v_i2] = T.exp(T_subtract[v_i0, v_i1, v_i2]) for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("trilu_1"): + with T.sblock("trilu_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(compute[v_i0, v_i1, v_i2]) T.writes(trilu_1[v_i0, v_i1, v_i2]) trilu_1[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, compute[v_i0, v_i1, v_i2], T.float32(0.0)) for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), T.int64(1), T.int64(8)): - with T.block("trilu_red_1"): + with T.sblock("trilu_red_1"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(trilu_1[v_ax0, v_ax1, v_k2]) T.writes(trilu_red_1[v_ax0, v_ax1, v_ax2]) @@ -3514,40 +3514,40 @@ def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) trilu_red_1[v_ax0, v_ax1, v_ax2] = T.float32(0.0) trilu_red_1[v_ax0, v_ax1, v_ax2] = trilu_red_1[v_ax0, v_ax1, v_ax2] + trilu_1[v_ax0, v_ax1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(trilu_1[v_ax0, v_ax1, v_ax2], trilu_red_1[v_ax0, v_ax1, T.int64(0)]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = trilu_1[v_ax0, v_ax1, v_ax2] / trilu_red_1[v_ax0, v_ax1, T.int64(0)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(16)): - with T.block("T_transpose_2"): + with T.sblock("T_transpose_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(v[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)): - with T.block("T_reshape_4"): + with T.sblock("T_reshape_4"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]) T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2]) T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)] for b, i, j, k_1 in T.grid(T.int64(128), T.int64(16), T.int64(16), T.int64(8)): - with T.block("T_batch_matmul_NN"): + with T.sblock("T_batch_matmul_NN"): v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k_1]) T.reads(T_divide[v_b, v_i, v_k], T_reshape_4[v_b, v_k, v_j]) T.writes(T_batch_matmul_NN[v_b, v_i, v_j]) - T.block_attr({"layout_free_placeholders": [T_reshape_4]}) + T.sblock_attr({"layout_free_placeholders": [T_reshape_4]}) with T.init(): T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0.0) T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_divide[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(16)): - with T.block("T_reshape_5"): + with T.sblock("T_reshape_5"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)]) T.writes(T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(16), T.int64(32), T.int64(16)): - with T.block("T_transpose_3"): + with T.sblock("T_transpose_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -3621,19 +3621,19 @@ def nll_loss( # function attr dict T.func_attr({"tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") nll_loss = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32") nll_loss_red = T.alloc_buffer([], dtype="float32") nll_loss_1 = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32") nll_loss_red_1 = T.alloc_buffer([], dtype="float32") for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss"): + with T.sblock("nll_loss"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(targets[v_ax0, v_ax1, v_ax2], predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss[v_ax0, v_ax1, v_ax2]) nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - predictions[v_ax0, targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss_red"): + with T.sblock("nll_loss_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(nll_loss[v_k0, v_k1, v_k2]) T.writes(nll_loss_red[()]) @@ -3641,20 +3641,20 @@ def nll_loss( nll_loss_red[()] = T.float32(0) nll_loss_red[()] = nll_loss_red[()] + nll_loss[v_k0, v_k1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss_1"): + with T.sblock("nll_loss_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(targets[v_ax0, v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2]) nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, v_ax1, v_ax2] != T.int64(-1), weights[targets[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss_red_1"): + with T.sblock("nll_loss_red_1"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(nll_loss_1[v_k0, v_k1, v_k2]) T.writes(nll_loss_red_1[()]) with T.init(): nll_loss_red_1[()] = T.float32(0) nll_loss_red_1[()] = nll_loss_red_1[()] + nll_loss_1[v_k0, v_k1, v_k2] - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(nll_loss_red[()], nll_loss_red_1[()]) T.writes(output[()]) @@ -3686,26 +3686,26 @@ def nll_loss_without_weight(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.i # function attr dict T.func_attr({"tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") T_full = T.alloc_buffer([T.int64(3)], dtype="float32") nll_loss = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32") nll_loss_red = T.alloc_buffer([], dtype="float32") nll_loss_1 = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32") nll_loss_red_1 = T.alloc_buffer([], dtype="float32") for ax0 in T.serial(T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads() T.writes(T_full[v_ax0]) T_full[v_ax0] = T.float32(1) for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss"): + with T.sblock("nll_loss"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss[v_ax0, v_ax1, v_ax2]) nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss_red"): + with T.sblock("nll_loss_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(nll_loss[v_k0, v_k1, v_k2]) T.writes(nll_loss_red[()]) @@ -3713,20 +3713,20 @@ def nll_loss_without_weight(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.i nll_loss_red[()] = T.float32(0) nll_loss_red[()] = nll_loss_red[()] + nll_loss[v_k0, v_k1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss_1"): + with T.sblock("nll_loss_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]]) T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2]) nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), T_full[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0)) for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("nll_loss_red_1"): + with T.sblock("nll_loss_red_1"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(nll_loss_1[v_k0, v_k1, v_k2]) T.writes(nll_loss_red_1[()]) with T.init(): nll_loss_red_1[()] = T.float32(0) nll_loss_red_1[()] = nll_loss_red_1[()] + nll_loss_1[v_k0, v_k1, v_k2] - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(nll_loss_red[()], nll_loss_red_1[()]) T.writes(T_divide[()]) @@ -3760,20 +3760,20 @@ def nll_loss(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((), "int64"), C = T.int64() rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (C,)) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_1, (C,)) - # with T.block("root"): + # with T.sblock("root"): nll_loss = T.alloc_buffer(()) nll_loss_1 = T.alloc_buffer(()) - with T.block("nll_loss"): + with T.sblock("nll_loss"): vi = T.axis.spatial(T.int64(1), T.int64(0)) T.reads(rxplaceholder[()], rxplaceholder_1[rxplaceholder[()]], rxplaceholder_2[rxplaceholder[()]]) T.writes(nll_loss[()]) nll_loss[()] = T.Select(rxplaceholder[()] != T.int64(1), (T.float32(0) - rxplaceholder_1[rxplaceholder[()]]) * rxplaceholder_2[rxplaceholder[()]], T.float32(0)) - with T.block("nll_loss_1"): + with T.sblock("nll_loss_1"): vi = T.axis.spatial(T.int64(1), T.int64(0)) T.reads(rxplaceholder[()], rxplaceholder_2[rxplaceholder[()]]) T.writes(nll_loss_1[()]) nll_loss_1[()] = T.Select(rxplaceholder[()] != T.int64(1), rxplaceholder_2[rxplaceholder[()]], T.float32(0)) - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(nll_loss[()], nll_loss_1[()]) T.writes(T_divide[()]) @@ -3813,19 +3813,19 @@ def nll_loss(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxp rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N, d1, d2], dtype="int64") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [C], dtype="float32") # body - # with T.block("root") + # with T.sblock("root") nll_loss = T.alloc_buffer([N, d1, d2], dtype="float32") nll_loss_red = T.alloc_buffer([], dtype="float32") nll_loss_1 = T.alloc_buffer([N, d1, d2], dtype="float32") nll_loss_red_1 = T.alloc_buffer([], dtype="float32") for ax0, ax1, ax2 in T.grid(N, d1, d2): - with T.block("nll_loss"): + with T.sblock("nll_loss"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2],rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]],) T.writes(nll_loss[v_ax0, v_ax1, v_ax2]) nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0),) for k0, k1, k2 in T.grid(N, d1, d2): - with T.block("nll_loss_red"): + with T.sblock("nll_loss_red"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(nll_loss[v_k0, v_k1, v_k2]) T.writes(nll_loss_red[()]) @@ -3833,20 +3833,20 @@ def nll_loss(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxp nll_loss_red[()] = T.float32(0) nll_loss_red[()] = nll_loss_red[()] + nll_loss[v_k0, v_k1, v_k2] for ax0, ax1, ax2 in T.grid(N, d1, d2): - with T.block("nll_loss_1"): + with T.sblock("nll_loss_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]],) T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2]) nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0),) for k0, k1, k2 in T.grid(N, d1, d2): - with T.block("nll_loss_red_1"): + with T.sblock("nll_loss_red_1"): v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2]) T.reads(nll_loss_1[v_k0, v_k1, v_k2]) T.writes(nll_loss_red_1[()]) with T.init(): nll_loss_red_1[()] = T.float32(0) nll_loss_red_1[()] = nll_loss_red_1[()] + nll_loss_1[v_k0, v_k1, v_k2] - with T.block("T_divide"): + with T.sblock("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(nll_loss_red[()], nll_loss_red_1[()]) T.writes(T_divide[()]) @@ -3879,9 +3879,9 @@ def pad( PadInput: T.Buffer((T.int64(2), T.int64(130), T.int64(30)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2 in T.grid(T.int64(2), T.int64(130), T.int64(30)): - with T.block("PadInput"): + with T.sblock("PadInput"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)]) T.writes(PadInput[v_i0, v_i1, v_i2]) diff --git a/tests/python/relax/test_transform_legalize_ops_qdq.py b/tests/python/relax/test_transform_legalize_ops_qdq.py index 09706c637ef7..fdcfa8de80ee 100644 --- a/tests/python/relax/test_transform_legalize_ops_qdq.py +++ b/tests/python/relax/test_transform_legalize_ops_qdq.py @@ -43,9 +43,9 @@ def quantize( quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("quantized"): + with T.sblock("quantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i0], C[v_i0]) T.writes(quantized[v_i0, v_i1]) @@ -97,9 +97,9 @@ def quantize( quantized: T.Buffer((T.int64(2), T.int64(4)), "uint8"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("quantized"): + with T.sblock("quantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i0], C[v_i0]) T.writes(quantized[v_i0, v_i1]) @@ -151,9 +151,9 @@ def quantize(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_quantized: T B = T.match_buffer(var_B, (n,)) C = T.match_buffer(var_C, (n,), "int8") quantized = T.match_buffer(var_quantized, (T.int64(4), n), "int8") - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(4), n): - with T.block("quantized"): + with T.sblock("quantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i1], C[v_i1]) T.writes(quantized[v_i0, v_i1]) @@ -202,9 +202,9 @@ def quantize( quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("quantized"): + with T.sblock("quantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(quantized[v_i0, v_i1]) @@ -252,9 +252,9 @@ def quantize( quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("quantized"): + with T.sblock("quantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], B[v_i0], C[v_i0]) T.writes(quantized[v_i0, v_i1]) @@ -301,9 +301,9 @@ def quantize( quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("quantized"): + with T.sblock("quantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(quantized[v_i0, v_i1]) @@ -349,9 +349,9 @@ def dequantize( dequantized: T.Buffer((T.int64(2), T.int64(4)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("dequantized"): + with T.sblock("dequantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], C[v_i0], B[v_i0]) T.writes(dequantized[v_i0, v_i1]) @@ -393,9 +393,9 @@ def dequantize( dequantized: T.Buffer((T.int64(2), T.int64(4)), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("dequantized"): + with T.sblock("dequantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(dequantized[v_i0, v_i1]) @@ -437,9 +437,9 @@ def dequantize( B = T.match_buffer(var_B, (n,)) C = T.match_buffer(var_C, (n,), "int8") dequantized = T.match_buffer(var_dequantized, (T.int64(2), n)) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), n): - with T.block("dequantized"): + with T.sblock("dequantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], C[v_i1], B[v_i1]) T.writes(dequantized[v_i0, v_i1]) @@ -486,9 +486,9 @@ def dequantize( dequantized: T.Buffer((T.int64(2), T.int64(4)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("dequantized"): + with T.sblock("dequantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1], C[v_i0], B[v_i0]) T.writes(dequantized[v_i0, v_i1]) @@ -540,9 +540,9 @@ def dequantize( dequantized: T.Buffer((T.int64(2), T.int64(4)), "float16"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): - with T.block("dequantized"): + with T.sblock("dequantized"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(A[v_i0, v_i1]) T.writes(dequantized[v_i0, v_i1]) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index b28451da1b18..9ddefd96daf7 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -45,7 +45,7 @@ def main(condition: R.Tensor((3, 2, 1), "bool"), x: R.Tensor((2, 3), "float32"), def where(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(1)), "bool"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(1)), "float32"), T_where: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_where"): + with T.sblock("T_where"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) T.writes(T_where[ax0, ax1, ax2]) @@ -89,7 +89,7 @@ def where(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplac rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [b, T.int64(1)], dtype="float32") T_where = T.match_buffer(var_T_where, [a, b, c], dtype="float32") for i0, i1, i2 in T.grid(a, b, c): - with T.block("T_where"): + with T.sblock("T_where"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) T.writes(T_where[ax0, ax1, ax2]) @@ -122,7 +122,7 @@ def argmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64( rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5)), "int64") rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) for ax0, ax1, ax2, k1 in T.grid(T.int64(2), T.int64(4), T.int64(5), T.int64(3)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0, v_ax1, v_ax2, v_k1 = T.axis.remap("SSSR", [ax0, ax1, ax2, k1]) T.reads(rxplaceholder[v_ax0, v_k1, v_ax1, v_ax2]) T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2]) @@ -134,7 +134,7 @@ def argmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64( rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_rxplaceholder_red_temp_v1 for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2]) @@ -176,11 +176,11 @@ def argmax(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c, d)) rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, (a, T.int64(1), c, d), "int64") - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red_temp_v0 = T.alloc_buffer((a, T.int64(1), c, d), "int64") rxplaceholder_red_temp_v1 = T.alloc_buffer((a, T.int64(1), c, d)) for ax0, ax1, ax2, ax3, k1 in T.grid(a, T.int64(1), c, d, b): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0, v_ax1, v_ax2, v_ax3, v_k1 = T.axis.remap("SSSSR", [ax0, ax1, ax2, ax3, k1]) T.reads(rxplaceholder[v_ax0, v_k1, v_ax2, v_ax3]) T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -192,7 +192,7 @@ def argmax(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(a, T.int64(1), c, d): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -220,7 +220,7 @@ def argmin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64( rxplaceholder_red_temp_v0 = T.alloc_buffer((), "int64") rxplaceholder_red_temp_v1 = T.alloc_buffer(()) for k0, k1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_k0, v_k1, v_k2, v_k3 = T.axis.remap("RRRR", [k0, k1, k2, k3]) T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3]) T.writes(rxplaceholder_red_temp_v0[()], rxplaceholder_red_temp_v1[()]) @@ -231,7 +231,7 @@ def argmin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64( v_rxplaceholder_red_temp_v1: T.float32 = T.Select(rxplaceholder_red_temp_v1[()] < rxplaceholder[v_k0, v_k1, v_k2, v_k3], rxplaceholder_red_temp_v1[()], rxplaceholder[v_k0, v_k1, v_k2, v_k3]) rxplaceholder_red_temp_v0[()] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[()] = v_rxplaceholder_red_temp_v1 - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): vi = T.axis.spatial(1, T.int64(0)) T.reads(rxplaceholder_red_temp_v0[()]) T.writes(rxplaceholder_red[()]) @@ -269,7 +269,7 @@ def argmin(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "int64") rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) for ax0, ax1, ax2, ax3, k0, k1, k2, k3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): - with T.block("rxplaceholder_red_temp"): + with T.sblock("rxplaceholder_red_temp"): v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k1, v_k2, v_k3 = T.axis.remap("SSSSRRRR", [ax0, ax1, ax2, ax3, k0, k1, k2, k3]) T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3]) T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -281,7 +281,7 @@ def argmin(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v0 rxplaceholder_red_temp_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_rxplaceholder_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red_temp_v0[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -320,7 +320,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 5), "float32"): def max(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(5), T.int64(3), T.int64(4)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, k1, k2, ax1]) T.writes(rxplaceholder_red[ax0, ax1]) @@ -363,7 +363,7 @@ def max(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, d, b, c): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, k1, k2, ax1]) T.writes(rxplaceholder_red[ax0, ax1]) @@ -396,7 +396,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 1, 1, 5), "float3 def min(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(5)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(5), T.int64(3), T.int64(4)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[ax0, k1, k2, ax3]) T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -439,7 +439,7 @@ def min(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, T.int64(1), T.int64(1), d], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), T.int64(1), d, b, c): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[ax0, k1, k2, ax3]) T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -472,7 +472,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): def sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[k0, k1, k2, k3]) T.writes(rxplaceholder_red[()]) @@ -510,7 +510,7 @@ def sum(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((), "float32")) d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[k0, k1, k2, k3]) T.writes(rxplaceholder_red[()]) @@ -543,7 +543,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 1, 1, 1), "float3 def prod(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads(rxplaceholder[k0, k1, k2, k3]) T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -581,7 +581,7 @@ def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads(rxplaceholder[k0, k1, k2, k3]) T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -615,7 +615,7 @@ def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5) T.func_attr({"tir.noalias": True}) rxplaceholder_red = T.alloc_buffer([T.int64(3), T.int64(4)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(3), T.int64(4), T.int64(2), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, k0, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[k0, ax0, ax1, k3]) T.writes(rxplaceholder_red[ax0, ax1]) @@ -623,7 +623,7 @@ def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5) rxplaceholder_red[ax0, ax1] = T.float32(0) rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax0, ax1, k3] for i0, i1 in T.grid(T.int64(3), T.int64(4)): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder_red[ax0, ax1]) T.writes(T_divide[ax0, ax1]) @@ -665,7 +665,7 @@ def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): T_divide = T.match_buffer(var_T_divide, [b, c], dtype="float32") rxplaceholder_red = T.alloc_buffer([b, c], dtype="float32") for i0, i1, i2, i3 in T.grid(b, c, a, d): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, k0, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[k0, ax0, ax1, k3]) T.writes(rxplaceholder_red[ax0, ax1]) @@ -673,7 +673,7 @@ def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): rxplaceholder_red[ax0, ax1] = T.float32(0) rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax0, ax1, k3] for i0, i1 in T.grid(b, c): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder_red[ax0, ax1]) T.writes(T_divide[ax0, ax1]) @@ -704,19 +704,19 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tuple(R.Tensor((3, 4, def median(var_x: T.handle, T_squeeze: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "float32"), T_squeeze_1: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "int64")): T.func_attr({"tir.noalias": True}) data_buf = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), align=8) - # with T.block("root"): + # with T.sblock("root"): T_full = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5)), "int64") out_buf = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "int64", align=8) T_gather = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_gather_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5))) T_gather_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5)), "int64") for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads() T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) T_full[v_ax0, v_ax1, v_ax2, v_ax3] = 0 - with T.block("argsort_cpu"): + with T.sblock("argsort_cpu"): T.reads() T.writes() T.call_packed("tvm.contrib.sort.argsort", T.tvm_stack_make_array(data_buf.data, @@ -727,31 +727,31 @@ def median(var_x: T.handle, T_squeeze: T.Buffer((T.int64(3), T.int64(4), T.int64 0, 4, T.int64(0), T.int64(0)), 0, T.bool(True)) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_gather"): + with T.sblock("T_gather"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3], out_buf[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_gather[v_ax0, v_ax1, v_ax2, v_ax3]) T_gather[v_ax0, v_ax1, v_ax2, v_ax3] = data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_gather_1"): + with T.sblock("T_gather_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_squeeze"): + with T.sblock("T_squeeze"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_gather_1[T.int64(0), v_ax0, v_ax1, v_ax2]) T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) T_squeeze[v_ax0, v_ax1, v_ax2] = T_gather_1[T.int64(0), v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_gather_2"): + with T.sblock("T_gather_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3]) T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3] = out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_squeeze_1"): + with T.sblock("T_squeeze_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_gather_2[T.int64(0), v_ax0, v_ax1, v_ax2]) T.writes(T_squeeze_1[v_ax0, v_ax1, v_ax2]) @@ -776,7 +776,7 @@ class Expected: @T.prim_func(private=True) def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), compute: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) T_divide = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) @@ -784,7 +784,7 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)) T_multiply_red = T.alloc_buffer(()) T_divide_1 = T.alloc_buffer(()) for ax0, ax1, ax2, ax3, k0, k1, k2, k3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k1, v_k2, v_k3 = T.axis.remap("SSSSRRRR", [ax0, ax1, ax2, ax3, k0, k1, k2, k3]) T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -792,37 +792,37 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)) rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder[v_k0, v_k1, v_k2, v_k3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(120.0) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_divide[T.int64(0), T.int64(0), T.int64(0), T.int64(0)]) T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_divide[T.int64(0), T.int64(0), T.int64(0), T.int64(0)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] for k0, k1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_k0, v_k1, v_k2, v_k3 = T.axis.remap("RRRR", [k0, k1, k2, k3]) T.reads(T_multiply[v_k0, v_k1, v_k2, v_k3]) T.writes(T_multiply_red[()]) with T.init(): T_multiply_red[()] = T.float32(0) T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1, v_k2, v_k3] - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_divide_1[()]) T_divide_1[()] = T_multiply_red[()] / T.float32(120.0) - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_divide_1[()]) T.writes(compute[()]) @@ -855,7 +855,7 @@ def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) a, b, c, d = T.int64(), T.int64(), T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c, d)) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) T_divide = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) T_subtract = T.alloc_buffer((a, b, c, d)) @@ -863,7 +863,7 @@ def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): T_multiply_red = T.alloc_buffer(()) T_divide_1 = T.alloc_buffer(()) for ax0, ax1, ax2, ax3, k0, k1, k2, k3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k1, v_k2, v_k3 = T.axis.remap("SSSSRRRR", [ax0, ax1, ax2, ax3, k0, k1, k2, k3]) T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -871,37 +871,37 @@ def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder[v_k0, v_k1, v_k2, v_k3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", a * b * c * d) for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_divide[T.int64(0), T.int64(0), T.int64(0), T.int64(0)]) T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_divide[T.int64(0), T.int64(0), T.int64(0), T.int64(0)] for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] for k0, k1, k2, k3 in T.grid(a, b, c, d): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_k0, v_k1, v_k2, v_k3 = T.axis.remap("RRRR", [k0, k1, k2, k3]) T.reads(T_multiply[v_k0, v_k1, v_k2, v_k3]) T.writes(T_multiply_red[()]) with T.init(): T_multiply_red[()] = T.float32(0) T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1, v_k2, v_k3] - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_divide_1[()]) T_divide_1[()] = T_multiply_red[()] / T.Cast("float32", a * b * c * d) - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_divide_1[()]) T.writes(compute[()]) @@ -947,7 +947,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 T_multiply = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") T_multiply_red = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[k0, ax1, ax2, k3]) T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -955,25 +955,25 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(0) rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[k0, ax1, ax2, k3] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) T.writes(T_divide_1[ax0, ax1, ax2, ax3]) T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] / T.float32(10.0) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide_1[T.int64(0), ax1, ax2, T.int64(0)]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide_1[T.int64(0), ax1, ax2, T.int64(0)] for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3]) T.writes(T_multiply[ax0, ax1, ax2, ax3]) T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(T_multiply[k0, ax1, ax2, k3]) T.writes(T_multiply_red[ax0, ax1, ax2, ax3]) @@ -981,7 +981,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 T_multiply_red[ax0, ax1, ax2, ax3] = T.float32(0) T_multiply_red[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[k0, ax1, ax2, k3] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) T.writes(T_divide[ax0, ax1, ax2, ax3]) @@ -1027,7 +1027,7 @@ def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): T_multiply = T.alloc_buffer([a, b, c, d], dtype="float32") T_multiply_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), b, c, T.int64(1), a, d): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[k0, ax1, ax2, k3]) T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -1035,25 +1035,25 @@ def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(0) rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[k0, ax1, ax2, k3] for i0, i1, i2, i3 in T.grid(T.int64(1), b, c, T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) T.writes(T_divide_1[ax0, ax1, ax2, ax3]) T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] / T.Cast("float32", a * d) for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_subtract"): + with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide_1[T.int64(0), ax1, ax2, T.int64(0)]) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide_1[T.int64(0), ax1, ax2, T.int64(0)] for i0, i1, i2, i3 in T.grid(a, b, c, d): - with T.block("T_multiply"): + with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract[ax0, ax1, ax2, ax3]) T.writes(T_multiply[ax0, ax1, ax2, ax3]) T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), b, c, T.int64(1), a, d): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(T_multiply[k0, ax1, ax2, k3]) T.writes(T_multiply_red[ax0, ax1, ax2, ax3]) @@ -1061,7 +1061,7 @@ def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): T_multiply_red[ax0, ax1, ax2, ax3] = T.float32(0) T_multiply_red[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[k0, ax1, ax2, k3] for i0, i1, i2, i3 in T.grid(T.int64(1), b, c, T.int64(1)): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) T.writes(T_divide[ax0, ax1, ax2, ax3]) @@ -1086,14 +1086,14 @@ class Expected: @T.prim_func(private=True) def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1))) T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1))) T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply_red = T.alloc_buffer((T.int64(3), T.int64(4))) for ax0, ax1, ax2, ax3, k0, k3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): - with T.block("rxplaceholder_red"): + with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k3 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, k0, k3]) T.reads(rxplaceholder[v_k0, v_ax1, v_ax2, v_k3]) T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) @@ -1101,25 +1101,25 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder[v_k0, v_ax1, v_ax2, v_k3] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): - with T.block("T_divide"): + with T.sblock("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(10) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_subtract"): + with T.sblock("T_subtract"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_divide_1[T.int64(0), v_ax1, v_ax2, T.int64(0)]) T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_divide_1[T.int64(0), v_ax1, v_ax2, T.int64(0)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.block("T_multiply"): + with T.sblock("T_multiply"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, k0, k3 in T.grid(T.int64(3), T.int64(4), T.int64(2), T.int64(5)): - with T.block("T_multiply_red"): + with T.sblock("T_multiply_red"): v_ax0, v_ax1, v_k0, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k0, k3]) T.reads(T_multiply[v_k0, v_ax0, v_ax1, v_k3]) T.writes(T_multiply_red[v_ax0, v_ax1]) @@ -1127,7 +1127,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 T_multiply_red[v_ax0, v_ax1] = T.float32(0) T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_k0, v_ax0, v_ax1, v_k3] for ax0, ax1 in T.grid(T.int64(3), T.int64(4)): - with T.block("T_divide_1"): + with T.sblock("T_divide_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(T_divide[v_ax0, v_ax1]) diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 066282ae15b1..c187f1d97661 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -35,7 +35,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ) -> None: for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) out[o, i, h, w] = w1[i, o, h, w] @@ -102,7 +102,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -175,7 +175,7 @@ def transform_layout_IOHW_to_OIHW( w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") ): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): - with T.block("layout_transform"): + with T.sblock("layout_transform"): o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w1[i, o, h, w]) T.writes(out[o, i, h, w]) @@ -1440,7 +1440,7 @@ def zeros(var_T_full: T.handle): n = T.int64() T_full = T.match_buffer(var_T_full, (n, n)) for ax0, ax1 in T.grid(n, n): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads() T.writes(T_full[v_ax0, v_ax1]) @@ -1465,9 +1465,9 @@ def zeros(var_T_full: T.handle): T.func_attr({"tir.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (n, n)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(n, n): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads() T.writes(T_full[v_ax0, v_ax1]) @@ -1532,7 +1532,7 @@ def slice( ): T.func_attr({"tir.noalias": True}) for j in range(16): - with T.block("T_full"): + with T.sblock("T_full"): vj = T.axis.remap("S", [j]) Output_Slice[vj] = Input_2d[slice_index, vj] @@ -1586,7 +1586,7 @@ def slice( ): T.func_attr({"tir.noalias": True}) for j in range(16): - with T.block("T_full"): + with T.sblock("T_full"): vj = T.axis.remap("S", [j]) Output_Slice[vj] = Input_2d[slice_index, vj] diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index 1282e1374f56..78ac386fd7d3 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -1142,7 +1142,7 @@ def relu( ): T.func_attr({"tir.noalias": True}) for i in range(T.int64(10)): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) Output[vi] = T.max(Input[vi], T.float32(0)) @@ -1194,7 +1194,7 @@ def relu( ): T.func_attr({"tir.noalias": True}) for i in range(T.int64(10)): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) Output[vi] = T.max(Input[vi], T.float32(0)) diff --git a/tests/python/relax/test_transform_meta_schedule_apply_database.py b/tests/python/relax/test_transform_meta_schedule_apply_database.py index 129901b97035..07a96eef2ea8 100644 --- a/tests/python/relax/test_transform_meta_schedule_apply_database.py +++ b/tests/python/relax/test_transform_meta_schedule_apply_database.py @@ -32,7 +32,7 @@ class RecordModule: def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.serial(2): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial(2, i) B[vi] = A[vi] @@ -42,7 +42,7 @@ class BlockRenamedModule: def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.serial(2): - with T.block("renamed_block"): + with T.sblock("renamed_block"): vi = T.axis.spatial(2, i) B[vi] = A[vi] @@ -58,13 +58,13 @@ def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): } ) for i in T.serial(2): - with T.block("renamed_block"): + with T.sblock("renamed_block"): vi = T.axis.spatial(2, i) B[vi] = A[vi] def create_trace(mod: tvm.IRModule): sch = tir.Schedule(mod) - _ = sch.get_block("block") + _ = sch.get_sblock("block") return sch.trace db = ms.database.create(kind="memory") diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index 3d290c0ae8c6..7e7e43462acb 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -42,7 +42,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (32, 32)) for i0, j0, k0 in T.grid(32, 32, 32): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): C[i, j] = 0.0 @@ -54,7 +54,7 @@ def tir_relu(x: T.handle, y: T.handle): A = T.match_buffer(x, (32, 32)) B = T.match_buffer(y, (32, 32)) for i, j in T.grid(32, 32): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.max(A[vi, vj], 0.0) @@ -157,11 +157,11 @@ def tir_matmul( C: T.Buffer((32, 32), "float32"), ): T.func_attr({"global_symbol": "tir_matmul", "tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_j0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_j0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): for k0 in range(32): - with T.block(""): + with T.sblock(""): i = T.axis.spatial(32, (i0_j0_fused_0 * 1024 + i0_j0_fused_1) // 32) j = T.axis.spatial(32, (i0_j0_fused_0 * 1024 + i0_j0_fused_1) % 32) k = T.axis.reduce(32, k0) @@ -174,10 +174,10 @@ def tir_matmul( @T.prim_func def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")): T.func_attr({"global_symbol": "tir_relu", "tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): vi = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32) vj = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32) T.reads(A[vi, vj]) diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 5dced084ebab..5afa9a34be68 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -46,7 +46,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32") @@ -87,10 +87,10 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) # body - # with T.block("root") + # with T.sblock("root") for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) T.reads(rxplaceholder[i0, i1]) @@ -157,10 +157,10 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) # body - # with T.block("root") + # with T.sblock("root") for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) T.reads(rxplaceholder[i0, i1]) @@ -198,10 +198,10 @@ class Expected: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): T.func_attr({"global_symbol": "exp", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) T.reads(rxplaceholder[i0, i1]) @@ -266,7 +266,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32") @@ -299,10 +299,10 @@ class Expected: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): T.func_attr({"global_symbol": "exp", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) T.reads(rxplaceholder[i0, i1]) @@ -404,7 +404,7 @@ def fused_conv2d_relu( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): pad_temp = T.alloc_buffer( (T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16" ) @@ -412,7 +412,7 @@ def fused_conv2d_relu( (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16" ) for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), T.int64(34), T.int64(16)): - with T.block("pad_temp"): + with T.sblock("pad_temp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) @@ -433,7 +433,7 @@ def fused_conv2d_relu( T.int64(3), T.int64(16), ): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = T.axis.remap( "SSSSRRR", [nn, yy, xx, ff, ry, rx, rc] ) @@ -450,7 +450,7 @@ def fused_conv2d_relu( * weight1[v_ff, v_ry, v_rx, v_rc] ) for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -466,11 +466,11 @@ def layer_norm( T_layer_norm: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), ): T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): A_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32))) A_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32))) for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, ax1, ax2, k3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_k3]) T.writes(A_red_temp_v0[v_ax0, v_ax1, v_ax2], A_red_temp_v1[v_ax0, v_ax1, v_ax2]) @@ -486,7 +486,7 @@ def layer_norm( A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0 A_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( A[v_ax0, v_ax1, v_ax2, v_ax3], @@ -766,7 +766,7 @@ def add_one(x_handle: T.handle, y_handle: T.handle): x = T.match_buffer(x_handle, (m,), "float32") y = T.match_buffer(y_handle, (m,), "float32") for i in range(m): - with T.block("add"): + with T.sblock("add"): vi = T.axis.remap("S", [i]) y[vi] = x[vi] + T.float32(1) @@ -801,9 +801,9 @@ def add_one(x_handle: T.handle, y_handle: T.handle): m = T.int64() x = T.match_buffer(x_handle, (m,)) y = T.match_buffer(y_handle, (m,)) - # with T.block("root"): + # with T.sblock("root"): for i in range(m): - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(m, i) T.reads(x[vi]) T.writes(y[vi]) diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index 677ba41a209b..3bc551a0b209 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -30,7 +30,7 @@ def reshape( T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), ): for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( rxplaceholder[ @@ -54,7 +54,7 @@ def expand_dims( for i0, i1, i2, i3, i4 in T.grid( T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) ): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(rxplaceholder[i0_1, i2_1, i4_1]) T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) @@ -81,7 +81,7 @@ def reshape( T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), ): for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( rxplaceholder[ @@ -105,7 +105,7 @@ def expand_dims( for i0, i1, i2, i3, i4 in T.grid( T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) ): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(rxplaceholder[i0_1, i2_1, i4_1]) T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) @@ -140,7 +140,7 @@ def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), " for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(2), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720)) v_ax1 = T.axis.spatial(T.int64(4096), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320)) v_ax2 = T.axis.spatial(T.int64(5), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64)) @@ -160,7 +160,7 @@ def expand_dims( for i0, i1, i2, i3, i4, i5 in T.grid( T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64) ): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1]) T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) @@ -183,9 +183,9 @@ def main( class Expected: @T.prim_func def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32"), expand_dims_1: T.Buffer((T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)), "float32")): - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(4096), T.int64(1), T.int64(5), T.int64(64)): - with T.block("expand_dims"): + with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[i0_1, i2_1, i4_1, i5_1]) T.writes(expand_dims_1[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) @@ -193,11 +193,11 @@ def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), @T.prim_func def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4096), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4096), T.int64(5), T.int64(64)), "float32")): - # with T.block("root"): + # with T.sblock("root"): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(10)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(2), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) // T.int64(1310720)) v_ax1 = T.axis.spatial(T.int64(4096), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(1310720) // T.int64(320)) v_ax2 = T.axis.spatial(T.int64(5), (ax0_ax1_ax2_ax3_fused_0 * T.int64(262144) + ax0_ax1_ax2_ax3_fused_1 * T.int64(1024) + ax0_ax1_ax2_ax3_fused_2) % T.int64(320) // T.int64(64)) @@ -230,10 +230,10 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): n = T.int32() A = T.match_buffer(var_A, (n, 16, 128), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), "float16") - # with T.block("root"): + # with T.sblock("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v0 = T.axis.spatial( n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2048 ) @@ -272,10 +272,10 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): n = T.int32() A = T.match_buffer(var_A, (n, 16, 128), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), "float16") - # with T.block("root"): + # with T.sblock("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 2, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v0 = T.axis.spatial( n, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) // 2048 ) @@ -319,7 +319,7 @@ def reshape( T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), ): for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads( rxplaceholder[ @@ -358,9 +358,9 @@ def fused_reshape5( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( lv2_0[ @@ -419,9 +419,9 @@ def fused_reshape5( ), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads( lv2_0[ @@ -482,7 +482,7 @@ def strided_slice( ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)): - with T.block("T_strided_slice"): + with T.sblock("T_strided_slice"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_strided_slice[v_ax0, v_ax1]) @@ -494,7 +494,7 @@ def add_one( T_add_one: T.buffer((T.int64(1), T.int64(1000)), "int32"), ): for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)): - with T.block("T_add_one"): + with T.sblock("T_add_one"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1]) T.writes(T_add_one[v_ax0, v_ax1]) @@ -555,9 +555,9 @@ def add( T_add: T.Buffer((T.int64(1),), "float32"), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(T.int64(1)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(T.int64(1), ax0) T.reads(A[v_ax0], B[v_ax0]) T.writes(T_add[v_ax0]) @@ -566,9 +566,9 @@ def add( @T.prim_func(private=True) def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")): T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0 in range(T.int64(1)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): v_ax0 = T.axis.spatial(T.int64(1), ax0) T.reads(A[()]) T.writes(T_reshape[v_ax0]) @@ -621,7 +621,7 @@ def add( T.func_attr({"tir.noalias": True}) for iters in T.grid(T.int64(64), T.int64(4)): - with T.block("T_add"): + with T.sblock("T_add"): i, j = T.axis.remap("SS", iters) z[i, j] = y1[i, j] + y2[i, j] @@ -686,7 +686,7 @@ def add( # T.func_attr({"tir.noalias": True}) # for iters in T.grid(T.int64(64), T.int64(4)): -# with T.block("T_add"): +# with T.sblock("T_add"): # i, j = T.axis.remap("SS", iters) # z[i, j] = y1[i, j] + y2[i, j] @@ -750,7 +750,7 @@ def add( T.func_attr({"tir.noalias": True}) for iters in T.grid(N * 4, T.int64(4)): - with T.block("T_add"): + with T.sblock("T_add"): i, j = T.axis.remap("SS", iters) z[i, j] = y1[i, j] + y2[i, j] diff --git a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py index d92570025fce..d2b3c00955e8 100644 --- a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py +++ b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py @@ -97,11 +97,11 @@ def max_pool2d_opencl( (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) ): - with T.block("pool_max"): + with T.sblock("pool_max"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] ) @@ -115,7 +115,7 @@ def max_pool2d_opencl( ] ) T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + T.sblock_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( -340282346638528859811704183484516925440.0 @@ -138,9 +138,9 @@ def te_layout_transform( (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) T.reads(x[v_self, v_i0, v_i1, v_i2]) T.writes( @@ -161,11 +161,11 @@ def te_layout_transform2( (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" ), ): - # with T.block("root"): + # with T.sblock("root"): for self, i0, i1, i2, i3 in T.grid( T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) ): - with T.block("te_layout_transform"): + with T.sblock("te_layout_transform"): v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index 5d45db524177..7c7b877a46c2 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -34,12 +34,12 @@ def tir_func( T.func_attr({"layout_free_buffers": [1]}) W_rewrite = T.alloc_buffer((4, 4, 56, 56)) for i, j in T.grid(224, 224): - with T.block("W_rewrite"): + with T.sblock("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) vj = T.axis.spatial(224, j0 * 56 + j1) Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] @@ -67,7 +67,7 @@ def tir_func_prepacked( Out: T.Buffer((224, 224), "float32"), ): for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) vj = T.axis.spatial(224, j0 * 56 + j1) Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] @@ -78,7 +78,7 @@ def tir_func_weight_prepack( W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), ): for i, j in T.grid(224, 224): - with T.block("W_rewrite"): + with T.sblock("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] @@ -117,17 +117,17 @@ def tir_func( W1_rewrite = T.alloc_buffer((4, 4, 56, 56)) W2_rewrite = T.alloc_buffer((4, 4, 56, 56)) for i, j in T.grid(224, 224): - with T.block("W1_rewrite"): + with T.sblock("W1_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] for i, j in T.grid(224, 224): - with T.block("W2_rewrite"): + with T.sblock("W2_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) vj = T.axis.spatial(224, j0 * 56 + j1) Out[vi, vj] = ( @@ -161,7 +161,7 @@ def tir_func_prepacked( Out: T.Buffer((224, 224), "float32"), ): for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) vj = T.axis.spatial(224, j0 * 56 + j1) Out[vi, vj] = ( @@ -178,11 +178,11 @@ def tir_func_weight_prepack( W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), ): for i, j in T.grid(224, 224): - with T.block("W1_rewrite"): + with T.sblock("W1_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] for i, j in T.grid(224, 224): - with T.block("W2_rewrite"): + with T.sblock("W2_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] @@ -228,12 +228,12 @@ def tir_func( T.func_attr({"layout_free_buffers": [1], "tir.noalias": True}) W_rewrite = T.alloc_buffer((4, 4, 56, 56)) for i, j in T.grid(224, 224): - with T.block("W_rewrite"): + with T.sblock("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) vj = T.axis.spatial(224, j0 * 56 + j1) Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] @@ -262,7 +262,7 @@ def tir_func_prepacked( ): T.func_attr({"tir.noalias": True}) for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): - with T.block("Out"): + with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) vj = T.axis.spatial(224, j0 * 56 + j1) Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] @@ -274,7 +274,7 @@ def tir_func_weight_prepack( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(224, 224): - with T.block("W_rewrite"): + with T.sblock("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 2a23890d7f62..d9a015e73935 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -1054,7 +1054,7 @@ def tir_identity( Output: T.Buffer(64, "float16"), ): for i in range(64): - with T.block("copy"): + with T.sblock("copy"): vi = T.axis.remap("S", [i]) Output[vi] = Input[vi] diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4ba7a92dc754..e01f214d8b18 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -126,7 +126,7 @@ class TestWellCallTIR: def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: T.func_attr(({"global_symbol": "tir_addone"})) for i, j in T.grid(16, 16): - with T.block("tir_addone"): + with T.sblock("tir_addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.int32(1) @@ -198,7 +198,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -227,7 +227,7 @@ def plus_one( ): T.func_attr({"some_attr": "foo", "another_attr": True, "tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -289,7 +289,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -339,7 +339,7 @@ def tir_func( ): T.func_attr({"tir.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) y[vi, vj] = x[vi, vj] + 1.0 @@ -980,7 +980,7 @@ def copy(var_x: T.handle, var_y: T.handle, n: T.int64): X = T.match_buffer(var_x, (n * 2,), dtype="float32") Y = T.match_buffer(var_y, (n * 2,), dtype="float32") for i in T.grid(n * 2): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) Y[vi] = X[vi] @@ -996,7 +996,7 @@ def identity_tir(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [54, 96]) for i, j in T.grid(54, 96): - with T.block("compute"): + with T.sblock("compute"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -1027,7 +1027,7 @@ def copy( # copies the contents of B into A and out1 T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(B[ax0, ax1]) T.writes(A[ax0, ax1], out1[ax0, ax1]) @@ -1080,7 +1080,7 @@ def copy( # copies the contents of B into A and out1 T.func_attr({"tir.noalias": True}) for iters in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): i, j = T.axis.remap("SS", iters) A[i, j] = B[i, j] out1[i, j] = B[i, j] @@ -1131,7 +1131,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block(): + with T.sblock(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 diff --git a/tests/python/relax/test_tvmscript_pyfunc.py b/tests/python/relax/test_tvmscript_pyfunc.py index 7b3c4052fa93..2901d86308c1 100644 --- a/tests/python/relax/test_tvmscript_pyfunc.py +++ b/tests/python/relax/test_tvmscript_pyfunc.py @@ -65,7 +65,7 @@ def simple_tir_func( B = T.match_buffer(var_B, (n,), "float32") for i in T.grid(n): - with T.block("copy"): + with T.sblock("copy"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] diff --git a/tests/python/relax/test_vm_alloc_storage_with_scope.py b/tests/python/relax/test_vm_alloc_storage_with_scope.py index 3839ae123406..d9355cd6648e 100644 --- a/tests/python/relax/test_vm_alloc_storage_with_scope.py +++ b/tests/python/relax/test_vm_alloc_storage_with_scope.py @@ -37,7 +37,7 @@ def add( T.func_attr({"operator_name": "relax.add"}) for ax0 in range(2): for ax1 in range(2): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0 = T.axis.spatial(2, ax0) v_ax1 = T.axis.spatial(2, ax1) T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index efd2f7ecbf59..8d54dcab1cd2 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -196,7 +196,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -237,7 +237,7 @@ def copy( # copies the contents of C into A, B, and out1 T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_zeros"): + with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(C[ax0, ax1]) T.writes(A[ax0, ax1], B[ax0, ax1], out1[ax0, ax1]) @@ -291,7 +291,7 @@ def inplace_add(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")): # sums A and B, storing the result in A T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(A[ax0, ax1], B[ax0, ax1]) T.writes(A[ax0, ax1]) @@ -483,7 +483,7 @@ def test_vm_emit_te_constant_param_gpu(exec_mode): mod = bb.get() sch = tvm.tir.Schedule(mod, debug_mask="all") - loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) + loops = sch.get_loops(sch.get_sblock(name="T_add", func_name="add")) sch.bind(loops[0], "threadIdx.x") exec = relax.build(sch.mod, "cuda", exec_mode=exec_mode) @@ -746,7 +746,7 @@ def main(x: R.Tensor((2, 3), dtype="float32")): @T.prim_func def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): for i0, i1 in T.grid(2, 3): - with T.block("block"): + with T.sblock("block"): vi0, vi1 = T.axis.remap("SS", [i0, i1]) B[vi0, vi1] = A[vi0, vi1] @@ -773,7 +773,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C = T.match_buffer(z, (m, k)) for i, j, k in T.grid(m, k, n): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -948,7 +948,7 @@ def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): C = T.match_buffer(z, (m, n)) for i, j in T.grid(m, n): - with T.block("mul"): + with T.sblock("mul"): vi = T.axis.spatial(m, i) vj = T.axis.spatial(n, j) with T.init(): diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 9633244c67fb..a359a8ae86a2 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -366,7 +366,7 @@ class TestKillObject: def full(T_full: T.Buffer((T.int64(4),), "float32")): T.func_attr({"global_symbol": "full", "tir.noalias": True}) for ax0 in range(T.int64(4)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0 = T.axis.spatial(T.int64(4), ax0) T.reads() T.writes(T_full[v_ax0]) @@ -376,7 +376,7 @@ def full(T_full: T.Buffer((T.int64(4),), "float32")): def full1(T_full: T.Buffer((T.int64(4),), "float32")): T.func_attr({"global_symbol": "full1", "tir.noalias": True}) for ax0 in range(T.int64(4)): - with T.block("T_full"): + with T.sblock("T_full"): v_ax0 = T.axis.spatial(T.int64(4), ax0) T.reads() T.writes(T_full[v_ax0]) diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index d04fd6bdab1b..4fff42499bee 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -51,10 +51,10 @@ def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl @T.prim_func def add(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): T.func_attr({"global_symbol": "add"}) - with T.block("root"): + with T.sblock("root"): for i in T.thread_binding(16, thread="threadIdx.x"): for j in range(16): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) diff --git a/tests/python/relax/texture/test_texture_nd.py b/tests/python/relax/texture/test_texture_nd.py index 520f309c6e41..ab2a7bcd27ec 100644 --- a/tests/python/relax/texture/test_texture_nd.py +++ b/tests/python/relax/texture/test_texture_nd.py @@ -120,12 +120,12 @@ class TextureCopy: def main(A: T.Buffer((M, N), dtype), B: T.Buffer((M, N), dtype)): T.func_attr({"global_symbol": "main"}) for li, lj in T.grid(M, N): - with T.block("Copy"): + with T.sblock("Copy"): i, j = T.axis.remap("SS", [li, lj]) B[i, j] = A[i, j] def schedule_texture_read(sch: tir.Schedule): - B_blk = sch.get_block("Copy") + B_blk = sch.get_sblock("Copy") Ai_block = sch.cache_read(B_blk, 0, "global.texture") sch.transform_layout(Ai_block, ("write", 0), lambda i, j: (i, j // lanes, j % lanes)) diff --git a/tests/python/runtime/test_evaluator_with_preproc.py b/tests/python/runtime/test_evaluator_with_preproc.py index 208d584e99a5..5570e87620f0 100644 --- a/tests/python/runtime/test_evaluator_with_preproc.py +++ b/tests/python/runtime/test_evaluator_with_preproc.py @@ -29,7 +29,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -41,7 +41,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def test_time_evalutor_with_preproc(f_preproc: str): mod = tvm.IRModule.from_expr(matmul.with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) - blk = sch.get_block("matmul") + blk = sch.get_sblock("matmul") i, j, k = sch.get_loops(blk) sch.bind(i, "blockIdx.x") sch.bind(j, "threadIdx.x") diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 627ebbb7d62c..95c25ae734ed 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -316,7 +316,7 @@ def check_remote_link_cl(remote): s = tvm.tir.Schedule(mod) - x = s.get_loops(s.get_block("B")) + x = s.get_loops(s.get_sblock("B")) xo, xi = s.split(x, factors=[None, 32]) s.bind(xo, "blockIdx.x") s.bind(xi, "threadIdx.x") diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 426272584bb5..09ebe020fefc 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -29,8 +29,8 @@ def test_unique_name_complete_block(): C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") func = te.create_prim_func([A, C]) s = tir.Schedule(func, debug_mask="all") - assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) - assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_sblock("main")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_sblock("main_1")), tir.schedule.StmtSRef) def test_unique_name_reduction_block(): @@ -41,8 +41,8 @@ def test_unique_name_reduction_block(): C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum") func = te.create_prim_func([A, C]) s = tir.Schedule(func, debug_mask="all") - assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef) - assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_sblock("sum")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_sblock("sum_1")), tir.schedule.StmtSRef) def _check_workload(te_workload, tir_workload, index_dtype_override=None, do_simplify=False): @@ -73,7 +73,7 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i0, j0, k0 in T.grid(128, 128, 128): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): C[i, j] = 0.0 @@ -88,7 +88,7 @@ def tir_matmul_int64( ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, j0, k0 in T.grid(T.int64(128), T.int64(128), T.int64(128)): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): C[i, j] = 0.0 @@ -118,11 +118,11 @@ def tir_element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i0, j0 in T.grid(128, 128): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i0, j0]) B[i, j] = A[i, j] * 2.0 for i0, j0 in T.grid(128, 128): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i0, j0]) C[i, j] = B[i, j] + 1.0 @@ -171,7 +171,7 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: Apad = T.alloc_buffer([16, 16, 16, 16]) for n, c, y, x in T.grid(16, 16, 16, 16): - with T.block("Apad"): + with T.sblock("Apad"): nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x]) Apad[nn, cc, yy, xx] = T.if_then_else( 1 <= yy and yy < 15 and 1 <= xx and xx < 15, @@ -180,7 +180,7 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: dtype="float32", ) for n, f, y, x, kc, ky, kx in T.grid(16, 32, 14, 14, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [n, f, y, x, kc, ky, kx]) with T.init(): B[nn, ff, yy, xx] = 0.0 @@ -211,10 +211,10 @@ def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): - with T.block("B.v0"): + with T.sblock("B.v0"): i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 - with T.block("B.v1"): + with T.sblock("B.v1"): i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0 @@ -247,7 +247,7 @@ def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128), elem_offset=off2) C = T.match_buffer(c, (128, 128), elem_offset=off3) # body - with T.block("C"): + with T.sblock("C"): T.reads() T.writes() T.evaluate( @@ -307,7 +307,7 @@ def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i0, j0, k0 in T.grid(128, 128, 128): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): C[i, j] = 0.0 @@ -407,14 +407,14 @@ def expected_layout_attr( T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) C = T.alloc_buffer([128, 128], dtype="float32") for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): x, y, k = T.axis.remap("SSR", [i0, i1, i2]) with T.init(): C[x, y] = T.float32(0) C[x, y] = C[x, y] + A[x, k] * B[y, k] for i0, i1 in T.grid(128, 128): - with T.block("D"): - T.block_attr({"layout_free_placeholders": [C]}) + with T.sblock("D"): + T.sblock_attr({"layout_free_placeholders": [C]}) x, y = T.axis.remap("SS", [i0, i1]) D[x, y] = C[x, y] + T.float32(1) @@ -428,7 +428,7 @@ def expected_layout_attr_int64( T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) C = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") for x, y, k in T.grid(T.int64(128), T.int64(128), T.int64(128)): - with T.block("C"): + with T.sblock("C"): v_x, v_y, v_k = T.axis.remap("SSR", [x, y, k]) T.reads(A[v_x, v_k], B[v_y, v_k]) T.writes(C[v_x, v_y]) @@ -436,8 +436,8 @@ def expected_layout_attr_int64( C[v_x, v_y] = T.float32(0) C[v_x, v_y] = C[v_x, v_y] + A[v_x, v_k] * B[v_y, v_k] for x, y in T.grid(T.int64(128), T.int64(128)): - with T.block("D"): - T.block_attr({"layout_free_placeholders": [C]}) + with T.sblock("D"): + T.sblock_attr({"layout_free_placeholders": [C]}) v_x, v_y = T.axis.remap("SS", [x, y]) T.reads(C[v_x, v_y]) T.writes(D[v_x, v_y]) @@ -502,7 +502,7 @@ def tir_argmax_idx_val( argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32") argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="float32") for i0, i1 in T.grid(m, n): - with T.block("argmax"): + with T.sblock("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) @@ -549,7 +549,7 @@ def tir_argmax_val_idx( argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32") for i0, i1 in T.grid(m, n): - with T.block("argmax"): + with T.sblock("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) @@ -595,10 +595,10 @@ def expected( c: T.Buffer((), "int32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block("c"): + with T.sblock("c"): vi = T.axis.spatial(1, 0) T.reads(a[()], b[()]) T.writes(c[()]) @@ -621,7 +621,7 @@ def tir_reshape( ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1 in T.grid(T.int64(4), T.int64(2)): - with T.block("T_reshape"): + with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads( A[ @@ -666,7 +666,7 @@ def tir_resize2d_symbolic( ow = T.int64() resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[v_i0, v_i1, T.int64(0) : T.int64(128), T.int64(0) : T.int64(128)]) T.writes(resize[v_i0, v_i1, v_i2, v_i3]) @@ -728,7 +728,7 @@ def tir_extern(var_A: T.handle, var_B: T.handle, var_P: T.handle, var_C: T.handl B = T.match_buffer(var_B, [128, 128], dtype="float32", offset_factor=1) P = T.match_buffer(var_P, [1], dtype="float32", offset_factor=1) C = T.match_buffer(var_C, [128, 128], dtype="float32", offset_factor=1) - with T.block("C"): + with T.sblock("C"): T.reads() T.writes() T.call_extern("myfunc", A.data, B.data, C.data, P[0], dtype="") @@ -751,9 +751,9 @@ def tir_slice_with_var_input(var_tensor: T.handle, idx: T.int64, var_slice: T.ha m, n = T.int64(), T.int64() tensor = T.match_buffer(var_tensor, (m, n)) slice = T.match_buffer(var_slice, (idx, n)) - # with T.block("root"): + # with T.sblock("root"): for i, j in T.grid(idx, n): - with T.block("slice"): + with T.sblock("slice"): v_i = T.axis.spatial(idx, i) v_j = T.axis.spatial(n, j) T.reads(tensor[v_i, v_j]) @@ -775,7 +775,7 @@ def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): b = T.match_buffer(var_b, (5,)) sum_red = T.match_buffer(var_sum_red, (5,)) for i, ax in T.grid(5, 5): - with T.block("sum_red"): + with T.sblock("sum_red"): v_i, v_ax = T.axis.remap("SR", [i, ax]) T.reads(b[v_i], a[v_i, v_ax]) T.writes(sum_red[v_i]) @@ -810,7 +810,7 @@ def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): b = T.match_buffer(var_b, (5,)) sum_red = T.match_buffer(var_sum_red, (5,)) for i, ax in T.grid(5, 5): - with T.block("sum_red"): + with T.sblock("sum_red"): v_i = T.axis.spatial(5, i) v_ax = T.axis.reduce(5, ax) T.reads(a[v_i, 0:5]) @@ -848,12 +848,12 @@ def tir_workload( # fmt: off adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30)) for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): - with T.block("adaptive_pool_sum_1"): + with T.sblock("adaptive_pool_sum_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) for rv0, rv1 in T.grid((v_ax2 % 3 * 4 + 16) // 12 + 1, (v_ax3 % 3 * 10 + 40) // 30 + 1): - with T.block("adaptive_pool_sum"): + with T.sblock("adaptive_pool_sum"): v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0) v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1) v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2) @@ -865,11 +865,11 @@ def tir_workload( adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0) adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): - with T.block("adaptive_pool_avg"): + with T.sblock("adaptive_pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) - T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) + T.sblock_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1)) # fmt: on @@ -897,12 +897,12 @@ def tir_workload( ): T.func_attr({"tir.noalias": True, "global_symbol": "main"}) for i0, i1, i2 in T.grid(8, 8, 8): - with T.block("compute_2"): + with T.sblock("compute_2"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(x[v_i0, v_i1, v_i2, 0:v_i1, 0 : v_i1 - 1]) T.writes(compute[v_i0, v_i1, v_i2]) for rv in range(v_i1): - with T.block("compute_1"): + with T.sblock("compute_1"): v_i0_1 = T.axis.spatial((v_i0, v_i0 + 1), v_i0) v_i1_1 = T.axis.spatial((v_i1, v_i1 + 1), v_i1) v_i2_1 = T.axis.spatial((v_i2, v_i2 + 1), v_i2) @@ -912,7 +912,7 @@ def tir_workload( with T.init(): compute[v_i0_1, v_i1_1, v_i2_1] = T.float32(0.0) for rv_1 in range(v_rv): - with T.block("compute"): + with T.sblock("compute"): v_i0_2 = T.axis.spatial((v_i0_1, v_i0_1 + 1), v_i0_1) v_i1_2 = T.axis.spatial((v_i1_1, v_i1_1 + 1), v_i1_1) v_i2_2 = T.axis.spatial((v_i2_1, v_i2_1 + 1), v_i2_1) diff --git a/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py b/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py index cb3a663c0379..1e45a92f9a01 100644 --- a/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py +++ b/tests/python/tir-analysis/test_tir_analysis_calculate_allocated_memory.py @@ -29,7 +29,7 @@ class Module: @T.prim_func def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): for i in T.serial(128): - with T.block("C"): + with T.sblock("C"): c[i] = a[i] * T.int8(2) @@ -37,10 +37,10 @@ def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")): B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm") for i in T.serial(128): - with T.block("B"): + with T.sblock("B"): B[i] = a[i] * T.int8(2) for i in T.serial(128): - with T.block("C"): + with T.sblock("C"): c[i] = B[i] * T.int8(3) # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks @@ -54,7 +54,7 @@ def test_scale_by(primFunc, size): """Test calculate allocated bytes per scope""" mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main")) sch = tir.Schedule(mod, debug_mask="all") - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") (flat,) = sch.get_loops(block_c) cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm") sch.compute_at(cache_block, flat) @@ -78,21 +78,21 @@ def matmul_mix_scope(a: T.handle, b: T.handle, c: T.handle) -> None: C_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global") for i, j in T.grid(128, 128): - with T.block("A.allocated"): + with T.sblock("A.allocated"): A_allocated[i, j] = A[i, j] for i, j in T.grid(128, 128): - with T.block("B.allocated"): + with T.sblock("B.allocated"): B_allocated[i, j] = B[i, j] for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C_allocated[vi, vj] = 0.0 C_allocated[vi, vj] = C[vi, vj] + A_allocated[vi, vk] * B_allocated[vj, vk] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): C[i, j] = C_allocated[i, j] @@ -114,7 +114,7 @@ def test_matmul_mix_scope(scope, size): def test_full_mod_calculator(): def apply_schedule(sch, func_name): sch.work_on(func_name) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") sch.cache_read(block_c, 0, storage_scope="global.vtcm") sch = tvm.tir.Schedule(Module, debug_mask="all") diff --git a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py index b3ce7efd0593..9a06e610f285 100644 --- a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py @@ -26,11 +26,11 @@ def buffer_load_store_func(a: T.handle, b: T.handle) -> None: C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") for ii, jj in T.grid(128, 128): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [ii, jj]) A[i, j] = T.float32(0) for i0, j0, k0 in T.grid(32, 32, 32): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): for ii, jj in T.grid(4, 4): @@ -49,7 +49,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16, 16], "float32") C = T.match_buffer(c, [16, 16], "float32") - with T.block(): + with T.sblock(): T.reads([]) T.writes(B[0:16, 0:16]) A = T.decl_buffer([256], "float32") @@ -62,7 +62,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] @@ -78,14 +78,14 @@ def match_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") for i, j in T.grid(8, 8): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) for ii, jj in T.grid(16, 16): - with T.block("AAA"): + with T.sblock("AAA"): vii, vjj = T.axis.remap("SS", [ii, jj]) AA = T.match_buffer(A[vii, vjj], ()) AA[()] = 1.0 @@ -99,7 +99,7 @@ def global_buffer_with_blockidx( ) -> None: for i0 in T.thread_binding(0, 1, thread="blockIdx.x"): for i1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("copy"): + with T.sblock("copy"): i, j = T.axis.remap("SS", [i0, i1]) T.reads(a[i, j]) T.writes(b[i, j]) diff --git a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py index 1fa013399e12..bc027a5da201 100644 --- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py +++ b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py @@ -29,14 +29,14 @@ def func() -> None: B = T.alloc_buffer((128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block(): + with T.sblock(): # Need add read/write region manually to avoid triggering block access region detector T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) T.writes([A[0:12, 0:12]]) for i, j in T.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] for i, j in T.grid(2, 2): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) @@ -47,14 +47,14 @@ def func() -> None: @T.prim_func def match_buffer_func() -> None: - with T.block("root"): + with T.sblock("root"): A = T.alloc_buffer((128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i, j in T.grid(8, 8): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -64,7 +64,7 @@ def match_buffer_func() -> None: B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8) ) for ii, jj in T.grid(16, 16): - with T.block("AAA"): + with T.sblock("AAA"): vii, vjj = T.axis.remap("SS", [ii, jj]) T.reads([]) T.writes(AA[vii, vjj]) @@ -76,18 +76,18 @@ def match_buffer_func() -> None: @T.prim_func def opaque_block_func() -> None: - with T.block("root"): + with T.sblock("root"): A = T.alloc_buffer((16, 16), "float32") B = T.alloc_buffer((16, 16), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes([B[i, 0:16]]) for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -98,7 +98,7 @@ def opaque_access_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block(): + with T.sblock(): v = T.axis.S(8, i) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([B[v * 128 : v * 128 + 128]]) @@ -112,7 +112,7 @@ def opaque_access_with_tvm_access_ptr_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) C = T.alloc_buffer([1024]) - with T.block("opaque"): + with T.sblock("opaque"): T.reads(A[0:1024], C[0:1024]) T.writes(B[0:1024], C[0:1024]) T.evaluate(A.access_ptr("r")) @@ -124,7 +124,7 @@ def opaque_access_with_tvm_access_ptr_func() -> None: def access_in_if_then_else_func() -> None: A = T.alloc_buffer([8]) B = T.alloc_buffer([8]) - with T.block(): + with T.sblock(): T.reads([A[0:5]]) T.writes([B[0:8]]) for i in T.serial(0, 8): @@ -135,7 +135,7 @@ def access_in_if_then_else_func() -> None: def access_in_branch_func() -> None: A = T.alloc_buffer([8]) B = T.alloc_buffer([8]) - with T.block(): + with T.sblock(): T.reads([A[0:7]]) T.writes([B[0:8]]) for i in T.serial(0, 8): @@ -151,7 +151,7 @@ def gemm() -> None: B = T.alloc_buffer([16, 16], "float32") C = T.alloc_buffer([16, 16], "float32") for i, j, k, ii, jj in T.grid(4, 4, 16, 4, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) @@ -169,14 +169,14 @@ def decomposed_gemm() -> None: C = T.alloc_buffer([16, 16], "float32") for i, j in T.grid(4, 4): for ii, jj in T.grid(4, 4): - with T.block("init"): + with T.sblock("init"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) T.reads([]) T.writes(C[vi, vj]) C[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) @@ -191,14 +191,14 @@ def access_of_padding_pattern() -> None: X_pad = T.alloc_buffer([32, 32]) Y = T.alloc_buffer([28, 28]) for i, j in T.grid(32, 32): - with T.block("padding"): + with T.sblock("padding"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([X[vi - 2, vj - 2]]) T.writes([X_pad[vi, vj]]) X_pad[vi, vj] = T.if_then_else( 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32" ) - with T.block("padding_reverse"): + with T.sblock("padding_reverse"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([X_pad[vi, vj]]) T.writes([Y[vi - 2, vj - 2]]) @@ -210,7 +210,7 @@ def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret[0]) tvm.ir.assert_structural_equal(block.writes, ret[1]) @@ -225,12 +225,12 @@ def test_opaque_block(): buffer_var_map = {buf.data: buf for buf in alloc_buffers} block0 = opaque_block_func.body.block.body.body.block - ret = tir.analysis.get_block_access_region(block0, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block0, buffer_var_map) tvm.ir.assert_structural_equal(block0.reads, ret[0]) tvm.ir.assert_structural_equal(block0.writes, ret[1]) block1 = block0.body.body.block - ret = tir.analysis.get_block_access_region(block1, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block1, buffer_var_map) tvm.ir.assert_structural_equal(block1.reads, ret[0]) tvm.ir.assert_structural_equal(block1.writes, ret[1]) @@ -240,8 +240,8 @@ def test_opaque_access(): alloc_buffers = opaque_access_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) - ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map) with pytest.raises(ValueError): tvm.ir.assert_structural_equal(ret0[0], ret1[0]) with pytest.raises(ValueError): @@ -253,8 +253,8 @@ def test_opaque_access_with_tvm_access_ptr(): alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) - ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret0[0]) tvm.ir.assert_structural_equal(block.writes, ret0[1]) with pytest.raises(ValueError): @@ -271,13 +271,13 @@ def test_match_buffer(): buffer_var_map = {buf.data: buf for buf in alloc_buffers} # Check block - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.writes, ret[1]) # B is opaque access tvm.ir.assert_structural_equal(block.reads, ret[2]) # Check inner block AAA without updating buffer_var_map - ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block_inner, buffer_var_map) # Since AA is not in the buffer_var_map, region of AA will not be collected. tvm.ir.assert_structural_equal([], ret[1]) @@ -286,7 +286,7 @@ def test_match_buffer(): target_buffer = match_buffer.buffer buffer_var_map[target_buffer.data] = target_buffer - ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block_inner, buffer_var_map) tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) @@ -295,8 +295,8 @@ def test_access_in_if_then_else_func(): block = access_in_if_then_else_func.body.block.body.block alloc_buffers = access_in_if_then_else_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) - ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(ret0[0], ret1[0]) tvm.ir.assert_structural_equal(ret0[1], ret1[1]) @@ -305,15 +305,15 @@ def test_access_in_branch_func(): block = access_in_branch_func.body.block.body.block alloc_buffers = access_in_branch_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) - ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(ret0[0], ret1[0]) tvm.ir.assert_structural_equal(ret0[1], ret1[1]) def test_access_of_padding_pattern(): s = tvm.tir.schedule.Schedule(access_of_padding_pattern) - alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers + alloc_buffers = s.get_sref(s.get_sblock("root")).stmt.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} def do_compare_buffer_region(region, expect): @@ -324,10 +324,10 @@ def do_compare_buffer_region(region, expect): analyzer.can_prove_equal(observed_range.extent, expected_range.extent) def do_check_block(block_name): - block = s.get_sref(s.get_block(block_name)).stmt + block = s.get_sref(s.get_sblock(block_name)).stmt expect_reads = block.reads expect_writes = block.writes - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) for i, read in enumerate(ret[0]): do_compare_buffer_region(read, expect_reads[i]) for i, write in enumerate(ret[1]): @@ -341,7 +341,7 @@ def test_access_of_reduction(): block = gemm.body.block.body.body.body.body.body.body.block alloc_buffers = gemm.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret[0]) tvm.ir.assert_structural_equal(block.writes, ret[1]) @@ -352,7 +352,7 @@ def test_access_of_decompose_reduction(): alloc_buffers = decomposed_gemm.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} for block in [init, update]: - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret[0]) tvm.ir.assert_structural_equal(block.writes, ret[1]) @@ -366,7 +366,7 @@ def func( output: T.Buffer((16, 16), "float32"), ): for i, s in T.grid(16, 16): - with T.block("copy"): + with T.sblock("copy"): vi, vs = T.axis.remap("SS", [i, s]) T.reads( seq_slot_ids[vi], @@ -380,7 +380,7 @@ def func( block = func.body.block.body.body.body.block buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret[0]) tvm.ir.assert_structural_equal(block.writes, ret[1]) @@ -393,7 +393,7 @@ def func( C: T.Buffer((16, 16), "float32"), ): for i, s in T.grid(16, 16): - with T.block("copy"): + with T.sblock("copy"): vi, vs = T.axis.remap("SS", [i, s]) T.reads(A[vi, vs], B[vi, vs]) T.writes(C[vi, vs]) @@ -406,7 +406,7 @@ def func( block = func.body.block.body.body.body.block buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} - ret = tir.analysis.get_block_access_region(block, buffer_var_map) + ret = tir.analysis.get_sblock_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret[0]) tvm.ir.assert_structural_equal(block.writes, ret[1]) diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index f6e1d2eade24..1297cec61fe4 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -30,11 +30,11 @@ def element_wise( ): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): # It's a opaque block , so it can use outside variables C[i, j] = B[i, j] * 2.0 @@ -49,7 +49,7 @@ def element_wise( B: T.Buffer((128, 128), "float32"), ): for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) # we cannot use `i` since it's defined outside the block B[vi, vj] = A[i, vj] * 2.0 @@ -280,7 +280,7 @@ class mod: @T.prim_func def func(A: T.Buffer([256, 256], "float32")): for iters in T.grid(16, 16, 16, 16): - with T.block("compute"): + with T.sblock("compute"): tile_i, tile_j, i, j = T.axis.remap("SSSS", iters) B = T.match_buffer( A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16], @@ -299,7 +299,7 @@ class mod: @T.prim_func def func(A: T.Buffer([256, 256], "int32")): for iters in T.grid(16, 16, 16, 16): - with T.block("compute"): + with T.sblock("compute"): tile_i, tile_j, i, j = T.axis.remap("SSSS", iters) elem_offset = T.int32() diff --git a/tests/python/tir-base/test_slice_tir.py b/tests/python/tir-base/test_slice_tir.py index fea2ce480e48..6fb6ced458fc 100644 --- a/tests/python/tir-base/test_slice_tir.py +++ b/tests/python/tir-base/test_slice_tir.py @@ -125,7 +125,7 @@ class TestAnnotateAndSliceTIR(tvm.testing.CompareBeforeAfter): # def main(A: T.Buffer((1,), "int8"): # #A = T.match_buffer(a, (1,), "int8") # A[0] = 0 - # with T.block("block_foo"): # optional: give this block a name, perhaps for testing? + # with T.sblock("block_foo"): # optional: give this block a name, perhaps for testing? # # NOTE: nice to have: human control over name used for the generated callee # T.annotate("extract_as_subroutine", "add_one") # A[0] += 1 @@ -137,7 +137,7 @@ class TestAnnotateAndSliceTIR(tvm.testing.CompareBeforeAfter): # def main(): # A = T.buffer[[1], "int8"] # A[0] = 0 - # with T.block("block_foo"): + # with T.sblock("block_foo"): # call_tir(add_one, A) # # @T.prim_func @@ -160,7 +160,7 @@ class TestLowerCallTir(tvm.testing.CompareBeforeAfter): # def main(): # A = T.buffer[[1], "int8"] # A[0] = 0 - # with T.block(): + # with T.sblock(): # call_tir(add_one, A) # # @T.prim_func @@ -173,7 +173,7 @@ class TestLowerCallTir(tvm.testing.CompareBeforeAfter): # def main(): # A = T.buffer[[1], "int8"] # A[0] = 0 - # with T.block(): + # with T.sblock(): # # TODO: figure out the right TVMScript thing to do here # call_packed(add_one, A) # not sure about this function / interface # @@ -200,7 +200,7 @@ class TestPrimfuncSlicingEndToEnd(tvm.testing.CompareBeforeAfter): # def main(): # A = T.buffer[[1], "int8"] # A[0] = 0 - # with T.block(): # optional: give this block a name, perhaps for testing? + # with T.sblock(): # optional: give this block a name, perhaps for testing? # # NOTE: nice to have: human control over name used for the generated callee # T.annotate("extract_as_subroutine", "add_one") # A[0] += 1 diff --git a/tests/python/tir-base/test_tir_block_dependence_info.py b/tests/python/tir-base/test_tir_block_dependence_info.py index 57370416a727..ead4e0a396d8 100644 --- a/tests/python/tir-base/test_tir_block_dependence_info.py +++ b/tests/python/tir-base/test_tir_block_dependence_info.py @@ -24,7 +24,7 @@ from tvm import tir from tvm.ir import IRModule from tvm.script import tir as T -from tvm.tir import PrimFunc, BlockDependenceInfo +from tvm.tir import PrimFunc, SBlockDependenceInfo from tvm.tir.stmt_functor import post_order_visit from tvm.tir.block_scope import DepKind @@ -37,15 +37,15 @@ def elementwise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -57,10 +57,10 @@ def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -71,11 +71,11 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -83,14 +83,14 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable -def get_blocks(func: PrimFunc): +def get_sblocks(func: PrimFunc): blocks = {} def update_blocks(node): - if isinstance(node, tvm.tir.Block): + if isinstance(node, tvm.tir.SBlock): blocks[node.name_hint] = node - # post_order_visit(func.body, lambda node: blocks[node.name_hint] = node if isinstance(node, tvm.tir.Block) else None) + # post_order_visit(func.body, lambda node: blocks[node.name_hint] = node if isinstance(node, tvm.tir.SBlock) else None) post_order_visit(func.body, update_blocks) return blocks @@ -98,7 +98,7 @@ def update_blocks(node): def _verify_dependence(dependence_info, src_block, dst_block, kind): src_sref = dependence_info.get_sref(src_block) dst_sref = dependence_info.get_sref(dst_block) - scope = dependence_info.get_block_scope(src_sref.parent) + scope = dependence_info.get_sblock_scope(src_sref.parent) def _find_dependence(deps): for dep in deps: @@ -128,22 +128,22 @@ def _get_dependency_kind_name(dep_kind): def test_RAW_dependences(): func = elementwise - dependence_info = BlockDependenceInfo(func) - blocks = get_blocks(func) + dependence_info = SBlockDependenceInfo(func) + blocks = get_sblocks(func) _verify_dependence(dependence_info, blocks["B"], blocks["C"], DepKind.RAW) def test_WAR_dependences(): func = war_dependency - dependence_info = BlockDependenceInfo(func) - blocks = get_blocks(func) + dependence_info = SBlockDependenceInfo(func) + blocks = get_sblocks(func) _verify_dependence(dependence_info, blocks["C"], blocks["B"], DepKind.WAR) def test_RAW_and_WAW_dependences(): func = matmul - dependence_info = BlockDependenceInfo(func) - blocks = get_blocks(func) + dependence_info = SBlockDependenceInfo(func) + blocks = get_sblocks(func) _verify_dependence(dependence_info, blocks["init"], blocks["update"], DepKind.RAW) _verify_dependence(dependence_info, blocks["init"], blocks["update"], DepKind.WAW) diff --git a/tests/python/tir-base/test_tir_host_func.py b/tests/python/tir-base/test_tir_host_func.py index 39284c97252b..01ac7d1865e8 100644 --- a/tests/python/tir-base/test_tir_host_func.py +++ b/tests/python/tir-base/test_tir_host_func.py @@ -38,9 +38,9 @@ def main( "tir.noalias": True, } ) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(729, 729, 729): - with T.block("C"): + with T.sblock("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 1e8c88e08e65..ca093343da8d 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -276,7 +276,7 @@ def clz_np(x, dtype): # Apply scheduling primitives if target is Vulkan if target.kind.name == "vulkan": - block = sch.get_block("B") + block = sch.get_sblock("B") loop = sch.get_loops(block)[0] bx, tx = sch.split(loop, factors=[None, 64]) sch.bind(bx, "blockIdx.x") diff --git a/tests/python/tir-base/test_tir_ptx_cp_async.py b/tests/python/tir-base/test_tir_ptx_cp_async.py index 9e0e18c30781..6a4c07e98b1a 100644 --- a/tests/python/tir-base/test_tir_ptx_cp_async.py +++ b/tests/python/tir-base/test_tir_ptx_cp_async.py @@ -28,7 +28,7 @@ def ptx_cp_async(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "floa tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) @@ -71,7 +71,7 @@ def ptx_cp_async_barrier( tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") T.reads(A[0:32, 0:128]) @@ -116,7 +116,7 @@ def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") T.reads(A[0:32, 0:128]) diff --git a/tests/python/tir-base/test_tir_ptx_ldmatrix.py b/tests/python/tir-base/test_tir_ptx_ldmatrix.py index 8d4ed399b2e8..368a91b1fef7 100644 --- a/tests/python/tir-base/test_tir_ptx_ldmatrix.py +++ b/tests/python/tir-base/test_tir_ptx_ldmatrix.py @@ -30,7 +30,7 @@ def ptx_ldmatrix( tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([16, 16], "float16", scope="shared") A_local = T.alloc_buffer([8], "float16", scope="local") diff --git a/tests/python/tir-base/test_tir_renew_defs.py b/tests/python/tir-base/test_tir_renew_defs.py index 7fe8d7c679fa..efb5d81bbbf6 100644 --- a/tests/python/tir-base/test_tir_renew_defs.py +++ b/tests/python/tir-base/test_tir_renew_defs.py @@ -20,7 +20,7 @@ from tvm.script import tir as T from tvm.tir.buffer import Buffer from tvm.tir.function import PrimFunc -from tvm.tir.stmt import Block +from tvm.tir.stmt import SBlock def _check_func_signature_remap(lhs: PrimFunc, rhs: PrimFunc): @@ -35,7 +35,7 @@ def _check_buffer_decl(lhs: Buffer, rhs: Buffer): assert lhs.data != rhs.data -def _check_block_signature_remap(lhs: Block, rhs: Block): +def _check_block_signature_remap(lhs: SBlock, rhs: SBlock): assert lhs != rhs for x, y in zip(lhs.iter_vars, rhs.iter_vars): assert x != y @@ -55,7 +55,7 @@ def elementwise(A: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") # i, j should be remapped for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): # vi, vj should be remapped vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) @@ -75,10 +75,10 @@ def elementwise(A: T.Buffer((128, 128), "float32")): assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var # check inner block - def _get_block(f): + def _get_sblock(f): return f.body.block.body.body.body.block - _check_block_signature_remap(_get_block(f1), _get_block(f2)) + _check_block_signature_remap(_get_sblock(f1), _get_sblock(f2)) def test_match_buffer(): @@ -87,7 +87,7 @@ def test_match_buffer(): @T.prim_func(check_well_formed=False) # A and B should be remapped def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): - with T.block("root"): + with T.sblock("root"): s = T.int32() e = T.int32() # A0 should be remapped @@ -100,7 +100,7 @@ def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128) elem_offset=e, ) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A0[vi, vj] * 2.0 @@ -112,11 +112,11 @@ def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128) _check_block_signature_remap(f1.body.block, f2.body.block) assert f1.body.block.body.loop_var != f2.body.block.body.loop_var - def _get_block(f): + def _get_sblock(f): return f.body.block - block1 = _get_block(f1) - block2 = _get_block(f2) + block1 = _get_sblock(f1) + block2 = _get_sblock(f2) _check_block_signature_remap(block1, block2) matched_buffer1 = block1.match_buffers[0].buffer @@ -176,7 +176,7 @@ def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (m * 2,)) B = T.match_buffer(b, (m, 2)) for i, j in T.grid(m, 2): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi * 2 + vj] @@ -194,7 +194,7 @@ def take( T_take: T.Buffer((1, 4096), "float16"), ): for ax0, ax1 in T.grid(1, 4096): - with T.block("T_take"): + with T.sblock("T_take"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[B[v_ax0], v_ax1], B[v_ax0]) T.writes(T_take[v_ax0, v_ax1]) diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index cead775e97cd..5fadbe1064dd 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -31,7 +31,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None: C = T.match_buffer(c, [m, m]) for i, j, k in T.grid(m, m, n): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -45,7 +45,7 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -60,7 +60,7 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [m, m]) for i, j, k in T.grid(m, m, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -78,7 +78,7 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [m, m]) for i, j, k in T.grid(m, m, x * 8): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -95,12 +95,12 @@ def element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((m, n), "float32") for i, j in T.grid(m, n): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(m, n): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -112,12 +112,12 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 64), "float32") for i, j in T.grid(128, 64): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 64): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -130,12 +130,12 @@ def element_wise_128_n(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, n), "float32") for i, j in T.grid(128, n): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, n): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -146,7 +146,7 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32, q: T. B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) for i, j in T.grid(m, n): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -157,7 +157,7 @@ def mem_copy_16_16_8_4(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -168,7 +168,7 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) for i, j in T.grid(m, n): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] @@ -226,7 +226,7 @@ def before(a: T.handle, b: T.handle): A = T.match_buffer(a, [n // 8, 8], "int32") B = T.match_buffer(b, [n], "int32") for i in range(n - 1): - with T.block(): + with T.sblock(): vi = T.axis.S(n - 1, i) B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 @@ -235,7 +235,7 @@ def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, [2, 8], "int32") B = T.match_buffer(b, [16], "int32") for i in range(15): - with T.block(): + with T.sblock(): vi = T.axis.S(15, i) B[vi] = A[vi // 8, vi % 8] + 714 diff --git a/tests/python/tir-base/test_tir_te_extern_primfunc.py b/tests/python/tir-base/test_tir_te_extern_primfunc.py index 1408597fa22e..80b3ef887ca1 100644 --- a/tests/python/tir-base/test_tir_te_extern_primfunc.py +++ b/tests/python/tir-base/test_tir_te_extern_primfunc.py @@ -36,11 +36,11 @@ def func_1(A: T.Buffer((16,), "float32"), C: T.Buffer((1,), "float32")): 0, 16, ): - with T.block(): + with T.sblock(): B = T.alloc_buffer((1,), dtype="float32") - with T.block(): + with T.sblock(): B[0] = A[i] * T.float32(2) - with T.block(): + with T.sblock(): C[0] = C[0] + A[i] + B[0] + T.float32(1) A[i] = B[0] + T.float32(1) @@ -65,11 +65,11 @@ def func_2( 0, 16, ): - with T.block(): + with T.sblock(): B = T.alloc_buffer((1,), dtype="float32") - with T.block(): + with T.sblock(): B[0] = A[i] * T.float32(2) - with T.block(): + with T.sblock(): C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] A[i] = B[0] + T.float32(1) + D[1] @@ -99,11 +99,11 @@ def func_3( 0, 16, ): - with T.block(): + with T.sblock(): B = T.alloc_buffer((1,), dtype="float32") - with T.block(): + with T.sblock(): B[0] = A[i] * T.float32(2) - with T.block(): + with T.sblock(): E[i] = A[i] F[i] = E[i] + 1.0 C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] @@ -141,11 +141,11 @@ def func_4( 0, 16, ): - with T.block(): + with T.sblock(): B = T.alloc_buffer((1,), dtype="float32") - with T.block(): + with T.sblock(): B[0] = A[i] * T.float32(2) - with T.block(): + with T.sblock(): E[i] = A[i] F[i] = E[i] + 1.0 C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] diff --git a/tests/python/tir-base/test_tir_texture_scope.py b/tests/python/tir-base/test_tir_texture_scope.py index 4b759bb0477d..cc4d7bff56dd 100644 --- a/tests/python/tir-base/test_tir_texture_scope.py +++ b/tests/python/tir-base/test_tir_texture_scope.py @@ -36,13 +36,13 @@ def main(a: T.handle, b: T.handle) -> None: for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"): for thread_idx in T.thread_binding(0, 128, thread="threadIdx.x"): for k in T.serial(4): - with T.block("B"): + with T.sblock("B"): vb, vt, vk = T.axis.remap("SSS", [block_idx, thread_idx, k]) B[vb, vt, vk] = A[vb, vt, vk] + T.float32(1) for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"): for thread_idx in T.thread_binding(0, 128, thread="threadIdx.x"): for k in T.serial(4): - with T.block("C"): + with T.sblock("C"): vb, vt, vk = T.axis.remap("SSS", [block_idx, thread_idx, k]) C[vb, vt, vk] = B[vb, vt, vk] * T.float32(2) @@ -52,8 +52,8 @@ def schedule_block(block): _, _, inner = sch.get_loops(block) sch.vectorize(inner) - schedule_block(sch.get_block("B")) - schedule_block(sch.get_block("C")) + schedule_block(sch.get_sblock("B")) + schedule_block(sch.get_sblock("C")) target = tvm.target.Target("opencl") mod = tvm.compile(sch.mod["main"], target=target) diff --git a/tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py b/tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py index 80944dc21da6..0cc07531b07b 100644 --- a/tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py +++ b/tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py @@ -34,7 +34,7 @@ def indirect_mem_access(a: T.handle, idx_a: T.handle, b: T.handle, idx_b: T.hand IB = T.match_buffer(idx_b, [10], dtype="int32") for i in range(10): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(10, i) T.reads(A[IA[vi]], IA[vi]) T.writes(B[IB[vi]], IB[vi]) @@ -49,7 +49,7 @@ def indirect_mem_access_hide_ia(a: T.handle, idx_a: T.handle, b: T.handle, idx_b IB = T.match_buffer(idx_b, [10], dtype="int32") for i in range(10): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(10, i) T.reads(A[IA[vi]]) T.writes(B[IB[vi]], IB[vi]) @@ -64,7 +64,7 @@ def indirect_mem_access_hide_ib(a: T.handle, idx_a: T.handle, b: T.handle, idx_b IB = T.match_buffer(idx_b, [10], dtype="int32") for i in range(10): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(10, i) T.reads(A[IA[vi]], IA[vi]) T.writes(B[IB[vi]]) @@ -73,7 +73,7 @@ def indirect_mem_access_hide_ib(a: T.handle, idx_a: T.handle, b: T.handle, idx_b def test_hide_buffer_access_read(): sch = tir.Schedule(indirect_mem_access, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") sch.unsafe_hide_buffer_access(block_b, "read", [1]) assert_structural_equal_ignore_global_symbol(indirect_mem_access_hide_ia, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=indirect_mem_access) @@ -81,7 +81,7 @@ def test_hide_buffer_access_read(): def test_hide_buffer_access_write(): sch = tir.Schedule(indirect_mem_access, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") sch.unsafe_hide_buffer_access(block_b, "write", [1]) assert_structural_equal_ignore_global_symbol(indirect_mem_access_hide_ib, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=indirect_mem_access) @@ -89,14 +89,14 @@ def test_hide_buffer_access_write(): def test_hide_buffer_access_fail_buffer_type(): sch = tir.Schedule(indirect_mem_access, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") with pytest.raises(tvm.error.TVMError): sch.unsafe_hide_buffer_access(block_b, "opaque", [0]) def test_hide_buffer_access_fail_buffer_index(): sch = tir.Schedule(indirect_mem_access, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") with pytest.raises(tvm.error.TVMError): sch.unsafe_hide_buffer_access(block_b, "read", [2]) diff --git a/tests/python/tir-schedule/test_tir_schedule_analysis.py b/tests/python/tir-schedule/test_tir_schedule_analysis.py index cc87818db428..afcf082c2404 100644 --- a/tests/python/tir-schedule/test_tir_schedule_analysis.py +++ b/tests/python/tir-schedule/test_tir_schedule_analysis.py @@ -164,11 +164,11 @@ def main( compute: T.Buffer((1024, 1024), "int32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() for i0, i1, i2 in T.grid(1024, 1024, 1024): - with T.block("compute"): + with T.sblock("compute"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) T.writes(compute[i, j]) @@ -189,7 +189,7 @@ def main( ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): ( n, oc_chunk, @@ -235,7 +235,7 @@ def callback(node): def test_get_tensorize_loop_mapping_dense_16x4(): s = Schedule(DenseTIRModule) - block = s.get_block("compute") + block = s.get_sblock("compute") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) @@ -253,7 +253,7 @@ def test_get_tensorize_loop_mapping_dense_16x4(): def test_get_tensorize_loop_mapping_conv2d_nchwc_16x4(): s = Schedule(Conv2dNCHWcTIRModule) - block = s.get_block("conv2d_NCHWc_int8") + block = s.get_sblock("conv2d_NCHWc_int8") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) @@ -277,11 +277,11 @@ def matmul_16x16x16xf16f16f16_desc( B: T.Buffer((16, 16), "float16", align=64, offset_factor=1), C: T.Buffer((16, 16), "float16", align=64, offset_factor=1), ) -> None: - with T.block("root"): + with T.sblock("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) for i, j, k in T.grid(16, 16, 16): - with T.block("update"): + with T.sblock("update"): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] @@ -294,7 +294,7 @@ def matmul_16x16x16xf16f16f16_desc( ) s = Schedule(matmul) - block = s.get_block("C") + block = s.get_sblock("C") i0, i1, i2 = s.get_loops(block) desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) @@ -326,7 +326,7 @@ def test_get_tensorize_loop_mapping_padding_matmul(): ) ) s = Schedule(matmul) - block = s.get_block("C") + block = s.get_sblock("C") desc = TensorIntrin.get(WMMA_SYNC_16x16x16_f16f16f16_INTRIN).desc info = get_tensorize_loop_mapping(s, block, desc, allow_padding=True) @@ -341,7 +341,7 @@ def test_get_tensorize_loop_mapping_padding_matmul(): def check_index_map(workload, block_name, intrin_name, expected_index_map): s = Schedule(workload) - block = s.get_block(block_name) + block = s.get_sblock(block_name) desc_func = TensorIntrin.get(intrin_name).desc info = get_auto_tensorize_mapping_info(s, block, desc_func) if expected_index_map is None: @@ -413,16 +413,16 @@ def two_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 sch = tvm.tir.Schedule(two_elementwise) - block_rv = sch.get_block("C") + block_rv = sch.get_sblock("C") assert is_output_block(sch, block_rv) @@ -431,13 +431,13 @@ def test_empty_grid(): def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")): act = T.alloc_buffer((1, 8, 8), "int32") for z2, y2, x2 in T.grid(1, 8, 8): - with T.block("b0"): + with T.sblock("b0"): az, ay, ax = T.axis.remap("SSS", [z2, y2, x2]) T.writes(act[az, ay, ax]) act[az, ay, az] = T.int32(0) # Empty grid: for z1, y1, x1 in T.grid(0, 8, 8): - with T.block("b1"): + with T.sblock("b1"): az, ay, ax = T.axis.remap("SSS", [z1, y1, x1]) T.reads(act[az + 1, ay, ax]) T.writes(out[az, ay, ax]) @@ -445,7 +445,7 @@ def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")): # The block below is not needed to show the bug, but the 'out' # buffer would be undefined without it. for z2, y2, x2 in T.grid(1, 8, 8): - with T.block("b2"): + with T.sblock("b2"): az, ay, ax = T.axis.remap("SSS", [z2, y2, x2]) T.writes(out[az, ay, ax]) out[az, ay, az] = T.int32(0) diff --git a/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py index b8f7c11a92b3..ca1536a15301 100644 --- a/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py +++ b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py @@ -30,11 +30,11 @@ def test_annotate_read_buffer_access(): def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -42,19 +42,19 @@ def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32" def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi - 1 : vi - 1 + 2, vj - 1 : vj - 1 + 2]) T.writes(B[vi, vj]) - T.block_attr({"explicit_read_region": [T.int32(0)]}) + T.sblock_attr({"explicit_read_region": [T.int32(0)]}) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 sch = tir.Schedule(before, debug_mask="all") - block = sch.get_block("B") + block = sch.get_sblock("B") sch.annotate_buffer_access( block, 0, "read", lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)) ) @@ -67,11 +67,11 @@ def test_annotate_write_buffer_access(): def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -79,19 +79,19 @@ def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32" def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi : vi + 2, vj : vj + 2]) - T.block_attr({"explicit_write_region": [T.int32(0)]}) + T.sblock_attr({"explicit_write_region": [T.int32(0)]}) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 sch = tir.Schedule(before, debug_mask="all") - block = sch.get_block("B") + block = sch.get_sblock("B") sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) verify_trace_roundtrip(sch=sch, mod=before) @@ -102,7 +102,7 @@ def test_annotate_buffer_access_for_resize(): @T.prim_func def resize_before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(x[v_i0, v_i1, 0:32, 0:32]) T.writes(resize[v_i0, v_i1, v_i2, v_i3]) @@ -111,15 +111,15 @@ def resize_before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1 @T.prim_func def resize_expected(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 - 3:v_i3 * 2 + 3]) T.writes(resize[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"explicit_read_region": [T.int32(0)]}) + T.sblock_attr({"explicit_read_region": [T.int32(0)]}) resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) # fmt: on sch = tir.Schedule(resize_before, debug_mask="all") - block = sch.get_block("resize") + block = sch.get_sblock("resize") sch.annotate_buffer_access( block, 0, @@ -140,13 +140,13 @@ def test_annotate_buffer_access_read_and_write(): def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -156,23 +156,23 @@ def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32" def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi - 1 : vi + 2, vj - 1 : vj + 2]) T.writes(B[vi : vi + 2, vj : vj + 2]) - T.block_attr( + T.sblock_attr( {"explicit_read_region": [T.int32(0)], "explicit_write_region": [T.int32(0)]} ) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + 1.0 sch = tir.Schedule(before, debug_mask="all") - block = sch.get_block("B") + block = sch.get_sblock("B") sch.annotate_buffer_access( block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) @@ -189,13 +189,13 @@ def test_double_annotate_buffer_access_read(): def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -205,21 +205,21 @@ def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32" def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi - 2 : vi + 3, vj - 2 : vj + 3]) T.writes(B[vi, vj]) - T.block_attr({"explicit_read_region": [T.int32(0)]}) + T.sblock_attr({"explicit_read_region": [T.int32(0)]}) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + 1.0 sch = tir.Schedule(before, debug_mask="all") - block = sch.get_block("B") + block = sch.get_sblock("B") sch.annotate_buffer_access( block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) @@ -239,11 +239,11 @@ def test_annotate_buffer_access_with_compute_at_for_resize(): def before(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): x_global = T.alloc_buffer([1, 3, 200, 200], dtype="float32") for ax0, ax1, ax2, ax3 in T.grid(1, 3, 200, 200): - with T.block("cache"): + with T.sblock("cache"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] for i0, i1, i2, i3 in T.grid(1, 3, 100, 100): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(v_i2 * 2 + 0.5)), T.Cast("int32", T.floor(v_i3 * 2 + 0.5))] @@ -252,7 +252,7 @@ def after(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100) x_global = T.alloc_buffer((1, 3, 200, 200)) for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): for ax0, ax1 in T.grid(24, 24): - with T.block("cache"): + with T.sblock("cache"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(3, i1) v2 = T.axis.spatial(200, i2_0 * 20 - 3 + ax0) @@ -262,13 +262,13 @@ def after(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100) T.writes(x_global[v0, v1, v2, v3]) x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] for i2_1, i3_1 in T.grid(10, 10): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) T.reads(x_global[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 - 3 + 6, v_i3 * 2 - 3:v_i3 * 2 - 3 + 6]) T.writes(y[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"explicit_read_region": [T.int32(0)]}) + T.sblock_attr({"explicit_read_region": [T.int32(0)]}) y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] @T.prim_func @@ -276,14 +276,14 @@ def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32" x_global = T.alloc_buffer((1, 3, 200, 200)) for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): for ax0, ax1 in T.grid(200, 200): - with T.block("cache"): + with T.sblock("cache"): v0 = T.axis.spatial(1, 0) v1, v2, v3 = T.axis.remap("SSS", [i1, ax0, ax1]) T.reads(x[v0, v1, v2, v3]) T.writes(x_global[v0, v1, v2, v3]) x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] for i2_1, i3_1 in T.grid(10, 10): - with T.block("resize"): + with T.sblock("resize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) @@ -294,8 +294,8 @@ def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32" # Schedule with annotate_buffer_access sch = tir.Schedule(before, debug_mask="all") - block = sch.get_block("resize") - cache_block = sch.get_block("cache") + block = sch.get_sblock("resize") + cache_block = sch.get_sblock("cache") # Annotate buffer access sch.annotate_buffer_access( @@ -316,8 +316,8 @@ def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32" # Schedule without annotate_buffer_access sch_without_annotate = tir.Schedule(before, debug_mask="all") - block_without_annotate = sch_without_annotate.get_block("resize") - cache_block_without_annotate = sch_without_annotate.get_block("cache") + block_without_annotate = sch_without_annotate.get_sblock("resize") + cache_block_without_annotate = sch_without_annotate.get_sblock("cache") h, w = sch_without_annotate.get_loops(block_without_annotate)[-2:] ho, hi = sch_without_annotate.split(h, factors=[10, 10]) diff --git a/tests/python/tir-schedule/test_tir_schedule_block_scope.py b/tests/python/tir-schedule/test_tir_schedule_block_scope.py index 375b6c07c2bb..578bf6bcb590 100644 --- a/tests/python/tir-schedule/test_tir_schedule_block_scope.py +++ b/tests/python/tir-schedule/test_tir_schedule_block_scope.py @@ -34,11 +34,11 @@ def elementwise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -49,11 +49,11 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -65,10 +65,10 @@ def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -78,32 +78,32 @@ def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=invalid-name -def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: +def _get_sblock(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: result = None def f_visit(node): nonlocal result - if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint: + if isinstance(node, tvm.tir.SBlock) and node.name_hint == name_hint: result = node func = s.mod["main"] post_order_visit(func.body, f_visit) - assert result is not None and isinstance(result, tvm.tir.Block) + assert result is not None and isinstance(result, tvm.tir.SBlock) return s.get_sref(result) def test_elementwise_dependency(): s = tir.ScheduleState(elementwise, debug_mask="all") - root = _get_block(s, "root") - block_b = _get_block(s, "B") - block_c = _get_block(s, "C") + root = _get_sblock(s, "root") + block_b = _get_sblock(s, "B") + block_c = _get_sblock(s, "C") # Check get_deps_by_src - (dep,) = s.get_block_scope(root).get_deps_by_src(block_b) + (dep,) = s.get_sblock_scope(root).get_deps_by_src(block_b) assert dep.src.same_as(block_b) assert dep.dst.same_as(block_c) assert dep.kind == DepKind.RAW # Check get_deps_by_dst - (dep,) = s.get_block_scope(root).get_deps_by_dst(block_c) + (dep,) = s.get_sblock_scope(root).get_deps_by_dst(block_c) assert dep.src.same_as(block_b) assert dep.dst.same_as(block_c) assert dep.kind == DepKind.RAW @@ -111,11 +111,11 @@ def test_elementwise_dependency(): def test_matmul_dependency(): s = tir.ScheduleState(matmul, debug_mask="all") - root = _get_block(s, "root") - init = _get_block(s, "init") - update = _get_block(s, "update") + root = _get_sblock(s, "root") + init = _get_sblock(s, "init") + update = _get_sblock(s, "update") # Check get_deps_by_src - p0, p1 = s.get_block_scope(root).get_deps_by_src(init) + p0, p1 = s.get_sblock_scope(root).get_deps_by_src(init) assert p0.src.same_as(init) assert p0.dst.same_as(update) assert p1.src.same_as(init) @@ -124,7 +124,7 @@ def test_matmul_dependency(): p0.kind == DepKind.WAW and p1.kind == DepKind.RAW ) # Check get_deps_by_dst - p0, p1 = s.get_block_scope(root).get_deps_by_dst(update) + p0, p1 = s.get_sblock_scope(root).get_deps_by_dst(update) assert p0.src.same_as(init) assert p0.dst.same_as(update) assert p1.src.same_as(init) @@ -136,16 +136,16 @@ def test_matmul_dependency(): def test_war_dependency(): s = tir.ScheduleState(war_dependency, debug_mask="all") - root = _get_block(s, "root") - block_c = _get_block(s, "C") - block_b = _get_block(s, "B") + root = _get_sblock(s, "root") + block_c = _get_sblock(s, "C") + block_b = _get_sblock(s, "B") # Check get_deps_by_src - (dep,) = s.get_block_scope(root).get_deps_by_src(block_c) + (dep,) = s.get_sblock_scope(root).get_deps_by_src(block_c) assert dep.src.same_as(block_c) assert dep.dst.same_as(block_b) assert dep.kind == DepKind.WAR # Check get_deps_by_dst - (dep,) = s.get_block_scope(root).get_deps_by_dst(block_b) + (dep,) = s.get_sblock_scope(root).get_deps_by_dst(block_b) assert dep.src.same_as(block_c) assert dep.dst.same_as(block_b) assert dep.kind == DepKind.WAR diff --git a/tests/python/tir-schedule/test_tir_schedule_blockize.py b/tests/python/tir-schedule/test_tir_schedule_blockize.py index 631df7a82dc3..bcb12ee9fdf2 100644 --- a/tests/python/tir-schedule/test_tir_schedule_blockize.py +++ b/tests/python/tir-schedule/test_tir_schedule_blockize.py @@ -28,7 +28,7 @@ @T.prim_func def single_elementwise(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -42,17 +42,17 @@ def after_blockize_outer( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), ) -> None: - with T.block("blockized_B"): + with T.sblock("blockized_B"): vio = T.axis.spatial(1, 0) vjo = T.axis.spatial(1, 0) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 func = single_elementwise s = tir.Schedule(func, debug_mask="all") - x, _ = s.get_loops(s.get_block("B")) + x, _ = s.get_loops(s.get_sblock("B")) s.blockize(x) tvm.ir.assert_structural_equal( s.mod["main"], after_blockize_outer.with_attr("global_symbol", "single_elementwise") @@ -67,17 +67,17 @@ def after_blockize_inner( B: T.Buffer((128, 128), "float32"), ) -> None: for i in T.serial(128): - with T.block("blockized_B"): + with T.sblock("blockized_B"): vi = T.axis.spatial(128, i) vjo = T.axis.spatial(1, 0) for j in T.serial(128): - with T.block("B"): + with T.sblock("B"): vj = T.axis.remap("S", [j]) B[vi, vj] = A[vi, vj] * 2.0 func = single_elementwise s = tir.Schedule(func, debug_mask="all") - _, y = s.get_loops(s.get_block("B")) + _, y = s.get_loops(s.get_sblock("B")) s.blockize(y) tvm.ir.assert_structural_equal( s.mod["main"], after_blockize_inner.with_attr("global_symbol", "single_elementwise") @@ -93,18 +93,18 @@ def before_blockize_rca( ) -> None: B = T.alloc_buffer([128, 128], dtype="float32") for i, j in T.grid(8, 8): - with T.block("B_o"): + with T.sblock("B_o"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) for i_1, j_1 in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 for ax0, ax1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i * 16 + ax0) vj = T.axis.spatial(128, j * 16 + ax1) T.reads(B[vi, vj]) @@ -118,22 +118,22 @@ def after_blockize_rca( ) -> None: B = T.alloc_buffer([128, 128], dtype="float32") for i, j in T.grid(8, 8): - with T.block("B_o"): + with T.sblock("B_o"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) for i_1, j_1 in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 - with T.block("C_o"): + with T.sblock("C_o"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) for ax0, ax1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi_i, vj_i = T.axis.remap("SS", [ax0, ax1]) T.reads(B[vi * 16 + vi_i, vj * 16 + vj_i]) T.writes(C[vi * 16 + vi_i, vj * 16 + vj_i]) @@ -141,7 +141,7 @@ def after_blockize_rca( func = before_blockize_rca s = tir.Schedule(func, debug_mask="all") - _, _, x, _ = s.get_loops(s.get_block("C")) + _, _, x, _ = s.get_loops(s.get_sblock("C")) s.blockize(x) tvm.ir.assert_structural_equal( s.mod["main"], after_blockize_rca.with_attr("global_symbol", "before_blockize_rca") @@ -156,22 +156,22 @@ def before_blockize_compute_at( C: T.Buffer((128, 128), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") B = T.alloc_buffer([128, 128], dtype="float32") for i_0, j_0 in T.grid(8, 8): for ax0, ax1 in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i_0 * 16 + ax0) vj = T.axis.spatial(128, j_0 * 16 + ax1) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * 2.0 - with T.block("C_o"): + with T.sblock("C_o"): vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) for i_1, j_1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) @@ -186,24 +186,24 @@ def after_blockize_compute_at( ) -> None: B = T.alloc_buffer([128, 128], dtype="float32") for i_0, j_0 in T.grid(8, 8): - with T.block("B_o"): + with T.sblock("B_o"): vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) T.reads(A[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) T.writes(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) for ax0, ax1 in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi_i, vj_i = T.axis.remap("SS", [ax0, ax1]) T.reads(A[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) T.writes(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( A[vi_o * 16 + vi_i, vj_o * 16 + vj_i] * 2.0 ) - with T.block("C_o"): + with T.sblock("C_o"): vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) for i_1, j_1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) @@ -213,7 +213,7 @@ def after_blockize_compute_at( func = before_blockize_compute_at s = tir.Schedule(func, debug_mask="all") - _, _, x, _ = s.get_loops(s.get_block("B")) + _, _, x, _ = s.get_loops(s.get_sblock("B")) s.blockize(x) tvm.ir.assert_structural_equal( s.mod["main"], @@ -226,7 +226,7 @@ def test_blockize_init_loops(): @T.prim_func def rowsum(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) -> None: for k, i in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vk, vi = T.axis.remap("RS", [k, i]) with T.init(): B[vi] = 0.0 @@ -237,21 +237,21 @@ def after_rowsum_blockize( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"), ) -> None: - with T.block("blockized_B"): + with T.sblock("blockized_B"): vko = T.axis.R(1, 0) vio = T.axis.S(1, 0) with T.init(): for i1 in T.serial(0, 128): - with T.block("B_init"): + with T.sblock("B_init"): vi_init = T.axis.S(128, i1) B[vi_init] = T.float32(0) for i0, i1_1 in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vk, vi = T.axis.remap("RS", [i0, i1_1]) B[vi] = B[vi] + A[vi, vk] s = tir.Schedule(rowsum, debug_mask="all") - k, _ = s.get_loops(s.get_block("B")) + k, _ = s.get_loops(s.get_sblock("B")) s.blockize(k) tvm.ir.assert_structural_equal( s.mod["main"], after_rowsum_blockize.with_attr("global_symbol", "rowsum") @@ -267,7 +267,7 @@ def single_elementwise_int64( B: T.Buffer((T.int64(16), T.int64(128)), "float32"), ) -> None: for i0, j0, i1, j1 in T.grid(T.int64(1), T.int64(8), T.int64(16), T.int64(16)): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(T.int64(16), i0 * T.int64(16) + i1) vj = T.axis.S(T.int64(128), j0 * T.int64(16) + j1) B[vi, vj] = A[vi, vj] + 1.0 @@ -278,11 +278,11 @@ def after_single_elementwise_int64_blockize( B: T.Buffer((T.int64(16), T.int64(128)), "float32"), ) -> None: for i0, j0 in T.grid(T.int64(1), T.int64(8)): - with T.block("B_o"): + with T.sblock("B_o"): vi_o = T.axis.spatial(T.int64(1), T.int64(0)) vj_o = T.axis.spatial(T.int64(8), j0) for i1, j1 in T.grid(T.int64(16), T.int64(16)): - with T.block("B"): + with T.sblock("B"): vi_i, vj_i = T.axis.remap("SS", [i1, j1]) B[vi_i, vj_o * T.int64(16) + vj_i] = A[ vi_i, vj_o * T.int64(16) + vj_i @@ -294,18 +294,18 @@ def after_single_elementwise_int64_blockize_preserve_unit_iters( B: T.Buffer((T.int64(16), T.int64(128)), "float32"), ) -> None: for i0, j0 in T.grid(T.int64(1), T.int64(8)): - with T.block("B_o"): + with T.sblock("B_o"): vi_o = T.axis.spatial(T.int64(1), i0) vj_o = T.axis.spatial(T.int64(8), j0) for i1, j1 in T.grid(T.int64(16), T.int64(16)): - with T.block("B"): + with T.sblock("B"): vi_i, vj_i = T.axis.remap("SS", [i1, j1]) B[vi_i, vj_o * T.int64(16) + vj_i] = A[ vi_i, vj_o * T.int64(16) + vj_i ] + T.float32(1) s = tir.Schedule(single_elementwise_int64, debug_mask="all") - _, _, i1, _ = s.get_loops(s.get_block("B")) + _, _, i1, _ = s.get_loops(s.get_sblock("B")) s.blockize(i1, preserve_unit_iters=preserve_unit_iters) expected = ( after_single_elementwise_int64_blockize_preserve_unit_iters @@ -323,14 +323,14 @@ def test_blockize_blocks(): def blocks_func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")) -> None: for m in T.serial(6): for i, j in T.grid(3, 1): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 64): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj + 64]) T.writes(B[vi, vj + 64]) @@ -341,26 +341,26 @@ def after_blocks_blockize( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32") ) -> None: for m in range(6): - with T.block("outer_B_C_"): + with T.sblock("outer_B_C_"): vi_o = T.axis.spatial(1, 0) vj_o = T.axis.spatial(1, 0) T.reads(A[0:128, 0:128]) T.writes(B[0:128, 0:128]) for i, j in T.grid(3, 1): - with T.block("B"): + with T.sblock("B"): vi_i = T.axis.spatial(3, i) T.reads(A[vi_i, 0]) T.writes(B[vi_i, 0]) B[vi_i, 0] = A[vi_i, 0] * T.float32(2) for i, j in T.grid(128, 64): - with T.block("C"): + with T.sblock("C"): vi_i, vj_i = T.axis.remap("SS", [i, j]) T.reads(A[vi_i, vj_i + 64]) T.writes(B[vi_i, vj_i + 64]) B[vi_i, vj_i + 64] = A[vi_i, vj_i + 64] * T.float32(3) s = tir.Schedule(blocks_func, debug_mask="all") - blocks = [s.get_block("B"), s.get_block("C")] + blocks = [s.get_sblock("B"), s.get_sblock("C")] s.blockize(blocks, preserve_unit_iters=False) expected = after_blocks_blockize tvm.ir.assert_structural_equal( diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_index.py b/tests/python/tir-schedule/test_tir_schedule_cache_index.py index 5ef39958823b..7341156e98f6 100644 --- a/tests/python/tir-schedule/test_tir_schedule_cache_index.py +++ b/tests/python/tir-schedule/test_tir_schedule_cache_index.py @@ -34,7 +34,7 @@ def resize(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (1, 3, 40, 40)) B = T.match_buffer(b, (1, 3, 80, 80)) for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): - with T.block("A"): + with T.sblock("A"): n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) B[n, c, vi, vj] = A[n, c, vi // 4 + vj // 4, vj // 2] @@ -46,19 +46,19 @@ def resize_cache_index( index_var_0 = T.alloc_buffer([80, 80], dtype="int32", strides=[1]) index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) for ax0, ax1 in T.grid(80, 80): - with T.block("index_0"): + with T.sblock("index_0"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads() T.writes(index_var_0[v0, v1]) index_var_0[v0, v1] = v0 // 4 + v1 // 4 for ax0 in T.serial(80): - with T.block("index_1"): + with T.sblock("index_1"): v0 = T.axis.spatial(80, ax0) T.reads() T.writes(index_var_1[v0]) index_var_1[v0] = v0 // 2 for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): - with T.block("A"): + with T.sblock("A"): n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[n, c, vi // 4 + vj // 4, vj // 2]) T.writes(B[n, c, vi, vj]) @@ -70,7 +70,7 @@ def bilinear_resize( x: T.Buffer((1, 3, 40, 40), "float16"), resize: T.Buffer((1, 3, 80, 80), "float16") ): for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): - with T.block("resize"): + with T.sblock("resize"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(x[i0_1, i1_1, 0:40, 0:40]) T.writes(resize[i0_1, i1_1, i2_1, i3_1]) @@ -342,7 +342,7 @@ def cached_bilinear_resize( index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) index_var_2 = T.alloc_buffer([80], dtype="int32", strides=[1]) for ax0 in T.serial(80): - with T.block("index_0"): + with T.sblock("index_0"): v0 = T.axis.spatial(80, ax0) T.reads() T.writes(index_var_0[v0]) @@ -362,7 +362,7 @@ def cached_bilinear_resize( ) ) for ax0 in T.serial(80): - with T.block("index_1"): + with T.sblock("index_1"): v0 = T.axis.spatial(80, ax0) T.reads() T.writes(index_var_1[v0]) @@ -374,7 +374,7 @@ def cached_bilinear_resize( ), ) for ax0 in T.serial(80): - with T.block("index_2"): + with T.sblock("index_2"): v0 = T.axis.spatial(80, ax0) T.reads() T.writes(index_var_2[v0]) @@ -386,7 +386,7 @@ def cached_bilinear_resize( ), ) for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): - with T.block("resize"): + with T.sblock("resize"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(x[i0_1, i1_1, 0:40, 0:40]) T.writes(resize[i0_1, i1_1, i2_1, i3_1]) @@ -454,7 +454,7 @@ def cached_bilinear_resize( def test_basic_cache_index(): sch = tvm.tir.Schedule(resize, debug_mask="all") - block = sch.get_block("A") + block = sch.get_sblock("A") sch.cache_index(block, "global") tvm.ir.assert_structural_equal( resize_cache_index, sch.mod["main"].with_attr("global_symbol", "resize_cache_index") @@ -464,7 +464,7 @@ def test_basic_cache_index(): def test_resize_bilinear_cache_index(): sch = tvm.tir.Schedule(bilinear_resize, debug_mask="all") - block = sch.get_block("resize") + block = sch.get_sblock("resize") sch.cache_index(block, "global", 4) tvm.ir.assert_structural_equal( sch.mod["main"], cached_bilinear_resize.with_attr("global_symbol", "bilinear_resize") diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py index 1fda0f432108..509f5651f9a5 100644 --- a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py +++ b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py @@ -38,11 +38,11 @@ def elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -53,11 +53,11 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((T.int64(128), T.int64(128))) C = T.match_buffer(c, (T.int64(128), T.int64(128))) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -69,19 +69,19 @@ def elementwise_reindex_cache_read( B = T.alloc_buffer((128, 128)) B_shared = T.alloc_buffer((128, 64, 2), scope="shared") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("B_shared"): + with T.sblock("B_shared"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(B_shared[vj, vi // 2, vi % 2]) B_shared[vj, vi // 2, vi % 2] = B[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B_shared[vj, vi // 2, vi % 2]) T.writes(C[vi, vj]) @@ -95,19 +95,19 @@ def elementwise_reindex_cache_write( B = T.alloc_buffer((128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B_shared[vj, vi]) B_shared[vj, vi] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("B_shared"): + with T.sblock("B_shared"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B_shared[vj, vi]) T.writes(B[vi, vj]) B[vi, vj] = B_shared[vj, vi] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -119,12 +119,12 @@ def reduce(A: T.Buffer((128, 128, 128, 128), "float32"), C: T.Buffer((128, 128), B = T.alloc_buffer((128, 128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): for l in range(128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) with T.init(): B[vi, vj, vk] = T.float32(0) B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl] - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -139,19 +139,19 @@ def reduce_reindex_cache_write_0( B_shared = T.alloc_buffer((128, 128, 128), scope="shared") for i, j, k in T.grid(128, 128, 128): for l in range(128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) T.reads(A[vi, vj, vk, vl]) T.writes(B_shared[vj, vi, vk]) with T.init(): B_shared[vj, vi, vk] = T.float32(0) B_shared[vj, vi, vk] = B_shared[vj, vi, vk] + A[vi, vj, vk, vl] - with T.block("B_shared"): + with T.sblock("B_shared"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads(B_shared[vj, vi, vk]) T.writes(B[vi, vj, vk]) B[vi, vj, vk] = B_shared[vj, vi, vk] - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(B[vi, vj, vk]) T.writes(C[vi, vj]) @@ -169,19 +169,19 @@ def reduce_reindex_cache_write_1( C_shared = T.alloc_buffer((128, 128), scope="shared") for i, j, k in T.grid(128, 128, 128): for l in range(128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) T.reads(A[vi, vj, vk, vl]) T.writes(B_shared[vj, vi, vk]) with T.init(): B_shared[vj, vi, vk] = T.float32(0) B_shared[vj, vi, vk] = B_shared[vj, vi, vk] + A[vi, vj, vk, vl] - with T.block("B_shared"): + with T.sblock("B_shared"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads(B_shared[vj, vi, vk]) T.writes(B[vi, vj, vk]) B[vi, vj, vk] = B_shared[vj, vi, vk] - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(B[vi, vj, vk]) T.writes(C_shared[vj, vi]) @@ -189,7 +189,7 @@ def reduce_reindex_cache_write_1( C_shared[vj, vi] = T.float32(0) C_shared[vj, vi] = C_shared[vj, vi] + B[vi, vj, vk] for i, j in T.grid(128, 128): - with T.block("C_shared"): + with T.sblock("C_shared"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(C_shared[vj, vi]) T.writes(C[vi, vj]) @@ -203,22 +203,22 @@ def func_nested_seq(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 2.0 for i, j in T.grid(8, 8): for x, y in T.grid(16, 16): - with T.block("B0"): + with T.sblock("B0"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) B[vi, vj] = 1.0 for x, y in T.grid(16, 16): - with T.block("B1"): + with T.sblock("B1"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) B[vi, vj] = A[vi, vj] + B[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 @@ -230,20 +230,20 @@ def access_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i0, j0 in T.grid(8, 8): - with T.block("scope"): + with T.sblock("scope"): i, j = T.axis.remap("SS", [i0, j0]) for x, y in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) A[vi, vj] = 1.0 for x, y in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) B[vi, vj] = A[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 @@ -256,13 +256,13 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (128, 128), dtype="float16") for i, j in T.grid(128, 128): - with T.block("load_store"): + with T.sblock("load_store"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) D[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): - with T.block("opaque"): + with T.sblock("opaque"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -287,7 +287,7 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: ) ) for i, j in T.grid(8, 8): - with T.block("match_buffer"): + with T.sblock("match_buffer"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -340,15 +340,15 @@ def func_multi_consumer() -> None: C = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) C[vi] = A[vi] @@ -361,25 +361,25 @@ def reindex_cache_read_multi_consumer() -> None: A_shared = T.alloc_buffer((4, 32), scope="shared") for i in range(8): for j in range(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.spatial(128, i * 16 + j) T.reads() T.writes(A[vi]) A[vi] = T.float32(1) for j in range(16): - with T.block("A_shared"): + with T.sblock("A_shared"): vi = T.axis.spatial(128, i * 16 + j) T.reads(A[vi]) T.writes(A_shared[vi // 32, vi % 32]) A_shared[vi // 32, vi % 32] = A[vi] for j in range(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i * 16 + j) T.reads(A_shared[vi // 32, vi % 32]) T.writes(B[vi]) B[vi] = A_shared[vi // 32, vi % 32] + T.float32(1) for i in range(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) T.reads(A[vi]) T.writes(C[vi]) @@ -391,15 +391,15 @@ def func_multi_producer() -> None: A = T.alloc_buffer((128)) B = T.alloc_buffer((128)) for i in range(128): - with T.block("A0"): + with T.sblock("A0"): vi = T.axis.S(128, i) A[vi] = 1.0 for i in range(128): - with T.block("A1"): + with T.sblock("A1"): vi = T.axis.S(128, i) A[vi] = 2.0 for i in range(128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) B[vi] = A[vi] @@ -409,12 +409,12 @@ def func_with_block_predicate() -> None: A = T.alloc_buffer((120)) B = T.alloc_buffer((120)) for i, j in T.grid(16, 8): - with T.block("producer"): + with T.sblock("producer"): T.where(i * 8 + j < 120) ax = T.axis.S(120, i * 8 + j) A[ax] = 0.0 for i, j in T.grid(16, 8): - with T.block("consumer"): + with T.sblock("consumer"): T.where(i * 8 + j < 120) ax = T.axis.S(120, i * 8 + j) B[ax] = A[ax] + 1.0 @@ -424,16 +424,16 @@ def func_with_block_predicate() -> None: def inplace_func(data_io: T.Buffer((64), "int32")): data_1d = T.alloc_buffer([64], dtype="int32") for i0 in T.serial(64): - with T.block("copy_in"): + with T.sblock("copy_in"): v0 = T.axis.remap("S", [i0]) data_1d[v0] = data_io[v0] for i0 in T.serial(1): - with T.block("ext_call"): + with T.sblock("ext_call"): T.reads(data_1d[:64]) T.writes(data_1d[:64]) T.evaluate(T.call_extern("call_impl", data_1d.data, dtype="")) for i0 in T.serial(64): - with T.block("copy_out"): + with T.sblock("copy_out"): v0 = T.axis.remap("S", [i0]) data_io[v0] = data_1d[v0] @@ -441,7 +441,7 @@ def inplace_func(data_io: T.Buffer((64), "int32")): @T.prim_func def inplace_call(data_io: T.Buffer((64), "int32")): for i0 in T.serial(1): - with T.block("ext_call"): + with T.sblock("ext_call"): T.reads(data_io[:64]) T.writes(data_io[:64]) T.evaluate(T.call_extern("call_impl", data_io.data, dtype="")) @@ -454,34 +454,34 @@ def cache_read_nested_seq_target( A = T.alloc_buffer([128, 128], dtype="float32") A_global = T.alloc_buffer([128, 128], dtype="float32") for i, j in T.grid(128, 128): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) T.reads() T.writes(A[vi, vj]) A[vi, vj] = T.float32(2) for i, j in T.grid(8, 8): for x, y in T.grid(16, 16): - with T.block("B0"): + with T.sblock("B0"): vi = T.axis.spatial(128, i * 16 + x) vj = T.axis.spatial(128, j * 16 + y) T.reads() T.writes(B[vi, vj]) B[vi, vj] = T.float32(1) for x, y in T.grid(16, 16): - with T.block("B1"): + with T.sblock("B1"): vi = T.axis.spatial(128, i * 16 + x) vj = T.axis.spatial(128, j * 16 + y) T.reads(A[vi, vj], B[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] + B[vi, vj] for ax0, ax1 in T.grid(128, 128): - with T.block("A_global"): + with T.sblock("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v0, v1]) T.writes(A_global[v0, v1]) A_global[v0, v1] = A[v0, v1] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(C[vi, vj]) @@ -494,7 +494,7 @@ def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle): B = T.match_buffer(var_B, T.int64(1), dtype="int32") C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32") for ax0, ax1 in T.grid(T.int64(1), T.int64(512)): - with T.block("C"): + with T.sblock("C"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[B[v_ax0], v_ax1], B[v_ax0]) T.writes(C[v_ax0, v_ax1]) @@ -512,19 +512,19 @@ def cache_read_elementwise(a: T.handle, c: T.handle) -> None: A_global = T.alloc_buffer((128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("A_global"): + with T.sblock("A_global"): vi, vj = T.axis.remap("SS", [i, j]) A_global[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_global[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("B_local"): + with T.sblock("B_local"): vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = B[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B_local[vi, vj] + 1.0 @@ -537,30 +537,30 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: A_global = T.alloc_buffer((128, 128)) for i0, j0 in T.grid(8, 8): - with T.block("scope"): + with T.sblock("scope"): i, j = T.axis.remap("SS", [i0, j0]) A_local = T.alloc_buffer((16, 16), scope="local") for x, y in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) A[vi, vj] = 1.0 for x, y in T.grid(16, 16): - with T.block("A_local"): + with T.sblock("A_local"): vi = T.axis.S(16, x) vj = T.axis.S(16, y) A_local[vi, vj] = A[i * 16 + vi, j * 16 + vj] for x, y in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) B[vi, vj] = A_local[vi - i * 16, vj - j * 16] + 1.0 for i, j in T.grid(128, 128): - with T.block("A_global"): + with T.sblock("A_global"): vi, vj = T.axis.remap("SS", [i, j]) A_global[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A_global[vi, vj] * 2.0 @@ -574,17 +574,17 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) A_global = T.alloc_buffer((128, 128), dtype="float16") for i, j in T.grid(128, 128): - with T.block("A_global"): + with T.sblock("A_global"): vi, vj = T.axis.remap("SS", [i, j]) A_global[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block("load_store"): + with T.sblock("load_store"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(D[vi, vj]) D[vi, vj] = A_global[vi, vj] for i, j in T.grid(8, 8): - with T.block("opaque"): + with T.sblock("opaque"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -609,7 +609,7 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) ) ) for i, j in T.grid(8, 8): - with T.block("match_buffer"): + with T.sblock("match_buffer"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -663,20 +663,20 @@ def cache_read_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + j) A_global[vi] = A[vi] for j in T.grid(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + j) B[vi] = A_global[vi] + 1.0 for i in T.grid(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) C[vi] = A_global[vi] @@ -689,20 +689,20 @@ def cache_read_multi_consumer_target() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i) A_global[vi] = A[vi] for i in T.grid(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) C[vi] = A_global[vi] @@ -715,19 +715,19 @@ def continuous_cache_read(a: T.handle, c: T.handle) -> None: B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("B_shared"): + with T.sblock("B_shared"): vi, vj = T.axis.remap("SS", [i, j]) B_shared[vi, vj] = B[vi, vj] for i, j in T.grid(128, 128): - with T.block("B_local"): + with T.sblock("B_local"): vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = B_shared[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B_local[vi, vj] + 1.0 @@ -738,16 +738,16 @@ def block_predicate_cache_read() -> None: B = T.alloc_buffer([120], dtype="float32") A_shared = T.alloc_buffer([120], dtype="float32", scope="shared") for i, j in T.grid(16, 8): - with T.block("producer"): + with T.sblock("producer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) A[ax] = T.float32(0) for ax0 in T.serial(120): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(120, ax0) A_shared[v0] = A[v0] for i, j in T.grid(16, 8): - with T.block("consumer"): + with T.sblock("consumer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) B[ax] = A_shared[ax] + T.float32(1) @@ -760,19 +760,19 @@ def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None: B = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") A_global = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): - with T.block("A_global"): + with T.sblock("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v0, v1]) T.writes(A_global[v0, v1]) A_global[v0, v1] = A[v0, v1] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A_global[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -784,24 +784,24 @@ def cache_read_inplace(data_io: T.Buffer(64, "int32")) -> None: data_1d = T.alloc_buffer([64], dtype="int32") data_io_local = T.alloc_buffer([64], dtype="int32", scope="local") for ax0 in T.serial(64): - with T.block("data_io_local"): + with T.sblock("data_io_local"): v0 = T.axis.spatial(64, ax0) T.reads(data_io[v0]) T.writes(data_io_local[v0]) data_io_local[v0] = data_io[v0] for i0 in T.serial(64): - with T.block("copy_in"): + with T.sblock("copy_in"): v0 = T.axis.spatial(64, i0) T.reads(data_io_local[v0]) T.writes(data_1d[v0]) data_1d[v0] = data_io_local[v0] for i0 in T.serial(1): - with T.block("ext_call"): + with T.sblock("ext_call"): T.reads(data_1d[0:64]) T.writes(data_1d[0:64]) T.evaluate(T.call_extern("call_impl", data_1d.data, dtype="")) for i0 in T.serial(64): - with T.block("copy_out"): + with T.sblock("copy_out"): v0 = T.axis.spatial(64, i0) T.reads(data_1d[v0]) T.writes(data_io[v0]) @@ -814,30 +814,30 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) -> None: data_io_global = T.alloc_buffer([64], dtype="int32") data_io_global_1 = T.alloc_buffer([64], dtype="int32") for ax0 in T.serial(64): - with T.block("data_io_global"): + with T.sblock("data_io_global"): v0 = T.axis.spatial(64, ax0) T.reads(data_io[v0]) T.writes(data_io_global[v0]) data_io_global[v0] = data_io[v0] for i0 in T.serial(1): for ax0 in T.serial(64): - with T.block("data_io_local"): + with T.sblock("data_io_local"): v0 = T.axis.spatial(64, ax0) T.reads(data_io_global[v0]) T.writes(data_io_local[v0]) data_io_local[v0] = data_io_global[v0] - with T.block("ext_call"): + with T.sblock("ext_call"): T.reads(data_io_local[0:64]) T.writes(data_io_local[0:64]) T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype="")) for ax0 in T.serial(64): - with T.block("data_io_local"): + with T.sblock("data_io_local"): v0 = T.axis.spatial(64, ax0) T.reads(data_io_local[v0]) T.writes(data_io_global_1[v0]) data_io_global_1[v0] = data_io_local[v0] for ax0 in T.serial(64): - with T.block("data_io_global"): + with T.sblock("data_io_global"): v0 = T.axis.spatial(64, ax0) T.reads(data_io_global_1[v0]) T.writes(data_io[v0]) @@ -851,13 +851,13 @@ def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.h C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32") B_global = T.alloc_buffer((T.int64(1),), "int32") for ax0 in range(T.int64(1)): - with T.block("B_global"): + with T.sblock("B_global"): v0 = T.axis.spatial(T.int64(1), ax0) T.reads(B[v0]) T.writes(B_global[v0]) B_global[v0] = B[v0] for ax0, ax1 in T.grid(T.int64(1), T.int64(512)): - with T.block("C"): + with T.sblock("C"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0]) T.writes(C[v_ax0, v_ax1]) @@ -875,19 +875,19 @@ def cache_write_elementwise(a: T.handle, c: T.handle) -> None: B_global = T.alloc_buffer((128, 128), scope="local") C_local = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B_global"): + with T.sblock("B_global"): vi, vj = T.axis.remap("SS", [i, j]) B_global[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_global[vi, vj] for i, j in T.grid(128, 128): - with T.block("C_local"): + with T.sblock("C_local"): vi, vj = T.axis.remap("SS", [i, j]) C_local[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = C_local[vi, vj] @@ -900,36 +900,36 @@ def cache_write_under_scope(b: T.handle, c: T.handle) -> None: A_global = T.alloc_buffer((128, 128)) for i0, j0 in T.grid(8, 8): - with T.block("scope"): + with T.sblock("scope"): i, j = T.axis.remap("SS", [i0, j0]) A_local = T.alloc_buffer((16, 16), scope="local") B_global = T.alloc_buffer((16, 16)) for x, y in T.grid(16, 16): - with T.block("A_local"): + with T.sblock("A_local"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) A_local[vi - i * 16, vj - j * 16] = 1.0 for x, y in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(16, x) vj = T.axis.S(16, y) A_global[i * 16 + vi, j * 16 + vj] = A_local[vi, vj] for x, y in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) B_global[vi - i * 16, vj - j * 16] = A_global[vi, vj] + 1.0 for x, y in T.grid(16, 16): - with T.block("B_global"): + with T.sblock("B_global"): vi = T.axis.S(16, x) vj = T.axis.S(16, y) B[i * 16 + vi, j * 16 + vj] = B_global[vi, vj] for i, j in T.grid(128, 128): - with T.block("A_global"): + with T.sblock("A_global"): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = A_global[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 @@ -945,13 +945,13 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle C_global = T.alloc_buffer((128, 128), dtype="float16") for i, j in T.grid(128, 128): - with T.block("load_store"): + with T.sblock("load_store"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D_global[vi, vj]) D_global[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): - with T.block("opaque"): + with T.sblock("opaque"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -976,7 +976,7 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle ) ) for i, j in T.grid(8, 8): - with T.block("match_buffer"): + with T.sblock("match_buffer"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) @@ -1022,15 +1022,15 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle ) for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = D_global[vi, vj] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_global[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = C_global[vi, vj] @@ -1043,20 +1043,20 @@ def cache_write_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block("A_global"): + with T.sblock("A_global"): vi = T.axis.S(128, i * 16 + j) A_global[vi] = 1.0 for j in T.grid(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(128, i * 16 + j) A[vi] = A_global[vi] for j in T.grid(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) C[vi] = A[vi] @@ -1069,19 +1069,19 @@ def cache_write_multi_consumer_B_consume_cache(): A_global = T.alloc_buffer([128], dtype="float32") for i in T.serial(8): for j in T.serial(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.spatial(128, i * 16 + j) A_global[vi] = 1.0 for j in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i * 16 + j) B[vi] = A_global[vi] + 1.0 for ax0 in T.serial(128): - with T.block("A_global"): + with T.sblock("A_global"): v0 = T.axis.spatial(128, ax0) A[v0] = A_global[v0] for i in T.serial(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) C[vi] = A[vi] @@ -1094,19 +1094,19 @@ def cache_write_multi_consumer_C_consume_cache(): A_global = T.alloc_buffer([128], dtype="float32") for i in T.serial(8): for j in T.serial(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.spatial(128, i * 16 + j) A_global[vi] = T.float32(1) for ax0 in T.serial(16): - with T.block("A_global"): + with T.sblock("A_global"): v0 = T.axis.spatial(128, i * 16 + ax0) A[v0] = A_global[v0] for j in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i * 16 + j) B[vi] = A[vi] + T.float32(1) for i in T.serial(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) C[vi] = A_global[vi] @@ -1119,19 +1119,19 @@ def cache_write_multi_consumer_all_consume_cache(): A_global = T.alloc_buffer([128], dtype="float32") for i in T.serial(8): for j in T.serial(16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.spatial(128, i * 16 + j) A_global[vi] = T.float32(1) for j in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i * 16 + j) B[vi] = A_global[vi] + T.float32(1) for i in T.serial(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i) C[vi] = A_global[vi] for ax0 in T.serial(128): - with T.block("A_global"): + with T.sblock("A_global"): v0 = T.axis.spatial(128, ax0) A[v0] = A_global[v0] @@ -1144,19 +1144,19 @@ def continuous_cache_write(a: T.handle, c: T.handle) -> None: B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_shared[vi, vj] = B_local[vi, vj] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_shared[vi, vj] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -1167,16 +1167,16 @@ def block_predicate_cache_write_intermediate_buf() -> None: B = T.alloc_buffer([120], dtype="float32") A_shared = T.alloc_buffer([120], dtype="float32", scope="shared") for i, j in T.grid(16, 8): - with T.block("producer"): + with T.sblock("producer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) A_shared[ax] = T.float32(0) for ax0 in T.serial(120): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(120, ax0) A[v0] = A_shared[v0] for i, j in T.grid(16, 8): - with T.block("consumer"): + with T.sblock("consumer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) B[ax] = A[ax] + 1.0 @@ -1188,17 +1188,17 @@ def block_predicate_cache_write_output_buf() -> None: B = T.alloc_buffer([120], dtype="float32") B_shared = T.alloc_buffer([120], dtype="float32", scope="shared") for i, j in T.grid(16, 8): - with T.block("producer"): + with T.sblock("producer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) A[ax] = T.float32(0) for i, j in T.grid(16, 8): - with T.block("consumer"): + with T.sblock("consumer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) B_shared[ax] = A[ax] + T.float32(1) for ax0 in T.serial(120): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial(120, ax0) B[v0] = B_shared[v0] @@ -1209,7 +1209,7 @@ def symbolic_matmul_blocked(var_A: T.handle, var_B: T.handle, var_C: T.handle, n B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): - with T.block("matmul_o"): + with T.sblock("matmul_o"): v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads( A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], @@ -1217,7 +1217,7 @@ def symbolic_matmul_blocked(var_A: T.handle, var_B: T.handle, var_C: T.handle, n ) T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) for i0_1, i1_1, k in T.grid(32, 32, 4): - with T.block("matmul"): + with T.sblock("matmul"): v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) @@ -1237,7 +1237,7 @@ def symbolic_matmul_blocked_cache_read( B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): - with T.block("matmul_o"): + with T.sblock("matmul_o"): v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads( A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], @@ -1246,14 +1246,14 @@ def symbolic_matmul_blocked_cache_read( T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) A_shared = T.alloc_buffer((32, 4), scope="shared") for ax0, ax1 in T.grid(32, 4): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(32, ax0) v1 = T.axis.spatial(4, ax1) T.reads(A[v_i0_o * 32 + v0, v1]) T.writes(A_shared[v0, v1]) A_shared[v0, v1] = A[v_i0_o * 32 + v0, v1] for i0_1, i1_1, k in T.grid(32, 32, 4): - with T.block("matmul"): + with T.sblock("matmul"): v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) T.reads(A_shared[v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) @@ -1273,7 +1273,7 @@ def symbolic_matmul_blocked_cache_write( B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): - with T.block("matmul_o"): + with T.sblock("matmul_o"): v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads( A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], @@ -1282,7 +1282,7 @@ def symbolic_matmul_blocked_cache_write( T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) C_pad_local = T.alloc_buffer((32, 32), scope="local") for i0_1, i1_1, k in T.grid(32, 32, 4): - with T.block("matmul"): + with T.sblock("matmul"): v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) T.writes(C_pad_local[v_i0_i, v_i1_i]) @@ -1293,7 +1293,7 @@ def symbolic_matmul_blocked_cache_write( + A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] ) for ax0, ax1 in T.grid(32, 32): - with T.block("C_pad_local"): + with T.sblock("C_pad_local"): v0 = T.axis.spatial(32, ax0) v1 = T.axis.spatial(32, ax1) T.reads(C_pad_local[v0, v1]) @@ -1308,26 +1308,26 @@ def symbolic_matmul_blocked_cache_write( def test_cache_read_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_b = sch.get_sblock("B") + block_c = sch.get_sblock("C") if use_block_name: cached_a = sch.cache_read("B", "A", "global") cached_b = sch.cache_read("C", "B", "local") else: cached_a = sch.cache_read(block_b, 0, "global") cached_b = sch.cache_read(block_c, 0, "local") - assert sch.get(cached_a) == sch.get(sch.get_block("A_global")) - assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) - assert sch.get(block_b) == sch.get(sch.get_block("B")) - assert sch.get(block_c) == sch.get(sch.get_block("C")) + assert sch.get(cached_a) == sch.get(sch.get_sblock("A_global")) + assert sch.get(cached_b) == sch.get(sch.get_sblock("B_local")) + assert sch.get(block_b) == sch.get(sch.get_sblock("B")) + assert sch.get(block_c) == sch.get(sch.get_sblock("C")) assert_structural_equal_ignore_global_symbol(cache_read_elementwise, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) def test_cache_read_under_scope(use_block_name): sch = tir.Schedule(access_under_scope, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = "C" if use_block_name else sch.get_block("C") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_read(block_b, 0, "local") sch.cache_read(block_c, 0, "global") assert_structural_equal_ignore_global_symbol(cache_read_under_scope, sch.mod["main"]) @@ -1336,7 +1336,7 @@ def test_cache_read_under_scope(use_block_name): def test_cache_read_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") - block = "load_store" if use_block_name else sch.get_block("load_store") + block = "load_store" if use_block_name else sch.get_sblock("load_store") sch.cache_read(block, 0, "global") assert_structural_equal_ignore_global_symbol(cache_read_opaque_access, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) @@ -1344,23 +1344,23 @@ def test_cache_read_opaque_access(use_block_name): def test_cache_read_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") sch.cache_read(block_b, 0, "global") assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Test that specific consumer block targeting works. sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = "C" if use_block_name else sch.get_block("C") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_read(block_b, 0, "global", consumer_blocks=[block_c]) assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer_target, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) # Also test setting multiple consumers yields same result as unspecified. sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = "C" if use_block_name else sch.get_block("C") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_read(block_b, 0, "global", consumer_blocks=[block_b, block_c]) assert_structural_equal_ignore_global_symbol(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) @@ -1368,7 +1368,7 @@ def test_cache_read_location(use_block_name): def test_continuous_cache_read(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_read(block_c, 0, "shared") sch.cache_read(block_c, 0, "local") assert_structural_equal_ignore_global_symbol(continuous_cache_read, sch.mod["main"]) @@ -1377,7 +1377,7 @@ def test_continuous_cache_read(use_block_name): def test_cache_read_with_block_predicate(use_block_name): sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = "consumer" if use_block_name else sch.get_block("consumer") + block = "consumer" if use_block_name else sch.get_sblock("consumer") sch.cache_read(block, 0, "shared") assert_structural_equal_ignore_global_symbol(block_predicate_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) @@ -1385,7 +1385,7 @@ def test_cache_read_with_block_predicate(use_block_name): def test_cache_read_non_int32_shape(use_block_name): sch = tir.Schedule(elementwise_shape_int64, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") sch.cache_read(block_b, 0, "global") assert_structural_equal_ignore_global_symbol(cache_read_shape_int64, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64) @@ -1393,7 +1393,7 @@ def test_cache_read_non_int32_shape(use_block_name): def test_cache_read_nested_buffer_access(use_block_name): sch = tir.Schedule(nested_buffer_access, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_read(block_c, 1, "global") assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=nested_buffer_access) @@ -1401,21 +1401,21 @@ def test_cache_read_nested_buffer_access(use_block_name): def test_cache_read_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 0, "global") def test_cache_read_fail_index_out_of_bound(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 1, "global") def test_cache_read_fail_invalid_storage_scope(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 0, "test_scope") @@ -1426,7 +1426,7 @@ def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")): B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) B_buf = T.decl_buffer((8), dtype="float32", data=B) for i in range(8): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(8, i) C[vi] = A[vi] + B_buf[vi] @@ -1437,20 +1437,20 @@ def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")): B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) B_buf = T.decl_buffer((8), data=B) for ax0 in range(8): - with T.block("A_global"): + with T.sblock("A_global"): v0 = T.axis.spatial(8, ax0) A_global[v0] = A[v0] for ax0 in range(8): - with T.block("B_buf_global"): + with T.sblock("B_buf_global"): v0 = T.axis.spatial(8, ax0) B_buf_global[v0] = B_buf[v0] for i in range(8): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(8, i) C[vi] = A_global[vi] + B_buf_global[vi] sch = tir.Schedule(before) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") sch.cache_read(block_c, 1, "global") sch.cache_read(block_c, 0, "global") @@ -1462,7 +1462,7 @@ def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")): def test_inplace_cache_read(): sch = tvm.tir.Schedule(inplace_func, debug_mask="all") - block = sch.get_block("copy_in") + block = sch.get_sblock("copy_in") sch.cache_read(block, 0, "local", [block]) assert_structural_equal_ignore_global_symbol(cache_read_inplace, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=inplace_func) @@ -1472,7 +1472,7 @@ def test_cache_inplace(): # cache_inplace could introduce WAR, which is expected but stage pipeline property changes debug_mask = tvm.tir.schedule.state.ScheduleDebugMask.VERIFY_SREF_TREE sch = tvm.tir.Schedule(inplace_call, debug_mask=debug_mask) - block = sch.get_block("ext_call") + block = sch.get_sblock("ext_call") blocks = sch.cache_inplace(block, 0, "local") block = sch.cache_read(blocks[0], 0, "global", [blocks[0]]) block = sch.cache_write(blocks[1], 0, "global") @@ -1483,7 +1483,7 @@ def test_cache_inplace(): def test_cache_read_nested_seq(use_block_name): sch = tir.Schedule(func_nested_seq, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c]) assert_structural_equal_ignore_global_symbol(cache_read_nested_seq_target, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_nested_seq) @@ -1494,23 +1494,23 @@ def test_cache_read_nested_seq(use_block_name): def test_cache_write_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_b = sch.get_sblock("B") + block_c = sch.get_sblock("C") cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local") cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global") - assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) - assert sch.get(cached_c) == sch.get(sch.get_block("C_global")) - assert sch.get(block_b) == sch.get(sch.get_block("B")) - assert sch.get(block_c) == sch.get(sch.get_block("C")) + assert sch.get(cached_b) == sch.get(sch.get_sblock("B_local")) + assert sch.get(cached_c) == sch.get(sch.get_sblock("C_global")) + assert sch.get(block_b) == sch.get(sch.get_sblock("B")) + assert sch.get(block_c) == sch.get(sch.get_sblock("C")) assert_structural_equal_ignore_global_symbol(cache_write_elementwise, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) def test_cache_write_under_scope(use_block_name): sch = tir.Schedule(access_under_scope, debug_mask="all") - block_a = "A" if use_block_name else sch.get_block("A") - block_b = "B" if use_block_name else sch.get_block("B") - block_scope = sch.get_block("scope") + block_a = "A" if use_block_name else sch.get_sblock("A") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_scope = sch.get_sblock("scope") sch.cache_write(block_a, 0, "local") sch.cache_write(block_b, 0, "global") sch.cache_write(block_scope, 0, "global") @@ -1520,9 +1520,9 @@ def test_cache_write_under_scope(use_block_name): def test_cache_write_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") - block_store = "load_store" if use_block_name else sch.get_block("load_store") - block_opaque = "opaque" if use_block_name else sch.get_block("opaque") - block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer") + block_store = "load_store" if use_block_name else sch.get_sblock("load_store") + block_opaque = "opaque" if use_block_name else sch.get_sblock("opaque") + block_match_buffer = "match_buffer" if use_block_name else sch.get_sblock("match_buffer") sch.cache_write(block_store, 0, "global") sch.cache_write(block_opaque, 0, "global") sch.cache_write(block_match_buffer, 0, "global") @@ -1532,7 +1532,7 @@ def test_cache_write_opaque_access(use_block_name): def test_cache_write_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_a = "A" if use_block_name else sch.get_block("A") + block_a = "A" if use_block_name else sch.get_sblock("A") sch.cache_write(block_a, 0, "global") assert_structural_equal_ignore_global_symbol(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) @@ -1540,8 +1540,8 @@ def test_cache_write_location(use_block_name): # Test that specific consumer block targeting works. # B read cache buffer and C read original output buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_a = "A" if use_block_name else sch.get_block("A") - block_b = "B" if use_block_name else sch.get_block("B") + block_a = "A" if use_block_name else sch.get_sblock("A") + block_b = "B" if use_block_name else sch.get_sblock("B") sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b]) assert_structural_equal_ignore_global_symbol( cache_write_multi_consumer_B_consume_cache, sch.mod["main"] @@ -1551,8 +1551,8 @@ def test_cache_write_location(use_block_name): # Test that specific consumer block targeting works. # B read original output buffer and C read cache buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_a = "A" if use_block_name else sch.get_block("A") - block_c = "C" if use_block_name else sch.get_block("C") + block_a = "A" if use_block_name else sch.get_sblock("A") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_write(block_a, 0, "global", consumer_blocks=[block_c]) assert_structural_equal_ignore_global_symbol( cache_write_multi_consumer_C_consume_cache, sch.mod["main"] @@ -1562,9 +1562,9 @@ def test_cache_write_location(use_block_name): # Test that specific consumer block targeting works. # B and C read cache buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_a = "A" if use_block_name else sch.get_block("A") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = "C" if use_block_name else sch.get_block("C") + block_a = "A" if use_block_name else sch.get_sblock("A") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b, block_c]) assert_structural_equal_ignore_global_symbol( cache_write_multi_consumer_all_consume_cache, sch.mod["main"] @@ -1574,7 +1574,7 @@ def test_cache_write_location(use_block_name): def test_continuous_cache_write(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") sch.cache_write(block_b, 0, "shared") sch.cache_write(block_b, 0, "local") assert_structural_equal_ignore_global_symbol(continuous_cache_write, sch.mod["main"]) @@ -1584,7 +1584,7 @@ def test_continuous_cache_write(use_block_name): def test_cache_write_with_block_predicate(use_block_name): # cache write for intermediate buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = "producer" if use_block_name else sch.get_block("producer") + block = "producer" if use_block_name else sch.get_sblock("producer") sch.cache_write(block, 0, "shared") assert_structural_equal_ignore_global_symbol( block_predicate_cache_write_intermediate_buf, sch.mod["main"] @@ -1592,7 +1592,7 @@ def test_cache_write_with_block_predicate(use_block_name): verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) # cache write for external buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = "consumer" if use_block_name else sch.get_block("consumer") + block = "consumer" if use_block_name else sch.get_sblock("consumer") sch.cache_write(block, 0, "shared") assert_structural_equal_ignore_global_symbol( block_predicate_cache_write_output_buf, sch.mod["main"] @@ -1602,8 +1602,8 @@ def test_cache_write_with_block_predicate(use_block_name): def test_cache_write_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") - block_a0 = "A0" if use_block_name else sch.get_block("A0") - block_a1 = "A1" if use_block_name else sch.get_block("A1") + block_a0 = "A0" if use_block_name else sch.get_sblock("A0") + block_a1 = "A1" if use_block_name else sch.get_sblock("A1") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_a0, 0, "global") with pytest.raises(tvm.tir.ScheduleError): @@ -1612,14 +1612,14 @@ def test_cache_write_fail_multi_producer(use_block_name): def test_cache_write_fail_index_out_of_bound(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_b, 1, "global") def test_cache_write_fail_invalid_storage_scope(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_b, 0, "test_scope") @@ -1641,13 +1641,13 @@ def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16" const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) for i, j in T.grid(128, 128): for x in range(8): - with T.block("B"): + with T.sblock("B"): vi, vj, vx = T.axis.remap("SSS", [i, j, x]) T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -1663,33 +1663,33 @@ def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float1 const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) for ax0, ax1 in T.grid(128, 128): - with T.block("A_global"): + with T.sblock("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v0, v1]) T.writes(A_global[v0, v1]) A_global[v0, v1] = A[v0, v1] for i, j, x in T.grid(128, 128, 8): - with T.block("B"): + with T.sblock("B"): vi, vj, vx = T.axis.remap("SSS", [i, j, x]) T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx]) T.writes(B[vi, vj]) B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx] for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C_global[vi, vj]) C_global[vi, vj] = B[vi, vj] + T.float32(1) for ax0, ax1 in T.grid(128, 128): - with T.block("C_global"): + with T.sblock("C_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(C_global[v0, v1]) T.writes(C[v0, v1]) C[v0, v1] = C_global[v0, v1] sch = tir.Schedule(before) - block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_b = sch.get_sblock("B") + block_c = sch.get_sblock("C") sch.cache_read(block_b, 0, "global") sch.cache_write(block_c, 0, "global") @@ -1765,7 +1765,7 @@ def test_reindex_cache_write_fail_not_single_point(): def test_symbolic_matmul_blocked_cache_read(use_block_name): sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") - block = "matmul" if use_block_name else sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_sblock("matmul") sch.cache_read(block=block, read_buffer_index=0, storage_scope="shared") assert_structural_equal_ignore_global_symbol( sch.mod["main"], symbolic_matmul_blocked_cache_read @@ -1775,7 +1775,7 @@ def test_symbolic_matmul_blocked_cache_read(use_block_name): def test_symbolic_matmul_blocked_cache_write(use_block_name): sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") - block = "matmul" if use_block_name else sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_sblock("matmul") sch.cache_write(block=block, write_buffer_index=0, storage_scope="local") assert_structural_equal_ignore_global_symbol( sch.mod["main"], symbolic_matmul_blocked_cache_write diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_at.py b/tests/python/tir-schedule/test_tir_schedule_compute_at.py index aa03dc2ba0e5..54712b5c11d2 100644 --- a/tests/python/tir-schedule/test_tir_schedule_compute_at.py +++ b/tests/python/tir-schedule/test_tir_schedule_compute_at.py @@ -34,11 +34,11 @@ def two_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -50,12 +50,12 @@ def two_elementwise_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for ax0, ax1 in T.grid(1, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i + ax0) vj = T.axis.S(128, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -66,11 +66,11 @@ def blockized_1(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(8, 8): - with T.block("C_outer"): + with T.sblock("C_outer"): vi_o, vj_o = T.axis.remap("SS", [i, j]) T.reads([B[ vi_o * 16 : vi_o * 16 + 16, @@ -81,7 +81,7 @@ def blockized_1(a: T.handle, c: T.handle) -> None: vj_o * 16 : vj_o * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block("C_inner"): + with T.sblock("C_inner"): vi = T.axis.S(128, vi_o * 16 + i_i) vj = T.axis.S(128, vj_o * 16 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @@ -94,11 +94,11 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i0_0, i1_0 in T.grid(8, 8): for ax0, ax1 in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i0_0 * 16 + ax0) vj = T.axis.S(128, i1_0 * 16 + ax1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block("C_outer"): + with T.sblock("C_outer"): vi_o, vj_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads([B[ vi_o * 16 : vi_o * 16 + 16, @@ -109,7 +109,7 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: vj_o * 16 : vj_o * 16 + 16 ]]) for i0_1, i1_1 in T.grid(16, 16): - with T.block("C_inner"): + with T.sblock("C_inner"): vi = T.axis.S(128, vi_o * 16 + i0_1) vj = T.axis.S(128, vj_o * 16 + i1_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -121,7 +121,7 @@ def blockized_2(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block("B_outer"): + with T.sblock("B_outer"): vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, @@ -132,12 +132,12 @@ def blockized_2(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block("B_inner"): + with T.sblock("B_inner"): vi = T.axis.S(128, vio * 16 + i_i) vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_o, j_o, i_i, j_i in T.grid(4, 4, 32, 32): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i_o * 32 + i_i) vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @@ -149,7 +149,7 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block("B_outer"): + with T.sblock("B_outer"): vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, @@ -160,12 +160,12 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block("B_inner"): + with T.sblock("B_inner"): vi = T.axis.S(128, vio * 16 + i_i) vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for ax0, ax1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i_o * 16 + ax0) vj = T.axis.S(128, j_o * 16 + ax1) T.reads([B[vi, vj]]) @@ -180,7 +180,7 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(4, 4): for ax0, ax1 in T.grid(2, 2): - with T.block("blockized_B"): + with T.sblock("blockized_B"): vio = T.axis.S(8, i_o * 2 + ax0) vjo = T.axis.S(8, j_o * 2 + ax1) T.reads([A[ @@ -192,12 +192,12 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16, ]]) for i_i, j_i in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, vio * 16 + i_i) vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_i, j_i in T.grid(32, 32): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i_o * 32 + i_i) vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @@ -213,23 +213,23 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): - with T.block("A_shared"): + with T.sblock("A_shared"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared"): + with T.sblock("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared_local[v0, v1] = B_shared[v0, v1] for i, j, k in T.grid(2048, 2048, 2048): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C_local[vi, vj] = 0.0 @@ -241,7 +241,7 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0_4 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1_4 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0_4, v1_4] = C_local[v0_4, v1_4] @@ -258,19 +258,19 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): - with T.block("A_shared"): + with T.sblock("A_shared"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared"): + with T.sblock("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): @@ -280,7 +280,7 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j, k in T.grid(4, 4, 2048): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k) @@ -288,7 +288,7 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -305,19 +305,19 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): - with T.block("A_shared"): + with T.sblock("A_shared"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared"): + with T.sblock("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): @@ -329,7 +329,7 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k_0 * 8 + k_1) @@ -337,7 +337,7 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -354,15 +354,15 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): - with T.block("A_shared"): + with T.sblock("A_shared"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared"): + with T.sblock("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): @@ -374,12 +374,12 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0 = T.axis.S(2048, k_0 * 8 + k_1 + i) v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k_0 * 8 + k_1) @@ -387,7 +387,7 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -404,11 +404,11 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): - with T.block("A_shared"): + with T.sblock("A_shared"): v0, v1 = T.axis.remap("SS", [i, j]) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(2048, 2048): - with T.block("B_shared"): + with T.sblock("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): @@ -420,17 +420,17 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k0 in T.serial(0, 256): for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -438,7 +438,7 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -455,7 +455,7 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for i, j in T.grid(2048, 2048): - with T.block("B_shared"): + with T.sblock("B_shared"): v0, v1 = T.axis.remap("SS", [i, j]) B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): @@ -466,23 +466,23 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.S(2048, k0 * 8 + i) v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -490,7 +490,7 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -514,28 +514,28 @@ def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.S(2048, k0 * 8 + i) v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(8, 64): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.S(2048, k0 * 8 + i) v1 = T.axis.S(2048, bx * 64 + j) B_shared[v0, v1] = B[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block("A_shared_local"): + with T.sblock("A_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block("B_shared_local"): + with T.sblock("B_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -543,7 +543,7 @@ def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -555,12 +555,12 @@ def tiled(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i_0 * 16 + i_1) vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -572,12 +572,12 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i_0 * 16 + i_1) vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 for j_1 in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i_0 * 16 + i_1) vj = T.axis.S(128, j_0 * 16 + j_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -589,12 +589,12 @@ def tiled_trivial_binding(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([1, 128, 128], "float32") C = T.match_buffer(c, [1, 128, 128], "float32") for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i_0 * 16 + i_1) vj = T.axis.S(128, j_0 * 16 + j_1) B[0, vi, vj] = A[0, vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[0, vi, vj] = B[0, vi, vj] + 1.0 @@ -606,12 +606,12 @@ def tiled_trivial_binding_after_reverse_compute_at(a: T.handle, c: T.handle) -> C = T.match_buffer(c, [1, 128, 128], "float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i_0 * 16 + i_1) vj = T.axis.S(128, j_0 * 16 + j_1) B[0, vi, vj] = A[0, vi, vj] * 2.0 for j_1 in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i_0 * 16 + i_1) vj = T.axis.S(128, j_0 * 16 + j_1) C[0, vi, vj] = B[0, vi, vj] + 1.0 @@ -625,14 +625,14 @@ def factorized(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block("B_rf"): + with T.sblock("B_rf"): vi = T.axis.S(16, i_o * 4 + i_i) vj, vk = T.axis.remap("SR", [j, k]) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for i, k in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 @@ -647,7 +647,7 @@ def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block("B_rf"): + with T.sblock("B_rf"): vi = T.axis.S(16, i_o * 4 + i_i) vj = T.axis.S(16, j) vk = T.axis.R(16, k) @@ -655,7 +655,7 @@ def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in T.serial(0, 4): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(16, j) vk = T.axis.R(16, i_o * 4 + k) with T.init(): @@ -669,14 +669,14 @@ def not_all_compact_data_flow(a: T.handle, c: T.handle): B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] for i, j in T.grid(128, 64): - with T.block("C_1"): + with T.sblock("C_1"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj * 2] = B[vi, vj * 2] + 1.0 - with T.block("C_2"): + with T.sblock("C_2"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj * 2 + 1] = B[vi, vj * 2 + 1] * 2.0 @@ -688,14 +688,14 @@ def not_all_compact_data_flow_after_compute_at(a: T.handle, c: T.handle): C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 64): for t in range(2): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) vj = T.axis.S(128, j * 2 + t) B[vi, vj] = A[vi, vj] - with T.block("C_1"): + with T.sblock("C_1"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj * 2] = B[vi, vj * 2] + 1.0 - with T.block("C_2"): + with T.sblock("C_2"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj * 2 + 1] = B[vi, vj * 2 + 1] * 2.0 @@ -707,17 +707,17 @@ def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for j in range(0, 64): - with T.block("B_0"): + with T.sblock("B_0"): vi = T.axis.S(128, i) vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 64): - with T.block("B_1"): + with T.sblock("B_1"): vi = T.axis.S(128, i) vj = T.axis.S(128, j + 64) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -729,15 +729,15 @@ def fail_all_consumers_under_loop(a: T.handle, c: T.handle, d: T.handle) -> None C = T.match_buffer(c, (128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + 1.0 @@ -749,15 +749,15 @@ def fail_all_producers_under_loop(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer((128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + C[vi, vj] @@ -768,11 +768,11 @@ def read_out_of_bound(a: T.handle, c:T.handle) -> None: B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for i in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): v = T.axis.S(16, i) B[v] = A[v] for j in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): v = T.axis.S(16, j) T.reads(B[v : v + 2]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -785,11 +785,11 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, 2): - with T.block("B"): + with T.sblock("B"): v = T.axis.S(16, j + i) T.where(j + i < 16) B[v] = A[v] - with T.block("C"): + with T.sblock("C"): v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -799,13 +799,13 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: def multi_reduction(A: T.Buffer((16, 16), "float32"), C: T.Buffer((), "float32")): B = T.alloc_buffer((16, ), dtype="float32") for i, k in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] += A[vi, vk] for k in T.grid(16): - with T.block("C"): + with T.sblock("C"): vk = T.axis.remap("R", [k]) with T.init(): C[()] = 0.0 @@ -820,12 +820,12 @@ def multi_reduction_after_compute_at( B = T.alloc_buffer((16, ), dtype="float32") for k in T.grid(16): for kk in T.grid(16): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [k, kk]) with T.init(): B[vi] = 0.0 B[vi] += A[vi, vk] - with T.block("C"): + with T.sblock("C"): vk = T.axis.remap("R", [k]) with T.init(): C[()] = 0.0 @@ -838,11 +838,11 @@ def tiled_pooling_read_cache(a: T.handle, b: T.handle) -> None: Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh, ww in T.grid(224, 224): - with T.block("cache"): + with T.sblock("cache"): h, w = T.axis.remap("SS", [hh, ww]) cache[h, w] = X[h, w] for hh_0, ww_0, hh_1, ww_1, khh, kww in T.grid(28, 28, 8, 8, 3, 3): - with T.block("compute"): + with T.sblock("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) @@ -862,13 +862,13 @@ def tiled_pooling_read_cache_after_compute_at(a: T.handle, b: T.handle) -> None: cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): for ax0, ax1 in T.grid(10, 10): - with T.block("cache"): + with T.sblock("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): - with T.block("compute"): + with T.sblock("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) @@ -887,11 +887,11 @@ def non_uniform_tiled_conv(x: T.Buffer((1, 3, 100, 100), "float32"), y: T.Buffer((1, 16, 98, 98), "float32")) -> None: x_global = T.alloc_buffer([1, 3, 100, 100], dtype="float32") for ax0, ax1, ax2, ax3 in T.grid(1, 3, 100, 100): - with T.block("cache"): + with T.sblock("cache"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] for h_o, w_o, n, c_o, h_i, w_i, c_i, kh, kw in T.grid(7, 7, 1, 16, 15, 15, 3, 3, 3): - with T.block("compute"): + with T.sblock("compute"): nn = T.axis.spatial(1, 0) cc = T.axis.spatial(16, c_o) hh = T.axis.spatial(98, h_o * 15 + h_i) @@ -910,7 +910,7 @@ def non_uniform_tiled_conv_after_compute_at(x: T.Buffer((1, 3, 100, 100), "float x_global = T.alloc_buffer([1, 3, 100, 100], dtype="float32") for h_o, w_o in T.grid(7, 7): for ax0, ax1, ax2 in T.grid(3, 17, 17): - with T.block("cache"): + with T.sblock("cache"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(3, ax0) v2 = T.axis.spatial(100, h_o * 15 + ax1) @@ -918,7 +918,7 @@ def non_uniform_tiled_conv_after_compute_at(x: T.Buffer((1, 3, 100, 100), "float T.where(h_o * 15 + ax1 < 100 and w_o * 15 + ax2 < 100) x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] for n, c_o, h_i, w_i, c_i, kh, kw in T.grid(1, 16, 15, 15, 3, 3, 3): - with T.block("compute"): + with T.sblock("compute"): nn = T.axis.spatial(1, 0) cc = T.axis.spatial(16, c_o) hh = T.axis.spatial(98, h_o * 15 + h_i) @@ -937,15 +937,15 @@ def concat_two_elemwise(x: T.Buffer((16,), "float32"), T_add_1 = T.alloc_buffer([16], dtype="float32") T_add_2 = T.alloc_buffer([8], dtype="float32") for i in T.serial(16): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax = T.axis.spatial(16, i) T_add_1[ax] = x[ax] + T.float32(1) for i in T.serial(8): - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax = T.axis.spatial(8, i) T_add_2[ax] = y[ax] + T.float32(2) for i in T.serial(24): - with T.block("T_concat"): + with T.sblock("T_concat"): ax = T.axis.spatial(24, i) T_concat[ax] = T.if_then_else(16 <= ax, T_add_2[ax - 16], T_add_1[ax], dtype="float32") @@ -956,15 +956,15 @@ def concat_two_elemwise_after_compute_at(x: T.Buffer((16,), "float32"), T_add_1 = T.alloc_buffer([16], dtype="float32") T_add_2 = T.alloc_buffer([8], dtype="float32") for i in T.serial(24): - with T.block("T_add_1"): + with T.sblock("T_add_1"): ax = T.axis.spatial(16, i) T.where(i < 16) T_add_1[ax] = x[ax] + T.float32(1) - with T.block("T_add_2"): + with T.sblock("T_add_2"): ax = T.axis.spatial(8, i - 16) T.where(16 <= i) T_add_2[ax] = y[ax] + T.float32(2) - with T.block("T_concat"): + with T.sblock("T_concat"): ax = T.axis.spatial(24, i) T_concat[ax] = T.if_then_else(16 <= ax, T_add_2[ax - 16], T_add_1[ax], dtype="float32") @@ -974,11 +974,11 @@ def floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None: Y = T.match_buffer(b, [256]) temp = T.alloc_buffer([16, 16]) for i, j in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): v_i, v_j = T.axis.remap("SS", [i, j]) temp[v_i, v_j] = X[v_j, v_i] + 1.0 for i in T.serial(0, 256): - with T.block("B"): + with T.sblock("B"): v_i = T.axis.remap("S", [i]) Y[v_i] = temp[v_i // 16, v_i % 16] @@ -989,11 +989,11 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han temp = T.alloc_buffer([16, 16], dtype="float32") for i in T.serial(0, 16): for j in T.serial(0, 16): - with T.block("A"): + with T.sblock("A"): v_i, v_j = T.axis.remap("SS", [i, j]) temp[v_i, v_j] = X[v_j, v_i] + T.float32(1) for ax0 in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): v_i = T.axis.spatial(256, i * 16 + ax0) Y[v_i] = temp[v_i // 16, v_i % 16] @@ -1002,16 +1002,16 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han def recursive_floordiv_floormod(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), C: T.Buffer((3, 512, 512), "float32")) -> None: T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((1, 128, 16, 8, 2, 32, 2), "float32") for axis1, axis2, axis3, axis4, axis5, axis6, axis7 in T.grid(1, 128, 16, 8, 2, 32, 2): - with T.block("In"): + with T.sblock("In"): v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7 = T.axis.remap("SSSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6, axis7]) T.reads(A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32]) T.writes(B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7]) B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7] = A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32] + 3 for ax1, ax2, ax3 in T.grid(3, 512, 512): - with T.block("Out"): + with T.sblock("Out"): v1, v2, v3 = T.axis.remap("SSS", [ax1, ax2, ax3]) T.reads(B[v1 // 8, v2 // 4, v3 // 32, v1, v2 % 4 // 2, v3 % 32, v2 % 2]) T.writes(C[v1, v2, v3]) @@ -1021,17 +1021,17 @@ def recursive_floordiv_floormod(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), @T.prim_func def recursive_floordiv_floormod_after_reverse_compute_at(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), C: T.Buffer((3, 512, 512), "float32")) -> None: T.func_attr({"tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): B = T.alloc_buffer((1, 128, 16, 8, 2, 32, 2)) for axis1, axis2, axis3 in T.grid(1, 128, 16): for axis4, axis5, axis6, axis7 in T.grid(8, 2, 32, 2): - with T.block("In"): + with T.sblock("In"): v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7 = T.axis.remap("SSSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6, axis7]) T.reads(A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32]) T.writes(B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7]) B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7] = A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32] + T.float32(3) for ax0, ax1, ax2 in T.grid(3, 4, 32): - with T.block("Out"): + with T.sblock("Out"): v1 = T.axis.spatial(3, ax0) v2 = T.axis.spatial(512, axis2 * 4 + ax1) v3 = T.axis.spatial(512, axis3 * 32 + ax2) @@ -1044,11 +1044,11 @@ def recursive_floordiv_floormod_after_reverse_compute_at(A: T.Buffer((16, 64, 1, def tiled_repeat_op(x: T.Buffer((4,), "float32"), T_repeat: T.Buffer((64,), "float32")) -> None: T_add = T.alloc_buffer([4], dtype="float32") for i0 in T.serial(4): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial(4, i0) T_add[ax0] = x[ax0] + 1.0 for i0_0, i0_1 in T.grid(8, 8): - with T.block("T_repeat"): + with T.sblock("T_repeat"): ax0 = T.axis.spatial(64, i0_0 * 8 + i0_1) T_repeat[ax0] = T_add[ax0 // 16] @@ -1056,11 +1056,11 @@ def tiled_repeat_op(x: T.Buffer((4,), "float32"), T_repeat: T.Buffer((64,), "flo def tiled_repeat_op_after_compute_at(x: T.Buffer((4,), "float32"), T_repeat: T.Buffer((64,), "float32")) -> None: T_add = T.alloc_buffer([4], dtype="float32") for i0_0 in T.serial(8): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial(4, i0_0 // 2) T_add[ax0] = x[ax0] + T.float32(1) for i0_1 in T.serial(8): - with T.block("T_repeat"): + with T.sblock("T_repeat"): ax0 = T.axis.spatial(64, i0_0 * 8 + i0_1) T_repeat[ax0] = T_add[ax0 // 16] @@ -1068,12 +1068,12 @@ def tiled_repeat_op_after_compute_at(x: T.Buffer((4,), "float32"), T_repeat: T.B def static_bound(A: T.Buffer((32, 1), "float32"), C: T.Buffer((32, 1), "float32")) -> None: B = T.alloc_buffer((32, 1), "float32") for i, j in T.grid(32, 1): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(32, i) vj = T.axis.spatial(1, j) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(32, 32): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(32, i) vj = T.axis.spatial(1, j) T.where(j < 1) @@ -1084,12 +1084,12 @@ def static_bound_after_compute_at(A: T.Buffer((32, 1), "float32"), C: T.Buffer(( B = T.alloc_buffer((32, 1), "float32") for i in range(32): for ax0, ax1 in T.grid(1, 1): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(32, i + ax0) vj = T.axis.spatial(1, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(32): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(32, i) vj = T.axis.spatial(1, j) T.where(j < 1) @@ -1102,8 +1102,8 @@ def static_bound_after_compute_at(A: T.Buffer((32, 1), "float32"), C: T.Buffer(( def test_compute_at_two_elementwise(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") - loop, _ = sch.get_loops("C" if use_block_name else sch.get_block("C")) + block = "B" if use_block_name else sch.get_sblock("B") + loop, _ = sch.get_loops("C" if use_block_name else sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(two_elementwise_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) @@ -1111,8 +1111,8 @@ def test_compute_at_two_elementwise(use_block_name): def test_compute_at_blockized_1(use_block_name): sch = tir.Schedule(blockized_1, debug_mask="all") - block = sch.get_block("B") - _, loop = sch.get_loops(sch.get_block("C_outer")) + block = sch.get_sblock("B") + _, loop = sch.get_loops(sch.get_sblock("C_outer")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(blockized_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=blockized_1) @@ -1120,8 +1120,8 @@ def test_compute_at_blockized_1(use_block_name): def test_compute_at_blockized_2(use_block_name): sch = tir.Schedule(blockized_2, debug_mask="all") - block = sch.get_block("B_outer") - _, loop, _, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B_outer") + _, loop, _, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(blockized_2_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=blockized_2) @@ -1129,8 +1129,8 @@ def test_compute_at_blockized_2(use_block_name): def test_compute_at_cuda_matmul_0(use_block_name): sch = tir.Schedule(cuda_matmul_0, debug_mask="all") - block = sch.get_block("C") - _, _, _, _, _, loop, _, _ = sch.get_loops(sch.get_block("C_local")) + block = sch.get_sblock("C") + _, _, _, _, _, loop, _, _ = sch.get_loops(sch.get_sblock("C_local")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(cuda_matmul_0_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_0) @@ -1138,8 +1138,8 @@ def test_compute_at_cuda_matmul_0(use_block_name): def test_compute_at_cuda_matmul_1(use_block_name): sch = tir.Schedule(cuda_matmul_1, debug_mask="all") - block = sch.get_block("A_shared_local") - _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("A_shared_local") + _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(cuda_matmul_2, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_1) @@ -1147,8 +1147,8 @@ def test_compute_at_cuda_matmul_1(use_block_name): def test_compute_at_cuda_matmul_2(use_block_name): sch = tir.Schedule(cuda_matmul_2, debug_mask="all") - block = sch.get_block("B_shared_local") - _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B_shared_local") + _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(cuda_matmul_3, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_2) @@ -1156,8 +1156,8 @@ def test_compute_at_cuda_matmul_2(use_block_name): def test_compute_at_cuda_matmul_3(use_block_name): sch = tir.Schedule(cuda_matmul_3, debug_mask="all") - block = sch.get_block("A_shared") - _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("A_shared") + _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(cuda_matmul_4, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_3) @@ -1165,8 +1165,8 @@ def test_compute_at_cuda_matmul_3(use_block_name): def test_compute_at_cuda_matmul_4(use_block_name): sch = tir.Schedule(cuda_matmul_4, debug_mask="all") - block = sch.get_block("B_shared") - _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B_shared") + _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(cuda_matmul_5, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4) @@ -1174,8 +1174,8 @@ def test_compute_at_cuda_matmul_4(use_block_name): def test_compute_at_reduction_block(use_block_name): sch = tir.Schedule(multi_reduction, debug_mask="all") - block = sch.get_block("B") - (loop,) = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B") + (loop,) = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=False) assert_structural_equal_ignore_global_symbol(multi_reduction_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=multi_reduction) @@ -1183,9 +1183,9 @@ def test_compute_at_reduction_block(use_block_name): def test_compute_at_tiled_pooling_read_cache(use_block_name): sch = tir.Schedule(tiled_pooling_read_cache, debug_mask="all") - compute = sch.get_block("compute") + compute = sch.get_sblock("compute") _, w_o, _, _, _, _ = sch.get_loops(compute) - cache = sch.get_block("cache") + cache = sch.get_sblock("cache") sch.compute_at(cache, w_o) assert_structural_equal_ignore_global_symbol( tiled_pooling_read_cache_after_compute_at, sch.mod["main"] @@ -1195,8 +1195,8 @@ def test_compute_at_tiled_pooling_read_cache(use_block_name): def test_compute_at_non_uniform_tiled_conv(use_block_name): sch = tir.Schedule(non_uniform_tiled_conv, debug_mask="all") - compute = sch.get_block("compute") - sch.compute_at(sch.get_block("cache"), sch.get_loops(compute)[1]) + compute = sch.get_sblock("compute") + sch.compute_at(sch.get_sblock("cache"), sch.get_loops(compute)[1]) assert_structural_equal_ignore_global_symbol( non_uniform_tiled_conv_after_compute_at, sch.mod["main"] ) @@ -1205,9 +1205,9 @@ def test_compute_at_non_uniform_tiled_conv(use_block_name): def test_compute_at_concat(use_block_name): sch = tir.Schedule(concat_two_elemwise, debug_mask="all") - concat = sch.get_block("T_concat") - add1 = sch.get_block("T_add_1") - add2 = sch.get_block("T_add_2") + concat = sch.get_sblock("T_concat") + add1 = sch.get_sblock("T_add_1") + add2 = sch.get_sblock("T_add_2") axis = sch.get_loops(concat)[0] sch.compute_at(add1, axis) sch.compute_at(add2, axis) @@ -1219,8 +1219,8 @@ def test_compute_at_concat(use_block_name): def test_compute_at_tiled_repeat_op(use_block_name): sch = tir.Schedule(tiled_repeat_op, debug_mask="all") - outer_ax, _ = sch.get_loops(sch.get_block("T_repeat")) - sch.compute_at(sch.get_block("T_add"), outer_ax) + outer_ax, _ = sch.get_loops(sch.get_sblock("T_repeat")) + sch.compute_at(sch.get_sblock("T_add"), outer_ax) assert_structural_equal_ignore_global_symbol(tiled_repeat_op_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op) @@ -1230,11 +1230,11 @@ def test_compute_at_rev_iter(): def before(X: T.Buffer((10, 10), "float32"), Z: T.Buffer((10, 10), "float32")): Y = T.alloc_buffer([10, 10], "float32") for i, j in T.grid(10, 10): - with T.block("b0"): + with T.sblock("b0"): vi, vj = T.axis.remap("SS", [i, j]) Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0 for i, j in T.grid(10, 10): - with T.block("b1"): + with T.sblock("b1"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Y[vj, vi] + 2.0 @@ -1243,26 +1243,26 @@ def after(X: T.Buffer((10, 10), "float32"), Z: T.Buffer((10, 10), "float32")): Y = T.alloc_buffer([10, 10], "float32") for i in range(10): for j in range(10): - with T.block("b0"): + with T.sblock("b0"): vi = T.axis.spatial(10, j) vj = T.axis.spatial(10, 9 - i) Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0 for j in range(10): - with T.block("b1"): + with T.sblock("b1"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Y[vj, vi] + 2.0 sch = tir.Schedule(before, debug_mask="all") - axis = sch.get_loops(sch.get_block("b1"))[0] - sch.compute_at(sch.get_block("b0"), axis) + axis = sch.get_loops(sch.get_sblock("b1"))[0] + sch.compute_at(sch.get_sblock("b0"), axis) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=before) def test_reverse_compute_at_tiled(use_block_name): sch = tir.Schedule(tiled, debug_mask="all") - block = sch.get_block("C") - _, _, loop, _ = sch.get_loops(sch.get_block("B")) + block = sch.get_sblock("C") + _, _, loop, _ = sch.get_loops(sch.get_sblock("B")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) assert_structural_equal_ignore_global_symbol(tiled_after_reverse_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=tiled) @@ -1270,8 +1270,8 @@ def test_reverse_compute_at_tiled(use_block_name): def test_reverse_compute_at_tiled_trivial_binding(use_block_name): sch = tir.Schedule(tiled_trivial_binding, debug_mask="all") - block = sch.get_block("C") - _, _, loop, _ = sch.get_loops(sch.get_block("B")) + block = sch.get_sblock("C") + _, _, loop, _ = sch.get_loops(sch.get_sblock("B")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) assert_structural_equal_ignore_global_symbol( tiled_trivial_binding_after_reverse_compute_at, sch.mod["main"] @@ -1281,8 +1281,8 @@ def test_reverse_compute_at_tiled_trivial_binding(use_block_name): def test_reverse_compute_at_blockized_2(use_block_name): sch = tir.Schedule(blockized_2, debug_mask="all") - block = sch.get_block("C") - _, loop = sch.get_loops(sch.get_block("B_outer")) + block = sch.get_sblock("C") + _, loop = sch.get_loops(sch.get_sblock("B_outer")) sch.reverse_compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol( blockized_2_after_reverse_compute_at, sch.mod["main"] @@ -1292,8 +1292,8 @@ def test_reverse_compute_at_blockized_2(use_block_name): def test_reverse_compute_at_factorized(use_block_name): sch = tir.Schedule(factorized, debug_mask="all") - block = sch.get_block("B") - _, loop, _, _ = sch.get_loops(sch.get_block("B_rf")) + block = sch.get_sblock("B") + _, loop, _, _ = sch.get_loops(sch.get_sblock("B_rf")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) assert_structural_equal_ignore_global_symbol( factorized_after_reverse_compute_at, sch.mod["main"] @@ -1303,8 +1303,8 @@ def test_reverse_compute_at_factorized(use_block_name): def test_reverse_compute_at_floordiv_and_floormod_indices(use_block_name): sch = tir.Schedule(floordiv_and_floormod_indices, debug_mask="all") - A = sch.get_block("A") - B = sch.get_block("B") + A = sch.get_sblock("A") + B = sch.get_sblock("B") sch.reverse_compute_at(B, sch.get_loops(A)[0]) assert_structural_equal_ignore_global_symbol( floordiv_and_floormod_indices_after_reverse_compute_at, sch.mod["main"] @@ -1314,7 +1314,7 @@ def test_reverse_compute_at_floordiv_and_floormod_indices(use_block_name): def test_reverse_compute_at_floordiv_and_floormod_recursive(use_block_name): sch = tir.Schedule(recursive_floordiv_floormod, debug_mask="all") - write_block = sch.get_block("Out") + write_block = sch.get_sblock("Out") sch.reverse_compute_at(write_block, sch.get_loops("In")[2]) assert_structural_equal_ignore_global_symbol( recursive_floordiv_floormod_after_reverse_compute_at, sch.mod["main"] @@ -1324,8 +1324,8 @@ def test_reverse_compute_at_floordiv_and_floormod_recursive(use_block_name): def test_read_out_of_bound(use_block_name): sch = tir.Schedule(read_out_of_bound, debug_mask="all") - block = sch.get_block("B") - (loop,) = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B") + (loop,) = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop) assert_structural_equal_ignore_global_symbol( read_out_of_bound_after_compute_at, sch.mod["main"] @@ -1335,8 +1335,8 @@ def test_read_out_of_bound(use_block_name): def test_compact_dataflow(use_block_name): sch = tir.Schedule(not_all_compact_data_flow, debug_mask="all") - block = sch.get_block("B") - _, loop = sch.get_loops(sch.get_block("C_1")) + block = sch.get_sblock("B") + _, loop = sch.get_loops(sch.get_sblock("C_1")) sch.compute_at(block, loop) assert_structural_equal_ignore_global_symbol( not_all_compact_data_flow_after_compute_at, sch.mod["main"] @@ -1346,8 +1346,8 @@ def test_compact_dataflow(use_block_name): def test_compute_at_simplify_static_bound(use_block_name): sch = tir.Schedule(static_bound, debug_mask="all") - block = sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B") + loop, _ = sch.get_loops(sch.get_sblock("C")) sch.compute_at(block, loop, preserve_unit_loops=True) assert_structural_equal_ignore_global_symbol(static_bound_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=static_bound) @@ -1361,7 +1361,7 @@ def main(x: T.handle, y: T.handle, n: T.int64): X = T.match_buffer(x, (T.int64(8), n * 32), "float32") Y = T.match_buffer(y, (T.int64(8), n * 32), "float32") for i, k in T.grid(T.int64(8), n * 32): - with T.block("Y"): + with T.sblock("Y"): vi, vk = T.axis.remap("SS", [i, k]) Y[vi, vk] = X[vi, vk] @@ -1375,19 +1375,19 @@ def main(x: T.handle, y: T.handle, n: T.int64): for i, k_0 in T.grid(T.int64(8), n): for ax0 in range(T.int64(32)): - with T.block("X_global"): + with T.sblock("X_global"): v0 = T.axis.spatial(T.int64(8), i) v1 = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) + ax0) X_global[v0, v1] = X[v0, v1] for k_1 in range(T.int64(32)): - with T.block("Y"): + with T.sblock("Y"): vi = T.axis.spatial(T.int64(8), i) vk = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) + k_1) Y[vi, vk] = X_global[vi, vk] sch = tir.Schedule(Before, debug_mask="all") - block = sch.get_block("Y") - i, k = sch.get_loops(sch.get_block("Y")) + block = sch.get_sblock("Y") + i, k = sch.get_loops(sch.get_sblock("Y")) ko, ki = sch.split(k, [None, 32]) XX = sch.cache_read(block, 0, "global") sch.compute_at(XX, ko) @@ -1401,11 +1401,11 @@ def grouped_channel_bias( ): B = T.alloc_buffer([45], dtype="float32", scope="") for i in T.grid(45): - with T.block("init"): + with T.sblock("init"): vi = T.axis.remap("S", [i]) B[vi] = vi for c_o, h, w, c_i in T.grid(2, 8, 8, 360): - with T.block("compute"): + with T.sblock("compute"): hh, ww = T.axis.remap("SS", [h, w]) cc = T.axis.spatial(720, c_o * 360 + c_i) Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] @@ -1417,18 +1417,18 @@ def grouped_channel_bias_non_perfect_tiled( B = T.alloc_buffer([45], dtype="float32") for c_o in range(2): for ax0 in range(23): - with T.block("init"): + with T.sblock("init"): vi = T.axis.spatial(45, c_o * 22 + ax0) B[vi] = vi for h, w, c_i in T.grid(8, 8, 360): - with T.block("compute"): + with T.sblock("compute"): hh, ww = T.axis.remap("SS", [h, w]) cc = T.axis.spatial(720, c_o * 360 + c_i) Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] sch = tir.Schedule(grouped_channel_bias, debug_mask="all") - loop = sch.get_loops(sch.get_block("compute"))[0] - sch.compute_at(sch.get_block("init"), loop) + loop = sch.get_loops(sch.get_sblock("compute"))[0] + sch.compute_at(sch.get_sblock("init"), loop) assert_structural_equal_ignore_global_symbol( sch.mod["main"], grouped_channel_bias_non_perfect_tiled ) @@ -1436,48 +1436,48 @@ def grouped_channel_bias_non_perfect_tiled( def test_fail_subtree_complete_block(use_block_name): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") - block = sch.get_block("B_0") - loop, _ = sch.get_loops(sch.get_block("C")) + block = sch.get_sblock("B_0") + loop, _ = sch.get_loops(sch.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError, match="complete block"): sch.compute_at(block, loop) def test_fail_not_in_same_scope(use_block_name): sch = tir.Schedule(blockized_1, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("C_inner")) + block = "B" if use_block_name else sch.get_sblock("B") + loop, _ = sch.get_loops(sch.get_sblock("C_inner")) with pytest.raises(tvm.tir.ScheduleError, match="same block scope"): sch.compute_at(block, loop) def test_fail_loop_is_ancestor_of_block(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("B")) + block = "B" if use_block_name else sch.get_sblock("B") + loop, _ = sch.get_loops(sch.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError, match="ancestor of block"): sch.compute_at(block, loop) def test_fail_output_block(use_block_name): sch = tir.Schedule(tiled, debug_mask="all") - block = "C" if use_block_name else sch.get_block("C") - loop, _, _, _ = sch.get_loops(sch.get_block("B")) + block = "C" if use_block_name else sch.get_sblock("C") + loop, _, _, _ = sch.get_loops(sch.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError, match="output block"): sch.compute_at(block, loop) def test_fail_all_consumers_under_loop(use_block_name): sch = tir.Schedule(fail_all_consumers_under_loop, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("C")) + block = "B" if use_block_name else sch.get_sblock("B") + loop, _ = sch.get_loops(sch.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError, match="requires all the consumer"): sch.compute_at(block, loop) def test_fail_all_producers_under_loop(use_block_name): sch = tir.Schedule(fail_all_producers_under_loop, debug_mask="all") - block = "D" if use_block_name else sch.get_block("D") - loop, _ = sch.get_loops(sch.get_block("C")) + block = "D" if use_block_name else sch.get_sblock("D") + loop, _ = sch.get_loops(sch.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError, match="requires all the producer"): sch.reverse_compute_at(block, loop) @@ -1494,8 +1494,8 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") - block_d = "D" if use_block_name else sch.get_block("D") + block_c = "C" if use_block_name else sch.get_sblock("C") + block_d = "D" if use_block_name else sch.get_sblock("D") i, _ = sch.get_loops(block_d) sch.compute_at(block_c, i) verify_trace_roundtrip(sch=sch, mod=mod) @@ -1511,7 +1511,7 @@ def multi_producers_conv( pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8") wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8") for i0, i1, i2, i3 in T.grid(1, 3, 230, 230): - with T.block("pad"): + with T.sblock("pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3]) T.writes(pad[i0_1, i1_1, i2_1, i3_1]) @@ -1523,13 +1523,13 @@ def multi_producers_conv( ) for i0 in T.serial(1): for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7): - with T.block("wbuf"): + with T.sblock("wbuf"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w[v0, v1, v2, v3]) T.writes(wbuf[v0, v1, v2, v3]) wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3] for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7): - with T.block("conv"): + with T.sblock("conv"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6] ) @@ -1551,7 +1551,7 @@ def multi_producers_after_compute_at( wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8") for i0 in T.serial(1): for ax0, ax1, ax2 in T.grid(3, 229, 229): - with T.block("pad"): + with T.sblock("pad"): i0_1 = T.axis.spatial(1, 0) i1_1 = T.axis.spatial(3, ax0) i2_1 = T.axis.spatial(230, ax1) @@ -1565,13 +1565,13 @@ def multi_producers_after_compute_at( dtype="int8", ) for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7): - with T.block("wbuf"): + with T.sblock("wbuf"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(w[v0, v1, v2, v3]) T.writes(wbuf[v0, v1, v2, v3]) wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3] for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7): - with T.block("conv"): + with T.sblock("conv"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6] ) @@ -1584,7 +1584,7 @@ def multi_producers_after_compute_at( ) * T.cast(wbuf[ff, rc, ry, rx], "int32") sch = tir.Schedule(multi_producers_conv, debug_mask="all") - block_c = sch.get_block("pad") + block_c = sch.get_sblock("pad") axis = sch.get_loops("conv")[0] sch.compute_at(block_c, axis, index=-2) assert_structural_equal_ignore_global_symbol(multi_producers_after_compute_at, sch.mod["main"]) @@ -1597,21 +1597,21 @@ def main(A: T.Buffer((128, 128), "float32"), D: T.Buffer((128, 128), "float32")) C = T.alloc_buffer([128, 128], dtype="float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + j_1) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for ax0 in T.serial(16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + ax0) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(D[vi, vj]) @@ -1625,21 +1625,21 @@ def main_reverse_compute_at( C = T.alloc_buffer([128, 128], dtype="float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + j_1) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for ax0 in T.serial(16): - with T.block("D"): + with T.sblock("D"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + ax0) T.reads(B[vi, vj]) T.writes(D[vi, vj]) D[vi, vj] = B[vi, vj] + T.float32(1) for ax0 in T.serial(16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + ax0) T.reads(B[vi, vj]) @@ -1647,7 +1647,7 @@ def main_reverse_compute_at( C[vi, vj] = B[vi, vj] + T.float32(1) sch = tir.Schedule(main, debug_mask="all") - block_c = sch.get_block("D") + block_c = sch.get_sblock("D") axis = sch.get_loops("B")[2] sch.reverse_compute_at(block_c, axis, index=1) assert_structural_equal_ignore_global_symbol(main_reverse_compute_at, sch.mod["main"]) @@ -1659,14 +1659,14 @@ def main(A: T.Buffer((128, 128), "float32"), D: T.Buffer((1, 2, 1), "float32")) B = T.alloc_buffer([128, 128], dtype="float32") for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)): for j_1 in T.serial(T.int64(16)): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1) vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(1)): - with T.block("D"): + with T.sblock("D"): v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(B[v0, v1]) T.writes(D[v0, v1, v2]) @@ -1679,14 +1679,14 @@ def main_reverse_compute_at( B = T.alloc_buffer([128, 128], dtype="float32") for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)): for j_1 in T.serial(T.int64(16)): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1) vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(16), T.int64(1)): - with T.block("D"): + with T.sblock("D"): T.where( i_0 * T.int64(16) + i_1 < T.int64(1) and j_0 * T.int64(16) + ax1 < T.int64(2) @@ -1699,7 +1699,7 @@ def main_reverse_compute_at( D[v0, v1, v2] = B[v0, v1] + T.float32(1) sch = tir.Schedule(main, debug_mask="all") - block_d = sch.get_block("D") + block_d = sch.get_sblock("D") axis = sch.get_loops("B")[2] sch.reverse_compute_at(block_d, axis, preserve_unit_loops=True, index=1) assert_structural_equal_ignore_global_symbol(main_reverse_compute_at, sch.mod["main"]) @@ -1710,11 +1710,11 @@ def test_reverse_compute_at_layout_trans(): def before(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), "float32")): B = T.alloc_buffer((1, 3, 5, 5, 16)) for i0, i1, i2, i3, i4 in T.grid(1, 3, 5, 5, 16): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + T.float32(1) for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 6, 5, 5, 8): - with T.block("T_layout_trans"): + with T.sblock("T_layout_trans"): v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[ v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8 + v_ax4) % 16 @@ -1725,11 +1725,11 @@ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), B = T.alloc_buffer((1, 3, 5, 5, 16)) for i0, i1 in T.grid(1, 3): for i2, i3, i4 in T.grid(5, 5, 16): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + T.float32(1) for ax0, ax1, ax2, ax3 in T.grid(2, 5, 5, 8): - with T.block("T_layout_trans"): + with T.sblock("T_layout_trans"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(6, i1 * 2 + ax0) v_ax2, v_ax3, v_ax4 = T.axis.remap("SSS", [ax1, ax2, ax3]) @@ -1738,7 +1738,7 @@ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), ] sch = tir.Schedule(before, debug_mask="all") - trans = sch.get_block("T_layout_trans") + trans = sch.get_sblock("T_layout_trans") axis = sch.get_loops("compute")[1] sch.reverse_compute_at(trans, axis) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) @@ -1761,12 +1761,12 @@ def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) offset = apply_decl_buffer([4], data=offset_ptr) for i in range(4): - with T.block("compute_B"): + with T.sblock("compute_B"): vi = T.axis.remap("S", [i]) B[vi] = 10.0 * vi + offset[vi] for i, j in T.grid(4, 256): - with T.block("compute_C"): + with T.sblock("compute_C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi] + 100.0 * vj @@ -1777,22 +1777,22 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) offset = apply_decl_buffer([4], data=offset_ptr) for i in range(4): - with T.block("compute_B"): + with T.sblock("compute_B"): vi = T.axis.remap("S", [i]) B[vi] = 10.0 * vi + offset[vi] for j in range(256): - with T.block("compute_C"): + with T.sblock("compute_C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi] + 100.0 * vj sch = tir.Schedule(before, debug_mask="all") if use_reverse_compute_at: - block = sch.get_block("compute_C") + block = sch.get_sblock("compute_C") axis = sch.get_loops("compute_B")[0] sch.reverse_compute_at(block, axis) else: - block = sch.get_block("compute_B") + block = sch.get_sblock("compute_B") axis = sch.get_loops("compute_C")[0] sch.compute_at(block, axis) @@ -1817,12 +1817,12 @@ def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) offset = apply_decl_buffer([4], data=offset_ptr) for i in range(4): - with T.block("compute_B"): + with T.sblock("compute_B"): vi = T.axis.remap("S", [i]) B[vi] = 10.0 * vi + offset[vi] for i, j in T.grid(4, 256): - with T.block("compute_C"): + with T.sblock("compute_C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi] + 100.0 * vj @@ -1831,12 +1831,12 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) offset = apply_decl_buffer([4], data=offset_ptr) for i, j in T.grid(4, 256): - with T.block("compute_C"): + with T.sblock("compute_C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = (10.0 * vi + offset[vi]) + 100.0 * vj sch = tir.Schedule(before, debug_mask="all") - block = sch.get_block("compute_B") + block = sch.get_sblock("compute_B") sch.compute_inline(block) after = sch.mod["main"] @@ -1852,10 +1852,10 @@ def before(a: T.handle, b: T.handle, c: T.handle): A = T.match_buffer(a, (32, 1, 128)) B = T.match_buffer(b, (32, n, 128)) C = T.match_buffer(c, (32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): C_rf = T.alloc_buffer((128, 32, 1, n)) for ax0_ax1_fused, ax2_fused_1, ax2_fused_0 in T.grid(n * 32, 128, 1): - with T.block("NT_matmul_rf"): + with T.sblock("NT_matmul_rf"): vax2_fused_1 = T.axis.spatial(128, ax2_fused_1) v0 = T.axis.spatial(32, ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1866,7 +1866,7 @@ def before(a: T.handle, b: T.handle, c: T.handle): C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0) C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1] for ax0_ax1_fused, ax2_fused_1 in T.grid(n * 32, 128): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_1 = T.axis.reduce(128, ax2_fused_1) v0 = T.axis.spatial(32, ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1881,11 +1881,11 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): n = T.int32() B = T.match_buffer(b, (32, n, 128)) C = T.match_buffer(c, (32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): C_rf = T.alloc_buffer((128, 32, 1, n)) for ax0_ax1_fused in range(n * 32): for ax2_fused_1, ax2_fused_0 in T.grid(128, 1): - with T.block("NT_matmul_rf"): + with T.sblock("NT_matmul_rf"): vax2_fused_1 = T.axis.spatial(128, ax2_fused_1) v0 = T.axis.spatial(32, ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1896,7 +1896,7 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0) C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1] for ax0, ax1, ax2 in T.grid(128, 1, 1): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_1 = T.axis.reduce(128, ax0) v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax1) v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax2) @@ -1907,8 +1907,8 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1] # fmt: on sch = tir.Schedule(before.with_attr("global_symbol", "main"), debug_mask="all") - block = sch.get_block("NT_matmul") - loop, _, _ = sch.get_loops(sch.get_block("NT_matmul_rf")) + block = sch.get_sblock("NT_matmul") + loop, _, _ = sch.get_loops(sch.get_sblock("NT_matmul_rf")) sch.reverse_compute_at(block, loop, preserve_unit_loops=True) tvm.ir.assert_structural_equal( sch.mod["main"], expected.with_attr("global_symbol", "main"), True @@ -1924,11 +1924,11 @@ def before(): Concat = T.alloc_buffer((1, 101, 28, 64), "float32") Slice = T.alloc_buffer((1, 87, 28, 64), "float32") for ax0, ax1, ax2, ax3 in T.grid(1, 16, 28, 64): - with T.block("compute"): + with T.sblock("compute"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) X[v_ax0, v_ax1, v_ax2, v_ax3] = 1.0 for ax0, ax1, ax2, ax3 in T.grid(1, 101, 28, 64): - with T.block("T_concat"): + with T.sblock("T_concat"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else( 85 <= v_ax1, @@ -1940,7 +1940,7 @@ def before(): ), ) for ax0, ax1, ax2, ax3 in T.grid(1, 87, 28, 64): - with T.block("T_strided_slice"): + with T.sblock("T_strided_slice"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3] @@ -1953,13 +1953,13 @@ def expect(): Slice = T.alloc_buffer((1, 87, 28, 64)) for ax0 in range(1): for ax0_1, ax1, ax2 in T.grid(2, 28, 64): - with T.block("compute"): + with T.sblock("compute"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(16, ax0_1) v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2]) X[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(1) for ax0_1, ax1, ax2 in T.grid(87, 28, 64): - with T.block("T_concat"): + with T.sblock("T_concat"): v_ax0 = T.axis.spatial(1, 0) v_ax1 = T.axis.spatial(101, ax0_1) v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2]) @@ -1973,14 +1973,14 @@ def expect(): ), ) for ax1, ax2, ax3 in T.grid(87, 28, 64): - with T.block("T_strided_slice"): + with T.sblock("T_strided_slice"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3] sch = tir.Schedule(before, debug_mask="all") - blk1 = sch.get_block("compute") - blk2 = sch.get_block("T_concat") - blk3 = sch.get_block("T_strided_slice") + blk1 = sch.get_sblock("compute") + blk2 = sch.get_sblock("T_concat") + blk3 = sch.get_sblock("T_strided_slice") loop = sch.get_loops(blk3)[0] sch.compute_at(blk2, loop) sch.compute_at(blk1, loop) diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py index 87015bfcca4c..26b2e06d5672 100644 --- a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py +++ b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py @@ -35,11 +35,11 @@ def elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -51,15 +51,15 @@ def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) - C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers @@ -70,11 +70,11 @@ def elementwise_multi_consumer_inlined(a: T.handle, c: T.handle, d: T.handle) -> C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 for i, j in T.grid(128, 128): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] @@ -85,11 +85,11 @@ def elementwise_standalone(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] + 1.0 @@ -99,7 +99,7 @@ def elementwise_standalone_dce(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] + 1.0 @@ -111,11 +111,11 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -125,7 +125,7 @@ def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -137,12 +137,12 @@ def fail_multi_reader_writer(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer((128, 128)) D = T.match_buffer(d, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 C[vi, vj] = A[vi, vj] + 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + C[vi, vj] @@ -153,11 +153,11 @@ def elementwise_multi_reverse_loads(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 @@ -167,7 +167,7 @@ def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 @@ -178,11 +178,11 @@ def elementwise_reverse_affine_load( ) -> None: B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j, k, l in T.grid(8, 32, 8, 8): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk, vl] = B[ ((((vi * 32) + vj) * 8 + vk) * 8 + vl) // 128, @@ -195,7 +195,7 @@ def elementwise_reverse_affine_load_inlined( A: T.Buffer((128, 128), "float32"), C: T.Buffer((8, 32, 8, 8), "float32") ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[ (vj + vi * 128) // 2048, @@ -215,11 +215,11 @@ def elementwise_reverse_affine_load_unit_iter( ) -> None: C = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 for i, j, k, l in T.grid(1, 8, 16, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) D[vi, vj, vk, vl] = C[vj * 16 + vk, vl] + B[vj, vk, vi] @@ -231,7 +231,7 @@ def elementwise_reverse_affine_load_unit_iter_inlined( D: T.Buffer((1, 8, 16, 128), "float32"), ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0] @@ -244,11 +244,11 @@ def elementwise_reverse_affine_load_unit_iter_simplified( ) -> None: C = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 for i, j, k in T.grid(8, 16, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) D[0, vi, vj, vk] = C[vi * 16 + vj, vk] + B[vi, vj, 0] @@ -260,7 +260,7 @@ def elementwise_reverse_affine_load_unit_iter_simplified_inlined( D: T.Buffer((1, 8, 16, 128), "float32"), ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0] @@ -272,15 +272,15 @@ def elementwise_reverse_affine_chain( B = T.alloc_buffer((128, 128)) C = T.alloc_buffer((8, 16, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j, k in T.grid(8, 16, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[vi * 16 + vj, vk] + 1.0 for i, j, k, l in T.grid(1, 8, 16, 128): - with T.block("D"): + with T.sblock("D"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) D[vi, vj, vk, vl] = C[vj, vk, vl] @@ -290,7 +290,7 @@ def elementwise_reverse_affine_chain_inlined( A: T.Buffer((128, 128), "float32"), D: T.Buffer((1, 8, 16, 128), "float32") ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + 1.0 @@ -302,11 +302,11 @@ def elementwise_multi_reverse_affine_load( ) -> None: B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j, k in T.grid(8, 16, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[vi * 16 + vj, vk] + B[vi * 16 + vj, vk] @@ -317,7 +317,7 @@ def elementwise_multi_reverse_affine_load_inlined( C: T.Buffer((8, 16, 128), "float32"), ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + A[vi, vj] * 2.0 @@ -328,11 +328,11 @@ def elementwise_reverse_non_affine_load( ) -> None: B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j, k in T.grid(8, 16, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[vi * 16 + vj, vi * 16 + vj] @@ -343,11 +343,11 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) @@ -361,11 +361,11 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[0:128, 0:128]) T.writes(C[0:128, 0:128]) @@ -380,11 +380,11 @@ def buffer_matched(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0 @@ -396,11 +396,11 @@ def elementwise_predicate(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0 @@ -411,7 +411,7 @@ def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -423,11 +423,11 @@ def elementwise_multi_loads(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] @@ -437,7 +437,7 @@ def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 @@ -448,18 +448,18 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) BB = T.alloc_buffer([1024]) - with T.block("opaque"): + with T.sblock("opaque"): # annotated opaque partial access T.reads(A[0:512]) T.writes(A_cache[0:512]) T.evaluate(A.access_ptr("r", extent=512)) T.evaluate(A_cache.access_ptr("w", extent=512)) for i in range(512): - with T.block("BB"): + with T.sblock("BB"): vi = T.axis.remap("S", [i]) BB[vi] = A_cache[vi] * 2.0 for i in range(512): - with T.block("B"): + with T.sblock("B"): vi = T.axis.remap("S", [i]) B[vi] = BB[vi] + 1.0 @@ -469,14 +469,14 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024], dtype="float32") B = T.match_buffer(b, [1024], dtype="float32") A_cache = T.alloc_buffer([1024], dtype="float32") - with T.block("opaque"): + with T.sblock("opaque"): # annotated opaque partial access should be kept T.reads(A[0:512]) T.writes([A_cache[0:512]]) T.evaluate(A.access_ptr("r", extent=512)) T.evaluate(A_cache.access_ptr("w", extent=512)) for i in T.serial(0, 512): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(512, i) T.reads([A_cache[vi]]) T.writes([B[vi]]) @@ -490,7 +490,7 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None compute = T.match_buffer(var_compute, [512, 512], dtype="float32") C = T.alloc_buffer([512, 512], dtype="float32") for i0, i1, i2 in T.grid(512, 512, 512): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads([C[i, j], A[i, k], B[k, j]]) T.writes([C[i, j]]) @@ -498,7 +498,7 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] for i0, i1 in T.grid(512, 512): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads([C[i0_1, i1_1]]) T.writes([compute[i0_1, i1_1]]) @@ -511,11 +511,11 @@ def elementwise_output(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -527,7 +527,7 @@ def inline_block_with_init( ) -> None: B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(1, 512, 1, 1, 49, 1): - with T.block("tensor_rf"): + with T.sblock("tensor_rf"): vi4 = T.axis.spatial(49, i4) ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) @@ -546,7 +546,7 @@ def inline_block_with_init( ) for i0, i1 in T.grid(1, 512): for ax0, ax1, ax2, ax3, ax4 in T.grid(49, 1, 1, 1, 1): - with T.block("tensor"): + with T.sblock("tensor"): vi4, ax0_1 = T.axis.remap("RS", [ax0, ax1]) ax1_1 = T.axis.spatial(512, i1 + ax2) ax2_1, ax3_1 = T.axis.remap("SS", [ax3, ax4]) @@ -565,13 +565,13 @@ def exp_exp_opaque_access_with_tvm_access_ptr( ) -> None: compute_1 = T.alloc_buffer([16], dtype="float16") for i0 in T.serial(16): - with T.block("compute"): + with T.sblock("compute"): i0_1 = T.axis.spatial(16, i0) T.reads(x[i0_1]) T.writes(compute_1[i0_1]) compute_1[i0_1] = T.exp(x[i0_1], dtype="float16") for i0 in T.serial(16): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_2 = T.axis.spatial(16, i0) T.reads(lookup_table[0:1024], compute_1[i0_2]) T.writes(compute[i0_2]) @@ -589,7 +589,7 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined( compute: T.Buffer((16,), "float16"), ) -> None: for i0 in T.serial(16): - with T.block("compute_1"): + with T.sblock("compute_1"): i0_1 = T.axis.spatial(16, i0) # Do not put the opaque access to new write region when opaque access # wrapped with a tvm_access_ptr and the access mask set to "read only" @@ -608,11 +608,11 @@ def elementwise_overcomputed_producer( ) -> None: B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(127, 127): - with T.block("C"): + with T.sblock("C"): cvi, cvj = T.axis.remap("SS", [i, j]) C[cvi, cvj] = B[cvi, cvj] + 1.0 @@ -622,7 +622,7 @@ def elementwise_overcomputed_producer_reverse_inlined( A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.where(i < 127 and j < 127) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -634,12 +634,12 @@ def elementwise_overcomputed_producer_simplify_predicate( ) -> None: B = T.alloc_buffer((128, 128)) for i in T.grid(16384): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i // 128) vj = T.axis.spatial(128, i % 128) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(127, 127): - with T.block("C"): + with T.sblock("C"): cvi, cvj = T.axis.remap("SS", [i, j]) C[cvi, cvj] = B[cvi, cvj] + 1.0 @@ -649,7 +649,7 @@ def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined( A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") ) -> None: for i in T.grid(16384): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i // 128) vj = T.axis.spatial(128, i % 128) T.where(i < 16255 and i % 128 < 127) @@ -662,11 +662,11 @@ def elementwise_overcomputed_producer_injective_load( ) -> None: B = T.alloc_buffer((8, 8, 16, 16)) for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) B[vi, vj, vm, vn] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 for i, j in T.grid(127, 127): - with T.block("C"): + with T.sblock("C"): cvi, cvj = T.axis.remap("SS", [i, j]) C[cvi, cvj] = B[cvi // 16, cvj // 16, cvi % 16, cvj % 16] + 1.0 @@ -676,7 +676,7 @@ def elementwise_overcomputed_producer_injective_load_reverse_inlined( A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") ) -> None: for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127) C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 @@ -688,11 +688,11 @@ def elementwise_producer_not_cover_consumer( ) -> None: B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(256, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32") @@ -703,13 +703,13 @@ def elementwise_producer_is_reduction( ) -> None: B = T.alloc_buffer((128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SR", [i, j]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A[vi, vj] for i in T.grid(128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.remap("S", [i]) D[vi] = B[vi] + 1.0 @@ -720,12 +720,12 @@ def elementwise_predicate_producer(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((127, 128)) C = T.match_buffer(c, (127, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.where(i < 127) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(127, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -735,7 +735,7 @@ def elementwise_predicate_producer_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (127, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): T.where(i < 127) vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) @@ -751,10 +751,10 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) + T.sblock_attr({"meta_schedule.unroll_explicit":1024}) compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") @@ -767,54 +767,54 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): for ax0_ax1_fused in T.serial(1024): - with T.block("pad_temp_reindex_shared"): + with T.sblock("pad_temp_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4}) + T.sblock_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4}) pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] for ax0_ax1_ax2_ax3_fused in T.serial(2048): - with T.block("p1_reindex_shared"): + with T.sblock("p1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(p1[v2, v0, v1, v3]) T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3}) + T.sblock_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3}) p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): for ax0_0_1, ax1_0_1 in T.grid(1, 1): - with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): - with T.block("pad_temp_reindex_shared_wmma.matrix_a"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): - with T.block("p1_reindex_shared_wmma.matrix_b_o"): + with T.sblock("p1_reindex_shared_wmma.matrix_b_o"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("p1_reindex_shared_wmma.matrix_b"): + with T.sblock("p1_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): - with T.block("conv2d_nhwc_o"): + with T.sblock("conv2d_nhwc_o"): v0 = T.axis.reduce(1, 0) v1 = T.axis.reduce(1, 0) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) @@ -822,36 +822,36 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32") for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): - with T.block("conv2d_nhwc_reindex_shared"): + with T.sblock("conv2d_nhwc_reindex_shared"): T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) @@ -859,7 +859,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " T.writes(compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, False, True, dtype="int32"), 255), 0) - p8[0], 1457846997, 31, 0, dtype="int32") for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): - with T.block("compute_4"): + with T.sblock("compute_4"): i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) T.reads(compute_3[i0_13, i1_13, i2_13, i3_13], p9[i0_13, i1_13, i2_13, i3_13]) T.writes(compute[i0_13, i1_13, i2_13, i3_13]) @@ -870,10 +870,10 @@ class Conv2dInt8_TensorCore_with_predicate_after: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit": 1024}) + T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") @@ -885,54 +885,54 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): for ax0_ax1_fused in range(1024): - with T.block("pad_temp_reindex_shared"): + with T.sblock("pad_temp_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] for ax0_ax1_ax2_ax3_fused in range(2048): - with T.block("p1_reindex_shared"): + with T.sblock("p1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(p1[v2, v0, v1, v3]) T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) + T.sblock_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): for ax0_0_1, ax1_0_1 in T.grid(1, 1): - with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): - with T.block("pad_temp_reindex_shared_wmma.matrix_a"): + with T.sblock("pad_temp_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): - with T.block("p1_reindex_shared_wmma.matrix_b_o"): + with T.sblock("p1_reindex_shared_wmma.matrix_b_o"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) T.reads(p1_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("p1_reindex_shared_wmma.matrix_b"): + with T.sblock("p1_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): - with T.block("conv2d_nhwc_o"): + with T.sblock("conv2d_nhwc_o"): v0 = T.axis.reduce(1, 0) v1 = T.axis.reduce(1, 0) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) @@ -940,36 +940,36 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): + with T.sblock("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) + T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): - with T.block("conv2d_nhwc_reindex_shared"): + with T.sblock("conv2d_nhwc_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) @@ -985,8 +985,8 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " def test_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = sch.get_block("C") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = sch.get_sblock("C") sch.compute_inline(block_b) assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" @@ -995,8 +995,8 @@ def test_compute_inline_elementwise(use_block_name): def test_compute_inline_under_loop(use_block_name): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = sch.get_block("C") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = sch.get_sblock("C") sch.compute_inline(block_b) assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" @@ -1005,8 +1005,8 @@ def test_compute_inline_under_loop(use_block_name): def test_compute_inline_as_dce(use_block_name): sch = tir.Schedule(elementwise_standalone, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = sch.get_block("C") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = sch.get_sblock("C") sch.compute_inline(block_b) assert_structural_equal_ignore_global_symbol(elementwise_standalone_dce, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" @@ -1015,9 +1015,9 @@ def test_compute_inline_as_dce(use_block_name): def test_compute_inline_multi_consumer(use_block_name): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_b = "B" if use_block_name else sch.get_sblock("B") + block_c = sch.get_sblock("C") + block_d = sch.get_sblock("D") sch.compute_inline(block_b) assert_structural_equal_ignore_global_symbol( elementwise_multi_consumer_inlined, sch.mod["main"] @@ -1029,15 +1029,15 @@ def test_compute_inline_multi_consumer(use_block_name): def test_compute_inline_fail_multi_writer(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_reverse_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") - block_c = "C" if use_block_name else sch.get_block("C") + block_b = sch.get_sblock("B") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" @@ -1046,8 +1046,8 @@ def test_reverse_compute_inline_elementwise(use_block_name): def test_reverse_compute_inline_under_loop(use_block_name): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") - block_b = sch.get_block("B") - block_c = "C" if use_block_name else sch.get_block("C") + block_b = sch.get_sblock("B") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" @@ -1056,28 +1056,28 @@ def test_reverse_compute_inline_under_loop(use_block_name): def test_reverse_compute_inline_fail_as_dce(use_block_name): sch = tir.Schedule(elementwise_standalone, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_b) def test_reverse_compute_inline_fail_multi_producer(use_block_name): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") - block_d = "D" if use_block_name else sch.get_block("D") + block_d = "D" if use_block_name else sch.get_sblock("D") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_d) def test_reverse_compute_inline_fail_multi_reader(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_reverse_compute_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol( elementwise_multi_reverse_loads_inlined, sch.mod["main"] @@ -1087,7 +1087,7 @@ def test_reverse_compute_multi_reverse_loads(use_block_name): def test_reverse_compute_inline_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol( elementwise_reverse_affine_load_inlined, sch.mod["main"] @@ -1097,7 +1097,7 @@ def test_reverse_compute_inline_affine_load(use_block_name): def test_reverse_compute_inline_multi_affine_load(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol( elementwise_multi_reverse_affine_load_inlined, sch.mod["main"] @@ -1107,7 +1107,7 @@ def test_reverse_compute_inline_multi_affine_load(use_block_name): def test_reverse_compute_inline_affine_load_unit_iter(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol( elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"] @@ -1117,7 +1117,7 @@ def test_reverse_compute_inline_affine_load_unit_iter(use_block_name): def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(block_c) assert_structural_equal_ignore_global_symbol( elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"] @@ -1128,8 +1128,8 @@ def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name) @pytest.mark.parametrize("reverse_order", [True, False]) def test_reverse_compute_inline_affine_chain(use_block_name, reverse_order): sch = tir.Schedule(elementwise_reverse_affine_chain, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") - block_d = "D" if use_block_name else sch.get_block("D") + block_c = "C" if use_block_name else sch.get_sblock("C") + block_d = "D" if use_block_name else sch.get_sblock("D") if reverse_order: sch.reverse_compute_inline(block_d) sch.reverse_compute_inline(block_c) @@ -1144,58 +1144,58 @@ def test_reverse_compute_inline_affine_chain(use_block_name, reverse_order): def test_reverse_compute_fail_non_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_reverse_compute_fail_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") - block_c = "C" if use_block_name else sch.get_block("C") + block_c = "C" if use_block_name else sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_opaque_access_load(use_block_name): sch = tir.Schedule(opaque_access_load, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_opaque_access_store(use_block_name): sch = tir.Schedule(opaque_access_store, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_buffer_matched(use_block_name): sch = tir.Schedule(buffer_matched, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_output_block(use_block_name): sch = tir.Schedule(matmul_relu, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block) sch = tir.Schedule(elementwise_output, debug_mask="all") - block = sch.get_block("B") + block = sch.get_sblock("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block) - block = sch.get_block("C") + block = sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block) def test_compute_inline_predicate(use_block_name): sch = tir.Schedule(elementwise_predicate, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") sch.compute_inline(block_b) assert_structural_equal_ignore_global_symbol(elementwise_predicate_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) @@ -1203,7 +1203,7 @@ def test_compute_inline_predicate(use_block_name): def test_compute_inline_multi_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") - block_b = "B" if use_block_name else sch.get_block("B") + block_b = "B" if use_block_name else sch.get_sblock("B") sch.compute_inline(block_b) assert_structural_equal_ignore_global_symbol(elementwise_multi_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) @@ -1212,7 +1212,7 @@ def test_compute_inline_multi_loads(use_block_name): def test_compute_inline_with_opaque_access(use_block_name): """Test not rewrite opaque reads/writes after irrelavant compute inline""" sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all") - BB = "BB" if use_block_name else sch.get_block("BB") + BB = "BB" if use_block_name else sch.get_sblock("BB") sch.compute_inline(BB) assert_structural_equal_ignore_global_symbol( access_opaque_ptr_then_elemwise_inline, sch.mod["main"] @@ -1221,7 +1221,7 @@ def test_compute_inline_with_opaque_access(use_block_name): def test_inline_block_with_init(): sch = tir.Schedule(inline_block_with_init, debug_mask="all") - block = sch.get_block(name="tensor_rf", func_name="main") + block = sch.get_sblock(name="tensor_rf", func_name="main") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block=block) @@ -1229,7 +1229,7 @@ def test_inline_block_with_init(): def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): """Test opaque access with tvm_access_ptr after compute inline""" sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all") - compute = "compute" if use_block_name else sch.get_block("compute") + compute = "compute" if use_block_name else sch.get_sblock("compute") sch.compute_inline(compute) assert_structural_equal_ignore_global_symbol( exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"] @@ -1239,7 +1239,7 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): def test_reverse_compute_inline_overcomputed_producer(use_block_name): """Test reverse compute inline overcomputed producer""" sch = tir.Schedule(elementwise_overcomputed_producer, debug_mask="all") - compute = "C" if use_block_name else sch.get_block("C") + compute = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(compute) assert_structural_equal_ignore_global_symbol( elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"] @@ -1249,7 +1249,7 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name): def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name): """Test reverse compute inline overcomputed producer where the predicate should be simplified""" sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all") - compute = "C" if use_block_name else sch.get_block("C") + compute = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(compute) assert_structural_equal_ignore_global_symbol( elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"] @@ -1259,7 +1259,7 @@ def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_blo def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_name): """Test reverse compute inline overcomputed producer with injective buffer load""" sch = tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all") - compute = "C" if use_block_name else sch.get_block("C") + compute = "C" if use_block_name else sch.get_sblock("C") sch.reverse_compute_inline(compute) assert_structural_equal_ignore_global_symbol( elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"] @@ -1271,7 +1271,7 @@ def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name its producer """ sch = tir.Schedule(elementwise_producer_not_cover_consumer, debug_mask="all") - compute = "C" if use_block_name else sch.get_block("C") + compute = "C" if use_block_name else sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(compute) @@ -1282,7 +1282,7 @@ def test_reverse_compute_inline_producer_predicate_allowed(): """ sch = tir.Schedule(elementwise_predicate_producer, debug_mask="all") - sch.reverse_compute_inline(sch.get_block("C")) + sch.reverse_compute_inline(sch.get_sblock("C")) assert_structural_equal_ignore_global_symbol( elementwise_predicate_producer_inlined, sch.mod["main"] ) @@ -1294,7 +1294,7 @@ def test_reverse_compute_inline_producer_predicate_disallowed(): """ sch = tir.Schedule(Conv2dInt8_TensorCore_with_predicate_before, debug_mask="all") - sch.reverse_compute_inline(sch.get_block("compute_4")) + sch.reverse_compute_inline(sch.get_sblock("compute_4")) assert_structural_equal_ignore_global_symbol( Conv2dInt8_TensorCore_with_predicate_after["main"], sch.mod["main"] ) @@ -1304,7 +1304,7 @@ def test_reverse_compute_inline_producer_is_reduction(): """Test reverse comput inline when producer is reduction""" sch = tir.Schedule(elementwise_producer_is_reduction, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): - sch.reverse_compute_inline(sch.get_block("C")) + sch.reverse_compute_inline(sch.get_sblock("C")) def test_compute_inline_softmax(): @@ -1320,7 +1320,7 @@ def before(p_lv44: T.handle, p_output0: T.handle): T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv44[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) @@ -1328,13 +1328,13 @@ def before(p_lv44: T.handle, p_output0: T.handle): T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): + with T.sblock("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) @@ -1342,14 +1342,14 @@ def before(p_lv44: T.handle, p_output0: T.handle): T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) + T.sblock_attr({"axis": 3}) var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -1361,12 +1361,12 @@ def after(p_lv44: T.handle, p_output0: T.handle): n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): + # with T.sblock("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv44[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) @@ -1374,7 +1374,7 @@ def after(p_lv44: T.handle, p_output0: T.handle): T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv44[v_i0, v_i1, v_i2, v_k], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) @@ -1382,14 +1382,14 @@ def after(p_lv44: T.handle, p_output0: T.handle): T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T.exp(lv44[v_i0, v_i1, v_i2, v_k] - T_softmax_maxelem[v_i0, v_i1, v_i2]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1, v_i2]) T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) + T.sblock_attr({"axis": 3}) var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) @@ -1397,7 +1397,7 @@ def after(p_lv44: T.handle, p_output0: T.handle): # fmt: on sch = tir.Schedule(before) - sch.compute_inline(sch.get_block("T_softmax_exp")) + sch.compute_inline(sch.get_sblock("T_softmax_exp")) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) @@ -1415,7 +1415,7 @@ def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(10)): for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused + ax1) v_k2 = T.axis.reduce(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) @@ -1430,7 +1430,7 @@ def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 for ax2_0 in range(T.int64(10)): for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_ax2 = T.axis.spatial(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) @@ -1438,7 +1438,7 @@ def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): + with T.sblock("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) @@ -1450,13 +1450,13 @@ def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): + # with T.sblock("root"): A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), scope="shared") A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), scope="shared") for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(10)): for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("A_red_temp"): + with T.sblock("A_red_temp"): v_ax0 = T.axis.spatial(T.int64(1), ax0) v_ax1 = T.axis.spatial(n, ax0_ax1_fused + ax1) v_k2 = T.axis.reduce(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) @@ -1471,7 +1471,7 @@ def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 for ax2_0 in range(T.int64(10)): for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1 = T.axis.spatial(n, ax0_ax1_fused) v_ax2 = T.axis.spatial(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) @@ -1481,7 +1481,7 @@ def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: # fmt: on sch = tir.Schedule(before) - sch.reverse_compute_inline(sch.get_block("compute")) + sch.reverse_compute_inline(sch.get_sblock("compute")) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) @@ -1493,11 +1493,11 @@ def before( ): T_add = T.alloc_buffer((1, 16, 7, 7)) for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T_add[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1) for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7): - with T.block("T_strided_slice_with_axes"): + with T.sblock("T_strided_slice_with_axes"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = T_add[ v_ax0, v_ax1, v_ax2, v_ax3 @@ -1510,22 +1510,22 @@ def after( ): T_strided_slice_with_axes_global = T.alloc_buffer((1, 12, 7, 7)) for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.where(ax1 < 12) T_strided_slice_with_axes_global[v_ax0, v_ax1, v_ax2, v_ax3] = x[ v_ax0, v_ax1, v_ax2, v_ax3 ] + T.float32(1) for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7): - with T.block("T_strided_slice_with_axes_global"): + with T.sblock("T_strided_slice_with_axes_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T_strided_slice_with_axes[v0, v1, v2, v3] = T_strided_slice_with_axes_global[ v0, v1, v2, v3 ] sch = tir.Schedule(before) - sch.reverse_compute_inline(sch.get_block("T_strided_slice_with_axes")) - sch.cache_write(sch.get_block("T_add"), 0, "global") + sch.reverse_compute_inline(sch.get_sblock("T_strided_slice_with_axes")) + sch.cache_write(sch.get_sblock("T_add"), 0, "global") assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) @@ -1538,7 +1538,7 @@ def before( ): T_batch_matmul_NN = T.alloc_buffer((T.int64(6), T.int64(1), T.int64(64))) for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): - with T.block("bmm"): + with T.sblock("bmm"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1]) T.writes(T_batch_matmul_NN[v0, T.int64(0), v1]) @@ -1549,7 +1549,7 @@ def before( + T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1] ) for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): - with T.block("transpose"): + with T.sblock("transpose"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_batch_matmul_NN[v0, T.int64(0), v1]) T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1]) @@ -1562,7 +1562,7 @@ def after( T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"), ): for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): - with T.block("bmm"): + with T.sblock("bmm"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1]) T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1]) @@ -1574,7 +1574,7 @@ def after( ) sch = tir.Schedule(before) - sch.reverse_compute_inline(sch.get_block("transpose")) + sch.reverse_compute_inline(sch.get_sblock("transpose")) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) diff --git a/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py b/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py index 882a5b72cefa..d81367cfb5b0 100644 --- a/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py +++ b/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py @@ -49,7 +49,7 @@ def before_decompose( y: T.Buffer((T.int64(1), T.int64(140), T.int64(128)), "int32"), ): for b, i, j in T.grid(T.int64(1), T.int64(140), T.int64(128)): - with T.block("block"): + with T.sblock("block"): vb, vi, vj = T.axis.remap("SSS", [b, i, j]) y[vb, vi, vj] = T.if_then_else(vi < T.int64(128), x[vb, vi, vj], 0) @@ -58,17 +58,17 @@ def after_decompose( x: T.Buffer((T.int64(1), T.int64(128), T.int64(128)), "int32"), y: T.Buffer((T.int64(1), T.int64(140), T.int64(128)), "int32"), ): - # with T.block("root"): + # with T.sblock("root"): for b, i in T.grid(T.int64(1), T.int64(140)): for j in range(T.int64(128)): - with T.block("block_pad_const"): + with T.sblock("block_pad_const"): vb = T.axis.spatial(T.int64(1), T.int64(0)) vi, vj = T.axis.remap("SS", [i, j]) T.reads() T.writes(y[vb, vi, vj]) y[vb, vi, vj] = 0 for j in range(T.int64(128)): - with T.block("block"): + with T.sblock("block"): vb = T.axis.spatial(T.int64(1), T.int64(0)) vi = T.axis.spatial(T.int64(128), i) vj = T.axis.spatial(T.int64(128), j) @@ -78,7 +78,7 @@ def after_decompose( y[vb, vi, vj] = x[vb, vi, vj] sch = tir.Schedule(before_decompose, debug_mask="all") - block = sch.get_block("block") + block = sch.get_sblock("block") sch.decompose_padding(block, sch.get_loops(block)[2]) check_decompose_padding(before_decompose, sch.mod["main"], after_decompose, check_run=False) @@ -87,27 +87,27 @@ def test_1d_decompose_padding(): @T.prim_func def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): for i in range(140): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32") @T.prim_func def after_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")): for i in T.serial(140): - with T.block("block_pad_const"): + with T.sblock("block_pad_const"): vi = T.axis.spatial(140, i) T.reads() T.writes(y[vi]) y[vi] = 0 for i in T.serial(128): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial(128, i) T.reads(x[vi]) T.writes(y[vi + 6]) y[vi + 6] = x[vi] sch = tir.Schedule(before_decompose, debug_mask="all") - block = sch.get_block("block") + block = sch.get_sblock("block") sch.decompose_padding(block, sch.get_loops(block)[0]) check_decompose_padding(before_decompose, sch.mod["main"], after_decompose, check_run=False) @@ -118,7 +118,7 @@ def sum_pool_2d( ): pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else( 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228, @@ -127,7 +127,7 @@ def sum_pool_2d( dtype="int8", ) for i0, i1, i2, i3, i4, i5 in T.grid(1, 16, 225, 225, 7, 7): - with T.block("tensor"): + with T.sblock("tensor"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): tensor[ax0, ax1, ax2, ax3] = T.int8(0) @@ -145,15 +145,15 @@ def pooling_decompose_0( ): pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): - with T.block("pad_temp_pad_const"): + with T.sblock("pad_temp_pad_const"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[ax0, ax1, ax2, ax3] = T.int8(0) for i0, i1, i2, i3 in T.grid(1, 16, 225, 225): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3] for i0, i1, i2, i3, i4, i5 in T.grid(1, 16, 225, 225, 7, 7): - with T.block("tensor"): + with T.sblock("tensor"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): tensor[ax0, ax1, ax2, ax3] = T.int8(0) @@ -162,7 +162,7 @@ def pooling_decompose_0( ) sch = tir.Schedule(sum_pool_2d, debug_mask="all") - pad = sch.get_block("pad_temp") + pad = sch.get_sblock("pad_temp") sch.decompose_padding(pad, sch.get_loops(pad)[0]) check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_0, check_run=True) @@ -177,7 +177,7 @@ def pooling_decompose_1( pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i2_0, i3_0 in T.grid(1, 3, 3): for ax0, ax1, ax2 in T.grid(16, 81, 81): - with T.block("pad_temp_pad_const"): + with T.sblock("pad_temp_pad_const"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(16, ax0) ax2_1 = T.axis.spatial(231, i2_0 * 75 + ax1) @@ -186,7 +186,7 @@ def pooling_decompose_1( T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) for ax0, ax1, ax2 in T.grid(16, 81, 81): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.spatial(16, ax0) ax2_2 = T.axis.spatial(225, i2_0 * 75 + ax1 - 3) @@ -201,7 +201,7 @@ def pooling_decompose_1( T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] for i1, i2_1, i3_1, i4, i5 in T.grid(16, 75, 75, 7, 7): - with T.block("tensor"): + with T.sblock("tensor"): ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) ax2_3 = T.axis.spatial(225, i2_0 * 75 + i2_1) ax3 = T.axis.spatial(225, i3_0 * 75 + i3_1) @@ -216,13 +216,13 @@ def pooling_decompose_1( ) sch = tir.Schedule(sum_pool_2d, debug_mask="all") - block = sch.get_block("tensor") - pad = sch.get_block("pad_temp") + block = sch.get_sblock("tensor") + pad = sch.get_sblock("pad_temp") n, c, h, w, kh, kw = sch.get_loops(block) ho, hi = sch.split(h, [3, 75]) wo, wi = sch.split(w, [3, 75]) sch.reorder(n, ho, wo, c, hi, wi, kh, kw) - sch.compute_at(sch.get_block("pad_temp"), wo) + sch.compute_at(sch.get_sblock("pad_temp"), wo) sch.decompose_padding(pad, sch.get_loops(pad)[3]) check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_1, check_run=True) @@ -236,7 +236,7 @@ def pooling_decompose_2( ) -> None: pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i2_0, i3_0, ax0, ax1, ax2 in T.grid(1, 3, 3, 16, 81, 81): - with T.block("pad_temp_pad_const"): + with T.sblock("pad_temp_pad_const"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(16, ax0) ax2_1 = T.axis.spatial(231, i2_0 * 75 + ax1) @@ -246,7 +246,7 @@ def pooling_decompose_2( pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) for i0, i2_0, i3_0 in T.grid(1, 3, 3): for ax0, ax1, ax2 in T.grid(16, 81, 81): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.spatial(16, ax0) ax2_2 = T.axis.spatial(225, i2_0 * 75 + ax1 - 3) @@ -261,7 +261,7 @@ def pooling_decompose_2( T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] for i1, i2_1, i3_1, i4, i5 in T.grid(16, 75, 75, 7, 7): - with T.block("tensor"): + with T.sblock("tensor"): ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) ax2_3 = T.axis.spatial(225, i2_0 * 75 + i2_1) ax3 = T.axis.spatial(225, i3_0 * 75 + i3_1) @@ -276,13 +276,13 @@ def pooling_decompose_2( ) sch = tir.Schedule(sum_pool_2d, debug_mask="all") - block = sch.get_block("tensor") - pad = sch.get_block("pad_temp") + block = sch.get_sblock("tensor") + pad = sch.get_sblock("pad_temp") n, c, h, w, kh, kw = sch.get_loops(block) ho, hi = sch.split(h, [3, 75]) wo, wi = sch.split(w, [3, 75]) sch.reorder(n, ho, wo, c, hi, wi, kh, kw) - sch.compute_at(sch.get_block("pad_temp"), wo) + sch.compute_at(sch.get_sblock("pad_temp"), wo) sch.decompose_padding(pad, sch.get_loops(pad)[0]) check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_2, check_run=True) @@ -297,7 +297,7 @@ def pooling_decompose_3( pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i2_0, i3_0 in T.grid(1, 3, 3): for ax0, ax1, ax2 in T.grid(16, 86, 86): - with T.block("pad_temp_pad_const"): + with T.sblock("pad_temp_pad_const"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(16, ax0) ax2_1 = T.axis.spatial(231, i2_0 * 80 + ax1) @@ -307,7 +307,7 @@ def pooling_decompose_3( T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) for ax0, ax1, ax2 in T.grid(16, 86, 86): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.spatial(16, ax0) ax2_2 = T.axis.spatial(225, i2_0 * 80 + ax1 - 3) @@ -324,7 +324,7 @@ def pooling_decompose_3( T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] for i1, i2_1, i3_1, i4, i5 in T.grid(16, 80, 80, 7, 7): - with T.block("tensor"): + with T.sblock("tensor"): ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) ax2_3 = T.axis.spatial(225, i2_0 * 80 + i2_1) ax3 = T.axis.spatial(225, i3_0 * 80 + i3_1) @@ -340,13 +340,13 @@ def pooling_decompose_3( ) sch = tir.Schedule(sum_pool_2d, debug_mask="all") - block = sch.get_block("tensor") - pad = sch.get_block("pad_temp") + block = sch.get_sblock("tensor") + pad = sch.get_sblock("pad_temp") n, c, h, w, kh, kw = sch.get_loops(block) ho, hi = sch.split(h, [None, 80]) wo, wi = sch.split(w, [None, 80]) sch.reorder(n, ho, wo, c, hi, wi, kh, kw) - sch.compute_at(sch.get_block("pad_temp"), wo) + sch.compute_at(sch.get_sblock("pad_temp"), wo) sch.decompose_padding(pad, sch.get_loops(pad)[3]) check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True) @@ -360,7 +360,7 @@ def pad_op( y: T.Buffer((1, 16, 231, 231), dtype="int8"), ): for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) y[ax0, ax1, ax2, ax3] = T.if_then_else( 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228, @@ -375,18 +375,18 @@ def pad_op_after( ): for i0, i1 in T.grid(1, 16): for i2, i3 in T.grid(231, 231): - with T.block("pad_temp_pad_const"): + with T.sblock("pad_temp_pad_const"): ax0 = T.axis.spatial(1, 0) ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3]) y[ax0, ax1, ax2, ax3] = T.int8(0) for i2, i3 in T.grid(225, 225): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0 = T.axis.spatial(1, 0) ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3]) y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3] sch = tir.Schedule(pad_op, debug_mask="all") - pad = sch.get_block("pad_temp") + pad = sch.get_sblock("pad_temp") _, _, h, _ = sch.get_loops(pad) sch.decompose_padding(pad, h) check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True) @@ -400,7 +400,7 @@ def trivial_pad( x: T.Buffer((1, 16, 225, 225), "int8"), y: T.Buffer([1, 16, 225, 225], dtype="int8") ): for i0, i1, i2, i3 in T.grid(1, 16, 225, 225): - with T.block("pad_temp"): + with T.sblock("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) y[ax0, ax1, ax2, ax3] = T.if_then_else( 0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225, @@ -410,7 +410,7 @@ def trivial_pad( ) sch = tir.Schedule(trivial_pad, debug_mask="all") - pad = sch.get_block("pad_temp") + pad = sch.get_sblock("pad_temp") _, _, h, _ = sch.get_loops(pad) assert not sch.can_decompose_padding(pad, h) diff --git a/tests/python/tir-schedule/test_tir_schedule_error.py b/tests/python/tir-schedule/test_tir_schedule_error.py index 6409b7c09514..929ef12566d7 100644 --- a/tests/python/tir-schedule/test_tir_schedule_error.py +++ b/tests/python/tir-schedule/test_tir_schedule_error.py @@ -31,11 +31,11 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -45,7 +45,7 @@ def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32): T.func_attr({"tir.noalias": True}) A = T.match_buffer(var_A, (1, seq_len * 8), "int32") B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8) - with T.block("exclusive_scan"): + with T.sblock("exclusive_scan"): T.reads() T.writes() s8: T.int32 = seq_len * 8 @@ -65,7 +65,7 @@ def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32): def test_tir_schedule_error_detail(): sch = tir.Schedule(matmul, debug_mask="all", error_render_level="detail") with pytest.raises(tir.ScheduleError) as excinfo: - sch.get_block("wrong_name") + sch.get_sblock("wrong_name") (msg,) = excinfo.value.args assert "Cannot find a block with the name: wrong_name" in msg @@ -73,7 +73,7 @@ def test_tir_schedule_error_detail(): def test_tir_schedule_error_fast(): sch = tir.Schedule(matmul, debug_mask="all", error_render_level="fast") with pytest.raises(tir.ScheduleError) as excinfo: - sch.get_block("wrong_name") + sch.get_sblock("wrong_name") (msg,) = excinfo.value.args assert "Cannot find a block with the specified name" in msg @@ -81,7 +81,7 @@ def test_tir_schedule_error_fast(): def test_tir_schedule_error_none(): sch = tir.Schedule(matmul, debug_mask="all", error_render_level="none") with pytest.raises(tir.ScheduleError) as excinfo: - sch.get_block("wrong_name") + sch.get_sblock("wrong_name") (msg,) = excinfo.value.args assert "(not rendered)" in msg diff --git a/tests/python/tir-schedule/test_tir_schedule_for_kind.py b/tests/python/tir-schedule/test_tir_schedule_for_kind.py index 7ae406445530..24d58f9e2f4f 100644 --- a/tests/python/tir-schedule/test_tir_schedule_for_kind.py +++ b/tests/python/tir-schedule/test_tir_schedule_for_kind.py @@ -35,7 +35,7 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -46,7 +46,7 @@ def element_wise_parallelized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -57,7 +57,7 @@ def element_wise_i_bound(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.thread_binding(0, 128, thread="threadIdx.x"): for i1 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -69,11 +69,11 @@ def element_wise_compute_at_split(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o, j1i in T.grid(32, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -86,12 +86,12 @@ def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.serial(0, 32): for j1i in T.vectorized(0, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -102,7 +102,7 @@ def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): - with T.block("B"): + with T.sblock("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) @@ -116,7 +116,7 @@ def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): - with T.block("B"): + with T.sblock("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) @@ -129,7 +129,7 @@ def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128]) for i in T.vectorized(0, 128): for j_0, j_1 in T.grid(13, 10): - with T.block("B"): + with T.sblock("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) @@ -143,12 +143,12 @@ def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): for j1i in T.serial(0, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -161,7 +161,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -174,7 +174,7 @@ def rowsum(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 @@ -187,7 +187,7 @@ def rowsum_unrolled(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 @@ -200,7 +200,7 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): @@ -214,7 +214,7 @@ def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vk] = 0.0 @@ -227,7 +227,7 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.serial(0, 128): for i1 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 @@ -238,7 +238,7 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: def opaque_block(a: T.handle) -> None: A = T.match_buffer(a, (16,)) for i in T.serial(0, 15): - with T.block("opaque"): + with T.sblock("opaque"): A[i + 1] = A[i + 1] + A[i] @@ -247,16 +247,16 @@ def block_inside_init(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128], dtype="float32") B = T.match_buffer(b, [128, 128], dtype="float32") for i in T.serial(0, 128): - with T.block("outer"): + with T.sblock("outer"): vi = T.axis.S(128, i) with T.init(): for j in T.serial(0, 128): - with T.block("init"): + with T.sblock("init"): vj = T.axis.S(128, j) B[vi, vj] = 0.0 for k in T.serial(0, 128): for j in T.serial(0, 128): - with T.block("inner"): + with T.sblock("inner"): vj, vk = T.axis.remap("SR", [j, k]) B[vi, vj] = B[vi, vj] + A[vi, vj, vk] @@ -266,16 +266,16 @@ def thread_bound_block_inside_init(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128], dtype="float32") B = T.match_buffer(b, [128, 128], dtype="float32") for i in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("outer"): + with T.sblock("outer"): vi = T.axis.S(128, i) with T.init(): for j in T.serial(0, 128): - with T.block("init"): + with T.sblock("init"): vj = T.axis.S(128, j) B[vi, vj] = 0.0 for k in T.serial(0, 128): for j in T.serial(0, 128): - with T.block("inner"): + with T.sblock("inner"): vj, vk = T.axis.remap("SR", [j, k]) B[vi, vj] = B[vi, vj] + A[vi, vj, vk] @@ -289,18 +289,18 @@ def decomposed_gemm( local = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(4, 4): for ii, jj in T.grid(4, 4): - with T.block("init"): + with T.sblock("init"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) local[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) local[vi, vj] += A[vi, vk] * B[vj, vk] for ii, jj in T.grid(4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) C[vi, vj] = local[vi, vj] @@ -315,19 +315,19 @@ def decomposed_gemm_after_vectorize( local = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(4, 4): for ii, jj in T.grid(4, 4): - with T.block("init"): + with T.sblock("init"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) local[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) local[vi, vj] += A[vi, vk] * B[vj, vk] for ii in range(4): for jj in T.vectorized(4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) C[vi, vj] = local[vi, vj] @@ -338,10 +338,10 @@ def nested_block_bind( A: T.Buffer((16, 16, 16, 16), "float32"), B: T.Buffer((16, 16, 16), "float32") ): for i, j in T.grid(16, 16): - with T.block("outer"): + with T.sblock("outer"): vi, vj = T.axis.remap("SS", [i, j]) for k, l in T.grid(16, 16): - with T.block("inner"): + with T.sblock("inner"): vk, vl = T.axis.remap("SR", [k, l]) with T.init(): B[vi, vj, vk] = 0.0 @@ -354,11 +354,11 @@ def thread_bound_nested_block( ) -> None: for i in T.serial(16): for j in T.thread_binding(16, thread="blockIdx.x"): - with T.block("outer"): + with T.sblock("outer"): vi, vj = T.axis.remap("SS", [i, j]) for k in T.serial(16): for l in T.thread_binding(16, thread="threadIdx.x"): - with T.block("inner"): + with T.sblock("inner"): vk, vl = T.axis.remap("SR", [k, l]) with T.init(): B[vi, vj, vk] = T.float32(0) @@ -370,16 +370,16 @@ def nested_block_bind_after_cache_read( A: T.Buffer((16, 16), "float32"), B: T.Buffer((16,), "float32") ) -> None: for i in T.serial(16): - with T.block("outer"): + with T.sblock("outer"): vi = T.axis.spatial(16, i) A_shared = T.alloc_buffer([1, 16], dtype="float32", scope="shared") for ax0, ax1 in T.grid(1, 16): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(16, vi + ax0) v1 = T.axis.spatial(16, ax1) A_shared[v0, v1] = A[v0, v1] for j in T.serial(16): - with T.block("inner"): + with T.sblock("inner"): vj = T.axis.reduce(16, j) with T.init(): B[vi] = T.float32(0) @@ -391,16 +391,16 @@ def thread_bound_nested_block_after_cache_read( A: T.Buffer((16, 16), "float32"), B: T.Buffer((16,), "float32") ) -> None: for i in T.thread_binding(16, thread="blockIdx.x"): - with T.block("outer"): + with T.sblock("outer"): vi = T.axis.spatial(16, i) A_shared = T.alloc_buffer([1, 16], dtype="float32", scope="shared") for ax0, ax1 in T.grid(1, 16): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial(16, vi + ax0) v1 = T.axis.spatial(16, ax1) A_shared[v0, v1] = A[v0, v1] for j in T.thread_binding(16, thread="threadIdx.x"): - with T.block("inner"): + with T.sblock("inner"): vj = T.axis.reduce(16, j) with T.init(): B[vi] = T.float32(0) @@ -417,14 +417,14 @@ def decomposed_gemm_parallelize_init( for i, j in T.grid(4, 4): for ii in T.serial(4): for jj in T.vectorized(4): - with T.block("init"): + with T.sblock("init"): vi = T.axis.spatial(16, i * 4 + ii) vj = T.axis.spatial(16, j * 4 + jj) T.reads() T.writes(local[vi, vj]) local[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.spatial(16, i * 4 + ii) vj = T.axis.spatial(16, j * 4 + jj) vk = T.axis.reduce(16, k) @@ -432,7 +432,7 @@ def decomposed_gemm_parallelize_init( T.writes(local[vi, vj]) local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk] for ii, jj in T.grid(4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(16, i * 4 + ii) vj = T.axis.spatial(16, j * 4 + jj) T.reads(local[vi, vj]) @@ -443,12 +443,12 @@ def decomposed_gemm_parallelize_init( @T.prim_func def scatter_compute(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): for i in T.grid(8): - with T.block("first_half"): + with T.sblock("first_half"): vi = T.axis.spatial(16, 8 + i) B[vi] = A[vi - 8] for i in T.grid(8): - with T.block("last_half"): + with T.sblock("last_half"): vi = T.axis.spatial(16, i) B[vi] = A[vi + 8] @@ -458,15 +458,15 @@ def scatter_compute_parallelize( A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32") ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i in T.parallel(8): - with T.block("first_half"): + with T.sblock("first_half"): vi = T.axis.spatial(16, 8 + i) T.reads(A[vi - 8]) T.writes(B[vi]) B[vi] = A[vi - 8] for i in T.parallel(8): - with T.block("last_half"): + with T.sblock("last_half"): vi = T.axis.spatial(16, i) T.reads(A[vi + 8]) T.writes(B[vi]) @@ -478,7 +478,7 @@ def scatter_compute_parallelize( def test_parallel(): s = tir.Schedule(element_wise, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) s.parallel(i) assert_structural_equal_ignore_global_symbol(s.mod["main"], element_wise_parallelized) verify_trace_roundtrip(s, mod=element_wise) @@ -486,7 +486,7 @@ def test_parallel(): def test_parallel_predicate(): s = tir.Schedule(element_wise_split_predicate, debug_mask="all") - _, j, _ = s.get_loops(s.get_block("B")) + _, j, _ = s.get_loops(s.get_sblock("B")) s.parallel(j) assert_structural_equal_ignore_global_symbol( s.mod["main"], element_wise_split_predicate_parallelized @@ -496,28 +496,28 @@ def test_parallel_predicate(): def test_parallel_reduction_block_iter(): s = tir.Schedule(matmul, debug_mask="all") - _, _, k = s.get_loops(s.get_block("C")) + _, _, k = s.get_loops(s.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError): s.parallel(k) def test_parallel_not_quasi_affine(): s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.parallel(i) def test_parallel_not_compact_data_flow(): s = tir.Schedule(rowsum_not_compact_data_flow, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.parallel(i) def test_vectorize(): s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") - _, _, j1i = s.get_loops(s.get_block("C")) + _, _, j1i = s.get_loops(s.get_sblock("C")) s.vectorize(j1i) assert_structural_equal_ignore_global_symbol( s.mod["main"], element_wise_compute_at_split_vectorized @@ -527,7 +527,7 @@ def test_vectorize(): def test_vectorize_predicate(): s = tir.Schedule(element_wise_split_predicate, debug_mask="all") - i, _, _ = s.get_loops(s.get_block("B")) + i, _, _ = s.get_loops(s.get_sblock("B")) s.vectorize(i) assert_structural_equal_ignore_global_symbol( s.mod["main"], element_wise_split_predicate_vectorized @@ -537,14 +537,14 @@ def test_vectorize_predicate(): def test_vectorize_opaque_block(): s = tir.Schedule(opaque_block, debug_mask="all") - (i,) = s.get_loops(s.get_block("opaque")) + (i,) = s.get_loops(s.get_sblock("opaque")) with pytest.raises(tvm.tir.ScheduleError): s.vectorize(i) def test_unroll(): s = tir.Schedule(rowsum, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) s.unroll(i) assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_unrolled) verify_trace_roundtrip(s, mod=rowsum) @@ -552,7 +552,7 @@ def test_unroll(): def test_unroll_after_bind(): s = tir.Schedule(rowsum, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) s.bind(i, "blockIdx.x") s.unroll(i) assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_unrolled) @@ -561,7 +561,7 @@ def test_unroll_after_bind(): def test_bind1(): s = tir.Schedule(element_wise, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) s.bind(i, "threadIdx.x") assert_structural_equal_ignore_global_symbol(s.mod["main"], element_wise_i_bound) verify_trace_roundtrip(s, mod=element_wise) @@ -569,8 +569,8 @@ def test_bind1(): def test_bind2(): s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") - _, j0 = s.get_loops(s.get_block("B")) - _, j1o, _ = s.get_loops(s.get_block("C")) + _, j0 = s.get_loops(s.get_sblock("B")) + _, j1o, _ = s.get_loops(s.get_sblock("C")) s.bind(j0, "threadIdx.x") s.bind(j1o, "threadIdx.x") assert_structural_equal_ignore_global_symbol( @@ -581,7 +581,7 @@ def test_bind2(): def test_bind_cross_thread_reduction(): s = tir.Schedule(rowsum, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) s.bind(k, "threadIdx.x") assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_cross_thread_reduction) verify_trace_roundtrip(s, mod=rowsum) @@ -589,14 +589,14 @@ def test_bind_cross_thread_reduction(): def test_bind_not_cross_thread_reduction(): s = tir.Schedule(rowsum, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.bind(k, "blockIdx.x") def test_bind_after_bind(): s = tir.Schedule(element_wise, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) s.bind(i, "blockIdx.x") s.bind(i, "threadIdx.x") assert_structural_equal_ignore_global_symbol(s.mod["main"], element_wise_i_bound) @@ -605,7 +605,7 @@ def test_bind_after_bind(): def test_block_inside_init(): s = tir.Schedule(block_inside_init, debug_mask="all") - (i,) = s.get_loops(s.get_block("outer")) + (i,) = s.get_loops(s.get_sblock("outer")) s.bind(i, "threadIdx.x") assert_structural_equal_ignore_global_symbol(s.mod["main"], thread_bound_block_inside_init) verify_trace_roundtrip(s, mod=block_inside_init) @@ -613,7 +613,7 @@ def test_block_inside_init(): def test_vectorize_after_decompose(): s = tir.Schedule(decomposed_gemm, debug_mask="all") - jj = s.get_loops(s.get_block("C"))[-1] + jj = s.get_loops(s.get_sblock("C"))[-1] s.vectorize(jj) assert_structural_equal_ignore_global_symbol(s.mod["main"], decomposed_gemm_after_vectorize) verify_trace_roundtrip(s, mod=decomposed_gemm) @@ -621,8 +621,8 @@ def test_vectorize_after_decompose(): def test_nested_block_bind(): s = tir.Schedule(nested_block_bind) - block_outer = s.get_block("outer") - block_inner = s.get_block("inner") + block_outer = s.get_sblock("outer") + block_inner = s.get_sblock("inner") _, j = s.get_loops(block_outer) _, l = s.get_loops(block_inner) s.bind(l, "threadIdx.x") @@ -633,8 +633,8 @@ def test_nested_block_bind(): def test_nexted_block_bind_after_cache_read(): s = tir.Schedule(nested_block_bind_after_cache_read) - block_outer = s.get_block("outer") - block_inner = s.get_block("inner") + block_outer = s.get_sblock("outer") + block_inner = s.get_sblock("inner") (i,) = s.get_loops(block_outer) (j,) = s.get_loops(block_inner) s.bind(i, "blockIdx.x") @@ -647,8 +647,8 @@ def test_nexted_block_bind_after_cache_read(): def test_vectorize_init(): s = tir.Schedule(decomposed_gemm, debug_mask="all") - init_blk = s.get_block("init") - upd_blk = s.get_block("update") + init_blk = s.get_sblock("init") + upd_blk = s.get_sblock("update") _, _, ii_0, jj_0 = s.get_loops(init_blk) _, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk) s.vectorize(jj_0) @@ -658,8 +658,8 @@ def test_vectorize_init(): def test_scatter_parallelize(): s = tir.Schedule(scatter_compute, debug_mask="all") - first = s.get_block("first_half") - last = s.get_block("last_half") + first = s.get_sblock("first_half") + last = s.get_sblock("last_half") (i_0,) = s.get_loops(first) (i_1,) = s.get_loops(last) s.parallel(i_0) @@ -675,7 +675,7 @@ def before( B: T.Buffer((T.int64(128), T.int64(128))), ) -> None: for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -686,12 +686,12 @@ def expected( ) -> None: for i0 in T.thread_binding(T.int64(128), thread="threadIdx.x"): for i1 in range(T.int64(128)): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 s = tir.Schedule(before, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) s.bind(i, "threadIdx.x") assert_structural_equal_ignore_global_symbol(s.mod["main"], expected) verify_trace_roundtrip(s, mod=before) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index dc89f9df56a7..7210237f834d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -38,13 +38,13 @@ def matmul_bias_before( ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): + with T.sblock("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.int32(0) temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] + C[vi, vj] @@ -58,7 +58,7 @@ def matmul_bias_expected( ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): + with T.sblock("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -76,13 +76,13 @@ def matmul_bias_fp32_before( ) -> None: temp = T.alloc_buffer((32, 32), dtype="float32") for i, j, k in T.grid(32, 32, 32): - with T.block("multiply"): + with T.sblock("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.float32(0) temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(32, 32): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] + C[vi, vj] @@ -96,7 +96,7 @@ def matmul_bias_fp32_expected( ) -> None: temp = T.alloc_buffer((32, 32), dtype="float32") for i, j, k in T.grid(32, 32, 32): - with T.block("multiply"): + with T.sblock("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -115,17 +115,17 @@ def matmul_bias_multiple_epilogue_before( ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): + with T.sblock("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.int32(0) temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") for i, j in T.grid(16, 16): - with T.block("add"): + with T.sblock("add"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] + C[vi, vj] for i, j in T.grid(16, 16): - with T.block("add2"): + with T.sblock("add2"): vi, vj = T.axis.remap("SS", [i, j]) E[vi, vj] = temp[vi, vj] + C[vi, vj] @@ -140,7 +140,7 @@ def matmul_bias_multiple_epilogue_expected( ) -> None: temp = T.alloc_buffer((16, 16), dtype="int32") for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): + with T.sblock("multiply"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -148,7 +148,7 @@ def matmul_bias_multiple_epilogue_expected( D[vi, vj] = C[vi, vj] D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") for i, j in T.grid(16, 16): - with T.block("add2"): + with T.sblock("add2"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(temp[vi, vj], C[vi, vj]) T.writes(E[vi, vj]) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py index 6b3338b9a164..d5ad24bbd46d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py @@ -40,14 +40,14 @@ def matmul_clipping_before( """Original function with separate reduction and clipping epilogue blocks.""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.float32(0) temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(16, 16): - with T.block("clipping"): + with T.sblock("clipping"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper) @@ -63,7 +63,7 @@ def matmul_clipping_expected( """Expected function after fusion (Clipping).""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -91,18 +91,18 @@ def matmul_clipping_before_per_iteration( lower = T.float32(-5.0) upper = T.float32(5.0) for i, j in T.grid(16, 16): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) temp[vi, vj] = T.min(T.max(T.float32(0), lower), upper) # Clip init for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # Per-iteration clipping temp[vi, vj] = T.min(T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) for i, j in T.grid(16, 16): - with T.block("copy"): + with T.sblock("copy"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] @@ -163,19 +163,19 @@ def matmul_clipping_multiple_epilogue_before( """Original function with separate reduction and multiple epilogue blocks (one with clipping, one without).""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.float32(0) temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(16, 16): - with T.block("clipping"): + with T.sblock("clipping"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper) for i, j in T.grid(16, 16): - with T.block("copy"): + with T.sblock("copy"): vi, vj = T.axis.remap("SS", [i, j]) E[vi, vj] = temp[vi, vj] @@ -192,7 +192,7 @@ def matmul_clipping_multiple_epilogue_expected( """Expected function after fusion (Clipping) with multiple epilogue blocks.""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -200,7 +200,7 @@ def matmul_clipping_multiple_epilogue_expected( D[vi, vj] = T.min(T.max(T.float32(0), lower), upper) D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) for i, j in T.grid(16, 16): - with T.block("copy"): + with T.sblock("copy"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(temp[vi, vj]) T.writes(E[vi, vj]) @@ -250,14 +250,14 @@ def test_func( ) -> None: temp = T.alloc_buffer((8, 8), dtype="float32") for i, j, k in T.grid(8, 8, 8): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.float32(0) temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(8, 8): - with T.block("clipping"): + with T.sblock("clipping"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = pattern_func(temp[vi, vj], T.float32(lower), T.float32(upper)) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py index 66e5e52e43db..e2dbfe24bb60 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py @@ -39,14 +39,14 @@ def matmul_bias_relu_before( """Original function with separate reduction and epilogue blocks (Bias + ReLU).""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.float32(0) temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(16, 16): - with T.block("bias_relu"): + with T.sblock("bias_relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0)) @@ -61,18 +61,18 @@ def matmul_bias_relu_before_per_iteration( """Original function with per-iteration ReLU (same semantics as fused).""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j in T.grid(16, 16): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) temp[vi, vj] = T.max(C[vi, vj], T.float32(0)) # ReLU on bias for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # Per-iteration ReLU temp[vi, vj] = T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) for i, j in T.grid(16, 16): - with T.block("copy"): + with T.sblock("copy"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = temp[vi, vj] @@ -87,7 +87,7 @@ def matmul_bias_relu_expected( """Expected function after fusion (Bias + ReLU).""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -163,19 +163,19 @@ def matmul_bias_relu_multiple_epilogue_before( """Original function with separate reduction and multiple epilogue blocks (one with ReLU, one without).""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.float32(0) temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(16, 16): - with T.block("bias_relu"): + with T.sblock("bias_relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0)) for i, j in T.grid(16, 16): - with T.block("bias"): + with T.sblock("bias"): vi, vj = T.axis.remap("SS", [i, j]) E[vi, vj] = temp[vi, vj] + C[vi, vj] @@ -191,7 +191,7 @@ def matmul_bias_relu_multiple_epilogue_expected( """Expected function after fusion (Bias + ReLU) with multiple epilogue blocks.""" temp = T.alloc_buffer((16, 16), dtype="float32") for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) @@ -199,7 +199,7 @@ def matmul_bias_relu_multiple_epilogue_expected( D[vi, vj] = T.max(C[vi, vj], T.float32(0)) D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) for i, j in T.grid(16, 16): - with T.block("bias"): + with T.sblock("bias"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(temp[vi, vj], C[vi, vj]) T.writes(E[vi, vj]) diff --git a/tests/python/tir-schedule/test_tir_schedule_instruction.py b/tests/python/tir-schedule/test_tir_schedule_instruction.py index 1aa78ee68c5b..9e9fb195b0aa 100644 --- a/tests/python/tir-schedule/test_tir_schedule_instruction.py +++ b/tests/python/tir-schedule/test_tir_schedule_instruction.py @@ -20,7 +20,7 @@ import pytest import tvm.testing -from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV +from tvm.tir.schedule import SBlockRV, Instruction, InstructionKind, LoopRV def test_inst_kind_get(): @@ -30,7 +30,7 @@ def test_inst_kind_get(): def test_inst_construct_1(): - block = BlockRV() + block = SBlockRV() loop0 = LoopRV() loop1 = LoopRV() inst = Instruction( @@ -50,7 +50,7 @@ def test_inst_construct_1(): def test_inst_construct_2(): - block = BlockRV() + block = SBlockRV() inst = Instruction( kind=InstructionKind.get("ComputeInline"), inputs=[block], diff --git a/tests/python/tir-schedule/test_tir_schedule_merge.py b/tests/python/tir-schedule/test_tir_schedule_merge.py index b3e72943bf6e..66900dd4f65f 100644 --- a/tests/python/tir-schedule/test_tir_schedule_merge.py +++ b/tests/python/tir-schedule/test_tir_schedule_merge.py @@ -35,20 +35,20 @@ def elementwise(a: T.handle, c: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (64, 64)) B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + j_1) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) for i_0, j_0, i_1, j_1 in T.grid(8, 8, 8, 8): - with T.block("D"): + with T.sblock("D"): vi = T.axis.spatial(64, i_0 * 8 + i_1) vj = T.axis.spatial(64, j_0 * 8 + j_1) T.reads(B[vi, vj]) @@ -63,21 +63,21 @@ def elementwise_merged(a: T.handle, c: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (64, 64)) B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for i_0_m in range(8): for j_0, i_1, j_1 in T.grid(8, 16, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i_0_m * 16 + i_1) vj = T.axis.spatial(128, j_0 * 16 + j_1) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) for j_0, i_1, j_1 in T.grid(8, 8, 8): - with T.block("D"): + with T.sblock("D"): vi = T.axis.spatial(64, i_0_m * 8 + i_1) vj = T.axis.spatial(64, j_0 * 8 + j_1) T.reads(B[vi, vj]) @@ -92,21 +92,21 @@ def elementwise_merged2(a: T.handle, c: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (64, 64)) B = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for i_0_m, j_0_m in T.grid(8, 8): for i_1, j_1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(128, i_0_m * 16 + i_1) vj = T.axis.spatial(128, j_0_m * 16 + j_1) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) for i_1, j_1 in T.grid(8, 8): - with T.block("D"): + with T.sblock("D"): vi = T.axis.spatial(64, i_0_m * 8 + i_1) vj = T.axis.spatial(64, j_0_m * 8 + j_1) T.reads(B[vi, vj]) @@ -116,8 +116,8 @@ def elementwise_merged2(a: T.handle, c: T.handle, d: T.handle) -> None: def test_merge(): sch = tir.Schedule(elementwise, debug_mask="all") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_c = sch.get_sblock("C") + block_d = sch.get_sblock("D") i = sch.get_loops(block_c)[0] j = sch.get_loops(block_d)[0] sch.merge(i, j) @@ -127,8 +127,8 @@ def test_merge(): def test_merge2(): sch = tir.Schedule(elementwise, debug_mask="all") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_c = sch.get_sblock("C") + block_d = sch.get_sblock("D") i = sch.get_loops(block_c)[1] j = sch.get_loops(block_d)[1] sch.merge(i, j) @@ -145,23 +145,23 @@ def elementwise_with_seq(a: T.handle, c: T.handle) -> None: D = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("D"): + with T.sblock("D"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) D[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = B[vi, vj, vk] * 2.0 sch = tir.Schedule(elementwise_with_seq, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, b = sch.get_loops(block_b) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") _, _, c = sch.get_loops(block_c) with pytest.raises(tvm.tir.ScheduleError): sch.merge(b, c) @@ -175,19 +175,19 @@ def elementwise_loops_not_start_with_zero(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(1, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 sch = tir.Schedule(elementwise_loops_not_start_with_zero, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, b = sch.get_loops(block_b) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") _, _, c = sch.get_loops(block_c) with pytest.raises(tvm.tir.ScheduleError): sch.merge(b, c) @@ -201,19 +201,19 @@ def elementwise_loops_not_same_extent(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((64, 128, 128)) for i, j in T.grid(64, 128): for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 sch = tir.Schedule(elementwise_loops_not_same_extent, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, b = sch.get_loops(block_b) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") _, _, c = sch.get_loops(block_c) with pytest.raises(tvm.tir.ScheduleError): sch.merge(b, c) @@ -227,19 +227,19 @@ def elementwise_not_same_level(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 sch = tir.Schedule(elementwise_not_same_level, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, b, _ = sch.get_loops(block_b) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") _, _, c = sch.get_loops(block_c) with pytest.raises(tvm.tir.ScheduleError): sch.merge(b, c) @@ -251,22 +251,22 @@ def elementwise_with_different_scope(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) C = T.match_buffer(c, (128, 128, 128)) B = T.alloc_buffer((128, 128, 128)) - with T.block("A"): + with T.sblock("A"): for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 sch = tir.Schedule(elementwise_with_different_scope, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, b = sch.get_loops(block_b) - block_c = sch.get_block("C") + block_c = sch.get_sblock("C") _, _, c = sch.get_loops(block_c) with pytest.raises(tvm.tir.ScheduleError): sch.merge(b, c) diff --git a/tests/python/tir-schedule/test_tir_schedule_pad_einsum.py b/tests/python/tir-schedule/test_tir_schedule_pad_einsum.py index c2a8c3c51652..a7226381e260 100644 --- a/tests/python/tir-schedule/test_tir_schedule_pad_einsum.py +++ b/tests/python/tir-schedule/test_tir_schedule_pad_einsum.py @@ -38,21 +38,21 @@ def matmul_before( B_shared = T.alloc_buffer((127, 127), "float32", scope="shared") C_shared = T.alloc_buffer((128, 127), "float32", scope="shared") for i0, i1 in T.grid(128, 127): - with T.block("A"): + with T.sblock("A"): i, j = T.axis.remap("SS", [i0, i1]) A_shared[i, j] = A[i, j] for i0, i1 in T.grid(127, 127): - with T.block("B"): + with T.sblock("B"): i, j = T.axis.remap("SS", [i0, i1]) B_shared[i, j] = B[i, j] for i0, i1, i2 in T.grid(128, 127, 127): - with T.block("C_shared"): + with T.sblock("C_shared"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) with T.init(): C_shared[i, j] = T.float32(0) C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j] for i0, i1 in T.grid(128, 127): - with T.block("C"): + with T.sblock("C"): i, j = T.axis.remap("SS", [i0, i1]) C[i, j] = C_shared[i, j] @@ -67,13 +67,13 @@ def matmul_expected( B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") for i0, i1 in T.grid(128, 128): - with T.block("A"): + with T.sblock("A"): i, j = T.axis.remap("SS", [i0, i1]) T.reads(A[i, j]) T.writes(A_shared_padded[i, j]) A_shared_padded[i, j] = T.if_then_else(j < 127, A[i, j], T.float32(0), dtype="float32") for i0, i1 in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): i, j = T.axis.remap("SS", [i0, i1]) T.reads(B[i, j]) T.writes(B_shared_padded[i, j]) @@ -81,7 +81,7 @@ def matmul_expected( i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32" ) for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("C_shared"): + with T.sblock("C_shared"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(A_shared_padded[i, k], B_shared_padded[k, j]) T.writes(C_shared_padded[i, j]) @@ -91,7 +91,7 @@ def matmul_expected( C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j] ) for i0, i1 in T.grid(128, 127): - with T.block("C"): + with T.sblock("C"): i, j = T.axis.remap("SS", [i0, i1]) T.reads(C_shared_padded[i, j]) T.writes(C[i, j]) @@ -115,7 +115,7 @@ def matmul_before( B = T.match_buffer(b, (n, 128), "float32") C = T.match_buffer(c, (128, n), "float32") for i0, i1, i2 in T.grid(128, n, 128): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) with T.init(): C[i, j] = T.float32(0) @@ -134,11 +134,11 @@ def matmul_after( B_pad = T.alloc_buffer(((n + 31) // 32 * 32, 128)) C_pad = T.alloc_buffer((128, (n + 31) // 32 * 32)) for i0, i1 in T.grid((n + 31) // 32 * 32, 128): - with T.block("B_pad"): + with T.sblock("B_pad"): v0, v1 = T.axis.remap("SS", [i0, i1]) B_pad[v0, v1] = T.if_then_else(v0 < n, B[v0, v1], T.float32(0)) for i0, i1, i2 in T.grid(128, (n + 31) // 32 * 32, 128): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(A[i, k], B_pad[j, k]) T.writes(C_pad[i, j]) @@ -146,12 +146,12 @@ def matmul_after( C_pad[i, j] = T.float32(0) C_pad[i, j] = C_pad[i, j] + A[i, k] * B_pad[j, k] for i0, i1 in T.grid(128, n): - with T.block("C_pad"): + with T.sblock("C_pad"): v0, v1 = T.axis.remap("SS", [i0, i1]) C[v0, v1] = C_pad[v0, v1] sch = tir.Schedule(matmul_before, debug_mask="all") - C = sch.get_block("C") + C = sch.get_sblock("C") sch.pad_einsum(C, [32, 32, 32]) assert_structural_equal_ignore_global_symbol(matmul_after, sch.mod["main"]) verify_trace_roundtrip(sch, mod=matmul_before) @@ -173,7 +173,7 @@ def before( D = T.match_buffer(d, (1, n, 11008)) C = T.alloc_buffer((1, n, 11008)) for i0, i1, i2, k in T.grid(1, n, 11008, 4096): - with T.block("C"): + with T.sblock("C"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(C[v_i0, v_i1, v_i2]) @@ -181,7 +181,7 @@ def before( C[v_i0, v_i1, v_i2] = T.float32(0) C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] for ax0, ax1, ax2 in T.grid(1, n, 11008): - with T.block("D"): + with T.sblock("D"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) D[v_ax0, v_ax1, v_ax2] = M[v_ax0, v_ax1, v_ax2] * C[v_ax0, v_ax1, v_ax2] @@ -193,16 +193,16 @@ def after(a: T.handle, b: T.handle, m: T.handle, d: T.handle): B = T.match_buffer(b, (11008, 4096)) M = T.match_buffer(m, (1, n, 11008)) D = T.match_buffer(d, (1, n, 11008)) - # with T.block("root"): + # with T.sblock("root"): C = T.alloc_buffer((1, n, 11008)) A_pad = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096)) C_pad = T.alloc_buffer((1, (n + 31) // 32 * 32, 11008)) for i0, i1, i2 in T.grid(1, (n + 31) // 32 * 32, 4096): - with T.block("A_pad"): + with T.sblock("A_pad"): v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) A_pad[v0, v1, v2] = T.if_then_else(v1 < n, A[v0, v1, v2], T.float32(0)) for i0, i1, i2, k in T.grid(1, (n + 31) // 32 * 32, 11008, 4096): - with T.block("C"): + with T.sblock("C"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A_pad[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(C_pad[v_i0, v_i1, v_i2]) @@ -212,16 +212,16 @@ def after(a: T.handle, b: T.handle, m: T.handle, d: T.handle): C_pad[v_i0, v_i1, v_i2] + A_pad[v_i0, v_i1, v_k] * B[v_i2, v_k] ) for i0, i1, i2 in T.grid(1, n, 11008): - with T.block("C_pad"): + with T.sblock("C_pad"): v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) C[v0, v1, v2] = C_pad[v0, v1, v2] for ax0, ax1, ax2 in T.grid(1, n, 11008): - with T.block("D"): + with T.sblock("D"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) D[v_ax0, v_ax1, v_ax2] = M[v_ax0, v_ax1, v_ax2] * C[v_ax0, v_ax1, v_ax2] sch = tir.Schedule(before, debug_mask="all") - C = sch.get_block("C") + C = sch.get_sblock("C") sch.pad_einsum(C, [1, 32, 32, 32]) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch, mod=before) @@ -241,7 +241,7 @@ def before( R = T.match_buffer(r, (1, n, 4096), "float32") S = T.alloc_buffer((1, n), "float32") for bsz, i, k in T.grid(1, n, 4096): - with T.block("S"): + with T.sblock("S"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(S[v_bsz, v_i]) @@ -249,7 +249,7 @@ def before( S[v_bsz, v_i] = T.float32(0) S[v_bsz, v_i] = S[v_bsz, v_i] + A[v_bsz, v_i, v_k] * A[v_bsz, v_i, v_k] for bsz, i, k in T.grid(1, n, 4096): - with T.block("R"): + with T.sblock("R"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) R[v_bsz, v_i, v_k] = W[v_k] * ( A[v_bsz, v_i, v_k] @@ -267,11 +267,11 @@ def after(a: T.handle, w: T.handle, r: T.handle): A_pad = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096)) S_pad = T.alloc_buffer((1, (n + 31) // 32 * 32)) for i0, i1, i2 in T.grid(1, (n + 31) // 32 * 32, 4096): - with T.block("A_pad"): + with T.sblock("A_pad"): v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) A_pad[v0, v1, v2] = T.if_then_else(v1 < n, A[v0, v1, v2], T.float32(0)) for bsz, i, k in T.grid(1, (n + 31) // 32 * 32, 4096): - with T.block("S"): + with T.sblock("S"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A_pad[v_bsz, v_i, v_k]) T.writes(S_pad[v_bsz, v_i]) @@ -281,11 +281,11 @@ def after(a: T.handle, w: T.handle, r: T.handle): S_pad[v_bsz, v_i] + A_pad[v_bsz, v_i, v_k] * A_pad[v_bsz, v_i, v_k] ) for i0, i1 in T.grid(1, n): - with T.block("S_pad"): + with T.sblock("S_pad"): v0, v1 = T.axis.remap("SS", [i0, i1]) S[v0, v1] = S_pad[v0, v1] for bsz, i, k in T.grid(1, n, 4096): - with T.block("R"): + with T.sblock("R"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) R[v_bsz, v_i, v_k] = W[v_k] * ( A[v_bsz, v_i, v_k] @@ -293,7 +293,7 @@ def after(a: T.handle, w: T.handle, r: T.handle): ) sch = tir.Schedule(before, debug_mask="all") - C = sch.get_block("S") + C = sch.get_sblock("S") sch.pad_einsum(C, [1, 32, 1]) assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) verify_trace_roundtrip(sch, mod=before) diff --git a/tests/python/tir-schedule/test_tir_schedule_partition.py b/tests/python/tir-schedule/test_tir_schedule_partition.py index 08595843e3aa..fbcdd2ece58f 100644 --- a/tests/python/tir-schedule/test_tir_schedule_partition.py +++ b/tests/python/tir-schedule/test_tir_schedule_partition.py @@ -34,7 +34,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -44,7 +44,7 @@ def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k in T.grid(128, 128, n): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -55,7 +55,7 @@ def elementwise_with_anno(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128, annotations={"useless_annotation": True}): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -68,7 +68,7 @@ def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -80,10 +80,10 @@ def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("opaque"): + with T.sblock("opaque"): T.reads([A[i, j, k]]) T.writes([B[i, j, k]]) - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -94,33 +94,33 @@ def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: def elementwise_partition_with_opaque_block(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block("opaque_i_common"): + with T.sblock("opaque_i_common"): T.reads() T.writes() - with T.block("opaque_i0_partition"): + with T.sblock("opaque_i0_partition"): T.reads() T.writes() for i0, j, k in T.grid(112, 128, 128): - with T.block("opaque_i0"): + with T.sblock("opaque_i0"): T.reads(A[i0, j, k]) T.writes(B[i0, j, k]) - with T.block("B_i0"): + with T.sblock("B_i0"): vi, vj, vk = T.axis.remap("SSS", [i0, j, k]) T.reads(A[0:112, 0:128, 0:128]) T.writes(B[0:112, 0:128, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("opaque_i1_partition"): + with T.sblock("opaque_i1_partition"): T.reads() T.writes() for i1 in range(112, 128): for j, k in T.grid(128, 128): - with T.block("opaque_i1"): + with T.sblock("opaque_i1"): T.reads(A[i1, j, k]) T.writes(B[i1, j, k]) - with T.block("B_i1"): + with T.sblock("B_i1"): vi, vj, vk = T.axis.remap("SSS", [i1, j, k]) T.reads(A[112:128, 0:128, 0:128]) T.writes(B[112:128, 0:128, 0:128]) @@ -131,74 +131,74 @@ def elementwise_partition_with_opaque_block(a: T.handle, b: T.handle) -> None: def elementwise_loop_partition_case0(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block("B_i_common"): + with T.sblock("B_i_common"): T.reads() T.writes() - with T.block("B_i0_partition"): + with T.sblock("B_i0_partition"): T.reads() T.writes() for i0 in range(2): - with T.block("B_i0_j_common"): + with T.sblock("B_i0_j_common"): T.reads() T.writes() - with T.block("B_i0_j0_partition"): + with T.sblock("B_i0_j0_partition"): T.reads() T.writes() for j0, k in T.grid(4, 128): - with T.block("B_i0_j0"): + with T.sblock("B_i0_j0"): vi, vj, vk = T.axis.remap("SSS", [i0, j0, k]) T.reads(A[0:2, 0:4, 0:128]) T.writes(B[0:2, 0:4, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i0_j1_partition"): + with T.sblock("B_i0_j1_partition"): T.reads() T.writes() for j1 in range(4, 36): for k in range(128): - with T.block("B_i0_j1"): + with T.sblock("B_i0_j1"): vi, vj, vk = T.axis.remap("SSS", [i0, j1, k]) T.reads(A[0:2, 4:36, 0:128]) T.writes(B[0:2, 4:36, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i0_j2_partition"): + with T.sblock("B_i0_j2_partition"): T.reads() T.writes() for j2 in range(36, 128): for k in range(128): - with T.block("B_i0_j2"): + with T.sblock("B_i0_j2"): vi, vj, vk = T.axis.remap("SSS", [i0, j2, k]) T.reads(A[0:2, 36:128, 0:128]) T.writes(B[0:2, 36:128, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i1_partition"): + with T.sblock("B_i1_partition"): T.reads() T.writes() for i1 in range(2, 3): for j, k in T.grid(128, 128): - with T.block("B_i1"): + with T.sblock("B_i1"): vi, vj, vk = T.axis.remap("SSS", [i1, j, k]) T.reads(A[2, 0:128, 0:128]) T.writes(B[2, 0:128, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i2_partition"): + with T.sblock("B_i2_partition"): T.reads() T.writes() for i2 in range(3, 67): for j, k in T.grid(128, 128): - with T.block("B_i2"): + with T.sblock("B_i2"): vi, vj, vk = T.axis.remap("SSS", [i2, j, k]) T.reads(A[3:67, 0:128, 0:128]) T.writes(B[3:67, 0:128, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i3_partition"): + with T.sblock("B_i3_partition"): T.reads() T.writes() for i3 in range(67, 128): for j, k in T.grid(128, 128): - with T.block("B_i3"): + with T.sblock("B_i3"): vi, vj, vk = T.axis.remap("SSS", [i3, j, k]) T.reads(A[67:128, 0:128, 0:128]) T.writes(B[67:128, 0:128, 0:128]) @@ -209,62 +209,62 @@ def elementwise_loop_partition_case0(a: T.handle, b: T.handle) -> None: def elementwise_loop_partition_case1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block("B_i_common"): + with T.sblock("B_i_common"): T.reads() T.writes() - with T.block("B_i0_partition"): + with T.sblock("B_i0_partition"): T.reads() T.writes() for i0, j, k in T.grid(63, 128, 128): - with T.block("B_i0"): + with T.sblock("B_i0"): vi, vj, vk = T.axis.remap("SSS", [i0, j, k]) T.reads(A[0:63, 0:128, 0:128]) T.writes(B[0:63, 0:128, 0:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i1_partition"): + with T.sblock("B_i1_partition"): T.reads() T.writes() for i1 in range(63, 64): for j in range(128): - with T.block("B_i1_k_common"): + with T.sblock("B_i1_k_common"): T.reads() T.writes() - with T.block("B_i1_k0_partition"): + with T.sblock("B_i1_k0_partition"): T.reads() T.writes() for k0 in range(1): - with T.block("B_i1_k0"): + with T.sblock("B_i1_k0"): vi, vj, vk = T.axis.remap("SSS", [i1, j, k0]) T.reads(A[63, 0:128, 0]) T.writes(B[63, 0:128, 0]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i1_k1_partition"): + with T.sblock("B_i1_k1_partition"): T.reads() T.writes() for k1 in range(1, 65): - with T.block("B_i1_k1"): + with T.sblock("B_i1_k1"): vi, vj, vk = T.axis.remap("SSS", [i1, j, k1]) T.reads(A[63, 0:128, 1:65]) T.writes(B[63, 0:128, 1:65]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i1_k2_partition"): + with T.sblock("B_i1_k2_partition"): T.reads() T.writes() for k2 in range(65, 128): - with T.block("B_i1_k2"): + with T.sblock("B_i1_k2"): vi, vj, vk = T.axis.remap("SSS", [i1, j, k2]) T.reads(A[63, 0:128, 65:128]) T.writes(B[63, 0:128, 65:128]) B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) - with T.block("B_i2_partition"): + with T.sblock("B_i2_partition"): T.reads() T.writes() for i2 in range(64, 128): for j, k in T.grid(128, 128): - with T.block("B_i2"): + with T.sblock("B_i2"): vi, vj, vk = T.axis.remap("SSS", [i2, j, k]) T.reads(A[64:128, 0:128, 0:128]) T.writes(B[64:128, 0:128, 0:128]) @@ -276,13 +276,13 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for i, j in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i, j in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) @@ -294,45 +294,45 @@ def opaque_access_loop_partition(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16)) B = T.match_buffer(b, (16, 16)) for i in range(16): - with T.block("A_j_common"): + with T.sblock("A_j_common"): T.reads() T.writes() - with T.block("A_j0_partition"): + with T.sblock("A_j0_partition"): T.reads() T.writes() for j0 in range(12): - with T.block("A_j0"): + with T.sblock("A_j0"): vi, vj = T.axis.remap("SS", [i, j0]) T.reads() T.writes(A[0:16, 0:12]) A[vi, vj] = T.float32(1) - with T.block("A_j1_partition"): + with T.sblock("A_j1_partition"): T.reads() T.writes() for j1 in range(12, 16): - with T.block("A_j1"): + with T.sblock("A_j1"): vi, vj = T.axis.remap("SS", [i, j1]) T.reads() T.writes(A[0:16, 12:16]) A[vi, vj] = T.float32(1) for i in range(16): - with T.block("B_j_common"): + with T.sblock("B_j_common"): T.reads() T.writes() - with T.block("B_j0_partition"): + with T.sblock("B_j0_partition"): T.reads() T.writes() for j0 in range(12): - with T.block("B_j0"): + with T.sblock("B_j0"): vi, vj = T.axis.remap("SS", [i, j0]) T.reads() T.writes(B[0:16, 0:16]) T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj) - with T.block("B_j1_partition"): + with T.sblock("B_j1_partition"): T.reads() T.writes() for j1 in range(12, 16): - with T.block("B_j1"): + with T.sblock("B_j1"): vi, vj = T.axis.remap("SS", [i, j1]) T.reads() T.writes(B[0:16, 0:16]) @@ -344,11 +344,11 @@ def opaque_access_loop_partition(a: T.handle, b: T.handle) -> None: def test_loop_partition(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.loop_partition(i, factors=[2, 1, 64]) - block_b_partition = sch.get_block("B_i0") + block_b_partition = sch.get_sblock("B_i0") i, j, k = sch.get_loops(block_b_partition) loops = sch.loop_partition(j, factors=[4, 32]) @@ -358,11 +358,11 @@ def test_loop_partition(): def test_partition_with_inferred_factor(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.loop_partition(i, factors=[None, 1, 64]) - block_b_partition = sch.get_block("B_i1") + block_b_partition = sch.get_sblock("B_i1") i, j, k = sch.get_loops(block_b_partition) sch.loop_partition(k, factors=[1, 64, None]) @@ -372,7 +372,7 @@ def test_partition_with_inferred_factor(): def test_partition_with_opaque_block(): sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") - block_opaque = sch.get_block("opaque") + block_opaque = sch.get_sblock("opaque") i, _, _ = sch.get_loops(block_opaque) sch.loop_partition(i, factors=[None, 16]) assert_structural_equal_ignore_global_symbol( @@ -383,10 +383,10 @@ def test_partition_with_opaque_block(): def test_partition_with_opaque_access(): sch = tir.Schedule(opaque_access, debug_mask="all") - block_a = sch.get_block("A") + block_a = sch.get_sblock("A") _, j = sch.get_loops(block_a) sch.loop_partition(j, factors=[None, 4]) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j = sch.get_loops(block_b) sch.loop_partition(j, factors=[None, 4]) assert_structural_equal_ignore_global_symbol(opaque_access_loop_partition, sch.mod["main"]) @@ -402,7 +402,7 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - (i,) = sch.get_loops(sch.get_block("B")) + (i,) = sch.get_loops(sch.get_sblock("B")) sch.loop_partition( i, factors=[ @@ -414,7 +414,7 @@ def _create_prim_func(): def test_partition_fail_symbolic(): sch = tir.Schedule(elementwise_symbolic, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.loop_partition(k, factors=[10, None]) @@ -422,7 +422,7 @@ def test_partition_fail_symbolic(): def test_partition_fail_out_of_bound(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.loop_partition(i, factors=[1000, 2, 3]) @@ -430,7 +430,7 @@ def test_partition_fail_out_of_bound(): def test_partition_with_non_positive_factors(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.loop_partition(i, factors=[-2, -64]) @@ -442,7 +442,7 @@ def test_partition_with_non_positive_factors(): def test_partition_fail_with_annotation(): sch = tir.Schedule(elementwise_with_anno, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.loop_partition(k, factors=[None, 10]) @@ -450,7 +450,7 @@ def test_partition_fail_with_annotation(): def test_partition_fail_with_thread_binding(): sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.loop_partition(k, factors=[None, 10]) diff --git a/tests/python/tir-schedule/test_tir_schedule_read_write_at.py b/tests/python/tir-schedule/test_tir_schedule_read_write_at.py index 930b873a51b7..4b0c823586ba 100644 --- a/tests/python/tir-schedule/test_tir_schedule_read_write_at.py +++ b/tests/python/tir-schedule/test_tir_schedule_read_write_at.py @@ -44,7 +44,7 @@ def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disab for k0 in T.serial(0, 256): for k1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -68,17 +68,17 @@ def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k0 in T.serial(0, 256): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.S(32, by) v1 = T.axis.S(256, k0) T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(64, 8): A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] for k1 in T.unroll(0, 8): for v_, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -103,25 +103,25 @@ def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k0 in T.serial(0, 256): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.S(32, by) v1 = T.axis.S(256, k0) T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(64, 8): A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.S(256, k0) v1 = T.axis.S(32, bx) T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(8, 64): B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] for k1 in T.unroll(0, 8): for v_, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -146,25 +146,25 @@ def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k0 in T.serial(0, 256): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.S(32, by) v1 = T.axis.S(256, k0) T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(64, 8): A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.S(256, k0) v1 = T.axis.S(32, bx) T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(8, 64): B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] for k1 in T.unroll(0, 8): for v_, i, j in T.grid(1, 4, 4): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) @@ -173,12 +173,12 @@ def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: with T.init(): C_shared[vi, vj] = T.float32(0) C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] - with T.block("C_shared"): + with T.sblock("C_shared"): v0 = T.axis.S(32, by) v1 = T.axis.S(32, bx) T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(64, 64): C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] @@ -189,7 +189,7 @@ def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: def test_read_at_global_to_shared_a(): sch = tir.Schedule(cuda_matmul, debug_mask="all") - block = sch.get_block("C") + block = sch.get_sblock("C") # pylint: disable=invalid-name _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) # pylint: enable=invalid-name @@ -200,7 +200,7 @@ def test_read_at_global_to_shared_a(): def test_read_at_global_to_shared_ab(): sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") - block = sch.get_block("C") + block = sch.get_sblock("C") # pylint: disable=invalid-name _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) # pylint: enable=invalid-name @@ -211,7 +211,7 @@ def test_read_at_global_to_shared_ab(): def test_read_at_local_to_shared_c(): sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") - block = sch.get_block("C") + block = sch.get_sblock("C") # pylint: disable=invalid-name _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) # pylint: enable=invalid-name diff --git a/tests/python/tir-schedule/test_tir_schedule_reduction.py b/tests/python/tir-schedule/test_tir_schedule_reduction.py index 4ed3c6178fce..2e7552e0565d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_reduction.py +++ b/tests/python/tir-schedule/test_tir_schedule_reduction.py @@ -35,15 +35,15 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): - with T.block("blockized_B"): + with T.sblock("blockized_B"): io, ko = T.axis.remap("SR", [i0, i2_0]) with T.init(): for i1 in T.serial(0, 4): - with T.block("B_init"): + with T.sblock("B_init"): ii_init = T.axis.S(4, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): - with T.block("B"): + with T.sblock("B"): ii = T.axis.S(4, i1_1) k = T.axis.R(128, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -55,7 +55,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -69,12 +69,12 @@ def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -85,17 +85,17 @@ def matmul_decompose1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4], elem_offset=0, align=64, offset_factor=1) for i0 in T.serial(0, 32): - with T.block("blockized_B_init"): + with T.sblock("blockized_B_init"): io = T.axis.S(32, i0) for i1 in T.serial(0, 4): - with T.block("B_init"): + with T.sblock("B_init"): ii = T.axis.S(4, i1) B[io, ii] = T.float32(0) for i0, i2_o in T.grid(32, 16): - with T.block("blockized_B_update"): + with T.sblock("blockized_B_update"): io, ko = T.axis.remap("SR", [i0, i2_o]) for i1, i2_i in T.grid(4, 8): - with T.block("B"): + with T.sblock("B"): ii = T.axis.S(4, i1) k = T.axis.R(128, ko * 8 + i2_i) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -108,11 +108,11 @@ def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1) for i0, i1 in T.grid(128, 128): - with T.block("update_init"): + with T.sblock("update_init"): vi_init, vj_init = T.axis.remap("SS", [i0, i1]) C[vi_init, vj_init] = T.float32(0) for i2 in T.serial(0, 128): - with T.block("update_update"): + with T.sblock("update_update"): vi, vj, vk = T.axis.remap("SSR", [i0, i1, i2]) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @@ -124,7 +124,7 @@ def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, k, j in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -137,17 +137,17 @@ def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128], elem_offset=0, align=64, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1) # body - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): - with T.block("update_init"): + with T.sblock("update_init"): vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init) vj_init = T.axis.S(128, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): - with T.block("update_update"): + with T.sblock("update_update"): T.where((((i2_0 * 7) + i2_1) < 128)) vi = T.axis.S(128, i0_0 * 8 + i0_1) vj = T.axis.S(128, i1) @@ -161,8 +161,8 @@ def matmul_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): - T.block_attr({"test_annotation": 1}) + with T.sblock("update"): + T.sblock_attr({"test_annotation": 1}) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -176,14 +176,14 @@ def matmul_decompose_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> N C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): - T.block_attr({"test_annotation": 1}) + with T.sblock("init"): + T.sblock_attr({"test_annotation": 1}) vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for i, j, k in T.grid(128, 128, 128): - with T.block("update"): - T.block_attr({"test_annotation": 1}) + with T.sblock("update"): + T.sblock_attr({"test_annotation": 1}) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -194,7 +194,7 @@ def colsum_with_vectorization(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32], dtype="float32") for k in T.serial(0, 128): for i in T.vectorized(0, 32): - with T.block("B"): + with T.sblock("B"): vk, vi = T.axis.remap("RS", [k, i]) with T.init(): B[vi] = T.float32(0) @@ -206,12 +206,12 @@ def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 32], dtype="float32") B = T.match_buffer(b, [32], dtype="float32") for i in T.vectorized(0, 32): - with T.block("B_init"): + with T.sblock("B_init"): vi = T.axis.S(32, i) B[vi] = T.float32(0) for k in T.serial(0, 128): for i in T.vectorized(0, 32): - with T.block("B"): + with T.sblock("B"): vk, vi = T.axis.remap("RS", [k, i]) B[vi] = B[vi] + A[vk, vi] @@ -223,7 +223,7 @@ def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None: def test_reduction_decompose0(use_block_name): s = tir.Schedule(matmul, debug_mask="all") - C = "update" if use_block_name else s.get_block("update") + C = "update" if use_block_name else s.get_sblock("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, i) assert_structural_equal_ignore_global_symbol(matmul_decompose0, s.mod["main"]) @@ -232,7 +232,7 @@ def test_reduction_decompose0(use_block_name): def test_reduction_decompose1(use_block_name): s = tir.Schedule(rowsum_blockized, debug_mask="all") - blockized_B = "blockized_B" if use_block_name else s.get_block("blockized_B") + blockized_B = "blockized_B" if use_block_name else s.get_sblock("blockized_B") io, ko = s.get_loops(blockized_B) s.decompose_reduction(blockized_B, io) assert_structural_equal_ignore_global_symbol(matmul_decompose1, s.mod["main"]) @@ -241,7 +241,7 @@ def test_reduction_decompose1(use_block_name): def test_reduction_decompose2(): s = tir.Schedule(matmul, debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, k) assert_structural_equal_ignore_global_symbol(matmul_decompose2, s.mod["main"]) @@ -250,7 +250,7 @@ def test_reduction_decompose2(): def test_reduction_decompose3(): s = tir.Schedule(matmul_decompose_fail3, debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") i, j, k = s.get_loops(C) with pytest.raises(tvm.tir.ScheduleError): s.decompose_reduction(C, k) @@ -258,7 +258,7 @@ def test_reduction_decompose3(): def test_reduction_decompose4(): s = tir.Schedule(matmul, debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") i, j, k = s.get_loops(C) io, ii = s.split(i, factors=[16, 8]) ko, ki = s.split(k, factors=[19, 7]) @@ -269,7 +269,7 @@ def test_reduction_decompose4(): def test_reduction_decompose_with_annotation(): s = tir.Schedule(matmul_with_annotation, debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, i) assert_structural_equal_ignore_global_symbol(matmul_decompose_with_annotation, s.mod["main"]) @@ -278,12 +278,12 @@ def test_reduction_decompose_with_annotation(): def test_reduction_decompose_with_different_for_kind(): s = tir.Schedule(colsum_with_vectorization, debug_mask="all") - B = s.get_block("B") + B = s.get_sblock("B") k, _ = s.get_loops(B) B_init = s.decompose_reduction(B, k) assert_structural_equal_ignore_global_symbol(s.mod["main"], colsum_decompose_with_vectorization) - assert s.get(B).same_as(s.get(s.get_block("B_update"))) - assert s.get(B_init).same_as(s.get(s.get_block("B_init"))) + assert s.get(B).same_as(s.get(s.get_sblock("B_update"))) + assert s.get(B_init).same_as(s.get(s.get_sblock("B_init"))) verify_trace_roundtrip(s, mod=colsum_with_vectorization) @@ -292,7 +292,7 @@ def test_decompose_reduction_ref_hash_check(): mod_bak = mod hash_before = tvm.ir.structural_hash(mod_bak) s = tir.Schedule(mod["main"], debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, k) hash_after = tvm.ir.structural_hash(mod_bak) @@ -303,49 +303,49 @@ def test_decompose_reduction_nested_block(): @T.prim_func def nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), "float32")): for i, ko in T.grid(1, 2): - with T.block("outer"): + with T.sblock("outer"): vi, vko = T.axis.remap("SR", [i, ko]) C = T.alloc_buffer((32,), dtype="float32") with T.init(): B[vi] = T.float32(0) for ki in T.serial(32): - with T.block("inner_1"): + with T.sblock("inner_1"): vki = T.axis.remap("S", [ki]) C[vki] = A[vi, vko * 32 + vki] for ki in T.serial(32): - with T.block("inner_2"): + with T.sblock("inner_2"): vki = T.axis.remap("R", [ki]) B[vi] += C[vki] @T.prim_func def decomposed_nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), "float32")): for i in range(1): - with T.block("outer_init"): + with T.sblock("outer_init"): vi = T.axis.spatial(1, i) T.reads() T.writes(B[vi]) B[vi] = T.float32(0) for ko in range(2): - with T.block("outer_update"): + with T.sblock("outer_update"): vi, vko = T.axis.remap("SR", [i, ko]) T.reads(B[vi], A[vi, vko * 32 : vko * 32 + 32]) T.writes(B[vi]) C = T.alloc_buffer((32,)) for ki in range(32): - with T.block("inner_1"): + with T.sblock("inner_1"): vki = T.axis.spatial(32, ki) T.reads(A[vi, vko * 32 + vki]) T.writes(C[vki]) C[vki] = A[vi, vko * 32 + vki] for ki in range(32): - with T.block("inner_2"): + with T.sblock("inner_2"): vki = T.axis.reduce(32, ki) T.reads(B[vi], C[vki]) T.writes(B[vi]) B[vi] = B[vi] + C[vki] sch = tir.Schedule(nested_block, debug_mask="all") - outer = sch.get_block("outer") + outer = sch.get_sblock("outer") i, ko = sch.get_loops(outer) sch.decompose_reduction(outer, ko) @@ -367,7 +367,7 @@ def func(mod): def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): for t in T.thread_binding(0, 32, thread="threadIdx.x"): for r in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi, vr = T.axis.remap("SR", [t, r]) with T.init(): B[vi] = T.float32(0) @@ -376,12 +376,12 @@ def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): @T.prim_func def expected(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): for t_init in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("B_init"): + with T.sblock("B_init"): vi = T.axis.remap("S", [t_init]) B[vi] = T.float32(0) for t in T.thread_binding(0, 32, thread="threadIdx.x"): for r in T.serial(16): - with T.block("B"): + with T.sblock("B"): vi, vr = T.axis.remap("SR", [t, r]) B[vi] += A[vi, vr] diff --git a/tests/python/tir-schedule/test_tir_schedule_reindex.py b/tests/python/tir-schedule/test_tir_schedule_reindex.py index a410c293bcb3..3d69db195cd8 100644 --- a/tests/python/tir-schedule/test_tir_schedule_reindex.py +++ b/tests/python/tir-schedule/test_tir_schedule_reindex.py @@ -32,7 +32,7 @@ def transpose_elementwise( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32") ) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] * 2.0 @@ -43,11 +43,11 @@ def transpose_elementwise_reindex_read( ) -> None: A_reindex = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("A_reindex"): + with T.sblock("A_reindex"): vi, vj = T.axis.remap("SS", [i, j]) A_reindex[vi, vj] = A[vj, vi] for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_reindex[vi, vj] * 2.0 @@ -60,7 +60,7 @@ def conv2d_nhwc( ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), @@ -69,7 +69,7 @@ def conv2d_nhwc( dtype="float32", ) for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): Conv2d_nhwc[n, h, w, co] = T.float32(0) @@ -88,7 +88,7 @@ def conv2d_nhwc_reindex_data( PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") ReindexInput = T.alloc_buffer([1, 112, 112, 7, 7, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), @@ -97,11 +97,11 @@ def conv2d_nhwc_reindex_data( dtype="float32", ) for i0, i1, i2, i3, i4, i5 in T.grid(1, 112, 112, 7, 7, 3): - with T.block("ReindexInput"): + with T.sblock("ReindexInput"): n, h, w, rh, rw, rc = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) ReindexInput[n, h, w, rh, rw, rc] = PadInput[n, ((h * 2) + rh), ((w * 2) + rw), rc] for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): Conv2d_nhwc[n, h, w, co] = T.float32(0) @@ -120,7 +120,7 @@ def conv2d_nhwc_reindex_weight( PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) @@ -131,13 +131,13 @@ def conv2d_nhwc_reindex_weight( dtype="float32", ) for ax3, ax4, ax5, ax6 in T.grid(64, 7, 7, 3): - with T.block("weight_reindex"): + with T.sblock("weight_reindex"): v3, v4, v5, v6 = T.axis.remap("SSSS", [ax3, ax4, ax5, ax6]) T.reads(weight[v4, v5, v6, v3]) T.writes(weight_reindex[v3, v4, v5, v6]) weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3] for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads( PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], @@ -160,7 +160,7 @@ def matmul( C: T.Buffer((512, 512), "float32"), ) -> None: for i0, i1, i2 in T.grid(512, 512, 512): - with T.block("matmul"): + with T.sblock("matmul"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C[i, j], A[i, k], B[k, j]) T.writes(C[i, j]) @@ -177,7 +177,7 @@ def matmul_reindex_write( ) -> None: C_reindex = T.alloc_buffer([512, 512], dtype="float32") for i0, i1, i2 in T.grid(512, 512, 512): - with T.block("matmul"): + with T.sblock("matmul"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C_reindex[i, j], A[i, k], B[k, j]) T.writes(C_reindex[i, j]) @@ -185,7 +185,7 @@ def matmul_reindex_write( C_reindex[i, j] = T.float32(0) C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j] for i0, i1 in T.grid(512, 512): - with T.block("C_reindex"): + with T.sblock("C_reindex"): v0, v1 = T.axis.remap("SS", [i0, i1]) T.reads(C_reindex[v0, v1]) T.writes(C[v0, v1]) @@ -195,7 +195,7 @@ def matmul_reindex_write( @T.prim_func def multiple_read(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] + A[vi, vj] @@ -207,7 +207,7 @@ def mixed_dtype( T_matmul_NT: T.Buffer((T.int64(2), 1280), "float16"), ) -> None: for i0, i1, i2 in T.grid(T.int64(2), 1280, 1280): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i = T.axis.spatial(T.int64(2), i0) j, k = T.axis.remap("SR", [i1, i2]) T.reads(p0[i, k], p1[j, k]) @@ -225,7 +225,7 @@ def mixed_dtype_reindex_write( ) -> None: T_matmul_NT_reindex = T.alloc_buffer([T.int64(2), 1280], dtype="float16") for i0, i1, i2 in T.grid(T.int64(2), 1280, 1280): - with T.block("T_matmul_NT"): + with T.sblock("T_matmul_NT"): i = T.axis.spatial(T.int64(2), i0) j, k = T.axis.remap("SR", [i1, i2]) T.reads(p0[i, k], p1[j, k]) @@ -234,7 +234,7 @@ def mixed_dtype_reindex_write( T_matmul_NT_reindex[i, j] = T.float16(0) T_matmul_NT_reindex[i, j] = T_matmul_NT_reindex[i, j] + p0[i, k] * p1[j, k] for ax0, ax1 in T.grid(T.int64(2), 1280): - with T.block("T_matmul_NT_reindex"): + with T.sblock("T_matmul_NT_reindex"): v0 = T.axis.spatial(T.int64(2), ax0) v1 = T.axis.remap("S", [ax1]) T.reads(T_matmul_NT_reindex[v0, v1]) @@ -249,7 +249,7 @@ def matmul_unit_dim( C: T.Buffer((1, 1), "float32"), ) -> None: for i0, i1, i2 in T.grid(1, 1, 512): - with T.block("matmul"): + with T.sblock("matmul"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C[i, j], A[i, k], B[k, j]) T.writes(C[i, j]) @@ -266,7 +266,7 @@ def matmul_unit_dim_reindex_write( ) -> None: C_reindex = T.alloc_buffer([1, 1], dtype="float32") for i0, i1, i2 in T.grid(1, 1, 512): - with T.block("matmul"): + with T.sblock("matmul"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C_reindex[i, j], A[i, k], B[k, j]) T.writes(C_reindex[i, j]) @@ -274,7 +274,7 @@ def matmul_unit_dim_reindex_write( C_reindex[i, j] = T.float32(0) C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j] for i0, i1 in T.grid(1, 1): - with T.block("C_reindex"): + with T.sblock("C_reindex"): v0, v1 = T.axis.remap("SS", [i0, i1]) T.reads(C_reindex[v0, v1]) T.writes(C[v0, v1]) @@ -287,7 +287,7 @@ def matmul_unit_dim_reindex_write( def test_reindex_read_basic(use_block_name, use_buffer_name): sch = tir.Schedule(transpose_elementwise) - block = "B" if use_block_name else sch.get_block("B") + block = "B" if use_block_name else sch.get_sblock("B") buf = "A" if use_buffer_name else ("read", 0) sch.reindex(block, buf) assert_structural_equal_ignore_global_symbol( @@ -298,7 +298,7 @@ def test_reindex_read_basic(use_block_name, use_buffer_name): def test_conv2d_reindex_weight(use_block_name, use_buffer_name): sch = tir.Schedule(conv2d_nhwc) - block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_sblock("conv2d_nhwc") buf = "Weight" if use_buffer_name else ("read", 1) sch.reindex(block, buf) assert_structural_equal_ignore_global_symbol(conv2d_nhwc_reindex_weight, sch.mod["main"]) @@ -307,7 +307,7 @@ def test_conv2d_reindex_weight(use_block_name, use_buffer_name): def test_conv2d_reindex_data(use_block_name, use_buffer_name): sch = tir.Schedule(conv2d_nhwc) - block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_sblock("conv2d_nhwc") buf = "PadInput" if use_buffer_name else ("read", 0) sch.reindex(block, buf) assert_structural_equal_ignore_global_symbol(conv2d_nhwc_reindex_data, sch.mod["main"]) @@ -316,7 +316,7 @@ def test_conv2d_reindex_data(use_block_name, use_buffer_name): def test_matmul_reindex_write(use_block_name, use_buffer_name): sch = tir.Schedule(matmul) - block = "matmul" if use_block_name else sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_sblock("matmul") buf = "C" if use_buffer_name else ("write", 0) sch.reindex(block, buf) assert_structural_equal_ignore_global_symbol(matmul_reindex_write, sch.mod["main"]) @@ -325,7 +325,7 @@ def test_matmul_reindex_write(use_block_name, use_buffer_name): def test_reindex_fail_multiple_read(use_block_name, use_buffer_name): sch = tir.Schedule(multiple_read) - block = "B" if use_block_name else sch.get_block("B") + block = "B" if use_block_name else sch.get_sblock("B") buf = "A" if use_buffer_name else ("read", 0) with pytest.raises(ScheduleError): sch.reindex(block, buf) @@ -333,7 +333,7 @@ def test_reindex_fail_multiple_read(use_block_name, use_buffer_name): def test_reindex_mixed_dtype(use_block_name, use_buffer_name): sch = tir.Schedule(mixed_dtype) - block = "T_matmul_NT" if use_block_name else sch.get_block("T_matmul_NT") + block = "T_matmul_NT" if use_block_name else sch.get_sblock("T_matmul_NT") buf = "T_matmul_NT" if use_buffer_name else ("write", 0) sch.reindex(block, buf) assert_structural_equal_ignore_global_symbol(mixed_dtype_reindex_write, sch.mod["main"]) @@ -342,7 +342,7 @@ def test_reindex_mixed_dtype(use_block_name, use_buffer_name): def test_matmul_unit_dim_reindex_write(use_block_name, use_buffer_name): sch = tir.Schedule(matmul_unit_dim) - block = "matmul" if use_block_name else sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_sblock("matmul") buf = "C" if use_buffer_name else ("write", 0) sch.reindex(block, buf) assert_structural_equal_ignore_global_symbol(matmul_unit_dim_reindex_write, sch.mod["main"]) diff --git a/tests/python/tir-schedule/test_tir_schedule_reorder.py b/tests/python/tir-schedule/test_tir_schedule_reorder.py index 7ca9d35ea09d..68c6da536579 100644 --- a/tests/python/tir-schedule/test_tir_schedule_reorder.py +++ b/tests/python/tir-schedule/test_tir_schedule_reorder.py @@ -35,7 +35,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -45,7 +45,7 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 8): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) vl = T.axis.S(128, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -57,7 +57,7 @@ def elementwise_dependent_loop(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128, 128)) for i in T.serial(0, 128): for j, k, l in T.grid(128, i, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -67,7 +67,7 @@ def elementwise_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -80,11 +80,11 @@ def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -94,10 +94,10 @@ def elementwise_with_loops_not_same_scope(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vk = T.axis.S(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -109,7 +109,7 @@ def elementwise_with_wrong_block_var_type(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) vk = T.axis.scan(128, k) T.reads([A[vi, vj, vk]]) @@ -122,7 +122,7 @@ def elementwise_reordered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -132,7 +132,7 @@ def elementwise_reordered2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for k, j, i, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -142,7 +142,7 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -153,13 +153,13 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for i, j in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i, j in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) @@ -171,13 +171,13 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for j, i in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for j, i in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) @@ -189,7 +189,7 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: def test_reorder(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) sch.reorder(l, i) assert_structural_equal_ignore_global_symbol(elementwise_reordered, sch.mod["main"]) @@ -198,7 +198,7 @@ def test_reorder(): def test_reorder2(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) sch.reorder(k, i, l) assert_structural_equal_ignore_global_symbol(elementwise_reordered2, sch.mod["main"]) @@ -207,10 +207,10 @@ def test_reorder2(): def test_reorder_with_opaque_access(): sch = tir.Schedule(opaque_access, debug_mask="all") - block_a = sch.get_block("A") + block_a = sch.get_sblock("A") i, j = sch.get_loops(block_a) sch.reorder(j, i) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j = sch.get_loops(block_b) sch.reorder(j, i) assert_structural_equal_ignore_global_symbol(opaque_access_reorder, sch.mod["main"]) @@ -222,7 +222,7 @@ def test_reorder_overlapped_access(): def overlapped_access(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4), "float32")): # example to write first axis multiple times for v0, v1, v2 in T.grid(6, 4, 4): - with T.block("block"): + with T.sblock("block"): i = T.axis.spatial(14, v0 * 2 + v1) j = T.axis.spatial(4, v2) B[i, j] = A[i, j] + 1.0 @@ -231,13 +231,13 @@ def overlapped_access(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4), "flo def overlapped_access_reorder(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4), "float32")): # example to write first axis multiple times for v0, v2, v1 in T.grid(6, 4, 4): - with T.block("block"): + with T.sblock("block"): i = T.axis.spatial(14, v0 * 2 + v1) j = T.axis.spatial(4, v2) B[i, j] = A[i, j] + 1.0 sch = tir.Schedule(overlapped_access, debug_mask="all") - v0, v1, v2 = sch.get_loops(sch.get_block("block")) + v0, v1, v2 = sch.get_loops(sch.get_sblock("block")) sch.reorder(v0, v2, v1) assert_structural_equal_ignore_global_symbol(overlapped_access_reorder, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=overlapped_access) @@ -247,7 +247,7 @@ def test_reorder_with_partial_affineness(): @T.prim_func def non_affine_func(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4), "float32")): for v0, v1, v2 in T.grid(6, 4, 4): - with T.block("block"): + with T.sblock("block"): i = T.axis.spatial(14, v0 * v0 + v1) j = T.axis.spatial(4, v2) B[i, j] = A[i, j] + 1.0 @@ -255,13 +255,13 @@ def non_affine_func(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4), "float @T.prim_func def non_affine_func_reorder(A: T.Buffer((14, 4), "float32"), B: T.Buffer((14, 4), "float32")): for v0, v2, v1 in T.grid(6, 4, 4): - with T.block("block"): + with T.sblock("block"): i = T.axis.spatial(14, v0 * v0 + v1) j = T.axis.spatial(4, v2) B[i, j] = A[i, j] + 1.0 sch = tir.Schedule(non_affine_func, debug_mask="all") - v0, v1, v2 = sch.get_loops(sch.get_block("block")) + v0, v1, v2 = sch.get_loops(sch.get_sblock("block")) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(v0, v2, v1) @@ -277,13 +277,13 @@ def cascade_pool_ops( ) -> None: y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32") for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3): - with T.block("pool_0"): + with T.sblock("pool_0"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw]) with T.init(): y1[ax0, ax1, ax2, ax3] = 0.0 y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1] for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3): - with T.block("pool_1"): + with T.sblock("pool_1"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw]) with T.init(): y2[ax0, ax1, ax2, ax3] = 0.0 @@ -296,7 +296,7 @@ def cascade_pool_ops_tile_reordered( y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32") for n, c, h_o in T.grid(1, 16, 27): for w, h_i, kh, kw in T.grid(110, 6, 3, 3): - with T.block("pool_0"): + with T.sblock("pool_0"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(16, c) ax2 = T.axis.spatial(110, h_o * 4 + h_i) @@ -307,7 +307,7 @@ def cascade_pool_ops_tile_reordered( y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1] ) for h_i, w, kh, kw in T.grid(4, 108, 3, 3): - with T.block("pool_1"): + with T.sblock("pool_1"): ax0 = T.axis.spatial(1, n) ax1 = T.axis.spatial(16, c) ax2 = T.axis.spatial(108, h_o * 4 + h_i) @@ -319,8 +319,8 @@ def cascade_pool_ops_tile_reordered( ) sch = tvm.tir.schedule.Schedule(cascade_pool_ops) - pool_0 = sch.get_block("pool_0") - pool_1 = sch.get_block("pool_1") + pool_0 = sch.get_sblock("pool_0") + pool_1 = sch.get_sblock("pool_1") _, _, h, w, _, _ = sch.get_loops(pool_1) ho, _ = sch.split(h, factors=[None, 4]) sch.compute_at(pool_0, ho) @@ -334,7 +334,7 @@ def cascade_pool_ops_tile_reordered( def test_reorder_with_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(l, i) @@ -342,7 +342,7 @@ def test_reorder_with_predicate(): def test_reorder_fail_with_multi_appearance_loops(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(k, i, i) @@ -350,13 +350,13 @@ def test_reorder_fail_with_multi_appearance_loops(): def test_reorder_fail_with_non_single_branch_loop(): sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(k, i) sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") - block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_b = sch.get_sblock("B") + block_c = sch.get_sblock("C") i, j, k1 = sch.get_loops(block_b) _, _, k2 = sch.get_loops(block_c) with pytest.raises(tvm.tir.ScheduleError): @@ -365,8 +365,8 @@ def test_reorder_fail_with_non_single_branch_loop(): def test_reorder_fail_with_loops_not_under_same_scope(): sch = tir.Schedule(elementwise_with_loops_not_same_scope, debug_mask="all") - block_b = sch.get_block("B") - block_a = sch.get_block("A") + block_b = sch.get_sblock("B") + block_a = sch.get_sblock("A") i, j = sch.get_loops(block_a) k = sch.get_loops(block_b)[0] with pytest.raises(tvm.tir.ScheduleError): @@ -375,7 +375,7 @@ def test_reorder_fail_with_loops_not_under_same_scope(): def test_reorder_fail_with_wrong_block_var_type(): sch = tir.Schedule(elementwise_with_wrong_block_var_type, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(k, i) @@ -383,7 +383,7 @@ def test_reorder_fail_with_wrong_block_var_type(): def test_reorder_fail_with_dependent_loops(): sch = tir.Schedule(elementwise_dependent_loop, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(l, i) @@ -391,7 +391,7 @@ def test_reorder_fail_with_dependent_loops(): def test_reorder_fail_not_affine_bindings(): sch = tir.Schedule(elementwise_not_affine, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(l, i) diff --git a/tests/python/tir-schedule/test_tir_schedule_reorder_block_iter_var.py b/tests/python/tir-schedule/test_tir_schedule_reorder_block_iter_var.py index fe1a832f49cb..a8a8566bae64 100644 --- a/tests/python/tir-schedule/test_tir_schedule_reorder_block_iter_var.py +++ b/tests/python/tir-schedule/test_tir_schedule_reorder_block_iter_var.py @@ -30,7 +30,7 @@ def matmul( C: T.Buffer((128, 128), "float32"), ) -> None: for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -44,7 +44,7 @@ def matmul_after_reorder_block_iter_var( C: T.Buffer((128, 128), "float32"), ): for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vk, vj, vi = T.axis.remap("RSS", [k, j, i]) T.reads(A[vi, vk], B[vj, vk]) T.writes(C[vi, vj]) @@ -55,7 +55,7 @@ def matmul_after_reorder_block_iter_var( def test_reorder_block_iter_var(): sch = tir.Schedule(matmul, debug_mask="all") - C = sch.get_block("C") + C = sch.get_sblock("C") sch.reorder_block_iter_var(C, [2, 1, 0]) tvm.ir.assert_structural_equal( matmul_after_reorder_block_iter_var.with_attr("global_symbol", "matmul"), sch.mod["main"] @@ -65,21 +65,21 @@ def test_reorder_block_iter_var(): def test_reorder_block_iter_var_fail_not_full(): sch = tir.Schedule(matmul, debug_mask="all") - C = sch.get_block("C") + C = sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reorder_block_iter_var(C, [2, 1]) def test_reorder_block_iter_var_fail_not_within_bound(): sch = tir.Schedule(matmul, debug_mask="all") - C = sch.get_block("C") + C = sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reorder_block_iter_var(C, [-1, 3, 2]) def test_reorder_block_iter_var_fail_not_unique(): sch = tir.Schedule(matmul, debug_mask="all") - C = sch.get_block("C") + C = sch.get_sblock("C") with pytest.raises(tvm.tir.ScheduleError): sch.reorder_block_iter_var(C, [0, 0, 2]) diff --git a/tests/python/tir-schedule/test_tir_schedule_rfactor.py b/tests/python/tir-schedule/test_tir_schedule_rfactor.py index a15bd3d9137b..429444e5247d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_rfactor.py +++ b/tests/python/tir-schedule/test_tir_schedule_rfactor.py @@ -36,7 +36,7 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([A[vi, vk], B[vj, vk]]) @@ -53,7 +53,7 @@ def transformed_matmul_with_let(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([A[vi, vk], B[vj, vk]]) @@ -72,7 +72,7 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([4, 128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block("update_rf"): + with T.sblock("update_rf"): vi2_inner_inner = T.axis.S(4, i2_inner_inner) vi = T.axis.S(128, i0) vj = T.axis.S(128, i1) @@ -86,7 +86,7 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: ) for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): - with T.block("update"): + with T.sblock("update"): vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) with T.init(): C[vi_1, vj_1] = 0.0 @@ -101,14 +101,14 @@ def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: C = T.alloc_buffer([256, 256]) for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(256, 256): - with T.block("D"): + with T.sblock("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = C[vi, vj] @@ -120,7 +120,7 @@ def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None C = T.match_buffer(c, (128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -135,12 +135,12 @@ def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.ha D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): ck, ci, cj = T.axis.remap("RSS", [k, i, j]) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] - with T.block("D"): + with T.sblock("D"): dk, di, dj = T.axis.remap("RSS", [k, i, j]) with T.init(): D[di, dj] = 0.0 @@ -153,7 +153,7 @@ def square_sum(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16]) for b0, i0, j0 in T.grid(16, 256, 256): - with T.block("C"): + with T.sblock("C"): b, i, j = T.axis.remap("SRR", [b0, i0, j0]) with T.init(): C[b] = 0.0 @@ -167,14 +167,14 @@ def square_sum_rfactor(a: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): - with T.block("C_rf"): + with T.sblock("C_rf"): vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): - with T.block("C"): + with T.sblock("C"): vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 @@ -188,7 +188,7 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block("C"): + with T.sblock("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) @@ -198,7 +198,7 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): - with T.block("D"): + with T.sblock("D"): b_1 = T.axis.S(16, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) @@ -213,7 +213,7 @@ def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block("C_rf"): + with T.sblock("C_rf"): vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) @@ -222,14 +222,14 @@ def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): - with T.block("C"): + with T.sblock("C"): vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): - with T.block("D"): + with T.sblock("D"): b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32") @@ -241,7 +241,7 @@ def transformed_square_sum_square_root_factor_one_1(a: T.handle, d: T.handle) -> C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block("C"): + with T.sblock("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) @@ -249,7 +249,7 @@ def transformed_square_sum_square_root_factor_one_1(a: T.handle, d: T.handle) -> C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): - with T.block("D"): + with T.sblock("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -261,7 +261,7 @@ def square_sum_square_root_factor_one_1_rfactor( C = T.alloc_buffer([16], dtype="float32") C_rf = T.alloc_buffer([1, 16], dtype="float32") for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block("C_rf"): + with T.sblock("C_rf"): b = T.axis.spatial(16, i0) i = T.axis.reduce(256, i1_i2_fused_outer // 256) j = T.axis.reduce(256, i1_i2_fused_outer % 256) @@ -270,13 +270,13 @@ def square_sum_square_root_factor_one_1_rfactor( C_rf[vi1_i2_fused_inner, b] = T.float32(0) C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + A[b, i, j] * A[b, i, j] for i0, i1_i2_fused_inner in T.grid(16, 1): - with T.block("C"): + with T.sblock("C"): b, vi1_i2_fused_inner = T.axis.remap("SR", [i0, i1_i2_fused_inner]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + C_rf[vi1_i2_fused_inner, b] for i0_1 in T.serial(16): - with T.block("D"): + with T.sblock("D"): b_1 = T.axis.spatial(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -288,7 +288,7 @@ def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): - with T.block("C"): + with T.sblock("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) @@ -296,7 +296,7 @@ def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): - with T.block("D"): + with T.sblock("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -308,7 +308,7 @@ def square_sum_square_root_factor_one_2_rfactor( C = T.alloc_buffer([16], dtype="float32") C_rf = T.alloc_buffer([16, 1], dtype="float32") for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): - with T.block("C_rf"): + with T.sblock("C_rf"): b = T.axis.spatial(16, i0) i = T.axis.reduce(256, i1_i2_fused_inner // 256) j = T.axis.reduce(256, i1_i2_fused_inner % 256) @@ -317,13 +317,13 @@ def square_sum_square_root_factor_one_2_rfactor( C_rf[b, vi1_i2_fused_outer] = T.float32(0) C_rf[b, vi1_i2_fused_outer] = C_rf[b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j] for i0, i1_i2_fused_outer in T.grid(16, 1): - with T.block("C"): + with T.sblock("C"): b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + C_rf[b, vi1_i2_fused_outer] for i0_1 in T.serial(16): - with T.block("D"): + with T.sblock("D"): b_1 = T.axis.spatial(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -334,8 +334,8 @@ def square_sum_with_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16]) for b0, i0, j0 in T.grid(16, 256, 256): - with T.block("C"): - T.block_attr({"test_annotation": 1}) + with T.sblock("C"): + T.sblock_attr({"test_annotation": 1}) b, i, j = T.axis.remap("SRR", [b0, i0, j0]) with T.init(): C[b] = 0.0 @@ -349,16 +349,16 @@ def square_sum_with_annotation_rfactor(a: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): - with T.block("C_rf"): - T.block_attr({"test_annotation": 1}) + with T.sblock("C_rf"): + T.sblock_attr({"test_annotation": 1}) vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): - with T.block("C"): - T.block_attr({"test_annotation": 1}) + with T.sblock("C"): + T.sblock_attr({"test_annotation": 1}) vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 @@ -371,7 +371,7 @@ def element_wise(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -382,7 +382,7 @@ def rowsum(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 @@ -395,7 +395,7 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): @@ -409,7 +409,7 @@ def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i, k in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi, vk] = 0.0 @@ -423,7 +423,7 @@ def rowsum_not_serial(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for k in T.parallel(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 @@ -436,7 +436,7 @@ def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 1.0 @@ -449,7 +449,7 @@ def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 @@ -462,7 +462,7 @@ def rowsum_init_not_bufferstore(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): v_init: T.float32 = T.float32(0) @@ -476,7 +476,7 @@ def rowsum_transformed(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for io, ii_ko_fused, ki in T.grid(32, 128, 4): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, io * 4 + T.floordiv(ii_ko_fused, 32)) vk = T.axis.R(128, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): @@ -490,7 +490,7 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, []) for k0 in range(128): - with T.block("B"): + with T.sblock("B"): k = T.axis.R(128, k0) with T.init(): B[()] = 0.0 @@ -504,12 +504,12 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: B_rf = T.alloc_buffer([128], elem_offset=T.int64(0)) for i in range(128): - with T.block("B_rf"): + with T.sblock("B_rf"): vi0 = T.axis.S(128, i) B_rf[vi0] = A[vi0] for i in range(128): - with T.block("B"): + with T.sblock("B"): vi0_1 = T.axis.R(128, i) with T.init(): B[()] = 0.0 @@ -521,7 +521,7 @@ def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block("B"): + with T.sblock("B"): T.where(k_0 * 10 + k_1 < 128) vi = T.axis.S(128, i) vk = T.axis.R(128, k_0 * 10 + k_1) @@ -536,14 +536,14 @@ def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block("B_rf"): + with T.sblock("B_rf"): vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) T.where(k_0 * 10 + k_1 < 128) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): - with T.block("B"): + with T.sblock("B"): vk_0, vi = T.axis.remap("RS", [k_0, i]) with T.init(): B[vi] = T.float32(0) @@ -561,14 +561,14 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: for i in T.serial(0, 16): for j1 in T.serial(0, 16): for k1o, k1i in T.grid(4, 4): - with T.block("C"): + with T.sblock("C"): ci, cj = T.axis.remap("SS", [i, j1]) ck = T.axis.R(16, k1o * 4 + k1i) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in T.grid(4, 4): - with T.block("D"): + with T.sblock("D"): di, dj = T.axis.remap("SS", [i, j1]) dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): @@ -576,14 +576,14 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block("E"): + with T.sblock("E"): ei, ej = T.axis.remap("SS", [i, j2]) ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block("F"): + with T.sblock("F"): fi, fj = T.axis.remap("SS", [i, j2]) fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): @@ -601,7 +601,7 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: C_rf = T.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): - with T.block("C_rf"): + with T.sblock("C_rf"): vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) with T.init(): C_rf[ci, cj, vk1o] = 0.0 @@ -609,13 +609,13 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: for i_1 in T.serial(0, 16): for j1_1 in T.serial(0, 16): for k1o_1 in T.serial(0, 4): - with T.block("C"): + with T.sblock("C"): vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in T.grid(4, 4): - with T.block("D"): + with T.sblock("D"): di, dj = T.axis.remap("SS", [i_1, j1_1]) dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): @@ -623,14 +623,14 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block("E"): + with T.sblock("E"): ei, ej = T.axis.remap("SS", [i_1, j2]) ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block("F"): + with T.sblock("F"): fi, fj = T.axis.remap("SS", [i_1, j2]) fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): @@ -644,7 +644,7 @@ def rfactor_spatial_only( B: T.Buffer((1, 512, 1, 1), "float32"), ) -> None: for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): - with T.block("acc"): + with T.sblock("acc"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) ax2 = T.axis.spatial(1, 0) @@ -666,10 +666,10 @@ def rfactor_spatial_only_after( B: T.Buffer((1, 512, 1, 1), "float32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): - with T.block("acc_rf"): + with T.sblock("acc_rf"): vi4 = T.axis.spatial(49, i4) ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) @@ -677,7 +677,7 @@ def rfactor_spatial_only_after( ax3 = T.axis.spatial(1, 0) B_rf[ax0, ax1, ax2, ax3, vi4] = A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7] for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): - with T.block("acc"): + with T.sblock("acc"): vi4 = T.axis.reduce(49, i4) ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) @@ -696,7 +696,7 @@ def argmax_split( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -718,7 +718,7 @@ def argmin_split_init_update_reordered( argmin_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmin"): + with T.sblock("argmin"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -740,7 +740,7 @@ def argmax_split_different_shape( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -762,7 +762,7 @@ def argmax_split_different_indices( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -784,7 +784,7 @@ def argmax_split_init_not_bufferstore( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -807,7 +807,7 @@ def argmax_split_init_buffer_duplicate( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -829,7 +829,7 @@ def argmax_split_letstmt_fewer_than_init( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -850,7 +850,7 @@ def argmax_split_letstmt_more_than_init( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -871,7 +871,7 @@ def argmax_split_let_body_neither_seqstmt_nor_bufferstore( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -892,7 +892,7 @@ def argmax_split_init_update_inconsistent_bufferstore_number( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -915,7 +915,7 @@ def argmax_split_body_seq_not_bufferstore( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -937,7 +937,7 @@ def argmax_split_body_bufferstore_value_not_var( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -961,7 +961,7 @@ def argmax_split_body_bufferstore_value_unbound_var( ) -> None: v_unbound = T.int32() for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -983,7 +983,7 @@ def argmax_split_one_let_var_used_multi_times( argmax_v1: T.Buffer((128,), "int32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1005,7 +1005,7 @@ def argmax_split_body_one_buffer_updated_multi_times( argmax_v1: T.Buffer((128,), "int32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1028,7 +1028,7 @@ def argmax_split_init_buffer_not_match( argmax_v1: T.Buffer((128,), "float32"), ) -> None: for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1052,7 +1052,7 @@ def argmax_split_rfactor( argmax_v0_rf = T.alloc_buffer([128, 32], dtype="int32") argmax_v1_rf = T.alloc_buffer([128, 32], dtype="float32") for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmax_rf"): + with T.sblock("argmax_rf"): vi1_1, i, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) T.reads(idx[i, vi1_0 * 32 + vi1_1], val[i, vi1_0 * 32 + vi1_1]) T.writes(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) @@ -1072,7 +1072,7 @@ def argmax_split_rfactor( argmax_v0_rf[i, vi1_1] = v_argmax_v0_rf argmax_v1_rf[i, vi1_1] = v_argmax_v1_rf for i0, i1_1 in T.grid(128, 32): - with T.block("argmax"): + with T.sblock("argmax"): vi1_1, i = T.axis.remap("RS", [i1_1, i0]) T.reads(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) T.writes(argmax_v0[i], argmax_v1[i]) @@ -1099,7 +1099,7 @@ def argmin_split_rfactor( argmin_v0_rf = T.alloc_buffer([128, 32], dtype="int32") argmin_v1_rf = T.alloc_buffer([128, 32], dtype="float32") for i0, i1_0, i1_1 in T.grid(128, 4, 32): - with T.block("argmin_rf"): + with T.sblock("argmin_rf"): vi1_1, i, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) T.reads(idx[i, vi1_0 * 32 + vi1_1], val[i, vi1_0 * 32 + vi1_1]) T.writes(argmin_v0_rf[i, vi1_1], argmin_v1_rf[i, vi1_1]) @@ -1119,7 +1119,7 @@ def argmin_split_rfactor( argmin_v0_rf[i, vi1_1] = v_argmin_v0_rf argmin_v1_rf[i, vi1_1] = v_argmin_v1_rf for i0, i1_1 in T.grid(128, 32): - with T.block("argmin"): + with T.sblock("argmin"): vi1_1, i = T.axis.remap("RS", [i1_1, i0]) T.reads(argmin_v0_rf[i, vi1_1], argmin_v1_rf[i, vi1_1]) T.writes(argmin_v0[i], argmin_v1[i]) @@ -1146,7 +1146,7 @@ def argmax_topi_rfactor( placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32") placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32") for i0, i1_0, i1_1 in T.grid(1, 4, 8): - with T.block("placeholder_red_temp_rf"): + with T.sblock("placeholder_red_temp_rf"): vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1]) T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) @@ -1168,7 +1168,7 @@ def argmax_topi_rfactor( placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf for i0, i1_1 in T.grid(1, 8): - with T.block("placeholder_red_temp"): + with T.sblock("placeholder_red_temp"): vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0]) T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) @@ -1190,7 +1190,7 @@ def argmax_topi_rfactor( placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0 placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1 for i0 in T.serial(1): - with T.block("placeholder_red"): + with T.sblock("placeholder_red"): ax0 = T.axis.spatial(1, i0) T.reads(placeholder_red_temp_v0[ax0]) T.writes(placeholder_red[ax0]) @@ -1207,7 +1207,7 @@ def argmin_topi_rfactor( placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32") placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32") for i0, i1_0, i1_1 in T.grid(1, 4, 8): - with T.block("placeholder_red_temp_rf"): + with T.sblock("placeholder_red_temp_rf"): vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1]) T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) @@ -1229,7 +1229,7 @@ def argmin_topi_rfactor( placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf for i0, i1_1 in T.grid(1, 8): - with T.block("placeholder_red_temp"): + with T.sblock("placeholder_red_temp"): vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0]) T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) @@ -1251,7 +1251,7 @@ def argmin_topi_rfactor( placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0 placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1 for i0 in T.serial(1): - with T.block("placeholder_red"): + with T.sblock("placeholder_red"): ax0 = T.axis.spatial(1, i0) T.reads(placeholder_red_temp_v0[ax0]) T.writes(placeholder_red[ax0]) @@ -1263,100 +1263,100 @@ def argmin_topi_rfactor( def test_reduction_rfactor_matmul(): s = tir.Schedule(transformed_matmul, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, 0) assert_structural_equal_ignore_global_symbol(s.mod["main"], matmul_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - assert s.get(update).same_as(s.get(s.get_block("update"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("update_rf"))) + assert s.get(update).same_as(s.get(s.get_sblock("update"))) verify_trace_roundtrip(s, mod=transformed_matmul) def test_reduction_rfactor_matmul_with_let(): s = tir.Schedule(transformed_matmul_with_let, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, 0) assert_structural_equal_ignore_global_symbol(s.mod["main"], matmul_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - assert s.get(update).same_as(s.get(s.get_block("update"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("update_rf"))) + assert s.get(update).same_as(s.get(s.get_sblock("update"))) verify_trace_roundtrip(s, mod=transformed_matmul_with_let) def test_reduction_rfactor_square_sum(): s = tir.Schedule(square_sum, debug_mask="all") - C = s.get_block("C") + C = s.get_sblock("C") _, _, j = s.get_loops(C) rf_block = s.rfactor(j, 1) assert_structural_equal_ignore_global_symbol(s.mod["main"], square_sum_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("C_rf"))) + assert s.get(C).same_as(s.get(s.get_sblock("C"))) verify_trace_roundtrip(s, mod=square_sum) def test_reduction_rfactor_square_sum_square_root(): s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") - C = s.get_block("C") + C = s.get_sblock("C") _, _, f_i = s.get_loops(C) rf_block = s.rfactor(f_i, 0) assert_structural_equal_ignore_global_symbol(s.mod["main"], square_sum_square_root_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("C_rf"))) + assert s.get(C).same_as(s.get(s.get_sblock("C"))) verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) def test_reduction_rfactor_loop_multiple_children(): s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") - k, _, _ = s.get_loops(s.get_block("C")) + k, _, _ = s.get_loops(s.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_stage_pipeline(): s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all") - _, _, k = s.get_loops(s.get_block("C")) + _, _, k = s.get_loops(s.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_reduction_block1(): s = tir.Schedule(element_wise, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(i, 0) def test_reduction_rfactor_not_reduction_block2(): s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_reduction_block3(): s = tir.Schedule(rowsum_not_dominant, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_serial_loop(): s = tir.Schedule(rowsum_not_serial, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_same_buffer_access(): s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all") - _, _, k = s.get_loops(s.get_block("C")) + _, _, k = s.get_loops(s.get_sblock("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_factor_axis_range_fail(): s = tir.Schedule(transformed_matmul, debug_mask="all") - _, _, _, _, kii = s.get_loops(s.get_block("update")) + _, _, _, _, kii = s.get_loops(s.get_sblock("update")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(kii, 3) with pytest.raises(tvm.tir.ScheduleError): @@ -1365,66 +1365,66 @@ def test_reduction_rfactor_factor_axis_range_fail(): def test_reduction_rfactor_factor_axis_range(): s = tir.Schedule(transformed_matmul, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, -3) assert_structural_equal_ignore_global_symbol(s.mod["main"], matmul_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - assert s.get(update).same_as(s.get(s.get_block("update"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("update_rf"))) + assert s.get(update).same_as(s.get(s.get_sblock("update"))) verify_trace_roundtrip(s, mod=transformed_matmul) def test_reduction_rfactor_wrong_reduce_pattern1(): s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_wrong_reduce_pattern2(): s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_init_not_bufferstore(): s = tir.Schedule(rowsum_init_not_bufferstore, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) + _, k = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_wrong_loops1(): s = tir.Schedule(rowsum, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) + i, _ = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(i, 0) def test_reduction_rfactor_wrong_loops2(): s = tir.Schedule(rowsum_transformed, debug_mask="all") - _, _, k_i = s.get_loops(s.get_block("B")) + _, _, k_i = s.get_loops(s.get_sblock("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k_i, 0) def test_reduction_rfactor_zero_dim(): s = tir.Schedule(rowsum_zero_dim, debug_mask="all") - B = s.get_block("B") + B = s.get_sblock("B") (k,) = s.get_loops(B) rf_block = s.rfactor(k, 0) assert_structural_equal_ignore_global_symbol(s.mod["main"], rowsum_zero_dim_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) - assert s.get(B).same_as(s.get(s.get_block("B"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("B_rf"))) + assert s.get(B).same_as(s.get(s.get_sblock("B"))) verify_trace_roundtrip(s, mod=rowsum_zero_dim) def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: disable=invalid-name s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") - _, _, k2o, k2i = s.get_loops(s.get_block("D")) - _, _, k3o, k3i = s.get_loops(s.get_block("E")) - _, _, k4o, k4i = s.get_loops(s.get_block("F")) + _, _, k2o, k2i = s.get_loops(s.get_sblock("D")) + _, _, k3o, k3i = s.get_loops(s.get_sblock("E")) + _, _, k4o, k4i = s.get_loops(s.get_sblock("F")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k2o, 0) with pytest.raises(tvm.tir.ScheduleError): @@ -1441,18 +1441,18 @@ def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: d def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") - C = s.get_block("C") + C = s.get_sblock("C") _, _, k1o, _ = s.get_loops(C) rf_block = s.rfactor(k1o, 2) assert_structural_equal_ignore_global_symbol(s.mod["main"], multiple_reduction_blocks_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("C_rf"))) + assert s.get(C).same_as(s.get(s.get_sblock("C"))) verify_trace_roundtrip(s, mod=multiple_reduction_blocks) def test_reduction_rfactor_predicate(): # pylint: disable=invalid-name s = tir.Schedule(rowsum_predicate, debug_mask="all") - B = s.get_block("B") + B = s.get_sblock("B") _, ko, _ = s.get_loops(B) # TODO: should be a tvm.tir.ScheduleError with pytest.raises(tvm.TVMError): @@ -1461,51 +1461,51 @@ def test_reduction_rfactor_predicate(): # pylint: disable=invalid-name def test_reduction_rfactor_with_annotation(): s = tir.Schedule(square_sum_with_annotation, debug_mask="all") - C = s.get_block("C") + C = s.get_sblock("C") _, _, j = s.get_loops(C) rf_block = s.rfactor(j, 1) assert_structural_equal_ignore_global_symbol(s.mod["main"], square_sum_with_annotation_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("C_rf"))) + assert s.get(C).same_as(s.get(s.get_sblock("C"))) verify_trace_roundtrip(s, mod=square_sum_with_annotation) def test_reduction_rfactor_spatial_only(): s = tir.Schedule(rfactor_spatial_only, debug_mask="all") - block = s.get_block(name="acc", func_name="main") + block = s.get_sblock(name="acc", func_name="main") _, _, _, _, loop, _ = s.get_loops(block) rf_block = s.rfactor(loop=loop, factor_axis=4) assert_structural_equal_ignore_global_symbol(s.mod["main"], rfactor_spatial_only_after) - assert s.get(rf_block).same_as(s.get(s.get_block("acc_rf"))) - assert s.get(block).same_as(s.get(s.get_block("acc"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("acc_rf"))) + assert s.get(block).same_as(s.get(s.get_sblock("acc"))) verify_trace_roundtrip(s, mod=rfactor_spatial_only) def test_reduction_rfactor_argmax(): s = tir.Schedule(argmax_split, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) rf_block = s.rfactor(ki, 1) assert_structural_equal_ignore_global_symbol(s.mod["main"], argmax_split_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("argmax_rf"))) - assert s.get(argmax).same_as(s.get(s.get_block("argmax"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("argmax_rf"))) + assert s.get(argmax).same_as(s.get(s.get_sblock("argmax"))) verify_trace_roundtrip(s, mod=argmax_split) def test_reduction_rfactor_argmin_init_update_reordeded(): s = tir.Schedule(argmin_split_init_update_reordered, debug_mask="all") - argmin = s.get_block("argmin") + argmin = s.get_sblock("argmin") _, _, ki = s.get_loops(argmin) rf_block = s.rfactor(ki, 1) assert_structural_equal_ignore_global_symbol(s.mod["main"], argmin_split_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("argmin_rf"))) - assert s.get(argmin).same_as(s.get(s.get_block("argmin"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("argmin_rf"))) + assert s.get(argmin).same_as(s.get(s.get_sblock("argmin"))) verify_trace_roundtrip(s, mod=argmin_split_init_update_reordered) def test_reduction_rfactor_argmax_reduction_buffer_different_shape(): s = tir.Schedule(argmax_split_different_shape, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1513,7 +1513,7 @@ def test_reduction_rfactor_argmax_reduction_buffer_different_shape(): def test_reduction_rfactor_argmax_different_access_indices(): s = tir.Schedule(argmax_split_different_indices, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1521,7 +1521,7 @@ def test_reduction_rfactor_argmax_different_access_indices(): def test_reduction_rfactor_argmax_init_not_bufferstore(): s = tir.Schedule(argmax_split_init_not_bufferstore, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1529,7 +1529,7 @@ def test_reduction_rfactor_argmax_init_not_bufferstore(): def test_reduction_rfactor_argmax_init_buffer_duplicate(): s = tir.Schedule(argmax_split_init_buffer_duplicate, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1537,7 +1537,7 @@ def test_reduction_rfactor_argmax_init_buffer_duplicate(): def test_reduction_rfactor_argmax_letstmt_fewer_than_init(): s = tir.Schedule(argmax_split_letstmt_fewer_than_init, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1545,7 +1545,7 @@ def test_reduction_rfactor_argmax_letstmt_fewer_than_init(): def test_reduction_rfactor_argmax_letstmt_more_than_init(): s = tir.Schedule(argmax_split_letstmt_more_than_init, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1553,7 +1553,7 @@ def test_reduction_rfactor_argmax_letstmt_more_than_init(): def test_reduction_rfactor_argmax_let_body_neither_seqstmt_nor_bufferstore(): s = tir.Schedule(argmax_split_let_body_neither_seqstmt_nor_bufferstore, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1561,7 +1561,7 @@ def test_reduction_rfactor_argmax_let_body_neither_seqstmt_nor_bufferstore(): def test_reduction_rfactor_argmax_init_update_inconsistent_bufferstore_number(): s = tir.Schedule(argmax_split_init_update_inconsistent_bufferstore_number, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1569,7 +1569,7 @@ def test_reduction_rfactor_argmax_init_update_inconsistent_bufferstore_number(): def test_reduction_rfactor_argmax_body_seq_not_bufferstore(): s = tir.Schedule(argmax_split_body_seq_not_bufferstore, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1577,7 +1577,7 @@ def test_reduction_rfactor_argmax_body_seq_not_bufferstore(): def test_reduction_rfactor_argmax_body_bufferstore_value_not_var(): s = tir.Schedule(argmax_split_body_bufferstore_value_not_var, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1586,7 +1586,7 @@ def test_reduction_rfactor_argmax_body_bufferstore_value_not_var(): @pytest.mark.xfail(reason="The input IR is not well-formed") def test_reduction_rfactor_argmax_body_bufferstore_value_unbound_var(): s = tir.Schedule(argmax_split_body_bufferstore_value_unbound_var, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1594,7 +1594,7 @@ def test_reduction_rfactor_argmax_body_bufferstore_value_unbound_var(): def test_reduction_rfactor_argmax_one_let_var_used_multi_times(): s = tir.Schedule(argmax_split_one_let_var_used_multi_times, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1602,7 +1602,7 @@ def test_reduction_rfactor_argmax_one_let_var_used_multi_times(): def test_reduction_rfactor_argmax_body_one_buffer_updated_multi_times(): s = tir.Schedule(argmax_split_body_one_buffer_updated_multi_times, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1610,7 +1610,7 @@ def test_reduction_rfactor_argmax_body_one_buffer_updated_multi_times(): def test_reduction_rfactor_argmax_init_buffer_not_match(): s = tir.Schedule(argmax_split_init_buffer_not_match, debug_mask="all") - argmax = s.get_block("argmax") + argmax = s.get_sblock("argmax") _, _, ki = s.get_loops(argmax) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(ki, 1) @@ -1621,13 +1621,13 @@ def test_reduction_rfactor_topi_argmax(): B = topi.argmax(A, axis=1) argmax_topi = te.create_prim_func([A, B]) s = tir.Schedule(argmax_topi, debug_mask="all") - argmax = s.get_block("placeholder_red_temp") + argmax = s.get_sblock("placeholder_red_temp") _, k = s.get_loops(argmax) _, ki = s.split(k, [None, 8]) rf_block = s.rfactor(ki, 1) assert_structural_equal_ignore_global_symbol(s.mod["main"], argmax_topi_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) - assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("placeholder_red_temp_rf"))) + assert s.get(argmax).same_as(s.get(s.get_sblock("placeholder_red_temp"))) verify_trace_roundtrip(s, mod=argmax_topi) @@ -1636,13 +1636,13 @@ def test_reduction_rfactor_topi_argmin(): B = topi.argmin(A, axis=1) argmin_topi = te.create_prim_func([A, B]) s = tir.Schedule(argmin_topi, debug_mask="all") - argmin = s.get_block("placeholder_red_temp") + argmin = s.get_sblock("placeholder_red_temp") _, k = s.get_loops(argmin) _, ki = s.split(k, [None, 8]) rf_block = s.rfactor(ki, 1) assert_structural_equal_ignore_global_symbol(s.mod["main"], argmin_topi_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) - assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("placeholder_red_temp_rf"))) + assert s.get(argmin).same_as(s.get(s.get_sblock("placeholder_red_temp"))) verify_trace_roundtrip(s, mod=argmin_topi) @@ -1657,7 +1657,7 @@ def before( for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4) ): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R( T.int64(128), @@ -1675,7 +1675,7 @@ def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), "float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)): - with T.block("update_rf"): + with T.sblock("update_rf"): vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer]) with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 @@ -1685,7 +1685,7 @@ def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), ) for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), T.int64(4)): - with T.block("update"): + with T.sblock("update"): vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) with T.init(): C[vi_1, vj_1] = 0.0 @@ -1693,12 +1693,12 @@ def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), # fmt: on s = tir.Schedule(before, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, 0) assert_structural_equal_ignore_global_symbol(s.mod["main"], expected) - assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - assert s.get(update).same_as(s.get(s.get_block("update"))) + assert s.get(rf_block).same_as(s.get(s.get_sblock("update_rf"))) + assert s.get(update).same_as(s.get(s.get_sblock("update"))) verify_trace_roundtrip(s, mod=before) diff --git a/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py b/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py index 6fdd830120ec..45b5f251c609 100644 --- a/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py +++ b/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py @@ -51,7 +51,7 @@ def check_rolling_buffer( def _tile_nd(s, tile, block_name): outer_indices = [] inner_indices = [] - block = s.get_block(block_name) + block = s.get_sblock(block_name) loops = s.get_loops(block) for i, size in enumerate(tile): outer, inner = s.split(loops[i], [None, size]) @@ -69,14 +69,14 @@ def before(A: T.Buffer((4, 12), "int32"), C: T.Buffer((4, 8), "int32")): for c in T.serial(4): for i in T.serial(0, 10): for k in T.serial(3): - with T.block("B"): + with T.sblock("B"): cc, vi, vk = T.axis.remap("SSR", [c, i, k]) with T.init(): B[cc, vi] = 0 B[cc, vi] = B[cc, vi] + A[cc, vi + vk] for i in T.serial(0, 8): for k in T.serial(3): - with T.block("C"): + with T.sblock("C"): cc, vi, vk = T.axis.remap("SSR", [c, i, k]) with T.init(): C[cc, vi] = 0 @@ -87,7 +87,7 @@ def expected(A: T.Buffer((4, 12), "int32"), C: T.Buffer((4, 8), "int32")): B = T.alloc_buffer([4, 6], dtype="int32") for c, i_0 in T.grid(4, 2): for ax0, ax1 in T.grid(6, 3): - with T.block("B"): + with T.sblock("B"): T.where(i_0 < 1 or 2 <= ax0) cc = T.axis.spatial(4, c) vi = T.axis.opaque(10, i_0 * 4 + ax0) @@ -98,7 +98,7 @@ def expected(A: T.Buffer((4, 12), "int32"), C: T.Buffer((4, 8), "int32")): B[cc, vi % 6] = 0 B[cc, vi % 6] = B[cc, vi % 6] + A[cc, vi + vk] for i_1, k in T.grid(4, 3): - with T.block("C"): + with T.sblock("C"): cc = T.axis.spatial(4, c) vi = T.axis.opaque(8, i_0 * 4 + i_1) vk = T.axis.reduce(3, k) @@ -109,10 +109,10 @@ def expected(A: T.Buffer((4, 12), "int32"), C: T.Buffer((4, 8), "int32")): C[cc, vi] = C[cc, vi] + B[cc, (vi + vk) % 6] sch = tir.Schedule(before, debug_mask="all") - _, i, _ = sch.get_loops(sch.get_block("C")) + _, i, _ = sch.get_loops(sch.get_sblock("C")) io, _ = sch.split(i, [2, 4]) - sch.compute_at(sch.get_block("B"), io) - sch.rolling_buffer(sch.get_block("B"), 0) + sch.compute_at(sch.get_sblock("B"), io) + sch.rolling_buffer(sch.get_sblock("B"), 0) check_rolling_buffer(sch, before, expected, check_run=True) @@ -120,13 +120,13 @@ def expected(A: T.Buffer((4, 12), "int32"), C: T.Buffer((4, 8), "int32")): def cascade_2_max_pool2d(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "int8")): B = T.alloc_buffer([1, 10, 10, 16], dtype="int8") for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): B[ax0, ax1, ax2, ax3] = T.int8(-128) B[ax0, ax1, ax2, ax3] = T.max(B[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3]) for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): C[ax0, ax1, ax2, ax3] = T.int8(-128) @@ -140,7 +140,7 @@ def cascade_3_max_pool2d_with_stride( B_0 = T.alloc_buffer([1, 22, 22, 16], dtype="int8") B_1 = T.alloc_buffer([1, 10, 10, 16], dtype="int8") for i0, i1, i2, i3, i4, i5 in T.grid(1, 22, 22, 16, 3, 3): - with T.block("B_0"): + with T.sblock("B_0"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): B_0[ax0, ax1, ax2, ax3] = T.int8(-128) @@ -148,7 +148,7 @@ def cascade_3_max_pool2d_with_stride( B_0[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3] ) for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): - with T.block("B_1"): + with T.sblock("B_1"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): B_1[ax0, ax1, ax2, ax3] = T.int8(-128) @@ -156,7 +156,7 @@ def cascade_3_max_pool2d_with_stride( B_1[ax0, ax1, ax2, ax3], B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3] ) for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): C[ax0, ax1, ax2, ax3] = T.int8(-128) @@ -171,7 +171,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B = T.alloc_buffer([1, 10, 6, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 1): for ax0, ax1, ax2, ax3, ax4 in T.grid(10, 6, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where(i2_0 < 1 or 2 <= ax1) ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(10, ax0) @@ -185,7 +185,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B[ax0_1, ax1_1, ax2_1 % 6, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 8, 4, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.spatial(8, i1_0 * 8 + i1_1) ax2 = T.axis.opaque(8, i2_0 * 4 + i2_1) @@ -201,8 +201,8 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") oi, _ = _tile_nd(sch, [1, 8, 4, 16], "C") - sch.compute_at(sch.get_block("B"), oi[-1]) - sch.rolling_buffer(sch.get_block("B"), 0) + sch.compute_at(sch.get_sblock("B"), oi[-1]) + sch.rolling_buffer(sch.get_sblock("B"), 0) check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) @@ -212,7 +212,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 1, 1): for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 10, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where(i1_0 < 1 or 2 <= ax0) ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) @@ -226,7 +226,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 8, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) ax2 = T.axis.spatial(8, i2_0 * 8 + i2_1) @@ -242,8 +242,8 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") io, _ = _tile_nd(sch, [1, 4, 8, 16], "C") - sch.compute_at(sch.get_block("B"), io[-1]) - sch.rolling_buffer(sch.get_block("B"), 0) + sch.compute_at(sch.get_sblock("B"), io[-1]) + sch.rolling_buffer(sch.get_sblock("B"), 0) check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) @@ -253,7 +253,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 2): for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 8, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) @@ -268,7 +268,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 8, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) ax2 = T.axis.spatial(8, i2_0 * 4 + i2_1) @@ -284,8 +284,8 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") io, _ = _tile_nd(sch, [1, 4, 4, 8], "C") - sch.compute_at(sch.get_block("B"), io[-1]) - sch.rolling_buffer(sch.get_block("B"), 0) + sch.compute_at(sch.get_sblock("B"), io[-1]) + sch.rolling_buffer(sch.get_sblock("B"), 0) check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) @@ -295,7 +295,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B = T.alloc_buffer([1, 8, 10, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): for ax0, ax1, ax2, ax3, ax4 in T.grid(8, 8, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where( i1_0 * 6 + ax0 < 10 and i2_0 * 6 + ax1 < 10 @@ -314,7 +314,7 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B[ax0_1, ax1_1 % 8, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 6, 6, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): T.where(i1_0 * 6 + i1_1 < 8 and i2_0 * 6 + i2_1 < 8) ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.opaque(8, i1_0 * 6 + i1_1) @@ -331,8 +331,8 @@ def expected(A: T.Buffer((1, 12, 12, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") io, _ = _tile_nd(sch, [1, 6, 6, 16], "C") - sch.compute_at(sch.get_block("B"), io[-1]) - sch.rolling_buffer(sch.get_block("B"), 0) + sch.compute_at(sch.get_sblock("B"), io[-1]) + sch.rolling_buffer(sch.get_sblock("B"), 0) check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) @@ -343,7 +343,7 @@ def expected(A: T.Buffer((1, 24, 24, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B_1 = T.alloc_buffer([1, 6, 10, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): for ax0, ax1, ax2, ax3, ax4 in T.grid(13, 13, 16, 3, 3): - with T.block("B_0"): + with T.sblock("B_0"): T.where((i1_0 < 1 or 5 <= ax0) and (i2_0 < 1 or 5 <= ax1)) ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.opaque(22, i1_0 * 8 + ax0) @@ -358,7 +358,7 @@ def expected(A: T.Buffer((1, 24, 24, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1], ) for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 16, 3, 3): - with T.block("B_1"): + with T.sblock("B_1"): T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.opaque(10, i1_0 * 4 + ax0) @@ -373,7 +373,7 @@ def expected(A: T.Buffer((1, 24, 24, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2], ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0_3 = T.axis.spatial(1, i0_0 + i0_1) ax1_3 = T.axis.opaque(8, i1_0 * 4 + i1_1) ax2_3 = T.axis.spatial(8, i2_0 * 4 + i2_1) @@ -390,10 +390,10 @@ def expected(A: T.Buffer((1, 24, 24, 16), "int8"), C: T.Buffer((1, 8, 8, 16), "i sch = tir.Schedule(cascade_3_max_pool2d_with_stride, debug_mask="all") io, _ = _tile_nd(sch, [1, 4, 4, 16], "C") - sch.compute_at(sch.get_block("B_1"), io[-1]) - sch.compute_at(sch.get_block("B_0"), io[-1]) - sch.rolling_buffer(sch.get_block("B_0"), 0) - sch.rolling_buffer(sch.get_block("B_1"), 0) + sch.compute_at(sch.get_sblock("B_1"), io[-1]) + sch.compute_at(sch.get_sblock("B_0"), io[-1]) + sch.rolling_buffer(sch.get_sblock("B_0"), 0) + sch.rolling_buffer(sch.get_sblock("B_1"), 0) check_rolling_buffer(sch, cascade_3_max_pool2d_with_stride, expected, check_run=True) @@ -403,7 +403,7 @@ def before(A: T.Buffer((1, 16, 16, 16), "int8"), C: T.Buffer((1, 24, 24, 16), "i B = T.alloc_buffer([1, 14, 14, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where(i1_0 * 5 // 2 + ax0 < 14 and i2_0 * 5 // 2 + ax1 < 14) ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(14, i1_0 * 5 // 2 + ax0) @@ -418,7 +418,7 @@ def before(A: T.Buffer((1, 16, 16, 16), "int8"), C: T.Buffer((1, 24, 24, 16), "i B[ax0_1, ax1_1, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.spatial(24, i1_0 * 5 + i1_1) @@ -440,7 +440,7 @@ def expected( B = T.alloc_buffer([1, 5, 14, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): - with T.block("B"): + with T.sblock("B"): T.where( i1_0 * 5 // 2 + ax0 < 14 and i2_0 * 5 // 2 + ax1 < 14 @@ -460,7 +460,7 @@ def expected( B[ax0_1, ax1_1 % 5, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] ) for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.opaque(24, i1_0 * 5 + i1_1) @@ -476,7 +476,7 @@ def expected( ) sch = tir.Schedule(before, debug_mask="all") - sch.rolling_buffer(sch.get_block("B"), 0) + sch.rolling_buffer(sch.get_sblock("B"), 0) check_rolling_buffer(sch, before, expected, check_run=True) @@ -488,7 +488,7 @@ def func_multi_writers( B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") for i0, i1, i2, i3 in T.grid(1, 3, 3, 1): for ax0, ax1, ax2 in T.grid(6, 6, 16): - with T.block("B_writer_0"): + with T.sblock("B_writer_0"): ax0_1 = T.axis.spatial(1, i0) ax1_1 = T.axis.spatial(12, i1 * 4 + ax0) ax2_1 = T.axis.spatial(12, i2 * 4 + ax1) @@ -497,7 +497,7 @@ def func_multi_writers( B[ax0_1, ax1_1, ax2_1, ax3_1] = T.int8(-128) B[ax0_1, ax1_1, ax2_1, ax3_1] = A[ax0_1, ax1_1, ax2_1, ax3_1] + T.int8(1) for ax0, ax1, ax2 in T.grid(6, 6, 16): - with T.block("B_writer_1"): + with T.sblock("B_writer_1"): ax0_2 = T.axis.spatial(1, i0) ax1_2 = T.axis.spatial(12, i1 * 4 + ax0) ax2_2 = T.axis.spatial(12, i2 * 4 + ax1) @@ -508,7 +508,7 @@ def func_multi_writers( ax0_2, ax1_2, ax2_2, ax3_2 ] * T.int8(2) for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 4, 16, 3, 3): - with T.block("C"): + with T.sblock("C"): ax0_3 = T.axis.spatial(1, i0 + ax0) ax1_3 = T.axis.spatial(12, i1 * 4 + ax1) ax2_3 = T.axis.spatial(12, i2 * 4 + ax2) @@ -522,7 +522,7 @@ def func_multi_writers( sch = tir.Schedule(func_multi_writers, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): - sch.rolling_buffer(sch.get_block("B_writer_0"), 0) + sch.rolling_buffer(sch.get_sblock("B_writer_0"), 0) def test_fail_rolling_buffer_not_match(): @@ -533,7 +533,7 @@ def func_non_overlap( B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1): for ax0, ax1, ax2 in T.grid(4, 4, 16): - with T.block("B"): + with T.sblock("B"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0) ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1) @@ -544,7 +544,7 @@ def func_non_overlap( B[ax0_1, ax1_1, ax2_1, ax3] = T.int8(-128) B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3] for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 1): - with T.block("C"): + with T.sblock("C"): ax0 = T.axis.spatial(1, i0_0 + i0_1) ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1) ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1) @@ -560,7 +560,7 @@ def func_non_overlap( sch = tir.Schedule(func_non_overlap, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): - sch.rolling_buffer(sch.get_block("B"), 0) + sch.rolling_buffer(sch.get_sblock("B"), 0) def test_fail_rolling_buffer_injection_invalid(): @@ -569,7 +569,7 @@ def test_fail_rolling_buffer_injection_invalid(): _, _ = _tile_nd(sch, [1, 4, 8, 16], "C") _, _ = _tile_nd(sch, [1, 4, 8, 16], "B") with pytest.raises(tvm.tir.ScheduleError): - sch.rolling_buffer(sch.get_block("B"), 0) + sch.rolling_buffer(sch.get_sblock("B"), 0) if __name__ == "__main__": diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index 9e86194afeb1..e348d356597e 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -34,7 +34,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 257, 1470)) B = T.match_buffer(b, (128, 257, 1470)) for i, j, k in T.grid(128, 257, 1470): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -47,7 +47,7 @@ def tiled_conv2d_with_padding( ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) @@ -81,7 +81,7 @@ def tiled_conv2d_with_padding( i2_3, i3_3, ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): n = T.axis.spatial(1, 0) h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) @@ -151,7 +151,7 @@ def test_sample_categorical_serialize(): def test_sample_perfect_tile_power_of_two(): sch = tir.Schedule(elementwise, debug_mask="all") - i, _, _ = sch.get_loops(sch.get_block("B")) + i, _, _ = sch.get_loops(sch.get_sblock("B")) factors = sch.sample_perfect_tile(i, n=4) factors = [sch.get(i) for i in factors] prod = factors[0] * factors[1] * factors[2] * factors[3] @@ -161,7 +161,7 @@ def test_sample_perfect_tile_power_of_two(): def test_sample_perfect_tile_prime(): sch = tir.Schedule(elementwise, debug_mask="all") - _, i, _ = sch.get_loops(sch.get_block("B")) + _, i, _ = sch.get_loops(sch.get_sblock("B")) factors = sch.sample_perfect_tile(i, n=4) factors = [sch.get(i) for i in factors] prod = factors[0] * factors[1] * factors[2] * factors[3] @@ -171,7 +171,7 @@ def test_sample_perfect_tile_prime(): def test_sample_perfect_tile_composite(): sch = tir.Schedule(elementwise, debug_mask="all") - _, _, i = sch.get_loops(sch.get_block("B")) + _, _, i = sch.get_loops(sch.get_sblock("B")) factors = sch.sample_perfect_tile(i, n=4) factors = [sch.get(i) for i in factors] prod = factors[0] * factors[1] * factors[2] * factors[3] @@ -188,7 +188,7 @@ def test_sample_compute_location(use_sugared_block): if use_sugared_block: pad_input = "PadInput" else: - pad_input = sch.get_block("PadInput") + pad_input = sch.get_sblock("PadInput") decision_dict = dict() for _ in range(n): _ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name @@ -204,10 +204,10 @@ def test_sample_compute_location(use_sugared_block): def test_sample_perfect_tile_after_copy(): sch = tir.Schedule(elementwise, debug_mask="all") sch_copy = sch.copy() - _, _, i = sch.get_loops(sch.get_block("B")) + _, _, i = sch.get_loops(sch.get_sblock("B")) sch.sample_perfect_tile(i, n=4) - _, _, i = sch_copy.get_loops(sch_copy.get_block("B")) + _, _, i = sch_copy.get_loops(sch_copy.get_sblock("B")) # Hangs if ForkSeed is not invoked when copying a schedule sch_copy.sample_perfect_tile(i, n=4) @@ -220,12 +220,12 @@ def workload(a: T.handle) -> None: n = T.int32() A = T.match_buffer(a, (n, 1024)) for i, j in T.grid(n, 1024): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 1.0 sch = tir.Schedule(workload, debug_mask="all") - di, si = sch.get_loops(sch.get_block("B")) + di, si = sch.get_loops(sch.get_sblock("B")) factors = sch.sample_perfect_tile(si, n=4) factors = [sch.get(i) for i in factors] diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py index 788e17e77146..a80ba1e62ebd 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py @@ -34,11 +34,11 @@ def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "fl B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -48,11 +48,11 @@ def element_wise_set_axis_separator(A: T.Buffer((128, 128), "float32"), C: T.Buf B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1]) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -62,11 +62,11 @@ def element_wise_set_axis_separator_input_buffer(A: T.Buffer(shape=(128, 128), d B = T.alloc_buffer([128, 128], dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -76,12 +76,12 @@ def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) B_subregion0[()] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) C[vi, vj] = B_subregion1[()] + 1.0 @@ -92,12 +92,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "flo B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1]) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) B_subregion0[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0]) C[vi, vj] = B_subregion1[()] + T.float32(1) @@ -116,11 +116,11 @@ def test_set_axis_separator(argument_style): s = tir.Schedule(func, debug_mask='all') if argument_style=='set_axis_separators': - s.set_axis_separator(s.get_block("B"), ("write",0), [1]) + s.set_axis_separator(s.get_sblock("B"), ("write",0), [1]) elif argument_style=='transform_layout_named': s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) elif argument_style =='transform_layout_buffer_object': - B = s.get(s.get_block('B')).writes[0].buffer + B = s.get(s.get_sblock('B')).writes[0].buffer s.transform_layout(block='B', buffer=B, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: raise ValueError(f'Unexpected argument_style: {argument_style}') @@ -133,9 +133,9 @@ def test_set_scope_fail_on_index_out_of_bound(): func = element_wise s = tir.Schedule(func, debug_mask='all') with pytest.raises(AssertionError): - s.set_axis_separator(s.get_block("B"), ("write",1),[1]) + s.set_axis_separator(s.get_sblock("B"), ("write",1),[1]) with pytest.raises(AssertionError): - s.set_axis_separator(s.get_block("B"), ("read",-1),[1]) + s.set_axis_separator(s.get_sblock("B"), ("read",-1),[1]) def test_set_axis_separator_input_buffer(argument_style): @@ -143,11 +143,11 @@ def test_set_axis_separator_input_buffer(argument_style): s = tir.Schedule(func, debug_mask='all') if argument_style=='set_axis_separators': - s.set_axis_separator(s.get_block("B"), ("read",0), [1]) + s.set_axis_separator(s.get_sblock("B"), ("read",0), [1]) elif argument_style=='transform_layout_named': s.transform_layout(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) elif argument_style =='transform_layout_buffer_object': - A = s.get(s.get_block('B')).reads[0].buffer + A = s.get(s.get_sblock('B')).reads[0].buffer s.transform_layout(block='B', buffer=A, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: raise ValueError(f'Unexpected argument_style: {argument_style}') @@ -162,11 +162,11 @@ def test_set_axis_separator_subregion(argument_style): s = tir.Schedule(func, debug_mask='all') if argument_style=='set_axis_separators': - s.set_axis_separator(s.get_block("B"), ("write",0), [1]) + s.set_axis_separator(s.get_sblock("B"), ("write",0), [1]) elif argument_style=='transform_layout_named': s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) elif argument_style =='transform_layout_buffer_object': - B = s.get(s.get_block('B')).writes[0].buffer + B = s.get(s.get_sblock('B')).writes[0].buffer s.transform_layout(block='B', buffer=B, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: raise ValueError(f'Unexpected argument_style: {argument_style}') @@ -187,7 +187,7 @@ def before(): A = T.alloc_buffer([4,4], dtype="int32") B = T.alloc_buffer([1,1], dtype="int32") for j in T.serial(4): - with T.block('block'): + with T.sblock('block'): A[B[0,0],j] = 0 @T.prim_func @@ -195,7 +195,7 @@ def expected(): A = T.alloc_buffer([4,4], dtype="int32") B = T.alloc_buffer([1,1], dtype="int32", axis_separators=[1]) for j in T.serial(4): - with T.block('block'): + with T.sblock('block'): A[B[0,0],j] = 0 diff --git a/tests/python/tir-schedule/test_tir_schedule_set_dtype.py b/tests/python/tir-schedule/test_tir_schedule_set_dtype.py index 96441b630b05..eff91e55204e 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_dtype.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_dtype.py @@ -34,11 +34,11 @@ def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "fl B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -46,13 +46,13 @@ def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "fl def element_wise_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): B = T.alloc_buffer((128, 128), "float16") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -63,12 +63,12 @@ def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) B_subregion0[()] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) C[vi, vj] = B_subregion1[()] + 1.0 @@ -78,14 +78,14 @@ def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: B = T.alloc_buffer((128, 128), "float16") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B_subregion0 = T.match_buffer(B[vi, vj], (), "float16", offset_factor=1) B_subregion0[()] = T.cast(A[vi, vj] * 2.0, "float16") for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) @@ -98,7 +98,7 @@ def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"), C def test_set_dtype(use_block_name): func = element_wise sch = tir.Schedule(func, debug_mask="all") - sch.unsafe_set_dtype("B" if use_block_name else sch.get_block("B"), 0, "float16") + sch.unsafe_set_dtype("B" if use_block_name else sch.get_sblock("B"), 0, "float16") assert_structural_equal_ignore_global_symbol(element_wise_set_dtype, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func) @@ -106,20 +106,20 @@ def test_set_dtype_fail_on_output_buffer(use_block_name): func = element_wise sch = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - sch.unsafe_set_dtype('C' if use_block_name else sch.get_block("C"), 0, "float16") + sch.unsafe_set_dtype('C' if use_block_name else sch.get_sblock("C"), 0, "float16") def test_set_dtype_fail_on_index_out_of_bound(): func = element_wise sch = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - sch.unsafe_set_dtype(sch.get_block("B"), 1, "float64") + sch.unsafe_set_dtype(sch.get_sblock("B"), 1, "float64") with pytest.raises(tvm.tir.ScheduleError): - sch.unsafe_set_dtype(sch.get_block("B"), -1, "float64") + sch.unsafe_set_dtype(sch.get_sblock("B"), -1, "float64") def test_set_dtype_subregion(): func = element_wise_subregion_match sch = tir.Schedule(func, debug_mask='all') - sch.unsafe_set_dtype(sch.get_block("B"), 0, "float16") + sch.unsafe_set_dtype(sch.get_sblock("B"), 0, "float16") assert_structural_equal_ignore_global_symbol(element_wise_subregion_match_set_dtype, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func) diff --git a/tests/python/tir-schedule/test_tir_schedule_set_scope.py b/tests/python/tir-schedule/test_tir_schedule_set_scope.py index 991a4ca9b77f..f35d7f06f0b4 100644 --- a/tests/python/tir-schedule/test_tir_schedule_set_scope.py +++ b/tests/python/tir-schedule/test_tir_schedule_set_scope.py @@ -33,11 +33,11 @@ def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "fl B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -47,11 +47,11 @@ def element_wise_set_scope(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_shared[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B_shared[vi, vj] + T.float32(1) @@ -61,12 +61,12 @@ def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) B_subregion0[()] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) C[vi, vj] = B_subregion1[()] + 1.0 @@ -77,12 +77,12 @@ def element_wise_subregion_match_set_scope(A: T.Buffer((128, 128), "float32"), C B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion0_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1) B_subregion0_shared[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B_subregion1_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1) C[vi, vj] = B_subregion1_shared[()] + T.float32(1) @@ -96,7 +96,7 @@ def element_wise_subregion_match_set_scope(A: T.Buffer((128, 128), "float32"), C def test_set_scope(use_block_name, use_buffer_name): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_scope('B' if use_block_name else s.get_block("B"), 'B' if use_buffer_name else 0, "shared") + s.set_scope('B' if use_block_name else s.get_sblock("B"), 'B' if use_buffer_name else 0, "shared") assert_structural_equal_ignore_global_symbol(element_wise_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -105,29 +105,29 @@ def test_set_scope_fail_on_output_buffer(use_block_name, use_buffer_name): func = element_wise s = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - s.set_scope('C' if use_block_name else s.get_block("C"), 'C' if use_buffer_name else 0, "shared") + s.set_scope('C' if use_block_name else s.get_sblock("C"), 'C' if use_buffer_name else 0, "shared") def test_set_scope_fail_on_index_out_of_bound(): func = element_wise s = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - s.set_scope(s.get_block("B"), 1, "shared") + s.set_scope(s.get_sblock("B"), 1, "shared") with pytest.raises(tvm.tir.ScheduleError): - s.set_scope(s.get_block("B"), -1, "shared") + s.set_scope(s.get_sblock("B"), -1, "shared") def test_set_scope_fail_on_invalid_scope(): func = element_wise s = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - s.set_scope(s.get_block("B"), 0, "test_scope") + s.set_scope(s.get_sblock("B"), 0, "test_scope") def test_set_scope_subregion(): func = element_wise_subregion_match s = tir.Schedule(func, debug_mask='all') - s.set_scope(s.get_block("B"), 0, "shared") + s.set_scope(s.get_sblock("B"), 0, "shared") assert_structural_equal_ignore_global_symbol(element_wise_subregion_match_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index f09f7417baf6..28f7467c10f2 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -34,7 +34,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -45,7 +45,7 @@ def elementwise_dependent_loops(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i in T.serial(0, 128): for j, k in T.grid(i, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) vj = T.axis.S(i, j) vk = T.axis.S(128, k) @@ -57,7 +57,7 @@ def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k in T.grid(128, 128, n): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -67,7 +67,7 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i_j_k_fused in T.serial(0, (n * 16384)): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, n * 128), n)) vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) @@ -81,7 +81,7 @@ def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)): - with T.block("B"): + with T.sblock("B"): T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n)) vi, vj = T.axis.remap("SS", [i, j]) vk = T.axis.S(n, k0 * T.floordiv(n + 9, 10) + k1) @@ -97,11 +97,11 @@ def elementwise_with_seq(a: T.handle, b: T.handle) -> None: C = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -112,7 +112,7 @@ def elementwise_with_anno(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128, annotations={"useless_annotation": True}): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -125,7 +125,7 @@ def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -138,7 +138,7 @@ def elementwise_with_starting_point(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(10, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -150,10 +150,10 @@ def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block("opaque"): + with T.sblock("opaque"): T.reads([A[i, j, k]]) T.writes([B[i, j, k]]) - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -165,7 +165,7 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for fused in T.serial(0, 2097152): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, T.floordiv(fused, 16384)) vj = T.axis.S(128, T.floordiv(T.floormod(fused, 16384), 128)) vk = T.axis.S(128, T.floormod(fused, 128)) @@ -179,7 +179,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3) vj = T.axis.S(128, j1 * 32 + j2) vk = T.axis.S(128, k1 * 8 + k2) @@ -193,7 +193,7 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3) vj = T.axis.S(128, j1 * 64 + j2 * 64 + j3) vk = T.axis.S(128, k1 * 64 + k2 * 64 + k3) @@ -207,7 +207,7 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) vj = T.axis.S(128, j0 * 129 + j1) vk = T.axis.S(128, k0 * 43 + k1) @@ -222,7 +222,7 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i_j_k_fused in T.serial(0, 2097152): - with T.block("opaque"): + with T.sblock("opaque"): T.reads( [ A[ @@ -241,7 +241,7 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: ] ] ) - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, 16384), 128)) vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) @@ -256,10 +256,10 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, j, k in T.grid(8, 16, 128, 128): - with T.block("opaque"): + with T.sblock("opaque"): T.reads([A[i0 * 16 + i1, j, k]]) T.writes([B[i0 * 16 + i1, j, k]]) - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i0 * 16 + i1) vj, vk = T.axis.remap("SS", [j, k]) T.reads([A[vi, vj, vk]]) @@ -272,13 +272,13 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for i, j in T.grid(16, 16): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i, j in T.grid(16, 16): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) @@ -290,14 +290,14 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16]) B = T.match_buffer(b, [16, 16]) for i_j_fused in T.serial(0, 256): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i_j_fused in T.serial(0, 256): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) @@ -310,14 +310,14 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16)) B = T.match_buffer(b, (16, 16)) for i, j0, j1 in T.grid(16, 4, 4): - with T.block("A"): + with T.sblock("A"): vi = T.axis.S(16, i) vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i, j0, j1 in T.grid(16, 4, 4): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(16, i) vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) @@ -331,7 +331,7 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (127, 128)) for i in T.serial(0, 4): for j, k in T.grid(T.min(31, 126 - i * 32) + 1, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(127, i * 32 + j) vj = T.axis.S(128, k) B[vi, vj] = A[vi, vj] @@ -343,7 +343,7 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S( 127, i * 32 + T.floordiv(j_k_fused, 128), @@ -359,7 +359,7 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: def test_fuse(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) assert_structural_equal_ignore_global_symbol(elementwise_fused, sch.mod["main"]) @@ -369,7 +369,7 @@ def test_fuse(): @pytest.mark.parametrize("disable_predication", [True, False]) def test_split(disable_predication): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[2, 1, 64], disable_predication=disable_predication) sch.split(j, factors=[4, 32], disable_predication=disable_predication) @@ -380,7 +380,7 @@ def test_split(disable_predication): def test_split_with_inferred_factor(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[None, 1, 64]) sch.split(j, factors=[2, None, 64]) @@ -397,7 +397,7 @@ def before(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (N, 128, M)) B = T.match_buffer(b, (N, 128, M)) for i, j, k in T.grid(N, 128, M): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -407,7 +407,7 @@ def expected(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (N, 128, M)) B = T.match_buffer(b, (N, 128, M)) for i_0, i_1, j_0, j_1, k_0, k_1 in T.grid((N + 15) // 16, 16, 4, 32, 16, (M + 15) // 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(N, i_0 * 16 + i_1) vj = T.axis.spatial(128, j_0 * 32 + j_1) vk = T.axis.spatial(M, k_0 * ((M + 15) // 16) + k_1) @@ -415,7 +415,7 @@ def expected(a: T.handle, b: T.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2.0) sch = tir.Schedule(before, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[None, 16]) sch.split(j, factors=[4, 32]) @@ -426,7 +426,7 @@ def expected(a: T.handle, b: T.handle) -> None: def test_split_with_predicate(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[1000, 2, 3]) sch.split(j, factors=[None, 129]) @@ -437,7 +437,7 @@ def test_split_with_predicate(): def test_fuse_fail_not_only_child(): sch = tir.Schedule(elementwise_with_seq, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) @@ -445,7 +445,7 @@ def test_fuse_fail_not_only_child(): def test_fuse_split_fail_with_annotation(): sch = tir.Schedule(elementwise_with_anno, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) @@ -455,7 +455,7 @@ def test_fuse_split_fail_with_annotation(): def test_fuse_split_fail_not_start_with_zero(): sch = tir.Schedule(elementwise_with_anno, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) @@ -465,7 +465,7 @@ def test_fuse_split_fail_not_start_with_zero(): def test_fuse_with_opaque_block(): sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") - block_opaque = sch.get_block("opaque") + block_opaque = sch.get_sblock("opaque") i, j, k = sch.get_loops(block_opaque) sch.fuse(i, j, k) assert_structural_equal_ignore_global_symbol( @@ -476,10 +476,10 @@ def test_fuse_with_opaque_block(): def test_fuse_with_opaque_access(): sch = tir.Schedule(opaque_access, debug_mask="all") - block_a = sch.get_block("A") + block_a = sch.get_sblock("A") i, j = sch.get_loops(block_a) sch.fuse(i, j) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j = sch.get_loops(block_b) sch.fuse(i, j) assert_structural_equal_ignore_global_symbol(opaque_access_fused, sch.mod["main"]) @@ -488,7 +488,7 @@ def test_fuse_with_opaque_access(): def test_split_with_opaque_block(): sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") - block_opaque = sch.get_block("opaque") + block_opaque = sch.get_sblock("opaque") i, _, _ = sch.get_loops(block_opaque) sch.split(i, factors=[None, 16]) assert_structural_equal_ignore_global_symbol( @@ -499,10 +499,10 @@ def test_split_with_opaque_block(): def test_split_with_opaque_access(): sch = tir.Schedule(opaque_access, debug_mask="all") - block_a = sch.get_block("A") + block_a = sch.get_sblock("A") _, j = sch.get_loops(block_a) sch.split(j, factors=[None, 4]) - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j = sch.get_loops(block_b) sch.split(j, factors=[None, 4]) assert_structural_equal_ignore_global_symbol(opaque_access_split, sch.mod["main"]) @@ -511,7 +511,7 @@ def test_split_with_opaque_access(): def test_split_with_non_positive_factors(): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.split(i, factors=[-2, -64]) @@ -523,7 +523,7 @@ def test_split_with_non_positive_factors(): def test_fuse_split_fail_with_thread_binding(): sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) @@ -533,7 +533,7 @@ def test_fuse_split_fail_with_thread_binding(): def test_fuse_symbolic(): sch = tir.Schedule(elementwise_symbolic, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) assert_structural_equal_ignore_global_symbol(elementwise_symbolic_fused, sch.mod["main"]) @@ -542,7 +542,7 @@ def test_fuse_symbolic(): def test_split_symbolic(): sch = tir.Schedule(elementwise_symbolic, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, k = sch.get_loops(block_b) sch.split(k, factors=[10, None]) assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, sch.mod["main"]) @@ -551,7 +551,7 @@ def test_split_symbolic(): def test_fuse_fail_with_dependent_loops(): sch = tir.Schedule(elementwise_dependent_loops, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, _ = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(i, j) @@ -559,7 +559,7 @@ def test_fuse_fail_with_dependent_loops(): def test_fuse_not_affine(): sch = tir.Schedule(elementwise_not_affine, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, j, k = sch.get_loops(block_b) sch.fuse(j, k) assert_structural_equal_ignore_global_symbol(elementwise_not_affine_fused, sch.mod["main"]) @@ -573,7 +573,7 @@ def zero_dim( B: T.Buffer((), "int32"), C: T.Buffer((), "int32"), ) -> None: - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] @@ -584,12 +584,12 @@ def zero_dim_added( C: T.Buffer((), "int32"), ) -> None: for u in range(1): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] sch = tir.Schedule(zero_dim, debug_mask="all") - block = sch.get_block("C") + block = sch.get_sblock("C") sch.add_unit_loop(block) assert_structural_equal_ignore_global_symbol(zero_dim_added, sch.mod["main"]) @@ -602,7 +602,7 @@ def zero_dim( C: T.Buffer((), "int32"), ) -> None: for u in range(1): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] @@ -613,12 +613,12 @@ def zero_dim_added( C: T.Buffer((), "int32"), ) -> None: for u1, u2 in T.grid(1, 1): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] sch = tir.Schedule(zero_dim, debug_mask="all") - block = sch.get_block("C") + block = sch.get_sblock("C") (loop,) = sch.get_loops(block) sch.add_unit_loop(loop) assert_structural_equal_ignore_global_symbol(zero_dim_added, sch.mod["main"]) @@ -635,7 +635,7 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - i, j = sch.get_loops(sch.get_block("B")) + i, j = sch.get_loops(sch.get_sblock("B")) sch.fuse(i, j) verify_trace_roundtrip(sch=sch, mod=mod) @@ -649,7 +649,7 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - (i,) = sch.get_loops(sch.get_block("B")) + (i,) = sch.get_loops(sch.get_sblock("B")) sch.split( i, factors=[ @@ -668,7 +668,7 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - (i,) = sch.get_loops(sch.get_block("B")) + (i,) = sch.get_loops(sch.get_sblock("B")) sch.split( i, factors=[ @@ -683,7 +683,7 @@ def _create_prim_func(): def test_split_int64_factors(): sch = tir.Schedule(elementwise_symbolic, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") _, _, k = sch.get_loops(block_b) sch.split(k, factors=[IntImm(dtype="int64", value=10), None]) assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, sch.mod["main"]) @@ -705,7 +705,7 @@ def before(a: T.handle): A = T.match_buffer(a, (num_elements,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i in T.serial(num_elements): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.remap("S", [i]) A[v_i] = 1.0 @@ -714,7 +714,7 @@ def after(a: T.handle): A = T.match_buffer(a, (num_elements,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements) A[v_i] = 1.0 @@ -741,7 +741,7 @@ def before(a: T.handle): A = T.match_buffer(a, (128,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i in T.serial(128): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.remap("S", [i]) A[v_i] = 1.0 @@ -750,7 +750,7 @@ def after(a: T.handle): A = T.match_buffer(a, (128,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1) A[v_i] = 1.0 @@ -771,7 +771,7 @@ def before(a: T.handle): A = T.match_buffer(a, (128,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i in T.serial(4 * T.vscale()): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.remap("S", [i]) A[v_i] = 1.0 @@ -780,7 +780,7 @@ def after(a: T.handle): A = T.match_buffer(a, (128,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i_0, i_1 in T.grid(T.vscale() * 2, T.vscale() * 2): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.spatial(T.vscale() * 4, i_0 * (T.vscale() * 2) + i_1) T.where(i_0 * (T.vscale() * 2) + i_1 < T.vscale() * 4) A[v_i] = 1.0 @@ -802,7 +802,7 @@ def before(a: T.handle): A = T.match_buffer(a, (128,), "float32") T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) for i in T.serial(128): - with T.block("A"): + with T.sblock("A"): v_i = T.axis.remap("S", [i]) A[v_i] = 1.0 diff --git a/tests/python/tir-schedule/test_tir_schedule_state.py b/tests/python/tir-schedule/test_tir_schedule_state.py index c023b9dbc59d..83d4ef30f1fd 100644 --- a/tests/python/tir-schedule/test_tir_schedule_state.py +++ b/tests/python/tir-schedule/test_tir_schedule_state.py @@ -34,11 +34,11 @@ def elementwise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -49,11 +49,11 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -63,25 +63,25 @@ def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") for i in range(128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) B[vi, 0] = A[vi, 0] if A[vi, 0] == 0.0: - with T.block("C"): + with T.sblock("C"): T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) for j in range(128): - with T.block("D"): + with T.sblock("D"): vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 3.0 else: - with T.block("E"): + with T.sblock("E"): T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) for j in range(128): - with T.block("F"): + with T.sblock("F"): vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 @@ -92,7 +92,7 @@ def block_in_opaque_block(a: T.handle, b: T.handle) -> None: def replace_ir_builder(deep_copy=False, realize=False): new_func = tvm.script.from_source(elementwise.script()) s = tir.ScheduleState(new_func, debug_mask="all") - target = tvm.tir.Block( + target = tvm.tir.SBlock( iter_vars=[], reads=[], writes=[], @@ -104,7 +104,7 @@ def replace_ir_builder(deep_copy=False, realize=False): annotations=None, ) if realize: - target = tvm.tir.BlockRealize( + target = tvm.tir.SBlockRealize( iter_values=[], predicate=True, block=target, @@ -120,7 +120,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): other_func = tvm.script.from_source(elementwise.script()) mod = IRModule(functions={"main": new_func, "other": other_func}) s = tir.ScheduleState(mod, debug_mask="all") - target = tvm.tir.Block( + target = tvm.tir.SBlock( iter_vars=[], reads=[], writes=[], @@ -132,7 +132,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): annotations=None, ) if realize: - target = tvm.tir.BlockRealize( + target = tvm.tir.SBlockRealize( iter_values=[], predicate=True, block=target, diff --git a/tests/python/tir-schedule/test_tir_schedule_state_cached_flags.py b/tests/python/tir-schedule/test_tir_schedule_state_cached_flags.py index 8120aa2aea31..420ca6a21d35 100644 --- a/tests/python/tir-schedule/test_tir_schedule_state_cached_flags.py +++ b/tests/python/tir-schedule/test_tir_schedule_state_cached_flags.py @@ -34,11 +34,11 @@ def elementwise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -49,11 +49,11 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for k in range(0, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -63,25 +63,25 @@ def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") for i in range(128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) B[vi, 0] = A[vi, 0] if A[vi, 0] == 0.0: - with T.block("C"): + with T.sblock("C"): T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) for j in range(128): - with T.block("D"): + with T.sblock("D"): vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 3.0 else: - with T.block("E"): + with T.sblock("E"): T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) for j in range(128): - with T.block("F"): + with T.sblock("F"): vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 @@ -92,11 +92,11 @@ def write_after_read(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -107,10 +107,10 @@ def loop_carried_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) C[vi] = T.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") @@ -120,15 +120,15 @@ def concatenate_multi_producer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block("A_0"): + with T.sblock("A_0"): vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block("A_1"): + with T.sblock("A_1"): vi = T.axis.S(64, i + 64) A[vi] = vi + 2 for i in range(0, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 @@ -138,15 +138,15 @@ def concatenate_multi_producer_uncovered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 63): - with T.block("A_0"): + with T.sblock("A_0"): vi = T.axis.S(63, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block("A_1"): + with T.sblock("A_1"): vi = T.axis.S(64, i + 64) A[vi] = vi + 2 for i in range(0, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 @@ -157,10 +157,10 @@ def lca_at_loop(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i) C[vi] = B[vi] + 1.0 @@ -170,19 +170,19 @@ def multi_producer_consumer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block("A_0"): + with T.sblock("A_0"): vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block("A_1"): + with T.sblock("A_1"): vi = T.axis.S(64, i + 64) A[vi] = vi + 2 for i in range(0, 64): - with T.block("B_0"): + with T.sblock("B_0"): vi = T.axis.S(64, i) B[vi] = A[vi] + 2.0 for i in range(0, 64): - with T.block("B_1"): + with T.sblock("B_1"): vi = T.axis.S(64, i + 64) B[vi] = A[vi] + 3.0 @@ -193,12 +193,12 @@ def elementwise_affine_producer(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j, k, l in T.grid(16, 2, 32, 16): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i * 8 + j * 4 + k // 8) vj = T.axis.S(128, k % 8 * 16 + l) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -209,16 +209,16 @@ def elementwise_subblock(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(32, 32): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) for ii, jj in T.grid(4, 4): - with T.block("B_sub"): + with T.sblock("B_sub"): vi_i, vj_i = T.axis.remap("SS", [ii, jj]) B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -229,16 +229,16 @@ def elementwise_subblock_uncovered(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(32, 32): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) for ii, jj in T.grid(2, 2): - with T.block("B_sub"): + with T.sblock("B_sub"): vi_i, vj_i = T.axis.remap("SS", [ii, jj]) B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -250,11 +250,11 @@ def bound_to_thread(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], scope="shared") for i in T.thread_binding(0, 128, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vj, vi] = B[vj, vi] + 1.0 @@ -267,12 +267,12 @@ def equal_ranked_threads(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 16, thread="threadIdx.x"): for i_i in T.thread_binding(0, 8, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i_o * 8 + i_i) vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi = T.axis.S(128, i_o * 8 + i_i) vj = T.axis.S(128, j) C[vj, vi] = B[vj, vi] + 1.0 @@ -286,11 +286,11 @@ def warp_memory(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 @@ -303,12 +303,12 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for i_o_prime in T.thread_binding(0, 4, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): _warp_id, warp_id, lane_id, vj = T.axis.remap( "SSSS", [i_o, i_i, i_o_prime, j] ) @@ -323,7 +323,7 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: for hh_0, ww_0 in T.grid(28, 28): for ax0 in T.serial(0, 10): for ax1 in T.serial(0, 10): - with T.block("cache"): + with T.sblock("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where( @@ -334,7 +334,7 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: ) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): - with T.block("compute"): + with T.sblock("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) @@ -357,11 +357,11 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: @T.prim_func def uncovered_producer_region(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in range(120): - with T.block("producer"): + with T.sblock("producer"): vi = T.axis.S((0, 120), i) A[vi] = 1.0 for i in range(120): - with T.block("consumer"): + with T.sblock("consumer"): vi = T.axis.S((8, 128), i + 8) B[vi] = A[vi] @@ -371,20 +371,20 @@ def matmul_relu_padding(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 12 # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") C = T.alloc_buffer([127, 127], dtype="float32") A_reindex = T.alloc_buffer([128, 128], dtype="float16") B_reindex = T.alloc_buffer([128, 128], dtype="float16") C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") for ax0, ax1, ax2 in T.grid(128, 1, 128): - with T.block("A_reindex"): + with T.sblock("A_reindex"): v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[v0, v2]) T.writes(A_reindex[v0, v2]) A_reindex[v0, v2] = T.if_then_else(v0 < 127 and v2 < 127, A[v0, v2], T.float16(0), dtype="float16") for ax0, ax1, ax2 in T.grid(1, 128, 128): - with T.block("B_reindex"): + with T.sblock("B_reindex"): v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(B[v2, v1]) T.writes(B_reindex[v2, v1]) @@ -393,45 +393,45 @@ def matmul_relu_padding(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 12 for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): for ax2_0_0, ax2_0_1, ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 2, 1, 2, 2, 1, 1): - with T.block("C_o"): + with T.sblock("C_o"): v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2) T.reads(A_reindex[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): + with T.sblock("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + T.sblock_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") for ax0, ax1 in T.grid(16, 32): - with T.block("C_reindex_shared_wmma.accumulator"): + with T.sblock("C_reindex_shared_wmma.accumulator"): v0 = T.axis.spatial(128, ax0_0_2_ax1_0_2_fused // 2 * 16 + ax0) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_2_ax1_0_2_fused % 2 * 32 + ax1) T.reads(C_reindex_shared_wmma_accumulator[v0, v1]) T.writes(C_reindex_shared[v0, v1]) C_reindex_shared[v0, v1] = C_reindex_shared_wmma_accumulator[v0, v1] for ax0, ax1 in T.grid(128, 64): - with T.block("C_reindex_shared"): + with T.sblock("C_reindex_shared"): v0 = T.axis.spatial(128, ax0) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax1) T.where(ax0 < 127 and ax0_0_0_ax1_0_0_fused * 64 + ax1 < 127) T.reads(C_reindex_shared[v0, v1]) T.writes(C[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) + T.sblock_attr({"meta_schedule.cooperative_fetch":3}) C[v0, v1] = C_reindex_shared[v0, v1] for i0, i1 in T.grid(127, 127): - with T.block("compute"): + with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(C[i0_1, i1_1]) T.writes(compute[i0_1, i1_1]) @@ -444,7 +444,7 @@ def splitted_square_sum_with_predicate( ) -> None: for i0_i1_i2_i3_0_fused, ax0, ax1, ax2, ax3 in T.grid(2, 1, 1, 1, 256): for ax4_ax5_fused_0, ax4_ax5_fused_1 in T.grid(1, 256): - with T.block("B"): + with T.sblock("B"): T.where(ax4_ax5_fused_0 * 256 + ax4_ax5_fused_1 < 49) ax0_1, ax1_1, ax2_1 = T.axis.remap("SSS", [ax0, ax1, ax2]) ax3_1 = T.axis.spatial(512, i0_i1_i2_i3_0_fused * 256 + ax3) @@ -461,34 +461,34 @@ def splitted_square_sum_with_predicate( # fmt: on -def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: +def _get_sblock(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: result = None def f_visit(node): nonlocal result - if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint: + if isinstance(node, tvm.tir.SBlock) and node.name_hint == name_hint: result = node func = s.mod["main"] post_order_visit(func.body, f_visit) - assert result is not None and isinstance(result, tvm.tir.Block) + assert result is not None and isinstance(result, tvm.tir.SBlock) return s.get_sref(result) def test_elementwise(): s = tir.ScheduleState(elementwise, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -499,17 +499,17 @@ def test_elementwise(): def test_matmul(): s = tir.ScheduleState(matmul, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "init")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "update")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "update")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -520,27 +520,27 @@ def test_matmul(): def test_block_in_opaque_block(): s = tir.ScheduleState(block_in_opaque_block, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "E")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "E")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "F")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "F")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -551,17 +551,17 @@ def test_block_in_opaque_block(): def test_write_after_read(): s = tir.ScheduleState(write_after_read, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, @@ -572,17 +572,17 @@ def test_write_after_read(): def test_loop_carried_dependency(): s = tir.ScheduleState(loop_carried_dependency, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, @@ -593,22 +593,22 @@ def test_loop_carried_dependency(): def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name s = tir.ScheduleState(concatenate_multi_producer, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -619,22 +619,22 @@ def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name s = tir.ScheduleState(concatenate_multi_producer_uncovered, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, @@ -645,17 +645,17 @@ def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name def test_lca_at_loop(): s = tir.ScheduleState(lca_at_loop, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -666,22 +666,22 @@ def test_lca_at_loop(): def test_multi_producer_consumer(): s = tir.ScheduleState(multi_producer_consumer, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "A_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "A_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B_0")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B_0")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B_1")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B_1")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -692,17 +692,17 @@ def test_multi_producer_consumer(): def test_elementwise_affine_producer(): s = tir.ScheduleState(elementwise_affine_producer, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -713,22 +713,22 @@ def test_elementwise_affine_producer(): def test_subblock(): s = tir.ScheduleState(elementwise_subblock, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B_sub")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -739,22 +739,22 @@ def test_subblock(): def test_subblock_uncovered(): s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B_sub")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, @@ -765,17 +765,17 @@ def test_subblock_uncovered(): def test_thread_binding(): s = tir.ScheduleState(bound_to_thread, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -786,17 +786,17 @@ def test_thread_binding(): def test_equal_ranked_threads(): s = tir.ScheduleState(equal_ranked_threads, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -807,17 +807,17 @@ def test_equal_ranked_threads(): def test_warp_memory(): s = tir.ScheduleState(warp_memory, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -828,17 +828,17 @@ def test_warp_memory(): def test_warp_memory_negative(): s = tir.ScheduleState(warp_memory_negative, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "root")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=False, ) - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, @@ -849,12 +849,12 @@ def test_warp_memory_negative(): def test_non_perfect_tiling_cache(): s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "cache")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, ) - assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "compute")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -865,7 +865,7 @@ def test_non_perfect_tiling_cache(): def test_uncovered_producer_region(): s = tir.ScheduleState(uncovered_producer_region, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "consumer")) == CachedFlags( affine_binding=True, region_cover=False, stage_pipeline=True, @@ -876,7 +876,7 @@ def test_uncovered_producer_region(): def test_matmul_relu_padding(): s = tir.ScheduleState(matmul_relu_padding, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "C_reindex_shared")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "C_reindex_shared")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, @@ -887,7 +887,7 @@ def test_matmul_relu_padding(): def test_splitted_square_sum_with_predicate(): s = tir.ScheduleState(splitted_square_sum_with_predicate, debug_mask="all") # pylint: disable=protected-access - assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + assert s._get_cached_flags(_get_sblock(s, "B")) == CachedFlags( affine_binding=True, region_cover=True, stage_pipeline=True, diff --git a/tests/python/tir-schedule/test_tir_schedule_storage_align.py b/tests/python/tir-schedule/test_tir_schedule_storage_align.py index 3825234c20e0..ea13cc436bdb 100644 --- a/tests/python/tir-schedule/test_tir_schedule_storage_align.py +++ b/tests/python/tir-schedule/test_tir_schedule_storage_align.py @@ -32,19 +32,19 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1) # body - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=64, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) @@ -56,20 +56,20 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1) # body - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=64, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) - T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) + T.sblock_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) @@ -81,20 +81,20 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1) # body - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=64, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block("B"): - T.block_attr({"buffer_dim_align": [0]}) + with T.sblock("B"): + T.sblock_attr({"buffer_dim_align": [0]}) vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) @@ -106,7 +106,7 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: def test_storage_align(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = 'B' if use_block_name else s.get_block("B") + B = 'B' if use_block_name else s.get_sblock("B") s.storage_align(B, 0, axis=0, factor=128, offset=127) assert_structural_equal_ignore_global_symbol(element_wise_storage_align, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -115,7 +115,7 @@ def test_storage_align(use_block_name): def test_storage_align_update(): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = s.get_sblock("B") s.storage_align(B, 0, axis=0, factor=128, offset=0) s.storage_align(B, 0, axis=0, factor=128, offset=127) assert_structural_equal_ignore_global_symbol(element_wise_storage_align, s.mod["main"]) @@ -125,7 +125,7 @@ def test_storage_align_update(): def test_storage_align_invalid_factor1(): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = s.get_sblock("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 0, axis=0, factor=0, offset=127) @@ -133,7 +133,7 @@ def test_storage_align_invalid_factor1(): def test_storage_align_invalid_factor2(): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = s.get_sblock("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 0, axis=0, factor=-1, offset=127) @@ -141,7 +141,7 @@ def test_storage_align_invalid_factor2(): def test_storage_align_invalid_buffer(): func = element_wise s = tir.Schedule(func, debug_mask='all') - C = s.get_block("C") + C = s.get_sblock("C") with pytest.raises(tir.ScheduleError): s.storage_align(C, 0, axis=0, factor=128, offset=127) @@ -149,7 +149,7 @@ def test_storage_align_invalid_buffer(): def test_storage_align_invalid_buffer_index(): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = s.get_sblock("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 2, axis=0, factor=128, offset=127) @@ -157,7 +157,7 @@ def test_storage_align_invalid_buffer_index(): def test_storage_align_invalid_axis(): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = s.get_sblock("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 0, axis=2, factor=128, offset=127) @@ -165,7 +165,7 @@ def test_storage_align_invalid_axis(): def test_storage_align_invalid_annotation(): func = element_wise_invalid_annotation s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = s.get_sblock("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 0, axis=2, factor=128, offset=127) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py b/tests/python/tir-schedule/test_tir_schedule_tensorize.py index 789d6be3ad0b..958149bca501 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py @@ -46,11 +46,11 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 16), align=64, offset_factor=1) C = T.match_buffer(c, (16, 16), align=64, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) T.writes(C[0 : 16, 0 : 16]) for i, j, k in T.grid(16, 16, 16): - with T.block("update"): + with T.sblock("update"): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] @@ -61,7 +61,7 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 16), align=64, offset_factor=1) C = T.match_buffer(c, (16, 16), align=64, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) T.writes(C[0 : 16, 0 : 16]) T.evaluate( @@ -85,11 +85,11 @@ def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (4,)) C = T.match_buffer(c, ()) - with T.block("root"): + with T.sblock("root"): T.reads(C[()], A[0 : 4], B[0 : 4]) T.writes(C[()]) for i in range(0, 4): - with T.block("update"): + with T.sblock("update"): vi = T.axis.remap("R", [i]) C[()] = C[()] + A[vi] * B[vi] @@ -100,7 +100,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (4,), offset_factor=1) C = T.match_buffer(c, (), offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[()], A[0 : 4], B[0 : 4]) T.writes(C[()]) T.evaluate( @@ -123,10 +123,10 @@ def dot_product_intrin_annotated(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (4,), offset_factor=1) C = T.match_buffer(c, (), offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[()], A[0 : 4], B[0 : 4]) T.writes(C[()]) - T.block_attr({"test_annotation": True}) + T.sblock_attr({"test_annotation": True}) T.evaluate( T.call_extern( "vec4add", @@ -147,7 +147,7 @@ def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 1), offset_factor=1) C = T.match_buffer(c, (16, 16), offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads( C[0 : 16, 0 : 16], A[0 : 16, 0 : 1], @@ -155,7 +155,7 @@ def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: ) T.writes(C[0 : 16, 0 : 16]) for i, j in T.grid(16, 16): - with T.block("update"): + with T.sblock("update"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0] @@ -166,7 +166,7 @@ def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 1), offset_factor=1) C = T.match_buffer(c, (16, 16), offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads( C[0 : 16, 0 : 16], A[0 : 16, 0 : 1], @@ -194,7 +194,7 @@ def matmul( C: T.Buffer((128, 128), "float32"), ) -> None: for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -209,12 +209,12 @@ def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: for i_outer, j_outer in T.grid(8, 8): for i_inner_init, j_inner_init in T.grid(16, 16): - with T.block("init"): + with T.sblock("init"): vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) C[vi_init, vj_init] = T.float32(0) for k_outer in T.grid(8): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) T.reads( [ @@ -264,12 +264,12 @@ def batch_matmul( C: T.Buffer((16, 128, 128), "float32"), ) -> None: for n, i, j in T.grid(16, 128, 128): - with T.block("init"): + with T.sblock("init"): vn, vi, vj = T.axis.remap("SSS", [n, i, j]) C[vn, vi, vj] = T.float32(0) for n, i, j, k in T.grid(16, 128, 128, 128): - with T.block("update"): + with T.sblock("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @@ -281,14 +281,14 @@ def tensorized_batch_matmul_mma( C: T.Buffer((16, 128, 128), "float32"), ) -> None: for n, i, j in T.grid(16, 128, 128): - with T.block("init"): + with T.sblock("init"): vn, vi, vj = T.axis.remap("SSS", [n, i, j]) T.reads() T.writes(C[vn, vi, vj]) C[vn, vi, vj] = T.float32(0) for n in range(0, 16): for i, j, k in T.grid(8, 8, 8): - with T.block("update"): + with T.sblock("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) T.reads( C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], @@ -336,13 +336,13 @@ def tensorized_batch_matmul_dot_product( C: T.Buffer((16, 128, 128), "float32"), ) -> None: for n, i, j in T.grid(16, 128, 128): - with T.block("init"): + with T.sblock("init"): vn, vi, vj = T.axis.remap("SSS", [n, i, j]) T.reads() T.writes(C[vn, vi, vj]) C[vn, vi, vj] = T.float32(0) for n, i, j, k_0 in T.grid(16, 128, 128, 32): - with T.block("blockized_update"): + with T.sblock("blockized_update"): vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0]) T.reads( C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko * 4 : vko * 4 + 4] @@ -376,13 +376,13 @@ def tensorized_batch_matmul_outer_product( C: T.Buffer((16, 128, 128), "float32"), ) -> None: for n, i, j in T.grid(16, 128, 128): - with T.block("init"): + with T.sblock("init"): vn, vi, vj = T.axis.remap("SSS", [n, i, j]) T.reads() T.writes(C[vn, vi, vj]) C[vn, vi, vj] = T.float32(0) for n, i_0, j_0, k in T.grid(16, 8, 8, 128): - with T.block("blockized_update"): + with T.sblock("blockized_update"): vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k]) T.reads( C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], @@ -409,12 +409,12 @@ def annotated_mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (16, 16), align=64, offset_factor=1) C = T.match_buffer(c, (16, 16), align=64, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) T.writes(C[0 : 16, 0 : 16]) for i, j, k in T.grid(16, 16, 16): - with T.block("update"): - T.block_attr({"test_annotation": True}) + with T.sblock("update"): + T.sblock_attr({"test_annotation": True}) vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] @@ -426,9 +426,9 @@ def annotated_matmul( C: T.Buffer((128, 128), "float32"), ) -> None: for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.block_attr({"test_annotation": True}) + T.sblock_attr({"test_annotation": True}) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -442,13 +442,13 @@ def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: for i_outer, j_outer in T.grid(8, 8): for i_inner_init, j_inner_init in T.grid(16, 16): - with T.block("init"): + with T.sblock("init"): vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) - T.block_attr({"test_annotation": True}) + T.sblock_attr({"test_annotation": True}) C[vi_init, vj_init] = T.float32(0) for k_outer in T.grid(8): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) T.reads( [ @@ -505,7 +505,7 @@ def test_tensorize_matmul(): func = matmul # schedule s = tir.Schedule(func, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") i, j, k = s.get_loops(update) io, ii = s.split(i, factors=[None, 16]) jo, ji = s.split(j, factors=[None, 16]) @@ -520,7 +520,7 @@ def test_tensorize_matmul(): def test_tensorize_batch_matmul(): func = batch_matmul s = tir.Schedule(func, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") _, i, j, k = s.get_loops(update) io, ii = s.split(i, factors=[None, 16]) jo, ji = s.split(j, factors=[None, 16]) @@ -534,7 +534,7 @@ def test_tensorize_batch_matmul(): def test_tensorize_dot_product(): func = batch_matmul s = tir.Schedule(func, debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") _, _, _, k = s.get_loops(C) _, ki = s.split(k, factors=[None, 4]) s.tensorize(ki, "test_dot_product_intrin") @@ -545,7 +545,7 @@ def test_tensorize_dot_product(): def test_tensorize_outer_product(): func = batch_matmul s = tir.Schedule(func, debug_mask="all") - C = s.get_block("update") + C = s.get_sblock("update") _, i, j, k = s.get_loops(C) io, ii = s.split(i, factors=[None, 16]) jo, ji = s.split(j, factors=[None, 16]) @@ -558,7 +558,7 @@ def test_tensorize_outer_product(): def test_tensorize_with_annotation(): func = annotated_matmul s = tir.Schedule(func, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") i, j, k = s.get_loops(update) io, ii = s.split(i, factors=[None, 16]) jo, ji = s.split(j, factors=[None, 16]) @@ -573,13 +573,13 @@ def test_tensorize_with_annotation(): def test_tensorize_intrinsic_with_annotation(): func = matmul s = tir.Schedule(func, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") _, _, k = s.get_loops(update) ko, ki = s.split(k, factors=[None, 4]) s.decompose_reduction(update, ko) s.tensorize(ki, "test_dot_product_intrin_annotated") - b = s.get(s.get_block("update_update_o")) + b = s.get(s.get_sblock("update_update_o")) assert b.annotations["test_annotation"] == T.bool(True) verify_trace_roundtrip(sch=s, mod=func) @@ -607,7 +607,7 @@ def tensorize_16x4_test(intrin=VNNI_DOT_16x4_INTRIN): func = get_matmul_packed(m, n, k, "uint8") sch = tir.Schedule(func, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") sch.transform_layout(block, "W", lambda i, j: [i//16, j//4, i%16, j%4]) _, j, k = sch.get_loops(block) @@ -636,7 +636,7 @@ def test_tensorize_arm_dot(): for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]: sch = tir.Schedule(func, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") sch.transform_layout(block, "W", lambda i, j: [i//4, j//4, i%4, j%4]) _, j, k = sch.get_loops(block) @@ -656,7 +656,7 @@ def test_tensorize_vrmpy(): func = get_matmul_packed(m, n, k, "uint8", "uint8") sch = tir.Schedule(func, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") sch.transform_layout(block, "W", lambda i, j: [i//32, j//4, i%32, j%4]) _, j, k = sch.get_loops(block) @@ -676,7 +676,7 @@ def test_tensorize_vdmpy(): func = get_matmul_packed(m, n, k, "int16", "int16") sch = tir.Schedule(func, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") sch.transform_layout(block, "W", lambda i, j: [i//32, j//2, i%32, j%2]) _, j, k = sch.get_loops(block) @@ -710,7 +710,7 @@ def _test_intrin(dtype_a, dtype_b, dtype_c, intrin): func = te.create_prim_func([X, W, matmul]) sch = tir.Schedule(func, debug_mask="all") - block = sch.get_block("compute") + block = sch.get_sblock("compute") i, j, k = sch.get_loops(block) by, ty, yi = sch.split(i, factors=sch.sample_perfect_tile(i, n=3)) @@ -762,12 +762,12 @@ def matmul_int64_shape( ) -> None: for i_0, j_0 in T.grid(T.int64(8), T.int64(8)): for i_1_init, j_1_init in T.grid(T.int64(16), T.int64(16)): - with T.block("init"): + with T.sblock("init"): vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1_init) vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1_init) C[vi, vj] = T.float32(0) for k_0, i_1, j_1, k_1 in T.grid(T.int64(8), T.int64(16), T.int64(16), T.int64(16)): - with T.block("update"): + with T.sblock("update"): vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1) vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1) vk = T.axis.reduce(T.int64(128), k_0 * T.int64(16) + k_1) @@ -781,12 +781,12 @@ def tensorized_matmul_int64_shape( ) -> None: for i_outer, j_outer in T.grid(T.int64(8), T.int64(8)): for i_inner_init, j_inner_init in T.grid(T.int64(16), T.int64(16)): - with T.block("init"): + with T.sblock("init"): vi = T.axis.spatial(T.int64(128), i_outer * T.int64(16) + i_inner_init) vj = T.axis.spatial(T.int64(128), j_outer * T.int64(16) + j_inner_init) C[vi, vj] = T.float32(0) for k_outer in T.grid(T.int64(8)): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) T.reads( [ @@ -830,7 +830,7 @@ def tensorized_matmul_int64_shape( # fmt: on s = tir.Schedule(matmul_int64_shape, debug_mask="all") - update = s.get_block("update") + update = s.get_sblock("update") ii = s.get_loops(update)[-3] s.tensorize(ii, "test_mma_intrin") assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape) @@ -866,11 +866,11 @@ def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None scope="local", ) - with T.block("root"): + with T.sblock("root"): T.reads(Compressed[0:1]) T.writes(Decompressed[0:8]) for i in T.grid(8): - with T.block("decode"): + with T.sblock("decode"): vi = T.axis.remap("S", [i]) Decompressed[vi] = _tir_packed_int_to_int_to_float(32)( 4, @@ -898,7 +898,7 @@ def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None scope="local", ) - with T.block("root"): + with T.sblock("root"): T.reads(Compressed[0:1]) T.writes(Decompressed[0:8]) T.call_extern( @@ -922,7 +922,7 @@ def decode_i4s_to_int32_to_f16(): for ax1_0 in range(32): for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 8): - with T.block("B_decode_local"): + with T.sblock("B_decode_local"): v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1) T.reads(B_local[v0, v1 // 8]) @@ -938,7 +938,7 @@ def tensorized_decode_i4s_to_int32_to_f16(): for ax1_0 in range(32): for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in range(1): - with T.block("B_decode_local_o"): + with T.sblock("B_decode_local_o"): v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1) T.reads(B_local[v0_o, v1_o]) @@ -948,7 +948,7 @@ def tensorized_decode_i4s_to_int32_to_f16(): T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8) s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all") - update = s.get_block("B_decode_local") + update = s.get_sblock("B_decode_local") ii = s.get_loops(update)[-1] s.tensorize(ii, "test_decode_i4s_to_f16_intrin") assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16) diff --git a/tests/python/tir-schedule/test_tir_schedule_trace.py b/tests/python/tir-schedule/test_tir_schedule_trace.py index 52c2e1dbc28a..2c70ffb278ad 100644 --- a/tests/python/tir-schedule/test_tir_schedule_trace.py +++ b/tests/python/tir-schedule/test_tir_schedule_trace.py @@ -23,7 +23,7 @@ import tvm.testing from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV, Trace +from tvm.tir.schedule import SBlockRV, Instruction, InstructionKind, LoopRV, Trace from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol # pylint: disable=no-member,invalid-name,unused-variable @@ -35,11 +35,11 @@ def elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -49,7 +49,7 @@ def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -57,9 +57,9 @@ def elementwise_inlined(a: T.handle, c: T.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable -def _make_get_block(name, output): +def _make_get_sblock(name, output): return Instruction( - kind=InstructionKind.get("GetBlock"), + kind=InstructionKind.get("GetSBlock"), inputs=[], attrs=[name, "main"], outputs=[output], @@ -102,7 +102,7 @@ def _make_enter_postproc(): ) -def _make_annotate(block: BlockRV, annotation: str): +def _make_annotate(block: SBlockRV, annotation: str): return Instruction( kind=InstructionKind.get("Annotate"), inputs=[block, annotation], @@ -114,7 +114,7 @@ def _make_annotate(block: BlockRV, annotation: str): def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name return Trace( insts=[ - _make_get_block(name="block", output=b0), + _make_get_sblock(name="block", output=b0), _make_get_loops(input=b0, outputs=[l1, l2]), ], decisions={}, @@ -124,7 +124,7 @@ def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name def _make_trace_2(b0): # pylint: disable=invalid-name return Trace( insts=[ - _make_get_block(name="B", output=b0), + _make_get_sblock(name="B", output=b0), _make_compute_inline(input=b0), ], decisions={}, @@ -134,17 +134,17 @@ def _make_trace_2(b0): # pylint: disable=invalid-name def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name if add_postproc: insts = [ - _make_get_block(name="B", output=b0), + _make_get_sblock(name="B", output=b0), _make_compute_inline(input=b0), - _make_get_block(name="C", output=b1), + _make_get_sblock(name="C", output=b1), _make_enter_postproc(), _make_compute_inline(input=b1), ] else: insts = [ - _make_get_block(name="B", output=b0), + _make_get_sblock(name="B", output=b0), _make_compute_inline(input=b0), - _make_get_block(name="C", output=b1), + _make_get_sblock(name="C", output=b1), ] return Trace(insts=insts, decisions={}) @@ -152,7 +152,7 @@ def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name return Trace( insts=[ - _make_get_block(name="B", output=b0), + _make_get_sblock(name="B", output=b0), _make_get_loops(input=b0, outputs=[l1]), _make_split([l1, None, T.int32(32)], [l2, l3]), ], @@ -161,12 +161,12 @@ def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name def test_trace_construct_1(): - trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + trace = _make_trace_1(SBlockRV(), LoopRV(), LoopRV()) assert str(trace) == "\n".join( ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="block", func_name="main")', + ' b0 = sch.get_sblock(name="block", func_name="main")', " l1, l2 = sch.get_loops(block=b0)", ) ) @@ -175,34 +175,34 @@ def test_trace_construct_1(): def test_trace_construct_get_decision_1(): - trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + trace = _make_trace_1(SBlockRV(), LoopRV(), LoopRV()) assert trace.get_decision(trace.insts[0]) is None assert trace.get_decision(trace.insts[1]) is None def test_trace_construct_append_1(): - trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) - trace.append(inst=_make_get_block("block2", BlockRV())) + trace = _make_trace_1(SBlockRV(), LoopRV(), LoopRV()) + trace.append(inst=_make_get_sblock("block2", SBlockRV())) assert str(trace) == "\n".join( ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="block", func_name="main")', + ' b0 = sch.get_sblock(name="block", func_name="main")', " l1, l2 = sch.get_loops(block=b0)", - ' b3 = sch.get_block(name="block2", func_name="main")', + ' b3 = sch.get_sblock(name="block2", func_name="main")', ) ) def test_trace_construct_pop_1(): - trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + trace = _make_trace_1(SBlockRV(), LoopRV(), LoopRV()) last_inst = trace.insts[-1] assert trace.pop().same_as(last_inst) assert str(trace) == "\n".join( ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="block", func_name="main")', + ' b0 = sch.get_sblock(name="block", func_name="main")', ) ) @@ -227,18 +227,18 @@ def test_trace_construct_pop_2(): def test_trace_apply_to_schedule(): - trace = _make_trace_2(BlockRV()) + trace = _make_trace_2(SBlockRV()) sch = tir.Schedule(elementwise, debug_mask="all") trace.apply_to_schedule(sch, remove_postproc=False, decision_provider=None) assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) def test_trace_as_json_1(): - trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + trace = _make_trace_1(SBlockRV(), LoopRV(), LoopRV()) obj = trace.as_json() assert obj == [ [ - ["GetBlock", [], ["block", "main"], ["b0"]], + ["GetSBlock", [], ["block", "main"], ["b0"]], ["GetLoops", ["b0"], [], ["l1", "l2"]], ], [], @@ -246,14 +246,14 @@ def test_trace_as_json_1(): def test_trace_simplified_1(): - trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) + trace = _make_trace_3(SBlockRV(), SBlockRV(), add_postproc=True) assert str(trace) == "\n".join( ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="B", func_name="main")', + ' b0 = sch.get_sblock(name="B", func_name="main")', " sch.compute_inline(block=b0)", - ' b1 = sch.get_block(name="C", func_name="main")', + ' b1 = sch.get_sblock(name="C", func_name="main")', " sch.enter_postproc()", " sch.compute_inline(block=b1)", ) @@ -263,21 +263,21 @@ def test_trace_simplified_1(): ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="B", func_name="main")', + ' b0 = sch.get_sblock(name="B", func_name="main")', " sch.compute_inline(block=b0)", ) ) def test_trace_simplified_2(): - trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) + trace = _make_trace_3(SBlockRV(), SBlockRV(), add_postproc=True) assert str(trace) == "\n".join( ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="B", func_name="main")', + ' b0 = sch.get_sblock(name="B", func_name="main")', " sch.compute_inline(block=b0)", - ' b1 = sch.get_block(name="C", func_name="main")', + ' b1 = sch.get_sblock(name="C", func_name="main")', " sch.enter_postproc()", " sch.compute_inline(block=b1)", ) @@ -287,9 +287,9 @@ def test_trace_simplified_2(): ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="B", func_name="main")', + ' b0 = sch.get_sblock(name="B", func_name="main")', " sch.compute_inline(block=b0)", - ' b1 = sch.get_block(name="C", func_name="main")', + ' b1 = sch.get_sblock(name="C", func_name="main")', " sch.enter_postproc()", " sch.compute_inline(block=b1)", ) @@ -297,12 +297,14 @@ def test_trace_simplified_2(): def test_trace_simplified_3(): - trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False) + trace = _make_trace_4(SBlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified( + remove_postproc=False + ) assert str(trace) == "\n".join( ( "# from tvm import tir", "def apply_trace(sch: tir.Schedule) -> None:", - ' b0 = sch.get_block(name="B", func_name="main")', + ' b0 = sch.get_sblock(name="B", func_name="main")', " l1, = sch.get_loops(block=b0)", " l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True, disable_predication=False)", ) @@ -310,7 +312,7 @@ def test_trace_simplified_3(): def test_apply_json_to_schedule_1(): - trace = _make_trace_2(BlockRV()) + trace = _make_trace_2(SBlockRV()) json_obj = trace.as_json() sch = tir.Schedule(elementwise, debug_mask="all") Trace.apply_json_to_schedule(json_obj, sch) @@ -347,10 +349,10 @@ def _test_apply_annotation_trace_from_json(annotation: str): Designed to handle some previously failing edge cases like the empty string. """ - b0 = BlockRV() + b0 = SBlockRV() trace = Trace( insts=[ - _make_get_block(name="B", output=b0), + _make_get_sblock(name="B", output=b0), _make_annotate(block=b0, annotation=annotation), ], decisions={}, @@ -365,12 +367,12 @@ def elementwise_expected(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block("B"): - T.block_attr({"meta_schedule.auto_tensorize": annotation}) + with T.sblock("B"): + T.sblock_attr({"meta_schedule.auto_tensorize": annotation}) vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 diff --git a/tests/python/tir-schedule/test_tir_schedule_transform.py b/tests/python/tir-schedule/test_tir_schedule_transform.py index b189d3c39e5b..6bc669a98723 100644 --- a/tests/python/tir-schedule/test_tir_schedule_transform.py +++ b/tests/python/tir-schedule/test_tir_schedule_transform.py @@ -30,11 +30,11 @@ def main( compute: T.Buffer((1024, 1024), "int32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() for i0, i1, i2 in T.grid(1024, 1024, 1024): - with T.block("compute"): + with T.sblock("compute"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) T.writes(compute[i, j]) @@ -56,9 +56,9 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4): - with T.block("compute"): + with T.sblock("compute"): i = T.axis.spatial(1024, i0) j = T.axis.spatial(1024, i1_0 * 16 + i1_1) k = T.axis.reduce(1024, i2_0 * 4 + i2_1) @@ -81,7 +81,7 @@ def main( ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): ( n, oc_chunk, @@ -123,11 +123,11 @@ def main( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid( 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4 ): - with T.block("conv2d_NCHWc_int8"): + with T.sblock("conv2d_NCHWc_int8"): n, oc_chunk, oh, ow = T.axis.remap("SSSS", [i0, i1, i2, i3]) oc_block = T.axis.spatial(16, i4_0 * 16 + i4_1) kh, kw, ic_outer, ic_f_inner = T.axis.remap("RRRR", [i5, i6, i7, i8]) @@ -152,7 +152,7 @@ def main( def test_tile_with_tensor_intrin_dense(intrin=VNNI_DOT_16x4_INTRIN): s = Schedule(DenseTIRModule) - block = s.get_block("compute") + block = s.get_sblock("compute") tiled_loop = tile_with_tensor_intrin(s, block, intrin) @@ -164,7 +164,7 @@ def test_tile_with_tensor_intrin_dense(intrin=VNNI_DOT_16x4_INTRIN): def test_tile_with_tensor_intrin_conv2d_nchwc(intrin=VNNI_DOT_16x4_INTRIN): s = Schedule(Conv2dNCHWcTIRModule) - block = s.get_block("conv2d_NCHWc_int8") + block = s.get_sblock("conv2d_NCHWc_int8") tiled_loop = tile_with_tensor_intrin(s, block, intrin) tiled_loops = s.get_loops(block) assert len(tiled_loops) == 12 diff --git a/tests/python/tir-schedule/test_tir_schedule_transform_layout.py b/tests/python/tir-schedule/test_tir_schedule_transform_layout.py index 8136d2b5fa78..66150484ca09 100644 --- a/tests/python/tir-schedule/test_tir_schedule_transform_layout.py +++ b/tests/python/tir-schedule/test_tir_schedule_transform_layout.py @@ -39,11 +39,11 @@ def packed_index_map_func(m, n): def two_elementwise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -54,11 +54,11 @@ def two_elementwise_transformed_intermediate_buffer( ) -> None: B = T.alloc_buffer((8, 8, 16, 16), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 @@ -69,11 +69,11 @@ def two_elementwise_transformed_input_buffer( ) -> None: B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi // 16, vj // 16, vi % 16, vj % 16] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -84,11 +84,11 @@ def two_elementwise_transformed_output_buffer( ) -> None: B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi // 16, vj // 16, vi % 16, vj % 16] = B[vi, vj] + 1.0 @@ -96,7 +96,7 @@ def two_elementwise_transformed_output_buffer( @T.prim_func def elementwise(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")) -> None: for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -104,7 +104,7 @@ def elementwise(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "flo @T.prim_func def elementwise_transformed(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")) -> None: for i in range(16384): - with T.block("B"): + with T.sblock("B"): vi = T.axis.remap("S", [i]) B[vi // 128, vi % 128] = A[vi // 128, vi % 128] * 2.0 @@ -117,7 +117,7 @@ def conv2d_nhwc( ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), @@ -126,7 +126,7 @@ def conv2d_nhwc( dtype="float32", ) for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): Conv2d_nhwc[n, h, w, co] = T.float32(0) @@ -144,7 +144,7 @@ def conv2d_nhwc_transformed( ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): - with T.block("PadInput"): + with T.sblock("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) @@ -155,7 +155,7 @@ def conv2d_nhwc_transformed( dtype="float32", ) for ax0, ax1, ax2 in T.grid(12544, 64, 147): - with T.block("conv2d_nhwc"): + with T.sblock("conv2d_nhwc"): v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) with T.init(): Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = T.float32(0) @@ -166,11 +166,11 @@ def conv2d_nhwc_transformed( def two_elementwise_unit_dim(A: T.Buffer((1, 128), "float32"), C: T.Buffer((1, 128), "float32")) -> None: B = T.alloc_buffer((1, 128), "float32") for i, j in T.grid(1, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(1, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -187,7 +187,7 @@ def transform_fn(x, y): return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32] sch = tvm.tir.Schedule(mod, debug_mask="all") - block_rv = sch.get_block("T_add") + block_rv = sch.get_sblock("T_add") sch.cache_write(block_rv, 0, "global") sch.transform_layout(block_rv, ("write", 0), transform_fn, pad_value=0.0) return sch.mod @@ -200,9 +200,9 @@ def before( T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"), ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(33), T.int64(128)): - with T.block("T_add"): + with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) @@ -210,16 +210,16 @@ def before( def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), axis_separators=[2]) for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), T.int64(32)): - with T.block("T_add"): + with T.sblock("T_add"): v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1, axis2]) T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1], p1[v_axis0 * T.int64(32) + v_axis2, v_axis1]) T.writes(T_add_global[v_axis0, v_axis1, v_axis2]) T_add_global[v_axis0, v_axis1, v_axis2] = T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0), p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) + v_axis2, v_axis1]) for ax0, ax1 in T.grid(T.int64(33), T.int64(128)): - with T.block("T_add_global"): + with T.sblock("T_add_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)]) T.writes(T_add[v0, v1]) @@ -241,7 +241,7 @@ def test_two_elementwise_transform_intermediate_buffer(use_block_name): index_map=packed_index_map_func, ) else: - block = sch.get_block("B") + block = sch.get_sblock("B") sch.transform_layout(block, ("write", 0), packed_index_map_func) assert_structural_equal_ignore_global_symbol( @@ -252,7 +252,7 @@ def test_two_elementwise_transform_intermediate_buffer(use_block_name): def test_transform_layout_with_sampling(): sch = tir.Schedule(two_elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") loop = sch.get_loops(block_b)[-1] j0, j1, j2 = sch.sample_perfect_tile(loop, 3, decision=[4, 8, 4]) sch.transform_layout(block_b, ("write", 0), lambda i, j: (i, j // (j1 * j2), j % (j1 * j2))) @@ -269,7 +269,7 @@ def test_two_elementwise_transform_input_buffer(use_block_name): buffer="A", ) else: - block = sch.get_block("B") + block = sch.get_sblock("B") sch.transform_layout(block, ("read", 0), packed_index_map_func) assert_structural_equal_ignore_global_symbol( @@ -288,7 +288,7 @@ def test_two_elementwise_transform_output_buffer(use_block_name): buffer="C", ) else: - block = sch.get_block("C") + block = sch.get_sblock("C") sch.transform_layout(block, ("write", 0), packed_index_map_func) assert_structural_equal_ignore_global_symbol( @@ -308,7 +308,7 @@ def test_two_elementwise_unit_dim(use_block_name): buffer="B", ) else: - block = sch.get_block("B") + block = sch.get_sblock("B") sch.transform_layout(block, ("write", 0), index_map) assert_structural_equal_ignore_global_symbol(two_elementwise_unit_dim, sch.mod["main"]) @@ -318,7 +318,7 @@ def test_two_elementwise_unit_dim(use_block_name): def test_simplify(): sch = tir.Schedule(two_elementwise, debug_mask="all") - i, j = sch.get_loops(sch.get_block("C")) + i, j = sch.get_loops(sch.get_sblock("C")) i, i_inner = sch.split(i, factors=[None, 16]) j, j_inner = sch.split(j, factors=[None, 16]) @@ -337,12 +337,12 @@ def test_simplify(): @T.prim_func def ref(B: T.Buffer((8, 8, 16, 16), "float32"), C: T.Buffer((128, 128), "float32")): for i_0, j_0 in T.grid(8, 8): - with T.block("C_o"): + with T.sblock("C_o"): vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) T.reads(B[vi_o, vj_o, 0:16, 0:16]) T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) for i_1, j_1 in T.grid(16, 16): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i_1, j_1]) T.reads(B[vi_o, vj_o, vi, vj]) T.writes(C[vi_o * 16 + vi, vj_o * 16 + vj]) @@ -363,7 +363,7 @@ def summation_3d( ) -> None: B[0] = 0 for i, j, k in T.grid(1024, 1024, 32): - with T.block("compute"): + with T.sblock("compute"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[0] = B[0] + A[vi, vj, vk] @@ -373,7 +373,7 @@ def summation_3d_split( ) -> None: B[0] = 0 for i, j, k in T.grid(1024, 1024, 32): - with T.block("compute"): + with T.sblock("compute"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[0] = B[0] + A[vi, vj, vk // 4, vk % 4] @@ -386,7 +386,7 @@ def summation_3d_split( def test_transform_block_layout_basic(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") + block = "B" if use_block_name else sch.get_sblock("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) assert_structural_equal_ignore_global_symbol(elementwise_transformed, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) @@ -394,7 +394,7 @@ def test_transform_block_layout_basic(use_block_name): def test_transform_block_layout_conv2d_nhwc(use_block_name): sch = tir.Schedule(conv2d_nhwc, debug_mask="all") - block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_sblock("conv2d_nhwc") sch.transform_block_layout( block, lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh * 7 * 3 + rw * 3 + rc), @@ -405,7 +405,7 @@ def test_transform_block_layout_conv2d_nhwc(use_block_name): def test_transform_block_layout_unit_dim(use_block_name): sch = tir.Schedule(two_elementwise_unit_dim, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") + block = "B" if use_block_name else sch.get_sblock("B") sch.transform_block_layout(block, lambda i, j: (j, i)) @T.prim_func @@ -414,11 +414,11 @@ def two_elementwise_unit_dim_transformed( ) -> None: B = T.alloc_buffer((1, 128), "float32") for j, i in T.grid(128, 1): - with T.block("B"): + with T.sblock("B"): vj, vi = T.axis.remap("SS", [j, i]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(1, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -430,14 +430,14 @@ def two_elementwise_unit_dim_transformed( def test_transform_block_layout_fail_non_affine(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") + block = "B" if use_block_name else sch.get_sblock("B") with pytest.raises(tir.ScheduleError): sch.transform_block_layout(block, lambda i, j: (i + j,)) def test_transform_block_layout_fail_mixed_iter_type(use_block_name): sch = tir.Schedule(conv2d_nhwc, debug_mask="all") - block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_sblock("conv2d_nhwc") with pytest.raises(tir.ScheduleError): sch.transform_block_layout( block, @@ -452,7 +452,7 @@ def elementwise_int64_extent( B: T.Buffer((T.int64(128), T.int64(128)), "float32"), ) -> None: for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -462,14 +462,14 @@ def elementwise_int64_extent_transformed( B: T.Buffer((T.int64(128), T.int64(128)), "float32"), ) -> None: for i in range(T.int64(16384)): - with T.block("B"): + with T.sblock("B"): vi = T.axis.remap("S", [i]) B[vi // T.int64(128), vi % T.int64(128)] = ( A[vi // T.int64(128), vi % T.int64(128)] * 2.0 ) sch = tir.Schedule(elementwise_int64_extent, debug_mask="all") - block = "B" if use_block_name else sch.get_block("B") + block = "B" if use_block_name else sch.get_sblock("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) assert_structural_equal_ignore_global_symbol( elementwise_int64_extent_transformed, sch.mod["main"] @@ -510,14 +510,14 @@ class TestNoPadding(BasePaddingCompare): def before(): A = T.alloc_buffer(16, "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi // 4, vi % 4] = 0 @@ -535,26 +535,26 @@ class TestNoPaddingMultipleUsage(BasePaddingCompare): def before(): A = T.alloc_buffer(16, "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 B = T.alloc_buffer(16, "int32") for i in T.serial(16): - with T.block("other"): + with T.sblock("other"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi // 4, vi % 4] = 0 B = T.alloc_buffer(16, "int32") for i in T.serial(16): - with T.block("other"): + with T.sblock("other"): vi = T.axis.remap("S", [i]) B[vi] = A[vi // 4, vi % 4] @@ -570,13 +570,13 @@ class TestNoPaddingOpaqueBlock(BasePaddingCompare): def before(): A = T.alloc_buffer(16, "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): A[i] = 0 def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): A[i // 4, i % 4] = 0 @@ -586,7 +586,7 @@ class TestErrorIfPaddingForbidden(BasePaddingCompare): def before(): A = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 @@ -603,14 +603,14 @@ class TestImplicitPaddingAssumeInjective(BasePaddingCompare): def before(): A = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 def expected(): A = T.alloc_buffer([4, 4], "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi // 4, vi % 4] = 0 @@ -623,7 +623,7 @@ class TestErrorOnWrongPaddingType(BasePaddingCompare): def before(): A = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 @@ -638,7 +638,7 @@ class TestErrorOnNonMatchingTypes(BasePaddingCompare): def before(): A = T.alloc_buffer(14, "float32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 @@ -666,7 +666,7 @@ def before(self, dtype): def func(A: T.Buffer(14, dtype)): B = T.alloc_buffer(14, dtype) for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] @@ -680,7 +680,7 @@ def expected(self, dtype, pad_value): def func(A: T.Buffer(14, dtype)): B = T.alloc_buffer([4, 4], dtype) for i, j in T.grid(4, 4): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.if_then_else( vi == 3 and 2 <= vj, pad_value, A[vi * 4 + vj], dtype=dtype @@ -699,18 +699,18 @@ class TestPaddedTransformWithoutLoop(BasePaddingCompare): pad_value = tvm.testing.parameter(0) def before(A: T.Buffer(14, "int32")): - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block("block"): + with T.sblock("block"): A[0] = 0 def expected(A: T.Buffer((4, 4), "int32")): - with T.block("block"): + with T.sblock("block"): A[0, 0] = 0 for i, j in T.grid(4, 4): - with T.block("buffer_A_padding"): + with T.sblock("buffer_A_padding"): vi, vj = T.axis.remap("SS", [i, j]) T.where(i == 3 and 2 <= j) A[vi, vj] = 0 @@ -725,7 +725,7 @@ class TestPaddedTransformIfThenElseReduction(BasePaddingCompare): def before(A: T.Buffer((14, 32), "int32")): B = T.alloc_buffer(14, "int32") for i, k in T.grid(14, 32): - with T.block("block"): + with T.sblock("block"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0 @@ -734,7 +734,7 @@ def before(A: T.Buffer((14, 32), "int32")): def expected(A: T.Buffer((14, 32), "int32")): B = T.alloc_buffer([4, 4], "int32") for i, j, k in T.grid(4, 4, 32): - with T.block("block"): + with T.sblock("block"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32") @@ -754,7 +754,7 @@ def before(A: T.Buffer((14, 32), "int32")): for i in T.serial(14): B[i] = 0 for k in T.serial(32): - with T.block("block"): + with T.sblock("block"): B[i] = B[i] + A[i, k] def expected(A: T.Buffer((14, 32), "int32")): @@ -762,7 +762,7 @@ def expected(A: T.Buffer((14, 32), "int32")): for i, j in T.grid(4, 4): B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") for k in T.serial(32): - with T.block("block"): + with T.sblock("block"): B[i, j] = T.if_then_else( i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" ) @@ -782,7 +782,7 @@ def before(A: T.Buffer(14, "int32")): B = T.alloc_buffer(14, "int32") C = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] C[vi] = 0 @@ -791,13 +791,13 @@ def expected(A: T.Buffer(14, "int32")): B = T.alloc_buffer([4, 4], "int32") C = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi // 4, vi % 4] = A[vi] C[vi] = 0 for i, j in T.grid(4, 4): - with T.block("block_pad_B"): + with T.sblock("block_pad_B"): vi, vj = T.axis.remap("SS", [i, j]) T.where(i == 3 and 2 <= j) B[vi, vj] = 0 @@ -810,18 +810,18 @@ class TestPaddedTransformOfInputCreatesAssumption(BasePaddingCompare): def before(A: T.Buffer(14, "int32"), B: T.Buffer(14, "int32")): for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] def expected(A: T.Buffer((4, 4), "int32"), B: T.Buffer(14, "int32")): for i, j in T.grid(4, 4): - with T.block("buffer_A_assumption"): + with T.sblock("buffer_A_assumption"): vi, vj = T.axis.remap("SS", [i, j]) T.evaluate(T.assume(not (vi == 3 and 2 <= vj) or A[vi, vj] == 42)) for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi // 4, vi % 4] @@ -850,14 +850,14 @@ def transform(mod): def before(A: T.Buffer(14, "int32")): B = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] def expected(A: T.Buffer(14, "int32")): B = T.alloc_buffer([4, 4], "int32") for i, j in T.grid(4, 4): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.if_then_else( vi == 3 and 2 <= vj, vi + vj, A[vi * 4 + vj], dtype="int32" @@ -879,7 +879,7 @@ def transform(self): def transform(mod): sch = tir.Schedule(mod) - A = sch.get(sch.get_block("block")).reads[0].buffer + A = sch.get(sch.get_sblock("block")).reads[0].buffer sch.transform_layout( "block", "A", @@ -893,13 +893,13 @@ def transform(mod): def before(A: T.Buffer(14, "int32")): B = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] def expected(A: T.Buffer((4, 4), "int32")): for i, j in T.grid(4, 4): - with T.block("buffer_A_assumption"): + with T.sblock("buffer_A_assumption"): vi, vj = T.axis.remap("SS", [i, j]) T.evaluate( T.assume( @@ -910,7 +910,7 @@ def expected(A: T.Buffer((4, 4), "int32")): B = T.alloc_buffer(14, "int32") for i in T.grid(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi // 4, vi % 4] @@ -927,7 +927,7 @@ def transform(self): def transform(mod): sch = tir.Schedule(mod) - A = sch.get(sch.get_block("block")).reads[0].buffer + A = sch.get(sch.get_sblock("block")).reads[0].buffer other = tir.decl_buffer(1, A.dtype, name="other") sch.transform_layout( "block", @@ -942,7 +942,7 @@ def transform(mod): def before(A: T.Buffer(14, "int32")): B = T.alloc_buffer(14, "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] @@ -972,14 +972,14 @@ def transform(mod): def before(A: T.Buffer(16, "int32"), n: T.int32): B = T.alloc_buffer(16, "int32") for i in T.serial(16): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] def expected(A: T.Buffer(16, "int32"), n: T.int32): B = T.alloc_buffer([(-16 % n + 16) // n, n], dtype="int32") for i, j in T.grid((-16 % n + 16) // n, n): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.if_then_else( # Checks if the transform introduced padding @@ -1004,14 +1004,14 @@ class TestTransformWithAxisSeparators(BasePaddingCompare): def before(a: T.handle): A = T.match_buffer(a, [14], "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 42 def expected(a: T.handle): A = T.match_buffer(a, [4, 4], "int32", axis_separators=[1]) for i, j in T.grid(4, 4): - with T.block("block"): + with T.sblock("block"): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 42, dtype="int32") @@ -1025,13 +1025,13 @@ class TestTransformWithAxisSeparatorsOpaqueBlock(BasePaddingCompare): def before(a: T.handle): A = T.match_buffer(a, [14], "int32") for i in T.serial(14): - with T.block("block"): + with T.sblock("block"): A[i] = 42 def expected(a: T.handle): A = T.match_buffer(a, [4, 4], "int32", axis_separators=[1]) for i, j in T.grid(4, 4): - with T.block("block"): + with T.sblock("block"): A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32") @@ -1041,7 +1041,7 @@ def test_index_map_dtype_legalize(): @T.prim_func def func(A: T.Buffer(T.int64(58), "int32")): for i in T.serial(T.int64(58)): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) T.writes(A[vi]) A[vi] = 0 @@ -1052,7 +1052,7 @@ def func(A: T.Buffer(T.int64(58), "int32")): # # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : # # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) sch.transform_layout( - sch.get_block("block"), buffer="A", index_map=lambda h: [h // 8, h % 8], pad_value=0 + sch.get_sblock("block"), buffer="A", index_map=lambda h: [h // 8, h % 8], pad_value=0 ) @@ -1065,7 +1065,7 @@ def test_index_map_dtype_legalize_with_constant(): @T.prim_func def func(A: T.Buffer(T.int64(16), "int32")): for i in T.grid(T.int64(16)): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) A[vi] = 0 @@ -1103,7 +1103,7 @@ def before(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") C = T.match_buffer(c, (T.int64(1), T.int64(32), T.int64(1), n), "float16") for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k]) T.writes(C[v_i0, v_i1, v_i2, v_i3]) @@ -1119,7 +1119,7 @@ def after(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") C = T.match_buffer(c, (n * T.int64(32),), "float16") for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k]) T.writes(C[v_i1 * n + v_i3]) @@ -1131,7 +1131,7 @@ def after(a: T.handle, b: T.handle, c: T.handle): # pylint: disable=invalid-name _, _, n, _ = before.buffer_map[before.params[1]].shape sch = tvm.tir.Schedule(before) - block = sch.get_block("NT_matmul") + block = sch.get_sblock("NT_matmul") sch.transform_layout( block, ("write", 0), @@ -1153,7 +1153,7 @@ def before(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") C = T.match_buffer(c, (n * T.int64(32),), "float16") for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k]) T.writes(C[v_i1 * n + v_i3]) @@ -1169,7 +1169,7 @@ def after(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") C = T.match_buffer(c, (n * T.int64(32),), "float16") for ax0, ax1 in T.grid(n * T.int64(32), T.int64(128)): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): v0, v1 = T.axis.remap("SR", [ax0, ax1]) T.reads(A[T.int64(0), v0 // n, T.int64(0), v1], B[T.int64(0), v0 // n, v0 % n, v1]) T.writes(C[v0]) @@ -1181,7 +1181,7 @@ def after(a: T.handle, b: T.handle, c: T.handle): # pylint: disable=invalid-name _, _, n, _ = before.buffer_map[before.params[1]].shape sch = tvm.tir.Schedule(before) - block = sch.get_block("NT_matmul") + block = sch.get_sblock("NT_matmul") sch.transform_block_layout( block, lambda x, y, z, w, k: ( diff --git a/tests/python/tir-schedule/test_tir_schedule_utilities.py b/tests/python/tir-schedule/test_tir_schedule_utilities.py index 0ad05ea83288..de053be72881 100644 --- a/tests/python/tir-schedule/test_tir_schedule_utilities.py +++ b/tests/python/tir-schedule/test_tir_schedule_utilities.py @@ -37,11 +37,11 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for k in range(0, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -53,13 +53,13 @@ def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None: C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -73,13 +73,13 @@ def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None: for i in T.serial(0, 1024, annotations={"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}}): for j in T.serial(0, 1024, annotations={"test2": 612, "test3": ["aa", 1]}): for k in T.serial(0, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -91,16 +91,16 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i, j, k in T.grid(1024, 1024, 1024): - with T.block("matmul"): + with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 - T.block_attr({"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}}) + T.sblock_attr({"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}}) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): - with T.block("relu"): + with T.sblock("relu"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"test2": 0.22, "test3": ["aa", 1]}) + T.sblock_attr({"test2": 0.22, "test3": ["aa", 1]}) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -112,7 +112,7 @@ def vector_add( B: T.Buffer(128, "float32"), ) -> None: for i in range(128): - with T.block("init"): + with T.sblock("init"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] @@ -122,7 +122,7 @@ def vector_add_2( B: T.Buffer(128, "float32"), ) -> None: for i in range(128): - with T.block("init"): + with T.sblock("init"): vi = T.axis.remap("S", [i]) B[vi] = A[vi] @@ -132,13 +132,13 @@ def tuple_reduction(data: T.Buffer((4, 32), "float32"), T_add: T.Buffer((4,), "f # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() data_red_temp_v0 = T.alloc_buffer([4], dtype="float32") data_red_temp_v1 = T.alloc_buffer([4], dtype="float32") for i0, i1 in T.grid(4, 32): - with T.block("data_red_temp"): + with T.sblock("data_red_temp"): ax0, k1 = T.axis.remap("SR", [i0, i1]) T.reads(data[ax0, k1]) T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0]) @@ -152,7 +152,7 @@ def tuple_reduction(data: T.Buffer((4, 32), "float32"), T_add: T.Buffer((4,), "f data_red_temp_v0[ax0] = v_data_red_temp_v0 data_red_temp_v1[ax0] = v_data_red_temp_v1 for i0 in range(4): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.remap("S", [i0]) T.reads(data_red_temp_v0[ax0], data_red_temp_v1[ax0]) T.writes(T_add[ax0]) @@ -175,13 +175,13 @@ def test_tir_schedule_creation(): assert sch_1.state.mod["main"].same_as(sch_2.state.mod["main"]) -def test_tir_schedule_get_block(): +def test_tir_schedule_get_sblock(): # Tests: - # - Schedule.get_block + # - Schedule.get_sblock # - Schedule.get_sref # - Schedule.get sch = tir.Schedule(matmul, debug_mask="all") - block_rv = sch.get_block(name="update") + block_rv = sch.get_sblock(name="update") block_sref = sch.get_sref(block_rv) block = sch.get(block_rv) assert block.name_hint == "update" @@ -193,9 +193,9 @@ def test_tir_schedule_get_block(): def test_tir_schedule_work_on(): sch = tir.Schedule(ModuleWithMultipleFuncs, debug_mask="all") with pytest.raises(ValueError, match="does not know which function to be working on"): - sch.get_block(name="init") + sch.get_sblock(name="init") sch.work_on(func_name="vector_add") - sch.get_block(name="init") + sch.get_sblock(name="init") assert sch.func_working_on == sch.mod.get_global_var("vector_add") @@ -204,7 +204,7 @@ def test_tir_schedule_get_loops(use_block_name): # - Schedule.get_loops # - Schedule.get sch = tir.Schedule(matmul, debug_mask="all") - block = "update" if use_block_name else sch.get_block(name="update") + block = "update" if use_block_name else sch.get_sblock(name="update") i, j, k = sch.get_loops(block) assert sch.get(i).loop_var.name == "i" assert sch.get(j).loop_var.name == "j" @@ -215,7 +215,7 @@ def test_tir_schedule_copy_1(use_block_name): # Tests: # - Schedule.copy sch_1 = tir.Schedule(matmul, debug_mask="all") - block_rv = sch_1.get_block(name="update") + block_rv = sch_1.get_sblock(name="update") i, j, k = sch_1.get_loops(block="update" if use_block_name else block_rv) assert sch_1.get(i).loop_var.name == "i" assert sch_1.get(j).loop_var.name == "j" @@ -230,7 +230,7 @@ def test_tir_schedule_copy_1(use_block_name): def test_tir_schedule_copy_2(): sch = tir.Schedule(mod=matmul, debug_mask="all") - i, j, k = sch.get_loops(sch.get_block("update")) + i, j, k = sch.get_loops(sch.get_sblock("update")) sch_copy = sch.copy() assert not sch.get_sref(i).same_as(sch_copy.get_sref(i)) assert not sch.get_sref(j).same_as(sch_copy.get_sref(j)) @@ -262,7 +262,7 @@ def test_tir_schedule_remove_rv(): # Tests: # - Schedule.remove_rv sch = tir.Schedule(matmul, debug_mask="all") - block_rv = sch.get_block(name="update") + block_rv = sch.get_sblock(name="update") assert sch.get(block_rv).name_hint == "update" sch.remove_rv(block_rv) with pytest.raises(IndexError): @@ -271,15 +271,15 @@ def test_tir_schedule_remove_rv(): def test_get_child_blocks(): s = tir.Schedule(matmul, debug_mask="all") - init = s.get_block("init") - update = s.get_block("update") + init = s.get_sblock("init") + update = s.get_sblock("update") # loop blocks = s.get_child_blocks(s.get_loops(init)[0]) assert len(blocks) == 2 assert s.get(init) == s.get(blocks[0]) assert s.get(update) == s.get(blocks[1]) # block - root = s.get_block("root") + root = s.get_sblock("root") blocks = s.get_child_blocks(root) assert len(blocks) == 2 assert s.get(init) == s.get(blocks[0]) @@ -288,50 +288,50 @@ def test_get_child_blocks(): def test_get_producers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - block = "relu" if use_block_name else sch.get_block("relu") + block = "relu" if use_block_name else sch.get_sblock("relu") (producer,) = sch.get_producers(block) tvm.ir.assert_structural_equal( sch.get_sref(producer).stmt, - sch.get_sref(sch.get_block("matmul")).stmt, + sch.get_sref(sch.get_sblock("matmul")).stmt, ) verify_trace_roundtrip(sch, mod=matmul_relu) def test_get_producers_multiple_buffer_depdencies(use_block_name): sch = tir.Schedule(mod=tuple_reduction, debug_mask="all") - block = "T_add" if use_block_name else sch.get_block("T_add") + block = "T_add" if use_block_name else sch.get_sblock("T_add") (producer,) = sch.get_producers(block) tvm.ir.assert_structural_equal( sch.get_sref(producer).stmt, - sch.get_sref(sch.get_block("data_red_temp")).stmt, + sch.get_sref(sch.get_sblock("data_red_temp")).stmt, ) def test_get_consumers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - block = "matmul" if use_block_name else sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_sblock("matmul") (consumer,) = sch.get_consumers(block) tvm.ir.assert_structural_equal( sch.get_sref(consumer).stmt, - sch.get_sref(sch.get_block("relu")).stmt, + sch.get_sref(sch.get_sblock("relu")).stmt, ) verify_trace_roundtrip(sch, mod=matmul_relu) def test_get_consumers_multiple_buffer_depdencies(use_block_name): sch = tir.Schedule(mod=tuple_reduction, debug_mask="all") - block = "data_red_temp" if use_block_name else sch.get_block("data_red_temp") + block = "data_red_temp" if use_block_name else sch.get_sblock("data_red_temp") (consumer,) = sch.get_consumers(block) tvm.ir.assert_structural_equal( sch.get_sref(consumer).stmt, - sch.get_sref(sch.get_block("T_add")).stmt, + sch.get_sref(sch.get_sblock("T_add")).stmt, ) def test_annotate_unannotate_loop(): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - matmul = sch.get_block("matmul") - relu = sch.get_block("relu") + matmul = sch.get_sblock("matmul") + relu = sch.get_sblock("relu") sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa") sch.annotate(sch.get_loops(matmul)[1], "test2", 612) sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1]) @@ -347,8 +347,8 @@ def test_annotate_unannotate_loop(): def test_annotate_unannotate_block(): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - matmul = sch.get_block("matmul") - relu = sch.get_block("relu") + matmul = sch.get_sblock("matmul") + relu = sch.get_sblock("relu") sch.annotate(matmul, "test1", "aaa") sch.annotate(relu, "test2", 0.22) sch.annotate(relu, "test3", ["aa", 1]) @@ -368,7 +368,7 @@ def test_get_output_blocks_single_output(): assert len(output_blocks) == 1, "Unexpected number of blocks when 1 was expected" block = sch.get(output_blocks[0]) assert block.name_hint == "relu" - relu_block = sch.get_block("relu") + relu_block = sch.get_sblock("relu") assert sch.get(relu_block).same_as(block) @@ -380,9 +380,9 @@ def test_get_output_blocks_multiple_outputs(): assert block_1.name_hint == "init" block_2 = sch.get(output_blocks[1]) assert block_2.name_hint == "update" - init_block = sch.get_block("init") + init_block = sch.get_sblock("init") assert sch.get(init_block).same_as(block_1) - update_block = sch.get_block("update") + update_block = sch.get_sblock("update") assert sch.get(update_block).same_as(block_2) @@ -392,11 +392,11 @@ def blockized( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), ) -> None: - with T.block("blockized_B"): + with T.sblock("blockized_B"): vio = T.axis.spatial(1, 0) vjo = T.axis.spatial(1, 0) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 @@ -407,9 +407,9 @@ def blockized( assert block_1.name_hint == "blockized_B" block_2 = sch.get(output_blocks[1]) assert block_2.name_hint == "B" - blockized_block = sch.get_block("blockized_B") + blockized_block = sch.get_sblock("blockized_B") assert sch.get(blockized_block).same_as(block_1) - b_block = sch.get_block("B") + b_block = sch.get_sblock("B") assert sch.get(b_block).same_as(block_2) sch = tir.Schedule(mod=blockized, debug_mask="all") @@ -417,7 +417,7 @@ def blockized( assert len(output_blocks) == 1, "Unexpected number of blocks when 1 were expected" block = sch.get(output_blocks[0]) assert block.name_hint == "B" - b_block = sch.get_block("B") + b_block = sch.get_sblock("B") assert sch.get(b_block).same_as(block) diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py index 2edf74ebfb3d..ce5ad84f2dc1 100644 --- a/tests/python/tir-transform/test_tir_inline_private_functions.py +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -141,7 +141,7 @@ def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): A = T.decl_buffer(16, "float32", data=A_data) B = T.decl_buffer(16, "float32", data=B_data) for i in range(16): - with T.block("scalar_mul"): + with T.sblock("scalar_mul"): B[i] = A[i] * 2.0 @I.ir_module @@ -153,7 +153,7 @@ def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): B_data_1: T.handle("float32") = T.address_of(B[0, 0]) B_1 = T.decl_buffer(16, "float32", data=B_data_1) for i in range(16): - with T.block("scalar_mul_1"): + with T.sblock("scalar_mul_1"): B_1[i] = A_1[i] * 2.0 with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: @@ -161,7 +161,7 @@ def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): B_data_2: T.handle("float32") = T.address_of(B[1, 0]) B_2 = T.decl_buffer(16, "float32", data=B_data_2) for i in range(16): - with T.block("scalar_mul_2"): + with T.sblock("scalar_mul_2"): B_2[i] = A_2[i] * 2.0 diff --git a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py index 006ebf6a1a0d..a7f93f83a4d3 100644 --- a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py +++ b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py @@ -79,17 +79,17 @@ def before(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -99,17 +99,17 @@ def expected(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), "float32") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[0, j]) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0 @@ -121,7 +121,7 @@ def before(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") @@ -140,11 +140,11 @@ def before(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (20, 20), "float32") B = T.match_buffer(c, (20, 20), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(B[i, 0:16]) for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -160,17 +160,17 @@ def before(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="shared") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -182,17 +182,17 @@ def expected(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((8, 16), "float32", scope="shared") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i1 * 4 + i2, j]) B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0 @@ -206,17 +206,17 @@ def before(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="warp") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -228,17 +228,17 @@ def expected(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0 @@ -250,17 +250,17 @@ def before(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block(): + with T.sblock(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((n * 8,), "float32") for j in range(0, 8): - with T.block(): + with T.sblock(): T.reads(A[i * 8 + j]) T.writes(B[i * 8 + j]) B[i * 8 + j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block(): + with T.sblock(): T.reads(B[i * 8 + j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[i * 8 + j] * 2.0 @@ -270,17 +270,17 @@ def expected(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block(): + with T.sblock(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((8,), "float32") for j in range(0, 8): - with T.block(): + with T.sblock(): T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block(): + with T.sblock(): T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0 @@ -292,12 +292,12 @@ def before(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block(): + with T.sblock(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((8, 8), "float32") for j in range(0, 4): - with T.block(): + with T.sblock(): D = T.alloc_buffer((8, 8), "float32") T.reads(A[i, j]) T.writes(B[i, j]) @@ -306,12 +306,12 @@ def before(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): B[i, j] = A[i, j] + D[k, j] for j in range(3, 5): - with T.block(): + with T.sblock(): T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] for j in range(6, 8): - with T.block(): + with T.sblock(): T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] @@ -321,12 +321,12 @@ def expected(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block(): + with T.sblock(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((1, 8), "float32") for j in range(0, 4): - with T.block(): + with T.sblock(): D = T.alloc_buffer((6, 1), "float32") T.reads(A[i, j]) T.writes(B[0, j]) @@ -335,12 +335,12 @@ def expected(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): B[0, j] = A[i, j] + D[k - 2, 0] for j in range(3, 5): - with T.block(): + with T.sblock(): T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] for j in range(6, 8): - with T.block(): + with T.sblock(): T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] @@ -354,19 +354,19 @@ def before(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block(): + with T.sblock(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((16, 16)) - with T.block(): + with T.sblock(): B0 = T.match_buffer(B[i, 0:16], (16)) for j in range(0, 16): - with T.block(): + with T.sblock(): A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[i, j], ()) C1[()] = B2[()] * 2.0 @@ -376,19 +376,19 @@ def expected(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block(): + with T.sblock(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((1, 16)) - with T.block(): + with T.sblock(): B0 = T.match_buffer(B[0, 0:16], (16)) for j in range(0, 16): - with T.block(): + with T.sblock(): A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[0, j], ()) C1[()] = B2[()] * 2.0 @@ -400,18 +400,18 @@ def before(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[i, j]) - T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -421,18 +421,18 @@ def expected(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtype="float32") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[0, j]) - T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0 @@ -443,13 +443,13 @@ class TestPaddingPattern(BaseCompactTest): def before(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (20, 20), "float32") - with T.block(): + with T.sblock(): B = T.alloc_buffer((20, 20), dtype="float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): B[i, j] = A[i, j] for i, j in T.grid(20, 20): - with T.block(): + with T.sblock(): C[i, j] = T.if_then_else( 2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2], @@ -461,13 +461,13 @@ def before(a: T.handle, c: T.handle) -> None: def expected(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 16], dtype="float32") C = T.match_buffer(c, [20, 20], dtype="float32") - with T.block(): + with T.sblock(): B = T.alloc_buffer([16, 16], dtype="float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): B[i, j] = A[i, j] for i, j in T.grid(20, 20): - with T.block(): + with T.sblock(): C[i, j] = T.if_then_else( 2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2], @@ -483,10 +483,10 @@ def before(a: T.handle, b: T.handle) -> None: Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for h, w in T.grid(224, 224): - with T.block("cache"): + with T.sblock("cache"): cache[h, w] = X[h, w] for h, w, kh, kw in T.grid(224, 224, 3, 3): - with T.block("compute"): + with T.sblock("compute"): Y[h, w] = T.max( Y[h, w], T.if_then_else( @@ -504,10 +504,10 @@ def before(a: T.handle, b: T.handle) -> None: def expected(X: T.Buffer((224, 224), "float32"), Y: T.Buffer((224, 224), "float32")) -> None: cache = T.alloc_buffer([224, 224], dtype="float32") for h, w in T.grid(224, 224): - with T.block("cache"): + with T.sblock("cache"): cache[h, w] = X[h, w] for h, w, kh, kw in T.grid(224, 224, 3, 3): - with T.block("compute"): + with T.sblock("compute"): Y[h, w] = T.max( Y[h, w], T.if_then_else( @@ -526,21 +526,21 @@ class TestMemAccessInBranch(BaseCompactTest): @T.prim_func def before(a: T.handle) -> None: A = T.match_buffer(a, (224, 224), "float32") - with T.block(): + with T.sblock(): B1 = T.alloc_buffer((224, 224), dtype="float32") B2 = T.alloc_buffer((224, 224), dtype="float32") B3 = T.alloc_buffer((224, 224), dtype="float32") B4 = T.alloc_buffer((224, 224), dtype="float32") for i in range(0, 224): for j in range(0, 224): - with T.block(): + with T.sblock(): if i < 112 and j < 112: B1[i, j] = A[i, j] * 2.0 else: B2[i, j] = A[i, j] + 3.0 for i in range(0, 224): for j in range(0, 224): - with T.block(): + with T.sblock(): if i < 112 or j < 112: B3[i, j] = A[i, j] * 2.0 else: @@ -549,19 +549,19 @@ def before(a: T.handle) -> None: @T.prim_func def expected(a: T.handle) -> None: A = T.match_buffer(a, [224, 224], dtype="float32") - with T.block(): + with T.sblock(): B1 = T.alloc_buffer([112, 112], dtype="float32") B2 = T.alloc_buffer([224, 224], dtype="float32") B3 = T.alloc_buffer([224, 224], dtype="float32") B4 = T.alloc_buffer([112, 112], dtype="float32") for i, j in T.grid(224, 224): - with T.block(): + with T.sblock(): if i < 112 and j < 112: B1[i, j] = A[i, j] * 2.0 else: B2[i, j] = A[i, j] + 3.0 for i, j in T.grid(224, 224): - with T.block(): + with T.sblock(): if i < 112 or j < 112: B3[i, j] = A[i, j] * 2.0 else: @@ -574,11 +574,11 @@ class TestAnnotatedOpaqueAccess(BaseCompactTest): @T.prim_func def before(a: T.handle) -> None: A = T.match_buffer(a, (1024,), "float32") - with T.block(): + with T.sblock(): B = T.alloc_buffer((1024,), dtype="float32") C = T.alloc_buffer((1024,), dtype="float32") for i in range(0, 512): - with T.block(): + with T.sblock(): # no annotation, opaque access will cover full region T.reads([]) T.writes([]) @@ -586,7 +586,7 @@ def before(a: T.handle) -> None: T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32") ) B[i] = A[i] - with T.block(): + with T.sblock(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) @@ -599,11 +599,11 @@ def before(a: T.handle) -> None: @T.prim_func def expected(a: T.handle) -> None: A = T.match_buffer(a, (1024,), "float32") - with T.block(): + with T.sblock(): B = T.alloc_buffer((1024,), dtype="float32") C = T.alloc_buffer((520,), dtype="float32") for i in range(0, 512): - with T.block(): + with T.sblock(): # no annotation, opaque access will cover full region T.reads([]) T.writes([]) @@ -611,7 +611,7 @@ def expected(a: T.handle) -> None: T.call_extern("opaque_extern_function", A.data, B.data, dtype="int32") ) B[i] = A[i] - with T.block(): + with T.sblock(): # treat opaque access only access annotated regions, even if # they are not compatible with actual buffer accesses. T.reads([B[i]]) @@ -630,26 +630,26 @@ def before( A_indptr: T.Buffer((129,), "int32"), ) -> None: for i in T.serial(128): - with T.block("rowsum_outer"): + with T.sblock("rowsum_outer"): T.reads( A_indptr[i : i + 1], A_data[A_indptr[i] + 0 : A_indptr[i] + (A_indptr[i + 1] - A_indptr[i])], ) T.writes(B[i]) - with T.block("rowsum_init"): + with T.sblock("rowsum_init"): T.reads() T.writes(B[i]) B[i] = T.float32(0) for k in T.serial(A_indptr[i + 1] - A_indptr[i]): - with T.block(): + with T.sblock(): T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i]) T.writes(B[i]) A_data_local = T.alloc_buffer([819], dtype="float32", scope="local") - with T.block("A_data_cache_read"): + with T.sblock("A_data_cache_read"): T.reads(A_indptr[i], A_data[A_indptr[i] + k]) T.writes(A_data_local[A_indptr[i] + k]) A_data_local[A_indptr[i] + k] = A_data[A_indptr[i] + k] - with T.block("rowsum_inner"): + with T.sblock("rowsum_inner"): T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k]) T.writes(B[i]) B[i] = B[i] + A_data_local[A_indptr[i] + k] @@ -661,26 +661,26 @@ def expected( A_indptr: T.Buffer((129,), "int32"), ) -> None: for i in T.serial(128): - with T.block("rowsum_outer"): + with T.sblock("rowsum_outer"): T.reads( A_indptr[i : i + 1], A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i + 1] - A_indptr[i])], ) T.writes(B[i]) - with T.block("rowsum_init"): + with T.sblock("rowsum_init"): T.reads() T.writes(B[i]) B[i] = T.float32(0) for k in T.serial(A_indptr[i + 1] - A_indptr[i]): - with T.block(): + with T.sblock(): T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i]) T.writes(B[i]) A_data_local = T.alloc_buffer([1], dtype="float32", scope="local") - with T.block("A_data_cache_read"): + with T.sblock("A_data_cache_read"): T.reads(A_indptr[i], A_data[A_indptr[i] + k]) T.writes(A_data_local[T.min(A_indptr[i] + k, 0)]) A_data_local[T.min(A_indptr[i] + k, 0)] = A_data[A_indptr[i] + k] - with T.block("rowsum_inner"): + with T.sblock("rowsum_inner"): T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k]) T.writes(B[i]) B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 0)] @@ -724,7 +724,7 @@ def before(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")) -> None B_cache = T.alloc_buffer(10, "float32") for j in T.serial(3): for k in T.serial(4): - with T.block("B_cache"): + with T.sblock("B_cache"): T.where(j * 4 + k < 10) B_cache[j * 4 + k] = B[j] for i in T.serial(10): @@ -734,7 +734,7 @@ def before(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")) -> None def expected(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")) -> None: B_cache = T.alloc_buffer([10], dtype="float32") for j, k in T.grid(3, 4): - with T.block("B_cache"): + with T.sblock("B_cache"): T.where(j * 4 + k < 10) T.reads(B[j]) T.writes(B_cache[j * 4 + k]) @@ -781,10 +781,10 @@ class TestSpatialTiledPadPooling(BaseCompactTest): @T.prim_func def before(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), "int32")) -> None: for h_o, w_o in T.grid(14, 14): - with T.block(): + with T.sblock(): X_cache = T.alloc_buffer([112, 112, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): - with T.block("cache"): + with T.sblock("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2]) T.writes(X_cache[h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2, ax0]) @@ -792,7 +792,7 @@ def before(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), "int3 ax0, h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2 ] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): - with T.block("compute"): + with T.sblock("compute"): T.reads( X_cache[(h_o * 4 + h_i) * 2 + kh - 1, (w_o * 4 + w_i) * 2 + kw - 1, c] ) @@ -819,12 +819,12 @@ def before(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), "int3 @T.prim_func def expected(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), "int32")) -> None: for h_o, w_o in T.grid(14, 14): - with T.block(): + with T.sblock(): T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 8 + 8]) T.writes(Y[h_o * 4 : h_o * 4 + 4, w_o * 4 : w_o * 4 + 4, 0:64]) X_cache = T.alloc_buffer([9, 9, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): - with T.block("cache"): + with T.sblock("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1]) T.writes( @@ -840,7 +840,7 @@ def expected(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), "in ax0, ] = X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): - with T.block("compute"): + with T.sblock("compute"): T.reads( X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, @@ -876,25 +876,25 @@ def before(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304), "float32 for bx in T.thread_binding(144, thread="blockIdx.x"): for vx in T.thread_binding(2, thread="vthread.x"): for tx_p in T.thread_binding(256, thread="threadIdx.x"): - with T.block(): + with T.sblock(): for k_0 in T.serial(193): - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([960, 770], dtype="float32", scope="shared") B_shared = T.alloc_buffer([770, 2304], dtype="float32", scope="shared") for _u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(3): - with T.block("A_shared"): + with T.sblock("A_shared"): T.where(bx // 18 * 128 + ((_u * 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4 < 770 and (_u * 256 + tx) * 3 + vec < 512) A_shared[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] for _u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(4): - with T.block("B_shared"): + with T.sblock("B_shared"): T.where(k_0 * 4 + ((_u * 256 + tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512) B_shared[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] = B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2): - with T.block("update_update"): + with T.sblock("update_update"): C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] = C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 + k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] @T.prim_func @@ -902,25 +902,25 @@ def expected(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304), "float for bx in T.thread_binding(144, thread="blockIdx.x"): for vx in T.thread_binding(2, thread="vthread.x"): for tx_p in T.thread_binding(256, thread="threadIdx.x"): - with T.block(): + with T.sblock(): for k_0 in T.serial(193): - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([128, 4], dtype="float32", scope="shared") B_shared = T.alloc_buffer([4, 128], dtype="float32", scope="shared") for v_u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(3): - with T.block("A_shared"): + with T.sblock("A_shared"): T.where(bx // 18 * 128 + (tx * 3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec < 512) A_shared[(tx * 3 + vec) // 4, (tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3 + vec) % 4] for v_u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(4): - with T.block("B_shared"): + with T.sblock("B_shared"): T.where(k_0 * 4 + tx // 32 < 770 and tx * 4 + vec < 512) B_shared[tx // 32, tx % 32 * 4 + vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2): - with T.block("update_update"): + with T.sblock("update_update"): C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 + tx_p % 32 * 2 + j_4] # fmt: on @@ -932,20 +932,20 @@ class TestDependentBufferIndices(BaseCompactTest): def before(): """This is a diagnal buffer access pattern""" for i in range(8): - with T.block(): + with T.sblock(): A = T.alloc_buffer((256, 256), "float32") for j, k in T.grid(8, 8): - with T.block(): + with T.sblock(): T.where(j * 8 + k < 60) A[i * 64 + j * 8 + k, i * 64 + j * 8 + k] = 1.0 @T.prim_func def expected() -> None: for i in T.serial(8): - with T.block(): + with T.sblock(): A = T.alloc_buffer([60, 60], dtype="float32") for j, k in T.grid(8, 8): - with T.block(): + with T.sblock(): T.where(j * 8 + k < 60) A[j * 8 + k, j * 8 + k] = 1.0 @@ -960,18 +960,18 @@ def before( C: T.Buffer((1020, 1000), "float32"), ): for i0, i1 in T.grid(4, 1): - with T.block(): + with T.sblock(): C_local2 = T.alloc_buffer([4, 1, 16, 1000, 16], dtype="float32", scope="local") C_local1 = T.alloc_buffer([1020, 1000], dtype="float32", scope="local") for ax0, ax1, ax2 in T.grid(255, 1000, 64): - with T.block("matmul"): + with T.sblock("matmul"): if ax2 == 0: C_local1[i0 * 255 + ax0, ax1] = 0 C_local1[i0 * 255 + ax0, ax1] = ( C_local1[i0 * 255 + ax0, ax1] + A[i0 * 255 + ax0, ax2] * B[ax1, ax2] ) for ax0, ax1 in T.grid(255, 1000): - with T.block("st1"): + with T.sblock("st1"): C_local2[ (i0 * 255 + ax0) // 255, 0, @@ -980,7 +980,7 @@ def before( (i0 * 255 + ax0) % 255 % 16, ] = C_local1[i0 * 255 + ax0, ax1] for ax0, ax1, ax2 in T.grid(16, 16, 1000): - with T.block("st2"): + with T.sblock("st2"): T.where(ax0 * 16 + ax1 < 255) C[i0 * 255 + (ax0 * 16 + ax1), i1 * 1000 + ax2] = C_local2[ (i0 * 255 + ax0 * 16 + ax1) // 255, @@ -997,21 +997,21 @@ def expected( C: T.Buffer((1020, 1000), "float32"), ) -> None: for i0, i1 in T.grid(4, 1): - with T.block(): + with T.sblock(): C_local2 = T.alloc_buffer([1, 1, 16, 1000, 16], dtype="float32", scope="local") C_local1 = T.alloc_buffer([255, 1000], dtype="float32", scope="local") for ax0, ax1, ax2 in T.grid(255, 1000, 64): - with T.block("matmul"): + with T.sblock("matmul"): if ax2 == 0: C_local1[ax0, ax1] = 0 C_local1[ax0, ax1] = ( C_local1[ax0, ax1] + A[i0 * 255 + ax0, ax2] * B[ax1, ax2] ) for ax0, ax1 in T.grid(255, 1000): - with T.block("st1"): + with T.sblock("st1"): C_local2[0, 0, ax0 // 16, ax1, ax0 % 16] = C_local1[ax0, ax1] for ax0, ax1, ax2 in T.grid(16, 16, 1000): - with T.block("st2"): + with T.sblock("st2"): T.where(ax0 * 16 + ax1 < 255) C[i0 * 255 + ax0 * 16 + ax1, ax2] = C_local2[ (ax0 * 16 + ax1) // 255, @@ -1163,24 +1163,24 @@ def before( ): """A mock workload where the intermediate buffer allocation is not enought originally""" for i_0, j_0 in T.grid(4, 4): - with T.block(""): + with T.sblock(""): T.reads(A[i_0 * 32 : i_0 * 32 + 32, 0:128], B[0:128, j_0 * 32 : j_0 * 32 + 32]) T.writes(C[i_0 * 32 : i_0 * 32 + 32, j_0 * 32 : j_0 * 32 + 32]) A_local = T.alloc_buffer((127, 127), scope="local") B_local = T.alloc_buffer((127, 127), scope="local") C_local = T.alloc_buffer((127, 127), scope="local") for ax0, ax1 in T.grid(32, 128): - with T.block("A_local"): + with T.sblock("A_local"): A_local[i_0 * 32 + ax0, ax1] = T.if_then_else( i_0 * 32 + ax0 < 127, A[i_0 * 32 + ax0, ax1], 0.0 ) for ax0, ax1 in T.grid(128, 32): - with T.block("B_local"): + with T.sblock("B_local"): B_local[ax0, j_0 * 32 + ax1] = T.if_then_else( j_0 * 32 + ax1 < 127, B[ax0, j_0 * 32 + ax1], 0.0 ) for i_1, j_1, k in T.grid(32, 32, 128): - with T.block("compute"): + with T.sblock("compute"): T.where(i_0 * 32 + i_1 < 127 and j_0 * 32 + j_1 < 127) if k == 0: C_local[i_0 * 32 + i_1, j_0 * 32 + j_1] = T.float32(0) @@ -1189,7 +1189,7 @@ def before( + A_local[i_0 * 32 + i_1, k] * B_local[k, j_0 * 32 + j_1] ) for ax0, ax1 in T.grid(32, 32): - with T.block("C_local"): + with T.sblock("C_local"): T.where(i_0 * 32 + ax0 < 127 and j_0 * 32 + ax1 < 127) C[i_0 * 32 + ax0, j_0 * 32 + ax1] = C_local[i_0 * 32 + ax0, j_0 * 32 + ax1] @@ -1200,30 +1200,30 @@ def expected( C: T.Buffer((127, 127), "float32"), ): for i_0, j_0 in T.grid(4, 4): - with T.block(""): + with T.sblock(""): T.reads(A[i_0 * 32 : i_0 * 32 + 32, 0:128], B[0:128, j_0 * 32 : j_0 * 32 + 32]) T.writes(C[i_0 * 32 : i_0 * 32 + 32, j_0 * 32 : j_0 * 32 + 32]) A_local = T.alloc_buffer((32, 128), scope="local") B_local = T.alloc_buffer((128, 32), scope="local") C_local = T.alloc_buffer((32, 32), scope="local") for ax0, ax1 in T.grid(32, 128): - with T.block("A_local"): + with T.sblock("A_local"): A_local[ax0, ax1] = T.if_then_else( i_0 * 32 + ax0 < 127, A[i_0 * 32 + ax0, ax1], T.float32(0) ) for ax0, ax1 in T.grid(128, 32): - with T.block("B_local"): + with T.sblock("B_local"): B_local[ax0, ax1] = T.if_then_else( j_0 * 32 + ax1 < 127, B[ax0, j_0 * 32 + ax1], T.float32(0) ) for i_1, j_1, k in T.grid(32, 32, 128): - with T.block("compute"): + with T.sblock("compute"): T.where(i_0 * 32 + i_1 < 127 and j_0 * 32 + j_1 < 127) if k == 0: C_local[i_1, j_1] = T.float32(0) C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] for ax0, ax1 in T.grid(32, 32): - with T.block("C_local"): + with T.sblock("C_local"): T.where(i_0 * 32 + ax0 < 127 and j_0 * 32 + ax1 < 127) C[i_0 * 32 + ax0, j_0 * 32 + ax1] = C_local[ax0, ax1] @@ -1292,13 +1292,13 @@ def before(x: T.handle, y: T.handle, n: T.int64): X = T.match_buffer(x, (T.int64(8), n * T.int64(32))) Y = T.match_buffer(y, (T.int64(8), n * T.int64(32))) for i, k_0 in T.grid(T.int64(8), n): - with T.block(""): + with T.sblock(""): X_global = T.alloc_buffer((T.int64(8), n * T.int64(32))) for ax0 in range(T.int64(32)): - with T.block("X_global"): + with T.sblock("X_global"): X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 * T.int64(32) + ax0] for k_1 in range(T.int64(32)): - with T.block("Y"): + with T.sblock("Y"): Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 * T.int64(32) + k_1] @T.prim_func @@ -1306,13 +1306,13 @@ def expected(x: T.handle, y: T.handle, n: T.int64): X = T.match_buffer(x, (T.int64(8), n * T.int64(32))) Y = T.match_buffer(y, (T.int64(8), n * T.int64(32))) for i, k_0 in T.grid(T.int64(8), n): - with T.block(""): + with T.sblock(""): X_global = T.alloc_buffer((T.int64(1), T.int64(32))) for ax0 in range(T.int64(32)): - with T.block("X_global"): + with T.sblock("X_global"): X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32) + ax0] for k_1 in range(T.int64(32)): - with T.block("Y"): + with T.sblock("Y"): Y[i, k_0 * T.int64(32) + k_1] = X_global[T.int64(0), k_1] @@ -1324,12 +1324,12 @@ def before(x: T.handle, y: T.handle, n: T.int64): X = T.match_buffer(x, (T.int64(8), n * T.int64(32))) Y = T.match_buffer(y, (T.int64(8), n * T.int64(32))) for i, k_0 in T.grid(T.int64(8), n): - with T.block(""): + with T.sblock(""): X_global = T.alloc_buffer((T.int64(8), n * T.int64(32))) - with T.block("X_global"): + with T.sblock("X_global"): for x0 in range(T.int64(32)): X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 * T.int64(32) + x0] - with T.block("Y"): + with T.sblock("Y"): for x1 in range(T.int64(32)): Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 * T.int64(32) + x1] @@ -1337,14 +1337,14 @@ def before(x: T.handle, y: T.handle, n: T.int64): def expected(x: T.handle, y: T.handle, n: T.int64): X = T.match_buffer(x, (T.int64(8), n * T.int64(32))) Y = T.match_buffer(y, (T.int64(8), n * T.int64(32))) - # with T.block("root"): + # with T.sblock("root"): for i, k_0 in T.grid(T.int64(8), n): - with T.block(""): + with T.sblock(""): X_global = T.alloc_buffer((T.int64(1), T.int64(32))) - with T.block("X_global"): + with T.sblock("X_global"): for x0 in range(T.int64(32)): X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32) + x0] - with T.block("Y"): + with T.sblock("Y"): for x1 in range(T.int64(32)): Y[i, k_0 * T.int64(32) + x1] = X_global[T.int64(0), x1] @@ -1359,7 +1359,7 @@ def before(p_output0: T.handle, n: T.int32): for i in T.thread_binding(256, thread="blockIdx.x"): for j in T.thread_binding(256, thread="threadIdx.x"): for k in range((n * n + 65535) // 65536): - with T.block("make_diag_mask_te"): + with T.sblock("make_diag_mask_te"): T.where((k * 256 + i) * 256 + j < n * n) T.reads() T.writes(B[(k * 65536 + i * 256 + j) // n, (k * 65536 + i * 256 + j) % n]) @@ -1371,7 +1371,7 @@ def before(p_output0: T.handle, n: T.int32): for i in T.thread_binding(256, thread="blockIdx.x"): for j in T.thread_binding(256, thread="threadIdx.x"): for k in range((n * n + 65535) // 65536): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): T.where((k * 256 + i) * 256 + j < n * n) T.reads(B[(k * 65536 + i * 256 + j) // n, (k * 65536 + i * 256 + j) % n]) T.writes( @@ -1388,7 +1388,7 @@ def expected(p_output0: T.handle, n: T.int32): for i in T.thread_binding(256, thread="blockIdx.x"): for j in T.thread_binding(256, thread="threadIdx.x"): for k in range((n * n + 65535) // 65536): - with T.block("make_diag_mask_te"): + with T.sblock("make_diag_mask_te"): T.where(k * 65536 + i * 256 + j < n * n) T.reads() T.writes(B[(k * 65536 + i * 256 + j) // n, (k * 65536 + i * 256 + j) % n]) @@ -1400,7 +1400,7 @@ def expected(p_output0: T.handle, n: T.int32): for i in T.thread_binding(256, thread="blockIdx.x"): for k in T.thread_binding(256, thread="threadIdx.x"): for k in range((n * n + 65535) // 65536): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): T.where(k * 65536 + i * 256 + k < n * n) T.reads(B[(k * 65536 + i * 256 + k) // n, (k * 65536 + i * 256 + k) % n]) T.writes( diff --git a/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py index 63a57eeffe29..73c2bdc7f99c 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py @@ -33,17 +33,17 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block(): + with T.sblock(): vi = T.axis.S(16, i) vj = T.axis.S(16, j) B[vi, vj] = A[vi, vj] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): vi = T.axis.S(16, i) vj = T.axis.S(16, j) C[vi, vj] = B[vi, vj] * 2.0 @@ -54,17 +54,17 @@ def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads([A[i, j]]) T.writes([B[i, j]]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads([B[i, j]]) T.writes([C[i, j]]) C[i, j] = B[i, j] * 2.0 @@ -80,7 +80,7 @@ class TestErrorIfPredicateUsesBlockVariables(tvm.testing.CompareBeforeAfter): def before(A: T.Buffer(8, "int32")): for i in T.serial(8): - with T.block(): + with T.sblock(): vi = T.axis.remap("S", [i]) T.where(vi < 6) T.evaluate(0) diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index 925f004cc527..04484d5fabfc 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -265,14 +265,14 @@ class TestFlattenInsideBlock(BaseCompare): def before(): A = T.alloc_buffer([32, 32]) for i, j in T.grid(32, 32): - with T.block("block"): + with T.sblock("block"): T.reads(A[i, j]) T.evaluate(A[i, j]) def expected(): A = T.alloc_buffer([1024]) for i, j in T.grid(32, 32): - with T.block("block"): + with T.sblock("block"): T.reads(A[i * 32 + j]) T.evaluate(A[i * 32 + j]) diff --git a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py index c85929e4f6bf..0e49fe6c1032 100644 --- a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py +++ b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py @@ -56,7 +56,7 @@ def before( for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): - with T.block("T_where"): + with T.sblock("T_where"): ax0 = T.axis.spatial(T.int64(1), T.int64(0)) ax1 = T.axis.spatial( T.int64(12), @@ -114,7 +114,7 @@ def expected( for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in range(7): - with T.block("T_where"): + with T.sblock("T_where"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial( 12, @@ -165,7 +165,7 @@ def test_block(): def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in T.serial(0, T.int64(16)): for j in T.serial(0, T.int64(8)): - with T.block(): + with T.sblock(): vi = T.axis.spatial(T.int64(128), i * T.int64(8) + j) B[vi] = A[vi] + T.float32(1) @@ -173,7 +173,7 @@ def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): def expected(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in T.serial(0, T.int32(16)): for j in T.serial(0, T.int32(8)): - with T.block(): + with T.sblock(): vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j) B[vi] = A[vi] + T.float32(1) @@ -187,7 +187,7 @@ def test_i16_buffer(): def before(A: T.Buffer((128,), "int16"), B: T.Buffer((128,), "int16")): for i in T.serial(0, T.int64(16)): for j in T.serial(0, T.int64(16)): - with T.block(): + with T.sblock(): vi = T.axis.spatial(T.int64(128), i * 8 + j) B[vi] = A[vi] + T.int16(1) @@ -195,7 +195,7 @@ def before(A: T.Buffer((128,), "int16"), B: T.Buffer((128,), "int16")): def expected(A: T.Buffer((128,), "int16"), B: T.Buffer((128,), "int16")): for i in T.serial(0, 16): for j in T.serial(0, 16): - with T.block(): + with T.sblock(): vi = T.axis.spatial(128, i * 8 + j) B[vi] = A[vi] + T.int16(1) @@ -209,7 +209,7 @@ def test_fail_on_buffer_map(): def func(A: T.Buffer((128,), "int64"), B: T.Buffer((128,), "int64")): for i in T.serial(0, 16): for j in T.serial(0, 8): - with T.block(): + with T.sblock(): vi = T.axis.spatial(128, i * 8 + j) B[vi] = A[vi] + T.int64(1) @@ -224,12 +224,12 @@ def func(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): C = T.alloc_buffer((128,), "int64") for i in T.serial(0, 16): for j in T.serial(0, 8): - with T.block(): + with T.sblock(): vi = T.axis.spatial(128, i * 8 + j) C[vi] = T.cast(A[vi], "int64") + T.int64(1) for i in T.serial(0, 16): for j in T.serial(0, 8): - with T.block(): + with T.sblock(): vi = T.axis.spatial(128, i * 8 + j) B[vi] = T.cast(C[vi] + T.int64(1), "int32") diff --git a/tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py b/tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py index 6495cdb2bd54..0c8013984e18 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py +++ b/tests/python/tir-transform/test_tir_transform_inject_permuted_layout.py @@ -36,35 +36,35 @@ def test_backward_compatibility_shared_a(): # fmt: off @T.prim_func def before(X: T.Buffer((4096, 4096), "float16")): - # with T.block("root"): + # with T.sblock("root"): for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072]) T.writes() for ax2_0_0 in range(128): - with T.block(""): + with T.sblock(""): T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) T.writes() X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") - with T.block("X_reindex_shared.dyn"): + with T.sblock("X_reindex_shared.dyn"): T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) - T.block_attr({"permuted_layout": "g2s_A"}) + T.sblock_attr({"permuted_layout": "g2s_A"}) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_3 in T.vectorized(8): X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] for ax2_0_1 in range(4): - with T.block(""): + with T.sblock(""): T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64:threadIdx_y // 2 * 64 + 64, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) T.writes() X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + with T.sblock("X_reindex_shared.dyn_m16n8k8.matrixA_o"): T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) - T.block_attr({"permuted_layout": "s2l_A"}) + T.sblock_attr({"permuted_layout": "s2l_A"}) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x * 32) @T.prim_func @@ -72,11 +72,11 @@ def expected(X: T.Buffer((4096, 4096), "float16")): for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): for ax2_0_0 in T.serial(128): - with T.block(""): + with T.sblock(""): X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") - with T.block("X_reindex_shared.dyn"): + with T.sblock("X_reindex_shared.dyn"): # annotate the reads and writes because they cannot be inferred from tir.bitwise_xor T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) @@ -84,10 +84,10 @@ def expected(X: T.Buffer((4096, 4096), "float16")): for ax0_ax1_fused_3 in T.vectorized(8): X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] for ax2_0_1 in T.serial(4): - with T.block(""): + with T.sblock(""): X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + with T.sblock("X_reindex_shared.dyn_m16n8k8.matrixA_o"): T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 + T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0) @@ -103,36 +103,36 @@ def before(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 4096), "floa for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): for ax2_0_0 in T.serial(128): - with T.block(""): + with T.sblock(""): X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") Y_reindex_shared_dyn = T.alloc_buffer((32, 128), "float16", strides=(128, 1), scope="shared.dyn") - with T.block("X_reindex_shared.dyn"): - T.block_attr({"permuted_layout": "g2s_A"}) + with T.sblock("X_reindex_shared.dyn"): + T.sblock_attr({"permuted_layout": "g2s_A"}) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_3 in T.vectorized(8): X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] - with T.block("Y_reindex_shared.dyn"): - T.block_attr({"permuted_layout": "g2s_B"}) + with T.sblock("Y_reindex_shared.dyn"): + T.sblock_attr({"permuted_layout": "g2s_B"}) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_3 in T.vectorized(8): Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, threadIdx_x % 16 * 8 + ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + ax0_ax1_fused_3] for ax2_0_1 in T.serial(4): - with T.block(""): + with T.sblock(""): X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16", scope="m16n8k8.matrixB") for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + with T.sblock("X_reindex_shared.dyn_m16n8k8.matrixA_o"): T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) - T.block_attr({"permuted_layout": "s2l_A"}) + T.sblock_attr({"permuted_layout": "s2l_A"}) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x * 32) for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"): + with T.sblock("Y_reindex_shared.dyn_m16n8k8.matrixB_o"): T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64 + ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32]) T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32]) - T.block_attr({"permuted_layout": "s2l_B"}) + T.sblock_attr({"permuted_layout": "s2l_B"}) T.ptx_ldmatrix("float16", T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data, ax2_0_1 * 1024 + threadIdx_y % 2 * 64 + ax1_0 * 32, 1024, 1), threadIdx_x % 8 * 128 + threadIdx_x // 8 * 8) @T.prim_func @@ -141,38 +141,38 @@ def expected(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 4096), "fl for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"): for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072], Y[threadIdx_y * 2 + threadIdx_x // 16:threadIdx_y * 2 + threadIdx_x // 16 + 4089, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8]) T.writes() for ax2_0_0 in T.serial(128): - with T.block(""): + with T.sblock(""): T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8], Y[ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8]) T.writes() X_reindex_shared_dyn = T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") Y_reindex_shared_dyn = T.alloc_buffer((32, 128), "float16", strides=(128, 1), scope="shared.dyn") - with T.block("X_reindex_shared.dyn"): + with T.sblock("X_reindex_shared.dyn"): T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_3 in T.vectorized(8): X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3] - with T.block("Y_reindex_shared.dyn"): + with T.sblock("Y_reindex_shared.dyn"): T.reads(Y[ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8]) T.writes(Y_reindex_shared_dyn[threadIdx_y * 2 + threadIdx_x // 16:threadIdx_y * 2 + threadIdx_x // 16 + 25, threadIdx_x % 16 * 8:threadIdx_x % 16 * 8 + 8]) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_3 in T.vectorized(8): Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, T.bitwise_xor(threadIdx_x % 16, threadIdx_y * 2 + threadIdx_x // 16) * 8 + ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + ax0_ax1_fused_3] for ax2_0_1 in T.serial(4): - with T.block(""): + with T.sblock(""): X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA") Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16", scope="m16n8k8.matrixB") for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"): + with T.sblock("X_reindex_shared.dyn_m16n8k8.matrixA_o"): T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y // 2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8]) T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8]) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data, threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 + T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0) for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"): + with T.sblock("Y_reindex_shared.dyn_m16n8k8.matrixB_o"): T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64 + ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32]) T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32]) T.ptx_ldmatrix("float16", T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8, T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data, ax2_0_1 * 1024 + threadIdx_x % 8 * 128 + T.bitwise_xor(threadIdx_y % 2 * 8 + ax1_0 * 4 + threadIdx_x // 8, threadIdx_x % 8) * 8, 1024, 1), 0) @@ -192,8 +192,8 @@ def before(p_A: T.handle): for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): for v0 in range(T.int64(4)): for v1 in T.vectorized(T.int64(8)): - with T.block("A_reindex_shared.dyn"): - T.block_attr({"permuted_layout": 1}) + with T.sblock("A_reindex_shared.dyn"): + T.sblock_attr({"permuted_layout": 1}) A_shared_dyn[ v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1 @@ -202,9 +202,9 @@ def before(p_A: T.handle): threadIdx_x % T.int64(4) * T.int64(8) + v1 ] for v0, v1 in T.grid(T.int64(2), T.int64(4)): - with T.block("A_reindex_shared.dyn_warp_o"): - T.block_attr({"permuted_layout": 1}) - with T.block("A_reindex_shared.dyn_warp_o"): + with T.sblock("A_reindex_shared.dyn_warp_o"): + T.sblock_attr({"permuted_layout": 1}) + with T.sblock("A_reindex_shared.dyn_warp_o"): T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) T.writes(A_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", @@ -227,15 +227,15 @@ def expected(A: T.Buffer((T.int64(128), T.int64(32)), "float16")): for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): for v0 in range(T.int64(4)): for v1 in T.vectorized(T.int64(8)): - with T.block("A_reindex_shared.dyn"): + with T.sblock("A_reindex_shared.dyn"): T.reads(A[(v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), threadIdx_x % T.int64(4) * T.int64(8) + v1]) T.writes(A_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1]) A_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8) + v1] = A[(v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), threadIdx_x % T.int64(4) * T.int64(8) + v1] for v0, v1 in T.grid(T.int64(2), T.int64(4)): - with T.block("A_reindex_shared.dyn_warp_o"): + with T.sblock("A_reindex_shared.dyn_warp_o"): T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) T.writes(A_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) - with T.block("A_reindex_shared.dyn_warp_o"): + with T.sblock("A_reindex_shared.dyn_warp_o"): T.reads(A_shared_dyn[threadIdx_z * T.int64(64) + v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) T.writes(A_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", A_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), A_shared_dyn.data, threadIdx_z * T.int64(2048) + v1 * T.int64(512) + threadIdx_x % T.int64(16) * T.int64(32) + T.bitwise_xor(v0 * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(8) // T.int64(2)) * T.int64(8), T.int64(512), 1), T.int64(0)) @@ -254,16 +254,16 @@ def before(B: T.Buffer((T.int64(128), T.int64(32)), "float16")): for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): for v0 in range(T.int64(4)): for v1 in T.vectorized(T.int64(8)): - with T.block("B_reindex_shared.dyn"): - T.block_attr({"permuted_layout": 1}) + with T.sblock("B_reindex_shared.dyn"): + T.sblock_attr({"permuted_layout": 1}) B_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1] = B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1] for v0 in range(T.int64(2)): - with T.block(""): + with T.sblock(""): B_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), T.int64(8)), "float16", scope="warp") for v1 in range(T.int64(4)): - with T.block("B_reindex_shared.dyn_warp_o"): - T.block_attr({"permuted_layout": 1}) - with T.block("B_reindex_shared.dyn_warp_o"): + with T.sblock("B_reindex_shared.dyn_warp_o"): + T.sblock_attr({"permuted_layout": 1}) + with T.sblock("B_reindex_shared.dyn_warp_o"): T.reads(B_shared_dyn[threadIdx_y * T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) T.writes(B_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data, threadIdx_y * T.int64(2048) + v1 * T.int64(512) + v0 * T.int64(16), T.int64(512), 1), threadIdx_x // T.int64(16) * T.int64(256) + threadIdx_x % T.int64(8) * T.int64(32) + threadIdx_x % T.int64(16) // T.int64(8) * T.int64(8)) @@ -276,18 +276,18 @@ def expected(B: T.Buffer((T.int64(128), T.int64(32)), "float16")): for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): for v0 in range(T.int64(4)): for v1 in T.vectorized(T.int64(8)): - with T.block("B_reindex_shared.dyn"): + with T.sblock("B_reindex_shared.dyn"): T.reads(B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1]) T.writes(B_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1]) B_shared_dyn[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8) + v1] = B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1] for v0 in range(T.int64(2)): - with T.block(""): + with T.sblock(""): B_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32), T.int64(8)), "float16", scope="warp") for v1 in range(T.int64(4)): - with T.block("B_reindex_shared.dyn_warp_o"): + with T.sblock("B_reindex_shared.dyn_warp_o"): T.reads(B_shared_dyn[threadIdx_y * T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) T.writes(B_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) - with T.block("B_reindex_shared.dyn_warp_o"): + with T.sblock("B_reindex_shared.dyn_warp_o"): T.reads(B_shared_dyn[threadIdx_y * T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)]) T.writes(B_warp[v1, T.int64(0), T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data, threadIdx_y * T.int64(2048) + v1 * T.int64(512) + threadIdx_x // T.int64(16) * T.int64(256) + threadIdx_x % T.int64(8) * T.int64(32) + T.bitwise_xor(v0 * T.int64(2) + threadIdx_x % T.int64(16) // T.int64(8), threadIdx_x % T.int64(8) // T.int64(2)) * T.int64(8), T.int64(512), 1), T.int64(0)) @@ -307,38 +307,38 @@ def before(p_O: T.handle): for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): for v0, v1 in T.grid(T.int64(4), T.int64(4)): - with T.block("O.dyn_warp_o"): - T.block_attr({"permuted_layout": 1}) - with T.block("O.dyn_warp_o"): + with T.sblock("O.dyn_warp_o"): + T.sblock_attr({"permuted_layout": 1}) + with T.sblock("O.dyn_warp_o"): for local_id in range(T.int64(8)): O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_y * T.int64(64) + v1 * T.int64(16) + local_id // T.int64(4) * T.int64(8) + threadIdx_x % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = O_warp[v0, v1, threadIdx_x, local_id] for v0 in range(T.int64(16)): for v1 in T.vectorized(T.int64(8)): - with T.block("O.dyn"): - T.block_attr({"permuted_layout": 1}) + with T.sblock("O.dyn"): + T.sblock_attr({"permuted_layout": 1}) O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1]) @T.prim_func def expected(O: T.Buffer((T.int64(128), T.int64(128)), "float16")): - # with T.block("root"): + # with T.sblock("root"): O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)), scope="shared.dyn") O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32), T.int64(8)), scope="warp") for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"): for threadIdx_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): for threadIdx_x in T.thread_binding(T.int64(32), thread="threadIdx.x"): for v0, v1 in T.grid(T.int64(4), T.int64(4)): - with T.block("O.dyn_warp_o"): + with T.sblock("O.dyn_warp_o"): T.reads(O_warp[v0, v1, threadIdx_x, T.int64(0):T.int64(8)]) T.writes(O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2):threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2) + T.int64(10)]) - with T.block("O.dyn_warp_o"): + with T.sblock("O.dyn_warp_o"): T.reads(O_warp[v0, v1, threadIdx_x, T.int64(0):T.int64(8)]) T.writes(O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2):threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2) + T.int64(10)]) for local_id in range(T.int64(8)): O_shared_dyn[threadIdx_z * T.int64(64) + v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) + threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_y * T.int64(8) + v1 * T.int64(2) + local_id // T.int64(4), threadIdx_x // T.int64(4)) * T.int64(8) + threadIdx_x % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = O_warp[v0, v1, threadIdx_x, local_id] for v0 in range(T.int64(16)): for v1 in T.vectorized(T.int64(8)): - with T.block("O.dyn"): + with T.sblock("O.dyn"): T.reads(O_shared_dyn[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1]) T.writes(O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1]) O[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8) + threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), T.bitwise_xor(threadIdx_x % T.int64(16), threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16)) * T.int64(8) + v1]) diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index aa7e2b357564..51cab991d2cc 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -48,7 +48,7 @@ def ptx_global_to_shared_copy( tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], dtype, scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) @@ -76,7 +76,7 @@ def ptx_global_to_shared_copy_fp32x1( tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], "float32", scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) @@ -103,7 +103,7 @@ def ptx_global_to_shared_dyn_copy_fp16x8( tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") T.reads(A[0:32, 0:128], B[0:32, 0:128]) @@ -193,7 +193,7 @@ def ptx_global_to_shared_copy_fp32x1_barrier( tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - with T.block(): + with T.sblock(): A_shared = T.alloc_buffer([32, 128], "float32", scope="shared") T.reads(A[0:32, 0:128]) @@ -452,24 +452,24 @@ def simple_compute( "software_pipeline_async_stages": [0], }, ): - with T.block("compute"): + with T.sblock("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(A_shared[tx, 0]) A_shared[tx, 0] = T.if_then_else( 1 <= i and i < 15, A[tx, i - 1], T.float32(0), dtype="float32" ) - with T.block(): + with T.sblock(): T.reads(B[tx, i]) T.writes(B_shared[tx, 0]) B_shared[tx, 0] = T.if_then_else( 1 <= i and i < 15, B[tx, i - 1], T.float32(0), dtype="float32" ) - with T.block(): + with T.sblock(): T.reads(A_shared[tx, 0], B_shared[tx, 0]) T.writes(C[tx, i]) C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] @@ -497,7 +497,7 @@ def complex_compute( Conv: T.Buffer((512, 1280), "float16"), ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): data_im2col_reindex_shared_dyn = T.alloc_buffer((512, 11520), "float16", scope="shared.dyn") data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer( (512, 11520), "float16", scope="wmma.matrix_a" @@ -516,7 +516,7 @@ def complex_compute( for x_0_1 in T.thread_binding(2, thread="threadIdx.y"): for y_0_1 in T.thread_binding(2, thread="threadIdx.z"): for x_0_2_init, y_0_2_init in T.grid(2, 2): - with T.block("Conv_init_o"): + with T.sblock("Conv_init_o"): v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) T.reads() @@ -562,7 +562,7 @@ def complex_compute( for ax0_ax1_0_fused_3 in T.thread_binding( 32, thread="threadIdx.x" ): - with T.block("data_im2col_reindex_shared.dyn_o"): + with T.sblock("data_im2col_reindex_shared.dyn_o"): v0 = T.axis.spatial( 512, x_0_0 * 64 @@ -599,7 +599,7 @@ def complex_compute( ] ) for ax1_1 in T.vectorized(8): - with T.block("data_im2col_reindex_shared.dyn"): + with T.sblock("data_im2col_reindex_shared.dyn"): v1_i = T.axis.spatial(8, ax1_1) T.reads( A[ @@ -614,7 +614,7 @@ def complex_compute( v0, v1_o * 8 + v1_i ] ) - T.block_attr( + T.sblock_attr( {"buffer_dim_align": [[0, 0, 32, 8]]} ) data_im2col_reindex_shared_dyn[ @@ -641,7 +641,7 @@ def complex_compute( 32, thread="threadIdx.x" ): for ax1_1 in T.vectorized(8): - with T.block("weight_flatten_reindex_shared.dyn"): + with T.sblock("weight_flatten_reindex_shared.dyn"): v0 = T.axis.spatial( 1280, y_0_0 * 64 @@ -677,7 +677,7 @@ def complex_compute( T.writes( weight_flatten_reindex_shared_dyn[v0, v1] ) - T.block_attr( + T.sblock_attr( {"buffer_dim_align": [[0, 0, 32, 8]]} ) weight_flatten_reindex_shared_dyn[v0, v1] = W[ @@ -688,7 +688,7 @@ def complex_compute( ] for k_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): + with T.sblock("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) T.reads( @@ -747,7 +747,7 @@ def complex_compute( "row_major", ) for ax0_0, ax1_0 in T.grid(2, 1): - with T.block( + with T.sblock( "weight_flatten_reindex_shared.dyn_wmma.matrix_b_o" ): v0_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) @@ -808,7 +808,7 @@ def complex_compute( "col_major", ) for x_0_2, y_0_2 in T.grid(2, 2): - with T.block("Conv_update_o"): + with T.sblock("Conv_update_o"): v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) v_k_o = T.axis.reduce(720, k_0_0 * 4 + k_0_1) @@ -886,7 +886,7 @@ def complex_compute( + C.elem_offset % C_s0 // 16, ) for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("Conv_reindex_wmma.accumulator_o"): + with T.sblock("Conv_reindex_wmma.accumulator_o"): v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) T.reads( diff --git a/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py b/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py index 697887dc8cbb..f1dc6c3d6fb7 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py @@ -57,15 +57,15 @@ def trivial_pipeline(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "floa for i in T.serial( 0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]} ): - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1) @@ -76,19 +76,19 @@ def transformed_trivial_pipeline( A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32") ) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0]) T.writes(C[tx, 0]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) B[0, tx, 0] = A[tx, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads() T.writes() T.evaluate(0) - with T.block(): + with T.sblock(): T.reads(B[0, tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[0, tx, 0] + T.float32(1) @@ -106,15 +106,15 @@ def simple_compute(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "floa "software_pipeline_order": [0, 1], }, ): - with T.block("compute"): + with T.sblock("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1) @@ -127,27 +127,27 @@ def transformed_simple_compute( A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32") ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads([A[tx, 0:16]]) T.writes([C[tx, 0:16]]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads([A[tx, 0]]) T.writes([B[0, tx, 0]]) B[0, tx, 0] = A[tx, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([A[tx, 1:16], B[0:2, tx, 0]]) T.writes([B[0:2, tx, 0], C[tx, 0:15]]) for i in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1]]) T.writes([B[(i + 1) % 2, tx, 0]]) B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) C[tx, i] = B[i % 2, tx, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, 0]]) T.writes([C[tx, 15]]) C[tx, 15] = B[1, tx, 0] + T.float32(1) @@ -167,15 +167,15 @@ def dynamic_compute(a_handle: T.handle, c_handle: T.handle): "software_pipeline_order": [0, 1], }, ): - with T.block("compute"): + with T.sblock("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1) @@ -187,34 +187,34 @@ def transformed_dynamic_compute(a_handle: T.handle, c_handle: T.handle): A = T.match_buffer(a_handle, (16, k), "float32") C = T.match_buffer(c_handle, (16, k), "float32") for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0 : T.max(1, k)]) T.writes(C[tx, T.min(0, k - 1) : T.min(0, k - 1) + T.max(k, 1)]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(""): + with T.sblock(""): T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) - with T.block(""): + with T.sblock(""): T.where(0 < k) T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) B[0, tx, 0] = A[tx, 0] * T.float32(2) - with T.block(""): + with T.sblock(""): T.reads(A[tx, 1 : 1 + (k - 1)], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[tx, 0 : k - 1]) for i in range(k - 1): - with T.block(""): + with T.sblock(""): T.reads(A[tx, i + 1]) T.writes(B[(i + 1) % 2, tx, 0]) B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) - with T.block(""): + with T.sblock(""): T.reads(B[i % 2, tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[i % 2, tx, 0] + T.float32(1) - with T.block(""): + with T.sblock(""): T.reads(B[(k + 1) % 2, tx, 0]) T.writes(C[tx, k - 1]) - with T.block(""): + with T.sblock(""): T.where(1 <= k) T.reads(B[(k + 1) % 2, tx, 0]) T.writes(C[tx, k - 1]) @@ -235,15 +235,15 @@ def simple_compute_with_other_annotation( "pragma_loop_partition_hint": True, }, ): - with T.block("compute"): + with T.sblock("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1) @@ -254,15 +254,15 @@ def transformed_simple_compute_with_other_annotation( A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32") ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads([A[tx, 0:16]]) T.writes([C[tx, 0:16]]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads([A[tx, 0]]) T.writes([B[0, tx, 0]]) B[0, tx, 0] = A[tx, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([A[tx, 1:16], B[0:2, tx, 0]]) T.writes([B[0:2, tx, 0], C[tx, 0:15]]) for i in T.serial( @@ -270,15 +270,15 @@ def transformed_simple_compute_with_other_annotation( 15, annotations={"pragma_loop_partition_hint": True}, ): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1]]) T.writes([B[(i + 1) % 2, tx, 0]]) B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) C[tx, i] = B[i % 2, tx, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, 0]]) T.writes([C[tx, 15]]) C[tx, 15] = B[1, tx, 0] + T.float32(1) @@ -295,20 +295,20 @@ def three_stage_compute(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "software_pipeline_order": [0, 1, 2], }, ): - with T.block("compute"): + with T.sblock("compute"): T.reads(A[tx, i]) T.writes(D[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(C[tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[tx, 0] + T.float32(1) @@ -319,50 +319,50 @@ def transformed_three_stage_compute( A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32") ) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0:16]) T.writes(D[tx, 0:16]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, 0:2], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) for i in T.unroll(2): - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[0:2, tx, 0]) B[i, tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.where(i == 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) for i in T.serial(14): - with T.block(): + with T.sblock(): T.reads(A[tx, i + 2]) T.writes(B[0:2, tx, 0]) B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[i % 2, tx, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(C[0:2, tx, 0], D[tx, 14:16]) for i in T.unroll(2): - with T.block(): + with T.sblock(): T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) D[tx, i + 14] = C[i, tx, 0] + T.float32(1) @@ -383,30 +383,30 @@ def dag_interleaving( "software_pipeline_order": [0, 2, 1, 3, 4], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(C[tx, i]) AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") AL = T.alloc_buffer((1, 1), dtype="float32", scope="local") BL = T.alloc_buffer((1, 1), dtype="float32", scope="local") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(AS[tx, 0]) AS[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(AS[tx, 0]) T.writes(AL[0, 0]) AL[0, 0] = AS[tx, 0] - with T.block(): + with T.sblock(): T.reads(B[tx, i]) T.writes(BS[tx, 0]) BS[tx, 0] = B[tx, i] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(BS[tx, 0]) T.writes(BL[0, 0]) BL[0, 0] = BS[tx, 0] - with T.block(): + with T.sblock(): T.reads(AL[0, 0], BL[0, 0]) T.writes(C[tx, i]) C[tx, i] = AL[0, 0] * BL[0, 0] @@ -419,59 +419,59 @@ def transformed_dag_interleaving( C: T.Buffer((16, 16), "float32"), ) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0:16], B[tx, 0:16]) T.writes(C[tx, 0:16]) AS = T.alloc_buffer([16, 1], dtype="float32", scope="shared") BS = T.alloc_buffer([16, 1], dtype="float32", scope="shared") AL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local") BL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local") - with T.block(): + with T.sblock(): T.reads(A[tx, 0], B[tx, 0], AS[tx, 0], BS[tx, 0]) T.writes(AS[tx, 0], BS[tx, 0], AL[0, 0, 0], BL[0, 0, 0]) - with T.block(): + with T.sblock(): T.reads(A[tx, 0]) T.writes(AS[tx, 0]) AS[tx, 0] = A[tx, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(BS[tx, 0]) BS[tx, 0] = B[tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(AS[tx, 0]) T.writes(AL[0, 0, 0]) AL[0, 0, 0] = AS[tx, 0] - with T.block(): + with T.sblock(): T.reads(BS[tx, 0]) T.writes(BL[0, 0, 0]) BL[0, 0, 0] = BS[tx, 0] - with T.block(): + with T.sblock(): T.reads( A[tx, 1:16], B[tx, 1:16], AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0] ) T.writes(AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0], C[tx, 0:15]) for i in T.serial(15): - with T.block(): + with T.sblock(): T.reads(A[tx, i + 1]) T.writes(AS[tx, 0]) AS[tx, 0] = A[tx, i + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, i + 1]) T.writes(BS[tx, 0]) BS[tx, 0] = B[tx, i + 1] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(AS[tx, 0]) T.writes(AL[(i + 1) % 2, 0, 0]) AL[(i + 1) % 2, 0, 0] = AS[tx, 0] - with T.block(): + with T.sblock(): T.reads(BS[tx, 0]) T.writes(BL[(i + 1) % 2, 0, 0]) BL[(i + 1) % 2, 0, 0] = BS[tx, 0] - with T.block(): + with T.sblock(): T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0]) T.writes(C[tx, i]) C[tx, i] = AL[i % 2, 0, 0] * BL[i % 2, 0, 0] - with T.block(): + with T.sblock(): T.reads(AL[1, 0, 0], BL[1, 0, 0]) T.writes(C[tx, 15]) C[tx, 15] = AL[1, 0, 0] * BL[1, 0, 0] @@ -490,12 +490,12 @@ def nested_pipeline_simple( "software_pipeline_order": [0, 1, 2, 3], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads(A[tx, i, j]) T.writes(A_shared[tx, 0, j]) A_shared[tx, 0, j] = A[tx, i, j] @@ -507,15 +507,15 @@ def nested_pipeline_simple( "software_pipeline_order": [0, 1], }, ): - with T.block(): + with T.sblock(): T.reads(A_shared[tx, 0, j]) T.writes(C[tx, i, j]) B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A_shared[tx, i, j]) T.writes(B[tx, i, 0]) B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, i, 0]) T.writes(C[tx, i, j]) C[tx, i, j] = B[tx, i, 0] + T.float32(1) @@ -526,73 +526,73 @@ def transformed_nested_pipeline_simple( A: T.Buffer((16, 16, 16), "float32"), C: T.Buffer((16, 16, 16), "float32") ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads([A[tx, 0:16, 0:16]]) T.writes([C[tx, 0:16, 0:16]]) A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16]]) T.writes([A_shared[0, tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, 0, j]]) T.writes([A_shared[0, tx, 0, j]]) A_shared[0, tx, 0, j] = A[tx, 0, j] - with T.block(): + with T.sblock(): T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]]) T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]]) for i in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, 0:16]]) T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, j]]) T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] - with T.block(): + with T.sblock(): T.reads([A_shared[i % 2, tx, i, 0]]) T.writes([B[0, tx, i, 0]]) B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_shared[i % 2, tx, i, j + 1]]) T.writes([B[(j + 1) % 2, tx, i, 0]]) B[(j + 1) % 2, tx, i, 0] = A_shared[ i % 2, tx, 0, j + 1 ] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, 0]]) T.writes([B[0, tx, 15, 0]]) B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, j + 1]]) T.writes([B[(j + 1) % 2, tx, 15, 0]]) B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, 15, 0]]) T.writes([C[tx, 15, 15]]) C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) @@ -611,12 +611,12 @@ def nested_pipeline_prefetch_inner( "software_pipeline_order": [0, 2, 1, 3], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads(A[tx, i, j]) T.writes(A_shared[tx, 0, j]) A_shared[tx, 0, j] = A[tx, i, j] @@ -628,15 +628,15 @@ def nested_pipeline_prefetch_inner( "software_pipeline_order": [0, 1], }, ): - with T.block(): + with T.sblock(): T.reads(A_shared[tx, 0, j]) T.writes(C[tx, i, j]) B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A_shared[tx, i, j]) T.writes(B[tx, i, 0]) B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, i, 0]) T.writes(C[tx, i, j]) C[tx, i, j] = B[tx, i, 0] + T.float32(1) @@ -647,76 +647,76 @@ def transformed_nested_pipeline_prefetch_inner( A: T.Buffer((16, 16, 16), "float32"), C: T.Buffer((16, 16, 16), "float32") ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads([A[tx, 0:16, 0:16]]) T.writes([C[tx, 0:16, 0:16]]) A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]]) T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]]) - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16]]) T.writes([A_shared[0, tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, 0, j]]) T.writes([A_shared[0, tx, 0, j]]) A_shared[0, tx, 0, j] = A[tx, 0, j] - with T.block(): + with T.sblock(): T.reads([A_shared[0, tx, 0, 0]]) T.writes([B[0, tx, 0, 0]]) B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]]) T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]]) for i in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, 0:16]]) T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, j]]) T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] - with T.block(): + with T.sblock(): T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_shared[i % 2, tx, i, j + 1]]) T.writes([B[(j + 1) % 2, tx, i, 0]]) B[(j + 1) % 2, tx, i, 0] = A_shared[ i % 2, tx, 0, j + 1 ] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_shared[1, tx, 15, j + 1]]) T.writes([B[(j + 1) % 2, tx, 15, 0]]) B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, 15, 0]]) T.writes([C[tx, 15, 15]]) C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) @@ -735,18 +735,18 @@ def nested_pipeline_interleaving( "software_pipeline_order": [0, 2, 3, 1, 4], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads(A[tx, i, j]) T.writes(A_shared[tx, 0, j]) A_shared[tx, 0, j] = A[tx, i, j] for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads(A_shared[tx, 0, j]) T.writes(A_local[0, 0, j]) A_local[0, 0, j] = A_shared[tx, i, j] @@ -758,15 +758,15 @@ def nested_pipeline_interleaving( "software_pipeline_order": [0, 1], }, ): - with T.block(): + with T.sblock(): T.reads(A_local[0, 0, j]) T.writes(C[tx, i, j]) B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A_local[tx, i, j]) T.writes(B[tx, i, 0]) B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, i, 0]) T.writes(C[tx, i, j]) C[tx, i, j] = B[tx, i, 0] + T.float32(1) @@ -777,36 +777,36 @@ def transformed_nested_pipeline_interleaving( A: T.Buffer((16, 16, 16), "float32"), C: T.Buffer((16, 16, 16), "float32") ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads([A[tx, 0:16, 0:16]]) T.writes([C[tx, 0:16, 0:16]]) A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local") B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]]) T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]]) - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16]]) T.writes([A_shared[tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, 0, j]]) T.writes([A_shared[tx, 0, j]]) A_shared[tx, 0, j] = A[tx, 0, j] - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, 0:16]]) T.writes([A_local[0, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, j]]) T.writes([A_local[0, 0, j]]) A_local[0, 0, j] = A_shared[tx, 0, j] - with T.block(): + with T.sblock(): T.reads([A_local[tx, 0, 0]]) T.writes([B[0, tx, 0, 0]]) B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads( [ A[tx, 1:16, 0:16], @@ -824,58 +824,58 @@ def transformed_nested_pipeline_interleaving( ] ) for i in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, 0:16]]) T.writes([A_shared[tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, j]]) T.writes([A_shared[tx, 0, j]]) A_shared[tx, 0, j] = A[tx, i + 1, j] - with T.block(): + with T.sblock(): T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_local[tx, i, j + 1]]) T.writes([B[(j + 1) % 2, tx, i, 0]]) B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, 0:16]]) T.writes([A_local[0, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, j]]) T.writes([A_local[0, 0, j]]) A_local[0, 0, j] = A_shared[tx, i + 1, j] - with T.block(): + with T.sblock(): T.reads([A_local[tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) - with T.block(): + with T.sblock(): T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_local[tx, 15, j + 1]]) T.writes([B[(j + 1) % 2, tx, 15, 0]]) B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, 15, 0]]) T.writes([C[tx, 15, 15]]) C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) @@ -894,19 +894,19 @@ def nested_pipeline_double_buffer( "software_pipeline_order": [0, 2, 3, 1, 4], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads(A[tx, i, j]) T.writes(A_shared[tx, 0, j]) A_shared[tx, 0, j] = A[tx, i, j] for j in T.serial(0, 16): - with T.block(): - T.block_attr({"double_buffer_scope": 0}) + with T.sblock(): + T.sblock_attr({"double_buffer_scope": 0}) T.reads(A_shared[tx, 0, j]) T.writes(A_local[0, 0, j]) A_local[0, 0, j] = A_shared[tx, i, j] @@ -918,15 +918,15 @@ def nested_pipeline_double_buffer( "software_pipeline_order": [0, 1], }, ): - with T.block(): + with T.sblock(): T.reads(A_local[0, 0, j]) T.writes(C[tx, i, j]) B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A_local[tx, i, j]) T.writes(B[tx, i, 0]) B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, i, 0]) T.writes(C[tx, i, j]) C[tx, i, j] = B[tx, i, 0] + T.float32(1) @@ -937,37 +937,37 @@ def transformed_nested_pipeline_double_buffer( A: T.Buffer((16, 16, 16), "float32"), C: T.Buffer((16, 16, 16), "float32") ) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads([A[tx, 0:16, 0:16]]) T.writes([C[tx, 0:16, 0:16]]) A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local") B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]]) T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]]) - with T.block(): + with T.sblock(): T.reads([A[tx, 0, 0:16]]) T.writes([A_shared[tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, 0, j]]) T.writes([A_shared[tx, 0, j]]) A_shared[tx, 0, j] = A[tx, 0, j] - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, 0:16]]) T.writes([A_local[0, 0, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, j]]) T.writes([A_local[0, 0, 0, j]]) - T.block_attr({"double_buffer_scope": 0}) + T.sblock_attr({"double_buffer_scope": 0}) A_local[0, 0, 0, j] = A_shared[tx, 0, j] - with T.block(): + with T.sblock(): T.reads([A_local[0, tx, 0, 0]]) T.writes([B[0, tx, 0, 0]]) B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads( [ A[tx, 1:16, 0:16], @@ -985,61 +985,61 @@ def transformed_nested_pipeline_double_buffer( ] ) for i in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, 0:16]]) T.writes([A_shared[tx, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A[tx, i + 1, j]]) T.writes([A_shared[tx, 0, j]]) A_shared[tx, 0, j] = A[tx, i + 1, j] - with T.block(): + with T.sblock(): T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_local[i % 2, tx, i, j + 1]]) T.writes([B[(j + 1) % 2, tx, i, 0]]) B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( 2 ) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, 0:16]]) T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) for j in T.serial(0, 16): - with T.block(): + with T.sblock(): T.reads([A_shared[tx, 0, j]]) T.writes([A_local[(i + 1) % 2, 0, 0, j]]) - T.block_attr({"double_buffer_scope": 0}) + T.sblock_attr({"double_buffer_scope": 0}) A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] - with T.block(): + with T.sblock(): T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) - with T.block(): + with T.sblock(): T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) for j in T.serial(0, 15): - with T.block(): + with T.sblock(): T.reads([A_local[1, tx, 15, j + 1]]) T.writes([B[(j + 1) % 2, tx, 15, 0]]) B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads([B[1, tx, 15, 0]]) T.writes([C[tx, 15, 15]]) C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) @@ -1058,20 +1058,20 @@ def simple_compute_incorrect_reorder( "software_pipeline_order": [0, 2, 1], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(D[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(C[tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[tx, 0] + T.float32(1) @@ -1090,20 +1090,20 @@ def simple_compute_conflicting_order( "software_pipeline_order": [0, 1, 1], }, ): - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(D[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[tx, 0] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(C[tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[tx, 0] + T.float32(1) @@ -1115,15 +1115,15 @@ def simple_compute_missing_annotation( ): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1) @@ -1185,42 +1185,42 @@ def test_simple_compute_async(): mod = tvm.IRModule.from_expr(gen_simple_compute(1).with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) - _, loop = sch.get_loops(sch.get_block("compute")) + _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0:16]) T.writes(C[tx, 0:16]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, 0]) T.writes(B[T.FloorMod(0, 2), tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(A[tx, 1:16], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[tx, 0:15]) for i in T.serial(15): - with T.block(): + with T.sblock(): T.where(i + 1 < 16) T.reads(A[tx, i + 1]) T.writes(B[(i + 1) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) - with T.block(): + with T.sblock(): T.where(i + 1 - 1 < 16) T.reads(B[(i - 1 + 1) % 2, tx, 0]) T.writes(C[tx, i - 1 + 1]) with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads(B[T.FloorMod(15, 2), tx, 0]) T.writes(C[tx, 15]) with T.attr(0, "async_wait_queue_scope", 0): @@ -1232,51 +1232,51 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): mod = tvm.IRModule.from_expr(gen_simple_compute(3).with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) - _, loop = sch.get_loops(sch.get_block("compute")) + _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) @T.prim_func def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0:16]) T.writes(C[tx, 0:16]) B = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, 0:3]) T.writes(B[0:3, tx, 0]) for i in T.unroll(3): - with T.block(): + with T.sblock(): T.where(i < 16) T.reads(A[tx, i]) T.writes(B[i % 4, tx, 0]) T.attr(0, "async_commit_queue_scope", 0) T.attr(0, "async_scope", 1) B[i % 4, tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.reads(A[tx, 3:16], B[0:4, tx, 0]) T.writes(B[0:4, tx, 0], C[tx, 0:13]) for i in T.serial(13): - with T.block(): + with T.sblock(): T.where(i + 3 < 16) T.reads(A[tx, i + 3]) T.writes(B[(i + 3) % 4, tx, 0]) T.attr(0, "async_commit_queue_scope", 0) T.attr(0, "async_scope", 1) B[(i + 3) % 4, tx, 0] = A[tx, i + 3] * T.float32(2) - with T.block(): + with T.sblock(): T.where(i + 3 - 3 < 16) T.reads(B[0:4, tx, 0]) T.writes(C[tx, i - 3 + 3]) with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 3): C[tx, i - 3 + 3] = B[(i - 3 + 3) % 4, tx, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads(B[0:4, tx, 0]) T.writes(C[tx, 13:16]) for i in T.unroll(3): - with T.block(): + with T.sblock(): T.where(i + 16 - 3 < 16) T.reads(B[0:4, tx, 0]) T.writes(C[tx, i - 3 + 16]) @@ -1296,20 +1296,20 @@ def simple_compute( ): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in range(16): - with T.block("compute"): + with T.sblock("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, i]) T.writes(A_shared[tx, 0]) A_shared[tx, 0] = A[tx, i] - with T.block(): + with T.sblock(): T.reads(B[tx, i]) T.writes(B_shared[tx, 0]) B_shared[tx, 0] = B[tx, i] - with T.block(): + with T.sblock(): T.reads(A_shared[tx, 0], B_shared[tx, 0]) T.writes(C[tx, i]) C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] @@ -1317,7 +1317,7 @@ def simple_compute( mod = tvm.IRModule.from_expr(simple_compute.with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) - _, loop = sch.get_loops(sch.get_block("compute")) + _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) sch.annotate(loop, ann_key="software_pipeline_order", ann_val=[0, 2, 1]) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) @@ -1330,16 +1330,16 @@ def ref( C: T.Buffer((16, 16), "float32"), ) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0:16], B[tx, 0:16]) T.writes(C[tx, 0:16]) A_shared = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared") B_shared = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, 0:3], B[tx, 0:3]) T.writes(A_shared[0:3, tx, 0], B_shared[0:3, tx, 0]) for i in T.unroll(3): - with T.block(): + with T.sblock(): T.where(i < 16) T.reads(A[tx, i], B[tx, i]) T.writes(A_shared[i % 4, tx, 0], B_shared[i % 4, tx, 0]) @@ -1348,18 +1348,18 @@ def ref( A_shared[i % 4, tx, 0] = A[tx, i] with T.attr(0, "async_scope", 1): B_shared[i % 4, tx, 0] = B[tx, i] - with T.block(): + with T.sblock(): T.reads(A[tx, 3:16], A_shared[0:4, tx, 0], B_shared[0:4, tx, 0], B[tx, 3:16]) T.writes(A_shared[0:4, tx, 0], C[tx, 0:13], B_shared[0:4, tx, 0]) for i in T.serial(13): - with T.block(): + with T.sblock(): T.where(i + 3 < 16) T.reads(A[tx, i + 3]) T.writes(A_shared[(i + 3) % 4, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): A_shared[(i + 3) % 4, tx, 0] = A[tx, i + 3] - with T.block(): + with T.sblock(): T.where(i + 3 - 3 < 16) T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0]) T.writes(C[tx, i - 3 + 3]) @@ -1369,18 +1369,18 @@ def ref( A_shared[(i - 3 + 3) % 4, tx, 0] + B_shared[(i - 3 + 3) % 4, tx, 0] ) - with T.block(): + with T.sblock(): T.where(i + 3 < 16) T.reads(B[tx, i + 3]) T.writes(B_shared[(i + 3) % 4, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B_shared[(i + 3) % 4, tx, 0] = B[tx, i + 3] - with T.block(): + with T.sblock(): T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0]) T.writes(C[tx, 13:16]) for i in T.unroll(3): - with T.block(): + with T.sblock(): T.where(i + 16 - 3 < 16) T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0]) T.writes(C[tx, i - 3 + 16]) @@ -1398,7 +1398,7 @@ def test_three_stage_compute_two_stage_async(): mod = tvm.IRModule.from_expr(three_stage_compute.with_attr("global_symbol", "main")) sch = tvm.tir.Schedule(mod) - _, loop = sch.get_loops(sch.get_block("compute")) + _, loop = sch.get_loops(sch.get_sblock("compute")) sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0, 1]) mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) @@ -1406,23 +1406,23 @@ def test_three_stage_compute_two_stage_async(): @T.prim_func def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): - with T.block(): + with T.sblock(): T.reads(A[tx, 0:16]) T.writes(D[tx, 0:16]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") - with T.block(): + with T.sblock(): T.reads(A[tx, 0:2], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) for i in T.unroll(2): - with T.block(): + with T.sblock(): T.where(i < 16) T.reads(A[tx, i]) T.writes(B[i % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[i % 2, tx, 0] = A[tx, i] * T.float32(2) - with T.block(): + with T.sblock(): T.where(i == 1 and i - 1 < 16) T.reads(B[(i - 1) % 2, tx, 0]) T.writes(C[(i - 1) % 2, tx, 0]) @@ -1433,18 +1433,18 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N C[(i - 1) % 2, tx, 0] = B[ (i - 1) % 2, tx, 0 ] + T.float32(2) - with T.block(): + with T.sblock(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) for i in T.serial(14): - with T.block(): + with T.sblock(): T.where(i + 2 < 16) T.reads(A[tx, i + 2]) T.writes(B[(i + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2) - with T.block(): + with T.sblock(): T.where(i + 2 - 1 < 16) T.reads(B[(i - 1 + 2) % 2, tx, 0]) T.writes(C[(i - 1 + 2) % 2, tx, 0]) @@ -1455,18 +1455,18 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N C[(i - 1 + 2) % 2, tx, 0] = B[ (i - 1 + 2) % 2, tx, 0 ] + T.float32(2) - with T.block(): + with T.sblock(): T.where(i + 2 - 2 < 16) T.reads(C[0:2, tx, 0]) T.writes(D[tx, i - 2 + 2]) with T.attr(0, "async_wait_queue_scope", 1): with T.attr(0, "async_wait_inflight_count", 1): D[tx, i - 2 + 2] = C[(i - 2 + 2) % 2, tx, 0] + T.float32(1) - with T.block(): + with T.sblock(): T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(C[0:2, tx, 0], D[tx, 14:16]) for i in T.unroll(2): - with T.block(): + with T.sblock(): T.where(i + 16 - 1 < 16) T.reads(B[(i - 1 + 16) % 2, tx, 0]) T.writes(C[(i - 1 + 16) % 2, tx, 0]) @@ -1477,7 +1477,7 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N C[(i - 1 + 16) % 2, tx, 0] = B[ (i - 1 + 16) % 2, tx, 0 ] + T.float32(2) - with T.block(): + with T.sblock(): T.where(i + 16 - 2 < 16) T.reads(C[0:2, tx, 0]) T.writes(D[tx, i - 2 + 16]) @@ -1549,7 +1549,7 @@ def build_and_run(sch): def test_async_pipelined_mma_gemm_simple(): sch = get_mma_schedule() - k0 = sch.get_loops(sch.get_block("C_o_update"))[3] + k0 = sch.get_loops(sch.get_sblock("C_o_update"))[3] sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) @@ -1590,8 +1590,8 @@ def test_async_pipelined_mma_gemm_simple(): def test_async_nested_pipeline_mma_gemm_ideal_annotation(): sch = get_mma_schedule() - k0 = sch.get_loops(sch.get_block("C_o_update"))[3] - k1 = sch.get_loops(sch.get_block("C_o_update"))[4] + k0 = sch.get_loops(sch.get_sblock("C_o_update"))[3] + k1 = sch.get_loops(sch.get_sblock("C_o_update"))[4] sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 2, 3, 3]) sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 3, 2, 4]) @@ -1644,63 +1644,63 @@ def before(A: T.Buffer((2,), "float32"), E: T.Buffer((2,), "float32")): "software_pipeline_order": [0, 1, 2, 3], }, ): - with T.block("compute"): + with T.sblock("compute"): B = T.alloc_buffer((1), dtype="float32", scope="shared") C = T.alloc_buffer((1), dtype="float32", scope="shared") D = T.alloc_buffer((1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): B[0] = A[i] * T.float32(2) - with T.block(): + with T.sblock(): C[0] = B[0] + T.float32(3) - with T.block(): + with T.sblock(): D[0] = C[0] + T.float32(4) - with T.block(): + with T.sblock(): E[i] = D[0] + T.float32(5) @T.prim_func def after(A: T.Buffer((2,), "float32"), E: T.Buffer((2,), "float32")): - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block(""): + with T.sblock(""): T.reads(A[0:3]) T.writes(E[0:2]) B = T.alloc_buffer((2, 1), scope="shared") C = T.alloc_buffer((2, 1), scope="shared") D = T.alloc_buffer((2, 1), scope="shared") - with T.block(""): + with T.sblock(""): T.reads(A[0:3], B[0:2, 0], C[0:2, 0]) T.writes(B[0:2, 0], C[0:2, 0], D[0:2, 0]) for i in T.unroll(3): - with T.block(""): + with T.sblock(""): T.where(i < 2) T.reads(A[i]) T.writes(B[0:2, 0]) B[i % 2, 0] = A[i] * T.float32(2.0) - with T.block(""): + with T.sblock(""): T.where(1 <= i) T.reads(B[0:2, 0]) T.writes(C[0:2, 0]) C[(i + 1) % 2, 0] = B[(i + 1) % 2, 0] + T.float32(3.0) - with T.block(""): + with T.sblock(""): T.where(i == 2) T.reads(C[0:2, 0]) T.writes(D[0:2, 0]) D[i % 2, 0] = C[i % 2, 0] + T.float32(4.0) - with T.block(""): + with T.sblock(""): T.reads() T.writes() T.evaluate(0) - with T.block(""): + with T.sblock(""): T.reads(C[0:2, 0], D[0:2, 0]) T.writes(D[0:2, 0], E[0:2]) for i in T.unroll(2): - with T.block(""): + with T.sblock(""): T.where(i < 1) T.reads(C[0:2, 0]) T.writes(D[0:2, 0]) D[(i + 1) % 2, 0] = C[(i + 1) % 2, 0] + T.float32(4.0) - with T.block(""): + with T.sblock(""): T.reads(D[0:2, 0]) T.writes(E[i]) E[i] = D[i, 0] + T.float32(5.0) @@ -1722,17 +1722,17 @@ def before(a: T.handle, b: T.handle): "software_pipeline_order": [0, 1, 2, 3], }, ): - with T.block("compute"): + with T.sblock("compute"): B = T.alloc_buffer((1), dtype="float32", scope="shared") C = T.alloc_buffer((1), dtype="float32", scope="shared") D = T.alloc_buffer((1), dtype="float32", scope="shared") - with T.block(): + with T.sblock(): B[0] = A[i] * T.float32(2) - with T.block(): + with T.sblock(): C[0] = B[0] + T.float32(3) - with T.block(): + with T.sblock(): D[0] = C[0] + T.float32(4) - with T.block(): + with T.sblock(): E[i] = D[0] + T.float32(5) @T.prim_func @@ -1740,69 +1740,69 @@ def after(a: T.handle, b: T.handle): K = T.int32() A = T.match_buffer(a, [K], "float32") E = T.match_buffer(b, [K], "float32") - with T.block("root"): + with T.sblock("root"): T.reads() T.writes() - with T.block(""): + with T.sblock(""): T.reads(A[0 : T.max(3, K)]) T.writes(E[T.min(0, K - 3) : T.min(0, K - 3) + T.max(K, 3)]) B = T.alloc_buffer((2, 1), scope="shared") C = T.alloc_buffer((2, 1), scope="shared") D = T.alloc_buffer((2, 1), scope="shared") - with T.block(""): + with T.sblock(""): T.reads(A[0:3], B[0:2, 0], C[0:2, 0]) T.writes(B[0:2, 0], C[0:2, 0], D[0:2, 0]) for i in T.unroll(3): - with T.block(""): + with T.sblock(""): T.where(i < K) T.reads(A[i]) T.writes(B[0:2, 0]) B[i % 2, 0] = A[i] * T.float32(2.0) - with T.block(""): + with T.sblock(""): T.where(1 <= i and i <= K) T.reads(B[0:2, 0]) T.writes(C[0:2, 0]) C[(i + 1) % 2, 0] = B[(i + 1) % 2, 0] + T.float32(3.0) - with T.block(""): + with T.sblock(""): T.where(i == 2 and i < K + 2) T.reads(C[0:2, 0]) T.writes(D[0:2, 0]) D[i % 2, 0] = C[i % 2, 0] + T.float32(4.0) - with T.block(""): + with T.sblock(""): T.reads(A[3 : 3 + (K - 3)], B[0:2, 0], C[0:2, 0], D[0:2, 0]) T.writes(B[0:2, 0], C[0:2, 0], D[0:2, 0], E[0 : K - 3]) for i in range(K - 3): - with T.block(""): + with T.sblock(""): T.reads(A[i + 3]) T.writes(B[0:2, 0]) B[(i + 1) % 2, 0] = A[i + 3] * T.float32(2.0) - with T.block(""): + with T.sblock(""): T.reads(B[0:2, 0]) T.writes(C[0:2, 0]) C[i % 2, 0] = B[i % 2, 0] + T.float32(3.0) - with T.block(""): + with T.sblock(""): T.reads(C[0:2, 0]) T.writes(D[0:2, 0]) D[(i + 1) % 2, 0] = C[(i + 1) % 2, 0] + T.float32(4.0) - with T.block(""): + with T.sblock(""): T.reads(D[0:2, 0]) T.writes(E[i]) E[i] = D[i % 2, 0] + T.float32(5.0) - with T.block(""): + with T.sblock(""): T.reads(B[0:2, 0], C[0:2, 0], D[0:2, 0]) T.writes(C[0:2, 0], D[0:2, 0], E[K - 3 : K - 3 + 3]) for i in T.unroll(3): - with T.block(""): + with T.sblock(""): T.where(1 <= i + K and i + K == K and 3 <= i + K) T.reads(B[0:2, 0]) T.writes(C[0:2, 0]) C[(i + K + 1) % 2, 0] = B[(i + K + 1) % 2, 0] + T.float32(3.0) - with T.block(""): + with T.sblock(""): T.where(2 <= i + K and i < 2 and 3 <= i + K) T.reads(C[0:2, 0]) T.writes(D[0:2, 0]) D[(i + K) % 2, 0] = C[(i + K) % 2, 0] + T.float32(4.0) - with T.block(""): + with T.sblock(""): T.where(3 <= i + K and 3 <= i + K) T.reads(D[0:2, 0]) T.writes(E[i + K - 3]) diff --git a/tests/python/tir-transform/test_tir_transform_lift_thread_binding.py b/tests/python/tir-transform/test_tir_transform_lift_thread_binding.py index 84868ae2ed16..eacd3bbcfe4c 100644 --- a/tests/python/tir-transform/test_tir_transform_lift_thread_binding.py +++ b/tests/python/tir-transform/test_tir_transform_lift_thread_binding.py @@ -28,49 +28,49 @@ def before(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (32, n, 128)) C = T.match_buffer(c, (32, 1, n)) for ax0_ax1_fused in T.thread_binding(n * 32, thread="blockIdx.x"): - with T.block(""): + with T.sblock(""): T.reads(A[ax0_ax1_fused // n, 0, 0:256], B[ax0_ax1_fused // n, ax0_ax1_fused % n, 0:256]) T.writes(C[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) D_local = T.alloc_buffer((32, 1, n), scope="local") D_rf_local = T.alloc_buffer((256, 32, 1, n), scope="local") for ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): T.reads() T.writes(D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = T.float32(0) for ax2_fused_0 in range(1): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): T.where(ax2_fused_0 * 256 + ax2_fused_1 < 128) T.reads(D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n], A[ax0_ax1_fused // n, 0, ax2_fused_0 * 256 + ax2_fused_1], B[ax0_ax1_fused // n, ax0_ax1_fused % n, ax2_fused_0 * 256 + ax2_fused_1]) T.writes(D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] + A[ax0_ax1_fused // n, 0, ax2_fused_0 * 256 + ax2_fused_1] * B[ax0_ax1_fused // n, ax0_ax1_fused % n, ax2_fused_0 * 256 + ax2_fused_1] for ax1_ax2_fused in range(1): for ax0_fused in T.thread_binding(256, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): T.reads(D_rf_local[ax0_fused, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) T.writes(D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) cross_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") in_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") - with T.block("NT_matmul_in_thread_init"): + with T.sblock("NT_matmul_in_thread_init"): T.reads() T.writes(in_thread_D_local[0]) in_thread_D_local[0] = T.float32(0) - with T.block("NT_matmul_in_thread"): + with T.sblock("NT_matmul_in_thread"): T.where(0 <= ax0_ax1_fused // n and ax0_ax1_fused // n < 32 and 0 <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) T.reads(D_rf_local[ax0_fused, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) T.writes(in_thread_D_local[0]) in_thread_D_local[0] = in_thread_D_local[0] + D_rf_local[ax0_fused, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] - with T.block("NT_matmul_cross_thread"): + with T.sblock("NT_matmul_cross_thread"): T.reads(in_thread_D_local[0]) T.writes(cross_thread_D_local[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), in_thread_D_local[0], T.bool(True), cross_thread_D_local[0], ax0_fused) - with T.block("NT_matmul_write_back"): + with T.sblock("NT_matmul_write_back"): T.where(ax0_fused == 0) T.reads(cross_thread_D_local[0]) T.writes(D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = cross_thread_D_local[0] - with T.block("T_divide"): + with T.sblock("T_divide"): T.where(0 <= ax0_ax1_fused // n and ax0_ax1_fused // n < 32 and 0 <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) T.reads(D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) T.writes(C[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) @@ -81,50 +81,50 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): n = T.int32() B = T.match_buffer(b, (32, n, 128)) C = T.match_buffer(c, (32, 1, n)) - # with T.block("root"): + # with T.sblock("root"): for blockIdx_x in T.thread_binding(n * 32, thread="blockIdx.x"): for threadIdx_x in T.thread_binding(256, thread="threadIdx.x"): - with T.block(""): + with T.sblock(""): T.reads(A[blockIdx_x // n, 0, 0:256], B[blockIdx_x // n, blockIdx_x % n, 0:256]) T.writes(C[blockIdx_x // n, 0, blockIdx_x % n]) D_local = T.alloc_buffer((32, 1, n), scope="local") D_rf_local = T.alloc_buffer((256, 32, 1, n), scope="local") - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): T.reads() T.writes(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] = T.float32(0) for ax2_fused_0 in range(1): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): T.where(ax2_fused_0 * 256 + threadIdx_x < 128) T.reads(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n], A[blockIdx_x // n, 0, ax2_fused_0 * 256 + threadIdx_x], B[blockIdx_x // n, blockIdx_x % n, ax2_fused_0 * 256 + threadIdx_x]) T.writes(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] = D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] + A[blockIdx_x // n, 0, ax2_fused_0 * 256 + threadIdx_x] * B[blockIdx_x // n, blockIdx_x % n, ax2_fused_0 * 256 + threadIdx_x] for ax1_ax2_fused in range(1): - with T.block(""): + with T.sblock(""): T.reads(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) T.writes(D_local[blockIdx_x // n, 0, blockIdx_x % n]) cross_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") in_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") - with T.block("NT_matmul_in_thread_init"): + with T.sblock("NT_matmul_in_thread_init"): T.reads() T.writes(in_thread_D_local[0]) in_thread_D_local[0] = T.float32(0) - with T.block("NT_matmul_in_thread"): + with T.sblock("NT_matmul_in_thread"): T.where(0 <= blockIdx_x // n and blockIdx_x // n < 32 and 0 <= blockIdx_x % n and blockIdx_x % n < n) T.reads(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) T.writes(in_thread_D_local[0]) in_thread_D_local[0] = in_thread_D_local[0] + D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] - with T.block("NT_matmul_cross_thread"): + with T.sblock("NT_matmul_cross_thread"): T.reads(in_thread_D_local[0]) T.writes(cross_thread_D_local[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), in_thread_D_local[0], T.bool(True), cross_thread_D_local[0], threadIdx_x) - with T.block("NT_matmul_write_back"): + with T.sblock("NT_matmul_write_back"): T.where(threadIdx_x == 0) T.reads(cross_thread_D_local[0]) T.writes(D_local[blockIdx_x // n, 0, blockIdx_x % n]) D_local[blockIdx_x // n, 0, blockIdx_x % n] = cross_thread_D_local[0] - with T.block("T_divide"): + with T.sblock("T_divide"): T.where(0 <= blockIdx_x // n and blockIdx_x // n < 32 and 0 <= blockIdx_x % n and blockIdx_x % n < n) T.reads(D_local[blockIdx_x // n, 0, blockIdx_x % n]) T.writes(C[blockIdx_x // n, 0, blockIdx_x % n]) diff --git a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py index 18e16513f481..8137ae674a8a 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py @@ -46,7 +46,7 @@ def loop_split(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i, ko in T.grid(128, 4): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) T.reads([A[vi, vk]]) @@ -64,18 +64,18 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("B_in_thread_init"): + with T.sblock("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): - with T.block("B_normal_reduction"): + with T.sblock("B_normal_reduction"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) T.reads([A[vi, vk]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( @@ -93,7 +93,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: dtype="handle", ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.S(128, i) T.where(ki == 0) T.reads([reduce_temp0[0]]) @@ -107,7 +107,7 @@ def no_normal_reduction(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([B[vi]]) @@ -124,7 +124,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) @@ -138,7 +138,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: T.uint32(1), A[vi, vk], True, reduce_temp0[0], k, dtype="handle" ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.spatial(128, i) T.where(k == 0) T.reads([reduce_temp0[0]]) @@ -153,7 +153,7 @@ def two_bound_loops(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([A[vi, vk]]) @@ -172,7 +172,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([A[vi, vk]]) @@ -187,7 +187,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: T.uint32(1), A[vi, vk], True, reduce_temp0[0], ko, ki, dtype="handle" ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.spatial(128, i) T.where(ko == 0 and ki == 0) T.reads([reduce_temp0[0]]) @@ -203,7 +203,7 @@ def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): for k0i0, k1 in T.grid(4, 16): - with T.block("B_rf"): + with T.sblock("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([A[vi, vk0, vk1]]) @@ -212,7 +212,7 @@ def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): - with T.block("B"): + with T.sblock("B"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi]]) @@ -231,12 +231,12 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block("B_in_thread_init"): + with T.sblock("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): - with T.block("B_rf"): + with T.sblock("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([A[vi, vk0, vk1]]) @@ -245,13 +245,13 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): - with T.block("B_normal_reduction"): + with T.sblock("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( @@ -269,7 +269,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No dtype="handle", ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.spatial(16, i) T.where(k0o == 0) T.reads([reduce_temp0[0]]) @@ -283,7 +283,7 @@ def with_block_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i, ko in T.grid(128, 4): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) @@ -302,19 +302,19 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("B_in_thread_init"): + with T.sblock("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): - with T.block("B_normal_reduction"): + with T.sblock("B_normal_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([A[vi, vk]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( @@ -332,7 +332,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: dtype="handle", ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.spatial(128, i) T.where(ki == 0) T.reads([reduce_temp0[0]]) @@ -349,7 +349,7 @@ def single_reduction_loop_with_block_predicate( for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) @@ -362,7 +362,7 @@ def single_reduction_loop_with_block_predicate( ) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) @@ -375,7 +375,7 @@ def single_reduction_loop_with_block_predicate( ) for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) T.where(i1_0 * 512 + i1_1 < 256) @@ -383,7 +383,7 @@ def single_reduction_loop_with_block_predicate( A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3] ) T.writes(T_softmax_norm[i0_3, i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3] @@ -403,19 +403,19 @@ def lowered_single_reduction_loop_with_block_predicate( for i0 in T.serial(256): for ax0 in T.serial(1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_maxelem_in_thread_init"): + with T.sblock("T_softmax_maxelem_in_thread_init"): T.reads() T.writes(in_thread_0[0]) in_thread_0[0] = T.float32(-3.4028234663852886e38) for ax1_0 in T.serial(1): - with T.block("T_softmax_maxelem_in_thread"): + with T.sblock("T_softmax_maxelem_in_thread"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_1 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) T.reads(A[i0_1, k]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) - with T.block("T_softmax_maxelem_cross_thread"): + with T.sblock("T_softmax_maxelem_cross_thread"): T.reads(in_thread_0[0]) T.writes(cross_thread_0[0]) T.attr( @@ -435,7 +435,7 @@ def lowered_single_reduction_loop_with_block_predicate( dtype="handle", ) ) - with T.block("T_softmax_maxelem_write_back"): + with T.sblock("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0 + ax0) T.where(ax1_1 == 0) T.reads(cross_thread_0[0]) @@ -443,12 +443,12 @@ def lowered_single_reduction_loop_with_block_predicate( T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] for ax0 in T.serial(1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_expsum_in_thread_init"): + with T.sblock("T_softmax_expsum_in_thread_init"): T.reads() T.writes(in_thread_1[0]) in_thread_1[0] = T.float32(0) for ax1_0 in T.serial(1): - with T.block("T_softmax_expsum_in_thread"): + with T.sblock("T_softmax_expsum_in_thread"): T.where(ax1_0 * 512 + ax1_1 < 256) i0_3 = T.axis.spatial(256, i0 + ax0) k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) @@ -457,7 +457,7 @@ def lowered_single_reduction_loop_with_block_predicate( in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" ) - with T.block("T_softmax_expsum_cross_thread"): + with T.sblock("T_softmax_expsum_cross_thread"): T.reads(in_thread_1[0]) T.writes(cross_thread_1[0]) T.attr( @@ -475,7 +475,7 @@ def lowered_single_reduction_loop_with_block_predicate( dtype="handle", ) ) - with T.block("T_softmax_expsum_write_back"): + with T.sblock("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0 + ax0) T.where(ax1_1 == 0) T.reads(cross_thread_1[0]) @@ -483,7 +483,7 @@ def lowered_single_reduction_loop_with_block_predicate( T_softmax_expsum_shared[i0_4] = cross_thread_1[0] for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) T.where(i1_0 * 512 + i1_1 < 256) @@ -491,7 +491,7 @@ def lowered_single_reduction_loop_with_block_predicate( A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5] ) T.writes(T_softmax_norm[i0_5, i1]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") / T_softmax_expsum_shared[i0_5] @@ -515,7 +515,7 @@ def spatial_reduction_with_shared_prefetch( for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.serial(4): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial( 128, ax0_0_ax1_0_fused // 16 * 8 @@ -545,7 +545,7 @@ def spatial_reduction_with_shared_prefetch( for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.serial(4): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial( 128, ax0_0_ax1_0_fused % 16 * 8 @@ -572,7 +572,7 @@ def spatial_reduction_with_shared_prefetch( T.writes(B_shared[v0, v1]) B_shared[v0, v1] = B[v0, v1] for ax2_1_0 in range(192): - with T.block("B"): + with T.sblock("B"): v0 = T.axis.spatial( 128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8 ) @@ -585,7 +585,7 @@ def spatial_reduction_with_shared_prefetch( with T.init(): C_local[v0, v1] = T.float32(0) C_local[v0, v1] = C_local[v0, v1] + A_shared[v0, v2] * B_shared[v1, v2] - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8) v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8) T.reads(C_local[v0, v1]) @@ -607,7 +607,7 @@ def lowered_spatial_reduction_with_shared_prefetch( for ax0_0_ax1_0_fused in T.thread_binding(256, thread="blockIdx.x"): for ax0_1_ax1_1_fused in T.thread_binding(64, thread="threadIdx.y"): for ax2_1_1_fused in T.thread_binding(2, thread="threadIdx.x"): - with T.block("B_in_thread_init"): + with T.sblock("B_in_thread_init"): T.reads() T.writes(in_thread_C_local[0]) in_thread_C_local[0] = T.float32(0) @@ -616,7 +616,7 @@ def lowered_spatial_reduction_with_shared_prefetch( for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"): for ax0_ax1_fused_3 in range(4): - with T.block("A_shared"): + with T.sblock("A_shared"): v0 = T.axis.spatial( 128, ax0_0_ax1_0_fused // 16 * 8 @@ -646,7 +646,7 @@ def lowered_spatial_reduction_with_shared_prefetch( for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"): for ax0_ax1_fused_3 in range(4): - with T.block("B_shared"): + with T.sblock("B_shared"): v0 = T.axis.spatial( 128, ax0_0_ax1_0_fused % 16 * 8 @@ -673,7 +673,7 @@ def lowered_spatial_reduction_with_shared_prefetch( T.writes(B_shared[v0, v1]) B_shared[v0, v1] = B[v0, v1] for ax2_1_0 in range(192): - with T.block("B_in_thread"): + with T.sblock("B_in_thread"): v0 = T.axis.spatial( 128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8 ) @@ -686,7 +686,7 @@ def lowered_spatial_reduction_with_shared_prefetch( in_thread_C_local[0] = ( in_thread_C_local[0] + A_shared[v0, v2] * B_shared[v1, v2] ) - with T.block("B_cross_thread"): + with T.sblock("B_cross_thread"): T.reads(in_thread_C_local[0]) T.writes(cross_thread_C_local[0]) T.attr( @@ -701,14 +701,14 @@ def lowered_spatial_reduction_with_shared_prefetch( cross_thread_C_local[0], ax2_1_1_fused, ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8) v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8) T.reads(cross_thread_C_local[0]) T.writes(C_local[v0, v1]) C_local[v0, v1] = cross_thread_C_local[0] for tx in T.thread_binding(2, thread="threadIdx.x"): - with T.block("C_local"): + with T.sblock("C_local"): v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8) v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8) T.where(tx == 0) @@ -723,7 +723,7 @@ def spatial_reduction_loop_predicate(A: T.Buffer((2, 32), "float32"), B: T.Buffe for i_1 in T.thread_binding(16, thread="threadIdx.y"): for k_0 in range(1): for k_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial(2, i_0 * 16 + i_1) vk = T.axis.reduce(32, k_0 * 64 + k_1) T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32) @@ -743,19 +743,19 @@ def lowered_reduction_spatial_loop_predicate( for i_0 in range(1): for i_1 in T.thread_binding(16, thread="threadIdx.y"): for k_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("block_in_thread_init"): + with T.sblock("block_in_thread_init"): T.reads() T.writes(in_thread_B[0]) in_thread_B[0] = T.float32(0) for k_0 in range(1): - with T.block("block_in_thread"): + with T.sblock("block_in_thread"): vi = T.axis.spatial(2, i_0 * 16 + i_1) vk = T.axis.reduce(32, k_0 * 64 + k_1) T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32) T.reads(A[vi, vk]) T.writes(in_thread_B[0]) in_thread_B[0] = in_thread_B[0] + A[vi, vk] - with T.block("block_cross_thread"): + with T.sblock("block_cross_thread"): T.reads(in_thread_B[0]) T.writes(cross_thread_B[0]) T.attr( @@ -767,7 +767,7 @@ def lowered_reduction_spatial_loop_predicate( T.uint32(1), in_thread_B[0], T.bool(True), cross_thread_B[0], k_1 ) k_0 = T.int32() - with T.block("block_write_back"): + with T.sblock("block_write_back"): vi = T.axis.spatial(2, i_0 * 16 + i_1) T.where(i_0 * 16 + i_1 < 2 and k_1 == 0) T.reads(cross_thread_B[0]) @@ -782,9 +782,9 @@ def single_reduction_loop_with_tensorize( output: T.Buffer((1, 16, 7, 7, 32), "int32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for i1, i2, i3, i4, i5 in T.grid(16, 4, 98, 2, 32): - with T.block("compute_o"): + with T.sblock("compute_o"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1) oh = T.axis.spatial(7, (i2 * 6272 + i3 * 64 + i4 * 32 + i5) // 3584) @@ -800,12 +800,12 @@ def single_reduction_loop_with_tensorize( T.writes(output[n, oc_chunk, oh, ow, 0:32]) with T.init(): for x in T.serial(32): - with T.block("compute_init"): + with T.sblock("compute_init"): oc_block_i_init = T.axis.spatial(32, x) T.reads() T.writes(output[n, oc_chunk, oh, ow, oc_block_i_init]) output[n, oc_chunk, oh, ow, oc_block_i_init] = 0 - with T.block("compute_o"): + with T.sblock("compute_o"): T.reads( output[n, oc_chunk, oh, ow, 0:32], input_A[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], @@ -843,23 +843,23 @@ def nested_reduction_loop_with_inner_match_buffers( out: T.Buffer((4, 4), "int32"), ) -> None: # body - # with T.block("root") + # with T.sblock("root") for y in T.serial(4): - with T.block("C"): + with T.sblock("C"): yi = T.axis.spatial(4, y) T.reads(in0[yi, 0:16], in1[yi, 0:16]) T.writes(out[yi, 0:4]) for x in T.serial(4): - with T.block("C"): + with T.sblock("C"): xr = T.axis.reduce(4, x) with T.init(): for i in T.serial(4): - with T.block("C_init"): + with T.sblock("C_init"): ii = T.axis.spatial(4, i) T.reads() T.writes(out[yi, ii]) out[yi, ii] = 0 - with T.block("C"): + with T.sblock("C"): T.reads( out[yi, xr], in0[yi, yi * 4 + xr : yi * 4 + xr + 4], @@ -892,7 +892,7 @@ def reducer_max(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([B[vi]]) @@ -909,7 +909,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None: reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) @@ -923,7 +923,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None: T.uint32(1), A[vi, vk], True, reduce_temp0[0], k, dtype="handle" ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.spatial(128, i) T.where(k == 0) T.reads([reduce_temp0[0]]) @@ -936,7 +936,7 @@ def zero_rank_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128], dtype="float32") B = T.match_buffer(b, [], dtype="float32") for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vk = T.axis.reduce(128, k) T.reads([A[vk]]) T.writes([B[()]]) @@ -952,7 +952,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): vk = T.axis.reduce(128, k) T.reads([A[vk]]) T.writes([reduce_temp0[0]]) @@ -964,7 +964,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vk], True, reduce_temp0[0], k, dtype="handle") ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): T.reads([reduce_temp0[0]]) T.writes([B[()]]) T.where(k == 0) @@ -978,7 +978,7 @@ def multiple_bufferstore(a: T.handle, b: T.handle) -> None: C = T.alloc_buffer([], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk], B[vi], C[()]]) T.writes([B[vi], C[()]]) @@ -994,7 +994,7 @@ def reduction_loop_not_deepest(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for k in T.thread_binding(0, 128, thread="threadIdx.x"): for i in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([B[vi]]) @@ -1009,7 +1009,7 @@ def reduction_loop_bound_to_blockidx(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="blockIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([B[vi]]) @@ -1024,7 +1024,7 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128], dtype="float32") for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes( @@ -1046,7 +1046,7 @@ def invalid_reducer(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([B[vi]]) @@ -1064,7 +1064,7 @@ def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): + with T.sblock("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k]]) @@ -1076,7 +1076,7 @@ def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_expsum"): + with T.sblock("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( @@ -1093,7 +1093,7 @@ def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( @@ -1104,7 +1104,7 @@ def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ] ) T.writes([T_softmax_norm[i0_3, i1]]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], @@ -1126,18 +1126,18 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_maxelem_normal_reduction_init"): + with T.sblock("T_softmax_maxelem_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): - with T.block("T_softmax_maxelem_normal_reduction"): + with T.sblock("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) - with T.block("T_softmax_maxelem_cross_thread_reduction"): + with T.sblock("T_softmax_maxelem_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( @@ -1155,19 +1155,19 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: dtype="handle", ) ) - with T.block("T_softmax_maxelem_write_back"): + with T.sblock("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.where(ax0_1 == 0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_expsum_normal_reduction_init"): + with T.sblock("T_softmax_expsum_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): - with T.block("T_softmax_expsum_normal_reduction"): + with T.sblock("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( @@ -1180,7 +1180,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" ) - with T.block("T_softmax_expsum_cross_thread_reduction"): + with T.sblock("T_softmax_expsum_cross_thread_reduction"): T.reads([normal_reduce_temp1[0]]) T.writes([reduce_temp1[0]]) T.attr( @@ -1198,7 +1198,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: dtype="handle", ) ) - with T.block("T_softmax_expsum_write_back"): + with T.sblock("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.where(ax0_1 == 0) T.reads([reduce_temp1[0]]) @@ -1206,7 +1206,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: T_softmax_expsum_shared[i0_4] = reduce_temp1[0] for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_norm"): + with T.sblock("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( @@ -1217,7 +1217,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ] ) T.writes([T_softmax_norm[i0_5, i1]]) - T.block_attr({"axis": 1}) + T.sblock_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp( A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], @@ -1236,7 +1236,7 @@ def argmax_split( ) -> None: for i0, i1_0 in T.grid(128, 4): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("argmax"): + with T.sblock("argmax"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1265,13 +1265,13 @@ def lowered_argmax_split( in_thread_argmax_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(128): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("argmax_in_thread_init"): + with T.sblock("argmax_in_thread_init"): T.reads() T.writes(in_thread_argmax_v0[0], in_thread_argmax_v1[0]) in_thread_argmax_v0[0] = -1 in_thread_argmax_v1[0] = T.float32(-3.4028234663852886e38) for i1_0 in T.serial(4): - with T.block("argmax_in_thread"): + with T.sblock("argmax_in_thread"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1284,7 +1284,7 @@ def lowered_argmax_split( ) in_thread_argmax_v0[0] = v_argmax_v0 in_thread_argmax_v1[0] = v_argmax_v1 - with T.block("argmax_cross_thread"): + with T.sblock("argmax_cross_thread"): T.reads(in_thread_argmax_v0[0], in_thread_argmax_v1[0]) T.writes(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0]) T.attr( @@ -1310,7 +1310,7 @@ def lowered_argmax_split( dtype="handle", ) ) - with T.block("argmax_write_back"): + with T.sblock("argmax_write_back"): i = T.axis.spatial(128, i0) T.where(i1_1 == 0) T.reads(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0]) @@ -1328,7 +1328,7 @@ def argmin_split_init_update_reordered( ) -> None: for i0, i1_0 in T.grid(128, 4): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("argmin"): + with T.sblock("argmin"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1357,13 +1357,13 @@ def lowered_argmin_split_init_update_reordered( in_thread_argmin_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(128): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("argmin_in_thread_init"): + with T.sblock("argmin_in_thread_init"): T.reads() T.writes(in_thread_argmin_v0[0], in_thread_argmin_v1[0]) in_thread_argmin_v0[0] = -1 in_thread_argmin_v1[0] = T.float32(3.4028234663852886e38) for i1_0 in T.serial(4): - with T.block("argmin_in_thread"): + with T.sblock("argmin_in_thread"): i = T.axis.spatial(128, i0) k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) @@ -1376,7 +1376,7 @@ def lowered_argmin_split_init_update_reordered( ) in_thread_argmin_v1[0] = v_argmin_v1 in_thread_argmin_v0[0] = v_argmin_v0 - with T.block("argmin_cross_thread"): + with T.sblock("argmin_cross_thread"): T.reads(in_thread_argmin_v0[0], in_thread_argmin_v1[0]) T.writes(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0]) T.attr( @@ -1402,7 +1402,7 @@ def lowered_argmin_split_init_update_reordered( dtype="handle", ) ) - with T.block("argmin_write_back"): + with T.sblock("argmin_write_back"): i = T.axis.spatial(128, i0) T.where(i1_1 == 0) T.reads(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0]) @@ -1423,7 +1423,7 @@ def layer_norm_tuple_sum( for i0_fused in T.thread_binding(128, thread="blockIdx.x"): for i1_0 in T.serial(24): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("data_red_temp"): + with T.sblock("data_red_temp"): ax0 = T.axis.spatial(128, i0_fused) k1 = T.axis.reduce(768, i1_0 * 32 + i1_1) T.reads(data[ax0, k1]) @@ -1439,7 +1439,7 @@ def layer_norm_tuple_sum( data_red_temp_v1[ax0] = v_data_red_temp_v1 for i0_i1_fused_0 in T.thread_binding(384, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): ax0 = T.axis.spatial(128, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 768) ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 768) T.reads( @@ -1473,7 +1473,7 @@ def lowered_layer_norm_tuple_sum( bias: T.Buffer(768, "float32"), T_layer_norm: T.Buffer((128, 768), "float32"), ) -> None: - # with T.block("root") + # with T.sblock("root") data_red_temp_v0 = T.alloc_buffer([128], dtype="float32") data_red_temp_v1 = T.alloc_buffer([128], dtype="float32") cross_thread_data_red_temp_v0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") @@ -1482,13 +1482,13 @@ def lowered_layer_norm_tuple_sum( in_thread_data_red_temp_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0_fused in T.thread_binding(128, thread="blockIdx.x"): for i1_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("data_red_temp_in_thread_init"): + with T.sblock("data_red_temp_in_thread_init"): T.reads() T.writes(in_thread_data_red_temp_v0[0], in_thread_data_red_temp_v1[0]) in_thread_data_red_temp_v0[0] = T.float32(0) in_thread_data_red_temp_v1[0] = T.float32(0) for i1_0 in T.serial(24): - with T.block("data_red_temp_in_thread"): + with T.sblock("data_red_temp_in_thread"): ax0 = T.axis.spatial(128, i0_fused) k1 = T.axis.reduce(768, i1_0 * 32 + i1_1) T.reads(data[ax0, k1]) @@ -1499,7 +1499,7 @@ def lowered_layer_norm_tuple_sum( ) in_thread_data_red_temp_v0[0] = v_data_red_temp_v0 in_thread_data_red_temp_v1[0] = v_data_red_temp_v1 - with T.block("data_red_temp_cross_thread"): + with T.sblock("data_red_temp_cross_thread"): T.reads(in_thread_data_red_temp_v0[0], in_thread_data_red_temp_v1[0]) T.writes(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0]) T.attr( @@ -1521,7 +1521,7 @@ def lowered_layer_norm_tuple_sum( dtype="handle", ) ) - with T.block("data_red_temp_write_back"): + with T.sblock("data_red_temp_write_back"): ax0 = T.axis.spatial(128, i0_fused) T.where(i1_1 == 0) T.reads(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0]) @@ -1530,7 +1530,7 @@ def lowered_layer_norm_tuple_sum( data_red_temp_v1[ax0] = cross_thread_data_red_temp_v1[0] for i0_i1_fused_0 in T.thread_binding(384, thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("T_layer_norm"): + with T.sblock("T_layer_norm"): ax0 = T.axis.spatial(128, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 768) ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 768) T.reads( @@ -1562,14 +1562,14 @@ def thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), " temp_local = T.alloc_buffer((256,), scope="local") for i in T.thread_binding(256, thread="blockIdx.x"): for k in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum"): + with T.sblock("sum"): vi, vk = T.axis.remap("SR", [i, k]) T.reads(A[vi, vk]) T.writes(temp_local[vi]) with T.init(): temp_local[vi] = T.float32(0) temp_local[vi] = temp_local[vi] + A[vi, vk] - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(256, i) T.reads(temp_local[vi]) T.writes(B[vi]) @@ -1583,7 +1583,7 @@ def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer(( cross_thread_temp_local = T.alloc_buffer((1,), strides=(1,), scope="local") for i in T.thread_binding(256, thread="blockIdx.x"): for k in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum_cross_thread"): + with T.sblock("sum_cross_thread"): vi, vk = T.axis.remap("SR", [i, k]) T.reads(A[vi, vk]) T.writes(cross_thread_temp_local[0]) @@ -1595,13 +1595,13 @@ def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer(( T.tvm_thread_allreduce( T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_local[0], k ) - with T.block("sum_write_back"): + with T.sblock("sum_write_back"): vi = T.axis.spatial(256, i) T.reads(cross_thread_temp_local[0]) T.writes(temp_local[vi]) temp_local[vi] = cross_thread_temp_local[0] for tx in T.thread_binding(256, thread="threadIdx.x"): - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(256, i) T.where(tx == 0) T.reads(temp_local[vi]) @@ -1620,7 +1620,7 @@ def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T. var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local") for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1628,7 +1628,7 @@ def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T. T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) for ax2_fused_0 in range(T.int64(1)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1639,7 +1639,7 @@ def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T. var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1] for ax1_ax2_fused in range(T.int64(1)): for ax0_fused in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("NT_matmul"): + with T.sblock("NT_matmul"): vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1649,7 +1649,7 @@ def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T. with T.init(): var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = T.float16(0) var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) @@ -1670,7 +1670,7 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 in_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), "float16", strides=(1,), scope="local") for ax0_ax1_fused in T.thread_binding(n * T.int64(32), thread="blockIdx.x"): for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("NT_matmul_rf_init"): + with T.sblock("NT_matmul_rf_init"): vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1678,7 +1678,7 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) for ax2_fused_0 in range(T.int64(1)): - with T.block("NT_matmul_rf_update"): + with T.sblock("NT_matmul_rf_update"): vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1689,11 +1689,11 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + vax2_fused_1] for ax1_ax2_fused in range(T.int64(1)): for ax0_fused in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("NT_matmul_in_thread_init"): + with T.sblock("NT_matmul_in_thread_init"): T.reads() T.writes(in_thread_var_NT_matmul_intermediate_local[0]) in_thread_var_NT_matmul_intermediate_local[0] = T.float16(0) - with T.block("NT_matmul_in_thread"): + with T.sblock("NT_matmul_in_thread"): vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused) v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) @@ -1701,12 +1701,12 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) T.writes(in_thread_var_NT_matmul_intermediate_local[0]) in_thread_var_NT_matmul_intermediate_local[0] = in_thread_var_NT_matmul_intermediate_local[0] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] - with T.block("NT_matmul_cross_thread"): + with T.sblock("NT_matmul_cross_thread"): T.reads(in_thread_var_NT_matmul_intermediate_local[0]) T.writes(cross_thread_var_NT_matmul_intermediate_local[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), in_thread_var_NT_matmul_intermediate_local[0], T.bool(True), cross_thread_var_NT_matmul_intermediate_local[0], ax0_fused) - with T.block("NT_matmul_write_back"): + with T.sblock("NT_matmul_write_back"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) @@ -1714,7 +1714,7 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6 T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1]) var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0] for tx in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("compute"): + with T.sblock("compute"): v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) T.where(tx == T.int64(0) and (T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n)) @@ -1730,20 +1730,20 @@ def no_thread_broadcast(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 25 temp_2_local = T.alloc_buffer((1,), scope="local") for i in T.thread_binding(256, thread="blockIdx.x"): for k in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum"): + with T.sblock("sum"): vi, vk = T.axis.remap("SR", [i, k]) T.reads(A[vi, vk]) T.writes(temp_1_local[vi]) with T.init(): temp_1_local[vi] = T.float32(0) temp_1_local[vi] = temp_1_local[vi] + A[vi, vk] - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(256, i) T.reads(temp_1_local[vi]) T.writes(temp_2_local[0]) temp_2_local[0] = temp_1_local[vi] + T.float32(1) for j in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum"): + with T.sblock("sum"): vi, vj = T.axis.remap("SR", [i, j]) T.reads(temp_2_local[0]) T.writes(B[vi, vj]) @@ -1760,7 +1760,7 @@ def lowered_no_thread_broadcast( cross_thread_temp_1_local = T.alloc_buffer((1,), strides=(1,), scope="local") for i in T.thread_binding(256, thread="blockIdx.x"): for k in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum_cross_thread"): + with T.sblock("sum_cross_thread"): vi, vk = T.axis.remap("SR", [i, k]) T.reads(A[vi, vk]) T.writes(cross_thread_temp_1_local[0]) @@ -1772,18 +1772,18 @@ def lowered_no_thread_broadcast( T.tvm_thread_allreduce( T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_1_local[0], k ) - with T.block("sum_write_back"): + with T.sblock("sum_write_back"): vi = T.axis.spatial(256, i) T.reads(cross_thread_temp_1_local[0]) T.writes(temp_1_local[vi]) temp_1_local[vi] = cross_thread_temp_1_local[0] - with T.block("add"): + with T.sblock("add"): vi = T.axis.spatial(256, i) T.reads(temp_1_local[vi]) T.writes(temp_2_local[0]) temp_2_local[0] = temp_1_local[vi] + T.float32(1) for j in T.thread_binding(256, thread="threadIdx.x"): - with T.block("sum"): + with T.sblock("sum"): vi, vj = T.axis.remap("SR", [i, j]) T.reads(temp_2_local[0]) T.writes(B[vi, vj]) diff --git a/tests/python/tir-transform/test_tir_transform_lower_init_block.py b/tests/python/tir-transform/test_tir_transform_lower_init_block.py index d05b8bc71f46..4e1a1cbfa6d7 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_init_block.py +++ b/tests/python/tir-transform/test_tir_transform_lower_init_block.py @@ -30,7 +30,7 @@ def main(a: T.handle, b: T.handle) -> None: for i0, j0 in T.grid(64, 64): for k0 in T.serial(32, 64): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SRR", [i0, j0, k0]) with T.init(): B[i] = T.float32(0) @@ -46,7 +46,7 @@ def main(a: T.handle, b: T.handle) -> None: for i0, j0 in T.grid(64, 64): for k0 in T.serial(32, 64): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SRR", [i0, j0, k0]) T.reads(A[i, j, k]) T.writes(B[i]) @@ -64,7 +64,7 @@ def main(a: T.handle, b: T.handle) -> None: for i0, j0 in T.grid(64, 64): for k0 in T.serial(32, 64): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SRR", [i0, j0, k0]) BB = T.match_buffer(B[i], ()) AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) @@ -82,7 +82,7 @@ def main(a: T.handle, b: T.handle) -> None: for i0, j0 in T.grid(64, 64): for k0 in T.serial(32, 64): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SRR", [i0, j0, k0]) T.reads(A[i, j, k]) T.writes(B[i]) diff --git a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py index 2ba658b73822..41c53437e98c 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py @@ -40,7 +40,7 @@ def buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block(): + with T.sblock(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) sub_A = T.match_buffer( @@ -56,7 +56,7 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block(): + with T.sblock(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) for ii, kk in T.grid(4, 2): @@ -73,7 +73,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( @@ -93,7 +93,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block(): + with T.sblock(): Bs_0 = T.int32() Bs_1 = T.int32() T.reads([]) @@ -121,7 +121,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) T.evaluate( @@ -135,7 +135,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block(): + with T.sblock(): T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) T.evaluate( @@ -154,7 +154,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: def high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block(): + with T.sblock(): As_0 = T.int32() As_1 = T.int32() T.reads([]) @@ -181,7 +181,7 @@ def high_dim_opaque_access(a: T.handle) -> None: def transformed_high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -200,7 +200,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None: def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block(): + with T.sblock(): As_0 = T.int32() As_1 = T.int32() T.reads([]) @@ -227,7 +227,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -247,7 +247,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block(): + with T.sblock(): T.reads([]) T.writes( [ @@ -269,7 +269,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: offset_factor=1, ) for jj, kk in T.grid(4, 4): - with T.block(): + with T.sblock(): T.reads([]) T.writes( [ @@ -309,7 +309,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block(): + with T.sblock(): T.reads([]) T.writes( [ @@ -318,7 +318,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: ] ) for jj, kk in T.grid(4, 4): - with T.block(): + with T.sblock(): T.reads([]) T.writes( [ @@ -353,7 +353,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block(): + with T.sblock(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) Bs_0 = T.int32() @@ -382,7 +382,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block(): + with T.sblock(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) for ii, jj in T.grid(m, m): @@ -405,7 +405,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block(): + with T.sblock(): T.reads([]) T.writes([A[i, j], B[i, j]]) sub_A = T.match_buffer(A[i, j], (), offset_factor=1) @@ -428,7 +428,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block(): + with T.sblock(): T.reads([]) T.writes([A[i, j], B[i, j]]) A[i, j] = 1 @@ -448,7 +448,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: def fail_match_load(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], (), elem_offset=0) @@ -459,7 +459,7 @@ def fail_match_load(a: T.handle) -> None: def fail_match_store(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], (), elem_offset=0) @@ -471,7 +471,7 @@ def fail_match_store(a: T.handle) -> None: def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block(): + with T.sblock(): stride = T.int32() sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 @@ -485,7 +485,7 @@ def fail_buffer_bind(a: T.handle) -> None: def fail_match_func_param(a: T.handle, m: T.handle, n: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block(): + with T.sblock(): sub_A = T.match_buffer(A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1 @@ -536,7 +536,7 @@ def test_fail_match_func_param(): def scalar_match_buffer_type_coercion(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block(""): + with T.sblock(""): vi = T.axis.spatial(8, i) vj = T.axis.spatial(8, j) T.reads() @@ -550,7 +550,7 @@ def scalar_match_buffer_type_coercion(a: T.handle) -> None: def transformed_scalar_match_buffer_type_coercion(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block(""): + with T.sblock(""): vi = T.axis.spatial(8, i) vj = T.axis.spatial(8, j) T.reads() diff --git a/tests/python/tir-transform/test_tir_transform_lower_opaque_block.py b/tests/python/tir-transform/test_tir_transform_lower_opaque_block.py index dbaafb617aad..294cdc42d2e3 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_opaque_block.py +++ b/tests/python/tir-transform/test_tir_transform_lower_opaque_block.py @@ -35,17 +35,17 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[0, j]) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0 @@ -70,17 +70,17 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): for i2 in T.thread_binding(0, 2, thread="vthread"): - with T.block(): + with T.sblock(): T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="local") for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i0 * 4 + i1 * 2 + i2, j]) T.writes(B[0, j]) B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[0, j]) T.writes(C[i0 * 4 + i1 * 2 + i2, j]) C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 @@ -111,17 +111,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - with T.block(): + with T.sblock(): T.reads(A[i, m]) T.writes(C[i, m]) B = T.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): - with T.block(): + with T.sblock(): T.reads(A[i, j]) T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): - with T.block(): + with T.sblock(): T.reads(B[j]) T.writes(C[i, j]) C[i, j] = B[j] * 2.0 @@ -146,7 +146,7 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for i, j in T.grid(5, 7): - with T.block(): + with T.sblock(): T.reads(A[i * 7 + j]) T.writes(C[i * 7 + j]) T.where(i * 7 + j < 32) @@ -169,7 +169,7 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for x, y, z in T.grid(4, 1, 8): - with T.block(): + with T.sblock(): T.reads(A[x * 8 + y * 8 + z]) T.writes(C[x * 8 + y * 8 + z]) C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 @@ -190,7 +190,7 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - with T.block(): + with T.sblock(): T.reads(A[i]) T.writes(D[i]) B = T.alloc_buffer((32,), scope="global") @@ -218,19 +218,19 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in range(0, 4): - with T.block(): + with T.sblock(): T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16]) B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") for i1 in range(0, 4): for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(A[i0 * 4 + i1, j]) T.writes(B[i1, j]) B[i1, j] = A[i0 * 4 + i1, j] + 1.0 for i1 in range(0, 4): for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads(B[i1, j]) T.writes(C[i0 * 4 + i1, j]) C[i0 * 4 + i1, j] = B[i1, j] * 2.0 @@ -255,14 +255,14 @@ def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: n = T.int32() A = T.match_buffer(a, (1, n, 10240)) padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96)) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): - with T.block(""): + with T.sblock(""): A_pad_shared_dyn = T.alloc_buffer( (1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn" ) for ax0, ax1 in T.grid(96, 64): - with T.block("A_pad_shared.dyn"): + with T.sblock("A_pad_shared.dyn"): T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else( i * 128 + j * 32 + ax0 < n, @@ -304,7 +304,7 @@ def annotated_loops(a: T.handle) -> None: @T.prim_func def boolean_handling_before(a: T.Buffer(10, "bool"), b: T.Buffer(10, "bool")) -> None: for i0 in T.serial(10): - with T.block("b"): + with T.sblock("b"): T.reads(a[i0]) T.writes(b[i0]) b[i0] = a[i0] @@ -365,8 +365,8 @@ def test_annotated_loops(): def test_annotated_block(): @T.prim_func def annotated_block() -> None: - with T.block(): - T.block_attr({"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}) + with T.sblock(): + T.sblock_attr({"pragma_1": "str_value", "pragma_2": 1, "pragma_3": 0.0}) T.evaluate(0) mod = tvm.IRModule.from_expr(annotated_block.with_attr("global_symbol", "main")) @@ -385,8 +385,8 @@ def test_preserved_annotations(): @T.prim_func def before(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")): for i in T.serial(8, annotations={"k_0": 1, "k_1": [2, 3], "k_2": 3.14}): - with T.block("block"): - T.block_attr({"k_3": "oops"}) + with T.sblock("block"): + T.sblock_attr({"k_3": "oops"}) B[i] = A[i] + 1.0 @T.prim_func diff --git a/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py index 15d7118fb8a9..f6fe611c4e8c 100644 --- a/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py +++ b/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py @@ -30,33 +30,33 @@ def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float3 # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): - with T.block(): + with T.sblock(): T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("A_shared"): + with T.sblock("A_shared"): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) - T.block_attr({"tir.manifest_shared_memory_local_stage":1}) + T.sblock_attr({"tir.manifest_shared_memory_local_stage":1}) A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("B_shared"): + with T.sblock("B_shared"): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) - T.block_attr({"tir.manifest_shared_memory_local_stage":1}) + T.sblock_attr({"tir.manifest_shared_memory_local_stage":1}) B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): - with T.block("C"): + with T.sblock("C"): T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: @@ -71,13 +71,13 @@ def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float3 # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): - with T.block(): + with T.sblock(): T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") @@ -86,30 +86,30 @@ def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float3 B_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block(): + with T.sblock(): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]) A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("A_shared"): + with T.sblock("A_shared"): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block(): + with T.sblock(): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]) B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("B_shared"): + with T.sblock("B_shared"): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): - with T.block("C"): + with T.sblock("C"): T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: diff --git a/tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py index c133a2afb3a3..f454699d1c9a 100644 --- a/tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py +++ b/tests/python/tir-transform/test_tir_transform_memhammer_lower_auto_copy.py @@ -28,17 +28,17 @@ class Transpose: def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", scope="shared.dyn") - with T.block("A_shared"): - T.block_attr({"auto_copy": True}) + with T.sblock("A_shared"): + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(128, 16): A_shared_dyn[ax1, ax0] = A[ax0, ax1] - with T.block("B"): - T.block_attr({"auto_copy": True}) + with T.sblock("B"): + T.sblock_attr({"auto_copy": True}) for ax1, ax0 in T.grid(16, 128): B[ax1, ax0] = A_shared_dyn[ax1, ax0] @@ -49,20 +49,20 @@ class GlobalToShared: def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float32", scope="shared.dyn" ) - with T.block("A_shared"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("A_shared"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for ax0, ax1 in T.grid(128, 128): A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] - with T.block("B"): + with T.sblock("B"): for ax0, ax1 in T.grid(128, 128): B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] @@ -73,20 +73,20 @@ class SharedToGlobal: def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float32", scope="shared.dyn" ) - with T.block("A_shared"): + with T.sblock("A_shared"): for ax0, ax1 in T.grid(128, 128): A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] - with T.block("B"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("B"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for ax1, ax0 in T.grid(128, 128): B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax1, ax0] @@ -97,22 +97,22 @@ class GlobalToSharedWithLocalStage: def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float32", scope="shared.dyn" ) - with T.block("A_shared"): - T.block_attr( + with T.sblock("A_shared"): + T.sblock_attr( {"auto_copy": True, "vector_bytes": 16, "local_stage": True} ) for ax0, ax1 in T.grid(128, 128): A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] - with T.block("B"): + with T.sblock("B"): for ax0, ax1 in T.grid(128, 128): B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] @@ -121,20 +121,20 @@ def main(a: T.handle, b: T.handle) -> None: class SharedToWmma: @T.prim_func def main() -> None: - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float16", scope="shared.dyn" ) A_wmma = T.alloc_buffer( [128, 128], dtype="float16", scope="wmma.matrix_a" ) - with T.block("A_wmma"): - T.block_attr({"auto_copy": True}) + with T.sblock("A_wmma"): + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(128, 128): A_wmma[ax0, ax1] = A_shared_dyn[ax0, ax1] @@ -143,20 +143,20 @@ def main() -> None: class WmmaToShared: @T.prim_func def main() -> None: - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): C_accum = T.alloc_buffer( [128, 128], dtype="float32", scope="wmma.accumulator" ) C_shared = T.alloc_buffer( [128, 128], dtype="float32", scope="shared.dyn" ) - with T.block("C_shared"): - T.block_attr({"auto_copy": True}) + with T.sblock("C_shared"): + T.sblock_attr({"auto_copy": True}) for ax0, ax1 in T.grid(128, 128): C_shared[ax0, ax1] = C_accum[ax0, ax1] @@ -166,17 +166,17 @@ class WmmaToGlobal: @T.prim_func def main(c: T.handle) -> None: C = T.match_buffer(c, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): C_accum = T.alloc_buffer( [128, 128], dtype="float32", scope="wmma.accumulator" ) - with T.block("C_global"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("C_global"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for ax0, ax1 in T.grid(128, 128): C[bx * 128 + ax0, by * 128 + ax1] = C_accum[ax0, ax1] @@ -187,17 +187,17 @@ class WmmaToGlobalWithFusion: def main(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [1024]) C = T.match_buffer(c, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): C_accum = T.alloc_buffer( [128, 128], dtype="float32", scope="wmma.accumulator" ) - with T.block("C_global"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("C_global"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for ax0, ax1 in T.grid(128, 128): C[bx * 128 + ax0, by * 128 + ax1] = ( C_accum[ax0, ax1] + A[bx * 128 + ax0] @@ -209,17 +209,17 @@ class MmaToGlobal: @T.prim_func def main(c: T.handle) -> None: C = T.match_buffer(c, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): C_accum = T.alloc_buffer( [128, 128], dtype="float32", scope="m16n8k8.matrixC" ) - with T.block("C_global"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("C_global"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for ax0, ax1 in T.grid(128, 128): C[bx * 128 + ax0, by * 128 + ax1] = C_accum[ax0, ax1] @@ -230,17 +230,17 @@ class TransformedGlobalToShared: def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn" ) - with T.block("A_shared"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("A_shared"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for outer in T.serial(16): for ty_1 in T.thread_binding(8, thread="threadIdx.y"): for tx in T.thread_binding(32, thread="threadIdx.x"): @@ -260,7 +260,7 @@ def main(a: T.handle, b: T.handle) -> None: + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128, ] - with T.block("B"): + with T.sblock("B"): for ax0, ax1 in T.grid(128, 128): B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] @@ -271,22 +271,22 @@ class TransformedSharedToGlobal: def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float32", strides=[129, 1], scope="shared.dyn" ) - with T.block("A_shared"): + with T.sblock("A_shared"): T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) T.writes(A_shared_dyn[0:128, 0:128]) for ax0, ax1 in T.grid(128, 128): A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] - with T.block("B"): - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + with T.sblock("B"): + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) for outer in T.serial(16): for ty_1 in T.thread_binding(8, thread="threadIdx.y"): for tx in T.thread_binding(32, thread="threadIdx.x"): @@ -314,21 +314,21 @@ class TransformedGlobalToSharedWithLocalStage: def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(""): + with T.sblock(""): T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) T.writes(B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) A_shared_dyn = T.alloc_buffer( (128, 128), strides=(128, 1), scope="shared.dyn" ) - with T.block("A_shared"): + with T.sblock("A_shared"): T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) T.writes(A_shared_dyn[0:128, 0:128]) - T.block_attr( + T.sblock_attr( {"auto_copy": True, "local_stage": True, "vector_bytes": 16} ) A_shared_dyn_local = T.alloc_buffer((16, 4), scope="local") @@ -406,7 +406,7 @@ def main(a: T.handle, b: T.handle): ax0_ax1_fused_0 * 8 * 32 * 4 // 128 % 128 // 8, ax0_ax1_fused_3 % 128, ] - with T.block("B"): + with T.sblock("B"): T.reads(A_shared_dyn[0:128, 0:128]) T.writes(B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) for ax0 in range(128): @@ -421,24 +421,24 @@ def main() -> None: s0 = T.int32() s1 = T.int32() # body - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): A_shared_dyn = T.alloc_buffer( [128, 128], dtype="float16", strides=[136, 1], scope="shared.dyn" ) A_wmma = T.alloc_buffer( [128, 128], dtype="float16", scope="wmma.matrix_a" ) - with T.block("C_shared"): + with T.sblock("C_shared"): T.reads(A_shared_dyn[0:128, 0:128]) T.writes(A_wmma[0:128, 0:128]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax00, ax10 in T.grid(8, 8): - with T.block("wmma_load"): + with T.sblock("wmma_load"): T.reads( A_shared_dyn[ ax00 * 16 : ax00 * 16 + 16, @@ -502,24 +502,24 @@ def main() -> None: s0 = T.int32() s1 = T.int32() # body - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): C_accum = T.alloc_buffer( [128, 128], dtype="float32", scope="wmma.accumulator" ) C_shared = T.alloc_buffer( [128, 128], dtype="float32", strides=[136, 1], scope="shared.dyn" ) - with T.block("A_wmma"): + with T.sblock("A_wmma"): T.reads(C_accum[0:128, 0:128]) T.writes(C_shared[0:128, 0:128]) - T.block_attr({"auto_copy": True}) + T.sblock_attr({"auto_copy": True}) for ax00, ax10 in T.grid(8, 8): - with T.block("wmma_store"): + with T.sblock("wmma_store"): T.reads( C_accum[ ax00 * 16 : ax00 * 16 + 16, @@ -580,25 +580,25 @@ def main() -> None: class TransformedWmmaToGlobal: @T.prim_func def main(C: T.Buffer((1024, 1024), "float32")): - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(""): + with T.sblock(""): T.reads() T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) C_accum = T.alloc_buffer((128, 128), scope="wmma.accumulator") - with T.block("C_global"): + with T.sblock("C_global"): T.reads(C_accum[0:128, 0:128]) T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) C_accum_shared_dyn = T.alloc_buffer( (8, 8, 16, 16), strides=(2048, 256, 16, 1), scope="shared.dyn" ) for ax0_0 in range(8): for ax1_0 in range(8): - with T.block("wmma_store"): + with T.sblock("wmma_store"): T.reads( C_accum[ ax0_0 * 16 : ax0_0 * 16 + 16, @@ -780,27 +780,27 @@ def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024, 1024), "float32")) s0 = T.int32() s1 = T.int32() # body - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(): + with T.sblock(): T.reads(A[bx * 128 : bx * 128 + 128]) T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) C_accum = T.alloc_buffer( [128, 128], dtype="float32", scope="wmma.accumulator" ) - with T.block("C_global"): + with T.sblock("C_global"): T.reads(C_accum[0:128, 0:128], A[bx * 128 : bx * 128 + 128]) T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) C_accum_shared_dyn = T.alloc_buffer( (8, 8, 16, 16), strides=(2048, 256, 16, 1), scope="shared.dyn" ) for ax0_0 in range(8): for ax1_0 in range(8): - with T.block("wmma_store"): + with T.sblock("wmma_store"): T.reads( C_accum[ ax0_0 * 16 : ax0_0 * 16 + 16, @@ -1002,25 +1002,25 @@ def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024, 1024), "float32")) class TransformedMmaToGlobal: @T.prim_func def main(C: T.Buffer((1024, 1024), "float32")): - with T.block("root"): - T.block_attr({"warp_execution": True}) + with T.sblock("root"): + T.sblock_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): - with T.block(""): + with T.sblock(""): T.reads() T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) C_accum = T.alloc_buffer((128, 128), scope="m16n8k8.matrixC") - with T.block("C_global"): + with T.sblock("C_global"): T.reads(C_accum[0:128, 0:128]) T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) - T.block_attr({"auto_copy": True, "vector_bytes": 16}) + T.sblock_attr({"auto_copy": True, "vector_bytes": 16}) C_accum_shared_dyn = T.alloc_buffer( (8, 16, 8, 8), strides=(1152, 72, 8, 1), scope="shared.dyn" ) for ax0_0 in range(16): for ax1_0 in range(16): - with T.block("mma_store"): + with T.sblock("mma_store"): T.reads( C_accum[ ax0_0 * 8 : ax0_0 * 8 + 8, @@ -1133,7 +1133,7 @@ def verify_single_allocation(stmt, alloc_size=None): def verify(n): if ( - isinstance(n, tvm.tir.Block) + isinstance(n, tvm.tir.SBlock) and n.alloc_buffers is not None and (True in ((buf.scope() == "shared.dyn") for buf in n.alloc_buffers)) ): diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index d6324e16c8fa..aa1a5d57e484 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -173,7 +173,7 @@ def test_block(): def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in T.serial(0, T.int64(16)): for j in T.serial(0, T.int64(8)): - with T.block(): + with T.sblock(): vi = T.axis.spatial(T.int64(128), i * T.int64(8) + j) B[vi] = A[vi] + T.float32(1) @@ -181,7 +181,7 @@ def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): def expected_after(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in T.serial(0, T.int32(16)): for j in T.serial(0, T.int32(8)): - with T.block(): + with T.sblock(): vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j) B[vi] = A[vi] + T.float32(1) diff --git a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py index ff3fa8cf7092..b64440ee8285 100644 --- a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py @@ -38,11 +38,11 @@ def element_func(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((16, 16)) for i0 in range(0, 16): for j0 in range(0, 16): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i0, j0]) B[i, j] = A[i, j] + 1.0 for j0 in range(0, 16): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i0, j0]) C[i, j] = B[i, j] * 2.0 @@ -53,16 +53,16 @@ def transformed_element_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16, 16]) for i_0 in range(0, 16): - with T.block(): + with T.sblock(): T.reads([A[i_0, 0:16]]) T.writes([C[i_0, 0:16]]) B = T.alloc_buffer([16, 16]) for j_0 in T.serial(0, 16): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i_0, j_0]) B[i, j] = A[i, j] + 1.0 for j_0 in T.serial(0, 16): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i_0, j_0]) C[i, j] = B[i, j] * 2.0 @@ -71,11 +71,11 @@ def transformed_element_func(a: T.handle, c: T.handle) -> None: def original_func() -> None: A = T.alloc_buffer((128, 128), "float32") for i0, j0 in T.grid(128, 128): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i0, j0]) A[i, j] = T.float32(0) for i0, j0, k0 in T.grid(32, 32, 32): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) B = T.alloc_buffer((128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") @@ -96,18 +96,18 @@ def original_func() -> None: def transformed_func() -> None: A = T.alloc_buffer([128, 128]) for i0, j0 in T.grid(128, 128): - with T.block(): + with T.sblock(): i, j = T.axis.remap("SS", [i0, j0]) A[i, j] = T.float32(0) for i0, j0, k0 in T.grid(32, 32, 32): - with T.block(): + with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) B = T.alloc_buffer([128, 128]) if k == 0: for ii, jj in T.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - with T.block(""): + with T.sblock(""): T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) C = T.alloc_buffer([128, 128]) @@ -116,7 +116,7 @@ def transformed_func() -> None: B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] ) for kk in T.serial(0, 4): - with T.block(""): + with T.sblock(""): T.reads( [ B[((i * 4) + ii), ((j * 4) + jj)], @@ -137,11 +137,11 @@ def transformed_func() -> None: def match_buffer_func() -> None: C = T.alloc_buffer((128, 128)) for i in range(128): - with T.block(): + with T.sblock(): vi = T.axis.S(128, i) C0 = T.match_buffer(C[vi, 0:128], (128)) for j in range(128): - with T.block(): + with T.sblock(): jj = T.axis.S(128, j) C1 = T.match_buffer(C0[jj], ()) C1[()] = 0 @@ -150,12 +150,12 @@ def match_buffer_func() -> None: @T.prim_func def transformed_match_buffer_func() -> None: for i in range(0, 128): - with T.block(): + with T.sblock(): vi = T.axis.S(128, i) C = T.alloc_buffer((128, 128)) C0 = T.match_buffer(C[vi, 0:128], (128)) for j in range(128): - with T.block(): + with T.sblock(): jj = T.axis.S(128, j) C1 = T.match_buffer(C0[jj], ()) C1[()] = 0 @@ -167,9 +167,9 @@ def opaque_access(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block(): + with T.sblock(): vi = T.axis.S(8, i) - with T.block(): + with T.sblock(): v = T.axis.S(8, vi) T.reads([A[(v * 128) : ((v * 128) + 128)]]) T.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) @@ -186,7 +186,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block(): + with T.sblock(): v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) @@ -198,12 +198,12 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) for i in T.serial(0, 8): - with T.block(): + with T.sblock(): vi = T.axis.S(8, i) T.reads(A[vi * 128 : vi * 128 + 128]) T.writes(B[vi * 128 : vi * 128 + 128]) A_cache = T.alloc_buffer([1024]) - with T.block(): + with T.sblock(): v = T.axis.S(8, vi) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([A_cache[v * 128 : v * 128 + 128]]) @@ -213,7 +213,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block(): + with T.sblock(): v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) @@ -248,15 +248,15 @@ def before(A: T.Buffer((8, 8, 8), "int32"), B: T.Buffer((8, 8, 8), "int32")): for i in T.serial(8): for j in T.serial(8): for k in T.serial(8): - with T.block("b0"): + with T.sblock("b0"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] + 1 for k in T.serial(8): - with T.block("b1"): + with T.sblock("b1"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) D[vi, vj, vk] = A[vi, vj, vk] + 2 for k in T.serial(8): - with T.block("b2"): + with T.sblock("b2"): vi, vk = T.axis.remap("SS", [i, k]) vj = T.axis.opaque(8, j) B[vi, vj, vk] = ( @@ -268,22 +268,22 @@ def before(A: T.Buffer((8, 8, 8), "int32"), B: T.Buffer((8, 8, 8), "int32")): @T.prim_func def after(A: T.Buffer((8, 8, 8), "int32"), B: T.Buffer((8, 8, 8), "int32")) -> None: for i in T.serial(8): - with T.block(): + with T.sblock(): T.reads(A[i, 0:8, 0:8]) T.writes(B[i, 0:8, 0:8]) C = T.alloc_buffer([8, 8, 8], dtype="int32") D = T.alloc_buffer([8, 8, 8], dtype="int32") for j in T.serial(8): for k in T.serial(8): - with T.block("b0"): + with T.sblock("b0"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] + 1 for k in T.serial(8): - with T.block("b1"): + with T.sblock("b1"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) D[vi, vj, vk] = A[vi, vj, vk] + 2 for k in T.serial(8): - with T.block("b2"): + with T.sblock("b2"): vi, vk = T.axis.remap("SS", [i, k]) vj = T.axis.opaque(8, j) B[vi, vj, vk] = ( @@ -306,7 +306,7 @@ def before(A: T.Buffer((4, 16), "int32"), C: T.Buffer((4, 8), "int32")): for i in T.serial(0, 2): for j in T.serial(0, 6): for k in T.serial(3): - with T.block("P1"): + with T.sblock("P1"): T.where(i < 1 or j >= 2) cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k]) if vk == 0: @@ -316,7 +316,7 @@ def before(A: T.Buffer((4, 16), "int32"), C: T.Buffer((4, 8), "int32")): ) for j in T.serial(0, 4): for k in T.serial(3): - with T.block("P2"): + with T.sblock("P2"): vi = T.axis.opaque(2, i) cc, vj, vk = T.axis.remap("SSR", [c, j, k]) if vk == 0: @@ -328,13 +328,13 @@ def before(A: T.Buffer((4, 16), "int32"), C: T.Buffer((4, 8), "int32")): @T.prim_func def after(A: T.Buffer((4, 16), "int32"), C: T.Buffer((4, 8), "int32")): for c in T.serial(4): - with T.block(): + with T.sblock(): T.reads(A[c, 0:12], C[c, 0:8]) T.writes(C[c, 0:8]) B = T.alloc_buffer([4, 6], dtype="int32") for i in T.serial(2): for j, k in T.grid(6, 3): - with T.block("P1"): + with T.sblock("P1"): T.where(i < 1 or j >= 2) cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k]) if vk == 0: @@ -343,7 +343,7 @@ def after(A: T.Buffer((4, 16), "int32"), C: T.Buffer((4, 8), "int32")): B[cc, (vi * 4 + vj) % 6] + A[cc, vi * 4 + vj + vk] ) for j, k in T.grid(4, 3): - with T.block("P2"): + with T.sblock("P2"): vi = T.axis.opaque(2, i) cc, vj, vk = T.axis.remap("SSR", [c, j, k]) if vk == 0: @@ -410,7 +410,7 @@ def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), x_red_ = T.alloc_buffer((256, 256)) for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4): for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64): - with T.block("x_red"): + with T.sblock("x_red"): v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1) v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1) v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1) @@ -418,7 +418,7 @@ def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), x_red_[v_ax0, v_ax1] = T.float32(0.0) x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1] for ax0, ax1 in T.grid(64, 64): - with T.block("x_red_"): + with T.sblock("x_red_"): v0 = T.axis.spatial(256, ax0_0 * 64 + ax0) v1 = T.axis.spatial(256, ax1_0 * 64 + ax1) x_red[v0, v1] = x_red_[v0, v1] @@ -426,13 +426,13 @@ def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), @T.prim_func def after(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")): for ax0_0 in range(4): - with T.block(""): + with T.sblock(""): T.reads(x[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256, 0:256]) T.writes(x_red[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256]) x_red_ = T.alloc_buffer((256, 256)) for k1_0, ax1_0 in T.grid(4, 4): for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64): - with T.block("x_red"): + with T.sblock("x_red"): v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1) v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1) v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1) @@ -442,7 +442,7 @@ def after(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), " x_red_[v_ax0, v_ax1] = T.float32(0.0) x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1] for ax0, ax1 in T.grid(64, 64): - with T.block("x_red_"): + with T.sblock("x_red_"): v0 = T.axis.spatial(256, ax0_0 * 64 + ax0) v1 = T.axis.spatial(256, ax1_0 * 64 + ax1) T.reads(x_red_[v0, v1]) diff --git a/tests/python/tir-transform/test_tir_transform_profiling_instr.py b/tests/python/tir-transform/test_tir_transform_profiling_instr.py index 4084ad0feb27..139524fa8325 100644 --- a/tests/python/tir-transform/test_tir_transform_profiling_instr.py +++ b/tests/python/tir-transform/test_tir_transform_profiling_instr.py @@ -36,11 +36,11 @@ def input1(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (8, 8, 128), dtype="int32") for i, j in T.grid(8, 8): for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 for k, l in T.grid(8, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * 2 @@ -54,20 +54,20 @@ def input2(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: for i in T.serial(0, 8): for j in T.serial(0, 8): for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] for j in T.serial(0, 8): for k, l in T.grid(8, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] + 2 for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = C[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] @@ -82,23 +82,23 @@ def input3(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: for j in T.parallel(0, 8): for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] for j in T.serial(0, 8): for k in T.parallel(0, 8): for l in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] + 2 for k in T.parallel(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = C[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] @@ -111,13 +111,13 @@ def test1_expected_output(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j in T.grid(8, 8): T.evaluate(T.start_profile_intrinsic(3, dtype="handle")) for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(3, dtype="handle")) T.evaluate(T.start_profile_intrinsic(5, dtype="handle")) for k, l in T.grid(8, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(5, dtype="handle")) @@ -134,12 +134,12 @@ def test2_expected_output(a: T.handle, b: T.handle, c: T.handle) -> None: for j in T.serial(0, 8): for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(2, dtype="handle")) @@ -158,14 +158,14 @@ def test3_expected_output(a: T.handle, b: T.handle, c: T.handle) -> None: T.evaluate(T.start_profile_intrinsic(3, dtype="handle")) for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(3, dtype="handle")) T.evaluate(T.start_profile_intrinsic(5, dtype="handle")) for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(5, dtype="handle")) @@ -184,13 +184,13 @@ def test4_expected_output(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> for j in T.serial(0, 8): T.evaluate(T.start_profile_intrinsic(3, dtype="handle")) for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(3, dtype="handle")) T.evaluate(T.start_profile_intrinsic(5, dtype="handle")) for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] T.evaluate(T.end_profile_intrinsic(5, dtype="handle")) @@ -199,13 +199,13 @@ def test4_expected_output(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> for j in T.serial(0, 8): T.evaluate(T.start_profile_intrinsic(8, dtype="handle")) for k, l in T.grid(8, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] + 2 T.evaluate(T.end_profile_intrinsic(8, dtype="handle")) T.evaluate(T.start_profile_intrinsic(10, dtype="handle")) for k, l in T.grid(8, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = C[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] T.evaluate(T.end_profile_intrinsic(10, dtype="handle")) @@ -223,12 +223,12 @@ def test5_expected_output(a: T.handle, b: T.handle, c: T.handle) -> None: for j in T.serial(0, 8): for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * 2 T.evaluate(T.end_profile_intrinsic(2, dtype="handle")) @@ -246,12 +246,12 @@ def test6_expected_output(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> for j in T.parallel(0, 8): for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = A[vi, vj, vk * 16 + vl] * 2 for k in T.serial(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] T.evaluate(T.end_profile_intrinsic(2, dtype="handle")) @@ -260,14 +260,14 @@ def test6_expected_output(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> T.evaluate(T.start_profile_intrinsic(8, dtype="handle")) for k in T.parallel(0, 8): for l in T.serial(0, 16): - with T.block("C"): + with T.sblock("C"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = B[vi, vj, vk * 16 + vl] + 2 T.evaluate(T.end_profile_intrinsic(8, dtype="handle")) T.evaluate(T.start_profile_intrinsic(10, dtype="handle")) for k in T.parallel(0, 8): for l in T.serial(0, 16): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) C[vi, vj, vk * 16 + vl] = C[vi, vj, vk * 16 + vl] * D[vi, vj, vk * 16 + vl] T.evaluate(T.end_profile_intrinsic(10, dtype="handle")) diff --git a/tests/python/tir-transform/test_tir_transform_remove_weight_layout_rewrite_block.py b/tests/python/tir-transform/test_tir_transform_remove_weight_layout_rewrite_block.py index 3b099aa0838e..a2a4eaeb5320 100644 --- a/tests/python/tir-transform/test_tir_transform_remove_weight_layout_rewrite_block.py +++ b/tests/python/tir-transform/test_tir_transform_remove_weight_layout_rewrite_block.py @@ -42,14 +42,14 @@ def before( T.func_attr({"layout_free_buffers": [1]}) B_ = T.alloc_buffer([16, 4, 4], dtype="float32") for i0_o, i1_o in T.grid(16, 16): - with T.block("layout_rewrite"): + with T.sblock("layout_rewrite"): i0, i1 = T.axis.remap("SS", [i0_o, i1_o]) T.reads(B[i0, i1]) T.writes(B_[i1, i0 // 4, i0 % 4]) - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) B_[i1, i0 // 4, i0 % 4] = B[i0, i1] for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.spatial(16, i0 * 4 + i1) vj = T.axis.spatial(16, j) vk = T.axis.reduce(16, k0 * 4 + k1) @@ -67,14 +67,14 @@ def after( ) -> None: T.func_attr({"layout_free_buffers": [1]}) for i0_o, i1_o in T.grid(16, 16): - with T.block("layout_rewrite"): + with T.sblock("layout_rewrite"): i0, i1 = T.axis.remap("SS", [i0_o, i1_o]) T.reads() T.writes() - T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + T.sblock_attr({"meta_schedule.layout_rewrite_preproc": True}) T.evaluate(0) for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): - with T.block("matmul"): + with T.sblock("matmul"): vi = T.axis.spatial(16, i0 * 4 + i1) vj = T.axis.spatial(16, j) vk = T.axis.reduce(16, k0 * 4 + k1) diff --git a/tests/python/tir-transform/test_tir_transform_unify_thread_binding.py b/tests/python/tir-transform/test_tir_transform_unify_thread_binding.py index a419dc3f9976..89b17719ac41 100644 --- a/tests/python/tir-transform/test_tir_transform_unify_thread_binding.py +++ b/tests/python/tir-transform/test_tir_transform_unify_thread_binding.py @@ -46,11 +46,11 @@ def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: for i in T.thread_binding(0, 128, "blockIdx.x"): for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 @@ -63,12 +63,12 @@ def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 ) for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 ) @@ -83,11 +83,11 @@ def element_wise_thread_x_different_dtype( for i in T.thread_binding(128, "blockIdx.x"): for j0_0 in T.thread_binding(4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_0 in T.thread_binding(T.int64(4), "threadIdx.x"): for j1_1 in T.serial(T.int64(32)): - with T.block(""): + with T.sblock(""): C[i, j1_0 * T.int64(32) + j1_1] = B[i, j1_0 * T.int64(32) + j1_1] + 1.0 @@ -100,12 +100,12 @@ def unified_element_wise_thread_x_different_dtype( for blockIdx_x in T.thread_binding(128, "blockIdx.x"): for threadIdx_x in T.thread_binding(4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 ) for j1_1 in T.serial(T.int64(32)): - with T.block(""): + with T.sblock(""): C[blockIdx_x, T.cast(threadIdx_x, "int64") * T.int64(32) + j1_1] = ( B[blockIdx_x, T.cast(threadIdx_x, "int64") * T.int64(32) + j1_1] + 1.0 ) @@ -124,10 +124,10 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: T.launch_thread(j1_0, 4) for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 @@ -140,12 +140,12 @@ def unified_element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 ) for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 ) @@ -159,7 +159,7 @@ def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: for i_1 in T.thread_binding(0, 64, "threadIdx.x"): for j_0 in T.thread_binding(0, 2, "vthread.x"): for j_1 in T.serial(0, 64): - with T.block(""): + with T.sblock(""): B[i_0 * 64 + i_1, j_0 * 64 + j_1] = A[i_0 * 64 + i_1, j_0 * 64 + j_1] * 2.0 @@ -170,7 +170,7 @@ def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: for vthread_x in T.thread_binding(0, 2, "vthread.x"): for threadIdx_x in T.thread_binding(0, 64, "threadIdx.x"): for j_1 in T.serial(0, 64): - with T.block(""): + with T.sblock(""): B[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] = ( A[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] * 2.0 ) @@ -230,11 +230,11 @@ def element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) -> None: for i in T.thread_binding(0, 128, "threadIdx.y"): for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 @@ -247,12 +247,12 @@ def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) - for blockIdx_x in T.thread_binding(0, 128, "threadIdx.y"): for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 ) for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 ) @@ -293,7 +293,7 @@ def test_inner_binding_with_annotation(): def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): for bx in T.thread_binding(32, "blockIdx.x"): for tx in T.thread_binding(2, "threadIdx.x", annotations={"my_annotation": 1}): - with T.block("block"): + with T.sblock("block"): v = T.axis.spatial(64, bx * 2 + tx) B[v] = A[v] @@ -304,7 +304,7 @@ def unified_inner_binding_with_annotation( for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for var in T.serial(1, annotations={"my_annotation": 1}): - with T.block("block"): + with T.sblock("block"): v = T.axis.spatial(64, blockIdx_x * 2 + threadIdx_x) T.reads(A[v]) T.writes(B[v]) diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py b/tests/python/tir-transform/test_transform_default_gpu_schedule.py index 33f3933cb4e5..5b4e05f0d25e 100644 --- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py +++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py @@ -35,9 +35,9 @@ def broadcast_to( x_0 = T.int64() x_1 = T.int64() T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1)) - # with T.block("root"): + # with T.sblock("root"): for ax0, ax1 in T.grid(x_0, x_1): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, T.int64(0)]) T.writes(T_broadcast_to[v_ax0, v_ax1]) @@ -53,7 +53,7 @@ def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), v for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) // T.int64(262144)): - with T.block("T_broadcast_to"): + with T.sblock("T_broadcast_to"): v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * x_0) // x_1) v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1) T.where((ax0_ax1_fused_0 * T.int64(256) + ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1) @@ -78,9 +78,9 @@ def matmul( C: T.Buffer((32, 32), "float16"), ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): - with T.block("C"): + with T.sblock("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) @@ -102,9 +102,9 @@ def matmul_gpu( "tag": "", "thread_warp_size": 32}), "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): - with T.block("C"): + with T.sblock("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) @@ -121,9 +121,9 @@ def matmul_cpu( T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): - with T.block("C"): + with T.sblock("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) @@ -140,11 +140,11 @@ def matmul( C: T.Buffer((32, 32), "float16"), ): T.func_attr({"tir.is_scheduled": True, "global_symbol": "main", "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): for k in range(32): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial( 32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32 ) @@ -161,9 +161,9 @@ def matmul( @T.prim_func def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): - with T.block("C"): + with T.sblock("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) @@ -174,11 +174,11 @@ def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16" @T.prim_func def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): for k in range(32): - with T.block("C"): + with T.sblock("C"): v_i = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32) v_j = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32) v_k = T.axis.reduce(32, k) @@ -204,7 +204,7 @@ class Before: def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_add[ax0, ax1, ax2, ax3]) @@ -221,12 +221,12 @@ def add( T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"), ): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_i2_i3_fused_1 in T.thread_binding( T.int64(72), thread="threadIdx.x" ): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial( T.int64(4), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) @@ -276,7 +276,7 @@ class Before: def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -290,10 +290,10 @@ def full( T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), ): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): - with T.block("T_full"): + with T.sblock("T_full"): ax0 = T.axis.spatial( T.int64(2), (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3), @@ -326,10 +326,10 @@ def full( T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), ): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): - with T.block("T_full"): + with T.sblock("T_full"): ax0 = T.axis.spatial( T.int64(2), (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3), @@ -360,7 +360,7 @@ class Before: def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_add[ax0, ax1, ax2, ax3]) @@ -370,7 +370,7 @@ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32") def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_full"): + with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) @@ -387,12 +387,12 @@ def add( T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"), ): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_i2_i3_fused_1 in T.thread_binding( T.int64(72), thread="threadIdx.x" ): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial( T.int64(4), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) @@ -431,10 +431,10 @@ def full( T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), ): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): - with T.block("T_full"): + with T.sblock("T_full"): ax0 = T.axis.spatial( T.int64(2), (i0_i1_fused_0 * T.int64(6) + i0_i1_fused_1) // T.int64(3), @@ -463,7 +463,7 @@ class Before: def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): - with T.block("T_add"): + with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) T.writes(T_add[ax0, ax1, ax2, ax3]) @@ -476,7 +476,7 @@ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32") T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"): - with T.block("T_add"): + with T.sblock("T_add"): ax0 = T.axis.spatial(T.int64(4), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) // T.int64(18)) ax1 = T.axis.spatial(T.int64(3), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(18) // T.int64(6)) ax2 = T.axis.spatial(T.int64(2), (i0_i1_i2_i3_fused_0 * T.int64(72) + i0_i1_i2_i3_fused_1) % T.int64(6) // T.int64(3)) @@ -500,7 +500,7 @@ class Before: @T.prim_func def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")): T.func_attr({"tir.noalias": True}) - with T.block("T_add"): + with T.sblock("T_add"): vi = T.axis.spatial(1, T.int64(0)) T.reads(rxplaceholder[()]) T.writes(T_add[()]) @@ -511,10 +511,10 @@ class Expected: @T.prim_func def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")): T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) - # with T.block("root"): + # with T.sblock("root"): for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): - with T.block("T_add"): + with T.sblock("T_add"): vi = T.axis.spatial(1, T.int64(0)) T.reads(rxplaceholder[()]) T.writes(T_add[()]) @@ -536,7 +536,7 @@ class Before: @T.prim_func def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")): for k0, k1 in T.grid(T.int64(2), T.int64(2)): - with T.block("A_red"): + with T.sblock("A_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) with T.init(): A_red[()] = T.float64(0) @@ -547,11 +547,11 @@ class Expected: @T.prim_func def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")): T.func_attr({"tir.is_scheduled": True}) - # with T.block("root"): + # with T.sblock("root"): for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): for k0, k1 in T.grid(T.int64(2), T.int64(2)): - with T.block("A_red"): + with T.sblock("A_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) T.reads(A[v_k0, v_k1]) T.writes(A_red[()]) diff --git a/tests/python/tvmscript/test_tvmscript_complete.py b/tests/python/tvmscript/test_tvmscript_complete.py index 60002dbdb08c..353bd547e892 100644 --- a/tests/python/tvmscript/test_tvmscript_complete.py +++ b/tests/python/tvmscript/test_tvmscript_complete.py @@ -27,7 +27,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -41,13 +41,13 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(32, 32): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) for ii, jj in T.grid(4, 4): C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) for k in range(0, 32): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) for ii, jj, kk in T.grid(4, 4, 4): C[vi * 4 + ii, vj * 4 + jj] = ( @@ -62,13 +62,13 @@ def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block(): + with T.sblock(): for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -78,11 +78,11 @@ def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block(): - with T.block(): + with T.sblock(): + with T.sblock(): B[0, 0] = A[0, 0] + T.float32(1) for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -93,15 +93,15 @@ def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block(): + with T.sblock(): for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) B[vi, vj] = A[vi, vj] + T.float32(1) for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -112,7 +112,7 @@ def test_complete_matmul(): A, B, C = [func.buffer_map[x] for x in func.params] block = func.body.block.body.body.body.body.block - assert isinstance(block, tvm.tir.Block) + assert isinstance(block, tvm.tir.SBlock) vi, vj, vk = [x.var for x in block.iter_vars] access_A = tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)]) access_B = tvm.tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)]) @@ -126,7 +126,7 @@ def test_complete_matmul_original(): A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body.body.body[0].block - assert isinstance(block1, tvm.tir.Block) + assert isinstance(block1, tvm.tir.SBlock) vi, vj = [x.var for x in block1.iter_vars] access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] @@ -135,7 +135,7 @@ def test_complete_matmul_original(): tvm.ir.assert_structural_equal(block1.writes, [access_C]) block2 = func.body.block.body.body.body[1].body.block - assert isinstance(block2, tvm.tir.Block) + assert isinstance(block2, tvm.tir.SBlock) vi, vj, vk = [x.var for x in block2.iter_vars] access_A = tvm.tir.BufferRegion( A, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vk * 4, 4)] @@ -158,7 +158,7 @@ def _check_elementwise(func): assert len(root_block.writes) == 0 block1 = func.body.block.body[0].body.body.block - assert isinstance(block1, tvm.tir.Block) + assert isinstance(block1, tvm.tir.SBlock) vi, vj = [x.var for x in block1.iter_vars] tvm.ir.assert_structural_equal( @@ -171,7 +171,7 @@ def _check_elementwise(func): ) block2 = func.body.block.body[1].body.body.block - assert isinstance(block2, tvm.tir.Block) + assert isinstance(block2, tvm.tir.SBlock) vi, vj = [x.var for x in block2.iter_vars] tvm.ir.assert_structural_equal( block2.reads, @@ -198,7 +198,7 @@ def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: out_buf = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] @@ -207,12 +207,12 @@ def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=64, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=64, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=64, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[vi, index_buf[0]], index_buf[0]]) T.writes([out_buf[vi, vj]]) @@ -226,7 +226,7 @@ def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> out_buf = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @@ -235,12 +235,12 @@ def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=64, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=64, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=64, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads( [ @@ -271,11 +271,11 @@ def test_complete_buffer_indices(): def match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block(): + with T.sblock(): A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block(): + with T.sblock(): for j in range(0, 16): - with T.block(): + with T.sblock(): A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 @@ -284,15 +284,15 @@ def match_buffer_func(a: T.handle) -> None: def expected_match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i, 0:16]) A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block(): + with T.sblock(): T.reads([]) T.writes(A0[0:16]) for j in range(0, 16): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A0[j]) A1 = T.match_buffer(A0[j], ()) @@ -320,7 +320,7 @@ def alloc_buffer_func(a: T.handle, b: T.handle) -> None: def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [2, 2], dtype="float32", elem_offset=0, align=64, offset_factor=1) B = T.match_buffer(b, [2, 2], dtype="float32", elem_offset=0, align=64, offset_factor=1) - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) C = T.alloc_buffer([2, 2], dtype="float32", elem_offset=0, align=64, offset_factor=1) @@ -343,7 +343,7 @@ def automatic_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")): B = T.decl_buffer(4, "int32", data=B_data) for i in range(4): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) C[vi] = A[vi] + B[vi] @@ -353,7 +353,7 @@ def explicit_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")): B = T.decl_buffer(4, "int32", data=B_data) for i in range(4): - with T.block("compute"): + with T.sblock("compute"): vi = T.axis.remap("S", [i]) T.reads(A[vi], B[vi]) T.writes(C[vi]) diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index 1cbd6af961c7..09e246243e20 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -166,13 +166,13 @@ def inconsistent_binding_type() -> None: def test_error_remap_args(): def error_remap_type() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("TT", [i, j]) # error T.evaluate(1.0) def error_remap_value() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i + j, j]) # error T.evaluate(1.0) @@ -184,7 +184,7 @@ def test_invalid_block_axes(): def invalid_block_axes(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi = T.axis.S(i, A) # error T.evaluate(1.0) @@ -194,14 +194,14 @@ def invalid_block_axes(a: T.handle) -> None: def test_duplicate_block_axes(): def duplicate_block_axes() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi = T.axis.S(16, i) vi = T.axis.S(16, j) # error T.evaluate(1.0) def duplicate_block_axes_remap() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vi = T.axis.remap("SS", [i, j]) # error T.evaluate(1.0) @@ -212,7 +212,7 @@ def duplicate_block_axes_remap() -> None: def test_miss_block_bind(): def miss_block_bind_value() -> None: for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi = T.axis.S(i) # error T.evaluate(1.0) @@ -238,7 +238,7 @@ def inconsistent_grid(A: T.Buffer(16)) -> None: def test_invalid_match_buffer_region(): def invalid_match_buffer_region() -> None: for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) A = T.match_buffer(vi) # error T.evaluate(1.0) @@ -258,7 +258,7 @@ def test_duplicate_block_signature(): def duplicate_reads() -> None: A = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[0:8, 0:8]) T.reads(A[0:16, 0:16]) # error @@ -267,7 +267,7 @@ def duplicate_reads() -> None: def duplicate_writes() -> None: A = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) T.writes(A[0:8, 0:8]) T.writes(A[0:16, 0:16]) # error @@ -275,14 +275,14 @@ def duplicate_writes() -> None: def duplicate_predicate() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) T.where(1) T.where(0) # error def duplicate_init() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) with T.init(): T.evaluate(1.0) @@ -291,17 +291,17 @@ def duplicate_init() -> None: def duplicate_axes() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) vi = T.axis.S(i, 16) # error T.evaluate(1.0) - def duplicate_block_attrs_with_same_key_diff_value() -> None: + def duplicate_sblock_attrs_with_same_key_diff_value() -> None: for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"key1": "block1"}) - T.block_attr({"key1": "block2"}) # error + T.sblock_attr({"key1": "block1"}) + T.sblock_attr({"key1": "block2"}) # error T.evaluate(1.0) check_error(duplicate_reads, 7) @@ -309,14 +309,14 @@ def duplicate_block_attrs_with_same_key_diff_value() -> None: check_error(duplicate_predicate, 6) check_error(duplicate_init, 7) check_error(duplicate_axes, 5) - check_error(duplicate_block_attrs_with_same_key_diff_value, 6) + check_error(duplicate_sblock_attrs_with_same_key_diff_value, 6) def test_opaque_access_during_complete(): def opaque_access_during_complete(a: T.handle) -> None: # error A = T.match_buffer(a, (16, 16), "float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): T.evaluate(T.call_extern("dummy_extern_function", A.data, dtype="int32")) check_error(opaque_access_during_complete, None) @@ -326,7 +326,7 @@ def test_convert_slice_to_bufferload(): def convert_slice_to_bufferload() -> None: A = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = A[vi : vi + 2, vj] + 1 # error @@ -369,7 +369,7 @@ def test_match_buffer_shape_mismatch(): def buffer_shape_mismatch(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block(): + with T.sblock(): T.reads([]) T.writes([A[i, j * 4 : j * 4 + 4]]) sub_A = T.match_buffer( @@ -383,7 +383,7 @@ def buffer_shape_mismatch(a: T.handle) -> None: def test_high_dim_store(): def high_dim_store() -> None: - with T.block("root"): + with T.sblock("root"): B = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index @@ -393,7 +393,7 @@ def high_dim_store() -> None: def test_block_has_option_vars(): def block_has_option_vars() -> None: - with T.block("root") as x: # error: block does not support option_vars + with T.sblock("root") as x: # error: block does not support option_vars T.evaluate(0.0) check_error(block_has_option_vars, 2) @@ -409,7 +409,7 @@ def implicit_root_has_write(): T.evaluate(0.0) def implicit_root_has_attrs(): - T.block_attr({}) # error: implicit root does not support block_attr + T.sblock_attr({}) # error: implicit root does not support sblock_attr T.evaluate(0.0) def implicit_root_has_predicate(): @@ -432,7 +432,7 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 8): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) vl = T.axis.S(128, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -445,32 +445,32 @@ def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 def test_reorder_fail_block(): sch = tir.Schedule(elementwise_not_affine, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k, l = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(l, i) expected_sub_error_message = ( - " # tir.Block#0\n" - ' with T.block("B"):\n' - " ^^^^^^^^^^^^^^^^^^\n" + " # tir.SBlock#0\n" + ' with T.sblock("B"):\n' + " ^^^^^^^^^^^^^^^^^^^\n" ) assert expected_sub_error_message in str(execinfo.value) def test_reorder_fail_nested_loop_inner(): sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(k, i) @@ -485,7 +485,7 @@ def test_reorder_fail_nested_loop_inner(): def test_fuse_fail_nested_loop_outer(): sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") - block_b = sch.get_block("B") + block_b = sch.get_sblock("B") i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.fuse(k, i) @@ -500,13 +500,13 @@ def test_fuse_fail_nested_loop_outer(): def test_report_error_root_block(): sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") - root = sch.get_block("root") + root = sch.get_sblock("root") with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.compute_inline(root) expected_sub_error_message = ( - " # tir.Block#0\n" - ' with T.block("root"):\n' - " ^^^^^^^^^^^^^^^^^^^^^\n" + " # tir.SBlock#0\n" + ' with T.sblock("root"):\n' + " ^^^^^^^^^^^^^^^^^^^^^^\n" ) assert expected_sub_error_message in str(execinfo.value) @@ -561,7 +561,7 @@ def binop_bad_type(h: T.handle): def test_non_integer_typed_block_iter(): def non_integer_typed_block_iter(): - with T.block(): + with T.sblock(): i = T.axis.S(0.1, 0.1) # error IterVar requires an integer dtype check_error(non_integer_typed_block_iter, 3) @@ -571,7 +571,7 @@ def test_illegal_buffer_slice(): def strided_buffer_region(A: T.handle): # do not allow stride in buffer region A = T.match_buffer((128, 128), "int32") - with T.block(): + with T.sblock(): T.reads([]) T.writes([A[0:128:2, 0:128:3]]) # error T.evaluate(T.call_extern("strided_compute", dtype="")) diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 8352b116443a..834078c3d9e2 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -90,14 +90,14 @@ def test_ir_builder_tir_primfunc_complete(): def test_ir_builder_tir_block_base(): with IRBuilder() as ib: - with T.block("block"): + with T.sblock("block"): T.evaluate(0) # the block generated by IRBuilder block_realize_actual = ib.get() # the expected block - block_expected = tir.Block( + block_expected = tir.SBlock( iter_vars=[], reads=[], writes=[], @@ -107,7 +107,7 @@ def test_ir_builder_tir_block_base(): match_buffers=None, annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)}, ) - block_realize_expected = tir.BlockRealize( + block_realize_expected = tir.SBlockRealize( iter_values=[], predicate=True, block=block_expected, @@ -125,11 +125,11 @@ def test_ir_builder_tir_block_complete(): d = T.int32() e = T.Buffer((128, 128), "float32") f = T.int32() - with T.block("block"): + with T.sblock("block"): T.where(a > 1) T.reads(b[0:16, 0:16]) T.writes(c[d:128, d:128]) - T.block_attr({"key": "value"}) + T.sblock_attr({"key": "value"}) T.alloc_buffer((128, 128), "float32") T.match_buffer(e[0:32, 0:32], (32, 32), "float32") T.axis.spatial(128, f) @@ -145,7 +145,7 @@ def test_ir_builder_tir_block_complete(): var_d = tir.Var("d", "int32") buffer_e = tir.decl_buffer((128, 128), "float32", name="c") var_f = tir.Var("f", "int32") - block_expected = tir.Block( + block_expected = tir.SBlock( iter_vars=[tir.IterVar((0, 128), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar)], reads=[buffer_b[0:16, 0:16]], writes=[buffer_c[var_d:128, var_d:128]], @@ -157,7 +157,7 @@ def test_ir_builder_tir_block_complete(): ], annotations={"key": "value"}, ) - block_realize_expected = tir.BlockRealize( + block_realize_expected = tir.SBlockRealize( iter_values=[var_f], predicate=var_a > 1, block=block_expected, @@ -173,7 +173,7 @@ def test_ir_builder_tir_axis(): b = T.int32() c = T.int32() d = T.int32() - with T.block("block"): + with T.sblock("block"): T.axis.spatial(8, a) T.axis.reduce(16, b) T.axis.scan(32, c) @@ -188,7 +188,7 @@ def test_ir_builder_tir_axis(): var_b = tir.Var("b", "int32") var_c = tir.Var("c", "int32") var_d = tir.Var("d", "int32") - block_expected = tir.Block( + block_expected = tir.SBlock( iter_vars=[ tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar), tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce), @@ -201,7 +201,7 @@ def test_ir_builder_tir_axis(): body=tir.Evaluate(0), annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)}, ) - block_realize_expected = tir.BlockRealize( + block_realize_expected = tir.SBlockRealize( iter_values=[var_a, var_b, var_c, var_d], predicate=True, block=block_expected, diff --git a/tests/python/tvmscript/test_tvmscript_meta_programming.py b/tests/python/tvmscript/test_tvmscript_meta_programming.py index 83b71e1447c7..58efac54ed28 100644 --- a/tests/python/tvmscript/test_tvmscript_meta_programming.py +++ b/tests/python/tvmscript/test_tvmscript_meta_programming.py @@ -28,7 +28,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [M, N], dtype=dtype) for i, j, k in T.grid(M, N, K): - with T.block(): + with T.sblock(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -43,7 +43,7 @@ def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], dtype="float16") for i, j, k in T.grid(128, 128, 128): - with T.block(): + with T.sblock(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -58,7 +58,7 @@ def generate_erf(dtype): @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): for i in range(1): - with T.block("C"): + with T.sblock("C"): C[i] = T.erf(A[i]) return main @@ -66,13 +66,13 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): @T.prim_func def fp32(A: T.Buffer((1,), "float32"), C: T.Buffer((1,), "float32")): for i in range(1): - with T.block("C"): + with T.sblock("C"): C[i] = T.erf(A[i]) @T.prim_func def fp16(A: T.Buffer((1,), "float16"), C: T.Buffer((1,), "float16")): for i in range(1): - with T.block("C"): + with T.sblock("C"): C[i] = T.erf(A[i]) tvm.ir.assert_structural_equal(fp16.with_attr("global_symbol", "main"), generate_erf("float16")) diff --git a/tests/python/tvmscript/test_tvmscript_ops.py b/tests/python/tvmscript/test_tvmscript_ops.py index 0d6beabd7a40..4b271c6c4d9d 100644 --- a/tests/python/tvmscript/test_tvmscript_ops.py +++ b/tests/python/tvmscript/test_tvmscript_ops.py @@ -36,11 +36,11 @@ def get_valid_counts( out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") - with T.block("init"): + with T.sblock("init"): vi = T.axis.S(1, 0) valid_count_buf[vi] = T.int32(0) for j in range(2500): - with T.block("update"): + with T.sblock("update"): vj = T.axis.S(2500, j) T.reads([data_buf[vi, vj, 6]]) T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) @@ -119,7 +119,7 @@ def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None: def alloc_zero_dim_buffer_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (), "float32") B = T.match_buffer(b, (), "float32") - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) C = T.alloc_buffer((), "float32") diff --git a/tests/python/tvmscript/test_tvmscript_parser_source.py b/tests/python/tvmscript/test_tvmscript_parser_source.py index 416bfd719f5c..c54f6322e29f 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_source.py +++ b/tests/python/tvmscript/test_tvmscript_parser_source.py @@ -28,7 +28,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index cc285e9835de..60e969bfe66b 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -66,7 +66,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -82,7 +82,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -99,7 +99,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -153,7 +153,7 @@ def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): assign(i, j, k, t1=A, t2=B, t3=C) @T.prim_func(private=True) @@ -162,7 +162,7 @@ def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -218,7 +218,7 @@ def __init__(self, x: T.Buffer): def load(self, x: T.Buffer): N, M = T.meta_var(self.local_x.shape) for i, j in T.grid(N, M): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i, j]) self.local_x[vi, vj] = x[vi, vj] @@ -235,12 +235,12 @@ def func_no_macro(a: T.handle): A = T.match_buffer(a, [128, 128]) local_a = T.alloc_buffer([128, 128]) for i, j in T.grid(128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i, j]) local_a[vi, vj] = A[vi, vj] local_b = T.alloc_buffer([128, 128]) for i, j in T.grid(128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i, j]) local_b[vi, vj] = local_a[vi, vj] @@ -309,7 +309,7 @@ def starred(a: T.handle, b: T.handle): A = T.match_buffer(a, [*dims, 128], "int32") B = T.match_buffer(b, dims, "int32") for *spatial, reduction in T.grid(*A.shape): - with T.block("reduce"): + with T.sblock("reduce"): with T.init(): B[spatial] = T.int32(0) B[spatial] = B[spatial] + A[(*spatial, reduction)] @@ -319,7 +319,7 @@ def non_starred(a: T.handle, b: T.handle): A = T.match_buffer(a, [128, 128, 128], "int32") B = T.match_buffer(b, [128, 128], "int32") for i, j, k in T.grid(128, 128, 128): - with T.block("reduce"): + with T.sblock("reduce"): with T.init(): B[i, j] = T.int32(0) B[i, j] = B[i, j] + A[i, j, k] @@ -517,7 +517,7 @@ def test_reinterpret_nop(): def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: T.func_attr({"global_symbol": "main"}) for i in T.serial(0, 32): - with T.block(): + with T.sblock(): vi = T.axis.remap("S", [i]) B[vi] = T.reinterpret("float32", A[vi]) @@ -525,7 +525,7 @@ def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: T.func_attr({"global_symbol": "main"}) for i in T.serial(0, 32): - with T.block(): + with T.sblock(): vi = T.axis.remap("S", [i]) B[vi] = A[vi] @@ -580,27 +580,27 @@ def _to_dict(anno: tvm_ffi.container.Map): @T.prim_func def func0(): - with T.block(): - T.block_attr({"key1": "block1"}) - T.block_attr({"key2": "block2"}) + with T.sblock(): + T.sblock_attr({"key1": "block1"}) + T.sblock_attr({"key2": "block2"}) T.evaluate(0) assert _to_dict(func0.body.block.annotations) == {"key1": "block1", "key2": "block2"} @T.prim_func def func1(): - with T.block(): - T.block_attr({"key": {"key1": "block1"}}) - T.block_attr({"key": {"key2": "block2"}}) + with T.sblock(): + T.sblock_attr({"key": {"key1": "block1"}}) + T.sblock_attr({"key": {"key2": "block2"}}) T.evaluate(0) assert _to_dict(func1.body.block.annotations) == {"key": {"key1": "block1", "key2": "block2"}} @T.prim_func def func2(): - with T.block(): - T.block_attr({"key1": "block1"}) - T.block_attr({"key1": "block1"}) + with T.sblock(): + T.sblock_attr({"key1": "block1"}) + T.sblock_attr({"key1": "block1"}) T.evaluate(0) assert _to_dict(func2.body.block.annotations) == {"key1": "block1"} @@ -609,16 +609,16 @@ def func2(): @T.prim_func def func3(): - with T.block(): - T.block_attr({"key1": "block1"}) - T.block_attr({"key1": "block2"}) + with T.sblock(): + T.sblock_attr({"key1": "block1"}) + T.sblock_attr({"key1": "block2"}) T.evaluate(0) def test_alloc_inside_block(): @T.prim_func(private=True) def func() -> None: - with T.block(): + with T.sblock(): A = T.alloc_buffer([10], "float32") for i in T.serial(0, 10): B = T.alloc_buffer([10], "float32") @@ -628,7 +628,7 @@ def func() -> None: @T.prim_func(private=True) def expected() -> None: - with T.block(): + with T.sblock(): A = T.alloc_buffer([10], "float32") B = T.alloc_buffer([10], "float32") for i, j in T.grid(10, 10): @@ -641,7 +641,7 @@ def expected() -> None: def test_tir_macro_block_name_suffix(): @T.macro def operation(A, idx): - with T.block("op"): + with T.sblock("op"): v = T.axis.remap("S", [idx]) A[v] = A[v] * T.float32(2) @@ -657,13 +657,13 @@ def func_w_macro(a: T.handle) -> None: def expected(a: T.handle) -> None: A = T.match_buffer(a, [10]) for i in T.serial(0, 10): - with T.block("op"): + with T.sblock("op"): v = T.axis.remap("S", [i]) A[v] = A[v] * T.float32(2) - with T.block("op_1"): + with T.sblock("op_1"): v = T.axis.remap("S", [i]) A[v] = A[v] * T.float32(2) - with T.block("op_2"): + with T.sblock("op_2"): v = T.axis.remap("S", [i]) A[v] = A[v] * T.float32(2) diff --git a/tests/python/tvmscript/test_tvmscript_printer_highlight.py b/tests/python/tvmscript/test_tvmscript_printer_highlight.py index 5d4173640b48..9ccd7806dedf 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_highlight.py +++ b/tests/python/tvmscript/test_tvmscript_printer_highlight.py @@ -37,7 +37,7 @@ def main( # type: ignore B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) for n, i, j, k in T.grid(16, 128, 128, 128): - with T.block("matmul"): + with T.sblock("matmul"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) with T.init(): C[vn, vi, vj] = 0.0 # type: ignore diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py index 70473954eb9c..6fb0fab3b8a7 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py +++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py @@ -130,13 +130,13 @@ def test_for(): @T.prim_func def func1(): for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): pass @T.prim_func def func2(): for i, j, k in T.grid(128, 128, 128): - with T.block(): + with T.sblock(): pass func1 = func1.with_attr("global_symbol", "main") diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index e4af15807426..f8eea3aee327 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -111,7 +111,7 @@ def test_block_realize(): j = tir.Var("j", "int32") k = tir.Var("k", "int32") with IRBuilder() as ib: - with T.block(name="block", no_realize=False): + with T.sblock(name="block", no_realize=False): vi = ib.name("vi", T.axis.spatial(128, i)) vj = ib.name("vj", T.axis.spatial(64, j)) vk = ib.name("vk", T.axis.reduce(32, k)) @@ -125,7 +125,7 @@ def test_block_realize(): i = T.int32() j = T.int32() k = T.int32() -with T.block("block"): +with T.sblock("block"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(64, j) vk = T.axis.reduce(32, k) @@ -140,7 +140,7 @@ def test_block(): j = tir.Var("j", "int32") k = tir.Var("k", "int32") with IRBuilder() as ib: - with T.block(name="block", no_realize=False): + with T.sblock(name="block", no_realize=False): vi = ib.name("vi", T.axis.spatial(128, i)) vj = ib.name("vj", T.axis.spatial(64, j)) vk = ib.name("vk", T.axis.reduce(32, k)) @@ -151,7 +151,7 @@ def test_block(): _assert_print( obj, """ -with T.block("block", no_realize=True): +with T.sblock("block", no_realize=True): vi = T.axis.spatial(128) vj = T.axis.spatial(64) vk = T.axis.reduce(32) @@ -725,7 +725,7 @@ def test_remap(): @T.prim_func def block_with_remap_implicitly(): for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): + with T.sblock("update"): v0 = T.axis.spatial(128, i0 + 1) v1 = T.axis.spatial(128, i1) v2 = T.axis.reduce(128, i2) @@ -736,7 +736,7 @@ def block_with_remap_implicitly(): @T.prim_func def block_with_remap_explicitly(): for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): + with T.sblock("update"): v0 = T.axis.spatial(128, i0 + 1) v1, v2 = T.axis.remap("SR", [i1, i2]) v3 = T.axis.spatial(128, i3 - 1) @@ -747,9 +747,9 @@ def block_with_remap_explicitly(): @T.prim_func def main(): - # with T.block("root"): + # with T.sblock("root"): for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): + with T.sblock("update"): v0 = T.axis.spatial(128, i0 + 1) v1, v2 = T.axis.remap("SR", [i1, i2]) v3 = T.axis.spatial(128, i3 - 1) @@ -768,15 +768,15 @@ def test_root_block(): def root_block_implicitly(): a = T.alloc_buffer([128, 128]) for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): T.evaluate(0) @T.prim_func def root_block_explicitly(): - with T.block("root"): + with T.sblock("root"): a = T.alloc_buffer([128, 128]) for i, j in T.grid(128, 128): - with T.block(): + with T.sblock(): T.evaluate(0) expected_output = """ @@ -784,10 +784,10 @@ def root_block_explicitly(): @T.prim_func def main(): - # with T.block("root"): + # with T.sblock("root"): a = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): - with T.block(""): + with T.sblock(""): T.reads() T.writes() T.evaluate(0) diff --git a/tests/python/tvmscript/test_tvmscript_regression.py b/tests/python/tvmscript/test_tvmscript_regression.py index d531acc2e993..64dc4cd4050e 100644 --- a/tests/python/tvmscript/test_tvmscript_regression.py +++ b/tests/python/tvmscript/test_tvmscript_regression.py @@ -31,7 +31,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -80,7 +80,7 @@ def test_tir_buffer_region_extent_correct_dtype(): @T.prim_func def func(A: T.Buffer((T.int64(16), T.int64(1)), "float32")): for i in T.grid(T.int64(16)): - with T.block("block"): + with T.sblock("block"): vi = T.axis.remap("S", [i]) T.reads(A[vi, T.int64(0) : T.int64(1)]) T.evaluate(0) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index b3d459b2e67f..cc64db9cf3e7 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2507,7 +2507,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) @@ -2524,12 +2524,12 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block("init"): + with T.sblock("init"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(128): - with T.block("update"): + with T.sblock("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -2544,11 +2544,11 @@ def element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -2562,7 +2562,7 @@ def predicate(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), "float32") for i, jo, ji in T.grid(16, 4, 5): - with T.block("update"): + with T.sblock("update"): vi = T.axis.S(16, i) vj = T.axis.S(16, jo * 4 + ji) T.where(jo * 4 + ji < 16) @@ -2585,13 +2585,13 @@ def test_matmul_original(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block, tir.stmt.SBlock) assert isinstance(rt_func.body.block.body, tir.stmt.For) assert isinstance(rt_func.body.block.body.body, tir.stmt.For) assert isinstance(rt_func.body.block.body.body.body, tir.stmt.SeqStmt) - assert isinstance(rt_func.body.block.body.body.body[0].block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body.body.body[0].block, tir.stmt.SBlock) assert isinstance(rt_func.body.block.body.body.body[1], tir.stmt.For) - assert isinstance(rt_func.body.block.body.body.body[1].body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body.body.body[1].body.block, tir.stmt.SBlock) def test_element_wise(): @@ -2599,15 +2599,15 @@ def test_element_wise(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block, tir.stmt.SBlock) assert isinstance(rt_func.body.block.body, tir.stmt.SeqStmt) assert isinstance(rt_func.body.block.body[0], tir.stmt.For) assert isinstance(rt_func.body.block.body[0].body, tir.stmt.For) - assert isinstance(rt_func.body.block.body[0].body.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body[0].body.body.block, tir.stmt.SBlock) assert isinstance(rt_func.body.block.body[1], tir.stmt.For) assert isinstance(rt_func.body.block.body[1].body, tir.stmt.For) - assert isinstance(rt_func.body.block.body[1].body.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body[1].body.body.block, tir.stmt.SBlock) def test_predicate(): @@ -2615,11 +2615,11 @@ def test_predicate(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block, tir.stmt.SBlock) assert isinstance(rt_func.body.block.body, tir.stmt.For) assert isinstance(rt_func.body.block.body.body, tir.stmt.For) assert isinstance(rt_func.body.block.body.body.body, tir.stmt.For) - assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.SBlock) def for_thread_binding(): @@ -2658,11 +2658,11 @@ def match_buffer_region(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (1), "float32") for i, j in T.grid(16, 4): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) for ii in range(4): - with T.block(): + with T.sblock(): vii = T.axis.S(4, ii) D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) for i, j in T.grid(4, 4): @@ -2676,19 +2676,19 @@ def test_match_buffer_region(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body, tir.stmt.BlockRealize) + assert isinstance(rt_func.body, tir.stmt.SBlockRealize) root = rt_func.body.block assert isinstance(root.body, tir.stmt.For) assert isinstance(root.body.body, tir.stmt.For) - assert isinstance(root.body.body.body, tir.stmt.BlockRealize) + assert isinstance(root.body.body.body, tir.stmt.SBlockRealize) outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) assert isinstance(outer_block.body, tir.stmt.For) - assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) + assert isinstance(outer_block.body.body, tir.stmt.SBlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer @@ -2701,12 +2701,12 @@ def block_elements(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (1, 1), "float32") - with T.block("update"): + with T.sblock("update"): vi = T.axis.S(1, 0) T.where(True) T.reads(A[0:16, 0:16]) T.writes(B[0, 0]) - T.block_attr({"attr_key": "attr_value"}) + T.sblock_attr({"attr_key": "attr_value"}) C = T.alloc_buffer((4, 4), dtype="float32") D = T.match_buffer(A[0:4, 0], (4, 1)) with T.init(): @@ -2721,9 +2721,9 @@ def test_block_elements(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.Block) - assert isinstance(rt_func.body.block.body, tir.stmt.BlockRealize) - assert isinstance(rt_func.body.block.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block, tir.stmt.SBlock) + assert isinstance(rt_func.body.block.body, tir.stmt.SBlockRealize) + assert isinstance(rt_func.body.block.body.block, tir.stmt.SBlock) block = rt_func.body.block.body.block assert isinstance(block.body, tir.stmt.BufferStore) assert isinstance(block.init, tir.stmt.BufferStore) @@ -2739,11 +2739,11 @@ def opaque_block(a: T.handle, b: T.handle) -> None: for i in range(16): for j in range(16): - with T.block(): + with T.sblock(): T.reads([]) T.writes(A[i, j]) A[i, j] = T.float32(0) - with T.block(): + with T.sblock(): T.reads([A[i, 0:16]]) T.writes([B[i, 0:16]]) for j in range(16): @@ -2758,14 +2758,14 @@ def test_opaque_block(): tvm.ir.assert_structural_equal(func, rt_func) root_block = rt_func.body.block - assert isinstance(root_block, tir.stmt.Block) + assert isinstance(root_block, tir.stmt.SBlock) assert isinstance(root_block.body, tir.stmt.For) assert isinstance(root_block.body.body[0], tir.stmt.For) - assert isinstance(root_block.body.body[0].body, tir.stmt.BlockRealize) - assert isinstance(root_block.body.body[0].body.block, tir.stmt.Block) + assert isinstance(root_block.body.body[0].body, tir.stmt.SBlockRealize) + assert isinstance(root_block.body.body[0].body.block, tir.stmt.SBlock) assert len(root_block.body.body[0].body.block.iter_vars) == 0 - assert isinstance(root_block.body.body[1], tir.stmt.BlockRealize) - assert isinstance(root_block.body.body[1].block, tir.stmt.Block) + assert isinstance(root_block.body.body[1], tir.stmt.SBlockRealize) + assert isinstance(root_block.body.body[1].block, tir.stmt.SBlock) assert len(root_block.body.body[1].block.iter_vars) == 0 @@ -2854,7 +2854,7 @@ def rank0_block(a: T.handle) -> None: B = T.alloc_buffer((), "float32") B[()] = A[()] - with T.block("update"): + with T.sblock("update"): T.reads([A[()]]) T.writes([B[()]]) for i in range(1): @@ -2888,7 +2888,7 @@ def abs(a: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block("A"): + with T.sblock("A"): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = T.abs(A[vi, vj]) @@ -2924,11 +2924,11 @@ def var_with_same_name(): def var_with_same_name(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block(): + with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 @@ -2952,7 +2952,7 @@ def while_loop(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (16,), "float32") i = T.alloc_buffer((), "int32", scope="local") for ii in range(16): - with T.block(): + with T.sblock(): vi = T.axis.S(16, ii) B[vi] = 0 while i[()] < 10: @@ -3031,11 +3031,11 @@ def multiple_commreducer() -> None: reduce_temp0 = T.Buffer([1], dtype="float32", strides=[1], scope="local") reduce_temp1 = T.Buffer([1], dtype="float32", strides=[1], scope="local") for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cross_thread_reduction"): + with T.sblock("T_softmax_maxelem_cross_thread_reduction"): T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("T_softmax_expsum_cross_thread_reduction"): + with T.sblock("T_softmax_expsum_cross_thread_reduction"): T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) @@ -3082,7 +3082,7 @@ def nontrivial_range_axis(): def nontrivial_range_axis(a: T.handle) -> None: A = T.match_buffer(a, (10), "float32") for i in range(10): - with T.block("block"): + with T.sblock("block"): vi = T.axis.spatial((1, 11), i + 1) A[vi - 1] = A[vi - 1] + 1.0 @@ -3133,8 +3133,8 @@ def func(): def func_root_attr(): @T.prim_func def func_root_attr(): - with T.block("root"): - T.block_attr({"a": "0"}) + with T.sblock("root"): + T.sblock_attr({"a": "0"}) T.evaluate(0) return func_root_attr @@ -3143,7 +3143,7 @@ def func_root_attr(): def func_trivial_root_block(): @T.prim_func def func(A: T.Buffer(1, "int32")): - with T.block("root"): + with T.sblock("root"): A[0] = 0 return func @@ -3152,8 +3152,8 @@ def func(A: T.Buffer(1, "int32")): def func_nested_root_block(): @T.prim_func def func(A: T.Buffer(1, "int32")): - with T.block("root"): - with T.block("block"): + with T.sblock("root"): + with T.sblock("block"): A[0] = 0 return func @@ -3203,7 +3203,7 @@ def llvm_intrin_call(): @T.prim_func def ctpop(A: T.Buffer((16,), "uint8"), B: T.Buffer((16,), "uint8")) -> None: for i in range(0, 16): - with T.block("A"): + with T.sblock("A"): vi = T.axis.remap( "S", [ @@ -3229,12 +3229,12 @@ def segment_sum( B = T.match_buffer(B_ptr, [n], dtype="float32") indptr = T.match_buffer(indptr_ptr, [n + 1], dtype="int32") for i in T.serial(n): - with T.block("outer"): + with T.sblock("outer"): vi = T.axis.spatial(n, i) T.reads(indptr[i : i + 2], B[vi], A[indptr[i] : indptr[i + 1]]) T.writes(B[vi]) for j in T.serial(indptr[i], indptr[i + 1]): - with T.block("inner"): + with T.sblock("inner"): vj = T.axis.reduce(m, j) T.reads(B[vi], A[vj]) T.writes(B[vi]) @@ -3252,11 +3252,11 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32") C = T.match_buffer(c, (T.int64(128), T.int64(128)), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -3303,11 +3303,11 @@ def element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128), "float32", axis_separators=[1]) for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -3436,9 +3436,9 @@ def func( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body - # with T.block("root") + # with T.sblock("root") for i0, i1, i2 in T.grid(1, 512, 768): - with T.block("T_isinf"): + with T.sblock("T_isinf"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(placeholder[ax0, ax1, ax2]) T.writes(T_isinf[ax0, ax1, ax2]) @@ -3988,14 +3988,14 @@ def func( T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) C = T.alloc_buffer([128, 128], dtype="float32") for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("C"): + with T.sblock("C"): x, y, k = T.axis.remap("SSR", [i0, i1, i2]) with T.init(): C[x, y] = T.float32(0) C[x, y] = C[x, y] + A[x, k] * B[y, k] for i0, i1 in T.grid(128, 128): - with T.block("D"): - T.block_attr({"layout_free_placeholders": [C]}) + with T.sblock("D"): + T.sblock_attr({"layout_free_placeholders": [C]}) x, y = T.axis.remap("SS", [i0, i1]) D[x, y] = C[x, y] + T.float32(1) diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index df8675704b67..fe09d68d4a7f 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -32,7 +32,7 @@ def transformed_matmul_no_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) @@ -50,7 +50,7 @@ def transformed_matmul_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> No C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block("update"): + with T.sblock("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) @@ -104,7 +104,7 @@ def elementwise_handle( A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -116,7 +116,7 @@ def elementwise_buffer_kwargs( b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"), ) -> None: for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0 @@ -128,7 +128,7 @@ def elementwise_buffer_no_kwargs( b: T.Buffer((128, 128, 128, 128), "float32"), ) -> None: for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0 @@ -165,7 +165,7 @@ def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): B = T.match_buffer(b, (K, M), "float32") C = T.match_buffer(c, (N, M), "float32") for i, j, k in T.grid(N, M, K): - with T.block("gemm"): + with T.sblock("gemm"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 @@ -183,11 +183,11 @@ def match_buffer_int64(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32") C = T.match_buffer(c, (T.int64(128), T.int64(128)), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -199,11 +199,11 @@ def match_buffer_int64_after_roundtrip( ) -> None: B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32") for i, j in T.grid(128, 128): - with T.block("B"): + with T.sblock("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -217,13 +217,13 @@ def test_match_buffer_int64(): def test_match_buffer_region_has_implicit_shape_dtype(): @T.prim_func def explicit_shape_dtype(A: T.Buffer((16, 64), "int32")): - with T.block(): + with T.sblock(): B = T.match_buffer(A[8:16, 32:64], shape=(8, 32), dtype="int32") T.evaluate(0) @T.prim_func def implicit_shape_dtype(A: T.Buffer((16, 64), "int32")): - with T.block(): + with T.sblock(): B = T.match_buffer(A[8:16, 32:64]) T.evaluate(0) @@ -280,11 +280,11 @@ def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (32, 8), "float16", align=64, offset_factor=16, scope="warp") C = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=16, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) T.writes(C[0:32, 0:8]) for i, j, k in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) @@ -307,11 +307,11 @@ def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> Non B = T.match_buffer(b, (32, 8), "float16", align=64, offset_factor=16, scope="warp") C = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=16, scope="warp") - with T.block("root"): + with T.sblock("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) T.writes(C[0:32, 0:8]) for i, j, k in T.grid(16, 16, 16): - with T.block("C"): + with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) T.reads( C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], @@ -359,7 +359,7 @@ def int64_grid( B: T.Buffer((T.int64(128), T.int64(128)), "float32"), ) -> None: for i, j in T.grid(T.int64(128), T.int64(128)): - with T.block("C"): + with T.sblock("C"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + 1.0 @@ -370,7 +370,7 @@ def int64_grid_expanded( ) -> None: for i in range(T.int64(0), T.int64(128)): for j in range(T.int64(0), T.int64(128)): - with T.block("C"): + with T.sblock("C"): vi = T.axis.spatial(T.int64(128), i) vj = T.axis.spatial(T.int64(128), j) B[vi, vj] = A[vi, vj] + 1.0 diff --git a/tests/python/tvmscript/test_tvmscript_type.py b/tests/python/tvmscript/test_tvmscript_type.py index 8228363a95ac..4d515d9356bb 100644 --- a/tests/python/tvmscript/test_tvmscript_type.py +++ b/tests/python/tvmscript/test_tvmscript_type.py @@ -28,21 +28,21 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1) # body - with T.block("root"): + with T.sblock("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=64, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i0) vj = T.axis.S(128, ax1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) - T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) B[vi, vj] = A[vi, vj] * T.float32(2) for i1 in T.serial(0, 128): - with T.block("C"): + with T.sblock("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) @@ -70,12 +70,12 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 ) for j1_1 in T.serial(0, 32): - with T.block(""): + with T.sblock(""): C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 ) @@ -92,7 +92,7 @@ def loop_split(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") for i, ko in T.grid(128, 4): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) T.reads([B[vi], A[vi, vk]]) @@ -117,13 +117,13 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: for ki in T.thread_binding(0, 32, thread="threadIdx.x"): normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): - with T.block("B_normal_reduction"): + with T.sblock("B_normal_reduction"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) T.reads([A[vi, vk], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] - with T.block("B_cross_thread_reduction"): + with T.sblock("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( @@ -141,7 +141,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: dtype="handle", ) ) - with T.block("B_write_back"): + with T.sblock("B_write_back"): vi = T.axis.S(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) @@ -159,7 +159,7 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128], dtype="float32") for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block("B"): + with T.sblock("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads([B[vi, vj], A[vi, vj, vk]]) T.writes( diff --git a/web/tests/python/relax_rpc_test.py b/web/tests/python/relax_rpc_test.py index c21b98564d78..f533dcbf29cb 100644 --- a/web/tests/python/relax_rpc_test.py +++ b/web/tests/python/relax_rpc_test.py @@ -40,7 +40,7 @@ def main(x: R.Tensor([1024], "float32"), y: R.Tensor([1024], "float32")): sch = tvm.tir.Schedule(mod) # manually transform loop sch.work_on("add") - (i,) = sch.get_loops(block=sch.get_block("T_add")) + (i,) = sch.get_loops(block=sch.get_sblock("T_add")) i0, i1 = sch.split(i, [None, 128]) sch.bind(i0, "blockIdx.x") sch.bind(i1, "threadIdx.x") diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index f1e1c828885f..4d244c754b84 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -41,7 +41,7 @@ def test_rpc(): B = te.compute(A.shape, lambda *i: te.log(te.abs(A(*i) + 1)), name="B") mod = tvm.IRModule.from_expr(te.create_prim_func([A, B])) sch = tvm.tir.Schedule(mod) - (i,) = sch.get_loops(block=sch.get_block("B")) + (i,) = sch.get_loops(block=sch.get_sblock("B")) i0, i1 = sch.split(i, [None, 32]) sch.bind(i0, "blockIdx.x") sch.bind(i1, "threadIdx.x")