@@ -949,6 +949,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
949949 return success ();
950950}
951951
952+ LogicalResult tosa::ConcatOp::verify () {
953+ // check that each input has same element type as output
954+ auto outType = getOutput ().getType ();
955+ const Operation::operand_range inputList = getInput1 ();
956+
957+ // Check there is at least one input
958+ if (inputList.empty ())
959+ return emitOpError (" expect at least one input" );
960+
961+ if (!llvm::all_of (inputList, [&](auto input) {
962+ return succeeded (verifySameElementTypes (
963+ *this , /* inType = */ input.getType (), outType));
964+ })) {
965+ return failure ();
966+ }
967+
968+ const int32_t axis = getAxis ();
969+ ShapeAdaptor firstRankedInputShape = nullptr ;
970+ for (const auto &input : inputList) {
971+ const Type inputType = input.getType ();
972+ ShapeAdaptor currShape (inputType);
973+ if (currShape.hasRank ()) {
974+ firstRankedInputShape = currShape;
975+ // Check axis is in expected range
976+ if (axis < 0 || axis >= firstRankedInputShape.getRank ())
977+ return emitOpError (" expect axis to be within range 0 < axis < "
978+ " rank(input1[firstRankedTensorIdx]), got " )
979+ << axis;
980+ break ;
981+ }
982+ }
983+
984+ const auto allOperandsHasRank = [](const Value input) {
985+ return ShapeAdaptor (input.getType ()).hasRank ();
986+ };
987+ if (llvm::all_of (inputList, allOperandsHasRank)) {
988+ const int64_t firstInputRank = firstRankedInputShape.getRank ();
989+
990+ for (const auto &[index, input] : llvm::enumerate (inputList.drop_front ())) {
991+ const ShapeAdaptor inputShape (input.getType ());
992+ const int64_t inputRank = inputShape.getRank ();
993+ const size_t operandNum = index + 1 ;
994+
995+ // Check that each operand has the same rank
996+ if (inputRank != firstInputRank)
997+ return emitOpError (
998+ " expect all operands to have the same rank, but got " )
999+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
1000+ << operandNum;
1001+
1002+ // Check non-axis dims match
1003+ for (int i = 0 ; i < inputRank; i++) {
1004+ const int64_t inputDim = inputShape.getDimSize (i);
1005+ const int64_t firstInputDim = firstRankedInputShape.getDimSize (i);
1006+ if (i == axis || firstRankedInputShape.isDynamicDim (i) ||
1007+ inputShape.isDynamicDim (i))
1008+ continue ;
1009+ if (inputDim != firstInputDim)
1010+ return emitOpError (" expect all operand shapes to have the same sizes "
1011+ " on non-axis dimensions, but got " )
1012+ << inputDim << " vs " << firstInputDim << " at index " << i
1013+ << " on operands 0 and " << operandNum;
1014+ }
1015+ }
1016+ }
1017+
1018+ return success ();
1019+ }
1020+
9521021LogicalResult tosa::EqualOp::inferReturnTypeComponents (
9531022 MLIRContext *context, ::std::optional<Location> location,
9541023 ValueShapeRange operands, DictionaryAttr attributes,
@@ -998,6 +1067,51 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
9981067 return success ();
9991068}
10001069
1070+ LogicalResult MatMulOp::verify () {
1071+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
1072+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
1073+
1074+ // Must be shaped tensor types
1075+ if (!aType)
1076+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
1077+
1078+ if (!bType)
1079+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
1080+
1081+ auto aElementType = aType.getElementType ();
1082+ auto bElementType = bType.getElementType ();
1083+
1084+ auto aQuantizedEType =
1085+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1086+ auto bQuantizedEType =
1087+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1088+
1089+ if (aQuantizedEType || bQuantizedEType) {
1090+ if (!aQuantizedEType || !bQuantizedEType) {
1091+ emitOpError (
1092+ " expect operands to be both quantized or both not quantized, got " )
1093+ << aElementType << " and " << bElementType;
1094+ }
1095+ // both a and b have quantized element types
1096+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1097+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1098+ if (aQuantWidth != bQuantWidth) {
1099+ emitOpError (" expect quantized operands to have same widths, got " )
1100+ << aQuantWidth << " and " << bQuantWidth;
1101+ }
1102+
1103+ return success ();
1104+ }
1105+
1106+ // non-quantized element types
1107+ if (aElementType != bElementType) {
1108+ emitOpError (" expect same element type for inputs a and b, got " )
1109+ << aElementType << " and " << bElementType;
1110+ }
1111+
1112+ return success ();
1113+ }
1114+
10011115LogicalResult tosa::PadOp::inferReturnTypeComponents (
10021116 MLIRContext *context, ::std::optional<Location> location,
10031117 PadOp::Adaptor adaptor,
@@ -1046,6 +1160,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
10461160}
10471161
10481162LogicalResult tosa::PadOp::verify () {
1163+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1164+ /* outType = */ getOutput ().getType ())
1165+ .failed ()) {
1166+ return failure ();
1167+ }
1168+
1169+ if (auto padConst = getPadConst ()) {
1170+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1171+ /* outType = */ getOutput ().getType ())
1172+ .failed ()) {
1173+ return failure ();
1174+ }
1175+ }
1176+
10491177 RankedTensorType inputType = getInput1 ().getType ();
10501178 RankedTensorType outputType = getOutput ().getType ();
10511179 auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1119,21 +1247,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
11191247}
11201248
11211249LogicalResult tosa::SliceOp::verify () {
1250+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1251+ /* outType = */ getOutput ().getType ())
1252+ .failed ())
1253+ return failure ();
11221254 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
11231255 if (!inputType)
11241256 return success ();
11251257
11261258 auto startShapeRank =
11271259 llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
11281260 if (inputType.getRank () != startShapeRank)
1129- return emitOpError (
1130- " length of start attribute is not equal rank of input shape" );
1261+ return emitOpError (" length of start is not equal to rank of input shape" );
11311262
11321263 auto sizeShapeRank =
11331264 llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
11341265 if (inputType.getRank () != sizeShapeRank)
1135- return emitOpError (
1136- " length of size attribute is not equal rank of input shape" );
1266+ return emitOpError (" length of size is not equal to rank of input shape" );
11371267
11381268 return success ();
11391269}
@@ -1338,6 +1468,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
13381468}
13391469
13401470LogicalResult tosa::TileOp::verify () {
1471+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1472+ /* outType = */ getOutput ().getType ())
1473+ .failed ()) {
1474+ return failure ();
1475+ }
13411476 ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
13421477 ShapedType outputType = llvm::cast<ShapedType>(getType ());
13431478
@@ -1419,6 +1554,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
14191554}
14201555
14211556llvm::LogicalResult tosa::ReshapeOp::verify () {
1557+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1558+ /* outType = */ getOutput ().getType ())
1559+ .failed ()) {
1560+ return failure ();
1561+ }
14221562 TensorType inputType = getInput1 ().getType ();
14231563 RankedTensorType outputType = getType ();
14241564
@@ -1606,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
16061746}
16071747
16081748LogicalResult tosa::TransposeOp::verify () {
1749+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1750+ /* outType = */ getOutput ().getType ())
1751+ .failed ()) {
1752+ return failure ();
1753+ }
16091754 TensorType inputType = getInput1 ().getType ();
16101755 TensorType outputType = getOutput ().getType ();
16111756 const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1706,6 +1851,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
17061851 return success ();
17071852}
17081853
1854+ LogicalResult tosa::GatherOp::verify () {
1855+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1856+ /* outType = */ getOutput ().getType ());
1857+ }
1858+
17091859LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
17101860 MLIRContext *context, ::std::optional<Location> location,
17111861 ResizeOp::Adaptor adaptor,
@@ -1867,6 +2017,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
18672017 return success ();
18682018}
18692019
2020+ LogicalResult tosa::ScatterOp::verify () {
2021+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
2022+ /* outType = */ getValuesOut ().getType ())
2023+ .failed () ||
2024+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
2025+ /* outType = */ getValuesOut ().getType ())
2026+ .failed ()) {
2027+ return failure ();
2028+ }
2029+ return success ();
2030+ }
2031+
18702032static LogicalResult ReduceInferReturnTypes (
18712033 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
18722034 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2322,6 +2484,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
23222484 inferredReturnShapes);
23232485}
23242486
2487+ LogicalResult MaxPool2dOp::verify () {
2488+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2489+ /* outType = */ getOutput ().getType ());
2490+ }
2491+
23252492LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
23262493 MLIRContext *context, ::std::optional<Location> location,
23272494 DepthwiseConv2DOp::Adaptor adaptor,
@@ -2622,6 +2789,10 @@ void IfOp::print(OpAsmPrinter &p) {
26222789}
26232790
26242791LogicalResult ReverseOp::verify () {
2792+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2793+ /* outType = */ getOutput ().getType ())
2794+ .failed ())
2795+ return failure ();
26252796 TensorType inputType = getInput1 ().getType ();
26262797 TensorType outputType = getOutput ().getType ();
26272798 int32_t reverseAxis = getAxis ();
@@ -2650,6 +2821,31 @@ LogicalResult ReverseOp::verify() {
26502821 return success ();
26512822}
26522823
2824+ LogicalResult tosa::SelectOp::verify () {
2825+ // verify input2 and input3 have same element type as output
2826+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2827+ /* outType = */ getOutput ().getType ())
2828+ .failed () ||
2829+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2830+ /* outType = */ getOutput ().getType ())
2831+ .failed ()) {
2832+ return failure ();
2833+ }
2834+ // verify input1 has element type of bool
2835+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2836+ if (!predicateType) {
2837+ emitOpError (" expect shaped tensor for input1, got " )
2838+ << getInput1 ().getType ();
2839+ }
2840+ auto predicateElementType = predicateType.getElementType ();
2841+ if (!predicateElementType.isInteger (1 )) {
2842+ emitOpError (" expect element type of bool for input1, got " )
2843+ << predicateElementType;
2844+ }
2845+
2846+ return success ();
2847+ }
2848+
26532849// parse and print of WhileOp refer to the implementation of SCF dialect.
26542850ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
26552851 SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments