Skip to content
1 change: 1 addition & 0 deletions sycl/source/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class RTDeviceBinaryImage {
ConstIterator begin() const { return ConstIterator(Begin); }
ConstIterator end() const { return ConstIterator(End); }
size_t size() const { return std::distance(begin(), end()); }
bool empty() const { return begin() == end(); }
friend class RTDeviceBinaryImage;
bool isAvailable() const { return !(Begin == nullptr); }

Expand Down
141 changes: 87 additions & 54 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class kernel_bundle_impl {

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, State);
fillUniqueDeviceImages();
}

// Interop constructor used by make_kernel
Expand All @@ -103,7 +104,8 @@ class kernel_bundle_impl {
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
device_image_plain &DevImage)
: kernel_bundle_impl(Ctx, Devs) {
MDeviceImages.push_back(DevImage);
MDeviceImages.emplace_back(DevImage);
MUniqueDeviceImages.emplace_back(DevImage);
}

// Matches sycl::build and sycl::compile
Expand All @@ -115,10 +117,12 @@ class kernel_bundle_impl {
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
MState(TargetState) {

MSpecConstValues = getSyclObjImpl(InputBundle)->get_spec_const_map_ref();
const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
getSyclObjImpl(InputBundle);
MSpecConstValues = InputBundleImpl->get_spec_const_map_ref();

const std::vector<device> &InputBundleDevices =
getSyclObjImpl(InputBundle)->get_devices();
InputBundleImpl->get_devices();
const bool AllDevsAssociatedWithInputBundle =
std::all_of(MDevices.begin(), MDevices.end(),
[&InputBundleDevices](const device &Dev) {
Expand All @@ -132,24 +136,37 @@ class kernel_bundle_impl {
"Not all devices are in the set of associated "
"devices for input bundle or vector of devices is empty");

for (const device_image_plain &DeviceImage : InputBundle) {
for (const DevImgPlainWithDeps &DevImgWithDeps :
InputBundleImpl->MDeviceImages) {
// Skip images which are not compatible with devices provided
if (std::none_of(
MDevices.begin(), MDevices.end(),
[&DeviceImage](const device &Dev) {
return getSyclObjImpl(DeviceImage)->compatible_with_device(Dev);
}))
if (std::none_of(MDevices.begin(), MDevices.end(),
[&DevImgWithDeps](const device &Dev) {
return getSyclObjImpl(DevImgWithDeps.getMain())
->compatible_with_device(Dev);
}))
continue;

switch (TargetState) {
case bundle_state::object:
MDeviceImages.push_back(detail::ProgramManager::getInstance().compile(
DeviceImage, MDevices, PropList));
case bundle_state::object: {
DevImgPlainWithDeps CompiledImgWithDeps =
detail::ProgramManager::getInstance().compile(DevImgWithDeps,
MDevices, PropList);

MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
CompiledImgWithDeps.begin(),
CompiledImgWithDeps.end());
MDeviceImages.push_back(std::move(CompiledImgWithDeps));
break;
case bundle_state::executable:
MDeviceImages.push_back(detail::ProgramManager::getInstance().build(
DeviceImage, MDevices, PropList));
}

case bundle_state::executable: {
device_image_plain BuiltImg =
detail::ProgramManager::getInstance().build(DevImgWithDeps,
MDevices, PropList);
MDeviceImages.emplace_back(BuiltImg);
MUniqueDeviceImages.push_back(BuiltImg);
break;
}
case bundle_state::input:
case bundle_state::ext_oneapi_source:
throw exception(make_error_code(errc::runtime),
Expand All @@ -158,6 +175,7 @@ class kernel_bundle_impl {
break;
}
}
removeDuplicateImages();
}

// Matches sycl::link
Expand Down Expand Up @@ -201,7 +219,7 @@ class kernel_bundle_impl {
"Not all devices are in the set of associated "
"devices for input bundles");

// TODO: Unify with c'tor for sycl::comile and sycl::build by calling
// TODO: Unify with c'tor for sycl::compile and sycl::build by calling
// sycl::join on vector of kernel_bundles

// The loop below just links each device image separately, not linking any
Expand All @@ -213,23 +231,27 @@ class kernel_bundle_impl {
// undefined symbols, then the logic in this loop will need to be changed.
for (const kernel_bundle<bundle_state::object> &ObjectBundle :
ObjectBundles) {
for (const device_image_plain &DeviceImage : ObjectBundle) {
for (const DevImgPlainWithDeps &DeviceImageWithDeps :
getSyclObjImpl(ObjectBundle)->MDeviceImages) {

// Skip images which are not compatible with devices provided
if (std::none_of(MDevices.begin(), MDevices.end(),
[&DeviceImage](const device &Dev) {
return getSyclObjImpl(DeviceImage)
[&DeviceImageWithDeps](const device &Dev) {
return getSyclObjImpl(DeviceImageWithDeps.getMain())
->compatible_with_device(Dev);
}))
continue;

std::vector<device_image_plain> LinkedResults =
detail::ProgramManager::getInstance().link(DeviceImage, MDevices,
PropList);
detail::ProgramManager::getInstance().link(DeviceImageWithDeps,
MDevices, PropList);
MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(),
LinkedResults.end());
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
LinkedResults.begin(), LinkedResults.end());
}
}
removeDuplicateImages();

for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
const KernelBundleImplPtr BundlePtr = getSyclObjImpl(Bundle);
Expand All @@ -249,6 +271,7 @@ class kernel_bundle_impl {

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, KernelIDs, State);
fillUniqueDeviceImages();
}

kernel_bundle_impl(context Ctx, std::vector<device> Devs,
Expand All @@ -259,6 +282,7 @@ class kernel_bundle_impl {

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, Selector, State);
fillUniqueDeviceImages();
}

// C'tor matches sycl::join API
Expand Down Expand Up @@ -287,11 +311,10 @@ class kernel_bundle_impl {
Bundle->MDeviceImages.end());
}

std::sort(MDeviceImages.begin(), MDeviceImages.end(),
LessByHash<device_image_plain>{});
fillUniqueDeviceImages();

if (get_bundle_state() == bundle_state::input) {
// Copy spec constants values from the device images to be removed.
// Copy spec constants values from the device images.
auto MergeSpecConstants = [this](const device_image_plain &Img) {
const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl(Img);
const std::map<std::string,
Expand All @@ -310,16 +333,9 @@ class kernel_bundle_impl {
SpecConst.second.back().Size);
}
};
std::for_each(MDeviceImages.begin(), MDeviceImages.end(),
MergeSpecConstants);
std::for_each(begin(), end(), MergeSpecConstants);
}

const auto DevImgIt =
std::unique(MDeviceImages.begin(), MDeviceImages.end());

// Remove duplicate device images.
MDeviceImages.erase(DevImgIt, MDeviceImages.end());

for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
for (const std::pair<const std::string, std::vector<unsigned char>>
&SpecConst : Bundle->MSpecConstValues) {
Expand Down Expand Up @@ -605,7 +621,7 @@ class kernel_bundle_impl {

assert(MDeviceImages.size() > 0);
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
detail::getSyclObjImpl(MDeviceImages[0]);
detail::getSyclObjImpl(MDeviceImages[0].getMain());
ur_program_handle_t UrProgram = DeviceImageImpl->get_ur_program_ref();
ContextImplPtr ContextImpl = getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();
Expand Down Expand Up @@ -634,7 +650,7 @@ class kernel_bundle_impl {
// Collect kernel ids from all device images, then remove duplicates

std::vector<kernel_id> Result;
for (const device_image_plain &DeviceImage : MDeviceImages) {
for (const device_image_plain &DeviceImage : MUniqueDeviceImages) {
const std::vector<kernel_id> &KernelIDs =
getSyclObjImpl(DeviceImage)->get_kernel_ids();

Expand Down Expand Up @@ -662,8 +678,9 @@ class kernel_bundle_impl {
// Used to track if any of the candidate images has specialization values
// set.
bool SpecConstsSet = false;
for (auto &DeviceImage : MDeviceImages) {
if (!DeviceImage.has_kernel(KernelID))
for (const DevImgPlainWithDeps &DeviceImageWithDeps : MDeviceImages) {
const device_image_plain &DeviceImage = DeviceImageWithDeps.getMain();
if (!DeviceImageWithDeps.getMain().has_kernel(KernelID))
continue;

const auto DeviceImageImpl = detail::getSyclObjImpl(DeviceImage);
Expand Down Expand Up @@ -718,39 +735,38 @@ class kernel_bundle_impl {
}

bool has_kernel(const kernel_id &KernelID) const noexcept {
return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
return std::any_of(begin(), end(),
[&KernelID](const device_image_plain &DeviceImage) {
return DeviceImage.has_kernel(KernelID);
});
}

bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept {
return std::any_of(
MDeviceImages.begin(), MDeviceImages.end(),
begin(), end(),
[&KernelID, &Dev](const device_image_plain &DeviceImage) {
return DeviceImage.has_kernel(KernelID, Dev);
});
}

bool contains_specialization_constants() const noexcept {
return std::any_of(
MDeviceImages.begin(), MDeviceImages.end(),
[](const device_image_plain &DeviceImage) {
begin(), end(), [](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)->has_specialization_constants();
});
}

bool native_specialization_constant() const noexcept {
return contains_specialization_constants() &&
std::all_of(MDeviceImages.begin(), MDeviceImages.end(),
std::all_of(begin(), end(),
[](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->all_specialization_constant_native();
});
}

bool has_specialization_constant(const char *SpecName) const noexcept {
return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
return std::any_of(begin(), end(),
[SpecName](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->has_specialization_constant(SpecName);
Expand All @@ -761,7 +777,7 @@ class kernel_bundle_impl {
const void *Value,
size_t Size) noexcept {
if (has_specialization_constant(SpecName))
for (const device_image_plain &DeviceImage : MDeviceImages)
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
getSyclObjImpl(DeviceImage)
->set_specialization_constant_raw_value(SpecName, Value);
else {
Expand All @@ -773,7 +789,7 @@ class kernel_bundle_impl {

void get_specialization_constant_raw_value(const char *SpecName,
void *ValueRet) const noexcept {
for (const device_image_plain &DeviceImage : MDeviceImages)
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
if (getSyclObjImpl(DeviceImage)->has_specialization_constant(SpecName)) {
getSyclObjImpl(DeviceImage)
->get_specialization_constant_raw_value(SpecName, ValueRet);
Expand All @@ -796,21 +812,21 @@ class kernel_bundle_impl {

bool is_specialization_constant_set(const char *SpecName) const noexcept {
bool SetInDevImg =
std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
std::any_of(begin(), end(),
[SpecName](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->is_specialization_constant_set(SpecName);
});
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
}

const device_image_plain *begin() const { return MDeviceImages.data(); }
const device_image_plain *begin() const { return MUniqueDeviceImages.data(); }

const device_image_plain *end() const {
return MDeviceImages.data() + MDeviceImages.size();
return MUniqueDeviceImages.data() + MUniqueDeviceImages.size();
}

size_t size() const noexcept { return MDeviceImages.size(); }
size_t size() const noexcept { return MUniqueDeviceImages.size(); }

bundle_state get_bundle_state() const { return MState; }

Expand All @@ -827,7 +843,7 @@ class kernel_bundle_impl {

// First try and get images in current bundle state
const bundle_state BundleState = get_bundle_state();
std::vector<device_image_plain> NewDevImgs =
std::vector<DevImgPlainWithDeps> NewDevImgs =
detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, {Dev}, {KernelID}, BundleState);

Expand All @@ -836,21 +852,38 @@ class kernel_bundle_impl {
return false;

// Propagate already set specialization constants to the new images
for (device_image_plain &DevImg : NewDevImgs)
for (auto SpecConst : MSpecConstValues)
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
SpecConst.first.c_str(), SpecConst.second.data());
for (DevImgPlainWithDeps &DevImgWithDeps : NewDevImgs)
for (device_image_plain &DevImg : DevImgWithDeps)
for (auto SpecConst : MSpecConstValues)
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
SpecConst.first.c_str(), SpecConst.second.data());

// Add the images to the collection
MDeviceImages.insert(MDeviceImages.end(), NewDevImgs.begin(),
NewDevImgs.end());
removeDuplicateImages();
return true;
}

private:
void fillUniqueDeviceImages() {
assert(MUniqueDeviceImages.empty());
for (const DevImgPlainWithDeps &Imgs : MDeviceImages)
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(), Imgs.begin(),
Imgs.end());
removeDuplicateImages();
}
void removeDuplicateImages() {
std::sort(MUniqueDeviceImages.begin(), MUniqueDeviceImages.end(),
LessByHash<device_image_plain>{});
const auto It =
std::unique(MUniqueDeviceImages.begin(), MUniqueDeviceImages.end());
MUniqueDeviceImages.erase(It, MUniqueDeviceImages.end());
}
context MContext;
std::vector<device> MDevices;
std::vector<device_image_plain> MDeviceImages;
std::vector<DevImgPlainWithDeps> MDeviceImages;
std::vector<device_image_plain> MUniqueDeviceImages;
// This map stores values for specialization constants, that are missing
// from any device image.
SpecConstMapT MSpecConstValues;
Expand Down
Loading
Loading