Skip to content

Commit 8798e80

Browse files
authored
Merge pull request #444 from Xilinx/koshrai2.backport.pad.lowering.changes
Allow padValue to be non-constant
2 parents 86d29ea + 3ba47e1 commit 8798e80

File tree

2 files changed

+66
-34
lines changed

2 files changed

+66
-34
lines changed

src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
2323
#include "src/Dialect/ONNX/ONNXOps.hpp"
2424
#include "llvm/ADT/APFloat.h"
25+
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/ADT/SmallVector.h"
26-
#include "llvm/Support/Casting.h"
2727

2828
using namespace mlir;
2929

@@ -61,11 +61,9 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
6161
op, "only 'constant' mode is supported");
6262
}
6363

64-
if (!pads.getDefiningOp<mlir::tosa::ConstOp>() ||
65-
!(constValue.getDefiningOp<mlir::tosa::ConstOp>() ||
66-
constValue.getDefiningOp<ONNXNoneOp>())) {
64+
if (!pads.getDefiningOp<mlir::tosa::ConstOp>()) {
6765
return rewriter.notifyMatchFailure(
68-
op, "only tosa.const operands are supported");
66+
op, "only tosa.const 'padding' values are supported");
6967
}
7068
// creating the DenseElementsAttr using pads values.
7169
auto denseAttr = tosa::getValueFromTosaConst<ElementsAttr>(pads);
@@ -90,33 +88,7 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
9088
mlir::Type resultType =
9189
getTypeConverter()->convertType(op.getResult().getType());
9290

93-
if (!isa<NoneType>(constValue.getType())) {
94-
auto valueAttr = tosa::getValueFromTosaConst<ElementsAttr>(constValue);
95-
TosaBuilder tosaBuilder(rewriter, loc);
96-
97-
Value constTosaTensor;
98-
if (isa<FloatType>(valueAttr.getElementType())) {
99-
auto valueIt = valueAttr.getValues<FloatAttr>().begin();
100-
const float valueFloat = cast<FloatAttr>(*valueIt).getValueAsDouble();
101-
constTosaTensor = tosaBuilder.getSplattedConst(
102-
valueFloat, valueAttr.getElementType(), 0);
103-
} else {
104-
assert(isTOSAInt(elementDtype) && "Already validated");
105-
auto valueIt = valueAttr.getValues<IntegerAttr>().begin();
106-
auto valueAsAPInt = cast<IntegerAttr>(*valueIt).getValue();
107-
auto asIntegerTy = cast<IntegerType>(valueAttr.getElementType());
108-
if (asIntegerTy.isUnsigned()) {
109-
constTosaTensor = tosaBuilder.getSplattedConst(
110-
valueAsAPInt.getZExtValue(), asIntegerTy, 0);
111-
} else {
112-
constTosaTensor = tosaBuilder.getSplattedConst(
113-
valueAsAPInt.getSExtValue(), asIntegerTy, 0);
114-
}
115-
}
116-
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
117-
op, resultType, data, padsList1, constTosaTensor);
118-
119-
} else {
91+
if (isa<NoneType>(constValue.getType())) {
12092
auto constType = RankedTensorType::get({}, elementDtype);
12193

12294
DenseElementsAttr constAttr;
@@ -134,8 +106,12 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
134106
padsList1,
135107
rewriter.create<mlir::tosa::ConstOp>(
136108
op->getLoc(), constType, constAttr));
109+
} else {
110+
TosaBuilder tosaBuilder(rewriter, loc);
111+
Value reshapeToSplattedConst = tosaBuilder.reshape(constValue, {});
112+
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
113+
op, resultType, data, padsList1, reshapeToSplattedConst);
137114
}
138-
139115
return success();
140116
}
141117
};

test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --cse %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --canonicalize --cse %s -split-input-file | FileCheck %s
22

33
func.func @test_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> {
44
%noval = "onnx.NoValue"() {value} : () -> none
@@ -160,3 +160,59 @@ func.func @test_pad_f16_constant_none(%arg0: tensor<256x1x1x5x1xf16>) -> tensor<
160160
// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] : (tensor<256x1x1x5x1xf16>, !tosa.shape<10>, tensor<f16>) -> tensor<256x1x1x5x2xf16>
161161
// CHECK: return %[[VAR2]] : tensor<256x1x1x5x2xf16>
162162
}
163+
164+
// -----
165+
166+
func.func @test_pad_f32_non_constant_padval(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<f32>) -> tensor<24x22x52x42xf32> {
167+
%noval = "onnx.NoValue"() {value} : () -> none
168+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
169+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<f32>, none) -> tensor<24x22x52x42xf32>
170+
return %2 : tensor<24x22x52x42xf32>
171+
// CHECK-LABEL: func.func @test_pad_f32_non_constant_padval
172+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<f32>) -> tensor<24x22x52x42xf32> {
173+
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
174+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf32>, !tosa.shape<8>, tensor<f32>) -> tensor<24x22x52x42xf32>
175+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf32>
176+
}
177+
178+
// -----
179+
180+
func.func @test_pad_f32_non_constant_1Dpadval(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<1xf32>) -> tensor<24x22x52x42xf32> {
181+
%noval = "onnx.NoValue"() {value} : () -> none
182+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
183+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<24x22x52x42xf32>
184+
return %2 : tensor<24x22x52x42xf32>
185+
// CHECK-LABEL: func.func @test_pad_f32_non_constant_1Dpadval
186+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<24x22x52x42xf32> {
187+
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
188+
// CHECK-DAG: [[VAL_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64>} : (tensor<1xf32>) -> tensor<f32>
189+
// CHECK: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAL_1_]] : (tensor<20x16x44x32xf32>, !tosa.shape<8>, tensor<f32>) -> tensor<24x22x52x42xf32>
190+
// CHECK: return [[VAR_2_]] : tensor<24x22x52x42xf32>
191+
}
192+
193+
// -----
194+
195+
func.func @test_pad_i64_non_constant_padval(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<i64>) -> tensor<24x22x52x42xi64> {
196+
%noval = "onnx.NoValue"() {value} : () -> none
197+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
198+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<i64>, none) -> tensor<24x22x52x42xi64>
199+
return %2 : tensor<24x22x52x42xi64>
200+
// CHECK-LABEL: func.func @test_pad_i64_non_constant_padval
201+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xi64>, [[PARAM_1_:%.+]]: tensor<i64>) -> tensor<24x22x52x42xi64> {
202+
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
203+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xi64>, !tosa.shape<8>, tensor<i64>) -> tensor<24x22x52x42xi64>
204+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xi64>
205+
}
206+
207+
// -----
208+
func.func @test_pad_f16_non_constant_padval(%arg0: tensor<20x16x44x32xf16>, %arg1: tensor<f16>) -> tensor<24x22x52x42xf16> {
209+
%noval = "onnx.NoValue"() {value} : () -> none
210+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
211+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf16>, tensor<8xi64>, tensor<f16>, none) -> tensor<24x22x52x42xf16>
212+
return %2 : tensor<24x22x52x42xf16>
213+
// CHECK-LABEL: func.func @test_pad_f16_non_constant_padval
214+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf16>, [[PARAM_1_:%.+]]: tensor<f16>) -> tensor<24x22x52x42xf16> {
215+
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
216+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf16>, !tosa.shape<8>, tensor<f16>) -> tensor<24x22x52x42xf16>
217+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf16>
218+
}

0 commit comments

Comments
 (0)