Skip to content

Commit fccfee4

Browse files
author
Razvan Lupusoru
committed
[flang][acc] Ensure fir.class is handled in type categorization
fir.class is treated similarly as fir.box - but it has one key distinction which is that it doesn't hold an element type. Thus the categorization logic was mishandling this case for this reason (and also the fact that it assumed that a base object is always a fir.ref). This PR improves this handling and adds appropriate test exercising both a class and a class field to ensure categorization works.
1 parent 68239b7 commit fccfee4

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,13 @@ template <>
320320
mlir::acc::VariableTypeCategory
321321
OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
322322
mlir::Value var) const {
323+
// Class-type does not behave like a normal box because it does not hold an
324+
// element type. Thus special handle it here.
325+
if (mlir::isa<fir::ClassType>(type))
326+
return mlir::acc::VariableTypeCategory::composite;
323327

324328
mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
329+
assert(eleTy && "expect to be able to unwrap the element type");
325330

326331
// If the type enclosed by the box is a mappable type, then have it
327332
// provide the type category.
@@ -346,7 +351,7 @@ OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
346351
return mlir::acc::VariableTypeCategory::nonscalar;
347352
}
348353

349-
static mlir::TypedValue<mlir::acc::PointerLikeType>
354+
static mlir::Value
350355
getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
351356
// If there is no defining op - the unwrapped reference is the base one.
352357
mlir::Operation *op = varPtr.getDefiningOp();
@@ -372,7 +377,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
372377
})
373378
.Default([&](mlir::Operation *) { return varPtr; });
374379

375-
return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
380+
return baseRef;
376381
}
377382

378383
static mlir::acc::VariableTypeCategory
@@ -384,10 +389,17 @@ categorizePointee(mlir::Type pointer,
384389
// value would both be represented as !fir.ref<f32>. We do not want to treat
385390
// such a reference as a scalar. Thus unwrap interior pointer calculations.
386391
auto baseRef = getBaseRef(varPtr);
387-
mlir::Type eleTy = baseRef.getType().getElementType();
388392

389-
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
390-
return mappableTy.getTypeCategory(varPtr);
393+
if (auto mappableTy =
394+
mlir::dyn_cast<mlir::acc::MappableType>(baseRef.getType()))
395+
return mappableTy.getTypeCategory(baseRef);
396+
397+
// It must be a pointer-like type since it is not a MappableType.
398+
auto ptrLikeTy = mlir::cast<mlir::acc::PointerLikeType>(baseRef.getType());
399+
mlir::Type eleTy = ptrLikeTy.getElementType();
400+
401+
if (auto mappableEleTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
402+
return mappableEleTy.getTypeCategory(varPtr);
391403

392404
if (isScalarLike(eleTy))
393405
return mlir::acc::VariableTypeCategory::scalar;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module mm
2+
type, public :: polyty
3+
real :: field
4+
end type
5+
contains
6+
subroutine init(this)
7+
class(polyty), intent(inout) :: this
8+
!$acc enter data copyin(this, this%field)
9+
end subroutine
10+
end module
11+
12+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
13+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this", structured = false}
14+
! CHECK: Mappable: !fir.class<!fir.type<_QMmmTpolyty{field:f32}>>
15+
! CHECK: Type category: composite
16+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this%field", structured = false}
17+
! CHECK: Pointer-like: !fir.ref<f32>
18+
! CHECK: Type category: composite

flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Arith/IR/Arith.h"
910
#include "mlir/Dialect/OpenACC/OpenACC.h"
1011
#include "mlir/IR/Builders.h"
1112
#include "mlir/IR/BuiltinOps.h"
1213
#include "mlir/Pass/Pass.h"
1314
#include "mlir/Support/LLVM.h"
15+
#include "flang/Optimizer/Dialect/FIRDialect.h"
16+
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
1417
#include "flang/Optimizer/Support/DataLayout.h"
18+
#include "mlir/Dialect/DLTI/DLTI.h"
1519

1620
using namespace mlir;
1721

@@ -25,6 +29,11 @@ struct TestFIROpenACCInterfaces
2529
StringRef getDescription() const final {
2630
return "Test FIR implementation of the OpenACC interfaces.";
2731
}
32+
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
33+
registry.insert<fir::FIROpsDialect, hlfir::hlfirDialect,
34+
mlir::arith::ArithDialect, mlir::acc::OpenACCDialect,
35+
mlir::DLTIDialect>();
36+
}
2837
void runOnOperation() override {
2938
mlir::ModuleOp mod = getOperation();
3039
auto datalayout =

0 commit comments

Comments
 (0)