Skip to content

Commit a262ac0

Browse files
authored
[flang][cuda] Make operations dynamically legal in cuf op conversion (llvm#102220)
1 parent 5972819 commit a262ac0

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,20 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
234234
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
235235
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
236236
/*forceUnifiedTBAATree=*/false, *dl);
237-
238-
target.addIllegalOp<cuf::AllocOp, cuf::AllocateOp, cuf::DeallocateOp,
239-
cuf::FreeOp>();
237+
target.addDynamicallyLegalOp<cuf::AllocOp>([](::cuf::AllocOp op) {
238+
return !mlir::isa<fir::BaseBoxType>(op.getInType());
239+
});
240+
target.addDynamicallyLegalOp<cuf::FreeOp>([](::cuf::FreeOp op) {
241+
if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
242+
op.getDevptr().getType())) {
243+
return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy());
244+
}
245+
return true;
246+
});
247+
target.addDynamicallyLegalOp<cuf::AllocateOp>(
248+
[](::cuf::AllocateOp op) { return isBoxGlobal(op); });
249+
target.addDynamicallyLegalOp<cuf::DeallocateOp>(
250+
[](::cuf::DeallocateOp op) { return isBoxGlobal(op); });
240251
patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
241252
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
242253
CufFreeOpConversion>(ctx);

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ func.func @_QPsub1() {
1414
return
1515
}
1616

17-
1817
// CHECK-LABEL: func.func @_QPsub1()
1918
// CHECK: %[[DESC_RT_CALL:.*]] = fir.call @_FortranACUFAllocDesciptor(%{{.*}}, %{{.*}}, %{{.*}}) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
2019
// CHECK: %[[DESC:.*]] = fir.convert %[[DESC_RT_CALL]] : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
@@ -27,4 +26,37 @@ func.func @_QPsub1() {
2726
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
2827
// CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
2928

29+
// Check operations that should not be transformed yet.
30+
func.func @_QPsub2() {
31+
%0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
32+
cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
33+
return
3034
}
35+
36+
// CHECK-LABEL: func.func @_QPsub2()
37+
// CHECK: cuf.alloc !fir.array<10xf32>
38+
// CHECK: cuf.free %{{.*}} : !fir.ref<!fir.array<10xf32>>
39+
40+
fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
41+
%0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
42+
%c0 = arith.constant 0 : index
43+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
44+
%2 = fir.embox %0(%1) : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
45+
fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xf32>>>
46+
}
47+
48+
func.func @_QPsub3() {
49+
%0 = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
50+
%1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
51+
%2 = cuf.allocate %1#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
52+
%3 = cuf.deallocate %1#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
53+
return
54+
}
55+
56+
// CHECK-LABEL: func.func @_QPsub3()
57+
// CHECK: cuf.allocate
58+
// CHECK: cuf.deallocate
59+
60+
}
61+
62+

0 commit comments

Comments
 (0)