Skip to content

Commit 2ac194e

Browse files
[GPU] Add pattern to sink extract_slice through generic ops
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 639c7cf commit 2ac194e

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,33 @@ void GPUPackToIntrinsicsPass::runOnOperation() {
289289
return !getLoweringConfig(producer) && !getLoweringConfig(consumer);
290290
};
291291

292+
// Additionally, we do not sink extract slice through generic if slice source
293+
// is a block argument or if the source, slice or generic are in different
294+
// blocks as this would affect how tiling uses extract slice ops.
295+
296+
linalg::ControlPropagationFn controlExtract =
297+
[](OpOperand *opOperand) -> bool {
298+
Operation *producer = opOperand->get().getDefiningOp();
299+
Operation *consumer = opOperand->getOwner();
300+
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
301+
return false;
302+
}
303+
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(producer);
304+
if (!sliceOp) {
305+
return false;
306+
}
307+
Operation *producerSrc = sliceOp.getSource().getDefiningOp();
308+
// If source is not an op, e.g is a block argument then return false.
309+
if (!producerSrc) {
310+
return false;
311+
}
312+
313+
return producerSrc->getBlock() == producer->getBlock() &&
314+
consumer->getBlock() == producer->getBlock();
315+
};
316+
292317
linalg::populateDataLayoutPropagationPatterns(patterns, control);
318+
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
293319
patterns.add<PackDestinationForOp>(context);
294320
linalg::UnPackOp::getCanonicalizationPatterns(patterns, context);
295321
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_pack_to_instrinsics.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,81 @@ func.func @hoist_pack_unpack_multiple_loop(%arg0 : tensor<1x1x4x2x16x16xbf16>, %
309309
// CHECK: scf.yield %[[INNER_FOR_RESULT]] : tensor<1x1x4x2x16x16xf32>
310310
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32>
311311
// CHECK: linalg.unpack %[[OUTER_FOR_RESULT]] inner_dims_pos = [2, 3] inner_tiles = [16, 16] into %[[EMPTY:.+]] : tensor<1x1x4x2x16x16xf32> -> tensor<1x1x64x32xf32>
312+
313+
// -----
314+
315+
func.func @propagate_extract_basic(%arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
316+
%empty = tensor.empty() : tensor<128xf32>
317+
%extracted_slice = tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
318+
%generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg1 : tensor<?xbf16>) {
319+
^bb0(%in: f32, %out: bf16):
320+
%1 = arith.truncf %in : f32 to bf16
321+
linalg.yield %1 : bf16
322+
} -> tensor<?xbf16>
323+
return %generic : tensor<?xbf16>
324+
}
325+
326+
// CHECK-LABEL: func.func @propagate_extract_basic
327+
// CHECK: linalg.generic
328+
// CHECK: tensor.extract_slice
329+
330+
// -----
331+
332+
func.func @no_propagate_extract_blockargument(%input : tensor<128xf32>, %arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
333+
%extracted_slice = tensor.extract_slice %input[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
334+
%generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg1 : tensor<?xbf16>) {
335+
^bb0(%in: f32, %out: bf16):
336+
%1 = arith.truncf %in : f32 to bf16
337+
linalg.yield %1 : bf16
338+
} -> tensor<?xbf16>
339+
return %generic : tensor<?xbf16>
340+
}
341+
342+
// CHECK-LABEL: func.func @no_propagate_extract_blockargument
343+
// CHECK: tensor.extract_slice
344+
// CHECK: linalg.generic
345+
346+
347+
// -----
348+
349+
func.func @no_propagate_extract_differentblock_1(%arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
350+
%empty = tensor.empty() : tensor<128xf32>
351+
%c0 = arith.constant 0 : index
352+
%c32 = arith.constant 32 : index
353+
%for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args(%arg4 = %arg1) -> tensor<?xbf16> {
354+
%extracted_slice = tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
355+
%generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg4 : tensor<?xbf16>) {
356+
^bb0(%in: f32, %out: bf16):
357+
%1 = arith.truncf %in : f32 to bf16
358+
linalg.yield %1 : bf16
359+
} -> tensor<?xbf16>
360+
scf.yield %generic : tensor<?xbf16>
361+
}
362+
return %for : tensor<?xbf16>
363+
}
364+
365+
// CHECK-LABEL: func.func @no_propagate_extract_differentblock_1
366+
// CHECK: tensor.extract_slice
367+
// CHECK: linalg.generic
368+
369+
// -----
370+
371+
func.func @no_propagate_extract_differentblock_2(%arg0 : index, %arg1 : tensor<?xbf16>) -> tensor<?xbf16> {
372+
%empty = tensor.empty() : tensor<128xf32>
373+
%c0 = arith.constant 0 : index
374+
%c32 = arith.constant 32 : index
375+
%extracted_slice = tensor.extract_slice %empty[%arg0] [%arg0] [1] : tensor<128xf32> to tensor<?xf32>
376+
%for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args(%arg4 = %arg1) -> tensor<?xbf16> {
377+
%generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<?xf32>) outs(%arg4 : tensor<?xbf16>) {
378+
^bb0(%in: f32, %out: bf16):
379+
%1 = arith.truncf %in : f32 to bf16
380+
linalg.yield %1 : bf16
381+
} -> tensor<?xbf16>
382+
scf.yield %generic : tensor<?xbf16>
383+
}
384+
return %for : tensor<?xbf16>
385+
}
386+
387+
// CHECK-LABEL: func.func @no_propagate_extract_differentblock_2
388+
// CHECK: tensor.extract_slice
389+
// CHECK: linalg.generic

0 commit comments

Comments
 (0)