Skip to content

Commit 3f40948

Browse files
committed
Address feedback
1 parent 9d91abe commit 3f40948

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454
return slicedIndices;
5555
}
5656

57-
// compute the new indices for vector.load/store by adding offsets to
58-
// originalIndices.
57+
// Compute the new indices for vector.load/store by adding `offsets` to
58+
// `originalIndices`.
5959
// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
60-
// Last m of originalIndices will be updated.
60+
// Last m of `originalIndices` will be updated.
6161
static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
6262
Location loc,
6363
ArrayRef<Value> originalIndices,
@@ -658,6 +658,20 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
658658
vector::UnrollVectorOptions options;
659659
};
660660

661+
// clang-format off
662+
// This pattern unrolls the vector load into multiple 1D vector loads by
663+
// extracting slices from the base memory and inserting them into the result
664+
// vector using vector.insert_strided_slice.
665+
// Following,
666+
// vector.load %base[%indices] : memref<4x4xf32>, vector<4x4xf32>
667+
// is converted to :
668+
// %cst = arith.constant dense<0.0> : vector<4x4xf32>
669+
// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
670+
// %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
671+
// %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32>
672+
// %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
673+
// ...
674+
// clang-format on
661675
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
662676
UnrollLoadPattern(MLIRContext *context,
663677
const vector::UnrollVectorOptions &options,
@@ -707,6 +721,17 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
707721
vector::UnrollVectorOptions options;
708722
};
709723

724+
// This pattern unrolls the vector store into multiple 1D vector stores by
725+
// extracting slices from the source vector and storing them into the
726+
// destination.
727+
// Following,
728+
// vector.store %source, %base[%indices] : vector<4x4xf32>
729+
// is converted to :
730+
// %slice_0 = vector.extract %source[0] : vector<4xf32>
731+
// vector.store %slice_0, %base[%indices] : vector<4xf32>
732+
// %slice_1 = vector.extract %source[1] : vector<4xf32>
733+
// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
734+
// ...
710735
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
711736
UnrollStorePattern(MLIRContext *context,
712737
const vector::UnrollVectorOptions &options,

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,14 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
380380
// CHECK: return [[r3]] : vector<4x4xf32>
381381

382382

383-
// CHECK-LABEL: func.func @unroll_2D_vector_load(
383+
func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
384+
%c0 = arith.constant 0 : index
385+
%0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
386+
return %0 : vector<4x4xf16>
387+
}
388+
389+
// CHECK-LABEL: func.func @vector_load_2D(
384390
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
385-
func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
386391
// CHECK: %[[C3:.*]] = arith.constant 3 : index
387392
// CHECK: %[[C2:.*]] = arith.constant 2 : index
388393
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -397,14 +402,16 @@ func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
397402
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
398403
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
399404
// CHECK: return %[[V7]] : vector<4x4xf16>
405+
406+
407+
func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
400408
%c0 = arith.constant 0 : index
401-
%0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
402-
return %0 : vector<4x4xf16>
409+
vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
410+
return
403411
}
404412

405-
// CHECK-LABEL: func.func @unroll_2D_vector_store(
413+
// CHECK-LABEL: func.func @vector_store_2D(
406414
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
407-
func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
408415
// CHECK: %[[C3:.*]] = arith.constant 3 : index
409416
// CHECK: %[[C2:.*]] = arith.constant 2 : index
410417
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -417,14 +424,16 @@ func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>
417424
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
418425
// CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
419426
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
420-
%c0 = arith.constant 0 : index
421-
vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
422-
return
427+
428+
429+
func.func @vector_load_4D_to_2D(%mem: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
430+
%c1 = arith.constant 1 : index
431+
%0 = vector.load %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
432+
return %0 : vector<2x2xf16>
423433
}
424434

425-
// CHECK-LABEL: func.func @unroll_vector_load(
435+
// CHECK-LABEL: func.func @vector_load_4D_to_2D(
426436
// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
427-
func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
428437
// CHECK: %[[C2:.*]] = arith.constant 2 : index
429438
// CHECK: %[[C1:.*]] = arith.constant 1 : index
430439
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
@@ -433,21 +442,18 @@ func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
433442
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
434443
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
435444
// CHECK: return %[[V3]] : vector<2x2xf16>
445+
446+
func.func @vector_store_2D_to_4D(%mem: memref<4x4x4x4xf16>, %v: vector<2x2xf16>) {
436447
%c1 = arith.constant 1 : index
437-
%0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
438-
return %0 : vector<2x2xf16>
448+
vector.store %v, %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
449+
return
439450
}
440451

441-
// CHECK-LABEL: func.func @unroll_vector_store(
452+
// CHECK-LABEL: func.func @vector_store_2D_to_4D(
442453
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
443-
func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
444454
// CHECK: %[[C2:.*]] = arith.constant 2 : index
445455
// CHECK: %[[C1:.*]] = arith.constant 1 : index
446456
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
447457
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
448458
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
449459
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
450-
%c1 = arith.constant 1 : index
451-
vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
452-
return
453-
}

0 commit comments

Comments
 (0)