Skip to content

Commit 2250084

Browse files
committed
Add ShapeInference for RandomUniform and RandomUniformLike.
Add ResultTypeInference to Bernoulli, CastLike, RamdomUniform, RandomUniformLike and SequenceEmpty. Add a result element type verifier to Bernoulli, Cast, CastLike, EyeLike, RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike. Refactor code between ops with `dtype` attr to have less duplication. Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 5d2f70f commit 2250084

20 files changed

+480
-243
lines changed

src/Dialect/ONNX/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ add_onnx_mlir_library(OMONNXOps
5353
ONNXOps/Math/MatMul.cpp
5454
ONNXOps/Math/RandomNormal.cpp
5555
ONNXOps/Math/RandomNormalLike.cpp
56+
ONNXOps/Math/RandomUniform.cpp
57+
ONNXOps/Math/RandomUniformLike.cpp
5658
ONNXOps/Math/Reduction.cpp
5759
ONNXOps/Math/Scatter.cpp
5860
ONNXOps/Math/TopK.cpp

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def ONNXBatchNormalizationV9Op:ONNX_Op<"BatchNormalizationV9",
603603
}
604604

605605
def ONNXBernoulliOp:ONNX_Op<"Bernoulli",
606-
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
606+
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>, DeclareOpInterfaceMethods<ResultTypeInferenceOpInterface>]> {
607607
let summary = "ONNX Bernoulli operation";
608608
let description = [{
609609
Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor
@@ -636,6 +636,7 @@ def ONNXBernoulliOp:ONNX_Op<"Bernoulli",
636636
return sh;
637637
}
638638
}];
639+
let hasVerifier = 1;
639640
}
640641

641642
def ONNXBitShiftOp:ONNX_Op<"BitShift",
@@ -942,10 +943,11 @@ def ONNXCastOp:ONNX_Op<"Cast",
942943
build($_builder, $_state, resultType, input, saturate, to);
943944
}] >
944945
];
946+
let hasVerifier = 1;
945947
}
946948

947949
def ONNXCastLikeOp:ONNX_Op<"CastLike",
948-
[Pure, OpVersionTrait<21>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
950+
[Pure, OpVersionTrait<21>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>, DeclareOpInterfaceMethods<ResultTypeInferenceOpInterface>]> {
949951
let summary = "ONNX CastLike operation";
950952
let description = [{
951953
The operator casts the elements of a given input tensor (the first input) to
@@ -975,6 +977,7 @@ def ONNXCastLikeOp:ONNX_Op<"CastLike",
975977
return sh;
976978
}
977979
}];
980+
let hasVerifier = 1;
978981
}
979982

980983
def ONNXCeilOp:ONNX_Op<"Ceil",
@@ -2396,6 +2399,7 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike",
23962399
return sh;
23972400
}
23982401
}];
2402+
let hasVerifier = 1;
23992403
}
24002404

24012405
def ONNXFlattenOp:ONNX_Op<"Flatten",
@@ -6304,6 +6308,7 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal",
63046308
return sh;
63056309
}
63066310
}];
6311+
let hasVerifier = 1;
63076312
}
63086313

63096314
def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike",
@@ -6347,7 +6352,7 @@ def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike",
63476352
}
63486353

63496354
def ONNXRandomUniformOp:ONNX_Op<"RandomUniform",
6350-
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
6355+
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>, DeclareOpInterfaceMethods<ResultTypeInferenceOpInterface>]> {
63516356
let summary = "ONNX RandomUniform operation";
63526357
let description = [{
63536358
Generate a tensor with random values drawn from a uniform distribution. The shape
@@ -6382,10 +6387,11 @@ def ONNXRandomUniformOp:ONNX_Op<"RandomUniform",
63826387
return sh;
63836388
}
63846389
}];
6390+
let hasVerifier = 1;
63856391
}
63866392

63876393
def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike",
6388-
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
6394+
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>, DeclareOpInterfaceMethods<ResultTypeInferenceOpInterface>]> {
63896395
let summary = "ONNX RandomUniformLike operation";
63906396
let description = [{
63916397
Generate a tensor with random values drawn from a uniform distribution.
@@ -6421,6 +6427,7 @@ def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike",
64216427
return sh;
64226428
}
64236429
}];
6430+
let hasVerifier = 1;
64246431
}
64256432

64266433
def ONNXRangeOp:ONNX_Op<"Range",
@@ -8452,7 +8459,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct",
84528459
}
84538460

84548461
def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty",
8455-
[Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
8462+
[Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>, DeclareOpInterfaceMethods<ResultTypeInferenceOpInterface>]> {
84568463
let summary = "ONNX SequenceEmpty operation";
84578464
let description = [{
84588465
Construct an empty tensor sequence, with given data type.

src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,25 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
16+
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
1617

1718
using namespace mlir;
18-
using namespace mlir::OpTrait::util;
1919
using namespace onnx_mlir;
2020

2121
//===----------------------------------------------------------------------===//
2222
// Verify
2323
//===----------------------------------------------------------------------===//
2424

2525
//===----------------------------------------------------------------------===//
26-
// Shape Inference
26+
// Verify
2727
//===----------------------------------------------------------------------===//
2828

29-
LogicalResult ONNXBernoulliOp::inferShapes(
30-
std::function<void(Region &)> doShapeInference) {
31-
auto builder = OpBuilder(getContext());
32-
if (!hasShapeAndRank(getInput())) {
33-
return success();
34-
}
35-
Type elementType;
36-
if (getDtypeAttr()) {
37-
elementType = convertONNXTypeToMLIRType(
38-
builder, static_cast<onnx::TensorProto_DataType>(
39-
getDtypeAttr().getValue().getSExtValue()));
40-
} else {
41-
elementType =
42-
mlir::cast<RankedTensorType>(getInput().getType()).getElementType();
43-
}
44-
ONNXBernoulliOpShapeHelper shapeHelper(getOperation(), {});
45-
return shapeHelper.computeShapeAndUpdateType(elementType);
29+
LogicalResult ONNXBernoulliOp::verify() {
30+
return verifyElementTypeFromDtypeWithFallBackToInputType(*this);
4631
}
32+
33+
//===----------------------------------------------------------------------===//
34+
// Type and Shape Inference
35+
//===----------------------------------------------------------------------===//
36+
37+
GET_SHAPE_AND_TYPE_INFERENCE_FOR_SHAPE_COPYING_OPS(ONNXBernoulliOp)

src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ LogicalResult ONNXBitwiseNotOp::inferShapes(
141141
// Cast
142142
//===----------------------------------------------------------------------===//
143143

144+
LogicalResult ONNXCastOp::verify() {
145+
return cast<ShapedType>(this->getResult().getType()).getElementType() ==
146+
getTo()
147+
? success()
148+
: emitOpError("element type does not match the 'to' attribute");
149+
}
150+
144151
std::vector<Type> ONNXCastOp::resultTypeInference() {
145152
return {UnrankedTensorType::get(getTo())};
146153
}
@@ -155,6 +162,24 @@ LogicalResult ONNXCastOp::inferShapes(
155162
return shapeHelper.computeShapeAndUpdateType(elementType);
156163
}
157164

165+
//===----------------------------------------------------------------------===//
166+
// CastLike (not really unary, but similar)
167+
//===----------------------------------------------------------------------===//
168+
169+
LogicalResult ONNXCastLikeOp::verify() {
170+
return cast<ShapedType>(this->getResult().getType()).getElementType() ==
171+
cast<ShapedType>(this->getTargetType().getType())
172+
.getElementType()
173+
? success()
174+
: emitOpError("element type does not match the 'target types' "
175+
"operands element type");
176+
}
177+
178+
std::vector<Type> ONNXCastLikeOp::resultTypeInference() {
179+
return {UnrankedTensorType::get(
180+
cast<ShapedType>(this->getTargetType().getType()).getElementType())};
181+
}
182+
158183
//===----------------------------------------------------------------------===//
159184
// Ceil
160185
//===----------------------------------------------------------------------===//
@@ -306,8 +331,8 @@ LogicalResult ONNXIsInfOp::verify() {
306331
int64_t detectPosAttribute = getDetectPositive();
307332
int64_t detectNegAttribute = getDetectNegative();
308333

309-
// One of the values for detectPosAttribute and detectNegAttribute must be 1.
310-
// If not, then this will result in an error.
334+
// One of the values for detectPosAttribute and detectNegAttribute must
335+
// be 1. If not, then this will result in an error.
311336
if (detectPosAttribute == 0 && detectNegAttribute == 0)
312337
return emitOpError(
313338
"This variation is currently unsupported. One or both of the "

src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,63 +13,25 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
16+
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
1617

1718
using namespace mlir;
18-
using namespace mlir::OpTrait::util;
1919
using namespace onnx_mlir;
2020

21-
namespace onnx_mlir {
22-
23-
template <>
24-
LogicalResult ONNXRandomNormalOpShapeHelper::computeShape() {
25-
ONNXRandomNormalOp randomOp = llvm::cast<ONNXRandomNormalOp>(op);
26-
27-
DimsExpr outputDims;
28-
createIE->getIntFromArrayAsLiterals(randomOp.getShape(), outputDims);
29-
if (!IndexExpr::isNonNegativeLiteral(outputDims))
30-
return op->emitError("Random normal tensor has dynamic dimension.");
31-
// Save the final result.
32-
setOutputDims(outputDims);
33-
return success();
34-
}
35-
36-
} // namespace onnx_mlir
37-
3821
//===----------------------------------------------------------------------===//
3922
// Verify
4023
//===----------------------------------------------------------------------===//
4124

25+
LogicalResult ONNXRandomNormalOp::verify() {
26+
return verifyElementTypeFromDtype(*this);
27+
}
28+
4229
//===----------------------------------------------------------------------===//
4330
// Type Inference
4431
//===----------------------------------------------------------------------===//
4532

46-
namespace {
47-
Type getRandomNormalElementType(ONNXRandomNormalOp op) {
48-
if (op.getDtypeAttr()) {
49-
const auto elementTypeID =
50-
static_cast<onnx::TensorProto_DataType>(op.getDtype());
51-
if (elementTypeID ==
52-
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16) {
53-
return Float16Type::get(op.getContext());
54-
} else if (elementTypeID ==
55-
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT) {
56-
return Float32Type::get(op.getContext());
57-
} else if (elementTypeID ==
58-
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE) {
59-
return Float64Type::get(op.getContext());
60-
} else if (elementTypeID ==
61-
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
62-
return BFloat16Type::get(op.getContext());
63-
} else {
64-
llvm_unreachable("dtype not supported for RandomNormal");
65-
}
66-
}
67-
return Float32Type::get(op.getContext());
68-
}
69-
} // namespace
70-
7133
std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
72-
return {UnrankedTensorType::get(getRandomNormalElementType(*this))};
34+
return {UnrankedTensorType::get(getResultElementTypeFromDtype(*this))};
7335
}
7436

7537
//===----------------------------------------------------------------------===//
@@ -80,13 +42,5 @@ LogicalResult ONNXRandomNormalOp::inferShapes(
8042
std::function<void(Region &)> doShapeInference) {
8143
ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {});
8244
return shapeHelper.computeShapeAndUpdateType(
83-
getRandomNormalElementType(*this));
45+
getResultElementTypeFromDtype(*this));
8446
}
85-
86-
//===----------------------------------------------------------------------===//
87-
// Template instantiation
88-
//===----------------------------------------------------------------------===//
89-
90-
namespace onnx_mlir {
91-
template struct ONNXNonSpecificOpShapeHelper<ONNXRandomNormalOp>;
92-
} // namespace onnx_mlir

src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp

Lines changed: 4 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
16+
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
1617

1718
using namespace mlir;
1819
using namespace mlir::OpTrait::util;
@@ -23,92 +24,11 @@ using namespace onnx_mlir;
2324
//===----------------------------------------------------------------------===//
2425

2526
LogicalResult ONNXRandomNormalLikeOp::verify() {
26-
ONNXRandomNormalLikeOpAdaptor operandAdaptor(*this);
27-
Value input = operandAdaptor.getInput();
28-
if (!hasShapeAndRank(input))
29-
return success();
30-
Value output = this->getOutput();
31-
if (!hasShapeAndRank(output))
32-
return success();
33-
34-
auto inputType =
35-
mlir::cast<RankedTensorType>(input.getType()).getElementType();
36-
auto outputType =
37-
mlir::cast<RankedTensorType>(output.getType()).getElementType();
38-
39-
auto elementTypeIDDType = operandAdaptor.getDtype();
40-
if (elementTypeIDDType) {
41-
const auto elementTypeID =
42-
static_cast<onnx::TensorProto_DataType>(*elementTypeIDDType);
43-
if (elementTypeID !=
44-
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16 &&
45-
elementTypeID !=
46-
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT &&
47-
elementTypeID !=
48-
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE &&
49-
elementTypeID !=
50-
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
51-
return emitOpError("dtype not float16, float, double or bfloat16");
52-
}
53-
if (elementTypeID ==
54-
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16 &&
55-
outputType != Float16Type::get(getContext()))
56-
return emitOpError("output tensor does not match float16 dtype.");
57-
else if (elementTypeID ==
58-
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT &&
59-
outputType != Float32Type::get(getContext()))
60-
return emitOpError("output tensor does not match float dtype.");
61-
else if (elementTypeID ==
62-
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE &&
63-
outputType != Float64Type::get(getContext()))
64-
return emitOpError("output tensor does not match double dtype.");
65-
else if (elementTypeID ==
66-
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16 &&
67-
outputType != BFloat16Type::get(getContext()))
68-
return emitOpError("output tensor does not match bfloat16 dtype.");
69-
} else if (inputType != outputType) {
70-
return emitOpError("output and input element types do not match.");
71-
}
72-
73-
return success();
74-
}
75-
76-
static Type getRandomNormalLikeOutputElementType(ONNXRandomNormalLikeOp op) {
77-
auto inputType = mlir::cast<TensorType>(op.getInput().getType());
78-
Type elementType = inputType.getElementType();
79-
if (op.getDtypeAttr()) {
80-
auto builder = OpBuilder(op.getContext());
81-
elementType = convertONNXTypeToMLIRType(
82-
builder, static_cast<onnx::TensorProto_DataType>(
83-
op.getDtypeAttr().getValue().getSExtValue()));
84-
}
85-
return elementType;
27+
return verifyElementTypeFromDtypeWithFallBackToInputType(*this);
8628
}
8729

8830
//===----------------------------------------------------------------------===//
89-
// Type Inference
31+
// Shape + Type Inference
9032
//===----------------------------------------------------------------------===//
9133

92-
std::vector<Type> ONNXRandomNormalLikeOp::resultTypeInference() {
93-
Type elementType = getRandomNormalLikeOutputElementType(*this);
94-
std::vector<Type> resultTypes;
95-
if (auto rankedInputType =
96-
mlir::dyn_cast<RankedTensorType>(getInput().getType())) {
97-
resultTypes.push_back(rankedInputType.clone(elementType));
98-
} else {
99-
resultTypes.push_back(UnrankedTensorType::get(elementType));
100-
}
101-
return resultTypes;
102-
}
103-
104-
//===----------------------------------------------------------------------===//
105-
// Shape Inference
106-
//===----------------------------------------------------------------------===//
107-
108-
LogicalResult ONNXRandomNormalLikeOp::inferShapes(
109-
std::function<void(Region &)> doShapeInference) {
110-
if (!hasShapeAndRank(getInput()))
111-
return success();
112-
Type elementType = getRandomNormalLikeOutputElementType(*this);
113-
return inferShapeForUnaryOps(getOperation(), elementType);
114-
}
34+
GET_SHAPE_AND_TYPE_INFERENCE_FOR_SHAPE_COPYING_OPS(ONNXRandomNormalLikeOp)

0 commit comments

Comments
 (0)