@@ -1857,7 +1857,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
18571857}
18581858
18591859LogicalResult tosa::MulOp::verify () {
1860- auto resElemType = getElementTypeOrSelf (getOutput ());
1860+ const Value output = getOutput ();
1861+ auto resElemType = getElementTypeOrSelf (output);
18611862
18621863 // Verify if the element type among operands and result match tosa
18631864 // specification.
@@ -1897,59 +1898,39 @@ LogicalResult tosa::MulOp::verify() {
18971898 // Verify the op has same ranks for all main operands (excludes extra operands
18981899 // such as shift of mul op, so this is the only difference with the built-in
18991900 // `SameOperandsAndResultRank` trait) and results types, if known.
1900-
1901- // delegate function that returns true if type is a shaped type with known
1902- // rank
1903- auto hasRank = [](const Type type) {
1904- if (auto shaped_type = dyn_cast<ShapedType>(type))
1905- return shaped_type.hasRank ();
1906-
1907- return false ;
1908- };
1909-
1910- auto rankedOperandTypes =
1911- llvm::to_vector (llvm::make_filter_range (getOperandTypes (), hasRank));
1912-
1913- auto rankedResultTypes =
1914- llvm::make_filter_range (getOperation ()->getResultTypes (), hasRank);
1915-
1916- // If all operands and results are unranked, then no further verification.
1917- if (rankedOperandTypes.empty () && rankedResultTypes.empty ())
1901+ TypeRange operandTypes = getOperandTypes ();
1902+ ShapedType aType = cast<ShapedType>(operandTypes[0 ]);
1903+ ShapedType bType = cast<ShapedType>(operandTypes[1 ]);
1904+
1905+ const bool aHasRank = aType.hasRank ();
1906+ const bool bHasRank = bType.hasRank ();
1907+ if (aHasRank && bHasRank) {
1908+ const int64_t aRank = aType.getRank ();
1909+ const int64_t bRank = bType.getRank ();
1910+ if (aRank != bRank)
1911+ return emitOpError (" a and b operands don't have matching ranks, got " )
1912+ << aRank << " and " << bRank;
1913+
1914+ // check for broadcast compatible shapes
1915+ SmallVector<int64_t > resultShape;
1916+ if (!mlir::OpTrait::util::getBroadcastedShape (
1917+ aType.getShape (), bType.getShape (), resultShape))
1918+ return emitOpError (" a and b operands don't have broadcast-compatible "
1919+ " shapes, got " )
1920+ << aType << " and " << bType;
1921+ }
1922+
1923+ ShapedType resultType = cast<ShapedType>(output.getType ());
1924+ if (!resultType.hasRank ())
19181925 return success ();
19191926
1920- // delegate function that returns rank of shaped type with known rank
1921- auto getRank = [](const Type type) {
1922- return cast<ShapedType>(type).getRank ();
1923- };
1924-
1925- auto rank = !rankedOperandTypes.empty () ? getRank (*rankedOperandTypes.begin ())
1926- : getRank (*rankedResultTypes.begin ());
1927-
1928- for (size_t i = 0 ; i < 2 ; ++i) {
1929- if (rank != getRank (rankedOperandTypes[i])) {
1930- return emitOpError (" operands don't have matching ranks" );
1931- }
1932- }
1933-
1934- for (const auto type : rankedResultTypes) {
1935- if (rank != getRank (type)) {
1936- return emitOpError (" result type has different rank than operands" );
1937- }
1938- }
1939-
1940- // check for broadcast compatible shapes in first two operands (ignoring
1941- // shift)
1942-
1943- // delegate function that returns shape of shaped type
1944- auto getShape = [](const Type type) {
1945- return mlir::cast<ShapedType>(type).getShape ();
1946- };
1947- SmallVector<int64_t > resultShape;
1948- if (!mlir::OpTrait::util::getBroadcastedShape (getShape (rankedOperandTypes[0 ]),
1949- getShape (rankedOperandTypes[1 ]),
1950- resultShape)) {
1951- return emitOpError (" operands don't have broadcast-compatible shapes" );
1952- }
1927+ const int64_t resultRank = resultType.getRank ();
1928+ if (aHasRank && resultRank != aType.getRank ())
1929+ return emitOpError (" result type has different rank than a, got " )
1930+ << resultRank << " vs " << aType.getRank ();
1931+ if (bHasRank && resultRank != bType.getRank ())
1932+ return emitOpError (" result type has different rank than b, got " )
1933+ << resultRank << " vs " << bType.getRank ();
19531934
19541935 return success ();
19551936}
0 commit comments