diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index d236cae0d8088..63e007cdc335c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -33,6 +33,7 @@ #include "mlir/Support/ThreadLocalCache.h" #include "llvm/ADT/PointerEmbeddedInt.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" @@ -87,6 +88,13 @@ class GEPArg : public PointerUnion { } // namespace LLVM } // namespace mlir +namespace mlir { +namespace LLVM { +struct AssumeAlignTag {}; +struct AssumeSeparateStorageTag {}; +} // namespace LLVM +} // namespace mlir + ///// Ops ///// #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 845c88b1be775..d07ebbacc6043 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -450,7 +450,14 @@ def LLVM_AssumeOp }]; let builders = [ - OpBuilder<(ins "Value":$cond)> + OpBuilder<(ins "Value":$cond)>, + OpBuilder<(ins "Value":$cond, + "ArrayRef>":$opBundles)>, + OpBuilder<(ins "Value":$cond, "llvm::StringRef":$tag, "ValueRange":$args)>, + OpBuilder<(ins "Value":$cond, "AssumeAlignTag":$tag, "Value":$ptr, + "Value":$align)>, + OpBuilder<(ins "Value":$cond, "AssumeSeparateStorageTag":$tag, + "Value":$ptr1, "Value":$ptr2)> ]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index cc73878a64ff6..c9bc9533ca2a6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3438,7 +3438,44 @@ void InlineAsmOp::getEffects( void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, mlir::Value cond) { return build(builder, state, cond, /*op_bundle_operands=*/{}, - /*op_bundle_tags=*/{}); + /*op_bundle_tags=*/ArrayAttr{}); +} + +void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, + Value cond, + ArrayRef> opBundles) { + SmallVector opBundleOperands; + SmallVector opBundleTags; + opBundleOperands.reserve(opBundles.size()); + opBundleTags.reserve(opBundles.size()); + + for (const llvm::OperandBundleDefT &bundle : opBundles) { + opBundleOperands.emplace_back(bundle.inputs()); + opBundleTags.push_back( + StringAttr::get(builder.getContext(), bundle.getTag())); + } + + auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags); + return build(builder, state, cond, opBundleOperands, opBundleTagsAttr); +} + +void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, + Value cond, llvm::StringRef tag, ValueRange args) { + llvm::OperandBundleDefT opBundle( + tag.str(), SmallVector(args.begin(), args.end())); + return build(builder, state, cond, opBundle); +} + +void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, + Value cond, AssumeAlignTag, Value ptr, Value align) { + return build(builder, state, cond, "align", ValueRange{ptr, align}); +} + +void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, + Value cond, AssumeSeparateStorageTag, Value ptr1, + Value ptr2) { + return build(builder, state, cond, "separate_storage", + ValueRange{ptr1, ptr2}); } LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }