@@ -850,6 +850,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
850850 return success ();
851851}
852852
853+ LogicalResult tosa::ConcatOp::verify () {
854+ // check that each input has same element type as output
855+ auto outType = getOutput ().getType ();
856+ const Operation::operand_range inputList = getInput1 ();
857+
858+ if (!llvm::all_of (inputList, [&](auto input) {
859+ return succeeded (verifySameElementTypes (
860+ *this , /* inType = */ input.getType (), outType));
861+ })) {
862+ return failure ();
863+ }
864+
865+ // Check there is at least one input
866+ if (inputList.empty ())
867+ return emitOpError (" expect at least one input" );
868+
869+ const Type firstInputType = inputList.front ().getType ();
870+ const ShapeAdaptor firstInputShape (firstInputType);
871+ const int32_t axis = getAxis ();
872+
873+ if (firstInputShape.hasRank ()) {
874+ // Check axis is in expected range
875+ if (axis < 0 || axis >= firstInputShape.getRank ())
876+ return emitOpError (" expect axis to be within range 0 < axis < "
877+ " rank(input1[0]), got " )
878+ << axis;
879+ }
880+
881+ const auto allOperandsHasRank = [](const Value input) {
882+ return ShapeAdaptor (input.getType ()).hasRank ();
883+ };
884+ if (llvm::all_of (inputList, allOperandsHasRank)) {
885+ const int64_t firstInputRank = firstInputShape.getRank ();
886+
887+ for (const auto [index, input] : llvm::enumerate (inputList.drop_front ())) {
888+ const ShapeAdaptor inputShape (input.getType ());
889+ const int64_t inputRank = inputShape.getRank ();
890+ const size_t operandNum = index + 1 ;
891+
892+ // Check that each operand has the same rank
893+ if (inputRank != firstInputRank)
894+ return emitOpError (
895+ " expect all operands to have the same rank, but got " )
896+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
897+ << operandNum;
898+
899+ // Check non-axis dims match
900+ for (int i = 0 ; i < inputRank; i++) {
901+ const int64_t inputDim = inputShape.getDimSize (i);
902+ const int64_t firstInputDim = firstInputShape.getDimSize (i);
903+ if (i == axis || firstInputShape.isDynamicDim (i) ||
904+ inputShape.isDynamicDim (i))
905+ continue ;
906+ if (inputDim != firstInputDim)
907+ return emitOpError (" expect all operand shapes to have the same sizes "
908+ " on non-axis dimensions, but got " )
909+ << inputDim << " vs " << firstInputDim << " at index " << i
910+ << " on operands 0 and " << operandNum;
911+ }
912+ }
913+ }
914+
915+ return success ();
916+ }
917+
853918LogicalResult tosa::EqualOp::inferReturnTypeComponents (
854919 MLIRContext *context, ::std::optional<Location> location,
855920 ValueShapeRange operands, DictionaryAttr attributes,
@@ -899,6 +964,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
899964 return success ();
900965}
901966
967+ LogicalResult MatMulOp::verify () {
968+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
969+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
970+
971+ // Must be shaped tensor types
972+ if (!aType) {
973+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
974+ return failure ();
975+ }
976+ if (!bType) {
977+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
978+ return failure ();
979+ }
980+
981+ auto aElementType = aType.getElementType ();
982+ auto bElementType = bType.getElementType ();
983+
984+ auto aQuantizedEType =
985+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
986+ auto bQuantizedEType =
987+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
988+
989+ if (aQuantizedEType || bQuantizedEType) {
990+ if (!aQuantizedEType || !bQuantizedEType) {
991+ emitOpError (
992+ " expect operands to be both quantized or both not quantized, got " )
993+ << aElementType << " and " << bElementType;
994+ return failure ();
995+ }
996+ // both a and b have quantized element types
997+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
998+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
999+ if (aQuantWidth != bQuantWidth) {
1000+ emitOpError (" expect quantized operands to have same widths, got " )
1001+ << aQuantWidth << " and " << bQuantWidth;
1002+ return failure ();
1003+ }
1004+
1005+ return success ();
1006+ }
1007+
1008+ // non-quantized element types
1009+ if (aElementType != bElementType) {
1010+ emitOpError (" expect same element type for inputs a and b, got " )
1011+ << aElementType << " and " << bElementType;
1012+ return failure ();
1013+ }
1014+
1015+ return success ();
1016+ }
1017+
9021018LogicalResult tosa::PadOp::inferReturnTypeComponents (
9031019 MLIRContext *context, ::std::optional<Location> location,
9041020 PadOp::Adaptor adaptor,
@@ -947,6 +1063,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
9471063}
9481064
9491065LogicalResult tosa::PadOp::verify () {
1066+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1067+ /* outType = */ getOutput ().getType ())
1068+ .failed ()) {
1069+ return failure ();
1070+ }
1071+ if (auto padConst = getPadConst ()) {
1072+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1073+ /* outType = */ getOutput ().getType ())
1074+ .failed ()) {
1075+ return failure ();
1076+ }
1077+ }
9501078 RankedTensorType inputType = getInput1 ().getType ();
9511079 RankedTensorType outputType = getOutput ().getType ();
9521080 auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1020,21 +1148,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10201148}
10211149
10221150LogicalResult tosa::SliceOp::verify () {
1151+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1152+ /* outType = */ getOutput ().getType ())
1153+ .failed ())
1154+ return failure ();
10231155 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
10241156 if (!inputType)
10251157 return success ();
10261158
10271159 auto startShapeRank =
10281160 llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
10291161 if (inputType.getRank () != startShapeRank)
1030- return emitOpError (
1031- " length of start attribute is not equal rank of input shape" );
1162+ return emitOpError (" length of start is not equal to rank of input shape" );
10321163
10331164 auto sizeShapeRank =
10341165 llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
10351166 if (inputType.getRank () != sizeShapeRank)
1036- return emitOpError (
1037- " length of size attribute is not equal rank of input shape" );
1167+ return emitOpError (" length of size is not equal to rank of input shape" );
10381168
10391169 return success ();
10401170}
@@ -1239,6 +1369,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
12391369}
12401370
12411371LogicalResult tosa::TileOp::verify () {
1372+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1373+ /* outType = */ getOutput ().getType ())
1374+ .failed ()) {
1375+ return failure ();
1376+ }
12421377 ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
12431378 ShapedType outputType = llvm::cast<ShapedType>(getType ());
12441379
@@ -1320,6 +1455,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13201455}
13211456
13221457llvm::LogicalResult tosa::ReshapeOp::verify () {
1458+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1459+ /* outType = */ getOutput ().getType ())
1460+ .failed ()) {
1461+ return failure ();
1462+ }
13231463 TensorType inputType = getInput1 ().getType ();
13241464 RankedTensorType outputType = getType ();
13251465
@@ -1434,6 +1574,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14341574}
14351575
14361576LogicalResult tosa::TransposeOp::verify () {
1577+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1578+ /* outType = */ getOutput ().getType ())
1579+ .failed ()) {
1580+ return failure ();
1581+ }
14371582 TensorType inputType = getInput1 ().getType ();
14381583 TensorType outputType = getOutput ().getType ();
14391584 const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1534,6 +1679,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
15341679 return success ();
15351680}
15361681
1682+ LogicalResult tosa::GatherOp::verify () {
1683+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1684+ /* outType = */ getOutput ().getType ());
1685+ }
1686+
15371687LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
15381688 MLIRContext *context, ::std::optional<Location> location,
15391689 ResizeOp::Adaptor adaptor,
@@ -1702,6 +1852,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17021852 return success ();
17031853}
17041854
1855+ LogicalResult tosa::ScatterOp::verify () {
1856+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
1857+ /* outType = */ getValuesOut ().getType ())
1858+ .failed () ||
1859+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
1860+ /* outType = */ getValuesOut ().getType ())
1861+ .failed ()) {
1862+ return failure ();
1863+ }
1864+ return success ();
1865+ }
1866+
17051867static LogicalResult ReduceInferReturnTypes (
17061868 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
17071869 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2066,6 +2228,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
20662228 inferredReturnShapes);
20672229}
20682230
2231+ LogicalResult MaxPool2dOp::verify () {
2232+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2233+ /* outType = */ getOutput ().getType ());
2234+ }
2235+
20692236LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
20702237 MLIRContext *context, ::std::optional<Location> location,
20712238 DepthwiseConv2DOp::Adaptor adaptor,
@@ -2368,6 +2535,10 @@ void IfOp::print(OpAsmPrinter &p) {
23682535}
23692536
23702537LogicalResult ReverseOp::verify () {
2538+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2539+ /* outType = */ getOutput ().getType ())
2540+ .failed ())
2541+ return failure ();
23712542 TensorType inputType = getInput1 ().getType ();
23722543 TensorType outputType = getOutput ().getType ();
23732544 int32_t reverseAxis = getAxis ();
@@ -2396,6 +2567,33 @@ LogicalResult ReverseOp::verify() {
23962567 return success ();
23972568}
23982569
2570+ LogicalResult tosa::SelectOp::verify () {
2571+ // verify input2 and input3 have same element type as output
2572+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2573+ /* outType = */ getOutput ().getType ())
2574+ .failed () ||
2575+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2576+ /* outType = */ getOutput ().getType ())
2577+ .failed ()) {
2578+ return failure ();
2579+ }
2580+ // verify input1 has element type of bool
2581+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2582+ if (!predicateType) {
2583+ emitOpError (" expect shaped tensor for input1, got " )
2584+ << getInput1 ().getType ();
2585+ return failure ();
2586+ }
2587+ auto predicateElementType = predicateType.getElementType ();
2588+ if (!predicateElementType.isInteger (1 )) {
2589+ emitOpError (" expect element type of bool for input1, got " )
2590+ << predicateElementType;
2591+ return failure ();
2592+ }
2593+
2594+ return success ();
2595+ }
2596+
23992597// parse and print of WhileOp refer to the implementation of SCF dialect.
24002598ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
24012599 SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments