Skip to content

Commit ac10142

Browse files
authored
Merge pull request #474 from Xilinx/jrickert.custom_types
Always try to take the element type from custom ops from the model, instead of guessing
2 parents 99c057f + 8e614df commit ac10142

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -848,23 +848,33 @@ class FrontendGenImpl {
848848
// Use the type map or types in input model to determine the data type of
849849
// output.
850850
std::vector<int> outputMap = T::getTypeMap();
851+
const bool shouldTakeShapeFromModelForCustomOp =
852+
isCustomOp &&
853+
(options_.useOnnxModelTypesForCustomOps || givenOutputTypes.empty());
851854
for (unsigned int i = 0; i < (unsigned int)node.output().size(); i++) {
852855
// Optional outputs using empty string.
853856
if (node.output()[i].empty()) {
854857
outputTypes.emplace_back(builder_.getNoneType());
855858
} else {
856-
if (options_.useOnnxModelTypes ||
857-
(isCustomOp && options_.useOnnxModelTypesForCustomOps)) {
859+
if (options_.useOnnxModelTypes || shouldTakeShapeFromModelForCustomOp) {
858860
auto onnxModelType = ConvertOnnxType(node.output(i), errorMessage);
859861
if (onnxModelType) {
860862
const auto ec = onnxModelType->getError();
861863
if (!ec) {
862-
outputTypes.emplace_back(*onnxModelType.value());
864+
Type outputType = *onnxModelType.value();
865+
if (!options_.useOnnxModelTypesForCustomOps &&
866+
!options_.useOnnxModelTypes) {
867+
if (auto shapedType = mlir::dyn_cast<ShapedType>(outputType)) {
868+
Type elementType = shapedType.getElementType();
869+
outputType = UnrankedTensorType::get(elementType);
870+
}
871+
}
872+
outputTypes.emplace_back(outputType);
863873
continue;
864874
}
865875
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866876
errorMessage +=
867-
"Failed to get type for '" + node.output(i) + "\n";
877+
"Failed to get type for '" + node.output(i) + "'\n";
868878
return ec;
869879
}
870880
llvm::errs() << "Warning: "
@@ -874,13 +884,19 @@ class FrontendGenImpl {
874884
}
875885
unsigned int j = i;
876886
// Variadic output is a single ODS result.
877-
if (variadicOut)
887+
if (variadicOut) {
878888
j = 0;
889+
}
879890
if (!givenOutputTypes.empty()) {
891+
assert(givenOutputTypes.size() > i &&
892+
"givenOutputTypes size is less than number of outputs");
880893
outputTypes.emplace_back(
881894
UnrankedTensorType::get(givenOutputTypes[i]));
882895
} else if (j < outputMap.size() && outputMap[j] >= MAX_NUM_TYPES) {
883896
// Mapping gives a connection with an input.
897+
assert(
898+
outputMap[j] - MAX_NUM_TYPES < static_cast<int>(inputs.size()) &&
899+
"output type mapping to input is out of range");
884900
Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType();
885901
if (mlir::isa<TensorType>(inputType)) {
886902
Type elementType =
@@ -1570,10 +1586,9 @@ class FrontendGenImpl {
15701586
auto domainAttr = builder_.getNamedAttr(
15711587
"domain_name", builder_.getStringAttr(node.domain()));
15721588
attributes.push_back(domainAttr);
1573-
int nIn = 0;
1574-
int nOut = 0;
1589+
const int nIn = ONNXCustomOp::getNumberOfOperands();
1590+
const int nOut = ONNXCustomOp::getNumberOfResults();
15751591
getNodeInputs(node, inputs);
1576-
nOut = node.output().size();
15771592
std::vector<Type> givenOutputTypes;
15781593

15791594
// We lack a way of specifying import behavior for custom domains. For now

test/mlir/onnx/parse/add_missing_output_types.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
// FAILURE: Could not successfully parse ONNX file
66
// FAILURE: ONNX type with id: 0 is not a valid type
7-
// FAILURE: Failed to import output type for
7+
// FAILURE: Failed to get type for 'output_c'
88
// FAILURE: Failed to import main graph, could not get its function type
99

1010
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --useOnnxModelTypesForCustomOps=false --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=INFERRED %s
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --useOnnxModelTypesForCustomOps=false --printIR %s | FileCheck %s
2+
<
3+
ir_version: 10,
4+
opset_import: ["" : 17, "com.test" : 1000]
5+
>
6+
test_custom (float[1,3,800,800] input) => (float[1,625,256] out)
7+
<float input_scale = {0.015625}, uint8 input_zero_point = {128}, float dq_scale = {0.5}, uint8 dq_zero_point = {128}, uint8[1,3,800,800] quant_linear, uint8[625,256] customop_res1, uint8[1,625,256] customop_res2> {
8+
quant_linear = QuantizeLinear <axis: int = 1> (input, input_scale, input_zero_point)
9+
customop_res1, customop_res2 = com.test.super_layer <body: string = "subgraph"> (quant_linear)
10+
out = DequantizeLinear <axis: int = 1> (customop_res2, dq_scale, dq_zero_point)
11+
}
12+
// CHECK-LABEL: func.func
13+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x800x800xf32> {onnx.name = "input"}) -> (tensor<1x625x256xf32> {onnx.name = "out"}) {
14+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.562500e-02> : tensor<f32>
15+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<128> : tensor<ui8>
16+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<5.000000e-01> : tensor<f32>
17+
// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<128> : tensor<ui8>
18+
// CHECK: [[VAR_4_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) {axis = 1 : si64, block_size = 0 : si64, onnx_node_name = "onnx.QuantizeLinear_0", output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x3x800x800xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x3x800x800xui8>
19+
// CHECK: [[VAR_5_:%.+]]:2 = "onnx.Custom"([[VAR_4_]]) {body = "subgraph", domain_name = "com.test", function_name = "super_layer", onnx_node_name = "onnx.Custom_1"} : (tensor<1x3x800x800xui8>) -> (tensor<*xui8>, tensor<*xui8>)
20+
// CHECK: [[VAR_6_:%.+]] = "onnx.DequantizeLinear"([[VAR_5_]]#1, [[VAR_2_]], [[VAR_3_]]) {axis = 1 : si64, block_size = 0 : si64, onnx_node_name = "onnx.DequantizeLinear_2"} : (tensor<*xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x625x256xf32>
21+
// CHECK: return [[VAR_6_]] : tensor<1x625x256xf32>
22+
// CHECK: }

0 commit comments

Comments
 (0)