Skip to content

Commit 1287ed1

Browse files
authored
[mlir][vector] Use source as the source argument name (llvm#158258)
This patch updates the following ops to use `source` (instead of `vector`) as the name for their source argument: * `vector.extract` * `vector.scalable.extract` * `vector.extract_strided_slice` This change ensures naming consistency with the "builders" for these Ops that already use the name `source` rather than `vector`. It also addresses part of: * llvm#131602 Specifically, it ensures that we use `source` and `dest` for read and write operations, respectively (as opposed to `vector` and `dest`).
1 parent 04cd39a commit 1287ed1

File tree

18 files changed

+87
-74
lines changed

18 files changed

+87
-74
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def Vector_ExtractOp :
675675
}];
676676

677677
let arguments = (ins
678-
AnyVectorOfAnyRank:$vector,
678+
AnyVectorOfAnyRank:$source,
679679
Variadic<Index>:$dynamic_position,
680680
DenseI64ArrayAttr:$static_position
681681
);
@@ -692,7 +692,7 @@ def Vector_ExtractOp :
692692

693693
let extraClassDeclaration = extraPoisonClassDeclaration # [{
694694
VectorType getSourceVectorType() {
695-
return ::llvm::cast<VectorType>(getVector().getType());
695+
return ::llvm::cast<VectorType>(getSource().getType());
696696
}
697697

698698
/// Return a vector with all the static and dynamic position indices.
@@ -709,12 +709,17 @@ def Vector_ExtractOp :
709709
bool hasDynamicPosition() {
710710
return !getDynamicPosition().empty();
711711
}
712+
713+
/// Wrapper for getSource, which replaced getVector.
714+
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
715+
return getSource();
716+
}
712717
}];
713718

714719
let assemblyFormat = [{
715-
$vector ``
720+
$source ``
716721
custom<DynamicIndexList>($dynamic_position, $static_position)
717-
attr-dict `:` type($result) `from` type($vector)
722+
attr-dict `:` type($result) `from` type($source)
718723
}];
719724

720725
let hasCanonicalizer = 1;
@@ -1023,6 +1028,10 @@ def Vector_ScalableExtractOp :
10231028
VectorType getResultVectorType() {
10241029
return ::llvm::cast<VectorType>(getResult().getType());
10251030
}
1031+
/// Wrapper for getSource, which replaced getVector.
1032+
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
1033+
return getSource();
1034+
}
10261035
}];
10271036
}
10281037

@@ -1174,7 +1183,7 @@ def Vector_ExtractStridedSliceOp :
11741183
Vector_Op<"extract_strided_slice", [Pure,
11751184
PredOpTrait<"operand and result have same element type",
11761185
TCresVTEtIsSameAsOpBase<0, 0>>]>,
1177-
Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
1186+
Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
11781187
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
11791188
Results<(outs AnyVectorOfNonZeroRank)> {
11801189
let summary = "extract_strided_slice operation";
@@ -1209,19 +1218,23 @@ def Vector_ExtractStridedSliceOp :
12091218
];
12101219
let extraClassDeclaration = [{
12111220
VectorType getSourceVectorType() {
1212-
return ::llvm::cast<VectorType>(getVector().getType());
1221+
return ::llvm::cast<VectorType>(getSource().getType());
12131222
}
12141223
void getOffsets(SmallVectorImpl<int64_t> &results);
12151224
bool hasNonUnitStrides() {
12161225
return llvm::any_of(getStrides(), [](Attribute attr) {
12171226
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
12181227
});
12191228
}
1229+
/// Wrapper for getSource, which replaced getVector.
1230+
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
1231+
return getSource();
1232+
}
12201233
}];
12211234
let hasCanonicalizer = 1;
12221235
let hasFolder = 1;
12231236
let hasVerifier = 1;
1224-
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
1237+
let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
12251238
}
12261239

12271240
// TODO: Tighten semantics so that masks and inbounds can't be used

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
462462
auto loc = extractOp.getLoc();
463463
auto position = extractOp.getMixedPosition();
464464

465-
Value sourceVector = extractOp.getVector();
465+
Value sourceVector = extractOp.getSource();
466466

467467
// Extract entire vector. Should be handled by folder, but just to be safe.
468468
if (position.empty()) {
@@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
692692
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
693693

694694
auto createMaskOp =
695-
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
695+
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
696696
if (!createMaskOp)
697697
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
698698

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
962962
return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
963963

964964
// Find the vector.transer_read whose result vector is being sliced.
965-
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
965+
auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
966966
if (!transferReadOp)
967967
return rewriter.notifyMatchFailure(op, "no transfer read");
968968

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ class VectorExtractOpConversion
11311131
positionVec.push_back(rewriter.getZeroAttr(idxType));
11321132
}
11331133

1134-
Value extracted = adaptor.getVector();
1134+
Value extracted = adaptor.getSource();
11351135
if (extractsAggregate) {
11361136
ArrayRef<OpFoldResult> position(positionVec);
11371137
if (extractsScalar) {

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
14141414
/// Return the vector from which newly generated ExtracOps will extract.
14151415
Value getDataVector(TransferWriteOp xferOp) const {
14161416
if (auto extractOp = getExtractOp(xferOp))
1417-
return extractOp.getVector();
1417+
return extractOp.getSource();
14181418
return xferOp.getVector();
14191419
}
14201420

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
189189
if (!dstType)
190190
return failure();
191191

192-
if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
193-
rewriter.replaceOp(extractOp, adaptor.getVector());
192+
if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
193+
rewriter.replaceOp(extractOp, adaptor.getSource());
194194
return success();
195195
}
196196

@@ -201,15 +201,15 @@ struct VectorExtractOpConvert final
201201
extractOp,
202202
"Static use of poison index handled elsewhere (folded to poison)");
203203
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
204-
extractOp, dstType, adaptor.getVector(),
204+
extractOp, dstType, adaptor.getSource(),
205205
rewriter.getI32ArrayAttr(id.value()));
206206
} else {
207207
Value sanitizedIndex = sanitizeDynamicIndex(
208208
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
209209
vector::ExtractOp::kPoisonIndex,
210210
extractOp.getSourceVectorType().getNumElements());
211211
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
212-
extractOp, dstType, adaptor.getVector(), sanitizedIndex);
212+
extractOp, dstType, adaptor.getSource(), sanitizedIndex);
213213
}
214214
return success();
215215
}

mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
445445
return rewriter.notifyMatchFailure(
446446
extractOp, "extracted type is not a 1-D scalable vector type");
447447

448-
auto *extendOp = extractOp.getVector().getDefiningOp();
448+
auto *extendOp = extractOp.getSource().getDefiningOp();
449449
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
450450
extendOp))
451451
return rewriter.notifyMatchFailure(extractOp,

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
542542
PatternRewriter &rewriter) const override {
543543
auto loc = extractOp.getLoc();
544544
auto createMaskOp =
545-
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
545+
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
546546
if (!createMaskOp)
547547
return rewriter.notifyMatchFailure(
548548
extractOp, "extract not from vector.create_mask op");

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
105105
return WalkResult::advance();
106106

107107
// Check that the vector to extract from is a BlockArgument.
108-
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
108+
auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
109109
if (!blockArg)
110110
return WalkResult::advance();
111111

@@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
141141
return WalkResult::advance();
142142

143143
rewriter.modifyOpInPlace(broadcast, [&] {
144-
extractOp.getVectorMutable().assign(initArg->get());
144+
extractOp.getSourceMutable().assign(initArg->get());
145145
});
146146
loop.moveOutOfLoop(extractOp);
147147
rewriter.moveOpAfter(broadcast, loop);

mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
7171
if (auto extractOp =
7272
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
7373
if (auto maskOp =
74-
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
74+
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
7575
return TransferMask{maskOp,
7676
SmallVector<int64_t>(extractOp.getStaticPosition())};
7777

0 commit comments

Comments
 (0)