Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
/// function confirms that the Operation has the desired properties.
bool satisfiesLLVMModule(Operation *op);

/// Lookup parent Module satisfying LLVM conditions on the Module Operation.
Operation *parentLLVMModule(Operation *op);

/// Convert an array of integer attributes to a vector of integers that can be
/// used as indices in LLVM operations.
template <typename IntT = int64_t>
Expand Down
65 changes: 65 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,10 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
/// Return the llvm.mlir.alias operation that defined the value referenced
/// here.
AliasOp getAlias(SymbolTableCollection &symbolTable);

/// Return the llvm.mlir.ifunc operation that defined the value referenced
/// here.
IFuncOp getIFunc(SymbolTableCollection &symbolTable);
}];

let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))";
Expand Down Expand Up @@ -1601,6 +1605,67 @@ def LLVM_AliasOp : LLVM_Op<"mlir.alias",
let hasRegionVerifier = 1;
}

def LLVM_IFuncOp : LLVM_Op<"mlir.ifunc",
[IsolatedFromAbove, Symbol, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttr:$i_func_type,
FlatSymbolRefAttr:$resolver,
TypeAttr:$resolver_type,
Linkage:$linkage,
UnitAttr:$dso_local,
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntNonNegative]>, "0">:$address_space,
DefaultValuedAttr<UnnamedAddr, "mlir::LLVM::UnnamedAddr::None">:$unnamed_addr,
DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_
);
let summary = "LLVM dialect ifunc";
let description = [{
`llvm.mlir.ifunc` is a top level operation that defines a global ifunc.
It defines a new symbol and takes a symbol refering to a resolver function.
IFuncs can be called as regular functions. The function type is the same
as the IFuncType. The symbol is resolved at runtime by calling a resolver
function.

Examples:

```mlir
// IFuncs resolve a symbol at runtime using a resovler function.
llvm.mlir.ifunc external @foo: !llvm.func<f32 (i64)>, !llvm.ptr @resolver

llvm.func @foo_1(i64) -> f32
llvm.func @foo_2(i64) -> f32

llvm.func @resolve_foo() -> !llvm.ptr attributes {
%0 = llvm.mlir.addressof @foo_2 : !llvm.ptr
%1 = llvm.mlir.addressof @foo_1 : !llvm.ptr

// ... Logic selecting from foo_{1, 2}

// Return function pointer to the selected function
llvm.return %7 : !llvm.ptr
}

llvm.func @use_foo() {
// IFuncs are called as regular functions
%res = llvm.call @foo(%value) : i64 -> f32
}
```
}];

let builders = [
OpBuilder<(ins "StringRef":$name, "Type":$i_func_type,
"StringRef":$resolver, "Type":$resolver_type,
"Linkage":$linkage, "LLVM::Visibility":$visibility)>
];

let assemblyFormat = [{
custom<LLVMLinkage>($linkage) ($visibility_^)? ($unnamed_addr^)?
$sym_name `:` $i_func_type `,` $resolver_type $resolver attr-dict
}];
let hasVerifier = 1;
}


def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
[Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins FlatSymbolRefAttr:$function_name);
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class ModuleImport {
/// Converts all aliases of the LLVM module to MLIR variables.
LogicalResult convertAliases();

/// Converts all ifuncs of the LLVM module to MLIR variables.
LogicalResult convertIFuncs();

/// Converts the data layout of the LLVM module to an MLIR data layout
/// specification.
LogicalResult convertDataLayout();
Expand Down Expand Up @@ -320,6 +323,8 @@ class ModuleImport {
/// Converts an LLVM global alias variable into an MLIR LLVM dialect alias
/// operation if a conversion exists. Otherwise, returns failure.
LogicalResult convertAlias(llvm::GlobalAlias *alias);
// Converts an LLVM global ifunc into an MLIR LLVM dialect ifunc operation.
LogicalResult convertIFunc(llvm::GlobalIFunc *ifunc);
/// Returns personality of `func` as a FlatSymbolRefAttr.
FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
/// Imports `bb` into `block`, which must be initially empty.
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ class ModuleTranslation {
return aliasesMapping.lookup(op);
}

/// Finds an LLVM IR global value that corresponds to the given MLIR operation
/// defining an IFunc.
llvm::GlobalValue *lookupIFunc(Operation *op) {
return ifuncMapping.lookup(op);
}

/// Returns the OpenMP IR builder associated with the LLVM IR module being
/// constructed.
llvm::OpenMPIRBuilder *getOpenMPBuilder();
Expand Down Expand Up @@ -308,6 +314,7 @@ class ModuleTranslation {
bool recordInsertions = false);
LogicalResult convertFunctionSignatures();
LogicalResult convertFunctions();
LogicalResult convertIFuncs();
LogicalResult convertComdats();

LogicalResult convertUnresolvedBlockAddress();
Expand Down Expand Up @@ -369,6 +376,10 @@ class ModuleTranslation {
/// aliases.
DenseMap<Operation *, llvm::GlobalValue *> aliasesMapping;

/// Mappings between llvm.mlir.ifunc definitions and corresponding global
/// ifuncs.
DenseMap<Operation *, llvm::GlobalValue *> ifuncMapping;

/// A stateful object used to translate types.
TypeToLLVMIRTranslator typeTranslator;

Expand Down
119 changes: 101 additions & 18 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
return static_cast<RetTy>(index);
}

static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) {
p << stringifyLinkage(val.getLinkage());
}

static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
val = LinkageAttr::get(
p.getContext(),
parseOptionalLLVMKeyword<LLVM::Linkage>(p, LLVM::Linkage::External));
return success();
}

//===----------------------------------------------------------------------===//
// Operand bundle helpers.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1175,14 +1186,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a symbol in the current scope";
auto fn = dyn_cast<LLVMFuncOp>(callee);
if (!fn)
return emitOpError() << "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function";

if (failed(verifyCallOpDebugInfo(*this, fn)))
return failure();
fnType = fn.getFunctionType();
if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
if (failed(verifyCallOpDebugInfo(*this, fn)))
return failure();
fnType = fn.getFunctionType();
} else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
fnType = ifunc.getIFuncType();
} else {
return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function or IFunc";
}
}

LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
Expand Down Expand Up @@ -2038,14 +2052,6 @@ LogicalResult ReturnOp::verify() {
// LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//

static Operation *parentLLVMModule(Operation *op) {
Operation *module = op->getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return module;
}

GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<GlobalOp>(
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
Expand All @@ -2061,6 +2067,11 @@ AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
}

IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<IFuncOp>(
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
}

LogicalResult
AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *symbol =
Expand All @@ -2069,10 +2080,11 @@ AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto global = dyn_cast_or_null<GlobalOp>(symbol);
auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
auto alias = dyn_cast_or_null<AliasOp>(symbol);
auto ifunc = dyn_cast_or_null<IFuncOp>(symbol);

if (!global && !function && !alias)
if (!global && !function && !alias && !ifunc)
return emitOpError("must reference a global defined by 'llvm.mlir.global', "
"'llvm.mlir.alias' or 'llvm.func'");
"'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'");

LLVMPointerType type = getType();
if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
Expand Down Expand Up @@ -2682,6 +2694,69 @@ unsigned AliasOp::getAddrSpace() {
return ptrTy.getAddressSpace();
}

//===----------------------------------------------------------------------===//
// IFuncOp
//===----------------------------------------------------------------------===//

void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
Type iFuncType, StringRef resolverName, Type resolverType,
Linkage linkage, LLVM::Visibility visibility) {
return build(builder, result, name, iFuncType, resolverName, resolverType,
linkage, /*dso_local=*/false, /*address_space=*/0,
UnnamedAddr::None, visibility);
}

LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *symbol =
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr());
// This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
auto resolver = dyn_cast<LLVMFuncOp>(symbol);
auto alias = dyn_cast<AliasOp>(symbol);
while (alias) {
Block &initBlock = alias.getInitializerBlock();
auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp());
// FIXME: This is a best effort solution. The AliasOp body might be more
// complex and in that case we bail out with success. To completely match
// the LLVM IR logic it would be necessary to implement proper alias and
// cast stripping.
if (!addrOp)
return success();
resolver = addrOp.getFunction(symbolTable);
alias = addrOp.getAlias(symbolTable);
}
if (!resolver)
return emitOpError("must have a function resolver");
Linkage linkage = resolver.getLinkage();
if (resolver.isExternal() || linkage == Linkage::AvailableExternally)
return emitOpError("resolver must be a definition");
if (!isa<LLVMPointerType>(resolver.getFunctionType().getReturnType()))
return emitOpError("resolver must return a pointer");
auto resolverPtr = dyn_cast<LLVMPointerType>(getResolverType());
if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace())
return emitOpError("resolver has incorrect type");
return success();
}

LogicalResult IFuncOp::verify() {
switch (getLinkage()) {
case Linkage::External:
case Linkage::Internal:
case Linkage::Private:
case Linkage::Weak:
case Linkage::WeakODR:
case Linkage::Linkonce:
case Linkage::LinkonceODR:
break;
default:
return emitOpError() << "'" << stringifyLinkage(getLinkage())
<< "' linkage not supported in ifuncs, available "
"options: private, internal, linkonce, weak, "
"linkonce_odr, weak_odr, or external linkage";
}
return success();
}

//===----------------------------------------------------------------------===//
// ShuffleVectorOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4329,3 +4404,11 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}

Operation *mlir::LLVM::parentLLVMModule(Operation *op) {
Operation *module = op->getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return module;
}
24 changes: 18 additions & 6 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,18 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
ArrayRef<llvm::Value *> operandsRef(operands);
llvm::CallInst *call;
if (auto attr = callOp.getCalleeAttr()) {
call =
builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
operandsRef, opBundles);
if (llvm::Function *function =
moduleTranslation.lookupFunction(attr.getValue())) {
call = builder.CreateCall(function, operandsRef, opBundles);
} else {
Operation *moduleOp = parentLLVMModule(&opInst);
Operation *ifuncOp =
moduleTranslation.symbolTable().lookupSymbolIn(moduleOp, attr);
llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp);
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles);
}
} else {
llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
moduleTranslation.convertType(callOp.getCalleeFunctionType()));
Expand Down Expand Up @@ -648,18 +657,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::LLVMFuncOp function =
addressOfOp.getFunction(moduleTranslation.symbolTable());
LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable());
LLVM::IFuncOp ifunc = addressOfOp.getIFunc(moduleTranslation.symbolTable());

// The verifier should not have allowed this.
assert((global || function || alias) &&
"referencing an undefined global, function, or alias");
assert((global || function || alias || ifunc) &&
"referencing an undefined global, function, alias, or ifunc");

llvm::Value *llvmValue = nullptr;
if (global)
llvmValue = moduleTranslation.lookupGlobal(global);
else if (alias)
llvmValue = moduleTranslation.lookupAlias(alias);
else
else if (function)
llvmValue = moduleTranslation.lookupFunction(function.getName());
else
llvmValue = moduleTranslation.lookupIFunc(ifunc);

moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue);
return success();
Expand Down
Loading
Loading