Skip to content

Commit 80be0e5

Browse files
sayeddlakosh-rai
authored andcommitted
Allow padValue to be non-constant
1 parent 189c863 commit 80be0e5

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,9 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
6161
op, "only 'constant' mode is supported");
6262
}
6363

64-
if (!pads.getDefiningOp<mlir::tosa::ConstOp>() ||
65-
!(constValue.getDefiningOp<mlir::tosa::ConstOp>() ||
66-
constValue.getDefiningOp<ONNXNoneOp>())) {
64+
if (!pads.getDefiningOp<mlir::tosa::ConstOp>()) {
6765
return rewriter.notifyMatchFailure(
68-
op, "only tosa.const operands are supported");
66+
op, "only tosa.const 'padding' values are supported");
6967
}
7068
// creating the DenseElementsAttr using pads values.
7169
auto denseAttr = tosa::getValueFromTosaConst<ElementsAttr>(pads);
@@ -90,7 +88,28 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
9088
mlir::Type resultType =
9189
getTypeConverter()->convertType(op.getResult().getType());
9290

93-
if (!isa<NoneType>(constValue.getType())) {
91+
if (isa<NoneType>(constValue.getType())) {
92+
auto constType = RankedTensorType::get({}, elementDtype);
93+
94+
DenseElementsAttr constAttr;
95+
if (auto floatType = dyn_cast<FloatType>(elementDtype)) {
96+
constAttr = DenseElementsAttr::get(
97+
constType, APFloat::getZero(floatType.getFloatSemantics()));
98+
} else {
99+
assert(isTOSAInt(elementDtype) && "Already validated");
100+
auto tyAsInt = cast<IntegerType>(elementDtype);
101+
constAttr = DenseElementsAttr::get(constType,
102+
llvm::APInt(tyAsInt.getWidth(), 0, tyAsInt.getSignedness()));
103+
}
104+
105+
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(op, resultType, data,
106+
padsList1,
107+
rewriter.create<mlir::tosa::ConstOp>(
108+
op->getLoc(), constType, constAttr));
109+
} else if (!constValue.getDefiningOp<mlir::tosa::ConstOp>()) {
110+
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
111+
op, resultType, data, padsList1, constValue);
112+
} else {
94113
auto valueAttr = tosa::getValueFromTosaConst<ElementsAttr>(constValue);
95114
TosaBuilder tosaBuilder(rewriter, loc);
96115

@@ -115,25 +134,6 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
115134
}
116135
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
117136
op, resultType, data, padsList1, constTosaTensor);
118-
119-
} else {
120-
auto constType = RankedTensorType::get({}, elementDtype);
121-
122-
DenseElementsAttr constAttr;
123-
if (auto floatType = dyn_cast<FloatType>(elementDtype)) {
124-
constAttr = DenseElementsAttr::get(
125-
constType, APFloat::getZero(floatType.getFloatSemantics()));
126-
} else {
127-
assert(isTOSAInt(elementDtype) && "Already validated");
128-
auto tyAsInt = cast<IntegerType>(elementDtype);
129-
constAttr = DenseElementsAttr::get(constType,
130-
llvm::APInt(tyAsInt.getWidth(), 0, tyAsInt.getSignedness()));
131-
}
132-
133-
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(op, resultType, data,
134-
padsList1,
135-
rewriter.create<mlir::tosa::ConstOp>(
136-
op->getLoc(), constType, constAttr));
137137
}
138138

139139
return success();

test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,44 @@ func.func @test_pad_f16_constant_none(%arg0: tensor<256x1x1x5x1xf16>) -> tensor<
160160
// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] : (tensor<256x1x1x5x1xf16>, !tosa.shape<10>, tensor<f16>) -> tensor<256x1x1x5x2xf16>
161161
// CHECK: return %[[VAR2]] : tensor<256x1x1x5x2xf16>
162162
}
163+
164+
// -----
165+
166+
func.func @test_pad_f32_non_constant_padval(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<f32>) -> tensor<24x22x52x42xf32> {
167+
%noval = "onnx.NoValue"() {value} : () -> none
168+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
169+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<f32>, none) -> tensor<24x22x52x42xf32>
170+
return %2 : tensor<24x22x52x42xf32>
171+
// CHECK-LABEL: func.func @test_pad_f32_non_constant_padval
172+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<f32>) -> tensor<24x22x52x42xf32> {
173+
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
174+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<f32>) -> tensor<24x22x52x42xf32>
175+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf32>
176+
}
177+
178+
// -----
179+
180+
func.func @test_pad_i64_non_constant_padval(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<i64>) -> tensor<24x22x52x42xi64> {
181+
%noval = "onnx.NoValue"() {value} : () -> none
182+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
183+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<i64>, none) -> tensor<24x22x52x42xi64>
184+
return %2 : tensor<24x22x52x42xi64>
185+
// CHECK-LABEL: func.func @test_pad_i64_non_constant_padval
186+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xi64>, [[PARAM_1_:%.+]]: tensor<i64>) -> tensor<24x22x52x42xi64> {
187+
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
188+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<i64>) -> tensor<24x22x52x42xi64>
189+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xi64>
190+
}
191+
192+
// -----
193+
func.func @test_pad_f16_non_constant_padval(%arg0: tensor<20x16x44x32xf16>, %arg1: tensor<f16>) -> tensor<24x22x52x42xf16> {
194+
%noval = "onnx.NoValue"() {value} : () -> none
195+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
196+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf16>, tensor<8xi64>, tensor<f16>, none) -> tensor<24x22x52x42xf16>
197+
return %2 : tensor<24x22x52x42xf16>
198+
// CHECK-LABEL: func.func @test_pad_f16_non_constant_padval
199+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf16>, [[PARAM_1_:%.+]]: tensor<f16>) -> tensor<24x22x52x42xf16> {
200+
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
201+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf16>, tensor<8xi64>, tensor<f16>) -> tensor<24x22x52x42xf16>
202+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf16>
203+
}

0 commit comments

Comments
 (0)