Skip to content

Commit 6fcb1f7

Browse files
clementvalLukacma
authored andcommitted
[flang][cuda] Update c_loc with device variable to get host address (llvm#164317)
Bypass the declare op because it is rewritten in CUFOpConversion and will only provide the device address. c_loc is expected to have the host address of a device address to be used in API like `cudaMemcpyToSymbol` so we need to provide the address of op directly.
1 parent 8be36bf commit 6fcb1f7

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,11 +3516,23 @@ static mlir::Value getAddrFromBox(fir::FirOpBuilder &builder,
35163516
return addr;
35173517
}
35183518

3519+
static void clocDeviceArgRewrite(fir::ExtendedValue arg) {
3520+
// Special case for device address in c_loc.
3521+
if (auto emboxOp = mlir::dyn_cast_or_null<fir::EmboxOp>(
3522+
fir::getBase(arg).getDefiningOp()))
3523+
if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
3524+
emboxOp.getMemref().getDefiningOp()))
3525+
if (declareOp.getDataAttr() &&
3526+
declareOp.getDataAttr() == cuf::DataAttribute::Device)
3527+
emboxOp.getMemrefMutable().assign(declareOp.getMemref());
3528+
}
3529+
35193530
static fir::ExtendedValue
35203531
genCLocOrCFunLoc(fir::FirOpBuilder &builder, mlir::Location loc,
35213532
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args,
35223533
bool isFunc = false, bool isDevLoc = false) {
35233534
assert(args.size() == 1);
3535+
clocDeviceArgRewrite(args[0]);
35243536
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
35253537
mlir::Value resAddr;
35263538
if (isDevLoc)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
module symbols
4+
integer(4), device, target :: sdev(100)
5+
end module
6+
7+
subroutine sub1
8+
use iso_c_binding
9+
use symbols
10+
print*, c_loc(sdev)
11+
end subroutine
12+
13+
! CHECK-LABEL: func.func @_QPsub1()
14+
! CHECK: %[[ADDR:.*]] = fir.address_of(@_QMsymbolsEsdev) : !fir.ref<!fir.array<100xi32>>
15+
! CHECK: %[[EMBOX:.*]] = fir.embox %[[ADDR]](%{{.*}}) : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<100xi32>>
16+
! CHECK: %[[__ADDRESS:.*]] = fir.coordinate_of %{{.*}}, __address : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) -> !fir.ref<i64>
17+
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[EMBOX]] : (!fir.box<!fir.array<100xi32>>) -> !fir.ref<!fir.array<100xi32>>
18+
! CHECK: %[[CONV:.*]] = fir.convert %[[BOX_ADDR]] : (!fir.ref<!fir.array<100xi32>>) -> i64
19+
! CHECK: fir.store %[[CONV]] to %[[__ADDRESS]] : !fir.ref<i64>

0 commit comments

Comments
 (0)