Skip to content

Commit 700b0ea

Browse files
committed
[MLIR] Remove Transfer vectore lower pattern.
1 parent 6a5ac3c commit 700b0ea

File tree

8 files changed

+97
-372
lines changed

8 files changed

+97
-372
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,8 +1910,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
19101910
MaskedReductionOpConversion, VectorInterleaveOpLowering,
19111911
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
19121912
VectorScalableStepOpLowering>(converter);
1913-
// Transfer ops with rank > 1 are handled by VectorToSCF.
1914-
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
19151913
}
19161914

19171915
void mlir::populateVectorToLLVMMatrixConversionPatterns(

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
7474
populateVectorInterleaveLoweringPatterns(patterns);
7575
populateVectorTransposeLoweringPatterns(patterns,
7676
VectorTransformsOptions());
77-
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
78-
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
7977
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
8078
}
8179

@@ -84,7 +82,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8482
LLVMTypeConverter converter(&getContext(), options);
8583
RewritePatternSet patterns(&getContext());
8684
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
87-
populateVectorTransferLoweringPatterns(patterns);
8885
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
8986
populateVectorToLLVMConversionPatterns(
9087
converter, patterns, reassociateFPReductions, force32BitVectorIndices);

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,10 @@ struct TransferWriteToVectorStoreLowering
636636
void mlir::vector::populateVectorTransferLoweringPatterns(
637637
RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
638638
PatternBenefit benefit) {
639-
patterns.add<TransferReadToVectorLoadLowering,
640-
TransferWriteToVectorStoreLowering>(patterns.getContext(),
641-
maxTransferRank, benefit);
642-
patterns
643-
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
644-
patterns.getContext(), benefit);
639+
// patterns.add<TransferReadToVectorLoadLowering,
640+
// TransferWriteToVectorStoreLowering>(patterns.getContext(),
641+
// maxTransferRank, benefit);
642+
// patterns
643+
// .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
644+
// patterns.getContext(), benefit);
645645
}

mlir/test/Conversion/GPUCommon/transfer_write.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: index, %arg3: vector<1xf32>) {
44
%c0 = arith.constant 0 : index
55
vector.warp_execute_on_lane_0(%arg0)[32] {
6-
// CHECK:%[[val:[0-9]+]] = llvm.extractelement
7-
// CHECK:%[[base:[0-9]+]] = llvm.extractvalue
8-
// CHECK:%[[ptr:[0-9]+]] = llvm.getelementptr %[[base]]
9-
// CHECK:llvm.store %[[val]], %[[ptr]]
6+
// CHECK: vector.transfer_write %arg9, %[[MEM:.*]][%[[IDX:.*]], %[[IDX]]] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
107
vector.transfer_write %arg3, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
118
}
129
return

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,12 +2953,16 @@ func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : ind
29532953
}
29542954

29552955
// CHECK-LABEL: func @vector_load_op_0d
2956-
// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
2957-
// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
2958-
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
2959-
// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
2960-
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
2961-
// CHECK: return %[[cast]] : vector<f32>
2956+
// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %arg2 : index to i64
2957+
// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %arg1 : index to i64
2958+
// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
2959+
// CHECK: %[[S3:.*]] = llvm.extractvalue %[[S2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
2960+
// CHECK: %[[S4:.*]] = llvm.mlir.constant(100 : index) : i64
2961+
// CHECK: %[[S5:.*]] = llvm.mul %[[S1]], %[[S4]] : i64
2962+
// CHECK: %[[S6:.*]] = llvm.add %[[S5]], %[[S0]] : i64
2963+
// CHECK: %[[S7:.*]] = llvm.getelementptr %[[S3]][%[[S6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2964+
// CHECK: %[[S8:.*]] = llvm.load %[[S7]] {alignment = 4 : i64} : !llvm.ptr -> vector<1xf32>
2965+
// CHECK: %[[S9:.*]] = builtin.unrealized_conversion_cast %[[S8]] : vector<1xf32> to vector<f32>
29622966

29632967
// -----
29642968

@@ -2969,11 +2973,17 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
29692973
}
29702974

29712975
// CHECK-LABEL: func @vector_store_op_0d
2972-
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
2973-
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
2974-
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
2975-
// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
2976-
// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
2976+
// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %arg2 : index to i64
2977+
// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %arg1 : index to i64
2978+
// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
2979+
// CHECK: %[[S3:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
2980+
// CHECK: %[[S4:.*]] = builtin.unrealized_conversion_cast %[[S3]] : vector<f32> to vector<1xf32>
2981+
// CHECK: %[[S5:.*]] = llvm.extractvalue %[[S2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
2982+
// CHECK: %[[S6:.*]] = llvm.mlir.constant(100 : index) : i64
2983+
// CHECK: %[[S7:.*]] = llvm.mul %[[S1]], %[[S6]] : i64
2984+
// CHECK: %[[S8:.*]] = llvm.add %[[S7]], %[[S0]] : i64
2985+
// CHECK: %[[S9:.*]] = llvm.getelementptr %[[S5]][%[[S8]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2986+
// CHECK: llvm.store %[[S4]], %[[S9]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr
29772987

29782988
// -----
29792989

0 commit comments

Comments
 (0)