diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index cfb18914e8126..c5f25f6cda254 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -34,6 +34,7 @@ #include "flang/Semantics/scope.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index b8aa49752d0a9..e2a60f57940f6 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -152,12 +152,6 @@ mlir::ValueRange getDataOperands(mlir::Operation *accOp); /// Used to get a mutable range iterating over the data operands. mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp); -/// Used to obtain the enclosing compute construct operation that contains -/// the provided `region`. Returns nullptr if no compute construct operation -/// is found. The returns operation is one of types defined by -///`ACC_COMPUTE_CONSTRUCT_OPS`. -mlir::Operation *getEnclosingComputeOp(mlir::Region ®ion); - /// Used to check whether the provided `type` implements the `PointerLikeType` /// interface. inline bool isPointerLikeType(mlir::Type type) { diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h new file mode 100644 index 0000000000000..378f4348f2cf1 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h @@ -0,0 +1,44 @@ +//===- OpenACCUtils.h - OpenACC Utilities -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_OPENACCUTILS_H_ +#define MLIR_DIALECT_OPENACC_OPENACCUTILS_H_ + +#include "mlir/Dialect/OpenACC/OpenACC.h" + +namespace mlir { +namespace acc { + +/// Used to obtain the enclosing compute construct operation that contains +/// the provided `region`. Returns nullptr if no compute construct operation +/// is found. The returned operation is one of types defined by +/// `ACC_COMPUTE_CONSTRUCT_OPS`. +mlir::Operation *getEnclosingComputeOp(mlir::Region ®ion); + +/// Returns true if this value is only used by `acc.private` operations in the +/// `region`. +bool isOnlyUsedByPrivateClauses(mlir::Value val, mlir::Region ®ion); + +/// Returns true if this value is only used by `acc.reduction` operations in +/// the `region`. +bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion); + +/// Looks for an OpenACC default attribute on the current operation `op` or in +/// a parent operation which encloses `op`. This is useful because OpenACC +/// specification notes that a visible default clause is the nearest default +/// clause appearing on the compute construct or a lexically containing data +/// construct. +std::optional getDefaultAttr(mlir::Operation *op); + +/// Get the type category of an OpenACC variable. +mlir::acc::VariableTypeCategory getTypeCategory(mlir::Value var); + +} // namespace acc +} // namespace mlir + +#endif // MLIR_DIALECT_OPENACC_OPENACCUTILS_H_ diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 9f57627c321fb..7117520599fa6 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(Utils) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index dcfe2c742407e..05a196a3f3b3c 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -4649,14 +4649,3 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { .Default([&](mlir::Operation *) { return nullptr; })}; return dataOperands; } - -mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { - mlir::Operation *parentOp = region.getParentOp(); - while (parentOp) { - if (mlir::isa(parentOp)) { - return parentOp; - } - parentOp = parentOp->getParentOp(); - } - return nullptr; -} diff --git a/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt new file mode 100644 index 0000000000000..68e124625921f --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIROpenACCUtils + OpenACCUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + DEPENDS + MLIROpenACCPassIncGen + MLIROpenACCOpsIncGen + MLIROpenACCEnumsIncGen + MLIROpenACCAttributesIncGen + MLIROpenACCMPOpsInterfacesIncGen + MLIROpenACCOpsInterfacesIncGen + MLIROpenACCTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIROpenACCDialect + MLIRIR + MLIRSupport +) diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp new file mode 100644 index 0000000000000..12233254f3fb4 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -0,0 +1,80 @@ +//===- OpenACCUtils.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" + +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "llvm/ADT/TypeSwitch.h" + +mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { + mlir::Operation *parentOp = region.getParentOp(); + while (parentOp) { + if (mlir::isa(parentOp)) + return parentOp; + parentOp = parentOp->getParentOp(); + } + return nullptr; +} + +template +static bool isOnlyUsedByOpClauses(mlir::Value val, mlir::Region ®ion) { + auto checkIfUsedOnlyByOpInside = [&](mlir::Operation *user) { + // For any users which are not in the current acc region, we can ignore. + // Return true so that it can be used in a `all_of` check. + if (!region.isAncestor(user->getParentRegion())) + return true; + return mlir::isa(user); + }; + + return llvm::all_of(val.getUsers(), checkIfUsedOnlyByOpInside); +} + +bool mlir::acc::isOnlyUsedByPrivateClauses(mlir::Value val, + mlir::Region ®ion) { + return isOnlyUsedByOpClauses(val, region); +} + +bool mlir::acc::isOnlyUsedByReductionClauses(mlir::Value val, + mlir::Region ®ion) { + return isOnlyUsedByOpClauses(val, region); +} + +std::optional +mlir::acc::getDefaultAttr(Operation *op) { + std::optional defaultAttr; + Operation *currOp = op; + + // Iterate outwards until a default clause is found (since OpenACC + // specification notes that a visible default clause is the nearest default + // clause appearing on the compute construct or a lexically containing data + // construct. + while (!defaultAttr.has_value() && currOp) { + defaultAttr = + llvm::TypeSwitch>(currOp) + .Case( + [&](auto op) { return op.getDefaultAttr(); }) + .Default([&](Operation *) { return std::nullopt; }); + currOp = currOp->getParentOp(); + } + + return defaultAttr; +} + +mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) { + mlir::acc::VariableTypeCategory typeCategory = + mlir::acc::VariableTypeCategory::uncategorized; + if (auto mappableTy = dyn_cast(var.getType())) + typeCategory = mappableTy.getTypeCategory(var); + else if (auto pointerLikeTy = + dyn_cast(var.getType())) + typeCategory = pointerLikeTy.getPointeeTypeCategory( + cast>(var), + pointerLikeTy.getElementType()); + return typeCategory; +} diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt index d5f40a44f8cc6..177c8680b0040 100644 --- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt @@ -1,8 +1,13 @@ add_mlir_unittest(MLIROpenACCTests OpenACCOpsTest.cpp + OpenACCUtilsTest.cpp ) mlir_target_link_libraries(MLIROpenACCTests PRIVATE MLIRIR + MLIRFuncDialect + MLIRMemRefDialect + MLIRArithDialect MLIROpenACCDialect + MLIROpenACCUtils ) diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp new file mode 100644 index 0000000000000..ab817b640edb3 --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -0,0 +1,412 @@ +//===- OpenACCUtilsTest.cpp - Unit tests for OpenACC utilities -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::acc; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class OpenACCUtilsTest : public ::testing::Test { +protected: + OpenACCUtilsTest() : b(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect(); + } + + MLIRContext context; + OpBuilder b; + Location loc; +}; + +//===----------------------------------------------------------------------===// +// getEnclosingComputeOp Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, getEnclosingComputeOpParallel) { + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + parallelRegion.emplaceBlock(); + + // Test that we can find the parallel op from its region + Operation *enclosingOp = getEnclosingComputeOp(parallelRegion); + EXPECT_EQ(enclosingOp, parallelOp.get()); +} + +TEST_F(OpenACCUtilsTest, getEnclosingComputeOpKernels) { + // Create a kernels op with a region + OwningOpRef kernelsOp = + KernelsOp::create(b, loc, TypeRange{}, ValueRange{}); + Region &kernelsRegion = kernelsOp->getRegion(); + kernelsRegion.emplaceBlock(); + + // Test that we can find the kernels op from its region + Operation *enclosingOp = getEnclosingComputeOp(kernelsRegion); + EXPECT_EQ(enclosingOp, kernelsOp.get()); +} + +TEST_F(OpenACCUtilsTest, getEnclosingComputeOpSerial) { + // Create a serial op with a region + OwningOpRef serialOp = + SerialOp::create(b, loc, TypeRange{}, ValueRange{}); + Region &serialRegion = serialOp->getRegion(); + serialRegion.emplaceBlock(); + + // Test that we can find the serial op from its region + Operation *enclosingOp = getEnclosingComputeOp(serialRegion); + EXPECT_EQ(enclosingOp, serialOp.get()); +} + +TEST_F(OpenACCUtilsTest, getEnclosingComputeOpNested) { + // Create nested ops: parallel containing a loop op + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create a loop op inside the parallel region + OwningOpRef loopOp = + LoopOp::create(b, loc, TypeRange{}, ValueRange{}); + Region &loopRegion = loopOp->getRegion(); + loopRegion.emplaceBlock(); + + // Test that from the loop region, we find the parallel op (loop is not a + // compute op) + Operation *enclosingOp = getEnclosingComputeOp(loopRegion); + EXPECT_EQ(enclosingOp, parallelOp.get()); +} + +TEST_F(OpenACCUtilsTest, getEnclosingComputeOpNone) { + // Create a module with a region that's not inside a compute construct + OwningOpRef moduleOp = ModuleOp::create(loc); + Region &moduleRegion = moduleOp->getBodyRegion(); + + // Test that we get nullptr when there's no enclosing compute op + Operation *enclosingOp = getEnclosingComputeOp(moduleRegion); + EXPECT_EQ(enclosingOp, nullptr); +} + +//===----------------------------------------------------------------------===// +// isOnlyUsedByPrivateClauses Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, isOnlyUsedByPrivateClausesTrue) { + // Create a value (memref) outside the compute region + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create a private op using the value + OwningOpRef privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Test that the value is only used by private clauses + EXPECT_TRUE(isOnlyUsedByPrivateClauses(varPtr, parallelRegion)); +} + +TEST_F(OpenACCUtilsTest, isOnlyUsedByPrivateClausesFalse) { + // Create a value (memref) outside the compute region + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create a private op using the value + OwningOpRef privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Also use the value in a function call (escape) + OwningOpRef callOp = func::CallOp::create( + b, loc, "some_func", TypeRange{}, ValueRange{varPtr}); + + // Test that the value is NOT only used by private clauses (it escapes via + // call) + EXPECT_FALSE(isOnlyUsedByPrivateClauses(varPtr, parallelRegion)); +} + +TEST_F(OpenACCUtilsTest, isOnlyUsedByPrivateClausesMultiple) { + // Create a value (memref) outside the compute region + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create multiple private ops using the value + OwningOpRef privateOp1 = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + OwningOpRef privateOp2 = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Test that the value is only used by private clauses even with multiple uses + EXPECT_TRUE(isOnlyUsedByPrivateClauses(varPtr, parallelRegion)); +} + +//===----------------------------------------------------------------------===// +// isOnlyUsedByReductionClauses Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, isOnlyUsedByReductionClausesTrue) { + // Create a value (memref) outside the compute region + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create a reduction op using the value + OwningOpRef reductionOp = ReductionOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Test that the value is only used by reduction clauses + EXPECT_TRUE(isOnlyUsedByReductionClauses(varPtr, parallelRegion)); +} + +TEST_F(OpenACCUtilsTest, isOnlyUsedByReductionClausesFalse) { + // Create a value (memref) outside the compute region + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create a reduction op using the value + OwningOpRef reductionOp = ReductionOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Also use the value in a function call (escape) + OwningOpRef callOp = func::CallOp::create( + b, loc, "some_func", TypeRange{}, ValueRange{varPtr}); + + // Test that the value is NOT only used by reduction clauses (it escapes via + // call) + EXPECT_FALSE(isOnlyUsedByReductionClauses(varPtr, parallelRegion)); +} + +TEST_F(OpenACCUtilsTest, isOnlyUsedByReductionClausesMultiple) { + // Create a value (memref) outside the compute region + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a parallel op with a region + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(parallelBlock); + + // Create multiple reduction ops using the value + OwningOpRef reductionOp1 = ReductionOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + OwningOpRef reductionOp2 = ReductionOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Test that the value is only used by reduction clauses even with multiple + // uses + EXPECT_TRUE(isOnlyUsedByReductionClauses(varPtr, parallelRegion)); +} + +//===----------------------------------------------------------------------===// +// getDefaultAttr Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, getDefaultAttrOnParallel) { + // Create a parallel op with a default attribute + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + parallelOp->setDefaultAttr(ClauseDefaultValue::None); + + // Test that we can retrieve the default attribute + std::optional defaultAttr = + getDefaultAttr(parallelOp.get()); + EXPECT_TRUE(defaultAttr.has_value()); + EXPECT_EQ(defaultAttr.value(), ClauseDefaultValue::None); +} + +TEST_F(OpenACCUtilsTest, getDefaultAttrOnKernels) { + // Create a kernels op with a default attribute + OwningOpRef kernelsOp = + KernelsOp::create(b, loc, TypeRange{}, ValueRange{}); + kernelsOp->setDefaultAttr(ClauseDefaultValue::Present); + + // Test that we can retrieve the default attribute + std::optional defaultAttr = + getDefaultAttr(kernelsOp.get()); + EXPECT_TRUE(defaultAttr.has_value()); + EXPECT_EQ(defaultAttr.value(), ClauseDefaultValue::Present); +} + +TEST_F(OpenACCUtilsTest, getDefaultAttrOnSerial) { + // Create a serial op with a default attribute + OwningOpRef serialOp = + SerialOp::create(b, loc, TypeRange{}, ValueRange{}); + serialOp->setDefaultAttr(ClauseDefaultValue::None); + + // Test that we can retrieve the default attribute + std::optional defaultAttr = + getDefaultAttr(serialOp.get()); + EXPECT_TRUE(defaultAttr.has_value()); + EXPECT_EQ(defaultAttr.value(), ClauseDefaultValue::None); +} + +TEST_F(OpenACCUtilsTest, getDefaultAttrOnData) { + // Create a data op with a default attribute + OwningOpRef dataOp = + DataOp::create(b, loc, TypeRange{}, ValueRange{}); + dataOp->setDefaultAttr(ClauseDefaultValue::Present); + + // Test that we can retrieve the default attribute + std::optional defaultAttr = getDefaultAttr(dataOp.get()); + EXPECT_TRUE(defaultAttr.has_value()); + EXPECT_EQ(defaultAttr.value(), ClauseDefaultValue::Present); +} + +TEST_F(OpenACCUtilsTest, getDefaultAttrNone) { + // Create a parallel op without setting a default attribute + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + // Do not set default attribute + + // Test that we get std::nullopt when there's no default attribute + std::optional defaultAttr = + getDefaultAttr(parallelOp.get()); + EXPECT_FALSE(defaultAttr.has_value()); +} + +TEST_F(OpenACCUtilsTest, getDefaultAttrNearest) { + // Create a data op with a default attribute + OwningOpRef dataOp = + DataOp::create(b, loc, TypeRange{}, ValueRange{}); + dataOp->setDefaultAttr(ClauseDefaultValue::Present); + + Region &dataRegion = dataOp->getRegion(); + Block *dataBlock = &dataRegion.emplaceBlock(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(dataBlock); + + // Create a parallel op inside the data region with NO default attribute + OwningOpRef parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + // Do not set default attribute on parallel op + + Region ¶llelRegion = parallelOp->getRegion(); + Block *parallelBlock = ¶llelRegion.emplaceBlock(); + + b.setInsertionPointToStart(parallelBlock); + + // Create a loop op inside the parallel region + OwningOpRef loopOp = + LoopOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Test that from the loop op, we find the nearest default attribute (from + // data op) + std::optional defaultAttr = getDefaultAttr(loopOp.get()); + EXPECT_TRUE(defaultAttr.has_value()); + EXPECT_EQ(defaultAttr.value(), ClauseDefaultValue::Present); +} + +//===----------------------------------------------------------------------===// +// getTypeCategory Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, getTypeCategoryScalar) { + // Create a scalar memref (no dimensions) + auto scalarMemrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, scalarMemrefTy); + Value varPtr = allocOp->getResult(); + + // Test that a scalar memref returns scalar category + VariableTypeCategory category = getTypeCategory(varPtr); + EXPECT_EQ(category, VariableTypeCategory::scalar); +} + +TEST_F(OpenACCUtilsTest, getTypeCategoryArray) { + // Create an array memref (with dimensions) + auto arrayMemrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, arrayMemrefTy); + Value varPtr = allocOp->getResult(); + + // Test that an array memref returns array category + VariableTypeCategory category = getTypeCategory(varPtr); + EXPECT_EQ(category, VariableTypeCategory::array); +}