Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
Arg<StrAttr, "the C++ function to call">:$callee,
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
Arg<OptionalAttr<StrArrayAttr>, "template argument names">:$template_arg_names,
Variadic<EmitCType>:$operands
);
let results = (outs Variadic<EmitCType>);
Expand All @@ -302,7 +303,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
"::mlir::ValueRange":$operands,
CArg<"::mlir::ArrayAttr", "{}">:$args,
CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{
build($_builder, $_state, resultTypes, callee, args, template_args,
build($_builder, $_state, resultTypes, callee, args, template_args, {},
operands);
}]
>
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,14 @@ LogicalResult emitc::CallOpaqueOp::verify() {
if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
return emitOpError("template argument has invalid type");
}

if (std::optional<ArrayAttr> templateArgNames = getTemplateArgNames()) {
if ((*templateArgNames).size() &&
(*templateArgNames).size() != (*templateArgsAttr).size()) {
return emitOpError("number of template argument names must be equal to "
"number of template arguments");
}
}
}

if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
Expand Down
26 changes: 23 additions & 3 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,31 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
};

auto emitNamedArgs =
[&](std::tuple<const Attribute &, const Attribute &> tuple)
-> LogicalResult {
Attribute attr = std::get<0>(tuple);
StringAttr argName = cast<StringAttr>(std::get<1>(tuple));

os << "/*" << argName.str() << "=*/";
return emitArgs(attr);
};

if (callOpaqueOp.getTemplateArgs()) {
os << "<";
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
emitArgs)))
return failure();
if (callOpaqueOp.getTemplateArgNames() &&
!callOpaqueOp.getTemplateArgNames()->empty()) {
if (failed(interleaveCommaWithError(
llvm::zip(*callOpaqueOp.getTemplateArgs(),
*callOpaqueOp.getTemplateArgNames()),
os, emitNamedArgs))) {
return failure();
}
} else {
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
emitArgs)))
return failure();
}
os << ">";
}

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,19 @@ func.func @test_verbatim(%arg0 : !emitc.ptr<i32>, %arg1 : i32) {
emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr<i32>, i32
return
}

// -----

func.func @template_args_with_names(%arg0: i32) {
// expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}}
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N", "P"], template_args = [42 : i32]} : (i32) -> ()
return
}

// -----

func.func @template_args_with_names2(%arg0: i32) {
// expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}}
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"], template_args = [42 : i32, 56 : i32]} : (i32) -> ()
return
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,10 @@ func.func @member_access(%arg0: !emitc.opaque<"mystruct">, %arg1: !emitc.opaque<
%2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.ptr<!emitc.opaque<"mystruct">>) -> i32
return
}

func.func @template_args_with_names(%arg0: i32, %arg1: f32) {
emitc.call_opaque "kernel1"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> ()
emitc.call_opaque "kernel2"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [42 : i32]} : (i32, f32) -> ()
emitc.call_opaque "kernel3"(%arg0, %arg1) {template_arg_names = [], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> ()
return
}
20 changes: 20 additions & 0 deletions mlir/test/Target/Cpp/template_arg_names.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT

// CPP-DEFAULT: void basic(int32_t v1, float v2) {
func.func @basic(%arg0: i32, %arg1: f32) {
emitc.call_opaque "kernel1"() : () -> ()
// CPP-DEFAULT: kernel1();
%0 = emitc.call_opaque "kernel2"(%arg0) : (i32) -> i16
// CPP-DEFAULT: int16_t v3 = kernel2(v1);
emitc.call_opaque "kernel3"(%arg0, %arg1) : (i32, f32) -> ()
// CPP-DEFAULT: kernel3(v1, v2);
emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> ()
// CPP-DEFAULT: kernel4</*N=*/42, /*P=*/56>(v1, v2);
emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> ()
// CPP-DEFAULT: kernel4</*N=*/42>(v1, v2);
return
// CPP-DEFAULT: return;
}
// CPP-DEFAULT: }