Skip to content

Commit 686e513

Browse files
committed
Add custom parsers/printers for has_device_addr
1 parent ba7925f commit 686e513

File tree

5 files changed

+59
-33
lines changed

5 files changed

+59
-33
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ class OpenMP_HasDeviceAddrClauseSkip<
463463
bit description = false, bit extraClassDeclaration = false
464464
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
465465
extraClassDeclaration> {
466+
let traits = [
467+
BlockArgOpenMPOpInterface
468+
];
469+
466470
let arguments = (ins
467471
Variadic<OpenMP_PointerLikeType>:$has_device_addr_vars
468472
);
@@ -473,11 +477,6 @@ class OpenMP_HasDeviceAddrClauseSkip<
473477
}
474478
}];
475479

476-
let optAssemblyFormat = [{
477-
`has_device_addr` `(` $has_device_addr_vars `:` type($has_device_addr_vars)
478-
`)`
479-
}];
480-
481480
let description = [{
482481
The optional `has_device_addr_vars` indicates that list items already have
483482
device addresses, so they may be directly accessed from the target device.

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
13151315
}] # clausesExtraClassDeclaration;
13161316

13171317
let assemblyFormat = clausesAssemblyFormat # [{
1318-
custom<HostEvalInReductionMapPrivateRegion>(
1319-
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
1318+
custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion>(
1319+
$region, $has_device_addr_vars, type($has_device_addr_vars),
1320+
$host_eval_vars, type($host_eval_vars), $in_reduction_vars,
13201321
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
13211322
$map_vars, type($map_vars), $private_vars, type($private_vars),
13221323
$private_syms, $private_maps) attr-dict

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ struct ReductionParseArgs {
508508
};
509509

510510
struct AllRegionParseArgs {
511+
std::optional<MapParseArgs> hasDeviceAddrArgs;
511512
std::optional<MapParseArgs> hostEvalArgs;
512513
std::optional<ReductionParseArgs> inReductionArgs;
513514
std::optional<MapParseArgs> mapArgs;
@@ -666,6 +667,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
666667
AllRegionParseArgs args) {
667668
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
668669

670+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
671+
args.hasDeviceAddrArgs)))
672+
return parser.emitError(parser.getCurrentLocation())
673+
<< "invalid `has_device_addr` format";
674+
669675
if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
670676
args.hostEvalArgs)))
671677
return parser.emitError(parser.getCurrentLocation())
@@ -709,8 +715,12 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
709715
return parser.parseRegion(region, entryBlockArgs);
710716
}
711717

712-
static ParseResult parseHostEvalInReductionMapPrivateRegion(
718+
// See custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion> in the
719+
// definition of TargetOp.
720+
static ParseResult parseHasDeviceAddrHostEvalInReductionMapPrivateRegion(
713721
OpAsmParser &parser, Region &region,
722+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars,
723+
SmallVectorImpl<Type> &hasDeviceAddrTypes,
714724
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
715725
SmallVectorImpl<Type> &hostEvalTypes,
716726
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
@@ -722,6 +732,7 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
722732
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
723733
DenseI64ArrayAttr &privateMaps) {
724734
AllRegionParseArgs args;
735+
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
725736
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
726737
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
727738
inReductionByref, inReductionSyms);
@@ -731,6 +742,7 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
731742
return parseBlockArgRegion(parser, region, args);
732743
}
733744

745+
// See custom<InReductionPrivateRegion> in the definition of TaskOp.
734746
static ParseResult parseInReductionPrivateRegion(
735747
OpAsmParser &parser, Region &region,
736748
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
@@ -745,6 +757,8 @@ static ParseResult parseInReductionPrivateRegion(
745757
return parseBlockArgRegion(parser, region, args);
746758
}
747759

760+
// See custom<InReductionPrivateReductionRegion> in the definition of
761+
// TaskloopOp.
748762
static ParseResult parseInReductionPrivateReductionRegion(
749763
OpAsmParser &parser, Region &region,
750764
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
@@ -765,6 +779,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
765779
return parseBlockArgRegion(parser, region, args);
766780
}
767781

782+
// See custom<PrivateRegion> in the definition of SingleOp.
768783
static ParseResult parsePrivateRegion(
769784
OpAsmParser &parser, Region &region,
770785
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
@@ -774,6 +789,7 @@ static ParseResult parsePrivateRegion(
774789
return parseBlockArgRegion(parser, region, args);
775790
}
776791

792+
// See custom<PrivateReductionRegion> in the definition of LoopOp.
777793
static ParseResult parsePrivateReductionRegion(
778794
OpAsmParser &parser, Region &region,
779795
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
@@ -789,6 +805,7 @@ static ParseResult parsePrivateReductionRegion(
789805
return parseBlockArgRegion(parser, region, args);
790806
}
791807

808+
// See custom<TaskReductionRegion> in the definition of TaskgroupOp.
792809
static ParseResult parseTaskReductionRegion(
793810
OpAsmParser &parser, Region &region,
794811
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
@@ -800,6 +817,8 @@ static ParseResult parseTaskReductionRegion(
800817
return parseBlockArgRegion(parser, region, args);
801818
}
802819

820+
// See custom<UseDeviceAddrUseDevicePtrRegion> in the definition of
821+
// TargetDataOp.
803822
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
804823
OpAsmParser &parser, Region &region,
805824
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
@@ -842,6 +861,7 @@ struct ReductionPrintArgs {
842861
: vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
843862
};
844863
struct AllRegionPrintArgs {
864+
std::optional<MapPrintArgs> hasDeviceAddrArgs;
845865
std::optional<MapPrintArgs> hostEvalArgs;
846866
std::optional<ReductionPrintArgs> inReductionArgs;
847867
std::optional<MapPrintArgs> mapArgs;
@@ -935,6 +955,9 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
935955
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
936956
MLIRContext *ctx = op->getContext();
937957

958+
printBlockArgClause(p, ctx, "has_device_addr",
959+
iface.getHasDeviceAddrBlockArgs(),
960+
args.hasDeviceAddrArgs);
938961
printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
939962
args.hostEvalArgs);
940963
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
@@ -957,14 +980,19 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
957980
p.printRegion(region, /*printEntryBlockArgs=*/false);
958981
}
959982

960-
static void printHostEvalInReductionMapPrivateRegion(
961-
OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
962-
TypeRange hostEvalTypes, ValueRange inReductionVars,
963-
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
964-
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
965-
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
983+
// See custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion> in the
984+
// definition of TargetOp.
985+
static void printHasDeviceAddrHostEvalInReductionMapPrivateRegion(
986+
OpAsmPrinter &p, Operation *op, Region &region,
987+
ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
988+
ValueRange hostEvalVars, TypeRange hostEvalTypes,
989+
ValueRange inReductionVars, TypeRange inReductionTypes,
990+
DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
991+
ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
992+
TypeRange privateTypes, ArrayAttr privateSyms,
966993
DenseI64ArrayAttr privateMaps) {
967994
AllRegionPrintArgs args;
995+
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
968996
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
969997
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
970998
inReductionByref, inReductionSyms);
@@ -973,6 +1001,7 @@ static void printHostEvalInReductionMapPrivateRegion(
9731001
printBlockArgRegion(p, op, region, args);
9741002
}
9751003

1004+
// See custom<InReductionPrivateRegion> in the definition of TaskOp.
9761005
static void printInReductionPrivateRegion(
9771006
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
9781007
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
@@ -986,6 +1015,8 @@ static void printInReductionPrivateRegion(
9861015
printBlockArgRegion(p, op, region, args);
9871016
}
9881017

1018+
// See custom<InReductionPrivateReductionRegion> in the definition of
1019+
// TaskloopOp.
9891020
static void printInReductionPrivateReductionRegion(
9901021
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
9911022
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
@@ -1003,6 +1034,7 @@ static void printInReductionPrivateReductionRegion(
10031034
printBlockArgRegion(p, op, region, args);
10041035
}
10051036

1037+
// See custom<PrivateRegion> in the definition of SingleOp.
10061038
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
10071039
ValueRange privateVars, TypeRange privateTypes,
10081040
ArrayAttr privateSyms) {
@@ -1012,6 +1044,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
10121044
printBlockArgRegion(p, op, region, args);
10131045
}
10141046

1047+
// See custom<PrivateReductionRegion> in the definition of LoopOp.
10151048
static void printPrivateReductionRegion(
10161049
OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
10171050
TypeRange privateTypes, ArrayAttr privateSyms,
@@ -1026,6 +1059,7 @@ static void printPrivateReductionRegion(
10261059
printBlockArgRegion(p, op, region, args);
10271060
}
10281061

1062+
// See custom<TaskReductionRegion> in the definition of TaskgroupOp.
10291063
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
10301064
Region &region,
10311065
ValueRange taskReductionVars,
@@ -1038,6 +1072,8 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
10381072
printBlockArgRegion(p, op, region, args);
10391073
}
10401074

1075+
// See custom<UseDeviceAddrUseDevicePtrRegion> in the definition of
1076+
// TargetDataOp.
10411077
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
10421078
Region &region,
10431079
ValueRange useDeviceAddrVars,

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,6 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
761761
return
762762
}
763763

764-
765764
// CHECK-LABEL: omp_target
766765
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
767766

@@ -773,17 +772,19 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
773772
}) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
774773

775774
// Test with optional map clause.
776-
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
777-
// CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
778-
// CHECK: omp.target has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) is_device_ptr(%[[VAL_4:.*]] : memref<i32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
775+
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
776+
// CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
777+
// CHECK: %[[MAP_C:.*]] = omp.map.info var_ptr(%[[VAL_3:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
778+
// CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[MAP_A]] -> {{.*}} : memref<?xi32>) map_entries(%[[MAP_B]] -> {{.*}}, %[[MAP_C]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
779+
%mapv0 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
779780
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
780781
%mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
781-
omp.target is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
782+
omp.target is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%mapv0 -> %arg0 : memref<?xi32>) map_entries(%mapv1 -> %arg1, %mapv2 -> %arg2 : memref<?xi32>, memref<?xi32>) {
782783
omp.terminator
783784
}
784-
// CHECK: %[[MAP_C:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
785-
// CHECK: %[[MAP_D:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
786-
// CHECK: omp.target map_entries(%[[MAP_C]] -> {{.*}}, %[[MAP_D]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
785+
// CHECK: %[[MAP_D:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
786+
// CHECK: %[[MAP_E:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
787+
// CHECK: omp.target map_entries(%[[MAP_D]] -> {{.*}}, %[[MAP_E]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
787788
%mapv3 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
788789
%mapv4 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
789790
omp.target map_entries(%mapv3 -> %arg0, %mapv4 -> %arg1 : memref<?xi32>, memref<?xi32>) {

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,6 @@ llvm.func @target_device(%x : i32) {
308308

309309
// -----
310310

311-
llvm.func @target_has_device_addr(%x : !llvm.ptr) {
312-
// expected-error@below {{not yet implemented: Unhandled clause has_device_addr in omp.target operation}}
313-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
314-
omp.target has_device_addr(%x : !llvm.ptr) {
315-
omp.terminator
316-
}
317-
llvm.return
318-
}
319-
320-
// -----
321-
322311
omp.declare_reduction @add_f32 : f32
323312
init {
324313
^bb0(%arg: f32):

0 commit comments

Comments
 (0)