Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();

// vector.load requires the most minor memref dim to have unit stride
// vector.load requires the most minor memref dim to have unit stride,
// or the result vector type to have only one element
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1)
if (stridesAttr.getStrides().back() != 1 &&
resultTy.getNumElements() != 1)
return failure();
}
}
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Vector/vector-gather-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
return %0 : vector<2xf32>
}

// CHECK-LABEL: @gather_strided_memref_1d
// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
// CHECK: %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
// CHECK: scf.yield %[[RES]] : vector<1xf32>
// CHECK: } else {
// CHECK: scf.yield %arg3 : vector<1xf32>
// CHECK: }
// CHECK: return %[[RET]] : vector<1xf32>
func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
%c0 = arith.constant 0 : index
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
return %0 : vector<1xf32>
}

// CHECK-LABEL: @gather_tensor_2d
// CHECK: scf.if
// CHECK: tensor.extract
Expand Down
Loading