diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 289e648eed854..d2c814cc958dd 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -1,5 +1,4 @@ -//===- MapsForPrivatizedSymbols.cpp -//-----------------------------------------===// +//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -28,6 +27,7 @@ #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/OpenMP/Passes.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/BuiltinAttributes.h" @@ -124,6 +124,8 @@ class MapsForPrivatizedSymbolsPass if (targetOp.getPrivateVars().empty()) return; OperandRange privVars = targetOp.getPrivateVars(); + llvm::SmallVector privVarMapIdx; + std::optional privSyms = targetOp.getPrivateSyms(); SmallVector mapInfoOps; for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) { @@ -133,17 +135,25 @@ class MapsForPrivatizedSymbolsPass SymbolTable::lookupNearestSymbolFrom( targetOp, privatizerName); if (!privatizerNeedsMap(privatizer)) { + privVarMapIdx.push_back(-1); continue; } + + privVarMapIdx.push_back(targetOp.getMapVars().size() + + mapInfoOps.size()); + builder.setInsertionPoint(targetOp); Location loc = targetOp.getLoc(); omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder); mapInfoOps.push_back(mapInfoOp); + LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n"); LLVM_DEBUG(mapInfoOp.dump()); } if (!mapInfoOps.empty()) { mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps}); + targetOp.setPrivateMapsAttr( + mlir::DenseI64ArrayAttr::get(targetOp.getContext(), privVarMapIdx)); } }); if (!mapInfoOpsForTarget.empty()) { diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 index b0c76ff3845f8..f3f9bbe4a76a2 100644 --- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 +++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 @@ -171,12 +171,12 @@ end subroutine target_allocatable ! CHECK_SAME %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] : ! CHECK-SAME !fir.ref, !fir.ref>>, !fir.ref>>, !fir.ref>) ! CHECK-SAME: private( -! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]], +! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1], ! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]], ! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]], -! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]], +! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2], ! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]], -! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] : +! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] : ! CHECK-SAME: !fir.ref>>, !fir.ref, !fir.ref, !fir.box>, !fir.ref>, !fir.boxchar<1>) { ! CHECK-NOT: fir.alloca ! CHECK: hlfir.declare %[[ALLOC_ARG]] diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 156e6eb371b85..f6c7f19fffddf 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1225,21 +1225,46 @@ def TargetOp : OpenMP_Op<"target", traits = [ The optional `if_expr` parameter specifies a boolean result of a conditional check. If this value is 1 or is not provided then the target region runs on a device, if it is 0 then the target region is executed on the host device. + + The `private_maps` attribute connects `private` operands to their corresponding + `map` operands. For `private` operands that require a map, the value of the + corresponding element in the attribute is the index of the `map` operand + (relative to other `map` operands not the whole operands of the operation). For + `private` opernads that do not require a map, this value is -1 (which is omitted + from the assembly foramt printing). }] # clausesDescription; + let arguments = !con(clausesArgs, + (ins OptionalAttr:$private_maps)); + let builders = [ OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)> ]; let extraClassDeclaration = [{ unsigned numMapBlockArgs() { return getMapVars().size(); } + + mlir::Value getMappedValueForPrivateVar(unsigned privVarIdx) { + std::optional privateMapIdices = getPrivateMapsAttr(); + + if (!privateMapIdices.has_value()) + return {}; + + int64_t mapInfoOpIdx = (*privateMapIdices)[privVarIdx]; + + if (mapInfoOpIdx == -1) + return {}; + + return getMapVars()[mapInfoOpIdx]; + } }] # clausesExtraClassDeclaration; let assemblyFormat = clausesAssemblyFormat # [{ custom( $region, $in_reduction_vars, type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars), - $private_vars, type($private_vars), $private_syms) attr-dict + $private_vars, type($private_vars), $private_syms, $private_maps) + attr-dict }]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 94e71e089d4b1..8c5f79a49a334 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -487,9 +487,11 @@ struct PrivateParseArgs { llvm::SmallVectorImpl &vars; llvm::SmallVectorImpl &types; ArrayAttr &syms; + DenseI64ArrayAttr *mapIndices; PrivateParseArgs(SmallVectorImpl &vars, - SmallVectorImpl &types, ArrayAttr &syms) - : vars(vars), types(types), syms(syms) {} + SmallVectorImpl &types, ArrayAttr &syms, + DenseI64ArrayAttr *mapIndices = nullptr) + : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} }; struct ReductionParseArgs { SmallVectorImpl &vars; @@ -517,8 +519,10 @@ static ParseResult parseClauseWithRegionArgs( SmallVectorImpl &operands, SmallVectorImpl &types, SmallVectorImpl ®ionPrivateArgs, - ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) { + ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr, + DenseBoolArrayAttr *byref = nullptr) { SmallVector symbolVec; + SmallVector mapIndicesVec; SmallVector isByRefVec; unsigned regionArgOffset = regionPrivateArgs.size(); @@ -538,6 +542,16 @@ static ParseResult parseClauseWithRegionArgs( parser.parseArgument(regionPrivateArgs.emplace_back())) return failure(); + if (mapIndices) { + if (parser.parseOptionalLSquare().succeeded()) { + if (parser.parseKeyword("map_idx") || parser.parseEqual() || + parser.parseInteger(mapIndicesVec.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else + mapIndicesVec.push_back(-1); + } + return success(); })) return failure(); @@ -571,6 +585,10 @@ static ParseResult parseClauseWithRegionArgs( *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs); } + if (!mapIndicesVec.empty()) + *mapIndices = + mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec); + if (byref) *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); @@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause( static ParseResult parseBlockArgClause( OpAsmParser &parser, llvm::SmallVectorImpl &entryBlockArgs, - StringRef keyword, std::optional reductionArgs) { + StringRef keyword, std::optional privateArgs) { if (succeeded(parser.parseOptionalKeyword(keyword))) { - if (!reductionArgs) + if (!privateArgs) return failure(); - if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars, - reductionArgs->types, entryBlockArgs, - &reductionArgs->syms))) + if (failed(parseClauseWithRegionArgs( + parser, privateArgs->vars, privateArgs->types, entryBlockArgs, + &privateArgs->syms, privateArgs->mapIndices))) return failure(); } return success(); @@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause( if (failed(parseClauseWithRegionArgs( parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs, - &reductionArgs->syms, &reductionArgs->byref))) + &reductionArgs->syms, /*mapIndices=*/nullptr, + &reductionArgs->byref))) return failure(); } return success(); @@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion( SmallVectorImpl &mapVars, SmallVectorImpl &mapTypes, llvm::SmallVectorImpl &privateVars, - llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms) { + llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms, + DenseI64ArrayAttr &privateMaps) { AllRegionParseArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + &privateMaps); return parseBlockArgRegion(parser, region, args); } @@ -776,8 +797,10 @@ struct PrivatePrintArgs { ValueRange vars; TypeRange types; ArrayAttr syms; - PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms) - : vars(vars), types(types), syms(syms) {} + DenseI64ArrayAttr mapIndices; + PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms, + DenseI64ArrayAttr mapIndices) + : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} }; struct ReductionPrintArgs { ValueRange vars; @@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols = nullptr, + DenseI64ArrayAttr mapIndices = nullptr, DenseBoolArrayAttr byref = nullptr) { if (argsSubrange.empty()) return; @@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, symbols = ArrayAttr::get(ctx, values); } + if (!mapIndices) { + llvm::SmallVector values(operands.size(), -1); + mapIndices = DenseI64ArrayAttr::get(ctx, values); + } + if (!byref) { mlir::SmallVector values(operands.size(), false); byref = DenseBoolArrayAttr::get(ctx, values); } - llvm::interleaveComma( - llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p, - [&p](auto t) { - auto [op, arg, sym, isByRef] = t; - if (isByRef) - p << "byref "; - if (sym) - p << sym << " "; - p << op << " -> " << arg; - }); + llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols, + mapIndices.asArrayRef(), + byref.asArrayRef()), + p, [&p](auto t) { + auto [op, arg, sym, map, isByRef] = t; + if (isByRef) + p << "byref "; + if (sym) + p << sym << " "; + + p << op << " -> " << arg; + + if (map != -1) + p << " [map_idx=" << map << "]"; + }); p << " : "; llvm::interleaveComma(types, p); p << ") "; @@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, if (privateArgs) printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types, - privateArgs->syms); + privateArgs->syms, privateArgs->mapIndices); } static void @@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, if (reductionArgs) printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, reductionArgs->vars, reductionArgs->types, - reductionArgs->syms, reductionArgs->byref); + reductionArgs->syms, /*mapIndices=*/nullptr, + reductionArgs->byref); } static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, @@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion( OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, - ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) { + ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, + DenseI64ArrayAttr privateMaps) { AllRegionPrintArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps); printBlockArgRegion(p, op, region, args); } @@ -908,7 +944,8 @@ static void printInReductionPrivateRegion( AllRegionPrintArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + /*mapIndices=*/nullptr); printBlockArgRegion(p, op, region, args); } @@ -921,7 +958,8 @@ static void printInReductionPrivateReductionRegion( AllRegionPrintArgs args; args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + /*mapIndices=*/nullptr); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, reductionSyms); printBlockArgRegion(p, op, region, args); @@ -931,7 +969,8 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) { AllRegionPrintArgs args; - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + /*mapIndices=*/nullptr); printBlockArgRegion(p, op, region, args); } @@ -941,7 +980,8 @@ static void printPrivateReductionRegion( TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) { AllRegionPrintArgs args; - args.privateArgs.emplace(privateVars, privateTypes, privateSyms); + args.privateArgs.emplace(privateVars, privateTypes, privateSyms, + /*mapIndices=*/nullptr); args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, reductionSyms); printBlockArgRegion(p, op, region, args); @@ -1560,6 +1600,24 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { return success(); } +static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) { + std::optional privateMapIndices = + targetOp.getPrivateMapsAttr(); + + // None of the private operands are mapped. + if (!privateMapIndices.has_value() || !privateMapIndices.value()) + return success(); + + OperandRange privateVars = targetOp.getPrivateVars(); + + if (privateMapIndices.value().size() != + static_cast(privateVars.size())) + return emitError(targetOp.getLoc(), "sizes of `private` operand range and " + "`private_maps` attribute mismatch"); + + return success(); +} + //===----------------------------------------------------------------------===// // TargetDataOp //===----------------------------------------------------------------------===// @@ -1656,14 +1714,23 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit); + makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit, + /*private_maps=*/nullptr); } LogicalResult TargetOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDependKinds(), getDependVars()); - return failed(verifyDependVars) ? verifyDependVars - : verifyMapClause(*this, getMapVars()); + + if (failed(verifyDependVars)) + return verifyDependVars; + + LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars()); + + if (failed(verifyMapVars)) + return verifyMapVars; + + return verifyPrivateVarsMapping(*this); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index c25a6ef4b4849..94c63dd8e9aa0 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref, %map2: memref, %priv_ return } +// CHECK-LABEL: omp_target_private_with_map_idx +func.func @omp_target_private_with_map_idx(%map1: memref, %map2: memref, %priv_var: !llvm.ptr) -> () { + %mapv1 = omp.map.info var_ptr(%map1 : memref, tensor) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} + %mapv2 = omp.map.info var_ptr(%map2 : memref, tensor) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + + // CHECK: omp.target + + // CHECK-SAME: map_entries( + // CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]], + // CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]] + // CHECK-SAME: : memref, memref + // CHECK-SAME: ) + + // CHECK-SAME: private( + // CHECK-SAME: @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]] [map_idx=1] + // CHECK-SAME: : !llvm.ptr + // CHECK-SAME: ) + omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref, memref) private(@x.privatizer %priv_var -> %priv_arg [map_idx=1] : !llvm.ptr) { + omp.terminator + } + + return +} + // CHECK-LABEL: omp_loop func.func @omp_loop(%lb : index, %ub : index, %step : index) { // CHECK: omp.loop {