Skip to content

Commit c9853d6

Browse files
committed
[MLIR][Vector] Allow strided memref for one-element vector.load in lowering vector.gather
Signed-off-by: PragmaTwice <[email protected]>
1 parent 86b1b06 commit c9853d6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
205205
Value condMask = op.getMask();
206206
Value base = op.getBase();
207207

208-
// vector.load requires the most minor memref dim to have unit stride
208+
// vector.load requires the most minor memref dim to have unit stride,
209+
// or the result vector type to have only one element
209210
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
210211
if (auto stridesAttr =
211212
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
212-
if (stridesAttr.getStrides().back() != 1)
213+
if (stridesAttr.getStrides().back() != 1 &&
214+
resultTy.getNumElements() != 1)
213215
return failure();
214216
}
215217
}

0 commit comments

Comments
 (0)