Skip to content

Commit 8c3dbf8

Browse files
committed
Linearization patterns for vector.load and vector.store
1 parent 4084ffc commit 8c3dbf8

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
623623
}
624624
};
625625

626+
/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
627+
/// It currently supports only lineariztion of <1XN> to <N>
628+
/// Following,
629+
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
630+
/// is converted to:
631+
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
632+
/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
633+
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
634+
using OpConversionPattern::OpConversionPattern;
635+
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
636+
PatternBenefit benefit = 1)
637+
: OpConversionPattern(typeConverter, context, benefit) {}
638+
639+
LogicalResult
640+
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
641+
ConversionPatternRewriter &rewriter) const override {
642+
VectorType vecTy = loadOp.getType();
643+
if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
644+
return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
645+
auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
646+
vecTy.isScalable());
647+
auto newLoad = rewriter.create<vector::LoadOp>(
648+
loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
649+
auto shapeCast = rewriter.create<vector::ShapeCastOp>(
650+
loadOp.getLoc(), vecTy, newLoad.getResult());
651+
rewriter.replaceOp(loadOp, shapeCast.getResult());
652+
return success();
653+
}
654+
};
655+
656+
/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
657+
/// It currently supports only lineariztion of <1XN> to <N>
658+
/// Following,
659+
/// vector.store %arg0, %arg1[%c0, %c0]
660+
/// : vector<1x4xf32>, memref<1x4xf32>
661+
/// is converted to:
662+
/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
663+
/// vector.store %arg0, %arg1[%c0, %%c0]
664+
/// : vector<4xf32>, memref<1x4xf32>
665+
struct LinearizeVectorStore final
666+
: public OpConversionPattern<vector::StoreOp> {
667+
using OpConversionPattern::OpConversionPattern;
668+
LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
669+
PatternBenefit benefit = 1)
670+
: OpConversionPattern(typeConverter, context, benefit) {}
671+
672+
LogicalResult
673+
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
674+
ConversionPatternRewriter &rewriter) const override {
675+
VectorType vecTy = storeOp.getValueToStore().getType();
676+
if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
677+
return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
678+
auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
679+
vecTy.isScalable());
680+
681+
Value valueToStore = adaptor.getValueToStore();
682+
if (valueToStore.getType() != linearTy) {
683+
valueToStore = rewriter.create<vector::ShapeCastOp>(
684+
storeOp.getLoc(), linearTy, valueToStore);
685+
}
686+
687+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
688+
storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
689+
return success();
690+
}
691+
};
692+
626693
} // namespace
627694

628695
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
714781
RewritePatternSet &patterns) {
715782
patterns
716783
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
717-
LinearizeVectorSplat, LinearizeVectorCreateMask>(
718-
typeConverter, patterns.getContext());
784+
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
785+
LinearizeVectorStore>(typeConverter, patterns.getContext());
719786
}
720787

721788
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
464464
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
465465
return %0 : vector<1x[16]xi1>
466466
}
467+
468+
// CHECK-LABEL: linearize_vector_load
469+
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
470+
func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
471+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
472+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
473+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
474+
// CHECK: return %[[CAST]] : vector<1x4xf32>
475+
%c0 = arith.constant 0 : index
476+
%0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
477+
return %0 : vector<1x4xf32>
478+
}
479+
480+
// CHECK-LABEL: linearize_vector_store
481+
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
482+
func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
483+
// CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
484+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
485+
// CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
486+
%c0 = arith.constant 0 : index
487+
vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
488+
return
489+
}

0 commit comments

Comments
 (0)