Skip to content

Commit f1dd6b3

Browse files
authored
[mlir][tensor] Fix createFillOrGenerateOp (llvm#121205)
This PR clones the padding value defined inside the PadOp block to outside to prevent a crash. Fixes llvm#120947.
1 parent 83344da commit f1dd6b3

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,8 +927,12 @@ Value DecomposePadOpPattern::createFillOrGenerateOp(
927927
RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
928928
const SmallVector<Value> &dynSizes) const {
929929
auto padValue = padOp.getConstantPaddingValue();
930-
if (padValue)
930+
if (padValue) {
931+
// Move the padding value defined inside the PadOp block to outside.
932+
if (padValue.getParentBlock() == &padOp.getRegion().front())
933+
rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
931934
return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
935+
}
932936

933937
// Fill could not be optimized: Lower to tensor::GenerateOp with region.
934938
auto generateOp = rewriter.create<tensor::GenerateOp>(

mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,22 @@ func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1
4444
} : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32>
4545
return %out : tensor<4x?x?x?xf32>
4646
}
47+
48+
// -----
49+
50+
// CHECK-LABEL: func.func @generalize_pad_tensor_constant_inside(
51+
// CHECK-SAME: %[[SRC:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
52+
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x32x32x1xf32>
53+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
54+
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32>
55+
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32>
56+
// CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32>
57+
// CHECK: }
58+
func.func @generalize_pad_tensor_constant_inside(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
59+
%0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
60+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
61+
%cst = arith.constant 0.000000e+00 : f32
62+
tensor.yield %cst : f32
63+
} : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
64+
return %0 : tensor<1x32x32x1xf32>
65+
}

0 commit comments

Comments
 (0)