@@ -64,6 +64,12 @@ def __repr__(self) -> str:
6464
6565
6666get_blockrealize = get_global_func ("tir.schedule.GetBlockRealize" )
67+ # BufferIndex Types
68+ Index = namedtuple ("Index" , ["sub" ]) # c
69+ RemIndex = namedtuple ("RemIndex" , ["sub" , "div" ]) # c%len
70+ DivIndex = namedtuple ("DivIndex" , ["sub" , "div" ]) # c//len
71+ MergeIndex = namedtuple ("MulIndex" , ["dom" , "mul" , "sub" ]) # co*len + cb
72+ BufIndex = List [Union [Index , RemIndex , DivIndex , MergeIndex , None ]]
6773
6874
6975# TODO: Shift Vlen Calculation here...
@@ -74,13 +80,6 @@ class BufferInfo:
7480 assoc_lps : List [Union [tir .schedule .LoopRV , None ]]
7581 assoc_lps_info : List [Union [tir .For , None ]]
7682
77- # BufferIndex Types
78- Index = namedtuple ("Index" , ["sub" ]) # c
79- RemIndex = namedtuple ("RemIndex" , ["sub" , "div" ]) # c%len
80- DivIndex = namedtuple ("DivIndex" , ["sub" , "div" ]) # c//len
81- MergeIndex = namedtuple ("MulIndex" , ["dom" , "mul" , "sub" ]) # co*len + cb
82- BufIndex = List [Union [Index , RemIndex , DivIndex , MergeIndex , None ]]
83-
8483 def __init__ (
8584 self ,
8685 sch : tir .Schedule ,
@@ -172,8 +171,6 @@ class BlockInfo:
172171 iters : List [IterInfo ]
173172 block_rv : tir .schedule .BlockRV
174173 _reduction_block : bool
175- read_bufs : List [BufferInfo ]
176- write_bufs : List [BufferInfo ]
177174
178175 def __init__ (
179176 self ,
@@ -192,6 +189,16 @@ def dom(self) -> List[Union[int, tir.PrimExpr]]:
192189 """The iteration domain of the block."""
193190 return [i .dom for i in self .iters ]
194191
192+ def read_bufs (self , sch : tir .Schedule ) -> List [BufferInfo ]:
193+ block_stmt = sch .get (self .block_rv )
194+ lps = sch .get_loops (self .block_rv )
195+ return [BufferInfo (sch , self .block_rv , buf , lps ) for buf in block_stmt .reads ]
196+
197+ def write_bufs (self , sch : tir .Schedule ) -> List [BufferInfo ]:
198+ block_stmt = sch .get (self .block_rv )
199+ lps = sch .get_loops (self .block_rv )
200+ return [BufferInfo (sch , self .block_rv , buf , lps ) for buf in block_stmt .writes ]
201+
195202 def dom_kind (self ) -> str :
196203 """The iteration domain kind of the block, for example, SSSS, SSSR."""
197204 return "" .join (i .kind for i in self .iters )
@@ -216,7 +223,7 @@ def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool:
216223 if len (r_region ) != len (w_region ):
217224 return False
218225 for var , r_dom , w_dom in zip (block .iter_vars , r_region , w_region ):
219- if not _check_unit_var_range (var , r_dom ) or not _check_unit_var_range (var , w_dom ):
226+ if not _check_unit_var_range (r_dom , var ) or not _check_unit_var_range (w_dom , var ):
220227 return False
221228 return True
222229
@@ -230,31 +237,23 @@ def is_reduction(self) -> bool:
230237
231238 def is_layout_transform (self , sch : tir .Schedule ) -> bool :
232239 """Whether the Block can be considered having a Layout Transform Pattern"""
233- block_stmt = sch .get (self .block_rv )
234- lps = sch .get_loops (block_rv )
235- read_bufs = [BufferInfo (sch , self .block_rv , buf , lps ) for buf in block_stmt .reads ]
236- write_bufs = [BufferInfo (sch , self .block_rv , buf , lps ) for buf in block_stmt .writes ]
237240 return (
238241 all (k == "S" for k in self .dom_kind ())
239- and len (write_bufs ) == 1
240- and len (read_bufs ) == 1
242+ and len (self . write_bufs ( sch ) ) == 1
243+ and len (self . read_bufs ( sch ) ) == 1
241244 and not self .is_elementwise (sch )
242245 and not get_global_func ("tir.schedule.HasIfThenElse" )(sch .get (self .block_rv ))
243246 )
244247
245248 def is_data_pad (self , sch : tir .Schedule ) -> bool :
246249 """Whether the Block can be considered having a data pad pattern"""
247- block_stmt = sch .get (self .block_rv )
248- lps = sch .get_loops (block_rv )
249- read_bufs = [BufferInfo (sch , self .block_rv , buf , lps ) for buf in block_stmt .reads ]
250- write_bufs = [BufferInfo (sch , self .block_rv , buf , lps ) for buf in block_stmt .writes ]
251250 return (
252251 all (k == "S" for k in self .dom_kind ())
253- and len (write_bufs ) == 1
254- and len (read_bufs ) == 1
252+ and len (self . write_bufs ( sch ) ) == 1
253+ and len (self . read_bufs ( sch ) ) == 1
255254 and not self .is_elementwise (sch )
256- and len (self .write_bufs [0 ].buf_region .region )
257- == len (self .read_bufs [0 ].buf_region .region )
255+ and len (self .write_bufs ( sch ) [0 ].buf_region .region )
256+ == len (self .read_bufs ( sch ) [0 ].buf_region .region )
258257 and get_global_func ("tir.schedule.HasIfThenElse" )(sch .get (self .block_rv ))
259258 )
260259
0 commit comments