Skip to content

Conversation

@jplehr
Copy link
Contributor

@jplehr jplehr commented Sep 17, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-offload
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Jan Patrick Lehr (jplehr)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/159256.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (modified) offload/include/device.h (-2)
  • (modified) offload/libomptarget/device.cpp (+1-83)
  • (modified) offload/plugins-nextgen/common/include/PluginInterface.h (+11-2)
  • (modified) offload/plugins-nextgen/common/src/PluginInterface.cpp (+91-1)
  • (modified) offload/plugins-nextgen/host/src/rtl.cpp (+1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8e36ead6993a8..fecf445720173 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -46,6 +46,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRIR
   MLIRLinalgDialect
   MLIRMemRefDialect
+  MLIRMemRefTransforms
   MLIRMemRefUtils
   MLIRSCFDialect
   MLIRSideEffectInterfaces
diff --git a/offload/include/device.h b/offload/include/device.h
index bf93ce0460aef..1e85bb1876c83 100644
--- a/offload/include/device.h
+++ b/offload/include/device.h
@@ -33,9 +33,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
 
-#include "GlobalHandler.h"
 #include "PluginInterface.h"
-
 using GenericPluginTy = llvm::omp::target::plugin::GenericPluginTy;
 
 // Forward declarations.
diff --git a/offload/libomptarget/device.cpp b/offload/libomptarget/device.cpp
index 71423ae0c94d9..6585286bf4285 100644
--- a/offload/libomptarget/device.cpp
+++ b/offload/libomptarget/device.cpp
@@ -37,8 +37,6 @@
 using namespace llvm::omp::target::ompt;
 #endif
 
-using namespace llvm::omp::target::plugin;
-
 int HostDataToTargetTy::addEventIfNecessary(DeviceTy &Device,
                                             AsyncInfoTy &AsyncInfo) const {
   // First, check if the user disabled atomic map transfer/malloc/dealloc.
@@ -99,55 +97,7 @@ llvm::Error DeviceTy::init() {
   return llvm::Error::success();
 }
 
-// Extract the mapping of host function pointers to device function pointers
-// from the entry table. Functions marked as 'indirect' in OpenMP will have
-// offloading entries generated for them which map the host's function pointer
-// to a global containing the corresponding function pointer on the device.
-static llvm::Expected<std::pair<void *, uint64_t>>
-setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
-                       __tgt_device_binary Binary) {
-  AsyncInfoTy AsyncInfo(Device);
-  llvm::ArrayRef<llvm::offloading::EntryTy> Entries(Image->EntriesBegin,
-                                                    Image->EntriesEnd);
-  llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
-  for (const auto &Entry : Entries) {
-    if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
-        Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
-      continue;
-
-    assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
-    auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
-
-    void *Ptr;
-    if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
-      return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
-                                       "failed to load %s", Entry.SymbolName);
-
-    HstPtr = Entry.Address;
-    if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
-      return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
-                                       "failed to load %s", Entry.SymbolName);
-  }
-
-  // If we do not have any indirect globals we exit early.
-  if (IndirectCallTable.empty())
-    return std::pair{nullptr, 0};
-
-  // Sort the array to allow for more efficient lookup of device pointers.
-  llvm::sort(IndirectCallTable,
-             [](const auto &x, const auto &y) { return x.first < y.first; });
-
-  uint64_t TableSize =
-      IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
-  void *DevicePtr = Device.allocData(TableSize, nullptr, TARGET_ALLOC_DEVICE);
-  if (Device.submitData(DevicePtr, IndirectCallTable.data(), TableSize,
-                        AsyncInfo))
-    return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
-                                     "failed to copy data");
-  return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
-}
-
-// Load binary to device and perform global initialization if needed.
+// Load binary to device.
 llvm::Expected<__tgt_device_binary>
 DeviceTy::loadBinary(__tgt_device_image *Img) {
   __tgt_device_binary Binary;
@@ -155,38 +105,6 @@ DeviceTy::loadBinary(__tgt_device_image *Img) {
   if (RTL->load_binary(RTLDeviceID, Img, &Binary) != OFFLOAD_SUCCESS)
     return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
                                      "failed to load binary %p", Img);
-
-  // This symbol is optional.
-  void *DeviceEnvironmentPtr;
-  if (RTL->get_global(Binary, sizeof(DeviceEnvironmentTy),
-                      "__omp_rtl_device_environment", &DeviceEnvironmentPtr))
-    return Binary;
-
-  // Obtain a table mapping host function pointers to device function pointers.
-  auto CallTablePairOrErr = setupIndirectCallTable(*this, Img, Binary);
-  if (!CallTablePairOrErr)
-    return CallTablePairOrErr.takeError();
-
-  GenericDeviceTy &GenericDevice = RTL->getDevice(RTLDeviceID);
-  DeviceEnvironmentTy DeviceEnvironment;
-  DeviceEnvironment.DeviceDebugKind = GenericDevice.getDebugKind();
-  DeviceEnvironment.NumDevices = RTL->getNumDevices();
-  // TODO: The device ID used here is not the real device ID used by OpenMP.
-  DeviceEnvironment.DeviceNum = RTLDeviceID;
-  DeviceEnvironment.DynamicMemSize = GenericDevice.getDynamicMemorySize();
-  DeviceEnvironment.ClockFrequency = GenericDevice.getClockFrequency();
-  DeviceEnvironment.IndirectCallTable =
-      reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
-  DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
-  DeviceEnvironment.HardwareParallelism =
-      GenericDevice.getHardwareParallelism();
-
-  AsyncInfoTy AsyncInfo(*this);
-  if (submitData(DeviceEnvironmentPtr, &DeviceEnvironment,
-                 sizeof(DeviceEnvironment), AsyncInfo))
-    return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
-                                     "failed to copy data");
-
   return Binary;
 }
 
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index ce66d277d6187..afeeff120218b 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -815,6 +815,11 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   Error unloadBinary(DeviceImageTy *Image);
   virtual Error unloadBinaryImpl(DeviceImageTy *Image) = 0;
 
+  /// Setup the device environment if needed. Notice this setup may not be run
+  /// on some plugins. By default, it will be executed, but plugins can change
+  /// this behavior by overriding the shouldSetupDeviceEnvironment function.
+  Error setupDeviceEnvironment(GenericPluginTy &Plugin, DeviceImageTy &Image);
+
   /// Setup the global device memory pool, if the plugin requires one.
   Error setupDeviceMemoryPool(GenericPluginTy &Plugin, DeviceImageTy &Image,
                               uint64_t PoolSize);
@@ -1014,7 +1019,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   uint32_t getDefaultNumBlocks() const {
     return GridValues.GV_Default_Num_Teams;
   }
-  uint32_t getDebugKind() const { return OMPX_DebugKind; }
   uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }
   virtual uint64_t getClockFrequency() const { return CLOCKS_PER_SEC; }
 
@@ -1155,6 +1159,11 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   virtual Error getDeviceHeapSize(uint64_t &V) = 0;
   virtual Error setDeviceHeapSize(uint64_t V) = 0;
 
+  /// Indicate whether the device should setup the device environment. Notice
+  /// that returning false in this function will change the behavior of the
+  /// setupDeviceEnvironment() function.
+  virtual bool shouldSetupDeviceEnvironment() const { return true; }
+
   /// Indicate whether the device should setup the global device memory pool. If
   /// false is return the value on the device will be uninitialized.
   virtual bool shouldSetupDeviceMemoryPool() const { return true; }
@@ -1210,7 +1219,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   enum class PeerAccessState : uint8_t { AVAILABLE, UNAVAILABLE, PENDING };
 
   /// Array of peer access states with the rest of devices. This means that if
-  /// the device I has a matrix PeerAccesses with PeerAccesses == AVAILABLE,
+  /// the device I has a matrix PeerAccesses with PeerAccesses[J] == AVAILABLE,
   /// the device I can access device J's memory directly. However, notice this
   /// does not mean that device J can access device I's memory directly.
   llvm::SmallVector<PeerAccessState> PeerAccesses;
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 9f830874d5dad..166228fe6be94 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -363,6 +363,54 @@ struct RecordReplayTy {
 };
 } // namespace llvm::omp::target::plugin
 
+// Extract the mapping of host function pointers to device function pointers
+// from the entry table. Functions marked as 'indirect' in OpenMP will have
+// offloading entries generated for them which map the host's function pointer
+// to a global containing the corresponding function pointer on the device.
+static Expected<std::pair<void *, uint64_t>>
+setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device,
+                       DeviceImageTy &Image) {
+  GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
+
+  llvm::ArrayRef<llvm::offloading::EntryTy> Entries(
+      Image.getTgtImage()->EntriesBegin, Image.getTgtImage()->EntriesEnd);
+  llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
+  for (const auto &Entry : Entries) {
+    if (Entry.Kind != object::OffloadKind::OFK_OpenMP || Entry.Size == 0 ||
+        !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
+      continue;
+
+    assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
+    auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
+
+    GlobalTy DeviceGlobal(Entry.SymbolName, Entry.Size);
+    if (auto Err =
+            Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal))
+      return std::move(Err);
+
+    HstPtr = Entry.Address;
+    if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(),
+                                       Entry.Size, nullptr))
+      return std::move(Err);
+  }
+
+  // If we do not have any indirect globals we exit early.
+  if (IndirectCallTable.empty())
+    return std::pair{nullptr, 0};
+
+  // Sort the array to allow for more efficient lookup of device pointers.
+  llvm::sort(IndirectCallTable,
+             [](const auto &x, const auto &y) { return x.first < y.first; });
+
+  uint64_t TableSize =
+      IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
+  void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE);
+  if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(),
+                                   TableSize, nullptr))
+    return std::move(Err);
+  return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
+}
+
 AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device,
                                        __tgt_async_info *AsyncInfoPtr)
     : Device(Device),
@@ -881,6 +929,10 @@ Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
   // Add the image to list.
   LoadedImages.push_back(Image);
 
+  // Setup the device environment if needed.
+  if (auto Err = setupDeviceEnvironment(Plugin, *Image))
+    return std::move(Err);
+
   // Setup the global device memory pool if needed.
   if (!Plugin.getRecordReplay().isReplaying() &&
       shouldSetupDeviceMemoryPool()) {
@@ -916,6 +968,43 @@ Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
   return Image;
 }
 
+Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
+                                              DeviceImageTy &Image) {
+  // There are some plugins that do not need this step.
+  if (!shouldSetupDeviceEnvironment())
+    return Plugin::success();
+
+  // Obtain a table mapping host function pointers to device function pointers.
+  auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image);
+  if (!CallTablePairOrErr)
+    return CallTablePairOrErr.takeError();
+
+  DeviceEnvironmentTy DeviceEnvironment;
+  DeviceEnvironment.DeviceDebugKind = OMPX_DebugKind;
+  DeviceEnvironment.NumDevices = Plugin.getNumDevices();
+  // TODO: The device ID used here is not the real device ID used by OpenMP.
+  DeviceEnvironment.DeviceNum = DeviceId;
+  DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
+  DeviceEnvironment.ClockFrequency = getClockFrequency();
+  DeviceEnvironment.IndirectCallTable =
+      reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
+  DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
+  DeviceEnvironment.HardwareParallelism = getHardwareParallelism();
+
+  // Create the metainfo of the device environment global.
+  GlobalTy DevEnvGlobal("__omp_rtl_device_environment",
+                        sizeof(DeviceEnvironmentTy), &DeviceEnvironment);
+
+  // Write device environment values to the device.
+  GenericGlobalHandlerTy &GHandler = Plugin.getGlobalHandler();
+  if (auto Err = GHandler.writeGlobalToDevice(*this, Image, DevEnvGlobal)) {
+    DP("Missing symbol %s, continue execution anyway.\n",
+       DevEnvGlobal.getName().data());
+    consumeError(std::move(Err));
+  }
+  return Plugin::success();
+}
+
 Error GenericDeviceTy::setupDeviceMemoryPool(GenericPluginTy &Plugin,
                                              DeviceImageTy &Image,
                                              uint64_t PoolSize) {
@@ -2158,7 +2247,8 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size,
   GenericGlobalHandlerTy &GHandler = getGlobalHandler();
   if (auto Err =
           GHandler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal)) {
-    consumeError(std::move(Err));
+    REPORT("Failure to look up global address: %s\n",
+           toString(std::move(Err)).data());
     return OFFLOAD_FAIL;
   }
 
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 0db01ca09ab02..c1edb4506bb7e 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -388,6 +388,7 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
   }
 
   /// This plugin should not setup the device environment or memory pool.
+  virtual bool shouldSetupDeviceEnvironment() const override { return false; };
   virtual bool shouldSetupDeviceMemoryPool() const override { return false; };
 
   /// Getters and setters for stack size and heap size not relevant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants