Skip to content

Commit 0a96b24

Browse files
[mlir][acc][flang] Introduce OpenACC interfaces for globals (#168614)
Introduce two new OpenACC operation interfaces for identifying global variables and their address computations: - `GlobalVariableOpInterface`: Identifies operations that define global variables. Provides an `isConstant()` method to query whether the global is constant. - `AddressOfGlobalOpInterface`: Identifies operations that compute the address of a global variable. Provides a `getSymbol()` method to retrieve the symbol reference. This is being done in preparation for `ACCImplicitDeclare` pass which will automatically ensure that `acc declare` is applied to globals when needed. The following operations now implement these interfaces: - `memref::GlobalOp` implements `GlobalVariableOpInterface` - `memref::GetGlobalOp` implements `AddressOfGlobalOpInterface` - `fir::GlobalOp` implements `GlobalVariableOpInterface` - `fir::AddrOfOp` implements `AddressOfGlobalOpInterface`
1 parent 1262acf commit 0a96b24

File tree

7 files changed

+176
-0
lines changed

7 files changed

+176
-0
lines changed

flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "mlir/Dialect/OpenACC/OpenACC.h"
1717

1818
namespace fir {
19+
class AddrOfOp;
1920
class DeclareOp;
21+
class GlobalOp;
2022
} // namespace fir
2123

2224
namespace hlfir {
@@ -53,6 +55,18 @@ struct PartialEntityAccessModel<hlfir::DeclareOp>
5355
bool isCompleteView(mlir::Operation *op) const;
5456
};
5557

58+
struct AddressOfGlobalModel
59+
: public mlir::acc::AddressOfGlobalOpInterface::ExternalModel<
60+
AddressOfGlobalModel, fir::AddrOfOp> {
61+
mlir::SymbolRefAttr getSymbol(mlir::Operation *op) const;
62+
};
63+
64+
struct GlobalVariableModel
65+
: public mlir::acc::GlobalVariableOpInterface::ExternalModel<
66+
GlobalVariableModel, fir::GlobalOp> {
67+
bool isConstant(mlir::Operation *op) const;
68+
};
69+
5670
} // namespace fir::acc
5771

5872
#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_

flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,13 @@ bool PartialEntityAccessModel<hlfir::DeclareOp>::isCompleteView(
5959
return !getBaseEntity(op);
6060
}
6161

62+
mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const {
63+
return mlir::cast<fir::AddrOfOp>(op).getSymbolAttr();
64+
}
65+
66+
bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
67+
auto globalOp = mlir::cast<fir::GlobalOp>(op);
68+
return globalOp.getConstant().has_value();
69+
}
70+
6271
} // namespace fir::acc

flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
4949
PartialEntityAccessModel<fir::CoordinateOp>>(*ctx);
5050
fir::DeclareOp::attachInterface<PartialEntityAccessModel<fir::DeclareOp>>(
5151
*ctx);
52+
53+
fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
54+
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
5255
});
5356

5457
// Register HLFIR operation interfaces

mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,35 @@ def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface
4444
];
4545
}
4646

47+
def AddressOfGlobalOpInterface : OpInterface<"AddressOfGlobalOpInterface"> {
48+
let cppNamespace = "::mlir::acc";
49+
50+
let description = [{
51+
An interface for operations that compute the address of a global variable
52+
or symbol.
53+
}];
54+
55+
let methods = [
56+
InterfaceMethod<"Get the symbol reference to the global", "::mlir::SymbolRefAttr",
57+
"getSymbol", (ins)>,
58+
];
59+
}
60+
61+
def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
62+
let cppNamespace = "::mlir::acc";
63+
64+
let description = [{
65+
An interface for operations that define global variables. This interface
66+
provides a uniform way to query properties of global variables across
67+
different dialects.
68+
}];
69+
70+
let methods = [
71+
InterfaceMethod<"Check if the global variable is constant", "bool",
72+
"isConstant", (ins), [{
73+
return false;
74+
}]>,
75+
];
76+
}
77+
4778
#endif // OPENACC_OPS_INTERFACES

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,24 @@ struct LLVMPointerPointerLikeModel
211211
Type getElementType(Type pointer) const { return Type(); }
212212
};
213213

214+
struct MemrefAddressOfGlobalModel
215+
: public AddressOfGlobalOpInterface::ExternalModel<
216+
MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
217+
SymbolRefAttr getSymbol(Operation *op) const {
218+
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
219+
return getGlobalOp.getNameAttr();
220+
}
221+
};
222+
223+
struct MemrefGlobalVariableModel
224+
: public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
225+
memref::GlobalOp> {
226+
bool isConstant(Operation *op) const {
227+
auto globalOp = cast<memref::GlobalOp>(op);
228+
return globalOp.getConstant();
229+
}
230+
};
231+
214232
/// Helper function for any of the times we need to modify an ArrayAttr based on
215233
/// a device type list. Returns a new ArrayAttr with all of the
216234
/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
@@ -302,6 +320,11 @@ void OpenACCDialect::initialize() {
302320
MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
303321
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
304322
*getContext());
323+
324+
// Attach operation interfaces
325+
memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
326+
*getContext());
327+
memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
305328
}
306329

307330
//===----------------------------------------------------------------------===//

mlir/unittests/Dialect/OpenACC/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIROpenACCTests
22
OpenACCOpsTest.cpp
3+
OpenACCOpsInterfacesTest.cpp
34
OpenACCUtilsTest.cpp
45
)
56
mlir_target_link_libraries(MLIROpenACCTests
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===- OpenACCOpsInterfacesTest.cpp - Unit tests for OpenACC interfaces --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
10+
#include "mlir/Dialect/OpenACC/OpenACC.h"
11+
#include "mlir/IR/Builders.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/MLIRContext.h"
14+
#include "mlir/IR/OwningOpRef.h"
15+
#include "gtest/gtest.h"
16+
17+
using namespace mlir;
18+
using namespace mlir::acc;
19+
20+
//===----------------------------------------------------------------------===//
21+
// Test Fixture
22+
//===----------------------------------------------------------------------===//
23+
24+
class OpenACCOpsInterfacesTest : public ::testing::Test {
25+
protected:
26+
OpenACCOpsInterfacesTest()
27+
: context(), builder(&context), loc(UnknownLoc::get(&context)) {
28+
context.loadDialect<acc::OpenACCDialect, memref::MemRefDialect>();
29+
}
30+
31+
MLIRContext context;
32+
OpBuilder builder;
33+
Location loc;
34+
};
35+
36+
//===----------------------------------------------------------------------===//
37+
// GlobalVariableOpInterface Tests
38+
//===----------------------------------------------------------------------===//
39+
40+
TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceNonConstant) {
41+
// Test that a non-constant global returns false for isConstant()
42+
43+
auto memrefType = MemRefType::get({10}, builder.getF32Type());
44+
OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create(
45+
builder, loc,
46+
/*sym_name=*/builder.getStringAttr("mutable_global"),
47+
/*sym_visibility=*/builder.getStringAttr("private"),
48+
/*type=*/TypeAttr::get(memrefType),
49+
/*initial_value=*/Attribute(),
50+
/*constant=*/UnitAttr(),
51+
/*alignment=*/IntegerAttr());
52+
53+
auto globalVarIface =
54+
dyn_cast<GlobalVariableOpInterface>(globalOp->getOperation());
55+
ASSERT_TRUE(globalVarIface != nullptr);
56+
EXPECT_FALSE(globalVarIface.isConstant());
57+
}
58+
59+
TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceConstant) {
60+
// Test that a constant global returns true for isConstant()
61+
62+
auto memrefType = MemRefType::get({5}, builder.getI32Type());
63+
OwningOpRef<memref::GlobalOp> constantGlobalOp = memref::GlobalOp::create(
64+
builder, loc,
65+
/*sym_name=*/builder.getStringAttr("constant_global"),
66+
/*sym_visibility=*/builder.getStringAttr("public"),
67+
/*type=*/TypeAttr::get(memrefType),
68+
/*initial_value=*/Attribute(),
69+
/*constant=*/builder.getUnitAttr(),
70+
/*alignment=*/IntegerAttr());
71+
72+
auto globalVarIface =
73+
dyn_cast<GlobalVariableOpInterface>(constantGlobalOp->getOperation());
74+
ASSERT_TRUE(globalVarIface != nullptr);
75+
EXPECT_TRUE(globalVarIface.isConstant());
76+
}
77+
78+
//===----------------------------------------------------------------------===//
79+
// AddressOfGlobalOpInterface Tests
80+
//===----------------------------------------------------------------------===//
81+
82+
TEST_F(OpenACCOpsInterfacesTest, AddressOfGlobalOpInterfaceGetSymbol) {
83+
// Test that getSymbol() returns the correct symbol reference
84+
85+
auto memrefType = MemRefType::get({5}, builder.getI32Type());
86+
const auto *symbolName = "test_global_symbol";
87+
88+
OwningOpRef<memref::GetGlobalOp> getGlobalOp = memref::GetGlobalOp::create(
89+
builder, loc, memrefType, FlatSymbolRefAttr::get(&context, symbolName));
90+
91+
auto addrOfGlobalIface =
92+
dyn_cast<AddressOfGlobalOpInterface>(getGlobalOp->getOperation());
93+
ASSERT_TRUE(addrOfGlobalIface != nullptr);
94+
EXPECT_EQ(addrOfGlobalIface.getSymbol().getLeafReference(), symbolName);
95+
}

0 commit comments

Comments
 (0)