Skip to content

Commit 0efff7c

Browse files
electricliliesJeff Niu
authored andcommitted
[mlir] Add call_intrinsic op to LLVMIIR
The call_intrinsic op allows us to call LLVM intrinsics from the LLVMDialect without implementing a new op every time. Reviewed By: lattner, rriddle Differential Revision: https://reviews.llvm.org/D137187
1 parent 1ceafe5 commit 0efff7c

File tree

3 files changed

+169
-4
lines changed

3 files changed

+169
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,25 @@ def LLVM_vector_extract
707707
}];
708708
}
709709

710+
//===--------------------------------------------------------------------===//
711+
// CallIntrinsicOp
712+
//===--------------------------------------------------------------------===//
713+
def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", [Pure]> {
714+
let summary = "Call to an LLVM intrinsic function.";
715+
let description = [{
716+
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
717+
the MLIR function type of this op to determine which intrinsic to call.
718+
}];
719+
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args);
720+
let results = (outs Variadic<LLVM_Type>:$results);
721+
let llvmBuilder = [{
722+
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
723+
}];
724+
let assemblyFormat = [{
725+
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
726+
}];
727+
}
728+
710729
//
711730
// LLVM Vector Predication operations.
712731
//

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,70 @@ static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) {
258258
return position;
259259
}
260260

261+
/// Get the declaration of an overloaded llvm intrinsic. First we get the
262+
/// overloaded argument types and/or result type from the CallIntrinsicOp, and
263+
/// then use those to get the correct declaration of the overloaded intrinsic.
264+
static FailureOr<llvm::Function *>
265+
getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id,
266+
llvm::Module *module,
267+
LLVM::ModuleTranslation &moduleTranslation) {
268+
SmallVector<llvm::Type *, 8> allArgTys;
269+
for (Type type : op->getOperandTypes())
270+
allArgTys.push_back(moduleTranslation.convertType(type));
271+
272+
llvm::Type *resTy;
273+
if (op.getNumResults() == 0)
274+
resTy = llvm::Type::getVoidTy(module->getContext());
275+
else
276+
resTy = moduleTranslation.convertType(op.getResult(0).getType());
277+
278+
// ATM we do not support variadic intrinsics.
279+
llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false);
280+
281+
SmallVector<llvm::Intrinsic::IITDescriptor, 8> table;
282+
getIntrinsicInfoTableEntries(id, table);
283+
ArrayRef<llvm::Intrinsic::IITDescriptor> tableRef = table;
284+
285+
SmallVector<llvm::Type *, 8> overloadedArgTys;
286+
if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
287+
overloadedArgTys) !=
288+
llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
289+
return op.emitOpError("intrinsic type is not a match");
290+
}
291+
292+
ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
293+
return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef);
294+
}
295+
296+
/// Builder for LLVM_CallIntrinsicOp
297+
static LogicalResult
298+
convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder,
299+
LLVM::ModuleTranslation &moduleTranslation) {
300+
llvm::Module *module = builder.GetInsertBlock()->getModule();
301+
llvm::Intrinsic::ID id =
302+
llvm::Function::lookupIntrinsicID(op.getIntrinAttr());
303+
if (!id)
304+
return op.emitOpError()
305+
<< "couldn't find intrinsic: " << op.getIntrinAttr();
306+
307+
llvm::Function *fn = nullptr;
308+
if (llvm::Intrinsic::isOverloaded(id)) {
309+
auto fnOrFailure =
310+
getOverloadedDeclaration(op, id, module, moduleTranslation);
311+
if (failed(fnOrFailure))
312+
return failure();
313+
fn = fnOrFailure.value();
314+
} else {
315+
fn = llvm::Intrinsic::getDeclaration(module, id, {});
316+
}
317+
318+
auto *inst =
319+
builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands()));
320+
if (op.getNumResults() == 1)
321+
moduleTranslation.mapValue(op->getResults().front()) = inst;
322+
return success();
323+
}
324+
261325
static LogicalResult
262326
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
263327
LLVM::ModuleTranslation &moduleTranslation) {
@@ -272,8 +336,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
272336
// Emit function calls. If the "callee" attribute is present, this is a
273337
// direct function call and we also need to look up the remapped function
274338
// itself. Otherwise, this is an indirect call and the callee is the first
275-
// operand, look it up as a normal value. Return the llvm::Value representing
276-
// the function result, which may be of llvm::VoidTy type.
339+
// operand, look it up as a normal value. Return the llvm::Value
340+
// representing the function result, which may be of llvm::VoidTy type.
277341
auto convertCall = [&](Operation &op) -> llvm::Value * {
278342
auto operands = moduleTranslation.lookupValues(op.getOperands());
279343
ArrayRef<llvm::Value *> operandsRef(operands);
@@ -404,8 +468,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
404468
return success();
405469
}
406470

407-
// Emit branches. We need to look up the remapped blocks and ignore the block
408-
// arguments that were transformed into PHI nodes.
471+
// Emit branches. We need to look up the remapped blocks and ignore the
472+
// block arguments that were transformed into PHI nodes.
409473
if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
410474
llvm::BranchInst *branch =
411475
builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK: ; ModuleID = 'LLVMDialectModule'
4+
// CHECK: source_filename = "LLVMDialectModule"
5+
// CHECK: declare ptr @malloc(i64)
6+
// CHECK: declare void @free(ptr)
7+
// CHECK: define <4 x float> @round_sse41() {
8+
// CHECK: %1 = call <4 x float> @llvm.x86.sse41.round.ss(<4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, <4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, i32 1)
9+
// CHECK: ret <4 x float> %1
10+
// CHECK: }
11+
llvm.func @round_sse41() -> vector<4xf32> {
12+
%0 = llvm.mlir.constant(1 : i32) : i32
13+
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
14+
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {}
15+
llvm.return %res: vector<4xf32>
16+
}
17+
18+
// -----
19+
20+
// CHECK: ; ModuleID = 'LLVMDialectModule'
21+
// CHECK: source_filename = "LLVMDialectModule"
22+
23+
// CHECK: declare ptr @malloc(i64)
24+
25+
// CHECK: declare void @free(ptr)
26+
27+
// CHECK: define float @round_overloaded() {
28+
// CHECK: %1 = call float @llvm.round.f32(float 1.000000e+00)
29+
// CHECK: ret float %1
30+
// CHECK: }
31+
llvm.func @round_overloaded() -> f32 {
32+
%0 = llvm.mlir.constant(1.0 : f32) : f32
33+
%res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
34+
llvm.return %res: f32
35+
}
36+
37+
// -----
38+
39+
// CHECK: ; ModuleID = 'LLVMDialectModule'
40+
// CHECK: source_filename = "LLVMDialectModule"
41+
// CHECK: declare ptr @malloc(i64)
42+
// CHECK: declare void @free(ptr)
43+
// CHECK: define void @lifetime_start() {
44+
// CHECK: %1 = alloca float, i8 1, align 4
45+
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1)
46+
// CHECK: ret void
47+
// CHECK: }
48+
llvm.func @lifetime_start() {
49+
%0 = llvm.mlir.constant(4 : i64) : i64
50+
%1 = llvm.mlir.constant(1 : i8) : i8
51+
%2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
52+
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
53+
llvm.return
54+
}
55+
56+
// -----
57+
58+
llvm.func @variadic() {
59+
%0 = llvm.mlir.constant(1 : i8) : i8
60+
%1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
61+
llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> ()
62+
llvm.return
63+
}
64+
65+
// -----
66+
67+
llvm.func @no_intrinsic() {
68+
// expected-error@below {{'llvm.call_intrinsic' op couldn't find intrinsic: "llvm.does_not_exist"}}
69+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
70+
llvm.call_intrinsic "llvm.does_not_exist"() : () -> ()
71+
llvm.return
72+
}
73+
74+
// -----
75+
76+
llvm.func @bad_types() {
77+
%0 = llvm.mlir.constant(1 : i8) : i8
78+
// expected-error@below {{'llvm.call_intrinsic' op intrinsic type is not a match}}
79+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
80+
llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
81+
llvm.return
82+
}

0 commit comments

Comments
 (0)