@@ -102,13 +102,12 @@ static ze_device_handle_t getDevice(const uint32_t driverIdx = 0,
102102}
103103
104104// Returns the default L0 context of the defult driver.
105- static ze_context_handle_t getDefaultContext ( ) {
105+ static ze_context_handle_t getContext ( ze_driver_handle_t driver ) {
106106 thread_local static ze_context_handle_t context;
107107 thread_local static bool isContextInitialised{false };
108108 if (isContextInitialised)
109109 return context;
110110 ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, nullptr , 0 };
111- auto driver = getDriver ();
112111 L0_SAFE_CALL (zeContextCreate (driver, &ctxtDesc, &context));
113112 isContextInitialised = true ;
114113 return context;
@@ -137,7 +136,7 @@ using UniqueZeContext =
137136using UniqueZeCommandList =
138137 std::unique_ptr<std::remove_pointer<ze_command_list_handle_t >::type,
139138 ZeCommandListDeleter>;
140- struct L0RtContext {
139+ struct L0RTContextWrapper {
141140 ze_driver_handle_t driver{nullptr };
142141 ze_device_handle_t device{nullptr };
143142 UniqueZeContext context;
@@ -149,12 +148,12 @@ struct L0RtContext {
149148 UniqueZeCommandList immCmdListCopy;
150149 uint32_t copyEngineMaxMemoryFillPatternSize{-1u };
151150
152- L0RtContext () = default ;
153- L0RtContext (const uint32_t driverIdx = 0 , const int32_t devIdx = 0 )
151+ L0RTContextWrapper () = default ;
152+ L0RTContextWrapper (const uint32_t driverIdx = 0 , const int32_t devIdx = 0 )
154153 : driver(getDriver(driverIdx)), device(getDevice(devIdx)) {
155154 // Create context
156- ze_context_handle_t defaultCtx = getDefaultContext ( );
157- context.reset (defaultCtx );
155+ ze_context_handle_t ctx = getContext (driver );
156+ context.reset (ctx );
158157
159158 // Determine ordinals
160159 uint32_t computeEngineOrdinal = -1u , copyEngineOrdinal = -1u ;
@@ -210,12 +209,12 @@ struct L0RtContext {
210209 context.get (), device, &cmdQueueDesc, &rawCmdListCompute));
211210 immCmdListCompute.reset (rawCmdListCompute);
212211 }
213- L0RtContext (const L0RtContext &) = delete ;
214- L0RtContext &operator =(const L0RtContext &) = delete ;
212+ L0RTContextWrapper (const L0RTContextWrapper &) = delete ;
213+ L0RTContextWrapper &operator =(const L0RTContextWrapper &) = delete ;
215214 // Allow move
216- L0RtContext (L0RtContext &&) noexcept = default ;
217- L0RtContext &operator =(L0RtContext &&) noexcept = default ;
218- ~L0RtContext () = default ;
215+ L0RTContextWrapper (L0RTContextWrapper &&) noexcept = default ;
216+ L0RTContextWrapper &operator =(L0RTContextWrapper &&) noexcept = default ;
217+ ~L0RTContextWrapper () = default ;
219218};
220219
221220struct ZeEventDeleter {
@@ -249,11 +248,15 @@ struct DynamicEventPool {
249248 std::vector<UniqueZeEvent> availableEvents;
250249 std::unordered_map<ze_event_handle_t , UniqueZeEvent> takenEvents;
251250
251+ // Limit the number of events to avoid running out of memory.
252+ // The limit is set to 32K events, which should be sufficient for most use
253+ // cases.
254+ size_t maxEventsCount{32768 }; // 32K events
252255 size_t currentEventsLimit{0 };
253256 size_t currentEventsCnt{0 };
254- L0RtContext *rtCtx;
257+ L0RTContextWrapper *rtCtx;
255258
256- DynamicEventPool (L0RtContext *rtCtx) : rtCtx(rtCtx) {
259+ DynamicEventPool (L0RTContextWrapper *rtCtx) : rtCtx(rtCtx) {
257260 createNewPool (numEventsPerPool);
258261 }
259262
@@ -291,6 +294,9 @@ struct DynamicEventPool {
291294 rawEvent = uniqueEvent.get ();
292295 takenEvents[rawEvent] = std::move (uniqueEvent);
293296 } else {
297+ if (currentEventsCnt >= maxEventsCount) {
298+ throw std::runtime_error (" DynamicEventPool: reached max events limit" );
299+ }
294300 if (currentEventsCnt == currentEventsLimit)
295301 createNewPool (numEventsPerPool);
296302
@@ -322,8 +328,8 @@ struct DynamicEventPool {
322328 }
323329};
324330
325- L0RtContext &getRtContext () {
326- thread_local static L0RtContext rtContext (0 );
331+ L0RTContextWrapper &getRtContext () {
332+ thread_local static L0RTContextWrapper rtContext (0 );
327333 return rtContext;
328334}
329335
@@ -488,7 +494,7 @@ extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
488494template <typename PATTERN_TYPE>
489495void mgpuMemset (void *dst, PATTERN_TYPE value, size_t count,
490496 StreamWrapper *stream) {
491- L0RtContext &rtContext = getRtContext ();
497+ L0RTContextWrapper &rtContext = getRtContext ();
492498 auto listType =
493499 rtContext.copyEngineMaxMemoryFillPatternSize >= sizeof (PATTERN_TYPE)
494500 ? rtContext.immCmdListCopy .get ()
@@ -561,7 +567,7 @@ extern "C" void mgpuSetDefaultDevice(int32_t devIdx) {
561567 catchAll ([&]() {
562568 // For now, a user must ensure that streams and events complete
563569 // and are destroyed before switching a device.
564- getRtContext () = L0RtContext (devIdx);
570+ getRtContext () = L0RTContextWrapper (devIdx);
565571 getDynamicEventPool () = DynamicEventPool (&getRtContext ());
566572 });
567573}
0 commit comments