Skip to content

Commit e859e0c

Browse files
sayeddlakosh-rai
authored andcommitted
review comments - cleanup handling of constant pad values
1 parent 80be0e5 commit e859e0c

File tree

2 files changed

+22
-31
lines changed

2 files changed

+22
-31
lines changed

src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
2323
#include "src/Dialect/ONNX/ONNXOps.hpp"
2424
#include "llvm/ADT/APFloat.h"
25+
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/ADT/SmallVector.h"
26-
#include "llvm/Support/Casting.h"
2727

2828
using namespace mlir;
2929

@@ -106,36 +106,12 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
106106
padsList1,
107107
rewriter.create<mlir::tosa::ConstOp>(
108108
op->getLoc(), constType, constAttr));
109-
} else if (!constValue.getDefiningOp<mlir::tosa::ConstOp>()) {
110-
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
111-
op, resultType, data, padsList1, constValue);
112109
} else {
113-
auto valueAttr = tosa::getValueFromTosaConst<ElementsAttr>(constValue);
114110
TosaBuilder tosaBuilder(rewriter, loc);
115-
116-
Value constTosaTensor;
117-
if (isa<FloatType>(valueAttr.getElementType())) {
118-
auto valueIt = valueAttr.getValues<FloatAttr>().begin();
119-
const float valueFloat = cast<FloatAttr>(*valueIt).getValueAsDouble();
120-
constTosaTensor = tosaBuilder.getSplattedConst(
121-
valueFloat, valueAttr.getElementType(), 0);
122-
} else {
123-
assert(isTOSAInt(elementDtype) && "Already validated");
124-
auto valueIt = valueAttr.getValues<IntegerAttr>().begin();
125-
auto valueAsAPInt = cast<IntegerAttr>(*valueIt).getValue();
126-
auto asIntegerTy = cast<IntegerType>(valueAttr.getElementType());
127-
if (asIntegerTy.isUnsigned()) {
128-
constTosaTensor = tosaBuilder.getSplattedConst(
129-
valueAsAPInt.getZExtValue(), asIntegerTy, 0);
130-
} else {
131-
constTosaTensor = tosaBuilder.getSplattedConst(
132-
valueAsAPInt.getSExtValue(), asIntegerTy, 0);
133-
}
134-
}
111+
Value reshapeToSplattedConst = tosaBuilder.reshape(constValue, {});
135112
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
136-
op, resultType, data, padsList1, constTosaTensor);
113+
op, resultType, data, padsList1, reshapeToSplattedConst);
137114
}
138-
139115
return success();
140116
}
141117
};

test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --cse %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --canonicalize --cse %s -split-input-file | FileCheck %s
22

33
func.func @test_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> {
44
%noval = "onnx.NoValue"() {value} : () -> none
@@ -170,21 +170,36 @@ func.func @test_pad_f32_non_constant_padval(%arg0: tensor<20x16x44x32xf32>, %arg
170170
return %2 : tensor<24x22x52x42xf32>
171171
// CHECK-LABEL: func.func @test_pad_f32_non_constant_padval
172172
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<f32>) -> tensor<24x22x52x42xf32> {
173-
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
173+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
174174
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<f32>) -> tensor<24x22x52x42xf32>
175175
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf32>
176176
}
177177

178178
// -----
179179

180+
func.func @test_pad_f32_non_constant_1Dpadval(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<1xf32>) -> tensor<24x22x52x42xf32> {
181+
%noval = "onnx.NoValue"() {value} : () -> none
182+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
183+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<24x22x52x42xf32>
184+
return %2 : tensor<24x22x52x42xf32>
185+
// CHECK-LABEL: func.func @test_pad_f32_non_constant_1Dpadval
186+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<24x22x52x42xf32> {
187+
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
188+
// CHECK-DAG: [[VAL_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64>} : (tensor<1xf32>) -> tensor<f32>
189+
// CHECK: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAL_1_]] : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<f32>) -> tensor<24x22x52x42xf32>
190+
// CHECK: return [[VAR_2_]] : tensor<24x22x52x42xf32>
191+
}
192+
193+
// -----
194+
180195
func.func @test_pad_i64_non_constant_padval(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<i64>) -> tensor<24x22x52x42xi64> {
181196
%noval = "onnx.NoValue"() {value} : () -> none
182197
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
183198
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<i64>, none) -> tensor<24x22x52x42xi64>
184199
return %2 : tensor<24x22x52x42xi64>
185200
// CHECK-LABEL: func.func @test_pad_i64_non_constant_padval
186201
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xi64>, [[PARAM_1_:%.+]]: tensor<i64>) -> tensor<24x22x52x42xi64> {
187-
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
202+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
188203
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<i64>) -> tensor<24x22x52x42xi64>
189204
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xi64>
190205
}
@@ -197,7 +212,7 @@ func.func @test_pad_f16_non_constant_padval(%arg0: tensor<20x16x44x32xf16>, %arg
197212
return %2 : tensor<24x22x52x42xf16>
198213
// CHECK-LABEL: func.func @test_pad_f16_non_constant_padval
199214
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf16>, [[PARAM_1_:%.+]]: tensor<f16>) -> tensor<24x22x52x42xf16> {
200-
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
215+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xi64>}> : () -> tensor<8xi64>
201216
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf16>, tensor<8xi64>, tensor<f16>) -> tensor<24x22x52x42xf16>
202217
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf16>
203218
}

0 commit comments

Comments
 (0)