Skip to content

Commit bc3a47e

Browse files
committed
Add linearization patterns for vector.load & vector.store
1 parent d0ce861 commit bc3a47e

File tree

3 files changed

+261
-12
lines changed

3 files changed

+261
-12
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 149 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626

2727
using namespace mlir;
2828

29+
constexpr unsigned defaultTargetVectorBitWidth =
30+
std::numeric_limits<unsigned>::max();
31+
2932
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
33+
if (targetBitWidth == 0)
34+
return false;
3035
auto resultTypes = op->getResultTypes();
3136
for (auto resType : resultTypes) {
3237
VectorType vecType = dyn_cast<VectorType>(resType);
@@ -82,7 +87,7 @@ struct LinearizeConstantLike final
8287

8388
LinearizeConstantLike(
8489
const TypeConverter &typeConverter, MLIRContext *context,
85-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
90+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
8691
PatternBenefit benefit = 1)
8792
: OpTraitConversionPattern(typeConverter, context, benefit),
8893
targetVectorBitWidth(targetVectBitWidth) {}
@@ -136,7 +141,7 @@ struct LinearizeVectorizable final
136141
public:
137142
LinearizeVectorizable(
138143
const TypeConverter &typeConverter, MLIRContext *context,
139-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
144+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
140145
PatternBenefit benefit = 1)
141146
: OpTraitConversionPattern(typeConverter, context, benefit),
142147
targetVectorBitWidth(targetVectBitWidth) {}
@@ -175,7 +180,7 @@ struct LinearizeVectorExtractStridedSlice final
175180
using OpConversionPattern::OpConversionPattern;
176181
LinearizeVectorExtractStridedSlice(
177182
const TypeConverter &typeConverter, MLIRContext *context,
178-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
183+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
179184
PatternBenefit benefit = 1)
180185
: OpConversionPattern(typeConverter, context, benefit),
181186
targetVectorBitWidth(targetVectBitWidth) {}
@@ -289,7 +294,7 @@ struct LinearizeVectorShuffle final
289294
using OpConversionPattern::OpConversionPattern;
290295
LinearizeVectorShuffle(
291296
const TypeConverter &typeConverter, MLIRContext *context,
292-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
297+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
293298
PatternBenefit benefit = 1)
294299
: OpConversionPattern(typeConverter, context, benefit),
295300
targetVectorBitWidth(targetVectBitWidth) {}
@@ -362,13 +367,17 @@ struct LinearizeVectorExtract final
362367
using OpConversionPattern::OpConversionPattern;
363368
LinearizeVectorExtract(
364369
const TypeConverter &typeConverter, MLIRContext *context,
365-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
370+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
366371
PatternBenefit benefit = 1)
367372
: OpConversionPattern(typeConverter, context, benefit),
368373
targetVectorBitWidth(targetVectBitWidth) {}
369374
LogicalResult
370375
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
371376
ConversionPatternRewriter &rewriter) const override {
377+
// Skip if result is not a vector type
378+
if (!isa<VectorType>(extractOp.getType()))
379+
return rewriter.notifyMatchFailure(extractOp,
380+
"scalar extract is not supported.");
372381
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
373382
if (!dstTy)
374383
return rewriter.notifyMatchFailure(extractOp,
@@ -425,7 +434,7 @@ struct LinearizeVectorInsert final
425434
using OpConversionPattern::OpConversionPattern;
426435
LinearizeVectorInsert(
427436
const TypeConverter &typeConverter, MLIRContext *context,
428-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
437+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
429438
PatternBenefit benefit = 1)
430439
: OpConversionPattern(typeConverter, context, benefit),
431440
targetVectorBitWidth(targetVectBitWidth) {}
@@ -506,7 +515,7 @@ struct LinearizeVectorBitCast final
506515
using OpConversionPattern::OpConversionPattern;
507516
LinearizeVectorBitCast(
508517
const TypeConverter &typeConverter, MLIRContext *context,
509-
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
518+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
510519
PatternBenefit benefit = 1)
511520
: OpConversionPattern(typeConverter, context, benefit),
512521
targetVectorBitWidth(targetVectBitWidth) {}
@@ -531,12 +540,139 @@ struct LinearizeVectorBitCast final
531540
unsigned targetVectorBitWidth;
532541
};
533542

543+
// clang-format off
544+
/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
545+
/// that works on a linearized vector.
546+
/// Following,
547+
/// vector.load %base[%indices] : vector<4x4xf32>
548+
/// is converted to :
549+
/// %result = arith.constant dense<0.0> : vector<4x4xf32>
550+
/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
551+
/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
552+
/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
553+
/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
554+
/// ...
555+
/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
556+
/// them into the result vector. The pattern currently supports only 2D vectors
557+
// clang-format on
558+
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
559+
using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
560+
561+
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
562+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
563+
PatternBenefit benefit = 1)
564+
: OpConversionPattern(typeConverter, context, benefit),
565+
targetVectorBitWidth(targetVectBitWidth) {}
566+
567+
LogicalResult
568+
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
569+
ConversionPatternRewriter &rewriter) const override {
570+
auto loc = loadOp->getLoc();
571+
VectorType vecType = loadOp.getVectorType();
572+
auto shape = vecType.getShape();
573+
574+
if (shape.size() != 2)
575+
return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
576+
577+
auto unrollCount = shape[0];
578+
auto vecSize = shape[1];
579+
VectorType newVecType =
580+
VectorType::get({vecSize}, vecType.getElementType());
581+
582+
llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
583+
Value xBaseIndex = indices[0];
584+
585+
// Construct the 2D vector.
586+
Value resultVec =
587+
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
588+
// Emit unrolled loads for each 1D vector slice.
589+
for (auto i = 0; i < unrollCount; i++) {
590+
Value xIndex = xBaseIndex;
591+
if (i) {
592+
auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
593+
xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
594+
}
595+
indices[0] = xIndex;
596+
auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
597+
adaptor.getBase(), indices);
598+
resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
599+
}
600+
601+
rewriter.replaceOp(loadOp, resultVec);
602+
return success();
603+
}
604+
605+
private:
606+
unsigned targetVectorBitWidth;
607+
};
608+
609+
/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
610+
/// that works on a linearized vector.
611+
/// Following,
612+
/// vector.store %source, %base[%indices] : vector<4x4xf32>
613+
/// is converted to :
614+
/// %slice_0 = vector.extract %source[0] : vector<4xf32>
615+
/// vector.store %slice_0, %base[%indices] : vector<4xf32>
616+
/// %slice_1 = vector.extract %source[1] : vector<4xf32>
617+
/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
618+
/// ...
619+
/// This unrolls the 2D vector store into multiple 1D vector stores by
620+
/// extracting slices from the source vector and storing them into the
621+
/// destination. The pattern currently supports only 2D vectors
622+
struct LinearizeVectorStore final
623+
: public OpConversionPattern<vector::StoreOp> {
624+
using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
625+
626+
LinearizeVectorStore(
627+
const TypeConverter &typeConverter, MLIRContext *context,
628+
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
629+
PatternBenefit benefit = 1)
630+
: OpConversionPattern(typeConverter, context, benefit),
631+
targetVectorBitWidth(targetVectBitWidth) {}
632+
633+
LogicalResult
634+
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
635+
ConversionPatternRewriter &rewriter) const override {
636+
auto loc = storeOp->getLoc();
637+
VectorType vecType = storeOp.getVectorType();
638+
auto shape = vecType.getShape();
639+
640+
if (shape.size() != 2)
641+
return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
642+
643+
auto unrollCount = shape[0];
644+
llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
645+
Value xBaseIndex = indices[0];
646+
647+
auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
648+
adaptor.getValueToStore());
649+
650+
for (auto i = 0; i < unrollCount; i++) {
651+
auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
652+
Value xIndex = xBaseIndex;
653+
if (i) {
654+
auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
655+
xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
656+
}
657+
indices[0] = xIndex;
658+
rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
659+
indices);
660+
}
661+
rewriter.eraseOp(storeOp);
662+
return success();
663+
}
664+
665+
private:
666+
unsigned targetVectorBitWidth;
667+
};
668+
534669
} // namespace
535670

536671
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
537672
TypeConverter &typeConverter, RewritePatternSet &patterns,
538673
ConversionTarget &target, unsigned targetBitWidth) {
539674

675+
typeConverter.addConversion([](Type type) -> Type { return type; });
540676
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
541677
if (!isLinearizableVector(type))
542678
return type;
@@ -555,9 +691,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
555691
};
556692
typeConverter.addSourceMaterialization(materializeCast);
557693
typeConverter.addTargetMaterialization(materializeCast);
694+
target.addLegalOp<vector::ShapeCastOp>();
558695
target.markUnknownOpDynamicallyLegal(
559696
[=](Operation *op) -> std::optional<bool> {
560-
if ((isa<vector::BitCastOp>(op) ||
697+
if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp>(op) ||
561698
op->hasTrait<OpTrait::ConstantLike>() ||
562699
op->hasTrait<OpTrait::Vectorizable>())) {
563700
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -567,9 +704,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
567704
return std::nullopt;
568705
});
569706

570-
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
571-
LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
572-
targetBitWidth);
707+
patterns
708+
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
709+
LinearizeVectorLoad, LinearizeVectorStore>(
710+
typeConverter, patterns.getContext(), targetBitWidth);
573711
}
574712

575713
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,113 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
399399
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
400400
return %1 : vector<[4]x4xf16>
401401
}
402+
403+
// -----
404+
// ALL-LABEL: linearize_vector_load
405+
// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
406+
func.func @linearize_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
407+
// DEFAULT: %[[C1:.*]] = arith.constant 1 : index
408+
// DEFAULT: %[[C2:.*]] = arith.constant 2 : index
409+
// DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
410+
// DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
411+
// DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
412+
// DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
413+
// DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
414+
// DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
415+
// DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
416+
// DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
417+
// DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
418+
// DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
419+
// DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
420+
// DEFAULT: %[[C3:.*]] = arith.constant 3 : index
421+
// DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
422+
// DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
423+
// DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
424+
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
425+
// DEFAULT: return %[[CAST]] : vector<4x4xf16>
426+
427+
// BW-128: %[[C1:.*]] = arith.constant 1 : index
428+
// BW-128: %[[C2:.*]] = arith.constant 2 : index
429+
// BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
430+
// BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
431+
// BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
432+
// BW-128: %[[C1_0:.*]] = arith.constant 1 : index
433+
// BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
434+
// BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
435+
// BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
436+
// BW-128: %[[C2_1:.*]] = arith.constant 2 : index
437+
// BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
438+
// BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
439+
// BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
440+
// BW-128: %[[C3:.*]] = arith.constant 3 : index
441+
// BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
442+
// BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
443+
// BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
444+
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
445+
// BW-128: return %[[CAST]] : vector<4x4xf16>
446+
447+
// BW-0: %[[C1:.*]] = arith.constant 1 : index
448+
// BW-0: %[[C2:.*]] = arith.constant 2 : index
449+
// BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
450+
// BW-0: return %[[LOAD]] : vector<4x4xf16>
451+
%c1 = arith.constant 1 : index
452+
%c2 = arith.constant 2 : index
453+
%0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
454+
return %0 : vector<4x4xf16>
455+
}
456+
457+
// -----
458+
// ALL-LABEL: linearize_vector_store
459+
// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) {
460+
func.func @linearize_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
461+
// DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
462+
// DEFAULT: %[[C1:.*]] = arith.constant 1 : index
463+
// DEFAULT: %[[C2:.*]] = arith.constant 2 : index
464+
// DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
465+
// DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
466+
// DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
467+
// DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
468+
// DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
469+
// DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
470+
// DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
471+
// DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
472+
// DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
473+
// DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
474+
// DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
475+
// DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
476+
// DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
477+
// DEFAULT: %[[C3:.*]] = arith.constant 3 : index
478+
// DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
479+
// DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
480+
// DEFAULT: return
481+
482+
// BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
483+
// BW-128: %[[C1:.*]] = arith.constant 1 : index
484+
// BW-128: %[[C2:.*]] = arith.constant 2 : index
485+
// BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
486+
// BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
487+
// BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
488+
// BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
489+
// BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
490+
// BW-128: %[[C1_0:.*]] = arith.constant 1 : index
491+
// BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
492+
// BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
493+
// BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
494+
// BW-128: %[[C2_1:.*]] = arith.constant 2 : index
495+
// BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
496+
// BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
497+
// BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
498+
// BW-128: %[[C3:.*]] = arith.constant 3 : index
499+
// BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
500+
// BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
501+
// BW-128: return
502+
503+
// BW-0: %[[C1:.*]] = arith.constant 1 : index
504+
// BW-0: %[[C2:.*]] = arith.constant 2 : index
505+
// BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
506+
// BW-0: return
507+
%c1 = arith.constant 1 : index
508+
%c2 = arith.constant 2 : index
509+
vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
510+
return
511+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,8 @@ struct TestVectorLinearize final
852852
return "Linearizes ND vectors for N >= 2 into 1D vectors";
853853
}
854854
void getDependentDialects(DialectRegistry &registry) const override {
855-
registry.insert<vector::VectorDialect>();
855+
registry.insert<vector::VectorDialect, memref::MemRefDialect,
856+
arith::ArithDialect>();
856857
}
857858

858859
Option<unsigned> targetVectorBitwidth{

0 commit comments

Comments
 (0)