Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
/// function confirms that the Operation has the desired properties.
bool satisfiesLLVMModule(Operation *op);

/// Lookup parent Module satisfying LLVM conditions on the Module Operation.
Operation *parentLLVMModule(Operation *op);

/// Convert an array of integer attributes to a vector of integers that can be
/// used as indices in LLVM operations.
template <typename IntT = int64_t>
Expand Down
65 changes: 65 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,10 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
/// Return the llvm.mlir.alias operation that defined the value referenced
/// here.
AliasOp getAlias(SymbolTableCollection &symbolTable);

/// Return the llvm.mlir.ifunc operation that defined the value referenced
/// here.
IFuncOp getIFunc(SymbolTableCollection &symbolTable);
}];

let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))";
Expand Down Expand Up @@ -1601,6 +1605,67 @@ def LLVM_AliasOp : LLVM_Op<"mlir.alias",
let hasRegionVerifier = 1;
}

def LLVM_IFuncOp : LLVM_Op<"mlir.ifunc",
[IsolatedFromAbove, Symbol, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttr:$i_func_type,
FlatSymbolRefAttr:$resolver,
TypeAttr:$resolver_type,
Linkage:$linkage,
UnitAttr:$dso_local,
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntNonNegative]>, "0">:$address_space,
DefaultValuedAttr<UnnamedAddr, "mlir::LLVM::UnnamedAddr::None">:$unnamed_addr,
DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_
);
let summary = "LLVM dialect ifunc";
let description = [{
`llvm.mlir.ifunc` is a top level operation that defines a global ifunc.
It defines a new symbol and takes a symbol refering to a resolver function.
IFuncs can be called as regular functions. The function type is the same
as the IFuncType. The symbol is resolved at runtime by calling a resolver
function.

Examples:

```mlir
// IFuncs resolve a symbol at runtime using a resovler function.
llvm.mlir.ifunc external @foo: !llvm.func<f32 (i64)>, !llvm.ptr @resolver

llvm.func @foo_1(i64) -> f32
llvm.func @foo_2(i64) -> f32

llvm.func @resolve_foo() -> !llvm.ptr attributes {
%0 = llvm.mlir.addressof @foo_2 : !llvm.ptr
%1 = llvm.mlir.addressof @foo_1 : !llvm.ptr

// ... Logic selecting from foo_{1, 2}

// Return function pointer to the selected function
llvm.return %7 : !llvm.ptr
}

llvm.func @use_foo() {
// IFuncs are called as regular functions
%res = llvm.call @foo(%value) : i64 -> f32
}
```
}];

let builders = [
OpBuilder<(ins "StringRef":$name, "Type":$i_func_type,
"StringRef":$resolver, "Type":$resolver_type,
"Linkage":$linkage, "LLVM::Visibility":$visibility)>
];

let assemblyFormat = [{
custom<LLVMLinkage>($linkage) ($visibility_^)? ($unnamed_addr^)?
$sym_name `:` $i_func_type `,` $resolver_type $resolver attr-dict
}];
let hasVerifier = 1;
}


def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
[Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins FlatSymbolRefAttr:$function_name);
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class ModuleImport {
/// Converts all aliases of the LLVM module to MLIR variables.
LogicalResult convertAliases();

/// Converts all ifuncs of the LLVM module to MLIR variables.
LogicalResult convertIFuncs();

/// Converts the data layout of the LLVM module to an MLIR data layout
/// specification.
LogicalResult convertDataLayout();
Expand Down Expand Up @@ -320,6 +323,8 @@ class ModuleImport {
/// Converts an LLVM global alias variable into an MLIR LLVM dialect alias
/// operation if a conversion exists. Otherwise, returns failure.
LogicalResult convertAlias(llvm::GlobalAlias *alias);
// Converts an LLVM global ifunc into an MLIR LLVM dialect ifunc operation.
LogicalResult convertIFunc(llvm::GlobalIFunc *ifunc);
/// Returns personality of `func` as a FlatSymbolRefAttr.
FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
/// Imports `bb` into `block`, which must be initially empty.
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ class ModuleTranslation {
return aliasesMapping.lookup(op);
}

/// Finds an LLVM IR global value that corresponds to the given MLIR operation
/// defining an IFunc.
llvm::GlobalValue *lookupIFunc(Operation *op) {
return ifuncMapping.lookup(op);
}

/// Returns the OpenMP IR builder associated with the LLVM IR module being
/// constructed.
llvm::OpenMPIRBuilder *getOpenMPBuilder();
Expand Down Expand Up @@ -308,6 +314,7 @@ class ModuleTranslation {
bool recordInsertions = false);
LogicalResult convertFunctionSignatures();
LogicalResult convertFunctions();
LogicalResult convertIFuncs();
LogicalResult convertComdats();

LogicalResult convertUnresolvedBlockAddress();
Expand Down Expand Up @@ -369,6 +376,10 @@ class ModuleTranslation {
/// aliases.
DenseMap<Operation *, llvm::GlobalValue *> aliasesMapping;

/// Mappings between llvm.mlir.ifunc definitions and corresponding global
/// ifuncs.
DenseMap<Operation *, llvm::GlobalValue *> ifuncMapping;

/// A stateful object used to translate types.
TypeToLLVMIRTranslator typeTranslator;

Expand Down
119 changes: 101 additions & 18 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
return static_cast<RetTy>(index);
}

static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) {
p << stringifyLinkage(val.getLinkage());
}

static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
val = LinkageAttr::get(
p.getContext(),
parseOptionalLLVMKeyword<LLVM::Linkage>(p, LLVM::Linkage::External));
return success();
}

//===----------------------------------------------------------------------===//
// Operand bundle helpers.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1175,14 +1186,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a symbol in the current scope";
auto fn = dyn_cast<LLVMFuncOp>(callee);
if (!fn)
return emitOpError() << "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function";

if (failed(verifyCallOpDebugInfo(*this, fn)))
return failure();
fnType = fn.getFunctionType();
if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
if (failed(verifyCallOpDebugInfo(*this, fn)))
return failure();
fnType = fn.getFunctionType();
} else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
fnType = ifunc.getIFuncType();
} else {
return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function or IFunc";
}
}

LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
Expand Down Expand Up @@ -2038,14 +2052,6 @@ LogicalResult ReturnOp::verify() {
// LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//

static Operation *parentLLVMModule(Operation *op) {
Operation *module = op->getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return module;
}

GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<GlobalOp>(
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
Expand All @@ -2061,6 +2067,11 @@ AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
}

IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<IFuncOp>(
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
}

LogicalResult
AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *symbol =
Expand All @@ -2069,10 +2080,11 @@ AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto global = dyn_cast_or_null<GlobalOp>(symbol);
auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
auto alias = dyn_cast_or_null<AliasOp>(symbol);
auto ifunc = dyn_cast_or_null<IFuncOp>(symbol);

if (!global && !function && !alias)
if (!global && !function && !alias && !ifunc)
return emitOpError("must reference a global defined by 'llvm.mlir.global', "
"'llvm.mlir.alias' or 'llvm.func'");
"'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'");

LLVMPointerType type = getType();
if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
Expand Down Expand Up @@ -2682,6 +2694,69 @@ unsigned AliasOp::getAddrSpace() {
return ptrTy.getAddressSpace();
}

//===----------------------------------------------------------------------===//
// IFuncOp
//===----------------------------------------------------------------------===//

void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
Type iFuncType, StringRef resolverName, Type resolverType,
Linkage linkage, LLVM::Visibility visibility) {
return build(builder, result, name, iFuncType, resolverName, resolverType,
linkage, /*dso_local=*/false, /*address_space=*/0,
UnnamedAddr::None, visibility);
}

LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *symbol =
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr());
// This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
auto resolver = dyn_cast<LLVMFuncOp>(symbol);
auto alias = dyn_cast<AliasOp>(symbol);
while (alias) {
Block &initBlock = alias.getInitializerBlock();
auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp());
// FIXME: This is a best effort solution. The AliasOp body might be more
// complex and in that case we bail out with success. To completely match
// the LLVM IR logic it would be necessary to implement proper alias and
// cast stripping.
if (!addrOp)
return success();
resolver = addrOp.getFunction(symbolTable);
alias = addrOp.getAlias(symbolTable);
}
if (!resolver)
return emitOpError("must have a function resolver");
Linkage linkage = resolver.getLinkage();
if (resolver.isExternal() || linkage == Linkage::AvailableExternally)
return emitOpError("resolver must be a definition");
if (!isa<LLVMPointerType>(resolver.getFunctionType().getReturnType()))
return emitOpError("resolver must return a pointer");
auto resolverPtr = dyn_cast<LLVMPointerType>(getResolverType());
if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace())
return emitOpError("resolver has incorrect type");
return success();
}

LogicalResult IFuncOp::verify() {
switch (getLinkage()) {
case Linkage::External:
case Linkage::Internal:
case Linkage::Private:
case Linkage::Weak:
case Linkage::WeakODR:
case Linkage::Linkonce:
case Linkage::LinkonceODR:
break;
default:
return emitOpError() << "'" << stringifyLinkage(getLinkage())
<< "' linkage not supported in ifuncs, available "
"options: private, internal, linkonce, weak, "
"linkonce_odr, weak_odr, or external linkage";
}
return success();
}

//===----------------------------------------------------------------------===//
// ShuffleVectorOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4329,3 +4404,11 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}

Operation *mlir::LLVM::parentLLVMModule(Operation *op) {
Operation *module = op->getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return module;
}
24 changes: 18 additions & 6 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,18 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
ArrayRef<llvm::Value *> operandsRef(operands);
llvm::CallInst *call;
if (auto attr = callOp.getCalleeAttr()) {
call =
builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
operandsRef, opBundles);
if (llvm::Function *function =
moduleTranslation.lookupFunction(attr.getValue())) {
call = builder.CreateCall(function, operandsRef, opBundles);
} else {
Operation *moduleOp = parentLLVMModule(&opInst);
Operation *ifuncOp =
moduleTranslation.symbolTable().lookupSymbolIn(moduleOp, attr);
llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp);
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles);
}
} else {
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
Expand Down Expand Up @@ -648,18 +657,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::LLVMFuncOp function =
addressOfOp.getFunction(moduleTranslation.symbolTable());
LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable());
LLVM::IFuncOp ifunc = addressOfOp.getIFunc(moduleTranslation.symbolTable());

// The verifier should not have allowed this.
assert((global || function || alias) &&
"referencing an undefined global, function, or alias");
assert((global || function || alias || ifunc) &&
"referencing an undefined global, function, alias, or ifunc");

llvm::Value *llvmValue = nullptr;
if (global)
llvmValue = moduleTranslation.lookupGlobal(global);
else if (alias)
llvmValue = moduleTranslation.lookupAlias(alias);
else
else if (function)
llvmValue = moduleTranslation.lookupFunction(function.getName());
else
llvmValue = moduleTranslation.lookupIFunc(ifunc);

moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue);
return success();
Expand Down
Loading