diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 352e2ec91bdbe..5ccddef158d9c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> { ]; } +def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> { + let description = [{ + This interface defines an LLVM operation with a disjoint flag and + provides a uniform API for accessing it. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod<[{ + Get the disjoint flag for the operation. + }], "bool", "getIsDisjoint", (ins), [{}], [{ + return $_op.getProperties().isDisjoint; + }]>, + InterfaceMethod<[{ + Set the disjoint flag for the operation. + }], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{ + $_op.getProperties().isDisjoint = isDisjoint; + }]>, + StaticInterfaceMethod<[{ + Get the attribute name of the isDisjoint property. + }], "StringRef", "getIsDisjointName", (ins), [{}], [{ + return "isDisjoint"; + }]>, + ]; +} + def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> { let description = [{ This interface defines an LLVM operation with an nneg flag and diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 34f3e4b33b829..847ff6def34b8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag traits = []> : + LLVM_ArithmeticOpBase], traits)> { + let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint)); + + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + moduleImport.setDisjointFlag(inst, op); + $res = op; + }]; + let assemblyFormat = [{ + (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res) + }]; + string llvmBuilder = [{ + auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/""); + moduleTranslation.setDisjointFlag(op, inst); + $res = inst; + }]; +} class LLVM_FloatArithmeticOp traits = []> : LLVM_ArithmeticOpBase; def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; -def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> { +def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> { let hasFolder = 1; } def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 30164843f6367..eea0647895b01 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -192,6 +192,11 @@ class ModuleImport { /// implement the exact flag interface. void setExactFlag(llvm::Instruction *inst, Operation *op) const; + /// Sets the disjoint flag attribute for the imported operation `op` + /// given the original instruction `inst`. Asserts if the operation does + /// not implement the disjoint flag interface. + void setDisjointFlag(llvm::Instruction *inst, Operation *op) const; + /// Sets the nneg flag attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not /// implement the nneg flag interface. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index ffeeeae57ae95..1b62437761ed9 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -167,6 +167,12 @@ class ModuleTranslation { /// attribute. void setLoopMetadata(Operation *op, llvm::Instruction *inst); + /// Sets the disjoint flag attribute for the exported instruction `value` + /// given the original operation `op`. Asserts if the operation does + /// not implement the disjoint flag interface, and asserts if the value + /// is an instruction that implements the disjoint flag. + void setDisjointFlag(Operation *op, llvm::Value *value); + /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(Type type); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 71d88d3a62f2b..0d416a5857fac 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const { iface.setIsExact(inst->isExact()); } +void ModuleImport::setDisjointFlag(llvm::Instruction *inst, + Operation *op) const { + auto iface = cast(op); + auto instDisjoint = cast(inst); + + iface.setIsDisjoint(instDisjoint->isDisjoint()); +} + void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const { auto iface = cast(op); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ceb8ba3b33818..9e58d2a29199e 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op, inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD); } +void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) { + auto iface = cast(op); + // We do a dyn_cast here in case the value got folded into a constant. + if (auto disjointInst = dyn_cast(value)) + disjointInst->setIsDisjoint(iface.getIsDisjoint()); +} + llvm::Type *ModuleTranslation::convertType(Type type) { return typeTranslator.translateType(type); } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index aa558bad2299c..06f7b2d9f586f 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32, %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32 %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32 +// Integer disjoint flag. +// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32 + %or_flag = llvm.or disjoint %arg0, %arg0 : i32 + // Floating point binary operations. // // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32 diff --git a/mlir/test/Target/LLVMIR/Import/disjoint.ll b/mlir/test/Target/LLVMIR/Import/disjoint.ll new file mode 100644 index 0000000000000..36091c0904352 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/disjoint.ll @@ -0,0 +1,8 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @disjointflag_inst +define void @disjointflag_inst(i64 %arg1, i64 %arg2) { + ; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64 + %1 = or disjoint i64 %arg1, %arg2 + ret void +} diff --git a/mlir/test/Target/LLVMIR/disjoint.mlir b/mlir/test/Target/LLVMIR/disjoint.mlir new file mode 100644 index 0000000000000..1f5a42e608ba4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/disjoint.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @disjointflag_func +llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) { + // CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}} + %0 = llvm.or disjoint %arg0, %arg1 : i64 + llvm.return +}