Skip to content

Commit 1fd9c02

Browse files
authored
[mlir] Adopt cast function objects. NFC. (#168228)
These were added in #165803.
1 parent 6b4fef0 commit 1fd9c02

File tree

5 files changed

+11
-17
lines changed

5 files changed

+11
-17
lines changed

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,8 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
159159

160160
return wrap(DIExpressionAttr::get(
161161
unwrap(ctx),
162-
llvm::map_to_vector(
163-
unwrapList(nOperations, operations, attrStorage),
164-
[](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
162+
llvm::map_to_vector(unwrapList(nOperations, operations, attrStorage),
163+
llvm::CastTo<DIExpressionElemAttr>)));
165164
}
166165

167166
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@@ -202,7 +201,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
202201
cast<DIExpressionAttr>(unwrap(allocated)),
203202
cast<DIExpressionAttr>(unwrap(associated)),
204203
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
205-
[](Attribute a) { return cast<DINodeAttr>(a); })));
204+
llvm::CastTo<DINodeAttr>)));
206205
}
207206

208207
MlirAttribute mlirLLVMDIDerivedTypeAttrGet(
@@ -308,7 +307,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
308307
return wrap(DISubroutineTypeAttr::get(
309308
unwrap(ctx), callingConvention,
310309
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
311-
[](Attribute a) { return cast<DITypeAttr>(a); })));
310+
llvm::CastTo<DITypeAttr>)));
312311
}
313312

314313
MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) {
@@ -338,10 +337,10 @@ MlirAttribute mlirLLVMDISubprogramAttrGet(
338337
cast<DISubroutineTypeAttr>(unwrap(type)),
339338
llvm::map_to_vector(
340339
unwrapList(nRetainedNodes, retainedNodes, nodesStorage),
341-
[](Attribute a) { return cast<DINodeAttr>(a); }),
340+
llvm::CastTo<DINodeAttr>),
342341
llvm::map_to_vector(
343342
unwrapList(nAnnotations, annotations, annotationsStorage),
344-
[](Attribute a) { return cast<DINodeAttr>(a); })));
343+
llvm::CastTo<DINodeAttr>)));
345344
}
346345

347346
MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) {
@@ -398,7 +397,7 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet(
398397
cast<DINodeAttr>(unwrap(entity)), cast<DIFileAttr>(unwrap(file)), line,
399398
cast<StringAttr>(unwrap(name)),
400399
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
401-
[](Attribute a) { return cast<DINodeAttr>(a); })));
400+
llvm::CastTo<DINodeAttr>)));
402401
}
403402

404403
MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name,

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
551551
RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
552552
assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
553553
auto tensorTypes =
554-
llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
555-
return llvm::cast<RankedTensorType>(type);
556-
}));
554+
llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
557555
int64_t concatRank = tensorTypes[0].getRank();
558556

559557
// The concatenation dim must be in the range [0, rank).

mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
14951495

14961496
template <typename T>
14971497
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
1498-
return llvm::to_vector(llvm::map_range(
1499-
range, [](transform::MappedValue value) { return cast<T>(value); }));
1498+
return llvm::map_to_vector(range, llvm::CastTo<T>);
15001499
}
15011500

15021501
void transform::detail::setApplyToOneResults(

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,7 @@ static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
926926
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
927927
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
928928
getAsOpFoldResult(origOffsets));
929-
newCoods = llvm::to_vector(llvm::map_range(
930-
ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
929+
newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
931930
return newCoods;
932931
}
933932

mlir/lib/IR/TypeUtilities.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
118118
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
119119
/// dims are equal. The element type does not matter.
120120
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
121-
auto shapedTypes = llvm::map_to_vector<8>(
122-
types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
121+
auto shapedTypes = llvm::map_to_vector<8>(types, llvm::DynCastTo<ShapedType>);
123122
// Return failure if some, but not all are not shaped. Return early if none
124123
// are shaped also.
125124
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))

0 commit comments

Comments
 (0)