@@ -86,6 +86,7 @@ class kernel_bundle_impl {
8686
8787 MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
8888 MContext, MDevices, State);
89+ fillUniqueDeviceImages ();
8990 }
9091
9192 // Interop constructor used by make_kernel
@@ -103,7 +104,8 @@ class kernel_bundle_impl {
103104 kernel_bundle_impl (context Ctx, std::vector<device> Devs,
104105 device_image_plain &DevImage)
105106 : kernel_bundle_impl(Ctx, Devs) {
106- MDeviceImages.push_back (DevImage);
107+ MDeviceImages.emplace_back (DevImage);
108+ MUniqueDeviceImages.emplace_back (DevImage);
107109 }
108110
109111 // Matches sycl::build and sycl::compile
@@ -115,10 +117,12 @@ class kernel_bundle_impl {
115117 : MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
116118 MState (TargetState) {
117119
118- MSpecConstValues = getSyclObjImpl (InputBundle)->get_spec_const_map_ref ();
120+ const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
121+ getSyclObjImpl (InputBundle);
122+ MSpecConstValues = InputBundleImpl->get_spec_const_map_ref ();
119123
120124 const std::vector<device> &InputBundleDevices =
121- getSyclObjImpl (InputBundle) ->get_devices ();
125+ InputBundleImpl ->get_devices ();
122126 const bool AllDevsAssociatedWithInputBundle =
123127 std::all_of (MDevices.begin (), MDevices.end (),
124128 [&InputBundleDevices](const device &Dev) {
@@ -132,24 +136,37 @@ class kernel_bundle_impl {
132136 " Not all devices are in the set of associated "
133137 " devices for input bundle or vector of devices is empty" );
134138
135- for (const device_image_plain &DeviceImage : InputBundle) {
139+ for (const DevImgPlainWithDeps &DevImgWithDeps :
140+ InputBundleImpl->MDeviceImages ) {
136141 // Skip images which are not compatible with devices provided
137- if (std::none_of (
138- MDevices. begin (), MDevices. end (),
139- [&DeviceImage]( const device &Dev) {
140- return getSyclObjImpl (DeviceImage) ->compatible_with_device (Dev);
141- }))
142+ if (std::none_of (MDevices. begin (), MDevices. end (),
143+ [&DevImgWithDeps]( const device &Dev) {
144+ return getSyclObjImpl (DevImgWithDeps. getMain ())
145+ ->compatible_with_device (Dev);
146+ }))
142147 continue ;
143148
144149 switch (TargetState) {
145- case bundle_state::object:
146- MDeviceImages.push_back (detail::ProgramManager::getInstance ().compile (
147- DeviceImage, MDevices, PropList));
150+ case bundle_state::object: {
151+ DevImgPlainWithDeps CompiledImgWithDeps =
152+ detail::ProgramManager::getInstance ().compile (DevImgWithDeps,
153+ MDevices, PropList);
154+
155+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
156+ CompiledImgWithDeps.begin (),
157+ CompiledImgWithDeps.end ());
158+ MDeviceImages.push_back (std::move (CompiledImgWithDeps));
148159 break ;
149- case bundle_state::executable:
150- MDeviceImages.push_back (detail::ProgramManager::getInstance ().build (
151- DeviceImage, MDevices, PropList));
160+ }
161+
162+ case bundle_state::executable: {
163+ device_image_plain BuiltImg =
164+ detail::ProgramManager::getInstance ().build (DevImgWithDeps,
165+ MDevices, PropList);
166+ MDeviceImages.emplace_back (BuiltImg);
167+ MUniqueDeviceImages.push_back (BuiltImg);
152168 break ;
169+ }
153170 case bundle_state::input:
154171 case bundle_state::ext_oneapi_source:
155172 throw exception (make_error_code (errc::runtime),
@@ -158,6 +175,7 @@ class kernel_bundle_impl {
158175 break ;
159176 }
160177 }
178+ removeDuplicateImages ();
161179 }
162180
163181 // Matches sycl::link
@@ -201,7 +219,7 @@ class kernel_bundle_impl {
201219 " Not all devices are in the set of associated "
202220 " devices for input bundles" );
203221
204- // TODO: Unify with c'tor for sycl::comile and sycl::build by calling
222+ // TODO: Unify with c'tor for sycl::compile and sycl::build by calling
205223 // sycl::join on vector of kernel_bundles
206224
207225 // The loop below just links each device image separately, not linking any
@@ -213,23 +231,27 @@ class kernel_bundle_impl {
213231 // undefined symbols, then the logic in this loop will need to be changed.
214232 for (const kernel_bundle<bundle_state::object> &ObjectBundle :
215233 ObjectBundles) {
216- for (const device_image_plain &DeviceImage : ObjectBundle) {
234+ for (const DevImgPlainWithDeps &DeviceImageWithDeps :
235+ getSyclObjImpl (ObjectBundle)->MDeviceImages ) {
217236
218237 // Skip images which are not compatible with devices provided
219238 if (std::none_of (MDevices.begin (), MDevices.end (),
220- [&DeviceImage ](const device &Dev) {
221- return getSyclObjImpl (DeviceImage )
239+ [&DeviceImageWithDeps ](const device &Dev) {
240+ return getSyclObjImpl (DeviceImageWithDeps. getMain () )
222241 ->compatible_with_device (Dev);
223242 }))
224243 continue ;
225244
226245 std::vector<device_image_plain> LinkedResults =
227- detail::ProgramManager::getInstance ().link (DeviceImage, MDevices ,
228- PropList);
246+ detail::ProgramManager::getInstance ().link (DeviceImageWithDeps ,
247+ MDevices, PropList);
229248 MDeviceImages.insert (MDeviceImages.end (), LinkedResults.begin (),
230249 LinkedResults.end ());
250+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
251+ LinkedResults.begin (), LinkedResults.end ());
231252 }
232253 }
254+ removeDuplicateImages ();
233255
234256 for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
235257 const KernelBundleImplPtr BundlePtr = getSyclObjImpl (Bundle);
@@ -249,6 +271,7 @@ class kernel_bundle_impl {
249271
250272 MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
251273 MContext, MDevices, KernelIDs, State);
274+ fillUniqueDeviceImages ();
252275 }
253276
254277 kernel_bundle_impl (context Ctx, std::vector<device> Devs,
@@ -259,6 +282,7 @@ class kernel_bundle_impl {
259282
260283 MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
261284 MContext, MDevices, Selector, State);
285+ fillUniqueDeviceImages ();
262286 }
263287
264288 // C'tor matches sycl::join API
@@ -287,11 +311,10 @@ class kernel_bundle_impl {
287311 Bundle->MDeviceImages .end ());
288312 }
289313
290- std::sort (MDeviceImages.begin (), MDeviceImages.end (),
291- LessByHash<device_image_plain>{});
314+ fillUniqueDeviceImages ();
292315
293316 if (get_bundle_state () == bundle_state::input) {
294- // Copy spec constants values from the device images to be removed .
317+ // Copy spec constants values from the device images.
295318 auto MergeSpecConstants = [this ](const device_image_plain &Img) {
296319 const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl (Img);
297320 const std::map<std::string,
@@ -310,16 +333,9 @@ class kernel_bundle_impl {
310333 SpecConst.second .back ().Size );
311334 }
312335 };
313- std::for_each (MDeviceImages.begin (), MDeviceImages.end (),
314- MergeSpecConstants);
336+ std::for_each (begin (), end (), MergeSpecConstants);
315337 }
316338
317- const auto DevImgIt =
318- std::unique (MDeviceImages.begin (), MDeviceImages.end ());
319-
320- // Remove duplicate device images.
321- MDeviceImages.erase (DevImgIt, MDeviceImages.end ());
322-
323339 for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
324340 for (const std::pair<const std::string, std::vector<unsigned char >>
325341 &SpecConst : Bundle->MSpecConstValues ) {
@@ -605,7 +621,7 @@ class kernel_bundle_impl {
605621
606622 assert (MDeviceImages.size () > 0 );
607623 const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
608- detail::getSyclObjImpl (MDeviceImages[0 ]);
624+ detail::getSyclObjImpl (MDeviceImages[0 ]. getMain () );
609625 ur_program_handle_t UrProgram = DeviceImageImpl->get_ur_program_ref ();
610626 ContextImplPtr ContextImpl = getSyclObjImpl (MContext);
611627 const AdapterPtr &Adapter = ContextImpl->getAdapter ();
@@ -634,7 +650,7 @@ class kernel_bundle_impl {
634650 // Collect kernel ids from all device images, then remove duplicates
635651
636652 std::vector<kernel_id> Result;
637- for (const device_image_plain &DeviceImage : MDeviceImages ) {
653+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages ) {
638654 const std::vector<kernel_id> &KernelIDs =
639655 getSyclObjImpl (DeviceImage)->get_kernel_ids ();
640656
@@ -662,8 +678,9 @@ class kernel_bundle_impl {
662678 // Used to track if any of the candidate images has specialization values
663679 // set.
664680 bool SpecConstsSet = false ;
665- for (auto &DeviceImage : MDeviceImages) {
666- if (!DeviceImage.has_kernel (KernelID))
681+ for (const DevImgPlainWithDeps &DeviceImageWithDeps : MDeviceImages) {
682+ const device_image_plain &DeviceImage = DeviceImageWithDeps.getMain ();
683+ if (!DeviceImageWithDeps.getMain ().has_kernel (KernelID))
667684 continue ;
668685
669686 const auto DeviceImageImpl = detail::getSyclObjImpl (DeviceImage);
@@ -718,39 +735,38 @@ class kernel_bundle_impl {
718735 }
719736
720737 bool has_kernel (const kernel_id &KernelID) const noexcept {
721- return std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
738+ return std::any_of (begin (), end (),
722739 [&KernelID](const device_image_plain &DeviceImage) {
723740 return DeviceImage.has_kernel (KernelID);
724741 });
725742 }
726743
727744 bool has_kernel (const kernel_id &KernelID, const device &Dev) const noexcept {
728745 return std::any_of (
729- MDeviceImages. begin (), MDeviceImages. end (),
746+ begin (), end (),
730747 [&KernelID, &Dev](const device_image_plain &DeviceImage) {
731748 return DeviceImage.has_kernel (KernelID, Dev);
732749 });
733750 }
734751
735752 bool contains_specialization_constants () const noexcept {
736753 return std::any_of (
737- MDeviceImages.begin (), MDeviceImages.end (),
738- [](const device_image_plain &DeviceImage) {
754+ begin (), end (), [](const device_image_plain &DeviceImage) {
739755 return getSyclObjImpl (DeviceImage)->has_specialization_constants ();
740756 });
741757 }
742758
743759 bool native_specialization_constant () const noexcept {
744760 return contains_specialization_constants () &&
745- std::all_of (MDeviceImages. begin (), MDeviceImages. end (),
761+ std::all_of (begin (), end (),
746762 [](const device_image_plain &DeviceImage) {
747763 return getSyclObjImpl (DeviceImage)
748764 ->all_specialization_constant_native ();
749765 });
750766 }
751767
752768 bool has_specialization_constant (const char *SpecName) const noexcept {
753- return std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
769+ return std::any_of (begin (), end (),
754770 [SpecName](const device_image_plain &DeviceImage) {
755771 return getSyclObjImpl (DeviceImage)
756772 ->has_specialization_constant (SpecName);
@@ -761,7 +777,7 @@ class kernel_bundle_impl {
761777 const void *Value,
762778 size_t Size) noexcept {
763779 if (has_specialization_constant (SpecName))
764- for (const device_image_plain &DeviceImage : MDeviceImages )
780+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages )
765781 getSyclObjImpl (DeviceImage)
766782 ->set_specialization_constant_raw_value (SpecName, Value);
767783 else {
@@ -773,7 +789,7 @@ class kernel_bundle_impl {
773789
774790 void get_specialization_constant_raw_value (const char *SpecName,
775791 void *ValueRet) const noexcept {
776- for (const device_image_plain &DeviceImage : MDeviceImages )
792+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages )
777793 if (getSyclObjImpl (DeviceImage)->has_specialization_constant (SpecName)) {
778794 getSyclObjImpl (DeviceImage)
779795 ->get_specialization_constant_raw_value (SpecName, ValueRet);
@@ -796,21 +812,21 @@ class kernel_bundle_impl {
796812
797813 bool is_specialization_constant_set (const char *SpecName) const noexcept {
798814 bool SetInDevImg =
799- std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
815+ std::any_of (begin (), end (),
800816 [SpecName](const device_image_plain &DeviceImage) {
801817 return getSyclObjImpl (DeviceImage)
802818 ->is_specialization_constant_set (SpecName);
803819 });
804820 return SetInDevImg || MSpecConstValues.count (std::string{SpecName}) != 0 ;
805821 }
806822
807- const device_image_plain *begin () const { return MDeviceImages .data (); }
823+ const device_image_plain *begin () const { return MUniqueDeviceImages .data (); }
808824
809825 const device_image_plain *end () const {
810- return MDeviceImages .data () + MDeviceImages .size ();
826+ return MUniqueDeviceImages .data () + MUniqueDeviceImages .size ();
811827 }
812828
813- size_t size () const noexcept { return MDeviceImages .size (); }
829+ size_t size () const noexcept { return MUniqueDeviceImages .size (); }
814830
815831 bundle_state get_bundle_state () const { return MState; }
816832
@@ -827,7 +843,7 @@ class kernel_bundle_impl {
827843
828844 // First try and get images in current bundle state
829845 const bundle_state BundleState = get_bundle_state ();
830- std::vector<device_image_plain > NewDevImgs =
846+ std::vector<DevImgPlainWithDeps > NewDevImgs =
831847 detail::ProgramManager::getInstance ().getSYCLDeviceImages (
832848 MContext, {Dev}, {KernelID}, BundleState);
833849
@@ -836,21 +852,38 @@ class kernel_bundle_impl {
836852 return false ;
837853
838854 // Propagate already set specialization constants to the new images
839- for (device_image_plain &DevImg : NewDevImgs)
840- for (auto SpecConst : MSpecConstValues)
841- getSyclObjImpl (DevImg)->set_specialization_constant_raw_value (
842- SpecConst.first .c_str (), SpecConst.second .data ());
855+ for (DevImgPlainWithDeps &DevImgWithDeps : NewDevImgs)
856+ for (device_image_plain &DevImg : DevImgWithDeps)
857+ for (auto SpecConst : MSpecConstValues)
858+ getSyclObjImpl (DevImg)->set_specialization_constant_raw_value (
859+ SpecConst.first .c_str (), SpecConst.second .data ());
843860
844861 // Add the images to the collection
845862 MDeviceImages.insert (MDeviceImages.end (), NewDevImgs.begin (),
846863 NewDevImgs.end ());
864+ removeDuplicateImages ();
847865 return true ;
848866 }
849867
850868private:
869+ void fillUniqueDeviceImages () {
870+ assert (MUniqueDeviceImages.empty ());
871+ for (const DevImgPlainWithDeps &Imgs : MDeviceImages)
872+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (), Imgs.begin (),
873+ Imgs.end ());
874+ removeDuplicateImages ();
875+ }
876+ void removeDuplicateImages () {
877+ std::sort (MUniqueDeviceImages.begin (), MUniqueDeviceImages.end (),
878+ LessByHash<device_image_plain>{});
879+ const auto It =
880+ std::unique (MUniqueDeviceImages.begin (), MUniqueDeviceImages.end ());
881+ MUniqueDeviceImages.erase (It, MUniqueDeviceImages.end ());
882+ }
851883 context MContext;
852884 std::vector<device> MDevices;
853- std::vector<device_image_plain> MDeviceImages;
885+ std::vector<DevImgPlainWithDeps> MDeviceImages;
886+ std::vector<device_image_plain> MUniqueDeviceImages;
854887 // This map stores values for specialization constants, that are missing
855888 // from any device image.
856889 SpecConstMapT MSpecConstValues;
0 commit comments