@@ -84,9 +84,17 @@ struct ol_program_impl_t {
8484 DeviceImage (DeviceImage) {}
8585 plugin::DeviceImageTy *Image;
8686 std::unique_ptr<llvm::MemoryBuffer> ImageData;
87+ std::vector<std::unique_ptr<ol_symbol_impl_t >> Symbols;
8788 __tgt_device_image DeviceImage;
8889};
8990
91+ struct ol_symbol_impl_t {
92+ ol_symbol_impl_t (GenericKernelTy *Kernel)
93+ : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
94+ std::variant<GenericKernelTy *> PluginImpl;
95+ ol_symbol_kind_t Kind;
96+ };
97+
9098namespace llvm {
9199namespace offload {
92100
@@ -653,7 +661,7 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) {
653661}
654662
655663Error olGetKernel_impl (ol_program_handle_t Program, const char *KernelName,
656- ol_kernel_handle_t *Kernel) {
664+ ol_symbol_handle_t *Kernel) {
657665
658666 auto &Device = Program->Image ->getDevice ();
659667 auto KernelImpl = Device.constructKernel (KernelName);
@@ -663,13 +671,15 @@ Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
663671 if (auto Err = KernelImpl->init (Device, *Program->Image ))
664672 return Err;
665673
666- *Kernel = &*KernelImpl;
674+ *Kernel = Program->Symbols
675+ .emplace_back (std::make_unique<ol_symbol_impl_t >(&*KernelImpl))
676+ .get ();
667677
668678 return Error::success ();
669679}
670680
671681Error olLaunchKernel_impl (ol_queue_handle_t Queue, ol_device_handle_t Device,
672- ol_kernel_handle_t Kernel, const void *ArgumentsData,
682+ ol_symbol_handle_t Kernel, const void *ArgumentsData,
673683 size_t ArgumentsSize,
674684 const ol_kernel_launch_size_args_t *LaunchSizeArgs,
675685 ol_event_handle_t *EventOut) {
@@ -680,6 +690,10 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
680690 " device specified does not match the device of the given queue" );
681691 }
682692
693+ if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
694+ return createOffloadError (ErrorCode::SYMBOL_KIND,
695+ " provided symbol is not a kernel" );
696+
683697 auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr ;
684698 AsyncInfoWrapperTy AsyncInfoWrapper (*DeviceImpl, QueueImpl);
685699 KernelArgsTy LaunchArgs{};
@@ -698,7 +712,7 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
698712 // Don't do anything with pointer indirection; use arg data as-is
699713 LaunchArgs.Flags .IsCUDA = true ;
700714
701- auto *KernelImpl = reinterpret_cast <GenericKernelTy *>(Kernel);
715+ auto *KernelImpl = std::get <GenericKernelTy *>(Kernel-> PluginImpl );
702716 auto Err = KernelImpl->launch (*DeviceImpl, LaunchArgs.ArgPtrs , nullptr ,
703717 LaunchArgs, AsyncInfoWrapper);
704718
0 commit comments