@@ -86,6 +86,7 @@ class kernel_bundle_impl {
8686
8787 MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
8888 MContext, MDevices, State);
89+ fillUniqueDeviceImages ();
8990 }
9091
9192 // Interop constructor used by make_kernel
@@ -103,7 +104,8 @@ class kernel_bundle_impl {
103104 kernel_bundle_impl (context Ctx, std::vector<device> Devs,
104105 device_image_plain &DevImage)
105106 : kernel_bundle_impl(Ctx, Devs) {
106- MDeviceImages.push_back (DevImage);
107+ MDeviceImages.emplace_back (DevImage);
108+ MUniqueDeviceImages.emplace_back (DevImage);
107109 }
108110
109111 // Matches sycl::build and sycl::compile
@@ -115,10 +117,12 @@ class kernel_bundle_impl {
115117 : MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
116118 MState (TargetState) {
117119
118- MSpecConstValues = getSyclObjImpl (InputBundle)->get_spec_const_map_ref ();
120+ const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
121+ getSyclObjImpl (InputBundle);
122+ MSpecConstValues = InputBundleImpl->get_spec_const_map_ref ();
119123
120124 const std::vector<device> &InputBundleDevices =
121- getSyclObjImpl (InputBundle) ->get_devices ();
125+ InputBundleImpl ->get_devices ();
122126 const bool AllDevsAssociatedWithInputBundle =
123127 std::all_of (MDevices.begin (), MDevices.end (),
124128 [&InputBundleDevices](const device &Dev) {
@@ -132,24 +136,37 @@ class kernel_bundle_impl {
132136 " Not all devices are in the set of associated "
133137 " devices for input bundle or vector of devices is empty" );
134138
135- for (const device_image_plain &DeviceImage : InputBundle) {
139+ for (const DevImgPlainWithDeps &DevImgWithDeps :
140+ InputBundleImpl->MDeviceImages ) {
136141 // Skip images which are not compatible with devices provided
137- if (std::none_of (
138- MDevices. begin (), MDevices. end (),
139- [&DeviceImage]( const device &Dev) {
140- return getSyclObjImpl (DeviceImage) ->compatible_with_device (Dev);
141- }))
142+ if (std::none_of (MDevices. begin (), MDevices. end (),
143+ [&DevImgWithDeps]( const device &Dev) {
144+ return getSyclObjImpl (DevImgWithDeps. getMain ())
145+ ->compatible_with_device (Dev);
146+ }))
142147 continue ;
143148
144149 switch (TargetState) {
145- case bundle_state::object:
146- MDeviceImages.push_back (detail::ProgramManager::getInstance ().compile (
147- DeviceImage, MDevices, PropList));
150+ case bundle_state::object: {
151+ DevImgPlainWithDeps CompiledImgWithDeps =
152+ detail::ProgramManager::getInstance ().compile (DevImgWithDeps,
153+ MDevices, PropList);
154+
155+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
156+ CompiledImgWithDeps.begin (),
157+ CompiledImgWithDeps.end ());
158+ MDeviceImages.push_back (std::move (CompiledImgWithDeps));
148159 break ;
149- case bundle_state::executable:
150- MDeviceImages.push_back (detail::ProgramManager::getInstance ().build (
151- DeviceImage, MDevices, PropList));
160+ }
161+
162+ case bundle_state::executable: {
163+ device_image_plain BuiltImg =
164+ detail::ProgramManager::getInstance ().build (DevImgWithDeps,
165+ MDevices, PropList);
166+ MDeviceImages.emplace_back (BuiltImg);
167+ MUniqueDeviceImages.push_back (BuiltImg);
152168 break ;
169+ }
153170 case bundle_state::input:
154171 case bundle_state::ext_oneapi_source:
155172 throw exception (make_error_code (errc::runtime),
@@ -158,6 +175,7 @@ class kernel_bundle_impl {
158175 break ;
159176 }
160177 }
178+ removeDuplicateImages ();
161179 }
162180
163181 // Matches sycl::link
@@ -201,7 +219,7 @@ class kernel_bundle_impl {
201219 " Not all devices are in the set of associated "
202220 " devices for input bundles" );
203221
204- // TODO: Unify with c'tor for sycl::comile and sycl::build by calling
222+ // TODO: Unify with c'tor for sycl::compile and sycl::build by calling
205223 // sycl::join on vector of kernel_bundles
206224
207225 // The loop below just links each device image separately, not linking any
@@ -213,23 +231,27 @@ class kernel_bundle_impl {
213231 // undefined symbols, then the logic in this loop will need to be changed.
214232 for (const kernel_bundle<bundle_state::object> &ObjectBundle :
215233 ObjectBundles) {
216- for (const device_image_plain &DeviceImage : ObjectBundle) {
234+ for (const DevImgPlainWithDeps &DeviceImageWithDeps :
235+ getSyclObjImpl (ObjectBundle)->MDeviceImages ) {
217236
218237 // Skip images which are not compatible with devices provided
219238 if (std::none_of (MDevices.begin (), MDevices.end (),
220- [&DeviceImage ](const device &Dev) {
221- return getSyclObjImpl (DeviceImage )
239+ [&DeviceImageWithDeps ](const device &Dev) {
240+ return getSyclObjImpl (DeviceImageWithDeps. getMain () )
222241 ->compatible_with_device (Dev);
223242 }))
224243 continue ;
225244
226245 std::vector<device_image_plain> LinkedResults =
227- detail::ProgramManager::getInstance ().link (DeviceImage, MDevices ,
228- PropList);
246+ detail::ProgramManager::getInstance ().link (DeviceImageWithDeps ,
247+ MDevices, PropList);
229248 MDeviceImages.insert (MDeviceImages.end (), LinkedResults.begin (),
230249 LinkedResults.end ());
250+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
251+ LinkedResults.begin (), LinkedResults.end ());
231252 }
232253 }
254+ removeDuplicateImages ();
233255
234256 for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
235257 const KernelBundleImplPtr BundlePtr = getSyclObjImpl (Bundle);
@@ -249,6 +271,7 @@ class kernel_bundle_impl {
249271
250272 MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
251273 MContext, MDevices, KernelIDs, State);
274+ fillUniqueDeviceImages ();
252275 }
253276
254277 kernel_bundle_impl (context Ctx, std::vector<device> Devs,
@@ -259,6 +282,7 @@ class kernel_bundle_impl {
259282
260283 MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
261284 MContext, MDevices, Selector, State);
285+ fillUniqueDeviceImages ();
262286 }
263287
264288 // C'tor matches sycl::join API
@@ -287,11 +311,10 @@ class kernel_bundle_impl {
287311 Bundle->MDeviceImages .end ());
288312 }
289313
290- std::sort (MDeviceImages.begin (), MDeviceImages.end (),
291- LessByHash<device_image_plain>{});
314+ fillUniqueDeviceImages ();
292315
293316 if (get_bundle_state () == bundle_state::input) {
294- // Copy spec constants values from the device images to be removed .
317+ // Copy spec constants values from the device images.
295318 auto MergeSpecConstants = [this ](const device_image_plain &Img) {
296319 const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl (Img);
297320 const std::map<std::string,
@@ -310,16 +333,10 @@ class kernel_bundle_impl {
310333 SpecConst.second .back ().Size );
311334 }
312335 };
313- std::for_each (MDeviceImages .begin (), MDeviceImages .end (),
336+ std::for_each (MUniqueDeviceImages .begin (), MUniqueDeviceImages .end (),
314337 MergeSpecConstants);
315338 }
316339
317- const auto DevImgIt =
318- std::unique (MDeviceImages.begin (), MDeviceImages.end ());
319-
320- // Remove duplicate device images.
321- MDeviceImages.erase (DevImgIt, MDeviceImages.end ());
322-
323340 for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
324341 for (const std::pair<const std::string, std::vector<unsigned char >>
325342 &SpecConst : Bundle->MSpecConstValues ) {
@@ -605,7 +622,7 @@ class kernel_bundle_impl {
605622
606623 assert (MDeviceImages.size () > 0 );
607624 const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
608- detail::getSyclObjImpl (MDeviceImages[0 ]);
625+ detail::getSyclObjImpl (MDeviceImages[0 ]. getMain () );
609626 ur_program_handle_t UrProgram = DeviceImageImpl->get_ur_program_ref ();
610627 ContextImplPtr ContextImpl = getSyclObjImpl (MContext);
611628 const AdapterPtr &Adapter = ContextImpl->getAdapter ();
@@ -634,7 +651,7 @@ class kernel_bundle_impl {
634651 // Collect kernel ids from all device images, then remove duplicates
635652
636653 std::vector<kernel_id> Result;
637- for (const device_image_plain &DeviceImage : MDeviceImages ) {
654+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages ) {
638655 const std::vector<kernel_id> &KernelIDs =
639656 getSyclObjImpl (DeviceImage)->get_kernel_ids ();
640657
@@ -662,8 +679,9 @@ class kernel_bundle_impl {
662679 // Used to track if any of the candidate images has specialization values
663680 // set.
664681 bool SpecConstsSet = false ;
665- for (auto &DeviceImage : MDeviceImages) {
666- if (!DeviceImage.has_kernel (KernelID))
682+ for (const DevImgPlainWithDeps &DeviceImageWithDeps : MDeviceImages) {
683+ const device_image_plain &DeviceImage = DeviceImageWithDeps.getMain ();
684+ if (!DeviceImageWithDeps.getMain ().has_kernel (KernelID))
667685 continue ;
668686
669687 const auto DeviceImageImpl = detail::getSyclObjImpl (DeviceImage);
@@ -718,39 +736,39 @@ class kernel_bundle_impl {
718736 }
719737
720738 bool has_kernel (const kernel_id &KernelID) const noexcept {
721- return std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
739+ return std::any_of (begin (), end (),
722740 [&KernelID](const device_image_plain &DeviceImage) {
723741 return DeviceImage.has_kernel (KernelID);
724742 });
725743 }
726744
727745 bool has_kernel (const kernel_id &KernelID, const device &Dev) const noexcept {
728746 return std::any_of (
729- MDeviceImages. begin (), MDeviceImages. end (),
747+ begin (), end (),
730748 [&KernelID, &Dev](const device_image_plain &DeviceImage) {
731749 return DeviceImage.has_kernel (KernelID, Dev);
732750 });
733751 }
734752
735753 bool contains_specialization_constants () const noexcept {
736754 return std::any_of (
737- MDeviceImages .begin (), MDeviceImages .end (),
755+ MUniqueDeviceImages .begin (), MUniqueDeviceImages .end (),
738756 [](const device_image_plain &DeviceImage) {
739757 return getSyclObjImpl (DeviceImage)->has_specialization_constants ();
740758 });
741759 }
742760
743761 bool native_specialization_constant () const noexcept {
744762 return contains_specialization_constants () &&
745- std::all_of (MDeviceImages .begin (), MDeviceImages .end (),
763+ std::all_of (MUniqueDeviceImages .begin (), MUniqueDeviceImages .end (),
746764 [](const device_image_plain &DeviceImage) {
747765 return getSyclObjImpl (DeviceImage)
748766 ->all_specialization_constant_native ();
749767 });
750768 }
751769
752770 bool has_specialization_constant (const char *SpecName) const noexcept {
753- return std::any_of (MDeviceImages .begin (), MDeviceImages .end (),
771+ return std::any_of (MUniqueDeviceImages .begin (), MUniqueDeviceImages .end (),
754772 [SpecName](const device_image_plain &DeviceImage) {
755773 return getSyclObjImpl (DeviceImage)
756774 ->has_specialization_constant (SpecName);
@@ -761,7 +779,7 @@ class kernel_bundle_impl {
761779 const void *Value,
762780 size_t Size) noexcept {
763781 if (has_specialization_constant (SpecName))
764- for (const device_image_plain &DeviceImage : MDeviceImages )
782+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages )
765783 getSyclObjImpl (DeviceImage)
766784 ->set_specialization_constant_raw_value (SpecName, Value);
767785 else {
@@ -773,7 +791,7 @@ class kernel_bundle_impl {
773791
774792 void get_specialization_constant_raw_value (const char *SpecName,
775793 void *ValueRet) const noexcept {
776- for (const device_image_plain &DeviceImage : MDeviceImages )
794+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages )
777795 if (getSyclObjImpl (DeviceImage)->has_specialization_constant (SpecName)) {
778796 getSyclObjImpl (DeviceImage)
779797 ->get_specialization_constant_raw_value (SpecName, ValueRet);
@@ -796,21 +814,21 @@ class kernel_bundle_impl {
796814
797815 bool is_specialization_constant_set (const char *SpecName) const noexcept {
798816 bool SetInDevImg =
799- std::any_of (MDeviceImages .begin (), MDeviceImages .end (),
817+ std::any_of (MUniqueDeviceImages .begin (), MUniqueDeviceImages .end (),
800818 [SpecName](const device_image_plain &DeviceImage) {
801819 return getSyclObjImpl (DeviceImage)
802820 ->is_specialization_constant_set (SpecName);
803821 });
804822 return SetInDevImg || MSpecConstValues.count (std::string{SpecName}) != 0 ;
805823 }
806824
807- const device_image_plain *begin () const { return MDeviceImages .data (); }
825+ const device_image_plain *begin () const { return MUniqueDeviceImages .data (); }
808826
809827 const device_image_plain *end () const {
810- return MDeviceImages .data () + MDeviceImages .size ();
828+ return MUniqueDeviceImages .data () + MUniqueDeviceImages .size ();
811829 }
812830
813- size_t size () const noexcept { return MDeviceImages .size (); }
831+ size_t size () const noexcept { return MUniqueDeviceImages .size (); }
814832
815833 bundle_state get_bundle_state () const { return MState; }
816834
@@ -827,7 +845,7 @@ class kernel_bundle_impl {
827845
828846 // First try and get images in current bundle state
829847 const bundle_state BundleState = get_bundle_state ();
830- std::vector<device_image_plain > NewDevImgs =
848+ std::vector<DevImgPlainWithDeps > NewDevImgs =
831849 detail::ProgramManager::getInstance ().getSYCLDeviceImages (
832850 MContext, {Dev}, {KernelID}, BundleState);
833851
@@ -836,21 +854,38 @@ class kernel_bundle_impl {
836854 return false ;
837855
838856 // Propagate already set specialization constants to the new images
839- for (device_image_plain &DevImg : NewDevImgs)
840- for (auto SpecConst : MSpecConstValues)
841- getSyclObjImpl (DevImg)->set_specialization_constant_raw_value (
842- SpecConst.first .c_str (), SpecConst.second .data ());
857+ for (DevImgPlainWithDeps &DevImgWithDeps : NewDevImgs)
858+ for (device_image_plain &DevImg : DevImgWithDeps)
859+ for (auto SpecConst : MSpecConstValues)
860+ getSyclObjImpl (DevImg)->set_specialization_constant_raw_value (
861+ SpecConst.first .c_str (), SpecConst.second .data ());
843862
844863 // Add the images to the collection
845864 MDeviceImages.insert (MDeviceImages.end (), NewDevImgs.begin (),
846865 NewDevImgs.end ());
866+ removeDuplicateImages ();
847867 return true ;
848868 }
849869
850870private:
871+ void fillUniqueDeviceImages () {
872+ assert (MUniqueDeviceImages.empty ());
873+ for (const DevImgPlainWithDeps &Imgs : MDeviceImages)
874+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (), Imgs.begin (),
875+ Imgs.end ());
876+ removeDuplicateImages ();
877+ }
878+ void removeDuplicateImages () {
879+ std::sort (MUniqueDeviceImages.begin (), MUniqueDeviceImages.end (),
880+ LessByHash<device_image_plain>{});
881+ const auto It =
882+ std::unique (MUniqueDeviceImages.begin (), MUniqueDeviceImages.end ());
883+ MUniqueDeviceImages.erase (It, MUniqueDeviceImages.end ());
884+ }
851885 context MContext;
852886 std::vector<device> MDevices;
853- std::vector<device_image_plain> MDeviceImages;
887+ std::vector<DevImgPlainWithDeps> MDeviceImages;
888+ std::vector<device_image_plain> MUniqueDeviceImages;
854889 // This map stores values for specialization constants, that are missing
855890 // from any device image.
856891 SpecConstMapT MSpecConstValues;
0 commit comments