@@ -42,17 +42,40 @@ using namespace error;
4242struct ol_platform_impl_t {
4343 ol_platform_impl_t (std::unique_ptr<GenericPluginTy> Plugin,
4444 ol_platform_backend_t BackendType)
45- : Plugin(std::move(Plugin)), BackendType(BackendType) {}
46- std::unique_ptr<GenericPluginTy> Plugin;
47- llvm::SmallVector<std::unique_ptr<ol_device_impl_t >> Devices;
45+ : BackendType(BackendType), Plugin(std::move(Plugin)) {}
4846 ol_platform_backend_t BackendType;
4947
48+ // / Get the plugin, lazily initializing it if necessary.
49+ llvm::Expected<GenericPluginTy *> getPlugin () {
50+ if (llvm::Error Err = init ())
51+ return Err;
52+ return Plugin.get ();
53+ }
54+
55+ // / Get the device list, lazily initializing it if necessary.
56+ llvm::Expected<llvm::SmallVector<std::unique_ptr<ol_device_impl_t >> &>
57+ getDevices () {
58+ if (llvm::Error Err = init ())
59+ return Err;
60+ return Devices;
61+ }
62+
5063 // / Complete all pending work for this platform and perform any needed
5164 // / cleanup.
5265 // /
5366 // / After calling this function, no liboffload functions should be called with
5467 // / this platform handle.
5568 llvm::Error destroy ();
69+
70+ // / Initialize the associated plugin and devices.
71+ llvm::Error init ();
72+
73+ // / Direct access to the plugin, may be uninitialized if accessed here.
74+ std::unique_ptr<GenericPluginTy> Plugin;
75+
76+ private:
77+ std::once_flag Initialized;
78+ llvm::SmallVector<std::unique_ptr<ol_device_impl_t >> Devices;
5679};
5780
5881// Handle type definitions. Ideally these would be 1:1 with the plugins, but
@@ -130,6 +153,39 @@ llvm::Error ol_platform_impl_t::destroy() {
130153 return Result;
131154}
132155
156+ llvm::Error ol_platform_impl_t::init () {
157+ std::unique_ptr<llvm::Error> Storage;
158+
159+ // This can be called concurrently, make sure we only do the actual
160+ // initialization once.
161+ std::call_once (Initialized, [&]() {
162+ // FIXME: Need better handling for the host platform.
163+ if (!Plugin)
164+ return ;
165+
166+ llvm::Error Err = Plugin->init ();
167+ if (Err) {
168+ Storage = std::make_unique<llvm::Error>(std::move (Err));
169+ return ;
170+ }
171+
172+ for (auto DevNum = 0 ; DevNum < Plugin->number_of_devices (); DevNum++) {
173+ if (Plugin->init_device (DevNum) == OFFLOAD_SUCCESS) {
174+ auto Device = &Plugin->getDevice (DevNum);
175+ auto Info = Device->obtainInfoImpl ();
176+ if (llvm::Error Err = Info.takeError ()) {
177+ Storage = std::make_unique<llvm::Error>(std::move (Err));
178+ return ;
179+ }
180+ Devices.emplace_back (std::make_unique<ol_device_impl_t >(
181+ DevNum, Device, *this , std::move (*Info)));
182+ }
183+ }
184+ });
185+
186+ return Storage ? std::move (*Storage) : llvm::Error::success ();
187+ }
188+
133189struct ol_queue_impl_t {
134190 ol_queue_impl_t (__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
135191 : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
@@ -209,13 +265,9 @@ struct OffloadContext {
209265 // key in AllocInfoMap
210266 llvm::SmallVector<void *> AllocBases{};
211267 SmallVector<std::unique_ptr<ol_platform_impl_t >, 4 > Platforms{};
268+ ol_device_handle_t HostDevice;
212269 size_t RefCount;
213270
214- ol_device_handle_t HostDevice () {
215- // The host platform is always inserted last
216- return Platforms.back ()->Devices [0 ].get ();
217- }
218-
219271 static OffloadContext &get () {
220272 assert (OffloadContextVal);
221273 return *OffloadContextVal;
@@ -259,28 +311,16 @@ Error initPlugins(OffloadContext &Context) {
259311 } while (false );
260312#include " Shared/Targets.def"
261313
262- // Preemptively initialize all devices in the plugin
263- for (auto &Platform : Context.Platforms ) {
264- auto Err = Platform->Plugin ->init ();
265- [[maybe_unused]] std::string InfoMsg = toString (std::move (Err));
266- for (auto DevNum = 0 ; DevNum < Platform->Plugin ->number_of_devices ();
267- DevNum++) {
268- if (Platform->Plugin ->init_device (DevNum) == OFFLOAD_SUCCESS) {
269- auto Device = &Platform->Plugin ->getDevice (DevNum);
270- auto Info = Device->obtainInfoImpl ();
271- if (auto Err = Info.takeError ())
272- return Err;
273- Platform->Devices .emplace_back (std::make_unique<ol_device_impl_t >(
274- DevNum, Device, *Platform, std::move (*Info)));
275- }
276- }
277- }
278-
279314 // Add the special host device
280315 auto &HostPlatform = Context.Platforms .emplace_back (
281316 std::make_unique<ol_platform_impl_t >(nullptr , OL_PLATFORM_BACKEND_HOST));
282- HostPlatform->Devices .emplace_back (std::make_unique<ol_device_impl_t >(
283- -1 , nullptr , *HostPlatform, InfoTreeNode{}));
317+ auto DevicesOrErr = HostPlatform->getDevices ();
318+ if (!DevicesOrErr)
319+ return DevicesOrErr.takeError ();
320+ Context.HostDevice = DevicesOrErr
321+ ->emplace_back (std::make_unique<ol_device_impl_t >(
322+ -1 , nullptr , *HostPlatform, InfoTreeNode{}))
323+ .get ();
284324
285325 Context.TracingEnabled = std::getenv (" OFFLOAD_TRACE" );
286326 Context.ValidationEnabled = !std::getenv (" OFFLOAD_DISABLE_VALIDATION" );
@@ -315,12 +355,12 @@ Error olShutDown_impl() {
315355 llvm::Error Result = Error::success ();
316356 auto *OldContext = OffloadContextVal.exchange (nullptr );
317357
318- for (auto &P : OldContext->Platforms ) {
358+ for (auto &Platform : OldContext->Platforms ) {
319359 // Host plugin is nullptr and has no deinit
320- if (!P ->Plugin || !P ->Plugin ->is_initialized ())
360+ if (!Platform ->Plugin || !Platform ->Plugin ->is_initialized ())
321361 continue ;
322362
323- if (auto Res = P ->destroy ())
363+ if (auto Res = Platform ->destroy ())
324364 Result = llvm::joinErrors (std::move (Result), std::move (Res));
325365 }
326366
@@ -334,9 +374,14 @@ Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
334374 InfoWriter Info (PropSize, PropValue, PropSizeRet);
335375 bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
336376
377+ auto PluginOrErr = Platform->getPlugin ();
378+ if (!PluginOrErr)
379+ return PluginOrErr.takeError ();
380+ GenericPluginTy *Plugin = *PluginOrErr;
381+
337382 switch (PropName) {
338383 case OL_PLATFORM_INFO_NAME:
339- return Info.writeString (IsHost ? " Host" : Platform-> Plugin ->getName ());
384+ return Info.writeString (IsHost ? " Host" : Plugin->getName ());
340385 case OL_PLATFORM_INFO_VENDOR_NAME:
341386 // TODO: Implement this
342387 return Info.writeString (" Unknown platform vendor" );
@@ -373,7 +418,7 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
373418Error olGetDeviceInfoImplDetail (ol_device_handle_t Device,
374419 ol_device_info_t PropName, size_t PropSize,
375420 void *PropValue, size_t *PropSizeRet) {
376- assert (Device != OffloadContext::get ().HostDevice () );
421+ assert (Device != OffloadContext::get ().HostDevice );
377422 InfoWriter Info (PropSize, PropValue, PropSizeRet);
378423
379424 auto makeError = [&](ErrorCode Code, StringRef Err) {
@@ -511,7 +556,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
511556Error olGetDeviceInfoImplDetailHost (ol_device_handle_t Device,
512557 ol_device_info_t PropName, size_t PropSize,
513558 void *PropValue, size_t *PropSizeRet) {
514- assert (Device == OffloadContext::get ().HostDevice () );
559+ assert (Device == OffloadContext::get ().HostDevice );
515560 InfoWriter Info (PropSize, PropValue, PropSizeRet);
516561
517562 constexpr auto uint32_max = std::numeric_limits<uint32_t >::max ();
@@ -579,7 +624,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
579624
580625Error olGetDeviceInfo_impl (ol_device_handle_t Device, ol_device_info_t PropName,
581626 size_t PropSize, void *PropValue) {
582- if (Device == OffloadContext::get ().HostDevice () )
627+ if (Device == OffloadContext::get ().HostDevice )
583628 return olGetDeviceInfoImplDetailHost (Device, PropName, PropSize, PropValue,
584629 nullptr );
585630 return olGetDeviceInfoImplDetail (Device, PropName, PropSize, PropValue,
@@ -588,17 +633,20 @@ Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
588633
589634Error olGetDeviceInfoSize_impl (ol_device_handle_t Device,
590635 ol_device_info_t PropName, size_t *PropSizeRet) {
591- if (Device == OffloadContext::get ().HostDevice () )
636+ if (Device == OffloadContext::get ().HostDevice )
592637 return olGetDeviceInfoImplDetailHost (Device, PropName, 0 , nullptr ,
593638 PropSizeRet);
594639 return olGetDeviceInfoImplDetail (Device, PropName, 0 , nullptr , PropSizeRet);
595640}
596641
597642Error olIterateDevices_impl (ol_device_iterate_cb_t Callback, void *UserData) {
598643 for (auto &Platform : OffloadContext::get ().Platforms ) {
599- for (auto &Device : Platform->Devices ) {
644+ auto DevicesOrErr = Platform->getDevices ();
645+ if (!DevicesOrErr)
646+ return DevicesOrErr.takeError ();
647+ for (auto &Device : *DevicesOrErr) {
600648 if (!Callback (Device.get (), UserData)) {
601- break ;
649+ return Error::success () ;
602650 }
603651 }
604652 }
@@ -949,7 +997,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
949997Error olMemcpy_impl (ol_queue_handle_t Queue, void *DstPtr,
950998 ol_device_handle_t DstDevice, const void *SrcPtr,
951999 ol_device_handle_t SrcDevice, size_t Size) {
952- auto Host = OffloadContext::get ().HostDevice () ;
1000+ auto Host = OffloadContext::get ().HostDevice ;
9531001 if (DstDevice == Host && SrcDevice == Host) {
9541002 if (!Queue) {
9551003 std::memcpy (DstPtr, SrcPtr, Size);
0 commit comments