|
25 | 25 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
26 | 26 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
27 | 27 | #include "mlir/IR/AffineExpr.h"
|
| 28 | +#include "mlir/IR/AffineMap.h" |
28 | 29 | #include "mlir/IR/Builders.h"
|
29 | 30 | #include "mlir/IR/BuiltinTypeInterfaces.h"
|
30 | 31 | #include "mlir/IR/BuiltinTypes.h"
|
@@ -1709,10 +1710,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
|
1709 | 1710 | return write;
|
1710 | 1711 |
|
1711 | 1712 | // Compute the mask and mask the write Op.
|
1712 |
| - auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); |
| 1713 | + auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(), |
| 1714 | + vecToStoreType.getScalableDims()); |
1713 | 1715 |
|
1714 | 1716 | SmallVector<OpFoldResult> destSizes =
|
1715 |
| - tensor::getMixedSizes(builder, loc, dest); |
| 1717 | + isa<MemRefType>(dest.getType()) |
| 1718 | + ? memref::getMixedSizes(builder, loc, dest) |
| 1719 | + : tensor::getMixedSizes(builder, loc, dest); |
1716 | 1720 | SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
|
1717 | 1721 | destSizes.end());
|
1718 | 1722 |
|
@@ -2118,6 +2122,92 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
|
2118 | 2122 | return success();
|
2119 | 2123 | }
|
2120 | 2124 |
|
| 2125 | +/// Vectorize a named linalg contraction op into: |
| 2126 | +/// vector::TransferReadOp - Reads vectors from the operands |
| 2127 | +/// vector::ContractionOp - Performs contraction |
| 2128 | +/// vector::TransferWriteOp - Write the result vector back to the |
| 2129 | +/// destination |
| 2130 | +/// The operands shapes are preserved and loaded directly into vectors. |
| 2131 | +/// Any further permutations or numerical casting remain within contraction op. |
| 2132 | +static LogicalResult |
| 2133 | +vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, |
| 2134 | + LinalgOp linalgOp, |
| 2135 | + SmallVectorImpl<Value> &newResults) { |
| 2136 | + Location loc = linalgOp.getLoc(); |
| 2137 | + MLIRContext *ctx = linalgOp.getContext(); |
| 2138 | + |
| 2139 | + // For simplicity, contraction vectorization is limited to linalg named ops. |
| 2140 | + // Generic op is ignored as not every arbitrary contraction body can be |
| 2141 | + // expressed by a vector.contract. |
| 2142 | + if (!isa<ContractionOpInterface>(linalgOp.getOperation())) |
| 2143 | + return failure(); |
| 2144 | + |
| 2145 | + OpOperand *outOperand = linalgOp.getDpsInitOperand(0); |
| 2146 | + Operation *reduceOp = matchLinalgReduction(outOperand); |
| 2147 | + auto maybeKind = getCombinerOpKind(reduceOp); |
| 2148 | + if (!maybeKind) { |
| 2149 | + LDBG("Failed to determine contraction combining kind.\n"); |
| 2150 | + return failure(); |
| 2151 | + } |
| 2152 | + |
| 2153 | + // Check that all dimensions are present in the input operands. |
| 2154 | + // Arbitrary broadcasts are not supported by the vector contraction. |
| 2155 | + // Broadcasts are expected to be decomposed before vectorization. |
| 2156 | + AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0]; |
| 2157 | + AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1]; |
| 2158 | + if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) { |
| 2159 | + LDBG("Contractions with broadcasts are not supported.\n"); |
| 2160 | + return failure(); |
| 2161 | + } |
| 2162 | + |
| 2163 | + // Load operands. |
| 2164 | + SmallVector<Value> vecOperands; |
| 2165 | + for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
| 2166 | + // The operand vector shape is computed by mapping the canonical vector |
| 2167 | + // shape to the operand's domain. Further permutations are left as a part of |
| 2168 | + // the contraction. |
| 2169 | + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
| 2170 | + AffineMap readMap = AffineMap::getMultiDimIdentityMap( |
| 2171 | + indexingMap.getNumResults(), rewriter.getContext()); |
| 2172 | + Type elemType = getElementTypeOrSelf(opOperand.get()); |
| 2173 | + VectorType readType = |
| 2174 | + state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); |
| 2175 | + |
| 2176 | + Value read = mlir::vector::createReadOrMaskedRead( |
| 2177 | + rewriter, loc, opOperand.get(), readType.getShape(), |
| 2178 | + /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), |
| 2179 | + /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); |
| 2180 | + vecOperands.push_back(read); |
| 2181 | + } |
| 2182 | + |
| 2183 | + // Remap iterators from linalg to vector. |
| 2184 | + SmallVector<Attribute> iterAttrs; |
| 2185 | + auto iterators = linalgOp.getIteratorTypesArray(); |
| 2186 | + for (utils::IteratorType iter : iterators) { |
| 2187 | + auto vecIter = iter == utils::IteratorType::parallel |
| 2188 | + ? vector::IteratorType::parallel |
| 2189 | + : vector::IteratorType::reduction; |
| 2190 | + iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter)); |
| 2191 | + } |
| 2192 | + |
| 2193 | + // Create contraction. |
| 2194 | + Operation *contractOp = rewriter.create<vector::ContractionOp>( |
| 2195 | + loc, /*lhs=*/vecOperands[0], |
| 2196 | + /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2], |
| 2197 | + linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind); |
| 2198 | + contractOp = state.maskOperation(rewriter, contractOp, linalgOp); |
| 2199 | + |
| 2200 | + // Store result. |
| 2201 | + Operation *write = createWriteOrMaskedWrite( |
| 2202 | + rewriter, loc, contractOp->getResult(0), outOperand->get()); |
| 2203 | + |
| 2204 | + // Finalize. |
| 2205 | + if (!write->getResults().empty()) |
| 2206 | + newResults.push_back(write->getResult(0)); |
| 2207 | + |
| 2208 | + return success(); |
| 2209 | +} |
| 2210 | + |
2121 | 2211 | namespace {
|
2122 | 2212 | enum class ConvOperationKind { Conv, Pool };
|
2123 | 2213 | } // namespace
|
@@ -2557,7 +2647,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
|
2557 | 2647 | FailureOr<VectorizationResult> mlir::linalg::vectorize(
|
2558 | 2648 | RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
|
2559 | 2649 | ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
|
2560 |
| - bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) { |
| 2650 | + bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes, |
| 2651 | + bool createNamedContraction) { |
2561 | 2652 | LDBG("Attempting to vectorize:\n" << *op << "\n");
|
2562 | 2653 | LDBG("Input vector sizes: ");
|
2563 | 2654 | LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
|
@@ -2604,6 +2695,11 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
|
2604 | 2695 | return failure();
|
2605 | 2696 | }
|
2606 | 2697 |
|
| 2698 | + if (createNamedContraction && |
| 2699 | + isa<ContractionOpInterface>(linalgOp.getOperation())) |
| 2700 | + return vectorizeAsLinalgContraction(rewriter, state, linalgOp, |
| 2701 | + results); |
| 2702 | + |
2607 | 2703 | LDBG("Vectorize generic by broadcasting to the canonical vector "
|
2608 | 2704 | "shape\n");
|
2609 | 2705 |
|
|
0 commit comments