@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
6767 return rewriter.create <LLVM::ExtractValueOp>(loc, val, pos);
6868}
6969
70+ // Helper that returns data layout alignment of a vector.
71+ LogicalResult getVectorAlignment (const LLVMTypeConverter &typeConverter,
72+ VectorType vectorType, unsigned &align) {
73+ Type convertedVectorTy = typeConverter.convertType (vectorType);
74+ if (!convertedVectorTy)
75+ return failure ();
76+
77+ llvm::LLVMContext llvmContext;
78+ align = LLVM::TypeToLLVMIRTranslator (llvmContext)
79+ .getPreferredAlignment (convertedVectorTy,
80+ typeConverter.getDataLayout ());
81+
82+ return success ();
83+ }
84+
7085// Helper that returns data layout alignment of a memref.
7186LogicalResult getMemRefAlignment (const LLVMTypeConverter &typeConverter,
7287 MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
8297 return success ();
8398}
8499
100+ // Helper to resolve the alignment for vector load/store, gather and scatter
101+ // ops. If useVectorAlignment is true, get the preferred alignment for the
102+ // vector type in the operation. This option is used for hardware backends with
103+ // vectorization. Otherwise, use the preferred alignment of the element type of
104+ // the memref. Note that if you choose to use vector alignment, the shape of the
105+ // vector type must be resolved before the ConvertVectorToLLVM pass is run.
106+ LogicalResult getVectorToLLVMAlignment (const LLVMTypeConverter &typeConverter,
107+ VectorType vectorType,
108+ MemRefType memrefType, unsigned &align,
109+ bool useVectorAlignment) {
110+ if (useVectorAlignment) {
111+ if (failed (getVectorAlignment (typeConverter, vectorType, align))) {
112+ return failure ();
113+ }
114+ } else {
115+ if (failed (getMemRefAlignment (typeConverter, memrefType, align))) {
116+ return failure ();
117+ }
118+ }
119+ return success ();
120+ }
121+
85122// Check if the last stride is non-unit and has a valid memory space.
86123static LogicalResult isMemRefTypeSupported (MemRefType memRefType,
87124 const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
224261template <class LoadOrStoreOp >
225262class VectorLoadStoreConversion : public ConvertOpToLLVMPattern <LoadOrStoreOp> {
226263public:
264+ explicit VectorLoadStoreConversion (const LLVMTypeConverter &typeConv,
265+ bool useVectorAlign)
266+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
267+ useVectorAlignment(useVectorAlign) {}
227268 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
228269
229270 LogicalResult
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
240281
241282 // Resolve alignment.
242283 unsigned align;
243- if (failed (getMemRefAlignment (*this ->getTypeConverter (), memRefTy, align)))
244- return failure ();
284+ if (failed (getVectorToLLVMAlignment (*this ->getTypeConverter (), vectorTy,
285+ memRefTy, align, useVectorAlignment)))
286+ return rewriter.notifyMatchFailure (loadOrStoreOp,
287+ " could not resolve alignment" );
245288
246289 // Resolve address.
247290 auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
252295 rewriter);
253296 return success ();
254297 }
298+
299+ private:
300+ // If true, use the preferred alignment of the vector type.
301+ // If false, use the preferred alignment of the element type
302+ // of the memref. This flag is intended for use with hardware
303+ // backends that require alignment of vector operations.
304+ const bool useVectorAlignment;
255305};
256306
257307// / Conversion pattern for a vector.gather.
258308class VectorGatherOpConversion
259309 : public ConvertOpToLLVMPattern<vector::GatherOp> {
260310public:
311+ explicit VectorGatherOpConversion (const LLVMTypeConverter &typeConv,
312+ bool useVectorAlign)
313+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
314+ useVectorAlignment(useVectorAlign) {}
261315 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
262316
263317 LogicalResult
@@ -278,10 +332,9 @@ class VectorGatherOpConversion
278332
279333 // Resolve alignment.
280334 unsigned align;
281- if (failed (getMemRefAlignment (*getTypeConverter (), memRefType, align))) {
282- return rewriter.notifyMatchFailure (gather,
283- " could not resolve memref alignment" );
284- }
335+ if (failed (getVectorToLLVMAlignment (*this ->getTypeConverter (), vType,
336+ memRefType, align, useVectorAlignment)))
337+ return rewriter.notifyMatchFailure (gather, " could not resolve alignment" );
285338
286339 // Resolve address.
287340 Value ptr = getStridedElementPtr (loc, memRefType, adaptor.getBase (),
@@ -297,12 +350,24 @@ class VectorGatherOpConversion
297350 adaptor.getPassThru (), rewriter.getI32IntegerAttr (align));
298351 return success ();
299352 }
353+
354+ private:
355+ // If true, use the preferred alignment of the vector type.
356+ // If false, use the preferred alignment of the element type
357+ // of the memref. This flag is intended for use with hardware
358+ // backends that require alignment of vector operations.
359+ const bool useVectorAlignment;
300360};
301361
302362// / Conversion pattern for a vector.scatter.
303363class VectorScatterOpConversion
304364 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
305365public:
366+ explicit VectorScatterOpConversion (const LLVMTypeConverter &typeConv,
367+ bool useVectorAlign)
368+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
369+ useVectorAlignment(useVectorAlign) {}
370+
306371 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
307372
308373 LogicalResult
@@ -322,10 +387,10 @@ class VectorScatterOpConversion
322387
323388 // Resolve alignment.
324389 unsigned align;
325- if (failed (getMemRefAlignment (*getTypeConverter (), memRefType, align))) {
390+ if (failed (getVectorToLLVMAlignment (*this ->getTypeConverter (), vType,
391+ memRefType, align, useVectorAlignment)))
326392 return rewriter.notifyMatchFailure (scatter,
327- " could not resolve memref alignment" );
328- }
393+ " could not resolve alignment" );
329394
330395 // Resolve address.
331396 Value ptr = getStridedElementPtr (loc, memRefType, adaptor.getBase (),
@@ -340,6 +405,13 @@ class VectorScatterOpConversion
340405 rewriter.getI32IntegerAttr (align));
341406 return success ();
342407 }
408+
409+ private:
410+ // If true, use the preferred alignment of the vector type.
411+ // If false, use the preferred alignment of the element type
412+ // of the memref. This flag is intended for use with hardware
413+ // backends that require alignment of vector operations.
414+ const bool useVectorAlignment;
343415};
344416
345417// / Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
19282000// / Populate the given list with patterns that convert from Vector to LLVM.
19292001void mlir::populateVectorToLLVMConversionPatterns (
19302002 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1931- bool reassociateFPReductions, bool force32BitVectorIndices) {
2003+ bool reassociateFPReductions, bool force32BitVectorIndices,
2004+ bool useVectorAlignment) {
19322005 // This function populates only ConversionPatterns, not RewritePatterns.
19332006 MLIRContext *ctx = converter.getDialect ()->getContext ();
19342007 patterns.add <VectorReductionOpConversion>(converter, reassociateFPReductions);
19352008 patterns.add <VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2009+ patterns.add <VectorLoadStoreConversion<vector::LoadOp>,
2010+ VectorLoadStoreConversion<vector::MaskedLoadOp>,
2011+ VectorLoadStoreConversion<vector::StoreOp>,
2012+ VectorLoadStoreConversion<vector::MaskedStoreOp>,
2013+ VectorGatherOpConversion, VectorScatterOpConversion>(
2014+ converter, useVectorAlignment);
19362015 patterns.add <VectorBitCastOpConversion, VectorShuffleOpConversion,
19372016 VectorExtractElementOpConversion, VectorExtractOpConversion,
19382017 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
19392018 VectorInsertOpConversion, VectorPrintOpConversion,
19402019 VectorTypeCastOpConversion, VectorScaleOpConversion,
1941- VectorLoadStoreConversion<vector::LoadOp>,
1942- VectorLoadStoreConversion<vector::MaskedLoadOp>,
1943- VectorLoadStoreConversion<vector::StoreOp>,
1944- VectorLoadStoreConversion<vector::MaskedStoreOp>,
1945- VectorGatherOpConversion, VectorScatterOpConversion,
19462020 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
19472021 VectorSplatOpLowering, VectorSplatNdOpLowering,
19482022 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
0 commit comments