Skip to content

Commit 5606aa3

Browse files
committed
[mlir][LLVM] Add dedicated operations for assume align and separate_storage
This patch adds two intrinsic operations, namely `llvm.intr.assume.align` and `llvm.intr.assume.separate_storage`. Module translation translates both operations to intrinsic calls to `@llvm.assume`, with different assume operand bundles. This patch also adds a new builder to `llvm.intr.assume` to make it easier to build the operation with assume operand bundles.
1 parent a6d6c00 commit 5606aa3

File tree

10 files changed

+154
-32
lines changed

10 files changed

+154
-32
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class ArmSME_IntrOp<string mnemonic,
6868
/*list<int> overloadedOperands=*/overloadedOperands,
6969
/*list<Trait> traits=*/traits,
7070
/*int numResults=*/numResults,
71+
/*bit enableMlirBuilder=*/1,
7172
/*bit requiresAccessGroup=*/0,
7273
/*bit requiresAliasAnalysis=*/0,
7374
/*bit requiresFastmath=*/0,

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Support/ThreadLocalCache.h"
3434
#include "llvm/ADT/PointerEmbeddedInt.h"
3535
#include "llvm/IR/DerivedTypes.h"
36+
#include "llvm/IR/InstrTypes.h"
3637
#include "llvm/IR/LLVMContext.h"
3738
#include "llvm/IR/Module.h"
3839
#include "llvm/IR/Type.h"

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,46 @@ def LLVM_AssumeOp
450450
}];
451451

452452
let builders = [
453-
OpBuilder<(ins "Value":$cond)>
453+
OpBuilder<(ins "Value":$cond)>,
454+
OpBuilder<(ins "Value":$cond,
455+
"ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>
454456
];
455457

456458
let hasVerifier = 1;
457459
}
458460

461+
class LLVM_AssumeOpBase<string mnem, list<int> opBundleOperandPositions,
462+
string opBundleTag>
463+
: LLVM_IntrOp<"assume." # mnem, /*overloadedResults=*/[],
464+
/*overloadedOperands=*/[], /*traits=*/[],
465+
/*numResults=*/0, /*enumName=*/"assume",
466+
/*enableMlirBuilder=*/0, /*requiresAccessGroup=*/0,
467+
/*requiresAliasAnalysis=*/0, /*requiresFastmath=*/0,
468+
/*requiresOpBundles=*/0, /*immArgPositions=*/[],
469+
/*immArgAttrNames=*/[],
470+
/*opBundleOperandPositions=*/[opBundleOperandPositions],
471+
/*opBundleTags=*/[opBundleTag]> {
472+
dag args = (ins I1:$cond);
473+
}
474+
475+
def LLVM_AssumeAlignOp : LLVM_AssumeOpBase<"align", [1, 2], "align"> {
476+
let arguments = !con(args, (ins LLVM_AnyPointer:$ptr, AnyInteger:$align));
477+
478+
let assemblyFormat = [{
479+
$cond `,` $ptr `,` $align attr-dict `:` functional-type(operands, results)
480+
}];
481+
}
482+
483+
def LLVM_AssumeSeparateStorageOp
484+
: LLVM_AssumeOpBase<"separate_storage", [1, 2], "separate_storage"> {
485+
let arguments = !con(
486+
args, (ins LLVM_AnyPointer:$ptr1, LLVM_AnyPointer:$ptr2));
487+
488+
let assemblyFormat = [{
489+
$cond `,` $ptr1 `,` $ptr2 attr-dict `:` functional-type(operands, results)
490+
}];
491+
}
492+
459493
def LLVM_SSACopyOp : LLVM_OneResultIntrOp<"ssa.copy", [], [0],
460494
[Pure, SameOperandsAndResultType]> {
461495
let arguments = (ins AnyType:$operand);

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,12 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
290290
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
291291
list<int> overloadedResults, list<int> overloadedOperands,
292292
list<Trait> traits, int numResults,
293-
bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
294-
bit requiresFastmath = 0, bit requiresOpBundles = 0,
295-
list<int> immArgPositions = [],
296-
list<string> immArgAttrNames = []>
293+
bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
294+
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
295+
bit requiresOpBundles = 0, list<int> immArgPositions = [],
296+
list<string> immArgAttrNames = [],
297+
list<list<int>> opBundleOperandPositions = [],
298+
list<string> opBundleTags = []>
297299
: LLVM_OpBase<dialect, opName, !listconcat(
298300
!if(!gt(requiresAccessGroup, 0),
299301
[DeclareOpInterfaceMethods<AccessGroupOpInterface>], []),
@@ -325,11 +327,18 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
325327
string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
326328
string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
327329
"StringLiteral(\"" # name # "\")"), ", ") # "}";
330+
string opBundleOperandPositionsCpp = "{" # !interleave(
331+
!foreach(positions, opBundleOperandPositions,
332+
"ArrayRef<unsigned>{" # !interleave(positions, ", ") # "}"
333+
), ", ") # "}";
334+
string opBundleTagsCpp = "{" # !interleave(!foreach(tag, opBundleTags,
335+
"StringLiteral(\"" # tag # "\")"), ", ") # "}";
328336
string baseLlvmBuilder = [{
329337
auto *inst = LLVM::detail::createIntrinsicCall(
330338
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
331339
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
332-
immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
340+
immArgPositionsCpp, immArgAttrNamesCpp, opBundleOperandPositionsCpp,
341+
opBundleTagsCpp], ",") # [{);
333342
(void) inst;
334343
}];
335344
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
@@ -357,9 +366,10 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
357366
$_location, resultTypes, mlirOperands, mlirAttrs);
358367
}];
359368
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
360-
let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
369+
let mlirBuilder = !if(enableMlirBuilder,
370+
baseMlirBuilder # !if(!gt(requiresFastmath, 0),
361371
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
362-
# baseMlirBuilderCoda;
372+
# baseMlirBuilderCoda, "");
363373

364374
// Code for handling a `range` attribute that holds the constant range of the
365375
// intrinsic's result (if one is specified at the call site). This is intended
@@ -387,16 +397,20 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
387397
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
388398
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
389399
list<int> overloadedOperands, list<Trait> traits,
390-
int numResults, bit requiresAccessGroup = 0,
400+
int numResults, string enumName = "",
401+
bit enableMlirBuilder = 1, bit requiresAccessGroup = 0,
391402
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
392-
bit requiresOpBundles = 0,
393-
list<int> immArgPositions = [],
394-
list<string> immArgAttrNames = []>
395-
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
403+
bit requiresOpBundles = 0, list<int> immArgPositions = [],
404+
list<string> immArgAttrNames = [],
405+
list<list<int>> opBundleOperandPositions = [],
406+
list<string> opBundleTags = []>
407+
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem,
408+
!if(!empty(enumName), !subst(".", "_", mnem), enumName),
396409
overloadedResults, overloadedOperands, traits,
397-
numResults, requiresAccessGroup, requiresAliasAnalysis,
398-
requiresFastmath, requiresOpBundles, immArgPositions,
399-
immArgAttrNames>;
410+
numResults, enableMlirBuilder, requiresAccessGroup,
411+
requiresAliasAnalysis, requiresFastmath,
412+
requiresOpBundles, immArgPositions, immArgAttrNames,
413+
opBundleOperandPositions, opBundleTags>;
400414

401415
// Base class for LLVM intrinsic operations returning no results. Places the
402416
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -418,11 +432,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
418432
bit requiresAliasAnalysis = 0,
419433
bit requiresOpBundles = 0,
420434
list<int> immArgPositions = [],
421-
list<string> immArgAttrNames = []>
435+
list<string> immArgAttrNames = [],
436+
list<list<int>> opBundleOperandPositions = [],
437+
list<string> opBundleTags = []>
422438
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
439+
/*enumName=*/"", /*enableMlirBuilder=*/1,
423440
requiresAccessGroup, requiresAliasAnalysis,
424441
/*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
425-
immArgAttrNames>;
442+
immArgAttrNames, opBundleOperandPositions, opBundleTags>;
426443

427444
// Base class for LLVM intrinsic operations returning one result. Places the
428445
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -437,6 +454,7 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
437454
list<int> immArgPositions = [],
438455
list<string> immArgAttrNames = []>
439456
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
457+
/*enumName=*/"", /*enableMlirBuilder=*/1,
440458
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
441459
requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
442460
immArgAttrNames>;

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
9797
list<string> immArgAttrNames = []> :
9898
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
9999
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
100-
overloadedOperands, traits, numResults, requiresAccessGroup,
100+
overloadedOperands, traits, numResults, 1, requiresAccessGroup,
101101
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
102102

103103
//===----------------------------------------------------------------------===//

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,9 @@ llvm::CallInst *createIntrinsicCall(
434434
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
435435
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
436436
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
437-
ArrayRef<unsigned> immArgPositions,
438-
ArrayRef<StringLiteral> immArgAttrNames);
437+
ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
438+
ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
439+
ArrayRef<StringLiteral> opBundleTags);
439440

440441
} // namespace detail
441442

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3441,6 +3441,24 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
34413441
/*op_bundle_tags=*/{});
34423442
}
34433443

3444+
void LLVM::AssumeOp::build(
3445+
OpBuilder &builder, OperationState &state, mlir::Value cond,
3446+
ArrayRef<llvm::OperandBundleDefT<mlir::Value>> opBundles) {
3447+
SmallVector<mlir::ValueRange> opBundleOperands;
3448+
SmallVector<mlir::Attribute> opBundleTags;
3449+
opBundleOperands.reserve(opBundles.size());
3450+
opBundleTags.reserve(opBundles.size());
3451+
3452+
for (const llvm::OperandBundleDefT<mlir::Value> &bundle : opBundles) {
3453+
opBundleOperands.emplace_back(bundle.inputs());
3454+
opBundleTags.push_back(
3455+
StringAttr::get(builder.getContext(), bundle.getTag()));
3456+
}
3457+
3458+
auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
3459+
return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
3460+
}
3461+
34443462
LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
34453463

34463464
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -849,30 +849,34 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
849849
llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
850850
Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
851851
ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
852-
ArrayRef<unsigned> immArgPositions,
853-
ArrayRef<StringLiteral> immArgAttrNames) {
852+
ArrayRef<unsigned> immArgPositions, ArrayRef<StringLiteral> immArgAttrNames,
853+
ArrayRef<ArrayRef<unsigned>> opBundleOperandPositions,
854+
ArrayRef<StringLiteral> opBundleTags) {
854855
assert(immArgPositions.size() == immArgAttrNames.size() &&
855856
"LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
856857
"length");
858+
assert(opBundleOperandPositions.size() == opBundleTags.size() &&
859+
"operand bundles and tags do not match");
857860

858861
SmallVector<llvm::OperandBundleDef> opBundles;
859-
size_t numOpBundleOperands = 0;
862+
863+
size_t numVariadicOpBundleOperands = 0;
860864
auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
861865
intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
862866
auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
863867
intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
864-
865868
if (opBundleSizesAttr && opBundleTagsAttr) {
866869
ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
867870
assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
868871
"operand bundles and tags do not match");
869872

870-
numOpBundleOperands =
873+
numVariadicOpBundleOperands =
871874
std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
872-
assert(numOpBundleOperands <= intrOp->getNumOperands() &&
875+
assert(numVariadicOpBundleOperands <= intrOp->getNumOperands() &&
873876
"operand bundle operands is more than the number of operands");
874877

875-
ValueRange operands = intrOp->getOperands().take_back(numOpBundleOperands);
878+
ValueRange operands =
879+
intrOp->getOperands().take_back(numVariadicOpBundleOperands);
876880
size_t nextOperandIdx = 0;
877881
opBundles.reserve(opBundleSizesAttr.size());
878882

@@ -887,9 +891,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
887891
}
888892

889893
// Map operands and attributes to LLVM values.
890-
auto opOperands = intrOp->getOperands().drop_back(numOpBundleOperands);
894+
auto opOperands =
895+
intrOp->getOperands().drop_back(numVariadicOpBundleOperands);
891896
auto operands = moduleTranslation.lookupValues(opOperands);
892-
SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
897+
898+
// Map operand bundle operands to LLVM operand bundles.
899+
DenseSet<unsigned> opBundleOperandPositionsSet;
900+
for (auto [positions, tag] :
901+
llvm::zip(opBundleOperandPositions, opBundleTags)) {
902+
opBundleOperandPositionsSet.insert(positions.begin(), positions.end());
903+
904+
SmallVector<llvm::Value *> bundleArgs;
905+
bundleArgs.reserve(positions.size());
906+
for (unsigned idx : positions) {
907+
assert(idx < operands.size() &&
908+
"op bundle operand index is out of range");
909+
bundleArgs.push_back(operands[idx]);
910+
}
911+
912+
opBundles.emplace_back(tag.str(), std::move(bundleArgs));
913+
}
914+
915+
SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size() -
916+
opBundleOperandPositionsSet.size());
893917
for (auto [immArgPos, immArgName] :
894918
llvm::zip(immArgPositions, immArgAttrNames)) {
895919
auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
@@ -900,8 +924,11 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
900924
}
901925
unsigned opArg = 0;
902926
for (auto &arg : args) {
903-
if (!arg)
927+
if (!arg) {
928+
while (opBundleOperandPositionsSet.contains(opArg))
929+
++opArg;
904930
arg = operands[opArg++];
931+
}
905932
}
906933

907934
// Resolve overloaded intrinsic declaration.
@@ -923,6 +950,8 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
923950
llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
924951
module, intrinsic, overloadedTypes);
925952

953+
llvm::outs() << "debug: createIntrinsicCall: num args = " << args.size()
954+
<< ", num op bundles = " << opBundles.size() << "\n";
926955
return builder.CreateCall(llvmIntr, args, opBundles);
927956
}
928957

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @assume_align
4+
// CHECK-SAME: (ptr %[[ARG:.+]])
5+
llvm.func @assume_align(%arg0: !llvm.ptr) {
6+
%0 = llvm.mlir.constant(1 : i1) : i1
7+
%1 = llvm.mlir.constant(8 : i32) : i32
8+
// CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[ARG]], i32 8) ]
9+
llvm.intr.assume.align %0, %arg0, %1 : (i1, !llvm.ptr, i32) -> ()
10+
llvm.return
11+
}
12+
13+
// CHECK-LABEL: @assume_separate_storage
14+
// CHECK-SAME: (ptr %[[ARG0:.+]], ptr %[[ARG1:.+]])
15+
llvm.func @assume_separate_storage(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
16+
%0 = llvm.mlir.constant(1 : i1) : i1
17+
// CHECK: call void @llvm.assume(i1 true) [ "separate_storage"(ptr %[[ARG0]], ptr %[[ARG1]]) ]
18+
llvm.intr.assume.separate_storage %0, %arg0, %arg1 : (i1, !llvm.ptr, !llvm.ptr) -> ()
19+
llvm.return
20+
}

mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
237237
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
238238
os << ", ";
239239
printBracketedRange(traits, os);
240-
os << ", " << intr.getNumResults() << ", "
240+
os << ", " << intr.getNumResults() << ", \"\", 1, "
241241
<< (requiresAccessGroup ? "1" : "0") << ", "
242242
<< (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
243243
<< (operands.empty() ? "" : " ");

0 commit comments

Comments
 (0)