Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2047,8 +2047,8 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru,
ConfinedAttr<OptionalAttr<I64Attr>,
Expand All @@ -2072,19 +2072,19 @@ def Vector_GatherOp :

```mlir
func.func @gather_3D_to_2D(
%base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
%index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
%base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
%indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
%fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
%result = vector.gather %base[%i0, %i1, %i2]
[%index_vec], %mask, %fall_thru : [...]
%result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
[%indices], %mask, %fall_thru : [...]
return %result : vector<2x3xf32>
}
```

The indexing semantics are then,

```
result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
else pass_thru[i,j]
```
The index into `base` only varies in the innermost ((k-1)-th) dimension.
Expand Down Expand Up @@ -2118,16 +2118,16 @@ def Vector_GatherOp :

let extraClassDeclaration = [{
ShapedType getBaseType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndexVec().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getPassThruVectorType() { return getPassThru().getType(); }
VectorType getVectorType() { return getResult().getType(); }
}];

let assemblyFormat =
"$base `[` $indices `]` `[` $index_vec `]` `,` "
"$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
"type($index_vec) `,` type($mask) `,` type($pass_thru) "
"type($indices) `,` type($mask) `,` type($pass_thru) "
"`into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand All @@ -2150,8 +2150,8 @@ def Vector_GatherOp :
def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
ConfinedAttr<OptionalAttr<I64Attr>,
Expand Down Expand Up @@ -2207,15 +2207,15 @@ def Vector_ScatterOp :

let extraClassDeclaration = [{
MemRefType getMemRefType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndexVec().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];

let assemblyFormat =
"$base `[` $indices `]` `[` $index_vec `]` `,` "
"$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
"type($indices) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;

Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ class VectorGatherOpConversion

// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
adaptor.getBase(), adaptor.getOffsets());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
base, ptr, adaptor.getIndexVec(), vType);
base, ptr, adaptor.getIndices(), vType);

// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
Expand Down Expand Up @@ -362,10 +362,10 @@ class VectorScatterOpConversion

// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
adaptor.getBase(), adaptor.getOffsets());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
adaptor.getBase(), ptr, adaptor.getIndices(), vType);

// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5782,7 +5782,7 @@ LogicalResult GatherOp::verify() {

if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
if (llvm::size(getIndices()) != baseType.getRank())
if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
Expand Down Expand Up @@ -5854,11 +5854,11 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
if (!isa<MemRefType>(op.getBase().getType()))
return rewriter.notifyMatchFailure(op, "base must be of memref type");

if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();

rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
op.getIndices(), op.getMask(),
op.getOffsets(), op.getMask(),
op.getPassThru());
return success();
}
Expand All @@ -5882,7 +5882,7 @@ LogicalResult ScatterOp::verify() {

if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
Expand Down Expand Up @@ -5917,11 +5917,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();

rewriter.replaceOpWithNewOp<MaskedStoreOp>(
op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct GatherOpInterface
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {

LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
Value indexVec = op.getIndexVec();
Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();

Expand All @@ -69,7 +69,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
op.getIndices(), indexSubVec, maskSubVec,
op.getOffsets(), indexSubVec, maskSubVec,
passThruSubVec);
};

Expand Down Expand Up @@ -141,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
VectorType vType = op.getIndexVec().getType();
VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));

Value newIdxs =
arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);

// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = vector::GatherOp::create(
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);

return success();
Expand Down Expand Up @@ -195,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {

Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndexVec());
auto baseOffsets = llvm::to_vector(op.getIndices());
op.getIndices());
auto baseOffsets = llvm::to_vector(op.getOffsets());
Value lastBaseOffset = baseOffsets.back();

Value result = op.getPassThru();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
passthru);
return success();
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,15 +640,15 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);

result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
Expand Down
Loading