-
这个是split之后的script: # split j to [j0, j1] shc.mod:
@tvm.script.ir_module
class Module:
@tir.prim_func
def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "mm_relu", "tir_noalias": True})
# body
# with tir.block("root")
Y = tir.alloc_buffer([128, 128], dtype="float32")
for i, j_0, j_1, k in tir.grid(128, 32, 4, 128):
with tir.block("Y"):
vi = tir.axis.spatial(128, i)
vj = tir.axis.spatial(128, j_0 * 4 + j_1)
vk = tir.axis.reduce(128, k)
tir.reads(A[vi, vk], B[vk, vj])
tir.writes(Y[vi, vj])
with tir.init():
Y[vi, vj] = tir.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in tir.grid(128, 128):
with tir.block("C"):
vi, vj = tir.axis.remap("SS", [i, j])
tir.reads(Y[vi, vj])
tir.writes(C[vi, vj])
C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0)) 这个是reverse_compute_at变换后的script: # move block C to block Y
@tvm.script.ir_module
class Module:
@tir.prim_func
def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "mm_relu", "tir_noalias": True})
# body
# with tir.block("root")
Y = tir.alloc_buffer([128, 128], dtype="float32")
for i, j_0 in tir.grid(128, 32):
for j_1, k in tir.grid(4, 128):
with tir.block("Y"):
vi = tir.axis.spatial(128, i)
vj = tir.axis.spatial(128, j_0 * 4 + j_1)
vk = tir.axis.reduce(128, k)
tir.reads(A[vi, vk], B[vk, vj])
tir.writes(Y[vi, vj])
with tir.init():
Y[vi, vj] = tir.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in tir.serial(4):
with tir.block("C"):
vi = tir.axis.spatial(128, i)
vj = tir.axis.spatial(128, j_0 * 4 + ax0)
tir.reads(Y[vi, vj])
tir.writes(C[vi, vj])
C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))
然后再做reorder之后,提示这个报错: Traceback (most recent call last):
File "/home/sanzo/mlc/case_study.py", line 98, in <module>
sch.reorder(j0, k, j1)
File "/home/sanzo/software/miniconda/4.12/envs/sanzo/lib/python3.8/site-packages/tvm/tir/schedule/_type_checker.py", line 237, in wrap
return func(*args, **kwargs)
File "/home/sanzo/software/miniconda/4.12/envs/sanzo/lib/python3.8/site-packages/tvm/tir/schedule/schedule.py", line 691, in reorder
_ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member
File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
3: TVMFuncCall
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::Array<tvm::tir::LoopRV, void> const&)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::tir::Schedule, tvm::tir::ScheduleNode, void, tvm::runtime::Array<tvm::tir::LoopRV, void> const&, void>(void (tvm::tir::ScheduleNode::*)(tvm::runtime::Array<tvm::tir::LoopRV, void> const&))::{lambda(tvm::tir::Schedule, tvm::runtime::Array<tvm::tir::LoopRV, void> const&)#1}>(tvm::runtime::Registry::set_body_method<tvm::tir::Schedule, tvm::tir::ScheduleNode, void, tvm::runtime::Array<tvm::tir::LoopRV, void> const&, void>(void (tvm::tir::ScheduleNode::*)(tvm::runtime::Array<tvm::tir::LoopRV, void> const&))::{lambda(tvm::tir::Schedule, tvm::runtime::Array<tvm::tir::LoopRV, void> const&)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
1: tvm::tir::TracedScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
0: tvm::tir::ConcreteScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&) [clone .cold]
ScheduleError: An error occurred in the schedule primitive 'reorder'.
The IR with diagnostic is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "mm_relu", "tir_noalias": True})
# body
# with T.block("root")
Y = T.alloc_buffer([128, 128], dtype="float32")
for i in T.serial(128):
# tir.For#0
for j_0 in T.serial(32):
^^^^^^^^^^^^^^^^^^^^^^^^
for j_1, k in T.grid(4, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in T.serial(4):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
Error message: The loops are not in a chain because there is a non-single-branch stmt in between. Problematic stmt: tir.For#0 报错信息好像说循环里面存在多个分支?(block Y和block C),但是我觉得reorder这个变换并不会影响block C的表示形式,不知道为什么会出现这个问题? |
Beta Was this translation helpful? Give feedback.
Answered by
Hzfengsy
Jul 19, 2022
Replies: 1 comment 1 reply
-
感谢提问,这是一个很好的问题。目前primitive的核心是“保证正确”,即任一合法IR我们保证变换之后的结果一定是正确的。 我们有两种策略可以解决问题:
|
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
Sanzo00
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
感谢提问,这是一个很好的问题。目前primitive的核心是“保证正确”,即任一合法IR我们保证变换之后的结果一定是正确的。
这样的策略会要求我们有一个正确性检查,但正确性检查往往不能够做到“精确”(即有可能合法变换会被禁用),你提到的就是这样一个例子。
我们有两种策略可以解决问题: