@@ -978,6 +978,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
978978 return success ();
979979}
980980
981+ LogicalResult tosa::ConcatOp::verify () {
982+ // check that each input has same element type as output
983+ auto outType = getOutput ().getType ();
984+ const Operation::operand_range inputList = getInput1 ();
985+
986+ // Check there is at least one input
987+ if (inputList.empty ())
988+ return emitOpError (" expect at least one input" );
989+
990+ if (!llvm::all_of (inputList, [&](auto input) {
991+ return succeeded (verifySameElementTypes (
992+ *this , /* inType = */ input.getType (), outType));
993+ })) {
994+ return failure ();
995+ }
996+
997+ const int32_t axis = getAxis ();
998+ ShapeAdaptor firstRankedInputShape = nullptr ;
999+ for (const auto &input : inputList) {
1000+ const Type inputType = input.getType ();
1001+ ShapeAdaptor currShape (inputType);
1002+ if (currShape.hasRank ()) {
1003+ firstRankedInputShape = currShape;
1004+ // Check axis is in expected range
1005+ if (axis < 0 || axis >= firstRankedInputShape.getRank ())
1006+ return emitOpError (" expect axis to be within range 0 < axis < "
1007+ " rank(input1[firstRankedTensorIdx]), got " )
1008+ << axis;
1009+ break ;
1010+ }
1011+ }
1012+
1013+ const auto allOperandsHasRank = [](const Value input) {
1014+ return ShapeAdaptor (input.getType ()).hasRank ();
1015+ };
1016+ if (llvm::all_of (inputList, allOperandsHasRank)) {
1017+ const int64_t firstInputRank = firstRankedInputShape.getRank ();
1018+
1019+ for (const auto &[index, input] : llvm::enumerate (inputList.drop_front ())) {
1020+ const ShapeAdaptor inputShape (input.getType ());
1021+ const int64_t inputRank = inputShape.getRank ();
1022+ const size_t operandNum = index + 1 ;
1023+
1024+ // Check that each operand has the same rank
1025+ if (inputRank != firstInputRank)
1026+ return emitOpError (
1027+ " expect all operands to have the same rank, but got " )
1028+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
1029+ << operandNum;
1030+
1031+ // Check non-axis dims match
1032+ for (int i = 0 ; i < inputRank; i++) {
1033+ const int64_t inputDim = inputShape.getDimSize (i);
1034+ const int64_t firstInputDim = firstRankedInputShape.getDimSize (i);
1035+ if (i == axis || firstRankedInputShape.isDynamicDim (i) ||
1036+ inputShape.isDynamicDim (i))
1037+ continue ;
1038+ if (inputDim != firstInputDim)
1039+ return emitOpError (" expect all operand shapes to have the same sizes "
1040+ " on non-axis dimensions, but got " )
1041+ << inputDim << " vs " << firstInputDim << " at index " << i
1042+ << " on operands 0 and " << operandNum;
1043+ }
1044+ }
1045+ }
1046+
1047+ return success ();
1048+ }
1049+
9811050LogicalResult tosa::EqualOp::inferReturnTypeComponents (
9821051 MLIRContext *context, ::std::optional<Location> location,
9831052 ValueShapeRange operands, DictionaryAttr attributes,
@@ -1027,6 +1096,53 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
10271096 return success ();
10281097}
10291098
1099+ LogicalResult MatMulOp::verify () {
1100+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
1101+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
1102+
1103+ // Must be shaped tensor types
1104+ if (!aType)
1105+ return emitOpError (" expect a shaped tensor for input a, got " )
1106+ << getA ().getType ();
1107+
1108+ if (!bType)
1109+ return emitOpError (" expect a shaped tensor for input b, got " )
1110+ << getB ().getType ();
1111+
1112+ auto aElementType = aType.getElementType ();
1113+ auto bElementType = bType.getElementType ();
1114+
1115+ auto aQuantizedEType =
1116+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1117+ auto bQuantizedEType =
1118+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1119+
1120+ if (aQuantizedEType || bQuantizedEType) {
1121+ if (!aQuantizedEType || !bQuantizedEType) {
1122+ return emitOpError (" expect operands to be both quantized or both not "
1123+ " quantized, got " )
1124+ << aElementType << " and " << bElementType;
1125+ }
1126+ // both a and b have quantized element types
1127+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1128+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1129+ if (aQuantWidth != bQuantWidth) {
1130+ return emitOpError (" expect quantized operands to have same widths, got " )
1131+ << aQuantWidth << " and " << bQuantWidth;
1132+ }
1133+
1134+ return success ();
1135+ }
1136+
1137+ // non-quantized element types
1138+ if (aElementType != bElementType) {
1139+ return emitOpError (" expect same element type for inputs a and b, got " )
1140+ << aElementType << " and " << bElementType;
1141+ }
1142+
1143+ return success ();
1144+ }
1145+
10301146LogicalResult tosa::PadOp::inferReturnTypeComponents (
10311147 MLIRContext *context, ::std::optional<Location> location,
10321148 PadOp::Adaptor adaptor,
@@ -1075,6 +1191,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
10751191}
10761192
10771193LogicalResult tosa::PadOp::verify () {
1194+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1195+ /* outType = */ getOutput ().getType ())
1196+ .failed ()) {
1197+ return failure ();
1198+ }
1199+
1200+ if (auto padConst = getPadConst ()) {
1201+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1202+ /* outType = */ getOutput ().getType ())
1203+ .failed ()) {
1204+ return failure ();
1205+ }
1206+ }
1207+
10781208 RankedTensorType inputType = getInput1 ().getType ();
10791209 RankedTensorType outputType = getOutput ().getType ();
10801210 auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1148,21 +1278,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
11481278}
11491279
11501280LogicalResult tosa::SliceOp::verify () {
1281+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1282+ /* outType = */ getOutput ().getType ())
1283+ .failed ())
1284+ return failure ();
11511285 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
11521286 if (!inputType)
11531287 return success ();
11541288
11551289 auto startShapeRank =
11561290 llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
11571291 if (inputType.getRank () != startShapeRank)
1158- return emitOpError (
1159- " length of start attribute is not equal rank of input shape" );
1292+ return emitOpError (" length of start is not equal to rank of input shape" );
11601293
11611294 auto sizeShapeRank =
11621295 llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
11631296 if (inputType.getRank () != sizeShapeRank)
1164- return emitOpError (
1165- " length of size attribute is not equal rank of input shape" );
1297+ return emitOpError (" length of size is not equal to rank of input shape" );
11661298
11671299 return success ();
11681300}
@@ -1367,6 +1499,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
13671499}
13681500
13691501LogicalResult tosa::TileOp::verify () {
1502+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1503+ /* outType = */ getOutput ().getType ())
1504+ .failed ()) {
1505+ return failure ();
1506+ }
13701507 ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
13711508 ShapedType outputType = llvm::cast<ShapedType>(getType ());
13721509
@@ -1448,6 +1585,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
14481585}
14491586
14501587llvm::LogicalResult tosa::ReshapeOp::verify () {
1588+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1589+ /* outType = */ getOutput ().getType ())
1590+ .failed ()) {
1591+ return failure ();
1592+ }
14511593 TensorType inputType = getInput1 ().getType ();
14521594 RankedTensorType outputType = getType ();
14531595
@@ -1626,6 +1768,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
16261768}
16271769
16281770LogicalResult tosa::TransposeOp::verify () {
1771+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1772+ /* outType = */ getOutput ().getType ())
1773+ .failed ()) {
1774+ return failure ();
1775+ }
16291776 TensorType inputType = getInput1 ().getType ();
16301777 TensorType outputType = getOutput ().getType ();
16311778 const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1726,6 +1873,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
17261873 return success ();
17271874}
17281875
1876+ LogicalResult tosa::GatherOp::verify () {
1877+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1878+ /* outType = */ getOutput ().getType ());
1879+ }
1880+
17291881LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
17301882 MLIRContext *context, ::std::optional<Location> location,
17311883 ResizeOp::Adaptor adaptor,
@@ -1887,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
18872039 return success ();
18882040}
18892041
2042+ LogicalResult tosa::ScatterOp::verify () {
2043+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
2044+ /* outType = */ getValuesOut ().getType ())
2045+ .failed () ||
2046+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
2047+ /* outType = */ getValuesOut ().getType ())
2048+ .failed ()) {
2049+ return failure ();
2050+ }
2051+ return success ();
2052+ }
2053+
18902054static LogicalResult ReduceInferReturnTypes (
18912055 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
18922056 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2342,6 +2506,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
23422506 inferredReturnShapes);
23432507}
23442508
2509+ LogicalResult MaxPool2dOp::verify () {
2510+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2511+ /* outType = */ getOutput ().getType ());
2512+ }
2513+
23452514LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
23462515 MLIRContext *context, ::std::optional<Location> location,
23472516 DepthwiseConv2DOp::Adaptor adaptor,
@@ -2642,6 +2811,10 @@ void IfOp::print(OpAsmPrinter &p) {
26422811}
26432812
26442813LogicalResult ReverseOp::verify () {
2814+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2815+ /* outType = */ getOutput ().getType ())
2816+ .failed ())
2817+ return failure ();
26452818 TensorType inputType = getInput1 ().getType ();
26462819 TensorType outputType = getOutput ().getType ();
26472820 int32_t reverseAxis = getAxis ();
@@ -2670,6 +2843,31 @@ LogicalResult ReverseOp::verify() {
26702843 return success ();
26712844}
26722845
2846+ LogicalResult tosa::SelectOp::verify () {
2847+ // verify input2 and input3 have same element type as output
2848+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2849+ /* outType = */ getOutput ().getType ())
2850+ .failed () ||
2851+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2852+ /* outType = */ getOutput ().getType ())
2853+ .failed ()) {
2854+ return failure ();
2855+ }
2856+ // verify input1 has element type of bool
2857+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2858+ if (!predicateType) {
2859+ return emitOpError (" expect shaped tensor for input1, got " )
2860+ << getInput1 ().getType ();
2861+ }
2862+ auto predicateElementType = predicateType.getElementType ();
2863+ if (!predicateElementType.isInteger (1 )) {
2864+ return emitOpError (" expect element type of bool for input1, got " )
2865+ << predicateElementType;
2866+ }
2867+
2868+ return success ();
2869+ }
2870+
26732871// parse and print of WhileOp refer to the implementation of SCF dialect.
26742872ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
26752873 SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments