Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flang/include/flang/Optimizer/OpenMP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def MapInfoFinalizationPass
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}

def MapsForPrivatizedSymbolsPass
: Pass<"omp-maps-for-privatized-symbols", "mlir::func::FuncOp"> {
let summary = "Creates MapInfoOp instances for privatized symbols when needed";
let description = [{
Adds omp.map.info operations for privatized symbols on omp.target ops
In certain situations, such as when an allocatable is privatized, its
descriptor is needed in the alloc region of the privatizer. This results
in the use of the descriptor inside the target region. As such, the
descriptor then needs to be mapped. This pass adds such MapInfoOp operations.
}];
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}

def MarkDeclareTargetPass
: Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
let summary = "Marks all functions called by an OpenMP declare target function as declare target";
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/OpenMP/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

add_flang_library(FlangOpenMPTransforms
FunctionFiltering.cpp
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp

Expand Down
131 changes: 131 additions & 0 deletions flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//===- MapsForPrivatizedSymbols.cpp
//-----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
/// \file
/// An OpenMP dialect related pass for FIR/HLFIR which creates MapInfoOp
/// instances for certain privatized symbols.
/// For example, if an allocatable variable is used in a private clause attached
/// to a omp.target op, then the allocatable variable's descriptor will be
/// needed on the device (e.g. GPU). This descriptor needs to be separately
/// mapped onto the device. This pass creates the necessary omp.map.info ops for
/// this.
//===----------------------------------------------------------------------===//
// TODO:
// 1. Before adding omp.map.info, check if in case we already have an
// omp.map.info for the variable in question.
// 2. Generalize this for more than just omp.target ops.
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/Debug.h"
#include <type_traits>

#define DEBUG_TYPE "omp-maps-for-privatized-symbols"

namespace flangomp {
#define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp
using namespace mlir;
namespace {
class MapsForPrivatizedSymbolsPass
: public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
MapsForPrivatizedSymbolsPass> {

bool privatizerNeedsMap(omp::PrivateClauseOp &privatizer) {
Region &allocRegion = privatizer.getAllocRegion();
Value blockArg0 = allocRegion.getArgument(0);
if (blockArg0.use_empty())
return false;
return true;
}
omp::MapInfoOp createMapInfo(Location loc, Value var, OpBuilder &builder) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only a comment for you but for me as well, it would be nicer to reuse the createMapInfoOp from Utils.cpp. I think you cannot do that at the moment because of linking issues, right? Did you give it a try?

Copy link
Contributor Author

@bhandarkar-pranav bhandarkar-pranav Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I didn't try because my createMapInfo is so specific (Doing llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO only and in all likelihood going to deal with descriptors only). Having said that code reuse is good, so let me try what you are suggesting. Of course, there may even be some linking issues as you point out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kareem, at this point, I didn't choose to do it (See my latest update). This is because, as you'll see I had to write special handlingfor BaseBoxType and BoxCharType. Let me know what you think.

uint64_t mapTypeTo = static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
Operation *definingOp = var.getDefiningOp();
auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp);
assert(declOp &&
"Expected defining Op of privatized var to be hlfir.declare");
Value varPtr = declOp.getOriginalBase();

return builder.create<omp::MapInfoOp>(
loc, varPtr.getType(), varPtr,
TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
.getElementType()),
/*varPtrPtr=*/Value{},
/*members=*/SmallVector<Value>{},
/*member_index=*/DenseIntElementsAttr{},
/*bounds=*/ValueRange{},
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
mapTypeTo),
builder.getAttr<omp::VariableCaptureKindAttr>(
omp::VariableCaptureKind::ByRef),
StringAttr(), builder.getBoolAttr(false));
}
void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
Location loc = targetOp.getLoc();
targetOp.getMapVarsMutable().append(ValueRange{mapInfoOp});
size_t numMapVars = targetOp.getMapVars().size();
targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
loc);
}
void addMapInfoOps(omp::TargetOp targetOp,
llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) {
for (auto mapInfoOp : mapInfoOps)
addMapInfoOp(targetOp, mapInfoOp);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
OpBuilder builder(context);
llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4>>
mapInfoOpsForTarget;
getOperation()->walk([&](omp::TargetOp targetOp) {
if (targetOp.getPrivateVars().empty())
return;
OperandRange privVars = targetOp.getPrivateVars();
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {

SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
omp::PrivateClauseOp privatizer =
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
targetOp, privatizerName);
if (!privatizerNeedsMap(privatizer)) {
continue;
}
builder.setInsertionPoint(targetOp);
Location loc = targetOp.getLoc();
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
mapInfoOps.push_back(mapInfoOp);
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
LLVM_DEBUG(mapInfoOp.dump());
}
if (!mapInfoOps.empty()) {
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
}
});
if (!mapInfoOpsForTarget.empty()) {
for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) {
addMapInfoOps(static_cast<omp::TargetOp>(targetOp), mapInfoOps);
}
}
}
};
} // namespace
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Passes/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
/// rather than the host device.
void createOpenMPFIRPassPipeline(mlir::PassManager &pm, bool isTargetDevice) {
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
if (isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ end subroutine target_allocatable
! CHECK-SAME: @[[VAR_PRIVATIZER_SYM:.*]] :
! CHECK-SAME: [[TYPE:!fir.ref<!fir.box<!fir.heap<i32>>>]] alloc {
! CHECK: ^bb0(%[[PRIV_ARG:.*]]: [[TYPE]]):
! CHECK: %[[PRIV_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "alloc_var", {{.*}}}
! CHECK: %[[PRIV_ALLOC:.*]] = fir.alloca [[DESC_TYPE:!fir.box<!fir.heap<i32>>]] {bindc_name = "alloc_var", {{.*}}}

! CHECK-NEXT: %[[PRIV_ARG_VAL:.*]] = fir.load %[[PRIV_ARG]] : !fir.ref<!fir.box<!fir.heap<i32>>>
! CHECK-NEXT: %[[PRIV_ARG_BOX:.*]] = fir.box_addr %[[PRIV_ARG_VAL]] : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
! CHECK-NEXT: %[[PRIV_ARG_VAL:.*]] = fir.load %[[PRIV_ARG]] : [[TYPE]]
! CHECK-NEXT: %[[PRIV_ARG_BOX:.*]] = fir.box_addr %[[PRIV_ARG_VAL]] : ([[DESC_TYPE]]) -> !fir.heap<i32>
! CHECK-NEXT: %[[PRIV_ARG_ADDR:.*]] = fir.convert %[[PRIV_ARG_BOX]] : (!fir.heap<i32>) -> i64
! CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : i64
! CHECK-NEXT: %[[ALLOC_COND:.*]] = arith.cmpi ne, %[[PRIV_ARG_ADDR]], %[[C0]] : i64

! CHECK-NEXT: fir.if %[[ALLOC_COND]] {
! CHECK: %[[PRIV_ALLOCMEM:.*]] = fir.allocmem i32 {fir.must_be_heap = true, {{.*}}}
! CHECK-NEXT: %[[PRIV_ALLOCMEM_BOX:.*]] = fir.embox %[[PRIV_ALLOCMEM]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
! CHECK-NEXT: fir.store %[[PRIV_ALLOCMEM_BOX]] to %[[PRIV_ALLOC]] : !fir.ref<!fir.box<!fir.heap<i32>>>
! CHECK-NEXT: %[[PRIV_ALLOCMEM_BOX:.*]] = fir.embox %[[PRIV_ALLOCMEM]] : (!fir.heap<i32>) -> [[DESC_TYPE]]
! CHECK-NEXT: fir.store %[[PRIV_ALLOCMEM_BOX]] to %[[PRIV_ALLOC]] : [[TYPE]]
! CHECK-NEXT: } else {
! CHECK-NEXT: %[[ZERO_BITS:.*]] = fir.zero_bits !fir.heap<i32>
! CHECK-NEXT: %[[ZERO_BOX:.*]] = fir.embox %[[ZERO_BITS]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
! CHECK-NEXT: fir.store %[[ZERO_BOX]] to %[[PRIV_ALLOC]] : !fir.ref<!fir.box<!fir.heap<i32>>>
! CHECK-NEXT: %[[ZERO_BOX:.*]] = fir.embox %[[ZERO_BITS]] : (!fir.heap<i32>) -> [[DESC_TYPE]]
! CHECK-NEXT: fir.store %[[ZERO_BOX]] to %[[PRIV_ALLOC]] : [[TYPE]]
! CHECK-NEXT: }

! CHECK-NEXT: %[[PRIV_DECL:.*]]:2 = hlfir.declare %[[PRIV_ALLOC]]
Expand Down Expand Up @@ -63,9 +63,11 @@ end subroutine target_allocatable

! CHECK-LABEL: func.func @_QPtarget_allocatable() {

! CHECK: %[[VAR_ALLOC:.*]] = fir.alloca !fir.box<!fir.heap<i32>>
! CHECK: %[[VAR_ALLOC:.*]] = fir.alloca [[DESC_TYPE]]
! CHECK-SAME: {bindc_name = "alloc_var", {{.*}}}
! CHECK: %[[VAR_DECL:.*]]:2 = hlfir.declare %[[VAR_ALLOC]]

! CHECK: omp.target private(
! CHECK: %[[MAP_VAR:.*]] = omp.map.info var_ptr(%[[VAR_DECL]]#1 : [[TYPE]], [[DESC_TYPE]])
! CHECK-SAME: map_clauses(to) capture(ByRef) -> [[TYPE]]
! CHECK: omp.target map_entries(%[[MAP_VAR]] -> %arg0 : [[TYPE]]) private(
! CHECK-SAME: @[[VAR_PRIVATIZER_SYM]] %[[VAR_DECL]]#0 -> %{{.*}} : [[TYPE]]) {
83 changes: 83 additions & 0 deletions flang/test/Transforms/omp-maps-for-privatized-symbols.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// RUN: fir-opt --split-input-file --omp-maps-for-privatized-symbols %s | FileCheck %s
module attributes {omp.is_target_device = false} {
omp.private {type = private} @_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 : !fir.ref<!fir.box<!fir.heap<i32>>> alloc {
^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
%0 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", pinned, uniq_name = "_QFtarget_simpleEsimple_var"}
%1 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
%2 = fir.box_addr %1 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
%3 = fir.convert %2 : (!fir.heap<i32>) -> i64
%c0_i64 = arith.constant 0 : i64
%4 = arith.cmpi ne, %3, %c0_i64 : i64
fir.if %4 {
%6 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
%7 = fir.box_addr %6 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
%8 = fir.allocmem i32 {fir.must_be_heap = true, uniq_name = "_QFtarget_simpleEsimple_var.alloc"}
%9 = fir.embox %8 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
fir.store %9 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
} else {
%6 = fir.zero_bits !fir.heap<i32>
%7 = fir.embox %6 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
fir.store %7 to %0 : !fir.ref<!fir.box<!fir.heap<i32>>>
}
%5:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
omp.yield(%5#0 : !fir.ref<!fir.box<!fir.heap<i32>>>)
} dealloc {
^bb0(%arg0: !fir.ref<!fir.box<!fir.heap<i32>>>):
%0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
%1 = fir.box_addr %0 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
%2 = fir.convert %1 : (!fir.heap<i32>) -> i64
%c0_i64 = arith.constant 0 : i64
%3 = arith.cmpi ne, %2, %c0_i64 : i64
fir.if %3 {
%false = arith.constant false
%4 = fir.absent !fir.box<none>
%c70 = arith.constant 70 : index
%c10_i32 = arith.constant 10 : i32
%6 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
%7 = fir.box_addr %6 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
fir.freemem %7 : !fir.heap<i32>
%8 = fir.zero_bits !fir.heap<i32>
%9 = fir.embox %8 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
fir.store %9 to %arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>
}
omp.yield
}
func.func @_QPtarget_simple() {
%0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
%1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
%3 = fir.zero_bits !fir.heap<i32>
%4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
%5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
%c2_i32 = arith.constant 2 : i32
hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
%6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>) {
%11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
%c10_i32 = arith.constant 10 : i32
%13 = fir.load %11#0 : !fir.ref<i32>
%14 = arith.addi %c10_i32, %13 : i32
hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
omp.terminator
}
%7 = fir.load %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
%8 = fir.box_addr %7 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
%9 = fir.convert %8 : (!fir.heap<i32>) -> i64
%c0_i64 = arith.constant 0 : i64
%10 = arith.cmpi ne, %9, %c0_i64 : i64
fir.if %10 {
%11 = fir.load %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
%12 = fir.box_addr %11 : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
fir.freemem %12 : !fir.heap<i32>
%13 = fir.zero_bits !fir.heap<i32>
%14 = fir.embox %13 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
fir.store %14 to %5#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
}
return
}
}
// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<i32>>>
// CHECK: omp.target map_entries(%[[MAP0]] -> %arg0, %[[MAP1]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>)
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
objects (e.g. derived types or classes), indicates the bounds to be copied
of the variable. When it's an array slice it is in rank order where rank 0
is the inner-most dimension.
- 'map_clauses': OpenMP map type for this map capture, for example: from, to and
- 'map_type': OpenMP map type for this map capture, for example: from, to and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the typo catch! :-)

always. It's a bitfield composed of the OpenMP runtime flags stored in
OpenMPOffloadMappingFlags.
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
Expand Down
Loading