Skip to content

Commit e25d207

Browse files
lmendesp-amdcferry-AMDmgehre-amd
authored
Include comments with template argument names in Cpp code from EmitC (#403)
* Include comments with template arg names in Cpp code from EmitC * Apply suggestions from code review Co-authored-by: Corentin Ferry <[email protected]> Co-authored-by: Matthias Gehre <[email protected]> * Test for the presence of template arg names when there are no template args --------- Co-authored-by: Corentin Ferry <[email protected]> Co-authored-by: Matthias Gehre <[email protected]>
1 parent 20a6720 commit e25d207

File tree

6 files changed

+83
-4
lines changed

6 files changed

+83
-4
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
292292
Arg<StrAttr, "the C++ function to call">:$callee,
293293
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
294294
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
295+
Arg<OptionalAttr<StrArrayAttr>, "template argument names">:$template_arg_names,
295296
Variadic<EmitCType>:$operands
296297
);
297298
let results = (outs Variadic<EmitCType>);
@@ -302,7 +303,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
302303
"::mlir::ValueRange":$operands,
303304
CArg<"::mlir::ArrayAttr", "{}">:$args,
304305
CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{
305-
build($_builder, $_state, resultTypes, callee, args, template_args,
306+
build($_builder, $_state, resultTypes, callee, args, template_args, {},
306307
operands);
307308
}]
308309
>

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,19 @@ LogicalResult emitc::CallOpaqueOp::verify() {
355355
}
356356
}
357357

358+
if (std::optional<ArrayAttr> templateArgNames = getTemplateArgNames()) {
359+
if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
360+
if ((*templateArgNames).size() &&
361+
(*templateArgNames).size() != (*templateArgsAttr).size()) {
362+
return emitOpError("number of template argument names must be equal to "
363+
"number of template arguments");
364+
}
365+
} else {
366+
return emitOpError("should not have names for template arguments if it "
367+
"does not have template arguments");
368+
}
369+
}
370+
358371
if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
359372
return emitOpError() << "cannot return array type";
360373
}

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,11 +659,31 @@ static LogicalResult printOperation(CppEmitter &emitter,
659659
return success();
660660
};
661661

662+
auto emitNamedArgs =
663+
[&](std::tuple<const Attribute &, const Attribute &> tuple)
664+
-> LogicalResult {
665+
Attribute attr = std::get<0>(tuple);
666+
StringAttr argName = cast<StringAttr>(std::get<1>(tuple));
667+
668+
os << "/*" << argName.str() << "=*/";
669+
return emitArgs(attr);
670+
};
671+
662672
if (callOpaqueOp.getTemplateArgs()) {
663673
os << "<";
664-
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
665-
emitArgs)))
666-
return failure();
674+
if (callOpaqueOp.getTemplateArgNames() &&
675+
!callOpaqueOp.getTemplateArgNames()->empty()) {
676+
if (failed(interleaveCommaWithError(
677+
llvm::zip_equal(*callOpaqueOp.getTemplateArgs(),
678+
*callOpaqueOp.getTemplateArgNames()),
679+
os, emitNamedArgs))) {
680+
return failure();
681+
}
682+
} else {
683+
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
684+
emitArgs)))
685+
return failure();
686+
}
667687
os << ">";
668688
}
669689

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,27 @@ func.func @test_verbatim(%arg0 : !emitc.ptr<i32>, %arg1 : i32) {
524524
emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr<i32>, i32
525525
return
526526
}
527+
528+
// -----
529+
530+
func.func @template_args_with_names(%arg0: i32) {
531+
// expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}}
532+
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N", "P"], template_args = [42 : i32]} : (i32) -> ()
533+
return
534+
}
535+
536+
// -----
537+
538+
func.func @template_args_with_names(%arg0: i32) {
539+
// expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}}
540+
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"], template_args = [42 : i32, 56 : i32]} : (i32) -> ()
541+
return
542+
}
543+
544+
// -----
545+
546+
func.func @template_args_with_names(%arg0: i32) {
547+
// expected-error @+1 {{'emitc.call_opaque' op should not have names for template arguments if it does not have template arguments}}
548+
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"]} : (i32) -> ()
549+
return
550+
}

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,10 @@ func.func @member_access(%arg0: !emitc.opaque<"mystruct">, %arg1: !emitc.opaque<
282282
%2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.ptr<!emitc.opaque<"mystruct">>) -> i32
283283
return
284284
}
285+
286+
func.func @template_args_with_names(%arg0: i32, %arg1: f32) {
287+
emitc.call_opaque "kernel1"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> ()
288+
emitc.call_opaque "kernel2"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [42 : i32]} : (i32, f32) -> ()
289+
emitc.call_opaque "kernel3"(%arg0, %arg1) {template_arg_names = [], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> ()
290+
return
291+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
2+
3+
// CPP-DEFAULT-LABEL: void basic
4+
func.func @basic(%arg0: i32, %arg1: f32) {
5+
emitc.call_opaque "kernel3"(%arg0, %arg1) : (i32, f32) -> ()
6+
// CPP-DEFAULT: kernel3(
7+
emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> ()
8+
// CPP-DEFAULT: kernel4</*N=*/42, /*P=*/56>(
9+
emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> ()
10+
// CPP-DEFAULT: kernel4</*N=*/42>(
11+
return
12+
}
13+
14+

0 commit comments

Comments
 (0)