@@ -871,6 +871,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
871871 return success ();
872872}
873873
874+ LogicalResult tosa::ConcatOp::verify () {
875+ // check that each input has same element type as output
876+ auto outType = getOutput ().getType ();
877+ const Operation::operand_range inputList = getInput1 ();
878+
879+ // Check there is at least one input
880+ if (inputList.empty ())
881+ return emitOpError (" expect at least one input" );
882+
883+ if (!llvm::all_of (inputList, [&](auto input) {
884+ return succeeded (verifySameElementTypes (
885+ *this , /* inType = */ input.getType (), outType));
886+ })) {
887+ return failure ();
888+ }
889+
890+ const Type firstInputType = inputList.front ().getType ();
891+ const ShapeAdaptor firstInputShape (firstInputType);
892+ const int32_t axis = getAxis ();
893+
894+ if (firstInputShape.hasRank ()) {
895+ // Check axis is in expected range
896+ if (axis < 0 || axis >= firstInputShape.getRank ())
897+ return emitOpError (" expect axis to be within range 0 < axis < "
898+ " rank(input1[0]), got " )
899+ << axis;
900+ }
901+
902+ const auto allOperandsHasRank = [](const Value input) {
903+ return ShapeAdaptor (input.getType ()).hasRank ();
904+ };
905+ if (llvm::all_of (inputList, allOperandsHasRank)) {
906+ const int64_t firstInputRank = firstInputShape.getRank ();
907+
908+ for (const auto [index, input] : llvm::enumerate (inputList.drop_front ())) {
909+ const ShapeAdaptor inputShape (input.getType ());
910+ const int64_t inputRank = inputShape.getRank ();
911+ const size_t operandNum = index + 1 ;
912+
913+ // Check that each operand has the same rank
914+ if (inputRank != firstInputRank)
915+ return emitOpError (
916+ " expect all operands to have the same rank, but got " )
917+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
918+ << operandNum;
919+
920+ // Check non-axis dims match
921+ for (int i = 0 ; i < inputRank; i++) {
922+ const int64_t inputDim = inputShape.getDimSize (i);
923+ const int64_t firstInputDim = firstInputShape.getDimSize (i);
924+ if (i == axis || firstInputShape.isDynamicDim (i) ||
925+ inputShape.isDynamicDim (i))
926+ continue ;
927+ if (inputDim != firstInputDim)
928+ return emitOpError (" expect all operand shapes to have the same sizes "
929+ " on non-axis dimensions, but got " )
930+ << inputDim << " vs " << firstInputDim << " at index " << i
931+ << " on operands 0 and " << operandNum;
932+ }
933+ }
934+ }
935+
936+ return success ();
937+ }
938+
874939LogicalResult tosa::EqualOp::inferReturnTypeComponents (
875940 MLIRContext *context, ::std::optional<Location> location,
876941 ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +985,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
920985 return success ();
921986}
922987
988+ LogicalResult MatMulOp::verify () {
989+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
990+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
991+
992+ // Must be shaped tensor types
993+ if (!aType) {
994+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
995+ return failure ();
996+ }
997+ if (!bType) {
998+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
999+ return failure ();
1000+ }
1001+
1002+ auto aElementType = aType.getElementType ();
1003+ auto bElementType = bType.getElementType ();
1004+
1005+ auto aQuantizedEType =
1006+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1007+ auto bQuantizedEType =
1008+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1009+
1010+ if (aQuantizedEType || bQuantizedEType) {
1011+ if (!aQuantizedEType || !bQuantizedEType) {
1012+ emitOpError (
1013+ " expect operands to be both quantized or both not quantized, got " )
1014+ << aElementType << " and " << bElementType;
1015+ return failure ();
1016+ }
1017+ // both a and b have quantized element types
1018+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1019+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1020+ if (aQuantWidth != bQuantWidth) {
1021+ emitOpError (" expect quantized operands to have same widths, got " )
1022+ << aQuantWidth << " and " << bQuantWidth;
1023+ return failure ();
1024+ }
1025+
1026+ return success ();
1027+ }
1028+
1029+ // non-quantized element types
1030+ if (aElementType != bElementType) {
1031+ emitOpError (" expect same element type for inputs a and b, got " )
1032+ << aElementType << " and " << bElementType;
1033+ return failure ();
1034+ }
1035+
1036+ return success ();
1037+ }
1038+
9231039LogicalResult tosa::PadOp::inferReturnTypeComponents (
9241040 MLIRContext *context, ::std::optional<Location> location,
9251041 PadOp::Adaptor adaptor,
@@ -968,6 +1084,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
9681084}
9691085
9701086LogicalResult tosa::PadOp::verify () {
1087+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1088+ /* outType = */ getOutput ().getType ())
1089+ .failed ()) {
1090+ return failure ();
1091+ }
1092+
1093+ if (auto padConst = getPadConst ()) {
1094+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1095+ /* outType = */ getOutput ().getType ())
1096+ .failed ()) {
1097+ return failure ();
1098+ }
1099+ }
1100+
9711101 RankedTensorType inputType = getInput1 ().getType ();
9721102 RankedTensorType outputType = getOutput ().getType ();
9731103 auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1041,21 +1171,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10411171}
10421172
10431173LogicalResult tosa::SliceOp::verify () {
1174+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1175+ /* outType = */ getOutput ().getType ())
1176+ .failed ())
1177+ return failure ();
10441178 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
10451179 if (!inputType)
10461180 return success ();
10471181
10481182 auto startShapeRank =
10491183 llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
10501184 if (inputType.getRank () != startShapeRank)
1051- return emitOpError (
1052- " length of start attribute is not equal rank of input shape" );
1185+ return emitOpError (" length of start is not equal to rank of input shape" );
10531186
10541187 auto sizeShapeRank =
10551188 llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
10561189 if (inputType.getRank () != sizeShapeRank)
1057- return emitOpError (
1058- " length of size attribute is not equal rank of input shape" );
1190+ return emitOpError (" length of size is not equal to rank of input shape" );
10591191
10601192 return success ();
10611193}
@@ -1260,6 +1392,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
12601392}
12611393
12621394LogicalResult tosa::TileOp::verify () {
1395+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1396+ /* outType = */ getOutput ().getType ())
1397+ .failed ()) {
1398+ return failure ();
1399+ }
12631400 ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
12641401 ShapedType outputType = llvm::cast<ShapedType>(getType ());
12651402
@@ -1341,6 +1478,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13411478}
13421479
13431480llvm::LogicalResult tosa::ReshapeOp::verify () {
1481+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1482+ /* outType = */ getOutput ().getType ())
1483+ .failed ()) {
1484+ return failure ();
1485+ }
13441486 TensorType inputType = getInput1 ().getType ();
13451487 RankedTensorType outputType = getType ();
13461488
@@ -1528,6 +1670,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
15281670}
15291671
15301672LogicalResult tosa::TransposeOp::verify () {
1673+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1674+ /* outType = */ getOutput ().getType ())
1675+ .failed ()) {
1676+ return failure ();
1677+ }
15311678 TensorType inputType = getInput1 ().getType ();
15321679 TensorType outputType = getOutput ().getType ();
15331680 const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1628,6 +1775,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
16281775 return success ();
16291776}
16301777
1778+ LogicalResult tosa::GatherOp::verify () {
1779+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1780+ /* outType = */ getOutput ().getType ());
1781+ }
1782+
16311783LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
16321784 MLIRContext *context, ::std::optional<Location> location,
16331785 ResizeOp::Adaptor adaptor,
@@ -1789,6 +1941,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17891941 return success ();
17901942}
17911943
1944+ LogicalResult tosa::ScatterOp::verify () {
1945+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
1946+ /* outType = */ getValuesOut ().getType ())
1947+ .failed () ||
1948+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
1949+ /* outType = */ getValuesOut ().getType ())
1950+ .failed ()) {
1951+ return failure ();
1952+ }
1953+ return success ();
1954+ }
1955+
17921956static LogicalResult ReduceInferReturnTypes (
17931957 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
17941958 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2408,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
22442408 inferredReturnShapes);
22452409}
22462410
2411+ LogicalResult MaxPool2dOp::verify () {
2412+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2413+ /* outType = */ getOutput ().getType ());
2414+ }
2415+
22472416LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
22482417 MLIRContext *context, ::std::optional<Location> location,
22492418 DepthwiseConv2DOp::Adaptor adaptor,
@@ -2546,6 +2715,10 @@ void IfOp::print(OpAsmPrinter &p) {
25462715}
25472716
25482717LogicalResult ReverseOp::verify () {
2718+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2719+ /* outType = */ getOutput ().getType ())
2720+ .failed ())
2721+ return failure ();
25492722 TensorType inputType = getInput1 ().getType ();
25502723 TensorType outputType = getOutput ().getType ();
25512724 int32_t reverseAxis = getAxis ();
@@ -2574,6 +2747,33 @@ LogicalResult ReverseOp::verify() {
25742747 return success ();
25752748}
25762749
2750+ LogicalResult tosa::SelectOp::verify () {
2751+ // verify input2 and input3 have same element type as output
2752+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2753+ /* outType = */ getOutput ().getType ())
2754+ .failed () ||
2755+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2756+ /* outType = */ getOutput ().getType ())
2757+ .failed ()) {
2758+ return failure ();
2759+ }
2760+ // verify input1 has element type of bool
2761+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2762+ if (!predicateType) {
2763+ emitOpError (" expect shaped tensor for input1, got " )
2764+ << getInput1 ().getType ();
2765+ return failure ();
2766+ }
2767+ auto predicateElementType = predicateType.getElementType ();
2768+ if (!predicateElementType.isInteger (1 )) {
2769+ emitOpError (" expect element type of bool for input1, got " )
2770+ << predicateElementType;
2771+ return failure ();
2772+ }
2773+
2774+ return success ();
2775+ }
2776+
25772777// parse and print of WhileOp refer to the implementation of SCF dialect.
25782778ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
25792779 SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments