-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Intrinsics] Add new MLIR api to automatically resolve overload types #168188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR][Intrinsics] Add new MLIR api to automatically resolve overload types #168188
Conversation
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-llvm-ir Author: Rajat Bajpai (rajatbajpai) ChangesCurrently, the This patch introduces a new This patch introduces two levels of convenience APIs to eliminate manual overload detection:
Key changes:
Full diff: https://github.com/llvm/llvm-project/pull/168188.diff 5 Files Affected:
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 9577d0141f168..ffcc4677830ca 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -104,6 +104,21 @@ namespace Intrinsic {
LLVM_ABI Function *getOrInsertDeclaration(Module *M, ID id,
ArrayRef<Type *> Tys = {});
+ /// Look up the Function declaration of the intrinsic \p IID in the Module
+ /// \p M. If it does not exist, add a declaration and return it. Otherwise,
+ /// return the existing declaration.
+ ///
+ /// This overload automatically resolves overloaded intrinsics based on the
+ /// provided return type and argument types. For non-overloaded intrinsics,
+ /// the return type and argument types are ignored.
+ ///
+ /// \param M - The module to get or insert the intrinsic declaration.
+ /// \param IID - The intrinsic ID.
+ /// \param RetTy - The return type of the intrinsic.
+ /// \param ArgTys - The argument types of the intrinsic.
+ LLVM_ABI Function *getOrInsertDeclaration(Module *M, ID IID, Type *RetTy,
+ ArrayRef<Type *> ArgTys);
+
/// Look up the Function declaration of the intrinsic \p id in the Module
/// \p M and return it if it exists. Otherwise, return nullptr. This version
/// supports non-overloaded intrinsics.
diff --git a/llvm/lib/IR/Intrinsics.cpp b/llvm/lib/IR/Intrinsics.cpp
index 526800e217399..61c83be82a917 100644
--- a/llvm/lib/IR/Intrinsics.cpp
+++ b/llvm/lib/IR/Intrinsics.cpp
@@ -720,14 +720,14 @@ Intrinsic::ID Intrinsic::lookupIntrinsicID(StringRef Name) {
#include "llvm/IR/IntrinsicImpl.inc"
#undef GET_INTRINSIC_ATTRIBUTES
-Function *Intrinsic::getOrInsertDeclaration(Module *M, ID id,
- ArrayRef<Type *> Tys) {
- // There can never be multiple globals with the same name of different types,
- // because intrinsics must be a specific type.
- auto *FT = getType(M->getContext(), id, Tys);
+static Function *getOrInsertIntrinsicDeclarationImpl(Module *M,
+ Intrinsic::ID id,
+ ArrayRef<Type *> Tys,
+ FunctionType *FT) {
Function *F = cast<Function>(
- M->getOrInsertFunction(
- Tys.empty() ? getName(id) : getName(id, Tys, M, FT), FT)
+ M->getOrInsertFunction(Tys.empty() ? Intrinsic::getName(id)
+ : Intrinsic::getName(id, Tys, M, FT),
+ FT)
.getCallee());
if (F->getFunctionType() == FT)
return F;
@@ -739,11 +739,43 @@ Function *Intrinsic::getOrInsertDeclaration(Module *M, ID id,
// invalid declaration will get upgraded later.
F->setName(F->getName() + ".invalid");
return cast<Function>(
- M->getOrInsertFunction(
- Tys.empty() ? getName(id) : getName(id, Tys, M, FT), FT)
+ M->getOrInsertFunction(Tys.empty() ? Intrinsic::getName(id)
+ : Intrinsic::getName(id, Tys, M, FT),
+ FT)
.getCallee());
}
+Function *Intrinsic::getOrInsertDeclaration(Module *M, ID id,
+ ArrayRef<Type *> Tys) {
+ // There can never be multiple globals with the same name of different types,
+ // because intrinsics must be a specific type.
+ FunctionType *FT = getType(M->getContext(), id, Tys);
+ return getOrInsertIntrinsicDeclarationImpl(M, id, Tys, FT);
+}
+
+Function *Intrinsic::getOrInsertDeclaration(Module *M, ID IID, Type *RetTy,
+ ArrayRef<Type *> ArgTys) {
+ // If the intrinsic is not overloaded, use the non-overloaded version.
+ if (!Intrinsic::isOverloaded(IID))
+ return getOrInsertDeclaration(M, IID);
+
+ // Get the intrinsic signature metadata.
+ SmallVector<Intrinsic::IITDescriptor, 8> Table;
+ getIntrinsicInfoTableEntries(IID, Table);
+ ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
+
+ FunctionType *FTy = FunctionType::get(RetTy, ArgTys, /*isVarArg=*/false);
+
+ // Automatically determine the overloaded types.
+ SmallVector<Type *, 4> OverloadTys;
+ [[maybe_unused]] Intrinsic::MatchIntrinsicTypesResult Res =
+ matchIntrinsicSignature(FTy, TableRef, OverloadTys);
+ assert(Res == Intrinsic::MatchIntrinsicTypes_Match && TableRef.empty() &&
+ "intrinsic signature mismatch");
+
+ return getOrInsertIntrinsicDeclarationImpl(M, IID, OverloadTys, FTy);
+}
+
Function *Intrinsic::getDeclarationIfExists(const Module *M, ID id) {
return M->getFunction(getName(id));
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 995ade5c9b033..5747021be189c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3176,12 +3176,7 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch",
let llvmBuilder = [{
auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op,
moduleTranslation, builder);
-
- if(op.getTensormap())
- // Overloaded intrinsic
- createIntrinsicCall(builder, id, args, {args[0]->getType()});
- else
- createIntrinsicCall(builder, id, args);
+ createIntrinsicCall(builder, id, builder.getVoidTy(), args);
}];
}
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index eb7dfa7637e52..039ac8e2e1911 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -512,6 +512,15 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder,
ArrayRef<llvm::Value *> args = {},
ArrayRef<llvm::Type *> tys = {});
+/// Creates a call to an LLVM IR intrinsic function with the given return type
+/// and arguments. If the intrinsic is overloaded, the function signature will
+/// be automatically resolved based on the provided return type and argument
+/// types.
+llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder,
+ llvm::Intrinsic::ID intrinsic,
+ llvm::Type *retTy,
+ ArrayRef<llvm::Value *> args);
+
/// Creates a call to a LLVM IR intrinsic defined by LLVM_IntrOpBase. This
/// resolves the overloads, and maps mixed MLIR value and attribute arguments to
/// LLVM values.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 64e3c5f085bb3..e48e841dae18a 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -898,6 +898,21 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(fn, args);
}
+llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
+ llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
+ llvm::Type *retTy, ArrayRef<llvm::Value *> args) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+
+ SmallVector<llvm::Type *> argTys;
+ argTys.reserve(args.size());
+ for (llvm::Value *arg : args)
+ argTys.push_back(arg->getType());
+
+ llvm::Function *fn =
+ llvm::Intrinsic::getOrInsertDeclaration(module, intrinsic, retTy, argTys);
+ return builder.CreateCall(fn, args);
+}
+
llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
|
|
@joker-eph, @ftynse Could you please help with a review? (mainly, the MLIR Module Translation updates..) |
🐧 Linux x64 Test Results
|
| llvm::Type *retTy, ArrayRef<llvm::Value *> args) { | ||
| llvm::Module *module = builder.GetInsertBlock()->getModule(); | ||
|
|
||
| SmallVector<llvm::Type *> argTys; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not just call builder.CreateIntrinsic() here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it fair to say that in the current state, this PR is moving the overrload type resolution logic from IRBuilder::CreateIntrinsic() to Intrinsic::getOrInsertDeclaration()? In that case, I'd think most external folks should be using IRBuilder API to create intrinsic calls, so there this change is not visible to them. In which case, what is the motivation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, it is more of a use case than the motivation on the LLVM side. I did find one example in VPlanRecipes.cpp (shown below), but I don’t have enough familiarity with that part of the codebase to confidently refactor it. The modification worked correctly in local testing, but I’m not sure it’s a good idea to merge it alongside this change.
Additionally, the earlier CreateIntrinsic API was suboptimal, it computed the overloaded type, which was then decoded again inside getOrInsertDeclaration. The new API avoids this redundant computation.
nit: why not just call builder.CreateIntrinsic() here?
That’s a good point. Technically, yes, we could call builder.CreateIntrinsic() here. However, I'm not sure if there were any other reasons of not using LLVM api here. Perhaps, @ftynse might have more insight as the original author of this code.
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -30,6 +30,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
+#include "llvm/IR/VectorTypeUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -1763,11 +1764,8 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent,
void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
assert(State.VF.isVector() && "not widening");
- SmallVector<Type *, 2> TysForDecl;
- // Add return type if intrinsic is overloaded on it.
- if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1, State.TTI))
- TysForDecl.push_back(VectorType::get(getResultType(), State.VF));
SmallVector<Value *, 4> Args;
+ SmallVector<Type *, 4> ArgTys;
for (const auto &I : enumerate(operands())) {
// Some intrinsics have a scalar argument - don't replace it with a
// vector.
@@ -1777,16 +1775,17 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
Arg = State.get(I.value(), VPLane(0));
else
Arg = State.get(I.value(), usesFirstLaneOnly(I.value()));
- if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index(),
- State.TTI))
- TysForDecl.push_back(Arg->getType());
Args.push_back(Arg);
+ ArgTys.push_back(Arg->getType());
}
// Use vector version of the intrinsic.
Module *M = State.Builder.GetInsertBlock()->getModule();
+
+ Type *RetTy = toVectorizedTy(getResultType(), State.VF);
+
Function *VectorF =
- Intrinsic::getOrInsertDeclaration(M, VectorIntrinsicID, TysForDecl);
+ Intrinsic::getOrInsertDeclaration(M, VectorIntrinsicID, RetTy, ArgTys);
assert(VectorF &&
"Can't retrieve vector intrinsic or vector-predication intrinsics.");There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think we should atleast split this PR into two. One is the MLIR side that adds the new createIntrinsicCall() that takes and return types and args and just calls Builder.CreateIntrinsic() (the CreateIntrinsic(Type *RetTy, Intrinsic::ID ID, ArrayRef<Value *> Args) variant as that seems like an independent change.
If you want, at the same time, we can change the existing createIntrinsicCall on MLIR side to also just use the same Builder.CreateIntrinsic API that hides all the overload type deduction inside and remove that code from createIntrinsicCall. It seems that can be done independently of any LLVM side changes.
Then, we can have another PR to add this ability to LLVM's Intrinsic:: API. Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sounds reasonable. I'll update this PR with MLIR only changes and create a separate PR for LLVM.
…olution Add createIntrinsicCall overload that accepts return type and arguments, automatically resolve overload types rather than requiring manual computation. Simplifies NVVM_PrefetchOp by removing conditional overload logic.
d1bf457 to
b6847cd
Compare
|
Just wanted to +1 for the idea of this patch |
Add createIntrinsicCall overload that accepts return type and arguments, automatically resolve overload types rather than requiring manual computation. Simplifies NVVM_PrefetchOp by removing conditional overload logic.