-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][LLVM] Add builders for llvm.intr.assume #113317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-sme Author: Sirui Mu (Lancern) ChangesThis PR adds two intrinsic operations, namely This PR also adds a new builder to Full diff: https://github.com/llvm/llvm-project/pull/113317.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32bcaad03..6ea3c9f2e1c7ba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
/*int numResults=*/numResults,
+ /*bit enableMlirBuilder=*/1,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d236cae0d80882..cf721f936cc932 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -33,6 +33,7 @@
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/ADT/PointerEmbeddedInt.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 845c88b1be7750..8bbd2b9053e160 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
}];
let builders = [
- OpBuilder<(ins "Value":$cond)>
+ OpBuilder<(ins "Value":$cond)>,
+ OpBuilder<(ins "Value":$cond,
+ "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
];
let hasVerifier = 1;
}
+class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
+ string opBundleTag>
+ : LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
+ /*overloadedOperands=*/[], /*traits=*/[],
+ /*numResults=*/0, /*enumName=*/"assume",
+ /*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
+ /*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[],
+ /*opBundleOperandPositions=*/[opBundleOperandPositions],
+ /*opBundleTags=*/[opBundleTag]> {
+ dag args = (ins I1:$cond);
+}
+
+def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
+ let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
+
+ let assemblyFormat = [{
+ $cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def LLVM_AssumeSeparateStorageOp
+ : LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
+ let arguments = !con(
+ args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
+
+ let assemblyFormat = [{
+ $cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
[Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyType:$operand);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a38dafa4d9cf34..9f6acbcd3c5104 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
- bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
- bit requiresFastmath = 0, bit requiresOpBundles = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
+ bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+ bit requiresOpBundles = 0, list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
: LLVM_OpBase<dialect, opName, !listconcat(
!if(!gt(requiresAccessGroup, 0),
[DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
"StringLiteral(\"" # name # "\")"), ", ") # "}";
+ string opBundleOperandPositionsCpp = "{" # !interleave(
+ !foreach(positions, opBundleOperandPositions,
+ "ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
+ ), ", ") # "}";
+ string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
+ "StringLiteral(\"" # tag # "\")"), ", ") # "}";
string baseLlvmBuilder = [{
auto *inst = LLVM::detail::createIntrinsicCall(
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
- immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
+ immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
+ opBundleTagsCpp], ",") # [{);
(void) inst;
}];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
$_location, resultTypes, mlirOperands, mlirAttrs);
}];
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
- let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+ let mlirBuilder = !if(enableMlirBuilder,
+ baseMlirBuilder # !if(!gt(requiresFastmath, 0),
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
- # baseMlirBuilderCoda;
+ # baseMlirBuilderCoda, "");
// Code for handling a `range` attribute that holds the constant range of the
// intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
- int numResults, bit requiresAccessGroup = 0,
+ int numResults, string enumName = "",
+ bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
- bit requiresOpBundles = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
- : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
+ bit requiresOpBundles = 0, list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
+ : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
+ !if(!empty(enumName), !subst(".", "_", mnem), enumName),
overloadedResults, overloadedOperands, traits,
- numResults, requiresAccessGroup, requiresAliasAnalysis,
- requiresFastmath, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ numResults, enableMlirBuilder, requiresAccessGroup,
+ requiresAliasAnalysis, requiresFastmath,
+ requiresOpBundles, immArgPositions, immArgAttrNames,
+ opBundleOperandPositions, opBundleTags>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
bit requiresAliasAnalysis = 0,
bit requiresOpBundles = 0,
list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
+ /*enumName=*/"", /*enableMlirBuilder=*/1,
requiresAccessGroup, requiresAliasAnalysis,
/*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ immArgAttrNames, opBundleOperandPositions, opBundleTags>;
// Base class for LLVM intrinsic operations returning one result. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+ /*enumName=*/"", /*enableMlirBuilder=*/1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3695708439d91f..7c204c99525ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
list<string> immArgAttrNames = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
- overloadedOperands, traits, numResults, requiresAccessGroup,
+ overloadedOperands, traits, numResults, 1, requiresAccessGroup,
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0c7e22f8c65596 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
- ArrayRef<unsigned> immArgPositions,
- ArrayRef<StringLiteral> immArgAttrNames);
+ ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+ ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+ ArrayRef<StringLiteral> opBundleTags);
} // namespace detail
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc73878a64ff67..f558cf23411ed6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
/*op_bundle_tags=*/{});
}
+void LLVM::AssumeOp::build(
+ OpBuilder &builder, OperationState &state, mlir::Value cond,
+ ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
+ SmallVector<mlir::ValueRange> opBundleOperands;
+ SmallVector<mlir::Attribute> opBundleTags;
+ opBundleOperands.reserve(opBundles.size());
+ opBundleTags.reserve(opBundles.size());
+
+ for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
+ opBundleOperands.emplace_back(bundle.inputs());
+ opBundleTags.push_back(
+ StringAttr::get(builder.getContext(), bundle.getTag()));
+ }
+
+ auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
+ return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
+}
+
LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..de493891ed7e4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
- ArrayRef<unsigned> immArgPositions,
- ArrayRef<StringLiteral> immArgAttrNames) {
+ ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+ ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+ ArrayRef<StringLiteral> opBundleTags) {
assert(immArgPositions.size() == immArgAttrNames.size() &&
"LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
"length");
+ assert(opBundleOperandPositions.size() == opBundleTags.size() &&
+ "operand bundles and tags do not match");
SmallVector<llvm::OperandBundleDef> opBundles;
- size_t numOpBundleOperands = 0;
+
+ size_t numVariadicOpBundleOperands = 0;
auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
-
if (opBundleSizesAttr && opBundleTagsAttr) {
ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
"operand bundles and tags do not match");
- numOpBundleOperands =
+ numVariadicOpBundleOperands =
std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
- assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+ assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
"operand bundle operands is more than the number of operands");
- ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+ ValueRange operands =
+ intrOp->getOperands().take_back(numVariadicOpBundleOperands);
size_t nextOperandIdx = 0;
opBundles.reserve(opBundleSizesAttr.size());
@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
}
// Map operands and attributes to LLVM values.
- auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+ auto opOperands =
+ intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
auto operands = moduleTranslation.lookupValues(opOperands);
- SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+
+ // Map operand bundle operands to LLVM operand bundles.
+ DenseSet<unsigned> opBundleOperandPositionsSet;
+ for (auto [positions, tag] :
+ llvm::zip(opBundleOperandPositions, opBundleTags)) {
+ opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
+
+ SmallVector<llvm::Value *> bundleArgs;
+ bundleArgs.reserve(positions.size());
+ for (unsigned idx : positions) {
+ assert(idx < operands.size() &&
+ "op bundle operand index is out of range");
+ bundleArgs.push_back(operands[idx]);
+ }
+
+ opBundles.emplace_back(tag.str(), std::move(bundleArgs));
+ }
+
+ SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
+ opBundleOperandPositionsSet.size());
for (auto [immArgPos, immArgName] :
llvm::zip(immArgPositions, immArgAttrNames)) {
auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
}
unsigned opArg = 0;
for (auto &arg : args) {
- if (!arg)
+ if (!arg) {
+ while (opBundleOperandPositionsSet.contains(opArg))
+ ++opArg;
arg = operands[opArg++];
+ }
}
// Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
module, intrinsic, overloadedTypes);
+ llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
+ << ", num op bundles = " << opBundles.size() << "\n";
return builder.CreateCall(llvmIntr, args, opBundles);
}
diff --git a/mlir/test/Dialect/LLVMIR/assume.mlir b/mlir/test/Dialect/LLVMIR/assume.mlir
new file mode 100644
index 00000000000000..4cf43b4828010f
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/assume.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @assume_align
+// CHECK-SAME: (ptr %[[ARG:.+]])
+llvm.func @assume_align(%arg0: !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ %1 = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
+ llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: @assume_separate_storage
+// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
+llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ // CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
+ llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 411a98a48bfb28..6fc3e989074937 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
- os << ", " << intr.getNumResults() << ", "
+ os << ", " << intr.getNumResults() << ", \"\", 1, "
<< (requiresAccessGroup ? "1" : "0") << ", "
<< (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
|
|
@llvm/pr-subscribers-mlir-core Author: Sirui Mu (Lancern) ChangesThis PR adds two intrinsic operations, namely This PR also adds a new builder to Full diff: https://github.com/llvm/llvm-project/pull/113317.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32bcaad03..6ea3c9f2e1c7ba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
/*int numResults=*/numResults,
+ /*bit enableMlirBuilder=*/1,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d236cae0d80882..cf721f936cc932 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -33,6 +33,7 @@
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/ADT/PointerEmbeddedInt.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 845c88b1be7750..8bbd2b9053e160 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
}];
let builders = [
- OpBuilder<(ins "Value":$cond)>
+ OpBuilder<(ins "Value":$cond)>,
+ OpBuilder<(ins "Value":$cond,
+ "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
];
let hasVerifier = 1;
}
+class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
+ string opBundleTag>
+ : LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
+ /*overloadedOperands=*/[], /*traits=*/[],
+ /*numResults=*/0, /*enumName=*/"assume",
+ /*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
+ /*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[],
+ /*opBundleOperandPositions=*/[opBundleOperandPositions],
+ /*opBundleTags=*/[opBundleTag]> {
+ dag args = (ins I1:$cond);
+}
+
+def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
+ let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
+
+ let assemblyFormat = [{
+ $cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def LLVM_AssumeSeparateStorageOp
+ : LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
+ let arguments = !con(
+ args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
+
+ let assemblyFormat = [{
+ $cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
[Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyType:$operand);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a38dafa4d9cf34..9f6acbcd3c5104 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
- bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
- bit requiresFastmath = 0, bit requiresOpBundles = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
+ bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+ bit requiresOpBundles = 0, list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
: LLVM_OpBase<dialect, opName, !listconcat(
!if(!gt(requiresAccessGroup, 0),
[DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
"StringLiteral(\"" # name # "\")"), ", ") # "}";
+ string opBundleOperandPositionsCpp = "{" # !interleave(
+ !foreach(positions, opBundleOperandPositions,
+ "ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
+ ), ", ") # "}";
+ string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
+ "StringLiteral(\"" # tag # "\")"), ", ") # "}";
string baseLlvmBuilder = [{
auto *inst = LLVM::detail::createIntrinsicCall(
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
- immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
+ immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
+ opBundleTagsCpp], ",") # [{);
(void) inst;
}];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
$_location, resultTypes, mlirOperands, mlirAttrs);
}];
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
- let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+ let mlirBuilder = !if(enableMlirBuilder,
+ baseMlirBuilder # !if(!gt(requiresFastmath, 0),
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
- # baseMlirBuilderCoda;
+ # baseMlirBuilderCoda, "");
// Code for handling a `range` attribute that holds the constant range of the
// intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
- int numResults, bit requiresAccessGroup = 0,
+ int numResults, string enumName = "",
+ bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
- bit requiresOpBundles = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
- : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
+ bit requiresOpBundles = 0, list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
+ : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
+ !if(!empty(enumName), !subst(".", "_", mnem), enumName),
overloadedResults, overloadedOperands, traits,
- numResults, requiresAccessGroup, requiresAliasAnalysis,
- requiresFastmath, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ numResults, enableMlirBuilder, requiresAccessGroup,
+ requiresAliasAnalysis, requiresFastmath,
+ requiresOpBundles, immArgPositions, immArgAttrNames,
+ opBundleOperandPositions, opBundleTags>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
bit requiresAliasAnalysis = 0,
bit requiresOpBundles = 0,
list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
+ /*enumName=*/"", /*enableMlirBuilder=*/1,
requiresAccessGroup, requiresAliasAnalysis,
/*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ immArgAttrNames, opBundleOperandPositions, opBundleTags>;
// Base class for LLVM intrinsic operations returning one result. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+ /*enumName=*/"", /*enableMlirBuilder=*/1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3695708439d91f..7c204c99525ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
list<string> immArgAttrNames = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
- overloadedOperands, traits, numResults, requiresAccessGroup,
+ overloadedOperands, traits, numResults, 1, requiresAccessGroup,
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0c7e22f8c65596 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
- ArrayRef<unsigned> immArgPositions,
- ArrayRef<StringLiteral> immArgAttrNames);
+ ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+ ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+ ArrayRef<StringLiteral> opBundleTags);
} // namespace detail
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc73878a64ff67..f558cf23411ed6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
/*op_bundle_tags=*/{});
}
+void LLVM::AssumeOp::build(
+ OpBuilder &builder, OperationState &state, mlir::Value cond,
+ ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
+ SmallVector<mlir::ValueRange> opBundleOperands;
+ SmallVector<mlir::Attribute> opBundleTags;
+ opBundleOperands.reserve(opBundles.size());
+ opBundleTags.reserve(opBundles.size());
+
+ for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
+ opBundleOperands.emplace_back(bundle.inputs());
+ opBundleTags.push_back(
+ StringAttr::get(builder.getContext(), bundle.getTag()));
+ }
+
+ auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
+ return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
+}
+
LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..de493891ed7e4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
- ArrayRef<unsigned> immArgPositions,
- ArrayRef<StringLiteral> immArgAttrNames) {
+ ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+ ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+ ArrayRef<StringLiteral> opBundleTags) {
assert(immArgPositions.size() == immArgAttrNames.size() &&
"LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
"length");
+ assert(opBundleOperandPositions.size() == opBundleTags.size() &&
+ "operand bundles and tags do not match");
SmallVector<llvm::OperandBundleDef> opBundles;
- size_t numOpBundleOperands = 0;
+
+ size_t numVariadicOpBundleOperands = 0;
auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
-
if (opBundleSizesAttr && opBundleTagsAttr) {
ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
"operand bundles and tags do not match");
- numOpBundleOperands =
+ numVariadicOpBundleOperands =
std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
- assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+ assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
"operand bundle operands is more than the number of operands");
- ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+ ValueRange operands =
+ intrOp->getOperands().take_back(numVariadicOpBundleOperands);
size_t nextOperandIdx = 0;
opBundles.reserve(opBundleSizesAttr.size());
@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
}
// Map operands and attributes to LLVM values.
- auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+ auto opOperands =
+ intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
auto operands = moduleTranslation.lookupValues(opOperands);
- SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+
+ // Map operand bundle operands to LLVM operand bundles.
+ DenseSet<unsigned> opBundleOperandPositionsSet;
+ for (auto [positions, tag] :
+ llvm::zip(opBundleOperandPositions, opBundleTags)) {
+ opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
+
+ SmallVector<llvm::Value *> bundleArgs;
+ bundleArgs.reserve(positions.size());
+ for (unsigned idx : positions) {
+ assert(idx < operands.size() &&
+ "op bundle operand index is out of range");
+ bundleArgs.push_back(operands[idx]);
+ }
+
+ opBundles.emplace_back(tag.str(), std::move(bundleArgs));
+ }
+
+ SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
+ opBundleOperandPositionsSet.size());
for (auto [immArgPos, immArgName] :
llvm::zip(immArgPositions, immArgAttrNames)) {
auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
}
unsigned opArg = 0;
for (auto &arg : args) {
- if (!arg)
+ if (!arg) {
+ while (opBundleOperandPositionsSet.contains(opArg))
+ ++opArg;
arg = operands[opArg++];
+ }
}
// Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
module, intrinsic, overloadedTypes);
+ llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
+ << ", num op bundles = " << opBundles.size() << "\n";
return builder.CreateCall(llvmIntr, args, opBundles);
}
diff --git a/mlir/test/Dialect/LLVMIR/assume.mlir b/mlir/test/Dialect/LLVMIR/assume.mlir
new file mode 100644
index 00000000000000..4cf43b4828010f
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/assume.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @assume_align
+// CHECK-SAME: (ptr %[[ARG:.+]])
+llvm.func @assume_align(%arg0: !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ %1 = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
+ llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: @assume_separate_storage
+// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
+llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ // CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
+ llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 411a98a48bfb28..6fc3e989074937 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
- os << ", " << intr.getNumResults() << ", "
+ os << ", " << intr.getNumResults() << ", \"\", 1, "
<< (requiresAccessGroup ? "1" : "0") << ", "
<< (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
|
|
@llvm/pr-subscribers-mlir Author: Sirui Mu (Lancern) ChangesThis PR adds two intrinsic operations, namely This PR also adds a new builder to Full diff: https://github.com/llvm/llvm-project/pull/113317.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32bcaad03..6ea3c9f2e1c7ba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
/*int numResults=*/numResults,
+ /*bit enableMlirBuilder=*/1,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d236cae0d80882..cf721f936cc932 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -33,6 +33,7 @@
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/ADT/PointerEmbeddedInt.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 845c88b1be7750..8bbd2b9053e160 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
}];
let builders = [
- OpBuilder<(ins "Value":$cond)>
+ OpBuilder<(ins "Value":$cond)>,
+ OpBuilder<(ins "Value":$cond,
+ "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
];
let hasVerifier = 1;
}
+class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
+ string opBundleTag>
+ : LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
+ /*overloadedOperands=*/[], /*traits=*/[],
+ /*numResults=*/0, /*enumName=*/"assume",
+ /*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
+ /*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
+ /*requiresOpBundles=*/0, /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[],
+ /*opBundleOperandPositions=*/[opBundleOperandPositions],
+ /*opBundleTags=*/[opBundleTag]> {
+ dag args = (ins I1:$cond);
+}
+
+def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
+ let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
+
+ let assemblyFormat = [{
+ $cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def LLVM_AssumeSeparateStorageOp
+ : LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
+ let arguments = !con(
+ args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
+
+ let assemblyFormat = [{
+ $cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
[Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyType:$operand);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a38dafa4d9cf34..9f6acbcd3c5104 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
- bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
- bit requiresFastmath = 0, bit requiresOpBundles = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
+ bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
+ bit requiresOpBundles = 0, list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
: LLVM_OpBase<dialect, opName, !listconcat(
!if(!gt(requiresAccessGroup, 0),
[DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
"StringLiteral(\"" # name # "\")"), ", ") # "}";
+ string opBundleOperandPositionsCpp = "{" # !interleave(
+ !foreach(positions, opBundleOperandPositions,
+ "ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
+ ), ", ") # "}";
+ string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
+ "StringLiteral(\"" # tag # "\")"), ", ") # "}";
string baseLlvmBuilder = [{
auto *inst = LLVM::detail::createIntrinsicCall(
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
- immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
+ immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
+ opBundleTagsCpp], ",") # [{);
(void) inst;
}];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
$_location, resultTypes, mlirOperands, mlirAttrs);
}];
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
- let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+ let mlirBuilder = !if(enableMlirBuilder,
+ baseMlirBuilder # !if(!gt(requiresFastmath, 0),
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
- # baseMlirBuilderCoda;
+ # baseMlirBuilderCoda, "");
// Code for handling a `range` attribute that holds the constant range of the
// intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
- int numResults, bit requiresAccessGroup = 0,
+ int numResults, string enumName = "",
+ bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
- bit requiresOpBundles = 0,
- list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
- : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
+ bit requiresOpBundles = 0, list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
+ : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
+ !if(!empty(enumName), !subst(".", "_", mnem), enumName),
overloadedResults, overloadedOperands, traits,
- numResults, requiresAccessGroup, requiresAliasAnalysis,
- requiresFastmath, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ numResults, enableMlirBuilder, requiresAccessGroup,
+ requiresAliasAnalysis, requiresFastmath,
+ requiresOpBundles, immArgPositions, immArgAttrNames,
+ opBundleOperandPositions, opBundleTags>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
bit requiresAliasAnalysis = 0,
bit requiresOpBundles = 0,
list<int> immArgPositions = [],
- list<string> immArgAttrNames = []>
+ list<string> immArgAttrNames = [],
+ list<list<int>> opBundleOperandPositions = [],
+ list<string> opBundleTags = []>
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
+ /*enumName=*/"", /*enableMlirBuilder=*/1,
requiresAccessGroup, requiresAliasAnalysis,
/*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ immArgAttrNames, opBundleOperandPositions, opBundleTags>;
// Base class for LLVM intrinsic operations returning one result. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+ /*enumName=*/"", /*enableMlirBuilder=*/1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3695708439d91f..7c204c99525ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
list<string> immArgAttrNames = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
- overloadedOperands, traits, numResults, requiresAccessGroup,
+ overloadedOperands, traits, numResults, 1, requiresAccessGroup,
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0c7e22f8c65596 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
- ArrayRef<unsigned> immArgPositions,
- ArrayRef<StringLiteral> immArgAttrNames);
+ ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+ ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+ ArrayRef<StringLiteral> opBundleTags);
} // namespace detail
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc73878a64ff67..f558cf23411ed6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
/*op_bundle_tags=*/{});
}
+void LLVM::AssumeOp::build(
+ OpBuilder &builder, OperationState &state, mlir::Value cond,
+ ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
+ SmallVector<mlir::ValueRange> opBundleOperands;
+ SmallVector<mlir::Attribute> opBundleTags;
+ opBundleOperands.reserve(opBundles.size());
+ opBundleTags.reserve(opBundles.size());
+
+ for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
+ opBundleOperands.emplace_back(bundle.inputs());
+ opBundleTags.push_back(
+ StringAttr::get(builder.getContext(), bundle.getTag()));
+ }
+
+ auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
+ return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
+}
+
LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..de493891ed7e4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
- ArrayRef<unsigned> immArgPositions,
- ArrayRef<StringLiteral> immArgAttrNames) {
+ ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
+ ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
+ ArrayRef<StringLiteral> opBundleTags) {
assert(immArgPositions.size() == immArgAttrNames.size() &&
"LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
"length");
+ assert(opBundleOperandPositions.size() == opBundleTags.size() &&
+ "operand bundles and tags do not match");
SmallVector<llvm::OperandBundleDef> opBundles;
- size_t numOpBundleOperands = 0;
+
+ size_t numVariadicOpBundleOperands = 0;
auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
-
if (opBundleSizesAttr && opBundleTagsAttr) {
ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
"operand bundles and tags do not match");
- numOpBundleOperands =
+ numVariadicOpBundleOperands =
std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
- assert(numOpBundleOperands <= intrOp->getNumOperands() &&
+ assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
"operand bundle operands is more than the number of operands");
- ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
+ ValueRange operands =
+ intrOp->getOperands().take_back(numVariadicOpBundleOperands);
size_t nextOperandIdx = 0;
opBundles.reserve(opBundleSizesAttr.size());
@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
}
// Map operands and attributes to LLVM values.
- auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
+ auto opOperands =
+ intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
auto operands = moduleTranslation.lookupValues(opOperands);
- SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
+
+ // Map operand bundle operands to LLVM operand bundles.
+ DenseSet<unsigned> opBundleOperandPositionsSet;
+ for (auto [positions, tag] :
+ llvm::zip(opBundleOperandPositions, opBundleTags)) {
+ opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
+
+ SmallVector<llvm::Value *> bundleArgs;
+ bundleArgs.reserve(positions.size());
+ for (unsigned idx : positions) {
+ assert(idx < operands.size() &&
+ "op bundle operand index is out of range");
+ bundleArgs.push_back(operands[idx]);
+ }
+
+ opBundles.emplace_back(tag.str(), std::move(bundleArgs));
+ }
+
+ SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
+ opBundleOperandPositionsSet.size());
for (auto [immArgPos, immArgName] :
llvm::zip(immArgPositions, immArgAttrNames)) {
auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
}
unsigned opArg = 0;
for (auto &arg : args) {
- if (!arg)
+ if (!arg) {
+ while (opBundleOperandPositionsSet.contains(opArg))
+ ++opArg;
arg = operands[opArg++];
+ }
}
// Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
module, intrinsic, overloadedTypes);
+ llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
+ << ", num op bundles = " << opBundles.size() << "\n";
return builder.CreateCall(llvmIntr, args, opBundles);
}
diff --git a/mlir/test/Dialect/LLVMIR/assume.mlir b/mlir/test/Dialect/LLVMIR/assume.mlir
new file mode 100644
index 00000000000000..4cf43b4828010f
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/assume.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @assume_align
+// CHECK-SAME: (ptr %[[ARG:.+]])
+llvm.func @assume_align(%arg0: !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ %1 = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
+ llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: @assume_separate_storage
+// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
+llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ %0 = llvm.mlir.constant(1 : i1) : i1
+ // CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
+ llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 411a98a48bfb28..6fc3e989074937 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
- os << ", " << intr.getNumResults() << ", "
+ os << ", " << intr.getNumResults() << ", \"\", 1, "
<< (requiresAccessGroup ? "1" : "0") << ", "
<< (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
|
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This intrinsic does not exist in LLVM IR right?
LLVM dialect operations usually have a 1:1 mapping to an LLVM IR instruction/intrinsic. There are a few exceptions for example for constants, since they are not instructions in LLVM IR. In these cases we always prefix the operation with llvm.mlir. to clarify this operation is special.
So if we add a new intrinsic it should be prefix the new intrinsics with llvm.intr.mlir. or similar. However, we need a good argument for this since this contradicts the LLVM dialect rational (see third paragraph of https://mlir.llvm.org/docs/Dialects/LLVM/).
Is your plan to match on these OPs somehow or would it be good enough to have a convenience builder that builds normal assume operations with the specific tags?
@ftynse what would be your take on adding specialized intrinsics that do not exist in LLVM IR (AFAIK there is no prior art)?
|
Why aren't we modeling the |
Yes I try to model the different assume operand bundles with dedicated operations, since I do not find a way to check whether an arbitrary operand bundle passed to
I agree that having convenience builders for assume operand bundles would be great enough, since the only point of this PR is to have some degree of convenience when building assume intrinsics with tags. |
You could add a verifier to AssumeOp that checks if the operand bundles are correct. If you do that, then it is important to replicate the logic in LLVM's Verifier.cpp. At the moment, only very few intrinsics in LLVM dialect implement a verifier though. Instead, the verification happens after lowering to LLVM proper. I am fine with either of these solutions!
If the goal is to simplify the lowering, then I would go the convenience builder route and avoid introducing custom intrinsics for every operand bundle type. Having different operations may be interesting if we want to use the assume information in transformations. However, if we want this then it would probably make sense to have a separate Assume dialect. That would require an RFC though especially since there is some overlap with existing dialects, such as memref, which implement some of this functionality. |
5606aa3 to
bdc31db
Compare
|
I have updated the patch and kept only the new builders for The first two are general and the last two are for specific tags. Currently I only add builders for As for tests, since we don't have any tests yet for operation builders, I assume it's safe to ignore them for now. |
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the comments are addressed.
A low-tech alternative maybe to just define some static getters in the extraClassDeclaration that return the tag strings (e.g. static StringRef getAlignTag() / static StringRef getSeparateStorageTag()) and then have one builder that takes a StringRef and a ValueRange. However, that way there may be a mismatch between tag and the number of arguments. I am fine with both approaches!
bdc31db to
0988dcc
Compare
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch adds several new builders for llvm.intr.assume that build the operation with additional operand bundles.
This PR adds several new builders for llvm.intr.assume that build the operation with additional operand bundles.