@@ -272,21 +272,36 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
272272    : SYCLBINContentCopy{ContentCopy (SYCLBINContent, SYCLBINSize)},
273273      SYCLBINContentCopySize{SYCLBINSize},
274274      ParsedSYCLBIN (SYCLBIN{SYCLBINContentCopy.get (), SYCLBINSize}) {
275-   size_t  NumJITBinaries = 0 , NumNativeBinaries = 0 ;
276-   for  (const  SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules ) {
277-     NumJITBinaries += AM.IRModules .size ();
278-     NumNativeBinaries += AM.NativeDeviceCodeImages .size ();
279-   }
280-   DeviceBinaries.reserve (NumJITBinaries + NumNativeBinaries);
281-   JITDeviceBinaryImages.reserve (NumJITBinaries);
282-   NativeDeviceBinaryImages.reserve (NumNativeBinaries);
275+   AbstractModuleDescriptors = std::unique_ptr<AbstractModuleDesc[]>(
276+       new  AbstractModuleDesc[ParsedSYCLBIN.AbstractModules .size ()]);
277+ 
278+   size_t  NumBinaries = 0 ;
279+   for  (const  SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules )
280+     NumBinaries += AM.IRModules .size () + AM.NativeDeviceCodeImages .size ();
281+   DeviceBinaries.reserve (NumBinaries);
282+   BinaryImages = std::unique_ptr<RTDeviceBinaryImage[]>(
283+       new  RTDeviceBinaryImage[NumBinaries]);
284+ 
285+   RTDeviceBinaryImage *CurrentBinaryImagesStart = BinaryImages.get ();
286+   for  (size_t  I = 0 ; I < getNumAbstractModules (); ++I) {
287+     SYCLBIN::AbstractModule &AM = ParsedSYCLBIN.AbstractModules [I];
288+     AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I];
289+ 
290+     //  Set up the abstract module descriptor.
291+     AMDesc.NumJITBinaries  = AM.IRModules .size ();
292+     AMDesc.NumNativeBinaries  = AM.NativeDeviceCodeImages .size ();
293+     AMDesc.JITBinaries  = CurrentBinaryImagesStart;
294+     AMDesc.NativeBinaries  = CurrentBinaryImagesStart + AMDesc.NumJITBinaries ;
295+     CurrentBinaryImagesStart +=
296+         AMDesc.NumJITBinaries  + AM.NativeDeviceCodeImages .size ();
283297
284-   for  (SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules ) {
285298    //  Construct properties from SYCLBIN metadata.
286299    std::vector<_sycl_device_binary_property_set_struct> &BinPropertySets =
287300        convertAbstractModuleProperties (AM);
288301
289-     for  (SYCLBIN::IRModule &IRM : AM.IRModules ) {
302+     for  (size_t  J = 0 ; J < AM.IRModules .size (); ++J) {
303+       SYCLBIN::IRModule &IRM = AM.IRModules [J];
304+ 
290305      sycl_device_binary_struct &DeviceBinary = DeviceBinaries.emplace_back ();
291306      DeviceBinary.Version  = SYCL_DEVICE_BINARY_VERSION;
292307      DeviceBinary.Kind  = 4 ;
@@ -309,11 +324,12 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
309324      DeviceBinary.PropertySetsEnd  =
310325          BinPropertySets.data () + BinPropertySets.size ();
311326      //  Create an image from it.
312-       JITDeviceBinaryImages. emplace_back ( &DeviceBinary) ;
327+       AMDesc. JITBinaries [J] = RTDeviceBinaryImage{ &DeviceBinary} ;
313328    }
314329
315-     for  (const  SYCLBIN::NativeDeviceCodeImage &NDCI :
316-          AM.NativeDeviceCodeImages ) {
330+     for  (size_t  J = 0 ; J < AM.NativeDeviceCodeImages .size (); ++J) {
331+       const  SYCLBIN::NativeDeviceCodeImage &NDCI = AM.NativeDeviceCodeImages [J];
332+ 
317333      assert (NDCI.Metadata  != nullptr );
318334      PropertySet &NDCIMetadataProps = (*NDCI.Metadata )
319335          [PropertySetRegistry::SYCLBIN_NATIVE_DEVICE_CODE_IMAGE_METADATA];
@@ -346,7 +362,7 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
346362      DeviceBinary.PropertySetsEnd  =
347363          BinPropertySets.data () + BinPropertySets.size ();
348364      //  Create an image from it.
349-       NativeDeviceBinaryImages. emplace_back ( &DeviceBinary) ;
365+       AMDesc. NativeBinaries [J] = RTDeviceBinaryImage{ &DeviceBinary} ;
350366    }
351367  }
352368}
@@ -394,33 +410,44 @@ SYCLBINBinaries::convertAbstractModuleProperties(SYCLBIN::AbstractModule &AM) {
394410}
395411
396412std::vector<const  RTDeviceBinaryImage *>
397- SYCLBINBinaries::getBestCompatibleImages (device_impl &Dev) {
398-   auto  SelectCompatibleImages =
399-       [&](const  std::vector<RTDeviceBinaryImage> &Imgs) {
400-         std::vector<const  RTDeviceBinaryImage *> CompatImgs;
401-         for  (const  RTDeviceBinaryImage &Img : Imgs)
402-           if  (doesDevSupportDeviceRequirements (Dev, Img) &&
403-               doesImageTargetMatchDevice (Img, Dev))
404-             CompatImgs.push_back (&Img);
405-         return  CompatImgs;
406-       };
407- 
408-   //  Try with native images first.
409-   std::vector<const  RTDeviceBinaryImage *> NativeImgs =
410-       SelectCompatibleImages (NativeDeviceBinaryImages);
411-   if  (!NativeImgs.empty ())
412-     return  NativeImgs;
413- 
414-   //  If there were no native images, pick JIT images.
415-   return  SelectCompatibleImages (JITDeviceBinaryImages);
413+ SYCLBINBinaries::getBestCompatibleImages (device_impl &Dev, bundle_state State) {
414+   auto  GetCompatibleImage = [&](const  RTDeviceBinaryImage *Imgs,
415+                                 size_t  NumImgs) {
416+     const  RTDeviceBinaryImage *CompatImagePtr =
417+         std::find_if (Imgs, Imgs + NumImgs, [&](const  RTDeviceBinaryImage &Img) {
418+           return  doesDevSupportDeviceRequirements (Dev, Img) &&
419+                  doesImageTargetMatchDevice (Img, Dev);
420+         });
421+     return  (CompatImagePtr != Imgs + NumImgs) ? CompatImagePtr : nullptr ;
422+   };
423+ 
424+   std::vector<const  RTDeviceBinaryImage *> Images;
425+   for  (size_t  I = 0 ; I < getNumAbstractModules (); ++I) {
426+     const  AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I];
427+     //  If the target state is executable, try with native images first.
428+     if  (State == bundle_state::executable) {
429+       if  (const  RTDeviceBinaryImage *CompatImagePtr = GetCompatibleImage (
430+               AMDesc.NativeBinaries , AMDesc.NumNativeBinaries )) {
431+         Images.push_back (CompatImagePtr);
432+         continue ;
433+       }
434+     }
435+ 
436+     //  Otherwise, select the first compatible JIT binary.
437+     if  (const  RTDeviceBinaryImage *CompatImagePtr =
438+             GetCompatibleImage (AMDesc.JITBinaries , AMDesc.NumJITBinaries ))
439+       Images.push_back (CompatImagePtr);
440+   }
441+   return  Images;
416442}
417443
418444std::vector<const  RTDeviceBinaryImage *>
419- SYCLBINBinaries::getBestCompatibleImages (devices_range Devs) {
445+ SYCLBINBinaries::getBestCompatibleImages (devices_range Devs,
446+                                          bundle_state State) {
420447  std::set<const  RTDeviceBinaryImage *> Images;
421448  for  (device_impl &Dev : Devs) {
422449    std::vector<const  RTDeviceBinaryImage *> BestImagesForDev =
423-         getBestCompatibleImages (Dev);
450+         getBestCompatibleImages (Dev, State );
424451    Images.insert (BestImagesForDev.cbegin (), BestImagesForDev.cend ());
425452  }
426453  return  {Images.cbegin (), Images.cend ()};
0 commit comments