@@ -1813,6 +1813,84 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
18131813 unsigned maxNumElementsToExtract = 0 ;
18141814};
18151815
1816+ // / Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
1817+ // / B)`.
1818+ // / Example:
1819+ // / %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1820+ // / %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1821+ // / vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1822+ // / vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1823+ // /
1824+ // / Becomes :
1825+ // /
1826+ // / %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1827+ // /
1828+ // / Supports only 1D-to-2D broadcasts. The following cases are not supported.
1829+ // / %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1830+ // / %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1831+ // / %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1832+ template <typename MulOpType>
1833+ struct FoldArithToVectorOuterProduct : public OpRewritePattern <MulOpType> {
1834+ using OpRewritePattern<MulOpType>::OpRewritePattern;
1835+ // Returns whether a vector.broadcast matches requirements for an outerproduct
1836+ // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
1837+ bool isValidBroadcastSource (vector::BroadcastOp broadcastOp) const {
1838+ // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
1839+ // shape_casts/broadcasts which does not belong in this pattern.
1840+ if (!broadcastOp.computeBroadcastedUnitDims ().empty ())
1841+ return false ;
1842+ // Avoid broadcast like f32 or vector<f32> -> ResType
1843+ auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
1844+ return srcType && srcType.getRank () != 2 ;
1845+ }
1846+
1847+ LogicalResult matchAndRewrite (MulOpType mulOp,
1848+ PatternRewriter &rewriter) const override {
1849+ auto resType = llvm::cast<VectorType>(mulOp.getResult ().getType ());
1850+ if (!resType)
1851+ return failure ();
1852+ if (resType.getRank () != 2 )
1853+ return failure ();
1854+ // / If operandA can be written as tr(broadcast(A)) and operandB as
1855+ // / broadcast(B) where broadcasts are 1D-to-2D, create and return
1856+ // / vector.outerproduct(A, B). Returns failure() otherwise.
1857+ auto matchOuterProduct =
1858+ [&](Value operandA,
1859+ Value operandB) -> FailureOr<vector::OuterProductOp> {
1860+ auto transposedLhs = operandA.getDefiningOp <vector::TransposeOp>();
1861+ if (!transposedLhs)
1862+ return failure ();
1863+ // Fail unless this is a true 2-D matrix transpose.
1864+ ArrayRef<int64_t > permutation = transposedLhs.getPermutation ();
1865+ if (permutation.size () != 2 || permutation[0 ] != 1 || permutation[1 ] != 0 )
1866+ return failure ();
1867+
1868+ auto broadcastedLhs =
1869+ transposedLhs.getVector ().getDefiningOp <vector::BroadcastOp>();
1870+ if (!broadcastedLhs || !isValidBroadcastSource (broadcastedLhs))
1871+ return failure ();
1872+
1873+ auto broadcastedRhs = operandB.getDefiningOp <vector::BroadcastOp>();
1874+ if (!broadcastedRhs || !isValidBroadcastSource (broadcastedRhs))
1875+ return failure ();
1876+
1877+ return rewriter.create <vector::OuterProductOp>(
1878+ mulOp->getLoc (), resType, broadcastedLhs.getSource (),
1879+ broadcastedRhs.getSource (), Value (), vector::CombiningKind::ADD);
1880+ };
1881+
1882+ Value lhs = mulOp->getOperand (0 ), rhs = mulOp->getOperand (1 );
1883+ auto maybeOuterP = matchOuterProduct (lhs, rhs);
1884+ // Handle commutativity, the transposed op is the outerproduct LHS.
1885+ if (failed (maybeOuterP))
1886+ maybeOuterP = matchOuterProduct (rhs, lhs);
1887+ if (failed (maybeOuterP))
1888+ return failure ();
1889+ rewriter.replaceOp (mulOp, maybeOuterP->getResult ());
1890+ return success ();
1891+ }
1892+ };
1893+
18161894} // namespace
18171895
18181896void mlir::vector::populateFoldArithExtensionPatterns (
@@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
19001978 maxNumElementsToExtract, benefit);
19011979}
19021980
1981+ void mlir::vector::populateElementwiseToVectorOpsPatterns (
1982+ RewritePatternSet &patterns) {
1983+ patterns.add <FoldArithToVectorOuterProduct<arith::MulFOp>,
1984+ FoldArithToVectorOuterProduct<arith::MulIOp>>(
1985+ patterns.getContext ());
1986+ }
1987+
19031988// ===----------------------------------------------------------------------===//
19041989// TableGen'd enum attribute definitions
19051990// ===----------------------------------------------------------------------===//
0 commit comments