@@ -58,6 +58,62 @@ module {
5858 }
5959}
6060
61+ // -----
62+ // For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
63+
64+ module {
65+ // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
66+ // CHECK: tensor.insert_slice
67+ // CHECK: scf.forall {{.*}} {
68+ // CHECK: scf.forall.in_parallel
69+ // CHECK: }
70+ func.func @fuse_pack_as_producer_blocked_by_insert_slice (%src: tensor <128 x256 xf32 >, %other: tensor <4 x4 x128 x256 xf32 >)
71+ -> tensor <4 x4 x128 x256 xf32 > {
72+ %dest = tensor.empty () : tensor <1 x1 x128 x256 xf32 >
73+ %pack = tensor.pack %src inner_dims_pos = [0 , 1 ] inner_tiles = [128 , 256 ]
74+ into %dest : tensor <128 x256 xf32 > -> tensor <1 x1 x128 x256 xf32 >
75+
76+ %out = tensor.empty () : tensor <4 x4 x128 x256 xf32 >
77+ %res = linalg.generic
78+ {index ing_maps = [affine_map <(i , j , k , l ) -> (0 , 0 , k , l )>,
79+ affine_map <(i , j , k , l ) -> (i , j , k , l )>,
80+ affine_map <(i , j , k , l ) -> (i , j , k , l )>],
81+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
82+ ins (%pack , %other: tensor <1 x1 x128 x256 xf32 >, tensor <4 x4 x128 x256 xf32 >)
83+ outs (%out: tensor <4 x4 x128 x256 xf32 >) {
84+ ^bb0 (%pack_elem: f32 , %other_elem: f32 , %out_elem: f32 ):
85+ %r = arith.addf %pack_elem , %other_elem : f32
86+ linalg.yield %r : f32
87+ } -> tensor <4 x4 x128 x256 xf32 >
88+
89+ return %res : tensor <4 x4 x128 x256 xf32 >
90+ }
91+
92+ module attributes {transform.with_named_sequence } {
93+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
94+ // Find and lower pack operation.
95+ %pack = transform.structured.match ops {[" tensor.pack" ]} in %arg1
96+ : (!transform.any_op ) -> !transform.op <" tensor.pack" >
97+ %paded , %expanded , %transpose = transform.structured.lower_pack %pack
98+ : (!transform.op <" tensor.pack" >)
99+ -> (!transform.op <" tensor.pad" >,
100+ !transform.op <" tensor.expand_shape" >,
101+ !transform.op <" linalg.transpose" >)
102+
103+ %root = transform.structured.match ops {[" linalg.generic" ]} in %arg1
104+ : (!transform.any_op ) -> !transform.any_op
105+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
106+ %tiled_op , %forall_op = transform.structured.tile_using_forall %root num_threads [4 , 4 ]
107+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
108+
109+ // Fuse the transpose operation into the tiled loop.
110+ transform.structured.fuse_into_containing_op %transpose into %forall_op
111+ : (!transform.op <" linalg.transpose" >, !transform.any_op ) -> (!transform.any_op , !transform.any_op )
112+ transform.yield
113+ }
114+ }
115+ }
116+
61117// -----
62118// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
63119// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
@@ -119,3 +175,64 @@ module {
119175 }
120176 }
121177}
178+
179+ // -----
180+ // For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
181+
182+ module {
183+ // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
184+ // CHECK: scf.forall {{.*}} {
185+ // CHECK: linalg.generic
186+ // CHECK: scf.forall.in_parallel
187+ // CHECK: }
188+ // CHECK: tensor.extract_slice
189+ func.func @fuse_unpack_as_consumer_blocked_by_extract_slice (%src: tensor <4 x4 x128 x256 xf32 >, %other: tensor <4 x4 x128 x256 xf32 >)
190+ -> tensor <128 x256 xf32 > {
191+ %out = tensor.empty () : tensor <1 x1 x128 x256 xf32 >
192+ %res = linalg.generic
193+ {index ing_maps = [affine_map <(i , j , k , l ) -> (i , j , k , l )>,
194+ affine_map <(i , j , k , l ) -> (i , j , k , l )>,
195+ affine_map <(i , j , k , l ) -> (0 , 0 , k , l )>],
196+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
197+ ins (%src , %other: tensor <4 x4 x128 x256 xf32 >, tensor <4 x4 x128 x256 xf32 >)
198+ outs (%out: tensor <1 x1 x128 x256 xf32 >) {
199+ ^bb0 (%unpack_elem: f32 , %other_elem: f32 , %out_elem: f32 ):
200+ %r = arith.addf %unpack_elem , %other_elem : f32
201+ linalg.yield %r : f32
202+ } -> tensor <1 x1 x128 x256 xf32 >
203+
204+ %dest = tensor.empty () : tensor <128 x256 xf32 >
205+ %unpack = tensor.unpack %res inner_dims_pos = [0 , 1 ] inner_tiles = [128 , 256 ]
206+ into %dest : tensor <1 x1 x128 x256 xf32 > -> tensor <128 x256 xf32 >
207+
208+ return %unpack : tensor <128 x256 xf32 >
209+ }
210+
211+ module attributes {transform.with_named_sequence } {
212+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
213+ // Find and lower unpack operation.
214+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %arg1
215+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
216+ transform.structured.lower_unpack %unpack
217+ : (!transform.op <" tensor.unpack" >)
218+ -> (!transform.op <" tensor.empty" >,
219+ !transform.op <" linalg.transpose" >,
220+ !transform.op <" tensor.collapse_shape" >,
221+ !transform.op <" tensor.extract_slice" >)
222+
223+ %root = transform.structured.match ops {[" linalg.generic" ]} in %arg1
224+ : (!transform.any_op ) -> !transform.any_op
225+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
226+ %tiled_op , %forall_op = transform.structured.tile_using_forall %root num_threads [4 , 4 ]
227+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
228+
229+ // Fuse the consumer operation into the tiled loop.
230+ %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %forall_op
231+ : (!transform.any_op ) -> !transform.op <" tensor.parallel_insert_slice" >
232+ // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
233+ // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
234+ // to fuse" error.
235+ transform.yield
236+ }
237+ }
238+ }
0 commit comments