|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Unit tests for CIR implementation of OpenACC's PointertLikeType interface |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "mlir/Dialect/OpenACC/OpenACC.h" |
| 14 | +#include "mlir/IR/BuiltinTypes.h" |
| 15 | +#include "mlir/IR/Diagnostics.h" |
| 16 | +#include "mlir/IR/MLIRContext.h" |
| 17 | +#include "mlir/IR/Value.h" |
| 18 | +#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" |
| 19 | +#include "clang/CIR/Dialect/IR/CIRDialect.h" |
| 20 | +#include "clang/CIR/Dialect/IR/CIRTypes.h" |
| 21 | +#include "clang/CIR/Dialect/OpenACC/CIROpenACCTypeInterfaces.h" |
| 22 | +#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h" |
| 23 | +#include "gtest/gtest.h" |
| 24 | + |
| 25 | +using namespace mlir; |
| 26 | +using namespace cir; |
| 27 | + |
| 28 | +//===----------------------------------------------------------------------===// |
| 29 | +// Test Fixture |
| 30 | +//===----------------------------------------------------------------------===// |
| 31 | + |
| 32 | +class CIROpenACCPointerLikeTest : public ::testing::Test { |
| 33 | +protected: |
| 34 | + CIROpenACCPointerLikeTest() : b(&context), loc(UnknownLoc::get(&context)) { |
| 35 | + context.loadDialect<cir::CIRDialect>(); |
| 36 | + context.loadDialect<mlir::acc::OpenACCDialect>(); |
| 37 | + |
| 38 | + // Register extension to integrate CIR types with OpenACC. |
| 39 | + mlir::DialectRegistry registry; |
| 40 | + cir::acc::registerOpenACCExtensions(registry); |
| 41 | + context.appendDialectRegistry(registry); |
| 42 | + } |
| 43 | + |
| 44 | + MLIRContext context; |
| 45 | + OpBuilder b; |
| 46 | + Location loc; |
| 47 | + |
| 48 | + mlir::IntegerAttr getSizeFromCharUnits(mlir::MLIRContext *ctx, |
| 49 | + clang::CharUnits size) { |
| 50 | + // Note that mlir::IntegerType is used instead of cir::IntType here |
| 51 | + // because we don't need sign information for this to be useful, so keep |
| 52 | + // it simple. |
| 53 | + return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), |
| 54 | + size.getQuantity()); |
| 55 | + } |
| 56 | + |
| 57 | + // General handler for types without a specific test |
| 58 | + void testElementType(mlir::Type ty) { |
| 59 | + mlir::Type ptrTy = cir::PointerType::get(ty); |
| 60 | + |
| 61 | + // cir::PointerType should be castable to acc::PointerLikeType |
| 62 | + auto pltTy = dyn_cast_if_present<mlir::acc::PointerLikeType>(ptrTy); |
| 63 | + ASSERT_NE(pltTy, nullptr); |
| 64 | + |
| 65 | + EXPECT_EQ(pltTy.getElementType(), ty); |
| 66 | + |
| 67 | + OwningOpRef<cir::AllocaOp> varPtrOp = b.create<cir::AllocaOp>( |
| 68 | + loc, ptrTy, ty, "", |
| 69 | + getSizeFromCharUnits(&context, clang::CharUnits::One())); |
| 70 | + |
| 71 | + mlir::Value val = varPtrOp.get(); |
| 72 | + mlir::acc::VariableTypeCategory typeCategory = pltTy.getPointeeTypeCategory( |
| 73 | + cast<TypedValue<mlir::acc::PointerLikeType>>(val), |
| 74 | + mlir::acc::getVarType(varPtrOp.get())); |
| 75 | + |
| 76 | + if (isAnyIntegerOrFloatingPointType(ty) || |
| 77 | + mlir::isa<cir::PointerType>(ty) || mlir::isa<cir::BoolType>(ty)) { |
| 78 | + EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::scalar); |
| 79 | + } else if (mlir::isa<cir::ArrayType>(ty)) { |
| 80 | + EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::array); |
| 81 | + } else if (mlir::isa<cir::RecordType>(ty)) { |
| 82 | + EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::composite); |
| 83 | + } else if (mlir::isa<cir::FuncType, cir::VectorType>(ty)) { |
| 84 | + EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::nonscalar); |
| 85 | + } else if (mlir::isa<cir::VoidType>(ty)) { |
| 86 | + EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::uncategorized); |
| 87 | + } else { |
| 88 | + EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::uncategorized); |
| 89 | + // If we hit this, we need to add support for a new type. |
| 90 | + ASSERT_TRUE(false); |
| 91 | + } |
| 92 | + } |
| 93 | +}; |
| 94 | + |
| 95 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToInt) { |
| 96 | + // Test various scalar types. |
| 97 | + testElementType(cir::IntType::get(&context, 8, true)); |
| 98 | + testElementType(cir::IntType::get(&context, 8, false)); |
| 99 | + testElementType(cir::IntType::get(&context, 16, true)); |
| 100 | + testElementType(cir::IntType::get(&context, 16, false)); |
| 101 | + testElementType(cir::IntType::get(&context, 32, true)); |
| 102 | + testElementType(cir::IntType::get(&context, 32, false)); |
| 103 | + testElementType(cir::IntType::get(&context, 64, true)); |
| 104 | + testElementType(cir::IntType::get(&context, 64, false)); |
| 105 | + testElementType(cir::IntType::get(&context, 128, true)); |
| 106 | + testElementType(cir::IntType::get(&context, 128, false)); |
| 107 | +} |
| 108 | + |
| 109 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToBool) { |
| 110 | + testElementType(cir::BoolType::get(&context)); |
| 111 | +} |
| 112 | + |
| 113 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToFloat) { |
| 114 | + testElementType(cir::SingleType::get(&context)); |
| 115 | + testElementType(cir::DoubleType::get(&context)); |
| 116 | +} |
| 117 | + |
| 118 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToPointer) { |
| 119 | + mlir::Type i32Ty = cir::IntType::get(&context, 32, true); |
| 120 | + mlir::Type ptrTy = cir::PointerType::get(i32Ty); |
| 121 | + testElementType(ptrTy); |
| 122 | +} |
| 123 | + |
| 124 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToArray) { |
| 125 | + // Test an array type. |
| 126 | + mlir::Type i32Ty = cir::IntType::get(&context, 32, true); |
| 127 | + testElementType(cir::ArrayType::get(i32Ty, 10)); |
| 128 | +} |
| 129 | + |
| 130 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToStruct) { |
| 131 | + // Test a struct type. |
| 132 | + mlir::Type i32Ty = cir::IntType::get(&context, 32, true); |
| 133 | + llvm::ArrayRef<mlir::Type> fields = {i32Ty, i32Ty}; |
| 134 | + cir::RecordType structTy = cir::RecordType::get( |
| 135 | + &context, b.getStringAttr("S"), cir::RecordType::RecordKind::Struct); |
| 136 | + structTy.complete(fields, false, false); |
| 137 | + testElementType(structTy); |
| 138 | + |
| 139 | + // Test a union type. |
| 140 | + cir::RecordType unionTy = cir::RecordType::get( |
| 141 | + &context, b.getStringAttr("U"), cir::RecordType::RecordKind::Union); |
| 142 | + unionTy.complete(fields, false, false); |
| 143 | + testElementType(unionTy); |
| 144 | +} |
| 145 | + |
| 146 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToFunction) { |
| 147 | + mlir::Type i32Ty = cir::IntType::get(&context, 32, true); |
| 148 | + cir::FuncType::get(SmallVector<mlir::Type, 2>{i32Ty, i32Ty}, i32Ty); |
| 149 | +} |
| 150 | + |
| 151 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToVector) { |
| 152 | + mlir::Type i32Ty = cir::IntType::get(&context, 32, true); |
| 153 | + mlir::Type vecTy = cir::VectorType::get(i32Ty, 4); |
| 154 | + testElementType(vecTy); |
| 155 | +} |
| 156 | + |
| 157 | +TEST_F(CIROpenACCPointerLikeTest, testPointerToVoid) { |
| 158 | + mlir::Type voidTy = cir::VoidType::get(&context); |
| 159 | + testElementType(voidTy); |
| 160 | +} |
0 commit comments