@@ -871,6 +871,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
871871 return success ();
872872}
873873
874+ LogicalResult tosa::ConcatOp::verify () {
875+ // check that each input has same element type as output
876+ auto outType = getOutput ().getType ();
877+ const Operation::operand_range inputList = getInput1 ();
878+
879+ // Check there is at least one input
880+ if (inputList.empty ())
881+ return emitOpError (" expect at least one input" );
882+
883+ if (!llvm::all_of (inputList, [&](auto input) {
884+ return succeeded (verifySameElementTypes (
885+ *this , /* inType = */ input.getType (), outType));
886+ })) {
887+ return failure ();
888+ }
889+
890+ const int32_t axis = getAxis ();
891+ ShapeAdaptor firstRankedInputShape = nullptr ;
892+ for (auto input : inputList) {
893+ const Type inputType = input.getType ();
894+ ShapeAdaptor currShape (inputType);
895+ if (currShape.hasRank ()) {
896+ firstRankedInputShape = currShape;
897+ // Check axis is in expected range
898+ if (axis < 0 || axis >= firstRankedInputShape.getRank ())
899+ return emitOpError (" expect axis to be within range 0 < axis < "
900+ " rank(input1[firstRankedTensorIdx]), got " )
901+ << axis;
902+ break ;
903+ }
904+ }
905+
906+ const auto allOperandsHasRank = [](const Value input) {
907+ return ShapeAdaptor (input.getType ()).hasRank ();
908+ };
909+ if (llvm::all_of (inputList, allOperandsHasRank)) {
910+ const int64_t firstInputRank = firstRankedInputShape.getRank ();
911+
912+ for (const auto [index, input] : llvm::enumerate (inputList.drop_front ())) {
913+ const ShapeAdaptor inputShape (input.getType ());
914+ const int64_t inputRank = inputShape.getRank ();
915+ const size_t operandNum = index + 1 ;
916+
917+ // Check that each operand has the same rank
918+ if (inputRank != firstInputRank)
919+ return emitOpError (
920+ " expect all operands to have the same rank, but got " )
921+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
922+ << operandNum;
923+
924+ // Check non-axis dims match
925+ for (int i = 0 ; i < inputRank; i++) {
926+ const int64_t inputDim = inputShape.getDimSize (i);
927+ const int64_t firstInputDim = firstRankedInputShape.getDimSize (i);
928+ if (i == axis || firstRankedInputShape.isDynamicDim (i) ||
929+ inputShape.isDynamicDim (i))
930+ continue ;
931+ if (inputDim != firstInputDim)
932+ return emitOpError (" expect all operand shapes to have the same sizes "
933+ " on non-axis dimensions, but got " )
934+ << inputDim << " vs " << firstInputDim << " at index " << i
935+ << " on operands 0 and " << operandNum;
936+ }
937+ }
938+ }
939+
940+ return success ();
941+ }
942+
874943LogicalResult tosa::EqualOp::inferReturnTypeComponents (
875944 MLIRContext *context, ::std::optional<Location> location,
876945 ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +989,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
920989 return success ();
921990}
922991
992+ LogicalResult MatMulOp::verify () {
993+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
994+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
995+
996+ // Must be shaped tensor types
997+ if (!aType) {
998+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
999+ return failure ();
1000+ }
1001+ if (!bType) {
1002+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
1003+ return failure ();
1004+ }
1005+
1006+ auto aElementType = aType.getElementType ();
1007+ auto bElementType = bType.getElementType ();
1008+
1009+ auto aQuantizedEType =
1010+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1011+ auto bQuantizedEType =
1012+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1013+
1014+ if (aQuantizedEType || bQuantizedEType) {
1015+ if (!aQuantizedEType || !bQuantizedEType) {
1016+ emitOpError (
1017+ " expect operands to be both quantized or both not quantized, got " )
1018+ << aElementType << " and " << bElementType;
1019+ return failure ();
1020+ }
1021+ // both a and b have quantized element types
1022+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1023+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1024+ if (aQuantWidth != bQuantWidth) {
1025+ emitOpError (" expect quantized operands to have same widths, got " )
1026+ << aQuantWidth << " and " << bQuantWidth;
1027+ return failure ();
1028+ }
1029+
1030+ return success ();
1031+ }
1032+
1033+ // non-quantized element types
1034+ if (aElementType != bElementType) {
1035+ emitOpError (" expect same element type for inputs a and b, got " )
1036+ << aElementType << " and " << bElementType;
1037+ return failure ();
1038+ }
1039+
1040+ return success ();
1041+ }
1042+
9231043LogicalResult tosa::PadOp::inferReturnTypeComponents (
9241044 MLIRContext *context, ::std::optional<Location> location,
9251045 PadOp::Adaptor adaptor,
@@ -968,6 +1088,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
9681088}
9691089
9701090LogicalResult tosa::PadOp::verify () {
1091+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1092+ /* outType = */ getOutput ().getType ())
1093+ .failed ()) {
1094+ return failure ();
1095+ }
1096+
1097+ if (auto padConst = getPadConst ()) {
1098+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1099+ /* outType = */ getOutput ().getType ())
1100+ .failed ()) {
1101+ return failure ();
1102+ }
1103+ }
1104+
9711105 RankedTensorType inputType = getInput1 ().getType ();
9721106 RankedTensorType outputType = getOutput ().getType ();
9731107 auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1041,21 +1175,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10411175}
10421176
10431177LogicalResult tosa::SliceOp::verify () {
1178+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1179+ /* outType = */ getOutput ().getType ())
1180+ .failed ())
1181+ return failure ();
10441182 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
10451183 if (!inputType)
10461184 return success ();
10471185
10481186 auto startShapeRank =
10491187 llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
10501188 if (inputType.getRank () != startShapeRank)
1051- return emitOpError (
1052- " length of start attribute is not equal rank of input shape" );
1189+ return emitOpError (" length of start is not equal to rank of input shape" );
10531190
10541191 auto sizeShapeRank =
10551192 llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
10561193 if (inputType.getRank () != sizeShapeRank)
1057- return emitOpError (
1058- " length of size attribute is not equal rank of input shape" );
1194+ return emitOpError (" length of size is not equal to rank of input shape" );
10591195
10601196 return success ();
10611197}
@@ -1260,6 +1396,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
12601396}
12611397
12621398LogicalResult tosa::TileOp::verify () {
1399+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1400+ /* outType = */ getOutput ().getType ())
1401+ .failed ()) {
1402+ return failure ();
1403+ }
12631404 ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
12641405 ShapedType outputType = llvm::cast<ShapedType>(getType ());
12651406
@@ -1341,6 +1482,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13411482}
13421483
13431484llvm::LogicalResult tosa::ReshapeOp::verify () {
1485+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1486+ /* outType = */ getOutput ().getType ())
1487+ .failed ()) {
1488+ return failure ();
1489+ }
13441490 TensorType inputType = getInput1 ().getType ();
13451491 RankedTensorType outputType = getType ();
13461492
@@ -1528,6 +1674,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
15281674}
15291675
15301676LogicalResult tosa::TransposeOp::verify () {
1677+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1678+ /* outType = */ getOutput ().getType ())
1679+ .failed ()) {
1680+ return failure ();
1681+ }
15311682 TensorType inputType = getInput1 ().getType ();
15321683 TensorType outputType = getOutput ().getType ();
15331684 const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1628,6 +1779,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
16281779 return success ();
16291780}
16301781
1782+ LogicalResult tosa::GatherOp::verify () {
1783+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1784+ /* outType = */ getOutput ().getType ());
1785+ }
1786+
16311787LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
16321788 MLIRContext *context, ::std::optional<Location> location,
16331789 ResizeOp::Adaptor adaptor,
@@ -1789,6 +1945,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17891945 return success ();
17901946}
17911947
1948+ LogicalResult tosa::ScatterOp::verify () {
1949+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
1950+ /* outType = */ getValuesOut ().getType ())
1951+ .failed () ||
1952+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
1953+ /* outType = */ getValuesOut ().getType ())
1954+ .failed ()) {
1955+ return failure ();
1956+ }
1957+ return success ();
1958+ }
1959+
17921960static LogicalResult ReduceInferReturnTypes (
17931961 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
17941962 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2412,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
22442412 inferredReturnShapes);
22452413}
22462414
2415+ LogicalResult MaxPool2dOp::verify () {
2416+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2417+ /* outType = */ getOutput ().getType ());
2418+ }
2419+
22472420LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
22482421 MLIRContext *context, ::std::optional<Location> location,
22492422 DepthwiseConv2DOp::Adaptor adaptor,
@@ -2544,6 +2717,10 @@ void IfOp::print(OpAsmPrinter &p) {
25442717}
25452718
25462719LogicalResult ReverseOp::verify () {
2720+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2721+ /* outType = */ getOutput ().getType ())
2722+ .failed ())
2723+ return failure ();
25472724 TensorType inputType = getInput1 ().getType ();
25482725 TensorType outputType = getOutput ().getType ();
25492726 int32_t reverseAxis = getAxis ();
@@ -2572,6 +2749,33 @@ LogicalResult ReverseOp::verify() {
25722749 return success ();
25732750}
25742751
2752+ LogicalResult tosa::SelectOp::verify () {
2753+ // verify input2 and input3 have same element type as output
2754+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2755+ /* outType = */ getOutput ().getType ())
2756+ .failed () ||
2757+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2758+ /* outType = */ getOutput ().getType ())
2759+ .failed ()) {
2760+ return failure ();
2761+ }
2762+ // verify input1 has element type of bool
2763+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2764+ if (!predicateType) {
2765+ emitOpError (" expect shaped tensor for input1, got " )
2766+ << getInput1 ().getType ();
2767+ return failure ();
2768+ }
2769+ auto predicateElementType = predicateType.getElementType ();
2770+ if (!predicateElementType.isInteger (1 )) {
2771+ emitOpError (" expect element type of bool for input1, got " )
2772+ << predicateElementType;
2773+ return failure ();
2774+ }
2775+
2776+ return success ();
2777+ }
2778+
25752779// parse and print of WhileOp refer to the implementation of SCF dialect.
25762780ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
25772781 SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments