Skip to content

Commit 6edf90c

Browse files
committed
[CIR] Add PointerLikeType interface support for cir::PointerType
This adds code to attach the OpenACC PointerLikeType interface to cir::PointerType, along with a unit test for the interface.
1 parent 4f0be94 commit 6edf90c

File tree

11 files changed

+326
-1
lines changed

11 files changed

+326
-1
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains external dialect interfaces for CIR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef CLANG_CIR_DIALECT_OPENACC_CIROPENACCTYPEINTERFACES_H
14+
#define CLANG_CIR_DIALECT_OPENACC_CIROPENACCTYPEINTERFACES_H
15+
16+
#include "mlir/Dialect/OpenACC/OpenACC.h"
17+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
18+
19+
namespace cir::acc {
20+
21+
template <typename T>
22+
struct OpenACCPointerLikeModel
23+
: public mlir::acc::PointerLikeType::ExternalModel<
24+
OpenACCPointerLikeModel<T>, T> {
25+
mlir::Type getElementType(mlir::Type pointer) const {
26+
return mlir::cast<T>(pointer).getPointee();
27+
}
28+
mlir::acc::VariableTypeCategory
29+
getPointeeTypeCategory(mlir::Type pointer,
30+
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
31+
mlir::Type varType) const;
32+
};
33+
34+
} // namespace cir::acc
35+
36+
#endif // CLANG_CIR_DIALECT_OPENACC_CIROPENACCTYPEINTERFACES_H
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef CLANG_CIR_DIALECT_OPENACC_REGISTEROPENACCEXTENSIONS_H
10+
#define CLANG_CIR_DIALECT_OPENACC_REGISTEROPENACCEXTENSIONS_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
} // namespace mlir
15+
16+
namespace cir::acc {
17+
18+
void registerOpenACCExtensions(mlir::DialectRegistry &registry);
19+
20+
} // namespace cir::acc
21+
22+
#endif // CLANG_CIR_DIALECT_OPENACC_REGISTEROPENACCEXTENSIONS_H

clang/lib/CIR/CodeGen/CIRGenerator.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "clang/AST/DeclGroup.h"
1919
#include "clang/CIR/CIRGenerator.h"
2020
#include "clang/CIR/Dialect/IR/CIRDialect.h"
21+
#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
2122

2223
using namespace cir;
2324
using namespace clang;
@@ -38,6 +39,12 @@ void CIRGenerator::Initialize(ASTContext &astContext) {
3839
mlirContext = std::make_unique<mlir::MLIRContext>();
3940
mlirContext->loadDialect<cir::CIRDialect>();
4041
mlirContext->getOrLoadDialect<mlir::acc::OpenACCDialect>();
42+
43+
// Register extensions to integrate CIR types with OpenACC.
44+
mlir::DialectRegistry registry;
45+
cir::acc::registerOpenACCExtensions(registry);
46+
mlirContext->appendDialectRegistry(registry);
47+
4148
cgm = std::make_unique<clang::CIRGen::CIRGenModule>(
4249
*mlirContext.get(), astContext, codeGenOpts, diags);
4350
}

clang/lib/CIR/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_clang_library(clangCIR
3535
clangBasic
3636
clangLex
3737
${dialect_libs}
38+
CIROpenACCSupport
3839
MLIRCIR
3940
MLIRCIRInterfaces
4041
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
2+
add_subdirectory(OpenACC)
23
add_subdirectory(Transforms)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Implementation of external dialect interfaces for CIR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "clang/CIR/Dialect/OpenACC/CIROpenACCTypeInterfaces.h"
14+
15+
namespace cir::acc {
16+
17+
template <>
18+
mlir::acc::VariableTypeCategory
19+
OpenACCPointerLikeModel<cir::PointerType>::getPointeeTypeCategory(
20+
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
21+
mlir::Type varType) const {
22+
mlir::Type eleTy = mlir::cast<cir::PointerType>(pointer).getPointee();
23+
24+
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
25+
return mappableTy.getTypeCategory(varPtr);
26+
27+
if (isAnyIntegerOrFloatingPointType(eleTy) ||
28+
mlir::isa<cir::BoolType>(eleTy) || mlir::isa<cir::PointerType>(eleTy))
29+
return mlir::acc::VariableTypeCategory::scalar;
30+
if (mlir::isa<cir::ArrayType>(eleTy))
31+
return mlir::acc::VariableTypeCategory::array;
32+
if (mlir::isa<cir::RecordType>(eleTy))
33+
return mlir::acc::VariableTypeCategory::composite;
34+
if (mlir::isa<cir::FuncType>(eleTy) || mlir::isa<cir::VectorType>(eleTy))
35+
return mlir::acc::VariableTypeCategory::nonscalar;
36+
37+
// Without further checking, this type cannot be categorized.
38+
return mlir::acc::VariableTypeCategory::uncategorized;
39+
}
40+
41+
} // namespace cir::acc
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_clang_library(CIROpenACCSupport
2+
CIROpenACCTypeInterfaces.cpp
3+
RegisterOpenACCExtensions.cpp
4+
5+
DEPENDS
6+
MLIRCIRTypeConstraintsIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRIR
10+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Registration for OpenACC extensions as applied to CIR dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
14+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
15+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
16+
#include "clang/CIR/Dialect/OpenACC/CIROpenACCTypeInterfaces.h"
17+
18+
namespace cir::acc {
19+
20+
void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
21+
registry.addExtension(+[](mlir::MLIRContext *ctx, cir::CIRDialect *dialect) {
22+
cir::PointerType::attachInterface<
23+
OpenACCPointerLikeModel<cir::PointerType>>(*ctx);
24+
});
25+
}
26+
27+
} // namespace cir::acc

clang/unittests/CIR/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
add_clang_unittest(CIRUnitTests
2+
PointerLikeTest.cpp
3+
LLVM_COMPONENTS
4+
Core
5+
6+
LINK_LIBS
7+
MLIRCIR
8+
CIROpenACCSupport
9+
MLIRIR
10+
MLIROpenACCDialect
11+
)
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Unit tests for CIR implementation of OpenACC's PointertLikeType interface
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/OpenACC/OpenACC.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "mlir/IR/Diagnostics.h"
16+
#include "mlir/IR/MLIRContext.h"
17+
#include "mlir/IR/Value.h"
18+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
19+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
20+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
21+
#include "clang/CIR/Dialect/OpenACC/CIROpenACCTypeInterfaces.h"
22+
#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
23+
#include "gtest/gtest.h"
24+
25+
using namespace mlir;
26+
using namespace cir;
27+
28+
//===----------------------------------------------------------------------===//
29+
// Test Fixture
30+
//===----------------------------------------------------------------------===//
31+
32+
class CIROpenACCPointerLikeTest : public ::testing::Test {
33+
protected:
34+
CIROpenACCPointerLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
35+
context.loadDialect<cir::CIRDialect>();
36+
context.loadDialect<mlir::acc::OpenACCDialect>();
37+
38+
// Register extension to integrate CIR types with OpenACC.
39+
mlir::DialectRegistry registry;
40+
cir::acc::registerOpenACCExtensions(registry);
41+
context.appendDialectRegistry(registry);
42+
}
43+
44+
MLIRContext context;
45+
OpBuilder b;
46+
Location loc;
47+
48+
mlir::IntegerAttr getSizeFromCharUnits(mlir::MLIRContext *ctx,
49+
clang::CharUnits size) {
50+
// Note that mlir::IntegerType is used instead of cir::IntType here
51+
// because we don't need sign information for this to be useful, so keep
52+
// it simple.
53+
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
54+
size.getQuantity());
55+
}
56+
57+
// General handler for types without a specific test
58+
void testElementType(mlir::Type ty) {
59+
mlir::Type ptrTy = cir::PointerType::get(ty);
60+
61+
// cir::PointerType should be castable to acc::PointerLikeType
62+
auto pltTy = dyn_cast_if_present<mlir::acc::PointerLikeType>(ptrTy);
63+
ASSERT_NE(pltTy, nullptr);
64+
65+
EXPECT_EQ(pltTy.getElementType(), ty);
66+
67+
OwningOpRef<cir::AllocaOp> varPtrOp = b.create<cir::AllocaOp>(
68+
loc, ptrTy, ty, "",
69+
getSizeFromCharUnits(&context, clang::CharUnits::One()));
70+
71+
mlir::Value val = varPtrOp.get();
72+
mlir::acc::VariableTypeCategory typeCategory = pltTy.getPointeeTypeCategory(
73+
cast<TypedValue<mlir::acc::PointerLikeType>>(val),
74+
mlir::acc::getVarType(varPtrOp.get()));
75+
76+
if (isAnyIntegerOrFloatingPointType(ty) ||
77+
mlir::isa<cir::PointerType>(ty) || mlir::isa<cir::BoolType>(ty)) {
78+
EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::scalar);
79+
} else if (mlir::isa<cir::ArrayType>(ty)) {
80+
EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::array);
81+
} else if (mlir::isa<cir::RecordType>(ty)) {
82+
EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::composite);
83+
} else if (mlir::isa<cir::FuncType, cir::VectorType>(ty)) {
84+
EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::nonscalar);
85+
} else if (mlir::isa<cir::VoidType>(ty)) {
86+
EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::uncategorized);
87+
} else {
88+
EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::uncategorized);
89+
// If we hit this, we need to add support for a new type.
90+
ASSERT_TRUE(false);
91+
}
92+
}
93+
};
94+
95+
TEST_F(CIROpenACCPointerLikeTest, testPointerToInt) {
96+
// Test various scalar types.
97+
testElementType(cir::IntType::get(&context, 8, true));
98+
testElementType(cir::IntType::get(&context, 8, false));
99+
testElementType(cir::IntType::get(&context, 16, true));
100+
testElementType(cir::IntType::get(&context, 16, false));
101+
testElementType(cir::IntType::get(&context, 32, true));
102+
testElementType(cir::IntType::get(&context, 32, false));
103+
testElementType(cir::IntType::get(&context, 64, true));
104+
testElementType(cir::IntType::get(&context, 64, false));
105+
testElementType(cir::IntType::get(&context, 128, true));
106+
testElementType(cir::IntType::get(&context, 128, false));
107+
}
108+
109+
TEST_F(CIROpenACCPointerLikeTest, testPointerToBool) {
110+
testElementType(cir::BoolType::get(&context));
111+
}
112+
113+
TEST_F(CIROpenACCPointerLikeTest, testPointerToFloat) {
114+
testElementType(cir::SingleType::get(&context));
115+
testElementType(cir::DoubleType::get(&context));
116+
}
117+
118+
TEST_F(CIROpenACCPointerLikeTest, testPointerToPointer) {
119+
mlir::Type i32Ty = cir::IntType::get(&context, 32, true);
120+
mlir::Type ptrTy = cir::PointerType::get(i32Ty);
121+
testElementType(ptrTy);
122+
}
123+
124+
TEST_F(CIROpenACCPointerLikeTest, testPointerToArray) {
125+
// Test an array type.
126+
mlir::Type i32Ty = cir::IntType::get(&context, 32, true);
127+
testElementType(cir::ArrayType::get(i32Ty, 10));
128+
}
129+
130+
TEST_F(CIROpenACCPointerLikeTest, testPointerToStruct) {
131+
// Test a struct type.
132+
mlir::Type i32Ty = cir::IntType::get(&context, 32, true);
133+
llvm::ArrayRef<mlir::Type> fields = {i32Ty, i32Ty};
134+
cir::RecordType structTy = cir::RecordType::get(
135+
&context, b.getStringAttr("S"), cir::RecordType::RecordKind::Struct);
136+
structTy.complete(fields, false, false);
137+
testElementType(structTy);
138+
139+
// Test a union type.
140+
cir::RecordType unionTy = cir::RecordType::get(
141+
&context, b.getStringAttr("U"), cir::RecordType::RecordKind::Union);
142+
unionTy.complete(fields, false, false);
143+
testElementType(unionTy);
144+
}
145+
146+
TEST_F(CIROpenACCPointerLikeTest, testPointerToFunction) {
147+
mlir::Type i32Ty = cir::IntType::get(&context, 32, true);
148+
cir::FuncType::get(SmallVector<mlir::Type, 2>{i32Ty, i32Ty}, i32Ty);
149+
}
150+
151+
TEST_F(CIROpenACCPointerLikeTest, testPointerToVector) {
152+
mlir::Type i32Ty = cir::IntType::get(&context, 32, true);
153+
mlir::Type vecTy = cir::VectorType::get(i32Ty, 4);
154+
testElementType(vecTy);
155+
}
156+
157+
TEST_F(CIROpenACCPointerLikeTest, testPointerToVoid) {
158+
mlir::Type voidTy = cir::VoidType::get(&context);
159+
testElementType(voidTy);
160+
}

0 commit comments

Comments
 (0)