2323#include < vector>
2424
2525namespace {
26-
2726template <typename F>
2827auto catchAll (F &&func) {
2928 try {
@@ -41,7 +40,9 @@ auto catchAll(F &&func) {
4140 { \
4241 ze_result_t status = (call); \
4342 if (status != ZE_RESULT_SUCCESS) { \
44- std::cerr << " L0 error " << status << std::endl; \
43+ const char *errorString; \
44+ zeDriverGetLastErrorDescription (NULL , &errorString); \
45+ std::cerr << " L0 error " << status << " : " << errorString << std::endl; \
4546 std::abort (); \
4647 } \
4748 }
@@ -78,21 +79,20 @@ static ze_driver_handle_t getDriver(uint32_t idx = 0) {
7879 return drivers[idx];
7980}
8081
81- static ze_device_handle_t getDefaultDevice (const uint32_t driverIdx = 0 ,
82- const int32_t devIdx = 0 ) {
82+ static ze_device_handle_t getDevice (const uint32_t driverIdx = 0 ,
83+ const int32_t devIdx = 0 ) {
8384 thread_local static ze_device_handle_t l0Device;
84- thread_local static int32_t currDevIdx{-1 };
85- thread_local static uint32_t currDriverIdx{0 };
85+ thread_local int32_t currDevIdx{-1 };
86+ thread_local uint32_t currDriverIdx{0 };
8687 if (currDriverIdx == driverIdx && currDevIdx == devIdx)
8788 return l0Device;
8889 auto driver = getDriver (driverIdx);
8990 uint32_t deviceCount{0 };
9091 L0_SAFE_CALL (zeDeviceGet (driver, &deviceCount, nullptr ));
9192 if (!deviceCount)
92- throw std::runtime_error (
93- " getDefaultDevice failed: did not find L0 device." );
93+ throw std::runtime_error (" getDevice failed: did not find L0 device." );
9494 if (static_cast <int >(deviceCount) < devIdx + 1 )
95- throw std::runtime_error (" getDefaultDevice failed: devIdx out-of-bounds." );
95+ throw std::runtime_error (" getDevice failed: devIdx out-of-bounds." );
9696 std::vector<ze_device_handle_t > devices (deviceCount);
9797 L0_SAFE_CALL (zeDeviceGet (driver, &deviceCount, devices.data ()));
9898 l0Device = devices[devIdx];
@@ -150,8 +150,8 @@ struct L0RtContext {
150150 uint32_t copyEngineMaxMemoryFillPatternSize{-1u };
151151
152152 L0RtContext () = default ;
153- L0RtContext (const int32_t devIdx = 0 )
154- : driver(getDriver()), device(getDefaultDevice (devIdx)) {
153+ L0RtContext (const uint32_t driverIdx = 0 , const int32_t devIdx = 0 )
154+ : driver(getDriver(driverIdx )), device(getDevice (devIdx)) {
155155 // Create context
156156 ze_context_handle_t defaultCtx = getDefaultContext ();
157157 context.reset (defaultCtx);
@@ -488,10 +488,11 @@ extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
488488template <typename PATTERN_TYPE>
489489void mgpuMemset (void *dst, PATTERN_TYPE value, size_t count,
490490 StreamWrapper *stream) {
491+ L0RtContext &rtContext = getRtContext ();
491492 auto listType =
492- getRtContext () .copyEngineMaxMemoryFillPatternSize >= sizeof (PATTERN_TYPE)
493- ? getRtContext () .immCmdListCopy .get ()
494- : getRtContext () .immCmdListCompute .get ();
493+ rtContext .copyEngineMaxMemoryFillPatternSize >= sizeof (PATTERN_TYPE)
494+ ? rtContext .immCmdListCopy .get ()
495+ : rtContext .immCmdListCompute .get ();
495496 stream->enqueueOp ([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
496497 ze_event_handle_t *waitEvents) {
497498 L0_SAFE_CALL (zeCommandListAppendMemoryFill (
0 commit comments