diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.h b/flang/include/flang/Optimizer/Dialect/FIROps.h index 62ef8b4b502f2..4651f2bb8038e 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.h +++ b/flang/include/flang/Optimizer/Dialect/FIROps.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" namespace fir { diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 58a317cf5d691..bae52d63fda45 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" +include "mlir/Interfaces/ViewLikeInterface.td" include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td" include "flang/Optimizer/Dialect/FIRDialect.td" include "flang/Optimizer/Dialect/FIRTypes.td" @@ -2828,7 +2829,8 @@ def fir_VolatileCastOp : fir_SimpleOneResultOp<"volatile_cast", [Pure]> { let hasFolder = 1; } -def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> { +def fir_ConvertOp + : fir_SimpleOneResultOp<"convert", [NoMemoryEffect, ViewLikeOpInterface]> { let summary = "encapsulates all Fortran entity type conversions"; let description = [{ @@ -2866,6 +2868,7 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> { static bool isPointerCompatible(mlir::Type ty); static bool canBeConverted(mlir::Type inType, mlir::Type outType); static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy); + mlir::Value getViewSource() { return getValue(); } }]; let hasCanonicalizer = 1; } diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h new file mode 100644 index 0000000000000..7afe97aac57e8 --- /dev/null +++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h @@ -0,0 +1,58 @@ +//===- FIROpenACCOpsInterfaces.h --------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains external operation interfaces for FIR. +// +//===----------------------------------------------------------------------===// + +#ifndef FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_ +#define FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_ + +#include "mlir/Dialect/OpenACC/OpenACC.h" + +namespace fir { +class DeclareOp; +} // namespace fir + +namespace hlfir { +class DeclareOp; +class DesignateOp; +} // namespace hlfir + +namespace fir::acc { + +template +struct PartialEntityAccessModel + : public mlir::acc::PartialEntityAccessOpInterface::ExternalModel< + PartialEntityAccessModel, Op> { + mlir::Value getBaseEntity(mlir::Operation *op) const; + + // Default implementation - returns false (partial view) + bool isCompleteView(mlir::Operation *op) const { return false; } +}; + +// Full specializations for declare operations +template <> +struct PartialEntityAccessModel + : public mlir::acc::PartialEntityAccessOpInterface::ExternalModel< + PartialEntityAccessModel, fir::DeclareOp> { + mlir::Value getBaseEntity(mlir::Operation *op) const; + bool isCompleteView(mlir::Operation *op) const; +}; + +template <> +struct PartialEntityAccessModel + : public mlir::acc::PartialEntityAccessOpInterface::ExternalModel< + PartialEntityAccessModel, hlfir::DeclareOp> { + mlir::Value getBaseEntity(mlir::Operation *op) const; + bool isCompleteView(mlir::Operation *op) const; +}; + +} // namespace fir::acc + +#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_ diff --git a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt index ef67ab1549537..898fb00d41dfe 100644 --- a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_flang_library(FIROpenACCSupport FIROpenACCAttributes.cpp + FIROpenACCOpsInterfaces.cpp FIROpenACCTypeInterfaces.cpp RegisterOpenACCExtensions.cpp diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp new file mode 100644 index 0000000000000..c1734be5185f4 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -0,0 +1,62 @@ +//===-- FIROpenACCOpsInterfaces.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of external operation interfaces for FIR. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h" + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" + +namespace fir::acc { + +template <> +mlir::Value PartialEntityAccessModel::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast(op).getMemref(); +} + +template <> +mlir::Value PartialEntityAccessModel::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast(op).getRef(); +} + +template <> +mlir::Value PartialEntityAccessModel::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast(op).getMemref(); +} + +mlir::Value PartialEntityAccessModel::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast(op).getStorage(); +} + +bool PartialEntityAccessModel::isCompleteView( + mlir::Operation *op) const { + // Return false (partial view) only if storage is present + // Return true (complete view) if storage is absent + return !getBaseEntity(op); +} + +mlir::Value PartialEntityAccessModel::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast(op).getStorage(); +} + +bool PartialEntityAccessModel::isCompleteView( + mlir::Operation *op) const { + // Return false (partial view) only if storage is present + // Return true (complete view) if storage is absent + return !getBaseEntity(op); +} + +} // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp index 717bf344e40aa..d71c40dfac03c 100644 --- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp @@ -11,8 +11,13 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.h" + #include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h" #include "flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h" namespace fir::acc { @@ -37,7 +42,24 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) { fir::LLVMPointerType::attachInterface< OpenACCPointerLikeModel>(*ctx); + + fir::ArrayCoorOp::attachInterface< + PartialEntityAccessModel>(*ctx); + fir::CoordinateOp::attachInterface< + PartialEntityAccessModel>(*ctx); + fir::DeclareOp::attachInterface>( + *ctx); }); + + // Register HLFIR operation interfaces + registry.addExtension( + +[](mlir::MLIRContext *ctx, hlfir::hlfirDialect *dialect) { + hlfir::DesignateOp::attachInterface< + PartialEntityAccessModel>(*ctx); + hlfir::DeclareOp::attachInterface< + PartialEntityAccessModel>(*ctx); + }); + registerAttrsExtensions(registry); } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td index 6fb9a950489f8..054c13a88a552 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td @@ -26,4 +26,22 @@ def ComputeRegionOpInterface : OpInterface<"ComputeRegionOpInterface"> { ]; } +def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that access a partial entity such as + field or array element access. + }]; + + let methods = [ + InterfaceMethod<"Get the base entity being accessed", "::mlir::Value", + "getBaseEntity", (ins)>, + InterfaceMethod<"Check if this is a complete view of the entity", "bool", + "isCompleteView", (ins), [{ + return false; + }]>, + ]; +} + #endif // OPENACC_OPS_INTERFACES diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h index 563c1e0099fc0..964735755c4a3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h @@ -47,6 +47,11 @@ std::string getVariableName(mlir::Value v); /// Returns an empty string if not possible to generate a recipe name. std::string getRecipeName(mlir::acc::RecipeKind kind, mlir::Type type); +// Get the base entity from partial entity access. This is used for getting +// the base `struct` from an operation that only accesses a field or the +// base `array` from an operation that only accesses a subarray. +mlir::Value getBaseEntity(mlir::Value val); + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 660c3138af0ec..fbac28e740750 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -145,3 +145,13 @@ std::string mlir::acc::getRecipeName(mlir::acc::RecipeKind kind, return recipeName; } + +mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { + if (auto partialEntityAccessOp = + dyn_cast(val.getDefiningOp())) { + if (!partialEntityAccessOp.isCompleteView()) + return partialEntityAccessOp.getBaseEntity(); + } + + return val; +} diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp index f1fe53c15a6f5..6f4e30585b2c9 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -570,3 +570,107 @@ TEST_F(OpenACCUtilsTest, getRecipeNamePrivateUnrankedMemref) { getRecipeName(RecipeKind::private_recipe, unrankedMemrefTy); EXPECT_EQ(recipeName, "privatization_memref_Zxi32_"); } + +//===----------------------------------------------------------------------===// +// getBaseEntity Tests +//===----------------------------------------------------------------------===// + +// Local implementation of PartialEntityAccessOpInterface for memref.subview. +// This is implemented locally in the test rather than officially because memref +// operations already have ViewLikeOpInterface, which serves a similar purpose +// for walking through views to the base entity. This test demonstrates how +// getBaseEntity() would work if the interface were attached to memref.subview. +namespace { +struct SubViewOpPartialEntityAccessOpInterface + : public acc::PartialEntityAccessOpInterface::ExternalModel< + SubViewOpPartialEntityAccessOpInterface, memref::SubViewOp> { + Value getBaseEntity(Operation *op) const { + auto subviewOp = cast(op); + return subviewOp.getSource(); + } + + bool isCompleteView(Operation *op) const { + // For testing purposes, we'll consider it a partial view (return false). + // The real implementation would need to look at the offsets. + return false; + } +}; +} // namespace + +TEST_F(OpenACCUtilsTest, getBaseEntityFromSubview) { + // Register the local interface implementation for memref.subview + memref::SubViewOp::attachInterface( + context); + + // Create a base memref + auto memrefTy = MemRefType::get({10, 20}, b.getF32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + Value baseMemref = allocOp->getResult(); + + // Create a subview of the base memref with non-zero offsets + // This creates a 5x10 view starting at [2, 3] in the original 10x20 memref + SmallVector offsets = {b.getIndexAttr(2), b.getIndexAttr(3)}; + SmallVector sizes = {b.getIndexAttr(5), b.getIndexAttr(10)}; + SmallVector strides = {b.getIndexAttr(1), b.getIndexAttr(1)}; + + OwningOpRef subviewOp = + memref::SubViewOp::create(b, loc, baseMemref, offsets, sizes, strides); + Value subview = subviewOp->getResult(); + + // Test that getBaseEntity returns the base memref, not the subview + Value baseEntity = getBaseEntity(subview); + EXPECT_EQ(baseEntity, baseMemref); +} + +TEST_F(OpenACCUtilsTest, getBaseEntityNoInterface) { + // Create a memref without the interface + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + Value varPtr = allocOp->getResult(); + + // Test that getBaseEntity returns the value itself when there's no interface + Value baseEntity = getBaseEntity(varPtr); + EXPECT_EQ(baseEntity, varPtr); +} + +TEST_F(OpenACCUtilsTest, getBaseEntityChainedSubviews) { + // Register the local interface implementation for memref.subview + memref::SubViewOp::attachInterface( + context); + + // Create a base memref + auto memrefTy = MemRefType::get({100, 200}, b.getI64Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + Value baseMemref = allocOp->getResult(); + + // Create first subview + SmallVector offsets1 = {b.getIndexAttr(10), b.getIndexAttr(20)}; + SmallVector sizes1 = {b.getIndexAttr(50), b.getIndexAttr(80)}; + SmallVector strides1 = {b.getIndexAttr(1), b.getIndexAttr(1)}; + + OwningOpRef subview1Op = + memref::SubViewOp::create(b, loc, baseMemref, offsets1, sizes1, strides1); + Value subview1 = subview1Op->getResult(); + + // Create second subview (subview of subview) + SmallVector offsets2 = {b.getIndexAttr(5), b.getIndexAttr(10)}; + SmallVector sizes2 = {b.getIndexAttr(20), b.getIndexAttr(30)}; + SmallVector strides2 = {b.getIndexAttr(1), b.getIndexAttr(1)}; + + OwningOpRef subview2Op = + memref::SubViewOp::create(b, loc, subview1, offsets2, sizes2, strides2); + Value subview2 = subview2Op->getResult(); + + // Test that getBaseEntity on the nested subview returns the first subview + // (since our implementation returns the immediate source, not the ultimate + // base) + Value baseEntity = getBaseEntity(subview2); + EXPECT_EQ(baseEntity, subview1); + + // Test that calling getBaseEntity again returns the original base + Value ultimateBase = getBaseEntity(baseEntity); + EXPECT_EQ(ultimateBase, baseMemref); +}