Skip to content

Commit b956f04

Browse files
authored
[mlir][linalg] Vectorize directly to a named contraction (#147296)
Extends linalg vectorizer with a path to lower contraction ops directly into `vector.contract`. The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction. The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification. The new path is optional and disabled by default to avoid changing the default vectorizer behavior.
1 parent cae7650 commit b956f04

File tree

7 files changed

+603
-11
lines changed

7 files changed

+603
-11
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,6 +2435,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
24352435
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
24362436
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
24372437
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
2438+
OptionalAttr<UnitAttr>:$create_named_contraction,
24382439
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
24392440

24402441
let results = (outs);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,12 +876,15 @@ struct VectorizationResult {
876876
/// greater than or equal to their counterpart iteration space sizes, if static.
877877
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
878878
/// shapes.
879+
/// Optionally, `createNamedContraction` can force compatible contractions to be
880+
/// vectorized directly to vector.contract operation.
879881
FailureOr<VectorizationResult>
880882
vectorize(RewriterBase &rewriter, Operation *op,
881883
ArrayRef<int64_t> inputVectorSizes = {},
882884
ArrayRef<bool> inputScalableVecDims = {},
883885
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
884-
bool assumeDynamicDimsMatchVecSizes = false);
886+
bool assumeDynamicDimsMatchVecSizes = false,
887+
bool createNamedContraction = false);
885888

886889
/// Emit a suitable vector form for a Copy op with fully static shape.
887890
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ bool isLinearizableVector(VectorType type);
227227
/// Note: all read offsets are set to 0.
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230-
bool useInBoundsInsteadOfMasking = false);
230+
bool useInBoundsInsteadOfMasking = false,
231+
ArrayRef<bool> scalableDims = {});
231232

232233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
233234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3920,8 +3920,10 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39203920
}
39213921
FailureOr<VectorizationResult> vectorResults =
39223922
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3923-
getVectorizeNdExtract().value_or(false), false,
3924-
getAssumeDynamicDimsMatchVecSizes().value_or(false));
3923+
getVectorizeNdExtract().value_or(false),
3924+
/*flatten1DDepthwiseConv=*/false,
3925+
getAssumeDynamicDimsMatchVecSizes().value_or(false),
3926+
getCreateNamedContraction().value_or(false));
39253927
if (failed(vectorResults)) {
39263928
return mlir::emitSilenceableFailure(target->getLoc())
39273929
<< "Attempted to vectorize, but failed";

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
2626
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2727
#include "mlir/IR/AffineExpr.h"
28+
#include "mlir/IR/AffineMap.h"
2829
#include "mlir/IR/Builders.h"
2930
#include "mlir/IR/BuiltinTypeInterfaces.h"
3031
#include "mlir/IR/BuiltinTypes.h"
@@ -1709,10 +1710,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17091710
return write;
17101711

17111712
// Compute the mask and mask the write Op.
1712-
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1713+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1714+
vecToStoreType.getScalableDims());
17131715

17141716
SmallVector<OpFoldResult> destSizes =
1715-
tensor::getMixedSizes(builder, loc, dest);
1717+
isa<MemRefType>(dest.getType())
1718+
? memref::getMixedSizes(builder, loc, dest)
1719+
: tensor::getMixedSizes(builder, loc, dest);
17161720
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
17171721
destSizes.end());
17181722

@@ -2118,6 +2122,92 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
21182122
return success();
21192123
}
21202124

2125+
/// Vectorize a named linalg contraction op into:
2126+
/// vector::TransferReadOp - Reads vectors from the operands
2127+
/// vector::ContractionOp - Performs contraction
2128+
/// vector::TransferWriteOp - Write the result vector back to the
2129+
/// destination
2130+
/// The operands shapes are preserved and loaded directly into vectors.
2131+
/// Any further permutations or numerical casting remain within contraction op.
2132+
static LogicalResult
2133+
vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2134+
LinalgOp linalgOp,
2135+
SmallVectorImpl<Value> &newResults) {
2136+
Location loc = linalgOp.getLoc();
2137+
MLIRContext *ctx = linalgOp.getContext();
2138+
2139+
// For simplicity, contraction vectorization is limited to linalg named ops.
2140+
// Generic op is ignored as not every arbitrary contraction body can be
2141+
// expressed by a vector.contract.
2142+
if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2143+
return failure();
2144+
2145+
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2146+
Operation *reduceOp = matchLinalgReduction(outOperand);
2147+
auto maybeKind = getCombinerOpKind(reduceOp);
2148+
if (!maybeKind) {
2149+
LDBG("Failed to determine contraction combining kind.\n");
2150+
return failure();
2151+
}
2152+
2153+
// Check that all dimensions are present in the input operands.
2154+
// Arbitrary broadcasts are not supported by the vector contraction.
2155+
// Broadcasts are expected to be decomposed before vectorization.
2156+
AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2157+
AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2158+
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2159+
LDBG("Contractions with broadcasts are not supported.\n");
2160+
return failure();
2161+
}
2162+
2163+
// Load operands.
2164+
SmallVector<Value> vecOperands;
2165+
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2166+
// The operand vector shape is computed by mapping the canonical vector
2167+
// shape to the operand's domain. Further permutations are left as a part of
2168+
// the contraction.
2169+
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2170+
AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2171+
indexingMap.getNumResults(), rewriter.getContext());
2172+
Type elemType = getElementTypeOrSelf(opOperand.get());
2173+
VectorType readType =
2174+
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2175+
2176+
Value read = mlir::vector::createReadOrMaskedRead(
2177+
rewriter, loc, opOperand.get(), readType.getShape(),
2178+
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2179+
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2180+
vecOperands.push_back(read);
2181+
}
2182+
2183+
// Remap iterators from linalg to vector.
2184+
SmallVector<Attribute> iterAttrs;
2185+
auto iterators = linalgOp.getIteratorTypesArray();
2186+
for (utils::IteratorType iter : iterators) {
2187+
auto vecIter = iter == utils::IteratorType::parallel
2188+
? vector::IteratorType::parallel
2189+
: vector::IteratorType::reduction;
2190+
iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2191+
}
2192+
2193+
// Create contraction.
2194+
Operation *contractOp = rewriter.create<vector::ContractionOp>(
2195+
loc, /*lhs=*/vecOperands[0],
2196+
/*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2197+
linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2198+
contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2199+
2200+
// Store result.
2201+
Operation *write = createWriteOrMaskedWrite(
2202+
rewriter, loc, contractOp->getResult(0), outOperand->get());
2203+
2204+
// Finalize.
2205+
if (!write->getResults().empty())
2206+
newResults.push_back(write->getResult(0));
2207+
2208+
return success();
2209+
}
2210+
21212211
namespace {
21222212
enum class ConvOperationKind { Conv, Pool };
21232213
} // namespace
@@ -2557,7 +2647,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25572647
FailureOr<VectorizationResult> mlir::linalg::vectorize(
25582648
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
25592649
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2560-
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
2650+
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2651+
bool createNamedContraction) {
25612652
LDBG("Attempting to vectorize:\n" << *op << "\n");
25622653
LDBG("Input vector sizes: ");
25632654
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2604,6 +2695,11 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
26042695
return failure();
26052696
}
26062697

2698+
if (createNamedContraction &&
2699+
isa<ContractionOpInterface>(linalgOp.getOperation()))
2700+
return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2701+
results);
2702+
26072703
LDBG("Vectorize generic by broadcasting to the canonical vector "
26082704
"shape\n");
26092705

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
320320
Value source,
321321
ArrayRef<int64_t> inputVectorSizes,
322322
Value padValue,
323-
bool useInBoundsInsteadOfMasking) {
323+
bool useInBoundsInsteadOfMasking,
324+
ArrayRef<bool> scalableDims) {
324325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
325326
"invalid input vector sizes");
326327
auto sourceShapedType = cast<ShapedType>(source.getType());
327328
auto sourceShape = sourceShapedType.getShape();
328329
assert(sourceShape.size() == inputVectorSizes.size() &&
329330
"expected same ranks.");
330-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
331+
auto vectorType =
332+
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
352354
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
353355
return transferReadOp;
354356
SmallVector<OpFoldResult> mixedSourceDims =
355-
tensor::getMixedSizes(builder, loc, source);
357+
isa<MemRefType>(source.getType())
358+
? memref::getMixedSizes(builder, loc, source)
359+
: tensor::getMixedSizes(builder, loc, source);
356360

357-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
361+
auto maskType =
362+
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
358363
Value mask =
359364
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
360365
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)