Skip to content

Commit 4a3fe1e

Browse files
[SYCL] Delete symbol based info with the last image referencing it (#19659)
Prior to this patch, symbol based info (e.g. kernel id, kernel assert usage or images containing an exported symbol) was deleted whenever an image referencing it was removed. This is incorrect since multiple images can contain the same symbol. While unlikely to cause any problems now (since those images usually all get removed with one call to `removeImages` after another), this will cause issues once kernel name based kernel caches start getting cleaned up in the same manner.
1 parent 975c772 commit 4a3fe1e

File tree

3 files changed

+181
-91
lines changed

3 files changed

+181
-91
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,9 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
20422042
}
20432043
m_KernelIDs2BinImage.insert(std::make_pair(It->second, Img.get()));
20442044
KernelIDs->push_back(It->second);
2045+
2046+
// Keep track of image to kernel name reference count for cleanup.
2047+
m_KernelNameRefCount[name]++;
20452048
}
20462049

20472050
cacheKernelUsesAssertInfo(*Img);
@@ -2115,6 +2118,18 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
21152118
addImage(&(DeviceBinary->DeviceBinaries[I]));
21162119
}
21172120

2121+
template <typename MultimapT, typename KeyT, typename ValT>
2122+
void removeFromMultimapByVal(MultimapT &Map, const KeyT &Key, const ValT &Val,
2123+
bool AssertContains = true) {
2124+
auto [RangeBegin, RangeEnd] = Map.equal_range(Key);
2125+
auto It = std::find_if(RangeBegin, RangeEnd,
2126+
[&](const auto &Pair) { return Pair.second == Val; });
2127+
if (!AssertContains && It == RangeEnd)
2128+
return;
2129+
assert(It != RangeEnd);
2130+
Map.erase(It);
2131+
}
2132+
21182133
void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
21192134
if (DeviceBinary->NumDeviceBinaries == 0)
21202135
return;
@@ -2140,44 +2155,68 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
21402155
// Unmap the unique kernel IDs for the offload entries
21412156
for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
21422157
EntriesIt = EntriesIt->Increment()) {
2143-
2158+
detail::KernelNameStrT Name = EntriesIt->GetName();
21442159
// Drop entry for service kernel
2145-
if (std::strstr(EntriesIt->GetName(), "__sycl_service_kernel__")) {
2146-
m_ServiceKernels.erase(EntriesIt->GetName());
2160+
if (Name.find("__sycl_service_kernel__") != std::string::npos) {
2161+
removeFromMultimapByVal(m_ServiceKernels, Name, Img);
21472162
continue;
21482163
}
21492164

21502165
// Exported device functions won't have a kernel ID
2151-
if (m_ExportedSymbolImages.find(EntriesIt->GetName()) !=
2166+
if (m_ExportedSymbolImages.find(std::string(Name)) !=
21522167
m_ExportedSymbolImages.end()) {
21532168
continue;
21542169
}
21552170

2156-
// remove everything associated with this KernelName
2157-
m_KernelUsesAssert.erase(EntriesIt->GetName());
2158-
m_KernelImplicitLocalArgPos.erase(EntriesIt->GetName());
2159-
2160-
if (auto It = m_KernelName2KernelIDs.find(EntriesIt->GetName());
2161-
It != m_KernelName2KernelIDs.end()) {
2162-
m_KernelIDs2BinImage.erase(It->second);
2163-
m_KernelName2KernelIDs.erase(It);
2171+
auto Name2IDIt = m_KernelName2KernelIDs.find(Name);
2172+
if (Name2IDIt != m_KernelName2KernelIDs.end())
2173+
removeFromMultimapByVal(m_KernelIDs2BinImage, Name2IDIt->second, Img);
2174+
2175+
auto RefCountIt = m_KernelNameRefCount.find(Name);
2176+
assert(RefCountIt != m_KernelNameRefCount.end());
2177+
int &RefCount = RefCountIt->second;
2178+
assert(RefCount > 0);
2179+
2180+
// Remove everything associated with this KernelName if this is the last
2181+
// image referencing it.
2182+
if (--RefCount == 0) {
2183+
// TODO aggregate all these maps into a single one since their entries
2184+
// share lifetime.
2185+
m_KernelUsesAssert.erase(Name);
2186+
m_KernelImplicitLocalArgPos.erase(Name);
2187+
m_KernelNameRefCount.erase(RefCountIt);
2188+
if (Name2IDIt != m_KernelName2KernelIDs.end())
2189+
m_KernelName2KernelIDs.erase(Name2IDIt);
21642190
}
21652191
}
21662192

21672193
// Drop reverse mapping
21682194
m_BinImg2KernelIDs.erase(Img);
21692195

2170-
// Unregister exported symbols (needs to happen after the ID unmap loop)
2196+
// Unregister exported symbol -> Img pair (needs to happen after the ID
2197+
// unmap loop)
21712198
for (const sycl_device_binary_property &ESProp :
21722199
Img->getExportedSymbols()) {
2173-
m_ExportedSymbolImages.erase(ESProp->Name);
2200+
removeFromMultimapByVal(m_ExportedSymbolImages, ESProp->Name, Img,
2201+
/*AssertContains*/ false);
21742202
}
21752203

21762204
for (const sycl_device_binary_property &VFProp :
21772205
Img->getVirtualFunctions()) {
21782206
std::string StrValue = DeviceBinaryProperty(VFProp).asCString();
2179-
for (const auto &SetName : detail::split_string(StrValue, ','))
2180-
m_VFSet2BinImage.erase(SetName);
2207+
// Unregister the image from all referenced virtual function sets.
2208+
for (const auto &VFSetName : detail::split_string(StrValue, ',')) {
2209+
auto It = m_VFSet2BinImage.find(VFSetName);
2210+
assert(It != m_VFSet2BinImage.end());
2211+
std::set<const RTDeviceBinaryImage *> &ImgSet = It->second;
2212+
auto ImgIt = ImgSet.find(Img);
2213+
assert(ImgIt != ImgSet.end());
2214+
ImgSet.erase(ImgIt);
2215+
// If no images referencing this virtual function set remain, drop
2216+
// it from the map.
2217+
if (ImgSet.empty())
2218+
m_VFSet2BinImage.erase(It);
2219+
}
21812220
}
21822221

21832222
m_DeviceGlobals.eraseEntries(Img);

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,12 @@ class ProgramManager {
459459
/// \ref Sync::getGlobalLock() while holding this mutex.
460460
std::mutex m_KernelIDsMutex;
461461

462+
/// Keeps track of binary image to kernel name reference count.
463+
/// Used for checking if the last image referencing the kernel name
464+
/// is removed in order to trigger cleanup of kernel name based information.
465+
/// Access must be guarded by the m_KernelIDsMutex mutex.
466+
std::unordered_map<KernelNameStrT, int> m_KernelNameRefCount;
467+
462468
/// Caches all found service kernels to expedite future checks. A SYCL service
463469
/// kernel is a kernel that has not been defined by the user but is instead
464470
/// generated by the SYCL runtime. Service kernel name types must be declared

sycl/unittests/program_manager/Cleanup.cpp

Lines changed: 120 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class ProgramManagerExposed : public sycl::detail::ProgramManager {
6161
return NativePrograms;
6262
}
6363

64+
std::unordered_map<sycl::detail::KernelNameStrT, int> &
65+
getKernelNameRefCount() {
66+
return m_KernelNameRefCount;
67+
}
68+
6469
std::unordered_map<const sycl::detail::RTDeviceBinaryImage *,
6570
std::unordered_map<sycl::detail::KernelNameStrT,
6671
sycl::detail::KernelArgMask>> &
@@ -132,6 +137,16 @@ std::string generateRefName(const std::string &ImageId,
132137
return FeatureName + "_" + ImageId;
133138
}
134139

140+
std::vector<std::string>
141+
generateRefNames(const std::vector<std::string> &ImageIds,
142+
const std::string &FeatureName) {
143+
std::vector<std::string> RefNames;
144+
RefNames.reserve(ImageIds.size());
145+
for (const std::string &ImageId : ImageIds)
146+
RefNames.push_back(generateRefName(ImageId, FeatureName));
147+
return RefNames;
148+
}
149+
135150
sycl::ext::oneapi::experimental::device_global<int> DeviceGlobalA;
136151
sycl::ext::oneapi::experimental::device_global<int> DeviceGlobalB;
137152
sycl::ext::oneapi::experimental::device_global<int> DeviceGlobalC;
@@ -143,7 +158,8 @@ using PipeA = sycl::ext::intel::experimental::pipe<PipeIDA, int, 10>;
143158
using PipeB = sycl::ext::intel::experimental::pipe<PipeIDB, int, 10>;
144159
using PipeC = sycl::ext::intel::experimental::pipe<PipeIDC, int, 10>;
145160

146-
sycl::unittest::MockDeviceImage generateImage(const std::string &ImageId) {
161+
sycl::unittest::MockDeviceImage generateImage(const std::string &ImageId,
162+
bool AddHostPipes = true) {
147163
sycl::unittest::MockPropertySet PropSet;
148164

149165
std::initializer_list<std::string> KernelNames{
@@ -181,11 +197,11 @@ sycl::unittest::MockDeviceImage generateImage(const std::string &ImageId) {
181197
std::vector<sycl::unittest::MockProperty>{
182198
sycl::unittest::makeDeviceGlobalInfo(
183199
generateRefName(ImageId, "DeviceGlobal"), sizeof(int), 0)});
184-
185-
PropSet.insert(__SYCL_PROPERTY_SET_SYCL_HOST_PIPES,
186-
std::vector<sycl::unittest::MockProperty>{
187-
sycl::unittest::makeHostPipeInfo(
188-
generateRefName(ImageId, "HostPipe"), sizeof(int))});
200+
if (AddHostPipes)
201+
PropSet.insert(__SYCL_PROPERTY_SET_SYCL_HOST_PIPES,
202+
std::vector<sycl::unittest::MockProperty>{
203+
sycl::unittest::makeHostPipeInfo(
204+
generateRefName(ImageId, "HostPipe"), sizeof(int))});
189205
std::vector<unsigned char> Bin{0};
190206

191207
std::vector<sycl::unittest::MockOffloadEntry> Entries =
@@ -229,6 +245,11 @@ static std::array<sycl::unittest::MockDeviceImage, 2> ImagesToKeep = {
229245
static std::array<sycl::unittest::MockDeviceImage, 1> ImagesToRemove = {
230246
generateImage("C")};
231247

248+
static std::array<sycl::unittest::MockDeviceImage, 1> ImagesToKeepSameEntries =
249+
{generateImage("A", /*AddHostPipe*/ false)};
250+
static std::array<sycl::unittest::MockDeviceImage, 1>
251+
ImagesToRemoveSameEntries = {generateImage("A", /*AddHostPipe*/ false)};
252+
232253
static std::array<sycl::unittest::MockDeviceImage, 2> ImagesToKeepKernelOnly = {
233254
generateImageKernelOnly("A"), generateImageKernelOnly("B")};
234255
static std::array<sycl::unittest::MockDeviceImage, 1> ImagesToRemoveKernelOnly =
@@ -251,76 +272,75 @@ void convertAndAddImages(
251272
PM.addImages(&AllBinaries);
252273
}
253274

254-
void checkAllInvolvedContainers(ProgramManagerExposed &PM, size_t ExpectedCount,
255-
const std::string &Comment) {
256-
EXPECT_EQ(PM.getKernelID2BinImage().size(), ExpectedCount) << Comment;
257-
{
258-
EXPECT_EQ(PM.getKernelName2KernelID().size(), ExpectedCount) << Comment;
259-
EXPECT_TRUE(
260-
PM.getKernelName2KernelID().count(generateRefName("A", "Kernel")) > 0)
261-
<< Comment;
262-
EXPECT_TRUE(
263-
PM.getKernelName2KernelID().count(generateRefName("B", "Kernel")) > 0)
264-
<< Comment;
265-
}
266-
EXPECT_EQ(PM.getBinImage2KernelId().size(), ExpectedCount) << Comment;
267-
{
268-
EXPECT_EQ(PM.getServiceKernels().size(), ExpectedCount) << Comment;
269-
EXPECT_TRUE(PM.getServiceKernels().count(
270-
generateRefName("A", "__sycl_service_kernel__")) > 0)
271-
<< Comment;
272-
EXPECT_TRUE(PM.getServiceKernels().count(
273-
generateRefName("B", "__sycl_service_kernel__")) > 0)
274-
<< Comment;
275-
}
276-
{
277-
EXPECT_EQ(PM.getExportedSymbolImages().size(), ExpectedCount) << Comment;
278-
EXPECT_TRUE(PM.getExportedSymbolImages().count(
279-
generateRefName("A", "Exported")) > 0)
280-
<< Comment;
281-
EXPECT_TRUE(PM.getExportedSymbolImages().count(
282-
generateRefName("B", "Exported")) > 0)
283-
<< Comment;
284-
}
285-
EXPECT_EQ(PM.getDeviceImages().size(), ExpectedCount) << Comment;
286-
{
287-
EXPECT_EQ(PM.getVFSet2BinImage().size(), ExpectedCount) << Comment;
288-
EXPECT_TRUE(PM.getVFSet2BinImage().count(generateRefName("A", "VF")) > 0)
289-
<< Comment;
290-
EXPECT_TRUE(PM.getVFSet2BinImage().count(generateRefName("B", "VF")) > 0)
291-
<< Comment;
275+
template <typename T>
276+
void checkContainer(const T &Container, size_t ExpectedCount,
277+
const std::vector<std::string> &ExpectedEntries,
278+
const std::string &Comment) {
279+
EXPECT_EQ(Container.size(), ExpectedCount) << Comment;
280+
for (const std::string &Entry : ExpectedEntries) {
281+
EXPECT_TRUE(Container.count(Entry) > 0) << Comment;
292282
}
283+
}
293284

294-
EXPECT_EQ(PM.getEliminatedKernelArgMask().size(), ExpectedCount) << Comment;
295-
{
296-
EXPECT_EQ(PM.getKernelUsesAssert().size(), ExpectedCount) << Comment;
297-
EXPECT_TRUE(PM.getKernelUsesAssert().count(generateRefName("A", "Kernel")) >
298-
0)
299-
<< Comment;
300-
EXPECT_TRUE(PM.getKernelUsesAssert().count(generateRefName("B", "Kernel")) >
301-
0)
302-
<< Comment;
303-
}
304-
EXPECT_EQ(PM.getKernelImplicitLocalArgPos().size(), ExpectedCount) << Comment;
305-
306-
{
307-
sycl::detail::DeviceGlobalMap &DeviceGlobalMap = PM.getDeviceGlobals();
308-
EXPECT_EQ(DeviceGlobalMap.size(), ExpectedCount) << Comment;
309-
EXPECT_TRUE(DeviceGlobalMap.count(generateRefName("A", "DeviceGlobal")) > 0)
310-
<< Comment;
311-
EXPECT_TRUE(DeviceGlobalMap.count(generateRefName("B", "DeviceGlobal")) > 0)
312-
<< Comment;
313-
EXPECT_EQ(DeviceGlobalMap.getPointerMap().size(), ExpectedCount) << Comment;
285+
void checkAllInvolvedContainers(ProgramManagerExposed &PM,
286+
size_t ExpectedImgCount,
287+
size_t ExpectedEntryCount,
288+
const std::vector<std::string> &ImgIds,
289+
const std::string &CommentPostfix,
290+
bool MultipleImgsPerEntryTestCase = false) {
291+
EXPECT_EQ(PM.getKernelID2BinImage().size(), ExpectedImgCount)
292+
<< "KernelID2BinImg " + CommentPostfix;
293+
checkContainer(PM.getKernelName2KernelID(), ExpectedEntryCount,
294+
generateRefNames(ImgIds, "Kernel"),
295+
"KernelName2KernelID " + CommentPostfix);
296+
EXPECT_EQ(PM.getBinImage2KernelId().size(), ExpectedImgCount)
297+
<< CommentPostfix;
298+
checkContainer(PM.getServiceKernels(), ExpectedImgCount,
299+
generateRefNames(ImgIds, "__sycl_service_kernel__"),
300+
"Service kernels " + CommentPostfix);
301+
checkContainer(PM.getExportedSymbolImages(), ExpectedImgCount,
302+
generateRefNames(ImgIds, "Exported"),
303+
"Exported symbol images " + CommentPostfix);
304+
EXPECT_EQ(PM.getDeviceImages().size(), ExpectedImgCount)
305+
<< "Device images " + CommentPostfix;
306+
307+
checkContainer(PM.getVFSet2BinImage(), ExpectedEntryCount,
308+
generateRefNames(ImgIds, "VF"),
309+
"VFSet2BinImage " + CommentPostfix);
310+
checkContainer(PM.getKernelNameRefCount(), ExpectedEntryCount,
311+
generateRefNames(ImgIds, "Kernel"),
312+
"Kernel name reference count " + CommentPostfix);
313+
EXPECT_EQ(PM.getEliminatedKernelArgMask().size(), ExpectedImgCount)
314+
<< "Eliminated kernel arg mask " + CommentPostfix;
315+
checkContainer(PM.getKernelUsesAssert(), ExpectedEntryCount,
316+
generateRefNames(ImgIds, "Kernel"),
317+
"KernelUsesAssert " + CommentPostfix);
318+
EXPECT_EQ(PM.getKernelImplicitLocalArgPos().size(), ExpectedEntryCount)
319+
<< "Kernel implicit local arg pos " + CommentPostfix;
320+
321+
if (!MultipleImgsPerEntryTestCase) {
322+
// FIXME expected to fail for now, device globals cleanup seems to be
323+
// purging all info for symbols associated with the removed image.
324+
checkContainer(PM.getDeviceGlobals(), ExpectedEntryCount,
325+
generateRefNames(ImgIds, "DeviceGlobal"),
326+
"Device globals " + CommentPostfix);
327+
328+
// The test case with the same entries in multiple images doesn't support
329+
// host pipes since those are assumed to be unique.
330+
checkContainer(PM.getHostPipes(), ExpectedEntryCount,
331+
generateRefNames(ImgIds, "HostPipe"),
332+
"Host pipes " + CommentPostfix);
333+
EXPECT_EQ(PM.getPtrToHostPipe().size(), ExpectedEntryCount)
334+
<< "Pointer to host pipe " + CommentPostfix;
314335
}
336+
}
315337

316-
{
317-
EXPECT_EQ(PM.getHostPipes().size(), ExpectedCount) << Comment;
318-
EXPECT_TRUE(PM.getHostPipes().count(generateRefName("A", "HostPipe")) > 0)
319-
<< Comment;
320-
EXPECT_TRUE(PM.getHostPipes().count(generateRefName("B", "HostPipe")) > 0)
321-
<< Comment;
322-
}
323-
EXPECT_EQ(PM.getPtrToHostPipe().size(), ExpectedCount) << Comment;
338+
void checkAllInvolvedContainers(ProgramManagerExposed &PM, size_t ExpectedCount,
339+
const std::vector<std::string> &ImgIds,
340+
const std::string &CommentPostfix,
341+
bool CheckHostPipes = false) {
342+
checkAllInvolvedContainers(PM, ExpectedCount, ExpectedCount, ImgIds,
343+
CommentPostfix, CheckHostPipes);
324344
}
325345

326346
TEST(ImageRemoval, BaseContainers) {
@@ -348,12 +368,37 @@ TEST(ImageRemoval, BaseContainers) {
348368
generateRefName("C", "HostPipe").c_str());
349369

350370
checkAllInvolvedContainers(PM, ImagesToRemove.size() + ImagesToKeep.size(),
351-
"Check failed before removal");
371+
{"A", "B", "C"}, "check failed before removal");
372+
373+
PM.removeImages(&TestBinaries);
374+
375+
checkAllInvolvedContainers(PM, ImagesToKeep.size(), {"A", "B"},
376+
"check failed after removal");
377+
}
378+
379+
TEST(ImageRemoval, MultipleImagesPerEntry) {
380+
ProgramManagerExposed PM;
381+
382+
sycl_device_binary_struct NativeImages[ImagesToKeepSameEntries.size()];
383+
sycl_device_binaries_struct AllBinaries;
384+
convertAndAddImages(PM, ImagesToKeepSameEntries, NativeImages, AllBinaries);
385+
386+
sycl_device_binary_struct
387+
NativeImagesForRemoval[ImagesToRemoveSameEntries.size()];
388+
sycl_device_binaries_struct TestBinaries;
389+
convertAndAddImages(PM, ImagesToRemoveSameEntries, NativeImagesForRemoval,
390+
TestBinaries);
391+
392+
checkAllInvolvedContainers(
393+
PM, ImagesToRemoveSameEntries.size() + ImagesToKeepSameEntries.size(),
394+
/*ExpectedEntryCount*/ 1, {"A"}, "check failed before removal",
395+
/*MultipleImgsPerEntryTestCase*/ true);
352396

353397
PM.removeImages(&TestBinaries);
354398

355-
checkAllInvolvedContainers(PM, ImagesToKeep.size(),
356-
"Check failed after removal");
399+
checkAllInvolvedContainers(PM, ImagesToKeepSameEntries.size(), {"A"},
400+
"check failed after removal",
401+
/*MultipleImgsPerEntryTestCase*/ true);
357402
}
358403

359404
TEST(ImageRemoval, NativePrograms) {

0 commit comments

Comments
 (0)