Skip to content

Commit 427beff

Browse files
authored
[OpenMP][MLIR] Add private clause to omp.target (#91202)
1 parent d24eaef commit 427beff

File tree

4 files changed

+98
-8
lines changed

4 files changed

+98
-8
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
17871787
UnitAttr:$nowait,
17881788
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
17891789
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
1790-
Variadic<AnyType>:$map_operands);
1790+
Variadic<AnyType>:$map_operands,
1791+
Variadic<AnyType>:$private_vars,
1792+
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
1793+
17911794
let regions = (region AnyRegion:$region);
17921795

17931796
let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
18021805
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
18031806
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
18041807
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
1808+
| `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
18051809
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
18061810
) $region attr-dict
18071811
}];

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,13 +470,17 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
470470
ValueRange argsSubrange,
471471
StringRef clauseName, ValueRange operands,
472472
TypeRange types, ArrayAttr symbols) {
473-
p << clauseName << "(";
473+
if (!clauseName.empty())
474+
p << clauseName << "(";
475+
474476
llvm::interleaveComma(
475477
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
476478
auto [sym, op, arg, type] = t;
477479
p << sym << " " << op << " -> " << arg << " : " << type;
478480
});
479-
p << ") ";
481+
482+
if (!clauseName.empty())
483+
p << ") ";
480484
}
481485

482486
static ParseResult parseParallelRegion(
@@ -1048,6 +1052,49 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
10481052
}
10491053
}
10501054

1055+
static ParseResult parsePrivateList(
1056+
OpAsmParser &parser,
1057+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
1058+
SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
1059+
SmallVector<SymbolRefAttr> privateSymRefs;
1060+
SmallVector<OpAsmParser::Argument> regionPrivateArgs;
1061+
1062+
if (failed(parser.parseCommaSeparatedList([&]() {
1063+
if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
1064+
parser.parseOperand(privateOperands.emplace_back()) ||
1065+
parser.parseArrow() ||
1066+
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
1067+
parser.parseColonType(privateOperandTypes.emplace_back()))
1068+
return failure();
1069+
return success();
1070+
})))
1071+
return failure();
1072+
1073+
SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
1074+
privateSymRefs.end());
1075+
privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
1076+
1077+
return success();
1078+
}
1079+
1080+
static void printPrivateList(OpAsmPrinter &p, Operation *op,
1081+
ValueRange privateVarOperands,
1082+
TypeRange privateVarTypes,
1083+
ArrayAttr privatizerSymbols) {
1084+
// TODO: Remove target-specific logic from this function.
1085+
auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
1086+
assert(targetOp);
1087+
1088+
auto &region = op->getRegion(0);
1089+
auto *argsBegin = region.front().getArguments().begin();
1090+
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
1091+
argsBegin + targetOp.getMapOperands().size() +
1092+
privateVarTypes.size());
1093+
printClauseWithRegionArgs(
1094+
p, op, argsSubrange, /*clauseName=*/llvm::StringRef{}, privateVarOperands,
1095+
privateVarTypes, privatizerSymbols);
1096+
}
1097+
10511098
static void printCaptureType(OpAsmPrinter &p, Operation *op,
10521099
VariableCaptureKindAttr mapCaptureType) {
10531100
std::string typeCapStr;
@@ -1256,13 +1303,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
12561303
const TargetClauseOps &clauses) {
12571304
MLIRContext *ctx = builder.getContext();
12581305
// TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1259-
// inReductionDeclSymbols, privateVars, privatizers, reductionVars,
1260-
// reductionByRefAttr, reductionDeclSymbols.
1306+
// inReductionDeclSymbols, reductionVars, reductionByRefAttr,
1307+
// reductionDeclSymbols.
12611308
TargetOp::build(
12621309
builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
12631310
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
12641311
clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
1265-
clauses.mapVars);
1312+
clauses.mapVars, clauses.privateVars,
1313+
makeArrayAttr(ctx, clauses.privatizers));
12661314
}
12671315

12681316
LogicalResult TargetOp::verify() {

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
20872087
// expected-error @below {{op expected as many depend values as depend variables}}
20882088
"omp.target"(%data_var) ({
20892089
"omp.terminator"() : () -> ()
2090-
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
2090+
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
20912091
"func.return"() : () -> ()
20922092
}
20932093

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
737737
"omp.target"(%if_cond, %device, %num_threads) ({
738738
// CHECK: omp.terminator
739739
omp.terminator
740-
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
740+
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>} : ( i1, si32, i32 ) -> ()
741741

742742
// Test with optional map clause.
743743
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
25502550
}
25512551
return
25522552
}
2553+
2554+
// CHECK-LABEL: omp_target_private
2555+
func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
2556+
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
2557+
%mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
2558+
2559+
// CHECK: omp.target
2560+
// CHECK-SAME: private(
2561+
// CHECK-SAME: @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
2562+
// CHECK-SAME: : !llvm.ptr
2563+
// CHECK-SAME: )
2564+
omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
2565+
// CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
2566+
^bb0(%priv_arg: !llvm.ptr):
2567+
omp.terminator
2568+
}
2569+
2570+
// CHECK: omp.target
2571+
2572+
// CHECK-SAME: map_entries(
2573+
// CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]],
2574+
// CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]]
2575+
// CHECK-SAME: : memref<?xi32>, memref<?xi32>
2576+
// CHECK-SAME: )
2577+
2578+
// CHECK-SAME: private(
2579+
// CHECK-SAME: @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
2580+
// CHECK-SAME: : !llvm.ptr
2581+
// CHECK-SAME: )
2582+
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
2583+
// CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
2584+
// CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
2585+
^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
2586+
omp.terminator
2587+
}
2588+
2589+
return
2590+
}

0 commit comments

Comments
 (0)