Skip to content

Commit 55b417a

Browse files
authored
[Offload] Cache symbols in program (#148209)
When creating a new symbol, check that it already exists. If it does, return that pointer rather than building a new symbol structure.
1 parent bd0f9dd commit 55b417a

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,20 @@ struct ol_program_impl_t {
8484
DeviceImage(DeviceImage) {}
8585
plugin::DeviceImageTy *Image;
8686
std::unique_ptr<llvm::MemoryBuffer> ImageData;
87-
std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
87+
std::mutex SymbolListMutex;
8888
__tgt_device_image DeviceImage;
89+
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
90+
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
8991
};
9092

9193
struct ol_symbol_impl_t {
92-
ol_symbol_impl_t(GenericKernelTy *Kernel)
93-
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
94-
ol_symbol_impl_t(GlobalTy &&Global)
95-
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
94+
ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
95+
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
96+
ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
97+
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
9698
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
9799
ol_symbol_kind_t Kind;
100+
llvm::StringRef Name;
98101
};
99102

100103
namespace llvm {
@@ -714,32 +717,40 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
714717
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
715718
auto &Device = Program->Image->getDevice();
716719

720+
std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
721+
717722
switch (Kind) {
718723
case OL_SYMBOL_KIND_KERNEL: {
719-
auto KernelImpl = Device.constructKernel(Name);
720-
if (!KernelImpl)
721-
return KernelImpl.takeError();
724+
auto &Kernel = Program->KernelSymbols[Name];
725+
if (!Kernel) {
726+
auto KernelImpl = Device.constructKernel(Name);
727+
if (!KernelImpl)
728+
return KernelImpl.takeError();
722729

723-
if (auto Err = KernelImpl->init(Device, *Program->Image))
724-
return Err;
730+
if (auto Err = KernelImpl->init(Device, *Program->Image))
731+
return Err;
732+
733+
Kernel = std::make_unique<ol_symbol_impl_t>(KernelImpl->getName(),
734+
&*KernelImpl);
735+
}
725736

726-
*Symbol =
727-
Program->Symbols
728-
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
729-
.get();
737+
*Symbol = Kernel.get();
730738
return Error::success();
731739
}
732740
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
733-
GlobalTy GlobalObj{Name};
734-
if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
735-
Device, *Program->Image, GlobalObj))
736-
return Res;
737-
738-
*Symbol = Program->Symbols
739-
.emplace_back(
740-
std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
741-
.get();
741+
auto &Global = Program->KernelSymbols[Name];
742+
if (!Global) {
743+
GlobalTy GlobalObj{Name};
744+
if (auto Res =
745+
Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
746+
Device, *Program->Image, GlobalObj))
747+
return Res;
748+
749+
Global = std::make_unique<ol_symbol_impl_t>(GlobalObj.getName().c_str(),
750+
std::move(GlobalObj));
751+
}
742752

753+
*Symbol = Global.get();
743754
return Error::success();
744755
}
745756
default:

offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ TEST_P(olGetSymbolKernelTest, Success) {
4141
ASSERT_NE(Kernel, nullptr);
4242
}
4343

44+
TEST_P(olGetSymbolKernelTest, SuccessSamePtr) {
45+
ol_symbol_handle_t KernelA = nullptr;
46+
ol_symbol_handle_t KernelB = nullptr;
47+
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelA));
48+
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelB));
49+
ASSERT_EQ(KernelA, KernelB);
50+
}
51+
4452
TEST_P(olGetSymbolKernelTest, InvalidNullProgram) {
4553
ol_symbol_handle_t Kernel = nullptr;
4654
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
@@ -72,6 +80,16 @@ TEST_P(olGetSymbolGlobalTest, Success) {
7280
ASSERT_NE(Global, nullptr);
7381
}
7482

83+
TEST_P(olGetSymbolGlobalTest, SuccessSamePtr) {
84+
ol_symbol_handle_t GlobalA = nullptr;
85+
ol_symbol_handle_t GlobalB = nullptr;
86+
ASSERT_SUCCESS(
87+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalA));
88+
ASSERT_SUCCESS(
89+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalB));
90+
ASSERT_EQ(GlobalA, GlobalB);
91+
}
92+
7593
TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) {
7694
ol_symbol_handle_t Global = nullptr;
7795
ASSERT_ERROR(

0 commit comments

Comments
 (0)