Skip to content

Commit 13ea9dc

Browse files
wrongtest-intellifbaoxinqi
andauthored
[TIR] Add step attribute to ForNode (Initial codes) (apache#18421)
An initial change to add `ForNode::step`. - Add `Optional<PrimExpr>` typed step attribute to ForNode. Then add minimal codes for - Roundtrip support for TIR tvmscript grammar - Correctness of TIR lowering pipeline: - Canonicalize the loop in default pipeline - Ensure the original `ForNode::step` is not dropped by mutations on `ForNode`. - CodeGen support for non-zero min and non-trivial step. - TODOs in the future (hopefully) - For **all transformations and analysis tools**, make adaptions to non-consecutive loop iteration indices - Correctness of TensorIR schedule and MetaSchedule --------- Co-authored-by: baoxinqi <[email protected]>
1 parent 9e905f9 commit 13ea9dc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+619
-126
lines changed

include/tvm/script/ir_builder/tir/frame.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,15 @@ class ForFrameNode : public TIRFrameNode {
251251
* \param loop_body The loop body
252252
* \return A stmt, the loop nest
253253
*/
254-
using FMakeForLoop =
255-
ffi::TypedFunction<tvm::tir::Stmt(ffi::Array<tvm::tir::Var> loop_vars,
256-
ffi::Array<Range> loop_extents, tvm::tir::Stmt loop_body)>;
254+
using FMakeForLoop = ffi::TypedFunction<tvm::tir::Stmt(
255+
ffi::Array<tvm::tir::Var> loop_vars, ffi::Array<Range> loop_extents,
256+
ffi::Array<ffi::Optional<PrimExpr>> loop_steps, tvm::tir::Stmt loop_body)>;
257257
/*! \brief The loop variable. */
258258
ffi::Array<tvm::tir::Var> vars;
259259
/*! \brief The domains of iteration. */
260260
ffi::Array<Range> doms;
261+
/*! \brief The optional steps of iteration. */
262+
ffi::Array<ffi::Optional<PrimExpr>> steps;
261263
/*! \brief The for loop generating function. */
262264
FMakeForLoop f_make_for_loop;
263265

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,37 +228,45 @@ ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings,
228228
* \param start The minimum value of iteration.
229229
* \param stop The maximum value of iteration.
230230
* \param annotations The optional annotations of the For statement.
231+
* \param step The optional step value of iteration.
231232
* \return The ForFrame.
232233
*/
233234
ForFrame Serial(PrimExpr start, PrimExpr stop,
234-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
235+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
236+
ffi::Optional<PrimExpr> step = std::nullopt);
235237
/*!
236238
* \brief The parallel For statement.
237239
* \param start The minimum value of iteration.
238240
* \param stop The maximum value of iteration.
239241
* \param annotations The optional annotations of the For statement.
242+
* \param step The optional step value of iteration.
240243
* \return The ForFrame.
241244
*/
242245
ForFrame Parallel(PrimExpr start, PrimExpr stop,
243-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
246+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
247+
ffi::Optional<PrimExpr> step = std::nullopt);
244248
/*!
245249
* \brief The vectorized For statement.
246250
* \param start The minimum value of iteration.
247251
* \param stop The maximum value of iteration.
248252
* \param annotations The optional annotations of the For statement.
253+
* \param step The optional step value of iteration.
249254
* \return The ForFrame.
250255
*/
251256
ForFrame Vectorized(PrimExpr start, PrimExpr stop,
252-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
257+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
258+
ffi::Optional<PrimExpr> step = std::nullopt);
253259
/*!
254260
* \brief The unrolled For statement.
255261
* \param start The minimum value of iteration.
256262
* \param stop The maximum value of iteration.
257263
* \param annotations The optional annotations of the For statement.
264+
* \param step The optional step value of iteration.
258265
* \return The ForFrame.
259266
*/
260267
ForFrame Unroll(PrimExpr start, PrimExpr stop,
261-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
268+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
269+
ffi::Optional<PrimExpr> step = std::nullopt);
262270
/*!
263271
* \brief The thread-binding For statement.
264272
* \param start The minimum value of iteration.

include/tvm/tir/stmt.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ enum class ForKind : int {
717717
*
718718
* \code
719719
*
720-
* for (loop_var = min; loop_var < min + extent; ++loop_var) {
720+
* for (loop_var = min; loop_var < min + extent; loop_var += step) {
721721
* // body
722722
* }
723723
* \endcode
@@ -748,6 +748,10 @@ class ForNode : public StmtNode {
748748
* and can be ignored in most passes.
749749
*/
750750
ffi::Map<ffi::String, ffi::Any> annotations;
751+
/*!
752+
* \brief The loop step. It is one if not specified.
753+
*/
754+
ffi::Optional<PrimExpr> step;
751755

752756
static void RegisterReflection() {
753757
namespace refl = tvm::ffi::reflection;
@@ -758,8 +762,13 @@ class ForNode : public StmtNode {
758762
.def_ro("kind", &ForNode::kind)
759763
.def_ro("body", &ForNode::body)
760764
.def_ro("thread_binding", &ForNode::thread_binding)
761-
.def_ro("annotations", &ForNode::annotations);
765+
.def_ro("annotations", &ForNode::annotations)
766+
.def_ro("step", &ForNode::step);
762767
}
768+
769+
/*! \brief Check it is a loop without nontrivial loop step. */
770+
bool HasTrivialStep() const;
771+
763772
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode);
764773
};
765774

@@ -771,8 +780,8 @@ class For : public Stmt {
771780
public:
772781
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
773782
ffi::Optional<IterVar> thread_binding = std::nullopt,
774-
ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
775-
Span span = Span());
783+
ffi::Map<ffi::String, ffi::Any> annotations = {},
784+
ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
776785

777786
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode);
778787
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,11 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L
677677

678678

679679
def serial(
680-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
680+
start: PrimExpr,
681+
stop: PrimExpr = None,
682+
*,
683+
annotations: Dict[str, Any] = None,
684+
step: Optional[PrimExpr] = None,
681685
) -> frame.ForFrame:
682686
"""The serial For statement.
683687
@@ -692,6 +696,9 @@ def serial(
692696
annotations : Dict[str, Any]
693697
The optional annotations of the For statement.
694698
699+
step : PrimExpr
700+
The optional step value of iteration.
701+
695702
Returns
696703
-------
697704
res : frame.ForFrame
@@ -703,11 +710,15 @@ def serial(
703710
start = IntImm(start.dtype, 0)
704711
else:
705712
start = 0
706-
return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
713+
return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
707714

708715

709716
def parallel(
710-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
717+
start: PrimExpr,
718+
stop: PrimExpr = None,
719+
*,
720+
annotations: Dict[str, Any] = None,
721+
step: Optional[PrimExpr] = None,
711722
) -> frame.ForFrame:
712723
"""The parallel For statement.
713724
@@ -722,6 +733,9 @@ def parallel(
722733
annotations : Dict[str, Any]
723734
The optional annotations of the For statement.
724735
736+
step : PrimExpr
737+
The optional step value of iteration.
738+
725739
Returns
726740
-------
727741
res : frame.ForFrame
@@ -733,11 +747,15 @@ def parallel(
733747
start = IntImm(start.dtype, 0)
734748
else:
735749
start = 0
736-
return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
750+
return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
737751

738752

739753
def vectorized(
740-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
754+
start: PrimExpr,
755+
stop: PrimExpr = None,
756+
*,
757+
annotations: Dict[str, Any] = None,
758+
step: Optional[PrimExpr] = None,
741759
) -> frame.ForFrame:
742760
"""The vectorized For statement.
743761
@@ -752,6 +770,9 @@ def vectorized(
752770
annotations : Dict[str, Any]
753771
The optional annotations of the For statement.
754772
773+
step : PrimExpr
774+
The optional step value of iteration.
775+
755776
Returns
756777
-------
757778
res : frame.ForFrame
@@ -763,11 +784,15 @@ def vectorized(
763784
start = IntImm(start.dtype, 0)
764785
else:
765786
start = 0
766-
return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
787+
return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
767788

768789

769790
def unroll(
770-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
791+
start: PrimExpr,
792+
stop: PrimExpr = None,
793+
*,
794+
annotations: Dict[str, Any] = None,
795+
step: Optional[PrimExpr] = None,
771796
) -> frame.ForFrame:
772797
"""The unrolled For statement.
773798
@@ -782,6 +807,9 @@ def unroll(
782807
annotations : Dict[str, Any]
783808
The optional annotations of the For statement.
784809
810+
step : PrimExpr
811+
The optional step value of iteration.
812+
785813
Returns
786814
-------
787815
res : frame.ForFrame
@@ -793,7 +821,7 @@ def unroll(
793821
start = IntImm(start.dtype, 0)
794822
else:
795823
start = 0
796-
return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
824+
return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
797825

798826

799827
def thread_binding(

python/tvm/script/parser/tir/parser.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import contextlib
2020
from functools import partial
21-
from typing import Any
21+
from typing import Any, Dict, Optional
2222

2323
import tvm
2424
from tvm.ir import GlobalVar, PrimType
@@ -168,6 +168,28 @@ def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: b
168168
return default
169169

170170

171+
def range_sugar(
172+
start: PrimExpr,
173+
stop: PrimExpr = None,
174+
step: Optional[PrimExpr] = None,
175+
*,
176+
annotations: Dict[str, Any] = None,
177+
) -> T.frame.ForFrame:
178+
"""The sugar for python range builtin."""
179+
180+
# Since `tir.For` do not support reversed iteration semantic,
181+
# the step must be checked to be positive integer when use range sugar
182+
if step is not None:
183+
try:
184+
step = int(step)
185+
if step <= 0:
186+
raise ValueError(f"Only support positive step in range(), get {step}")
187+
except TypeError: # pylint: disable=broad-except
188+
raise ValueError(f"Only support literal step in range(), get {step}")
189+
190+
return T.serial(start, stop, annotations=annotations, step=step)
191+
192+
171193
@dispatch.register(token="tir", type_name="For")
172194
def visit_for(self: Parser, node: doc.For) -> None:
173195
"""The for visiting method for tir.
@@ -379,7 +401,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
379401
privacy = find_decorator_annotation(node, "private", default=False)
380402
self.function_annotations = None
381403
with self.var_table.with_frame():
382-
self.var_table.add("range", T.serial)
404+
405+
self.var_table.add("range", range_sugar)
383406
with T.prim_func(is_private=privacy):
384407
T.func_name(node.name)
385408
if node.returns is not None:

python/tvm/tir/ir_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def scope_attr(self, node, attr_key, value):
202202
value = op.max(1, value)
203203
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
204204

205-
def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
205+
def for_range(self, begin, end, name="i", dtype=None, kind="serial", step=None):
206206
"""Create a for iteration scope.
207207
208208
Parameters
@@ -223,6 +223,10 @@ def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
223223
kind : str, optional
224224
The special tag on the for loop.
225225
226+
step : PrimExpr
227+
The loop step. Default to none which
228+
represent one.
229+
226230
Returns
227231
-------
228232
loop_scope : With.Scope of Var
@@ -275,7 +279,7 @@ def _exit_cb():
275279
kind_id = _stmt.ForKind.UNROLLED
276280
else:
277281
raise ValueError("Unknown kind")
278-
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))
282+
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), step=step))
279283

280284
return WithScope(loop_var, _exit_cb)
281285

python/tvm/tir/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
3131
pass_ctx = tvm.transform.PassContext.current()
3232
config = pass_ctx.config
3333
passes = [
34+
tir.transform.CanonicalizeLoop(),
3435
tir.transform.LowerCrossThreadReduction(),
3536
tir.transform.LowerInitBlock(),
3637
tir.transform.PlanAndUpdateBufferAllocationLocation(),

python/tvm/tir/stmt.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ class For(Stmt):
145145
The thread this loop binds to. Only valid
146146
if kind is ThreadBinding
147147
148+
step : PrimExpr
149+
The loop step. Default to none which
150+
represent one.
151+
148152
annotations: Optional[Mapping[str, Object]]
149153
Additional annotation hints.
150154
@@ -159,6 +163,7 @@ class For(Stmt):
159163
body: Stmt
160164
thread_binding: Optional[IterVar]
161165
annotations: Mapping[str, Object]
166+
step: Optional[PrimExpr]
162167
span: Optional[Span]
163168

164169
def __init__(
@@ -170,6 +175,7 @@ def __init__(
170175
body: Stmt,
171176
thread_binding: Optional[IterVar] = None,
172177
annotations: Optional[Mapping[str, Object]] = None,
178+
step: Optional[PrimExpr] = None,
173179
span: Optional[Span] = None,
174180
) -> None:
175181
self.__init_handle_by_constructor__(
@@ -181,6 +187,7 @@ def __init__(
181187
body,
182188
thread_binding,
183189
annotations,
190+
step,
184191
span,
185192
)
186193

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,3 +1171,14 @@ def LowerVtcmAlloc():
11711171
The result pass
11721172
"""
11731173
return _ffi_api.LowerVtcmAlloc() # type: ignore
1174+
1175+
1176+
def CanonicalizeLoop():
1177+
"""Canonicalize the loop to start from zero and use trivial step
1178+
1179+
Returns
1180+
-------
1181+
fpass : tvm.transform.Pass
1182+
The result pass
1183+
"""
1184+
return _ffi_api.CanonicalizeLoop() # type: ignore

src/relax/distributed/transform/lower_global_view_to_local_view.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator {
330330
if (shard > 1) {
331331
arith::Analyzer analyzer;
332332
ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0));
333-
return For(new_loop->loop_var, new_loop->min, floordiv(new_loop->extent, shard),
334-
new_loop->kind, new_loop->body, new_loop->thread_binding, new_loop->annotations);
333+
new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
334+
return new_loop;
335335
}
336336
}
337337
return new_loop;

0 commit comments

Comments
 (0)