Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/api/python/tir/tir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ tvm.tir
.. automodule:: tvm.tir
:members:
:imported-members:
:exclude-members: PrimExpr, const, StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
:exclude-members: PrimExpr, const, StmtSRef, SBlockScope, ScheduleState, Schedule, ScheduleError
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/schedule/cuda/thread_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ffi::Array<tir::LoopRV> BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, //
* \param max_threads_per_block The maximum number of threads allowed.
* \param get_factor A function that returns the tiling factor.
*/
void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block, //
void BindBlockThreadIdx(tir::Schedule sch, tir::SBlockRV block, //
int64_t max_threadblocks, int64_t max_threads_per_block,
std::function<tir::ExprRV(int64_t max_extent)> get_factor = nullptr);

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/schedule/generic/winograd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace meta_schedule {
* If there is a constant winograd transform matrix, inline it.
* \return The only producer block.
*/
tir::BlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::BlockRV block);
tir::SBlockRV GetWinogradProducerAndInlineConst(tir::Schedule sch, tir::SBlockRV block);

} // namespace meta_schedule
} // namespace tvm
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ScheduleRuleNode : public runtime::Object {
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
virtual ffi::Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0;
virtual ffi::Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::SBlockRV& block) = 0;

/*!
* \brief Deep clone the schedule rule.
Expand Down Expand Up @@ -90,7 +90,7 @@ class ScheduleRule : public runtime::ObjectRef {
* \return The list of schedules generated by applying the schedule rule.
*/
using FApply =
ffi::TypedFunction<ffi::Array<tir::Schedule>(const tir::Schedule&, const tir::BlockRV&)>;
ffi::TypedFunction<ffi::Array<tir::Schedule>(const tir::Schedule&, const tir::SBlockRV&)>;
/*!
* \brief Get the schedule rule as string with name.
* \return The string of the schedule rule.
Expand Down Expand Up @@ -151,7 +151,7 @@ class ScheduleRule : public runtime::ObjectRef {
* \param reuse_read Data reuse configuration for reading. std::nullopt means no reuse.
* \param reuse_write Data reuse configuration for writing. std::nullopt means no reuse.
* \param filter_fn A function that can be passed to overwrite the default condition for applying
* MultiLevelTiling to a block. Its signature must be (Schedule, BlockRV) -> bool.
* MultiLevelTiling to a block. Its signature must be (Schedule, SBlockRV) -> bool.
* This is useful if there is a need to apply MultiLevelTiling to an operation / block which is
* ignored by default. This function should return True for a block that should be tiled.
* \return The schedule rule created
Expand Down Expand Up @@ -343,7 +343,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
}

void InitializeWithTuneContext(const TuneContext& context) final;
ffi::Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
ffi::Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::SBlockRV& block) final;
ScheduleRule Clone() const final;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyScheduleRule", PyScheduleRuleNode,
ScheduleRuleNode);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ TVM_DLL bool WellFormed(ffi::Variant<IRModule, Function> obj, bool check_struct_
* from the object (block or buffer) to it's index map transformation.
*/

TVM_DLL ffi::Map<tir::Block, ffi::Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
TVM_DLL ffi::Map<tir::SBlock, ffi::Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
const Function& fn, ffi::Array<tir::IndexMap> write_buffer_transformations);

/* \brief Collect variables whose value can be computed at compile-time
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/distributed/axis_group_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor {
return true;
}

void VisitStmt_(const BlockNode* op) final {
void VisitStmt_(const SBlockNode* op) final {
if (op->name_hint == "root") {
StmtExprVisitor::VisitStmt_(op);
return;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class BlockFrameNode : public TIRFrameNode {
.def_ro("predicate", &BlockFrameNode::predicate)
.def_ro("no_realize", &BlockFrameNode::no_realize);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BlockFrame", BlockFrameNode,
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.SBlockFrame", BlockFrameNode,
TIRFrameNode);

public:
Expand Down Expand Up @@ -207,7 +207,7 @@ class BlockInitFrameNode : public TIRFrameNode {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BlockInitFrameNode>();
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BlockInitFrame", BlockInitFrameNode,
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.SBlockInitFrame", BlockInitFrameNode,
TIRFrameNode);

public:
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array<PrimExpr> shape,
/*!
* \brief The block declaration statement.
* \param name The name of the block.
* \param no_realize The flag whether to construct BlockRealize or Block.
* \param no_realize The flag whether to construct SBlockRealize or SBlock.
* \return The BlockFrame.
*/
BlockFrame Block(ffi::String name, bool no_realize = false);
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
* - second: write regions
* - third: opaque regions
*/
TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetBlockAccessRegion(
const Block& block, const ffi::Map<Var, Buffer>& buffer_var_map);
TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockAccessRegion(
const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Auto detect the block read/write region according to its body stmt. An opaque access will
Expand All @@ -244,8 +244,8 @@ TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetBlockAccessRegion(
* It is a map from buffer var to the buffer
* \return An array only consisting of the read regions and write regions of the input block
*/
TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetBlockReadWriteRegion(
const Block& block, const ffi::Map<Var, Buffer>& buffer_var_map);
TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockReadWriteRegion(
const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);

/*! \brief Helper struct for return value of IdentifyMemCpy
*
Expand Down Expand Up @@ -329,7 +329,7 @@ TVM_DLL ffi::Map<Buffer, ffi::Optional<Stmt>> DetectBufferAccessLCA(const PrimFu
*
* - Each variable has a single point of definition.
*
* - Expressions within a tir::Block may not reference variables
* - Expressions within a tir::SBlock may not reference variables
* defined outside the block. For example, for a block with iter
* vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement
* `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop
Expand Down Expand Up @@ -379,7 +379,7 @@ const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var);
* \param mod The input TIR module.
* \return The anchor block if found, nullptr otherwise.
*/
const tir::BlockNode* FindAnchorBlock(const IRModule& mod);
const tir::SBlockNode* FindAnchorBlock(const IRModule& mod);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/tir/block_dependence_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
/*!
* \file tvm/tir/block_dependence_info.h
* \brief Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects to
* \brief Define BlockDependenceInfoNode that uses the SBlockScope and StmtSRef objects to
* store the block level dependences
* \sa BlockDependenceInfoNode
*/
Expand All @@ -41,10 +41,10 @@ namespace tir {

/**
* @brief An object that helps build and query block level dependences using the 2 core objects
* BlockScope and StmtSRef
* SBlockScope and StmtSRef
*
* The data structures exposed are:
* 1) sref2scope: Mapping from the srefs to its corresponding BlockScope
* 1) sref2scope: Mapping from the srefs to its corresponding SBlockScope
* 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs
*
* Note that this object does not store SRefs to loops as the purpose is only to expose block level
Expand All @@ -54,28 +54,28 @@ namespace tir {
class BlockDependenceInfoNode : public Object {
public:
/*!
* \brief Mapping from a block sref to its corresponding BlockScope,
* \brief Mapping from a block sref to its corresponding SBlockScope,
* tracking the dependency inside the block scope,
*/
std::unordered_map<StmtSRef, BlockScope, ObjectPtrHash, ObjectPtrEqual> sref2scope;
std::unordered_map<StmtSRef, SBlockScope, ObjectPtrHash, ObjectPtrEqual> sref2scope;
/*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BlockDependenceInfoNode>();
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockDependenceInfo", BlockDependenceInfoNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockDependenceInfo", BlockDependenceInfoNode, Object);

/*!
* \brief Get the BlockScope corresponding to the sref of scope root block
* \brief Get the SBlockScope corresponding to the sref of scope root block
* \param scope_root The block sref to be retrieved
* \return The corresponding BlockScope
* \return The corresponding SBlockScope
*/
BlockScope GetBlockScope(const StmtSRef& scope_root) const {
SBlockScope GetSBlockScope(const StmtSRef& scope_root) const {
auto it = sref2scope.find(scope_root);
CHECK(it != sref2scope.end())
<< "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
<< "IndexError: Cannot find the corresponding SBlockScope to the block sref:\n"
<< ffi::GetRef<Stmt>(scope_root->stmt);
return it->second;
}
Expand Down
34 changes: 17 additions & 17 deletions include/tvm/tir/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
*/
/*!
* \file tvm/tir/block_scope.h
* \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
* \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, SBlockScope.
* \sa StmtSRefNode
* \sa BlockScopeNode
* \sa SBlockScopeNode
*/
#ifndef TVM_TIR_BLOCK_SCOPE_H_
#define TVM_TIR_BLOCK_SCOPE_H_
Expand All @@ -41,7 +41,7 @@ namespace tir {
* \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
*
* Glossary
* - Block sref: A StmtSRef that points to a TensorIR block.
* - SBlock sref: A StmtSRef that points to a TensorIR SBlock.
* - Loop sref: A StmtSRef that points to a TensorIR for loop.
* - Parent sref: The parent reference of an sref is the block or loop reference to the closest
schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in
Expand Down Expand Up @@ -86,7 +86,7 @@ class StmtSRefNode : public Object {
* \brief Get the referenced statement with proper type checking.
* It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
* \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
* tvm::tir::BlockNode or tvm::tir::ForNode
* tvm::tir::SBlockNode or tvm::tir::ForNode
* \return nullptr if type check fails, otherwise the casted result for `this->stmt`
*/
template <typename StmtType>
Expand Down Expand Up @@ -177,7 +177,7 @@ class SRefTreeCreator : private StmtVisitor {

void VisitStmt_(const ForNode* loop) final;

void VisitStmt_(const BlockRealizeNode* realize) final;
void VisitStmt_(const SBlockRealizeNode* realize) final;

void VisitStmt_(const SeqStmtNode* seq_stmt) final;

Expand Down Expand Up @@ -243,14 +243,14 @@ class Dependency : public ObjectRef {
* For example even leaf nodes have a scope node, even though they have no dependencies.
*
* Glossary:
* - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
* - SBlock scope: A contiguous subtree of the sref tree, rooted at each SBlock sref,
* whose components are:
* - scope root: a block sref
* - internal srefs: loop srefs
* - scope leaves: block srefs
* - Child block: The scope leaf blocks under the scope root or a specific internal sref
*/
class BlockScopeNode : public Object {
class SBlockScopeNode : public Object {
public:
/*!
* \brief Lookup table for the `src` of dependencies
Expand All @@ -265,9 +265,9 @@ class BlockScopeNode : public Object {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BlockScopeNode>();
refl::ObjectDef<SBlockScopeNode>();
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockScope", BlockScopeNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockScope", SBlockScopeNode, Object);

public:
/******** Dependency ********/
Expand All @@ -286,29 +286,29 @@ class BlockScopeNode : public Object {
};

/*!
* \brief Managed reference to BlockScopeNode
* \sa BlockScopeNode
* \brief Managed reference to SBlockScopeNode
* \sa SBlockScopeNode
*/
class BlockScope : public ObjectRef {
class SBlockScope : public ObjectRef {
public:
/*!
* \brief Constructor from ObjectPtr<BlockScopeNode>.
* \brief Constructor from ObjectPtr<SBlockScopeNode>.
* \param data The object pointer.
*/
explicit BlockScope(ObjectPtr<BlockScopeNode> data) : ObjectRef(data) {
explicit SBlockScope(ObjectPtr<SBlockScopeNode> data) : ObjectRef(data) {
TVM_FFI_ICHECK(data != nullptr);
}
/*! \brief The constructor creating an empty block scope with on dependency information */
TVM_DLL BlockScope();
TVM_DLL SBlockScope();
/*!
* \brief Create the object with the specific leaf blocks, and compute the dependency information
* between the leaf blocks.
* \param child_block_srefs The srefs to the leaf blocks
* \note We assume the leaf blocks are given in pre-DFS order
*/
TVM_DLL explicit BlockScope(const ffi::Array<StmtSRef>& child_block_srefs);
TVM_DLL explicit SBlockScope(const ffi::Array<StmtSRef>& child_block_srefs);

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockScope, ObjectRef, BlockScopeNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockScope, ObjectRef, SBlockScopeNode);
};

} // namespace tir
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class DataTypeLegalizer : public StmtExprMutator {
protected:
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const SBlockRealizeNode* op) override;
Stmt VisitStmt_(const SBlockNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
Expand Down Expand Up @@ -101,8 +101,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
using Parent::VisitExpr_;
using Parent::VisitStmt_;

Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const SBlockRealizeNode* op) override;
Stmt VisitStmt_(const SBlockNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const AttrStmtNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/schedule/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class InstructionKindNode : public runtime::Object {
ffi::String name;
/*!
* \brief Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule
* state. For example, the instruction `GetBlock` is pure because it changes
* state. For example, the instruction `GetSBlock` is pure because it changes
* nothing, while `ComputeInline` is not because removing it leads to a different resulting
* schedule.
*/
Expand Down Expand Up @@ -148,7 +148,7 @@ class InstructionNode : public runtime::Object {
/*!
* \brief The input random variables of the instruction, and the type of each element can be one
* of the following:
* - BlockRV
* - SBlockRV
* - LoopRV
* - ExprRV
* - double
Expand All @@ -160,12 +160,12 @@ class InstructionNode : public runtime::Object {
/*!
* \brief The attributes of the instruction. Similar to attributes of an operator,
* attributes of an instruction are arbitrary constant metadata required by the instructions.
* For example, the name of the block to be retrieved in `GetBlock`.
* For example, the name of the block to be retrieved in `GetSBlock`.
*/
ffi::Array<Any> attrs;
/*! \brief The output random variables of the instruction, and the type of each element can be one
* of the following:
* - BlockRV
* - SBlockRV
* - LoopRV
* - ExprRV, atomic variables only, won't be constants or composite PrimExpr
*/
Expand Down
Loading
Loading