Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class LLVMImportDialectInterface
/// returns the list of supported intrinsic identifiers.
virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }

/// Whether dialect have a generic way to represent unsupported intrinsics
/// (i.e. as oposed to supported ones aboves).
virtual bool getUnregisteredIntrinsics() const { return false; }

/// Hook for derived dialect interfaces to publish the supported instructions.
/// As every LLVM IR instruction has a unique integer identifier, the function
/// returns the list of supported instruction identifiers. These identifiers
Expand Down Expand Up @@ -145,6 +149,12 @@ class LLVMImportInterface
// Add a mapping for all supported metadata kinds.
for (unsigned kind : iface.getSupportedMetadata(llvmContext))
metadataToDialect[kind].push_back(iface.getDialect());

// There can be only one dialect dealing with unregistered
// intrinsics, the last one to support the interface is the
// one to be used.
if (iface.getUnregisteredIntrinsics())
unregisteredIntrinscToDialect = iface.getDialect();
}

return success();
Expand All @@ -155,7 +165,19 @@ class LLVMImportInterface
LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst,
LLVM::ModuleImport &moduleImport) const {
// Lookup the dialect interface for the given intrinsic.
Dialect *dialect = intrinsicToDialect.lookup(inst->getIntrinsicID());
llvm::Intrinsic::ID intrinId = inst->getIntrinsicID();
if (intrinId == llvm::Intrinsic::not_intrinsic)
return failure();

// First lookup intrinsic across different dialects for known
// supported converstions, examples include arm-neon, nvm-sve, etc
Dialect *dialect = intrinsicToDialect.lookup(intrinId);

// No specialized (supported) intrinsics, attempt to generate a generic
// version via llvm.call_intrinsic (if available).
if (!dialect)
dialect = unregisteredIntrinscToDialect;

if (!dialect)
return failure();

Expand Down Expand Up @@ -227,6 +249,9 @@ class LLVMImportInterface
DenseMap<unsigned, Dialect *> intrinsicToDialect;
DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;

/// Unregistered generic and target independent intrinsics.
Dialect *unregisteredIntrinscToDialect = nullptr;
};

} // namespace mlir
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,49 @@ static ArrayRef<unsigned> getSupportedIntrinsicsImpl() {
return convertibleIntrinsics;
}

/// Converts the LLVM intrinsic to a generic LLVM intrinsic call using
/// llvm.intrinsic_call. Returns failure otherwise.
static LogicalResult
convertUnregisteredIntrinsicImpl(OpBuilder &odsBuilder, llvm::CallInst *inst,
LLVM::ModuleImport &moduleImport) {
StringRef intrinName = inst->getCalledFunction()->getName();

// Sanity check the intrinsic ID.
SmallVector<llvm::Value *> args(inst->args());
ArrayRef<llvm::Value *> llvmOperands(args);

SmallVector<llvm::OperandBundleUse> llvmOpBundles;
llvmOpBundles.reserve(inst->getNumOperandBundles());
for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i)
llvmOpBundles.push_back(inst->getOperandBundleAt(i));

SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs)))
return failure();

mlir::Type results = moduleImport.convertType(inst->getType());
auto op = odsBuilder.create<::mlir::LLVM::CallIntrinsicOp>(
moduleImport.translateLoc(inst->getDebugLoc()), results,
StringAttr::get(odsBuilder.getContext(), intrinName),
ValueRange{mlirOperands}, FastmathFlagsAttr{});

moduleImport.setFastmathFlagsAttr(inst, op);

// Update importer tracking of results.
unsigned numRes = op.getNumResults();
if (numRes == 1)
moduleImport.mapValue(inst) = op.getResult(0);
else if (numRes == 0)
moduleImport.mapNoResultOp(inst);
else
return op.emitError(
"expected at most one result from target intrinsic call");

return success();
}

/// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a
/// conversion exits. Returns failure otherwise.
static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
Expand All @@ -75,6 +118,8 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
llvmOpBundles.push_back(inst->getOperandBundleAt(i));

#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
} else if (intrinsicID != llvm::Intrinsic::not_intrinsic) {
return convertUnregisteredIntrinsicImpl(odsBuilder, inst, moduleImport);
}

return failure();
Expand Down Expand Up @@ -422,6 +467,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
return getSupportedIntrinsicsImpl();
}

/// Cnvertible to generic llvm.call_intrinsic.
bool getUnregisteredIntrinsics() const final { return true; }

/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
/// LLVM dialect attributes.
ArrayRef<unsigned>
Expand Down
12 changes: 0 additions & 12 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,6 @@ bb1:

; // -----

declare void @llvm.gcroot(ptr %arg1, ptr %arg2)

; CHECK: <unknown>
; CHECK-SAME: error: unhandled intrinsic: call void @llvm.gcroot(ptr %arg1, ptr null)
define void @unhandled_intrinsic() gc "example" {
%arg1 = alloca ptr
call void @llvm.gcroot(ptr %arg1, ptr null)
ret void
}

; // -----

; Check that debug intrinsics with an unsupported argument are dropped.

declare void @llvm.dbg.value(metadata, metadata, metadata)
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Target/LLVMIR/Import/intrinsic-unregistered.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; RUN: mlir-translate -import-llvm %s -split-input-file | FileCheck %s

declare i64 @llvm.aarch64.ldxr.p0(ptr)

define dso_local void @t0(ptr %a) {
%x = call i64 @llvm.aarch64.ldxr.p0(ptr elementtype(i8) %a)
ret void
}

; CHECK-LABEL: llvm.func @llvm.aarch64.ldxr.p0(!llvm.ptr)
; CHECK-LABEL: llvm.func @t0
; CHECK: llvm.call_intrinsic "llvm.aarch64.ldxr.p0"({{.*}}) : (!llvm.ptr) -> i64
; CHECK: llvm.return
; CHECK: }

; -----

declare <8 x i8> @llvm.aarch64.neon.uabd.v8i8(<8 x i8>, <8 x i8>)

define dso_local <8 x i8> @t1(<8 x i8> %lhs, <8 x i8> %rhs) {
%r = call <8 x i8> @llvm.aarch64.neon.uabd.v8i8(<8 x i8> %lhs, <8 x i8> %rhs)
ret <8 x i8> %r
}

; CHECK: llvm.func @t1(%[[A0:.*]]: vector<8xi8>, %[[A1:.*]]: vector<8xi8>) -> vector<8xi8> {{.*}} {
; CHECK: %[[R:.*]] = llvm.call_intrinsic "llvm.aarch64.neon.uabd.v8i8"(%[[A0]], %[[A1]]) : (vector<8xi8>, vector<8xi8>) -> vector<8xi8>
; CHECK: llvm.return %[[R]] : vector<8xi8>
; CHECK: }

; -----

declare void @llvm.aarch64.neon.st2.v8i8.p0(<8 x i8>, <8 x i8>, ptr)

define dso_local void @t2(<8 x i8> %lhs, <8 x i8> %rhs, ptr %a) {
call void @llvm.aarch64.neon.st2.v8i8.p0(<8 x i8> %lhs, <8 x i8> %rhs, ptr %a)
ret void
}

; CHECK: llvm.func @t2(%[[A0:.*]]: vector<8xi8>, %[[A1:.*]]: vector<8xi8>, %[[A2:.*]]: !llvm.ptr) {{.*}} {
; CHECK: llvm.call_intrinsic "llvm.aarch64.neon.st2.v8i8.p0"(%[[A0]], %[[A1]], %[[A2]]) : (vector<8xi8>, vector<8xi8>, !llvm.ptr) -> !llvm.void
; CHECK: llvm.return
; CHECK: }

; -----

declare void @llvm.gcroot(ptr %arg1, ptr %arg2)
define void @gctest() gc "example" {
%arg1 = alloca ptr
call void @llvm.gcroot(ptr %arg1, ptr null)
ret void
}

; CHECK-LABEL: @gctest
; CHECK: llvm.call_intrinsic "llvm.gcroot"({{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr) -> !llvm.void

; -----

; Test we get the supported version, not the unregistered one.

declare i32 @llvm.lround.i32.f32(float)

; CHECK-LABEL: llvm.func @lround_test
define void @lround_test(float %0, double %1) {
; CHECK-NOT: llvm.call_intrinsic "llvm.lround
; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i32
%3 = call i32 @llvm.lround.i32.f32(float %0)
ret void
}