Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 28 additions & 46 deletions source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ ur_result_t ur_context_handle_t_::finalize() {
}
{
std::scoped_lock<ur_mutex> Lock(ZeEventPoolCacheMutex);
for (auto &ZePoolCache : ZeEventPoolCaches) {
for (auto &ZePoolCache : ZeEventPoolCache) {
for (auto &ZePool : ZePoolCache) {
auto ZeResult = ZE_CALL_NOCHECK(zeEventPoolDestroy, (ZePool));
// Gracefully handle the case that L0 was already unloaded.
Expand Down Expand Up @@ -494,21 +494,21 @@ static const uint32_t MaxNumEventsPerPool = [] {
}();

ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
ze_event_pool_handle_t &Pool, size_t &Index, v2::event_flags_t Flags,
ur_device_handle_t Device) {
ze_event_pool_handle_t &Pool, size_t &Index, bool HostVisible,
bool ProfilingEnabled, ur_device_handle_t Device,
bool CounterBasedEventEnabled, bool UsingImmCmdList,
bool InterruptBasedEventEnabled) {
// Lock while updating event pool machinery.
std::scoped_lock<ur_mutex> Lock(ZeEventPoolCacheMutex);

ze_device_handle_t ZeDevice = nullptr;
size_t DeviceId;

if (Device) {
ZeDevice = Device->ZeDevice;
DeviceId =
Device->Id.has_value() ? static_cast<size_t>(Device->Id.value()) : 0;
}
std::list<ze_event_pool_handle_t> *ZePoolCache =
getZeEventPoolCache(Flags, ZeDevice, DeviceId);
std::list<ze_event_pool_handle_t> *ZePoolCache = getZeEventPoolCache(
HostVisible, ProfilingEnabled, CounterBasedEventEnabled, UsingImmCmdList,
InterruptBasedEventEnabled, ZeDevice);

if (!ZePoolCache->empty()) {
if (NumEventsAvailableInEventPool[ZePoolCache->front()] == 0) {
Expand Down Expand Up @@ -546,26 +546,26 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
ZeEventPoolDesc.count = MaxNumEventsPerPool;
ZeEventPoolDesc.flags = 0;
ZeEventPoolDesc.pNext = nullptr;
if (Flags & v2::EVENT_FLAGS_HOST_VISIBLE)
if (HostVisible)
ZeEventPoolDesc.flags |= ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
if (Flags & v2::EVENT_FLAGS_PROFILING_ENABLED)
if (ProfilingEnabled)
ZeEventPoolDesc.flags |= ZE_EVENT_POOL_FLAG_KERNEL_TIMESTAMP;
logger::debug("ze_event_pool_desc_t flags set to: {}",
ZeEventPoolDesc.flags);
if (Flags & v2::EVENT_FLAGS_COUNTER) {
if (Flags & v2::EVENT_FLAGS_IMM_CMDLIST) {
if (CounterBasedEventEnabled) {
if (UsingImmCmdList) {
counterBasedExt.flags = ZE_EVENT_POOL_COUNTER_BASED_EXP_FLAG_IMMEDIATE;
} else {
counterBasedExt.flags =
ZE_EVENT_POOL_COUNTER_BASED_EXP_FLAG_NON_IMMEDIATE;
}
logger::debug("ze_event_pool_desc_t counter based flags set to: {}",
counterBasedExt.flags);
if (Flags & EVENT_FLAG_INTERRUPT) {
if (InterruptBasedEventEnabled) {
counterBasedExt.pNext = &eventSyncMode;
}
ZeEventPoolDesc.pNext = &counterBasedExt;
} else if (Flags & EVENT_FLAG_INTERRUPT) {
} else if (InterruptBasedEventEnabled) {
ZeEventPoolDesc.pNext = &eventSyncMode;
}

Expand All @@ -592,23 +592,18 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
return UR_RESULT_SUCCESS;
}

ur_event_handle_t
ur_context_handle_t_::getEventFromContextCache(v2::event_flags_t Flags,
ur_device_handle_t Device) {
ur_event_handle_t ur_context_handle_t_::getEventFromContextCache(
bool HostVisible, bool WithProfiling, ur_device_handle_t Device,
bool CounterBasedEventEnabled, bool InterruptBasedEventEnabled) {
std::scoped_lock<ur_mutex> Lock(EventCacheMutex);

auto Cache = getEventCache(Flags & v2::EVENT_FLAGS_HOST_VISIBLE,
Flags & v2::EVENT_FLAGS_PROFILING_ENABLED, Device,
Flags & v2::EVENT_FLAGS_COUNTER,
Flags & v2::EVENT_FLAGS_INTERRUPT);

auto Cache =
getEventCache(HostVisible, WithProfiling, Device,
CounterBasedEventEnabled, InterruptBasedEventEnabled);
if (Cache->empty()) {
logger::info("Cache empty (Host Visible: {}, Profiling: {}, Counter: {}, "
"Interrupt: {}, Device: {})",
(Flags & v2::EVENT_FLAGS_HOST_VISIBLE),
(Flags & v2::EVENT_FLAGS_PROFILING_ENABLED),
(Flags & v2::EVENT_FLAGS_COUNTER),
(Flags & v2::EVENT_FLAGS_INTERRUPT), Device);
HostVisible, WithProfiling, CounterBasedEventEnabled,
InterruptBasedEventEnabled, Device);
return nullptr;
}

Expand Down Expand Up @@ -637,7 +632,7 @@ void ur_context_handle_t_::addEventToContextCache(ur_event_handle_t Event) {
}

auto Cache = getEventCache(
Event->HostVisibleEvent, Event->isProfilingEnabled(), Device,
Event->isHostVisible(), Event->isProfilingEnabled(), Device,
Event->CounterBasedEventsEnabled, Event->InterruptBasedEventsEnabled);
logger::info("Inserting {} event (Host Visible: {}, Profiling: {}, Counter: "
"{}, Device: {}) into cache {}",
Expand All @@ -658,30 +653,17 @@ ur_context_handle_t_::decrementUnreleasedEventsInPool(ur_event_handle_t Event) {
}

ze_device_handle_t ZeDevice = nullptr;
size_t DeviceId;

bool UsingImmediateCommandlists =
!Event->UrQueue || Event->UrQueue->UsingImmCmdLists;

if (!Event->IsMultiDevice && Event->UrQueue) {
ZeDevice = Event->UrQueue->Device->ZeDevice;
DeviceId = Event->UrQueue->Device->Id.has_value()
? static_cast<size_t>(Event->UrQueue->Device->Id.value())
: 0;
}
v2::event_flags_t Flags = 0;
if (UsingImmediateCommandlists)
Flags |= v2::EVENT_FLAGS_IMM_CMDLIST;
if (Event->isHostVisible())
Flags |= v2::EVENT_FLAGS_HOST_VISIBLE;
if (Event->isProfilingEnabled())
Flags |= v2::EVENT_FLAGS_PROFILING_ENABLED;
if (Event->CounterBasedEventsEnabled)
Flags |= v2::EVENT_FLAGS_COUNTER;
if (Event->InterruptBasedEventsEnabled)
Flags |= v2::EVENT_FLAGS_INTERRUPT;
std::list<ze_event_pool_handle_t> *ZePoolCache =
getZeEventPoolCache(Flags, ZeDevice, DeviceId);

std::list<ze_event_pool_handle_t> *ZePoolCache = getZeEventPoolCache(
Event->isHostVisible(), Event->isProfilingEnabled(),
Event->CounterBasedEventsEnabled, UsingImmediateCommandlists,
Event->InterruptBasedEventsEnabled, ZeDevice);

// Put the empty pool to the cache of the pools.
if (NumEventsUnreleasedInEventPool[Event->ZeEventPool] == 0)
Expand Down
144 changes: 108 additions & 36 deletions source/adapters/level_zero/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <ze_api.h>
#include <zes_api.h>

#include "./v2/event_provider.hpp"
#include "common.hpp"
#include "queue.hpp"

Expand Down Expand Up @@ -169,8 +168,9 @@ struct ur_context_handle_t_ : _ur_object {
// head.
//
// Cache of event pools to which host-visible events are added to.
using ZeEventPoolCache = std::list<ze_event_pool_handle_t>;
std::vector<ZeEventPoolCache> ZeEventPoolCaches;
std::vector<std::list<ze_event_pool_handle_t>> ZeEventPoolCache{30};
std::vector<std::unordered_map<ze_device_handle_t, size_t>>
ZeEventPoolCacheDeviceMap{30};

// This map will be used to determine if a pool is full or not
// by storing number of empty slots available in the pool.
Expand Down Expand Up @@ -213,54 +213,124 @@ struct ur_context_handle_t_ : _ur_object {
// slot for a host-visible event. The ProfilingEnabled tells is we need a
// slot for an event with profiling capabilities.
ur_result_t getFreeSlotInExistingOrNewPool(ze_event_pool_handle_t &, size_t &,
v2::event_flags_t Flags,
ur_device_handle_t Device);
bool HostVisible,
bool ProfilingEnabled,
ur_device_handle_t Device,
bool CounterBasedEventEnabled,
bool UsingImmCmdList,
bool InterruptBasedEventEnabled);

// Get ur_event_handle_t from cache.
ur_event_handle_t getEventFromContextCache(v2::event_flags_t Flags,
ur_device_handle_t Device);
ur_event_handle_t getEventFromContextCache(bool HostVisible,
bool WithProfiling,
ur_device_handle_t Device,
bool CounterBasedEventEnabled,
bool InterruptBasedEventEnabled);

// Add ur_event_handle_t to cache.
void addEventToContextCache(ur_event_handle_t);

enum EventPoolCacheType {
HostVisibleCacheType,
HostInvisibleCacheType,
HostVisibleCounterBasedRegularCacheType,
HostInvisibleCounterBasedRegularCacheType,
HostVisibleCounterBasedImmediateCacheType,
HostInvisibleCounterBasedImmediateCacheType,

HostVisibleInterruptBasedRegularCacheType,
HostInvisibleInterruptBasedRegularCacheType,
HostVisibleInterruptBasedImmediateCacheType,
HostInvisibleInterruptBasedImmediateCacheType,

HostVisibleInterruptAndCounterBasedRegularCacheType,
HostInvisibleInterruptAndCounterBasedRegularCacheType,
HostVisibleInterruptAndCounterBasedImmediateCacheType,
HostInvisibleInterruptAndCounterBasedImmediateCacheType
};

std::list<ze_event_pool_handle_t> *
getZeEventPoolCache(v2::event_flags_t Flags, ze_device_handle_t ZeDevice,
size_t DeviceId) {
size_t index = 0;
index |= uint64_t(Flags);
getZeEventPoolCache(bool HostVisible, bool WithProfiling,
bool CounterBasedEventEnabled, bool UsingImmediateCmdList,
bool InterruptBasedEventEnabled,
ze_device_handle_t ZeDevice) {
EventPoolCacheType CacheType;

calculateCacheIndex(HostVisible, CounterBasedEventEnabled,
UsingImmediateCmdList, InterruptBasedEventEnabled,
CacheType);
if (ZeDevice) {
index |= v2::EVENT_FLAGS_DEVICE | (DeviceId << v2::MAX_EVENT_FLAG_BITS);
}

if (index >= ZeEventPoolCaches.size()) {
ZeEventPoolCaches.resize(index + 1);
auto ZeEventPoolCacheMap =
WithProfiling ? &ZeEventPoolCacheDeviceMap[CacheType * 2]
: &ZeEventPoolCacheDeviceMap[CacheType * 2 + 1];
if (ZeEventPoolCacheMap->find(ZeDevice) == ZeEventPoolCacheMap->end()) {
ZeEventPoolCache.emplace_back();
ZeEventPoolCacheMap->insert(
std::make_pair(ZeDevice, ZeEventPoolCache.size() - 1));
}
return &ZeEventPoolCache[(*ZeEventPoolCacheMap)[ZeDevice]];
} else {
return WithProfiling ? &ZeEventPoolCache[CacheType * 2]
: &ZeEventPoolCache[CacheType * 2 + 1];
}
return &ZeEventPoolCaches[index];
}

/*
std::list<ze_event_pool_handle_t> *
getZeEventPoolCache(v2::event_flags_t Flags, ze_device_handle_t ZeDevice) {
size_t index = 0;
index |= Flags;
bool WithProfiling = Flags & v2::EVENT_FLAGS_PROFILING_ENABLED;

if (ZeDevice) {
auto ZeEventPoolCacheMap =
WithProfiling ? &ZeEventPoolCachesDeviceMap[index * 2]
: &ZeEventPoolCachesDeviceMap[index * 2 + 1];
if (ZeEventPoolCacheMap->find(ZeDevice) == ZeEventPoolCacheMap->end()) {
ZeEventPoolCaches.emplace_back();
ZeEventPoolCacheMap->insert(
std::make_pair(ZeDevice, ZeEventPoolCaches.size() - 1));
ur_result_t calculateCacheIndex(bool HostVisible,
bool CounterBasedEventEnabled,
bool UsingImmediateCmdList,
bool InterruptBasedEventEnabled,
EventPoolCacheType &CacheType) {
if (InterruptBasedEventEnabled) {
if (CounterBasedEventEnabled) {
if (HostVisible) {
if (UsingImmediateCmdList) {
CacheType = HostVisibleInterruptAndCounterBasedImmediateCacheType;
} else {
CacheType = HostVisibleInterruptAndCounterBasedRegularCacheType;
}
} else {
if (UsingImmediateCmdList) {
CacheType = HostInvisibleInterruptAndCounterBasedImmediateCacheType;
} else {
CacheType = HostInvisibleInterruptAndCounterBasedRegularCacheType;
}
}
return &ZeEventPoolCaches[(*ZeEventPoolCacheMap)[ZeDevice]];
} else {
return WithProfiling ? &ZeEventPoolCaches[index * 2]
: &ZeEventPoolCaches[index * 2 + 1];
if (HostVisible) {
if (UsingImmediateCmdList) {
CacheType = HostVisibleInterruptBasedImmediateCacheType;
} else {
CacheType = HostVisibleInterruptBasedRegularCacheType;
}
} else {
if (UsingImmediateCmdList) {
CacheType = HostInvisibleInterruptBasedImmediateCacheType;
} else {
CacheType = HostInvisibleInterruptBasedRegularCacheType;
}
}
}
} else {
if (CounterBasedEventEnabled && HostVisible && !UsingImmediateCmdList) {
CacheType = HostVisibleCounterBasedRegularCacheType;
} else if (CounterBasedEventEnabled && !HostVisible &&
!UsingImmediateCmdList) {
CacheType = HostInvisibleCounterBasedRegularCacheType;
} else if (CounterBasedEventEnabled && HostVisible &&
UsingImmediateCmdList) {
CacheType = HostVisibleCounterBasedImmediateCacheType;
} else if (CounterBasedEventEnabled && !HostVisible &&
UsingImmediateCmdList) {
CacheType = HostInvisibleCounterBasedImmediateCacheType;
} else if (!CounterBasedEventEnabled && HostVisible) {
CacheType = HostVisibleCacheType;
} else {
CacheType = HostInvisibleCacheType;
}
}
*/

return UR_RESULT_SUCCESS;
}

// Decrement number of events living in the pool upon event destroy
// and return the pool to the cache if there are no unreleased events.
Expand Down Expand Up @@ -309,6 +379,7 @@ struct ur_context_handle_t_ : _ur_object {
MAX_EVENT_FLAG_BITS =
5, // this is used as an offset for embedding device id
};

// Mutex to control operations on event caches.
ur_mutex EventCacheMutex;

Expand Down Expand Up @@ -341,6 +412,7 @@ struct ur_context_handle_t_ : _ur_object {
if (index >= EventCaches.size()) {
EventCaches.resize(index + 1);
}

return &EventCaches[index];
}
};
Expand Down
24 changes: 7 additions & 17 deletions source/adapters/level_zero/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,27 +1341,16 @@ ur_result_t EventCreate(ur_context_handle_t Context, ur_queue_handle_t Queue,
bool ProfilingEnabled =
ForceDisableProfiling ? false : (!Queue || Queue->isProfilingEnabled());
bool UsingImmediateCommandlists = !Queue || Queue->UsingImmCmdLists;
v2::event_flags_t Flags = 0;
if (ProfilingEnabled)
Flags |= v2::EVENT_FLAGS_PROFILING_ENABLED;
if (UsingImmediateCommandlists)
Flags |= v2::EVENT_FLAGS_IMM_CMDLIST;
if (HostVisible)
Flags |= v2::EVENT_FLAGS_HOST_VISIBLE;
if (IsMultiDevice)
Flags |= v2::EVENT_FLAGS_MULTIDEVICE;
if (CounterBasedEventEnabled)
Flags |= v2::EVENT_FLAGS_COUNTER;
if (InterruptBasedEventEnabled)
Flags |= v2::EVENT_FLAGS_INTERRUPT;

ur_device_handle_t Device = nullptr;

if (!IsMultiDevice && Queue) {
Device = Queue->Device;
}

if (auto CachedEvent = Context->getEventFromContextCache(Flags, Device)) {
if (auto CachedEvent = Context->getEventFromContextCache(
HostVisible, ProfilingEnabled, Device, CounterBasedEventEnabled,
InterruptBasedEventEnabled)) {
*RetEvent = CachedEvent;
return UR_RESULT_SUCCESS;
}
Expand All @@ -1371,8 +1360,10 @@ ur_result_t EventCreate(ur_context_handle_t Context, ur_queue_handle_t Queue,

size_t Index = 0;

if (auto Res = Context->getFreeSlotInExistingOrNewPool(ZeEventPool, Index,
Flags, Device))
if (auto Res = Context->getFreeSlotInExistingOrNewPool(
ZeEventPool, Index, HostVisible, ProfilingEnabled, Device,
CounterBasedEventEnabled, UsingImmediateCommandlists,
InterruptBasedEventEnabled))
return Res;

ZeStruct<ze_event_desc_t> ZeEventDesc;
Expand Down Expand Up @@ -1409,7 +1400,6 @@ ur_result_t EventCreate(ur_context_handle_t Context, ur_queue_handle_t Queue,
if (HostVisible)
(*RetEvent)->HostVisibleEvent =
reinterpret_cast<ur_event_handle_t>(*RetEvent);
(*RetEvent)->Flags = Flags;

return UR_RESULT_SUCCESS;
}
Expand Down
Loading
Loading