Skip to content

Commit a320978

Browse files
committed
Address review comments.
Add L0 error message. Find levelzero only once if SYCL or LEVELZERO runner is enabled.
1 parent 2908cf8 commit a320978

File tree

2 files changed

+28
-31
lines changed

2 files changed

+28
-31
lines changed

mlir/lib/ExecutionEngine/CMakeLists.txt

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -389,19 +389,22 @@ if(LLVM_ENABLE_PIC)
389389
)
390390
endif()
391391

392+
if(MLIR_ENABLE_SYCL_RUNNER OR MLIR_ENABLE_LEVELZERO_RUNNER)
393+
# Both runtimes require LevelZero, so we can find it once.
394+
find_package(LevelZeroRuntime)
395+
396+
if(NOT LevelZeroRuntime_FOUND)
397+
message(FATAL_ERROR "LevelZero not found. Please set LEVEL_ZERO_DIR.")
398+
endif()
399+
endif()
400+
392401
if(MLIR_ENABLE_SYCL_RUNNER)
393402
find_package(SyclRuntime)
394403

395404
if(NOT SyclRuntime_FOUND)
396405
message(FATAL_ERROR "syclRuntime not found. Please set check oneapi installation and run setvars.sh.")
397406
endif()
398407

399-
find_package(LevelZeroRuntime)
400-
401-
if(NOT LevelZeroRuntime_FOUND)
402-
message(FATAL_ERROR "LevelZero not found. Please set LEVEL_ZERO_DIR.")
403-
endif()
404-
405408
add_mlir_library(mlir_sycl_runtime
406409
SHARED
407410
SyclRuntimeWrappers.cpp
@@ -427,25 +430,18 @@ if(LLVM_ENABLE_PIC)
427430
endif()
428431

429432
if(MLIR_ENABLE_LEVELZERO_RUNNER)
430-
find_package(LevelZeroRuntime)
431-
432-
if(NOT LevelZeroRuntime_FOUND)
433-
message(FATAL_ERROR "LevelZero not found. Please set LEVEL_ZERO_DIR.")
434-
endif()
435-
436433
add_mlir_library(mlir_levelzero_runtime
437434
SHARED
438435
LevelZeroRuntimeWrappers.cpp
439436

440437
EXCLUDE_FROM_LIBMLIR
441438
)
442439

443-
check_cxx_compiler_flag("-frtti" CXX_HAS_FRTTI_FLAG)
444-
445-
if(NOT CXX_HAS_FRTTI_FLAG)
446-
message(FATAL_ERROR "CXX compiler does not accept flag -frtti")
447-
endif()
440+
# check_cxx_compiler_flag("-frtti" CXX_HAS_FRTTI_FLAG)
448441

442+
# if(NOT CXX_HAS_FRTTI_FLAG)
443+
# message(FATAL_ERROR "CXX compiler does not accept flag -frtti")
444+
# endif()
449445
target_compile_options(mlir_levelzero_runtime PUBLIC -fexceptions -frtti)
450446

451447
target_include_directories(mlir_levelzero_runtime PRIVATE

mlir/lib/ExecutionEngine/LevelZeroRuntimeWrappers.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include <vector>
2424

2525
namespace {
26-
2726
template <typename F>
2827
auto catchAll(F &&func) {
2928
try {
@@ -41,7 +40,9 @@ auto catchAll(F &&func) {
4140
{ \
4241
ze_result_t status = (call); \
4342
if (status != ZE_RESULT_SUCCESS) { \
44-
std::cerr << "L0 error " << status << std::endl; \
43+
const char *errorString; \
44+
zeDriverGetLastErrorDescription(NULL, &errorString); \
45+
std::cerr << "L0 error " << status << ": " << errorString << std::endl; \
4546
std::abort(); \
4647
} \
4748
}
@@ -78,21 +79,20 @@ static ze_driver_handle_t getDriver(uint32_t idx = 0) {
7879
return drivers[idx];
7980
}
8081

81-
static ze_device_handle_t getDefaultDevice(const uint32_t driverIdx = 0,
82-
const int32_t devIdx = 0) {
82+
static ze_device_handle_t getDevice(const uint32_t driverIdx = 0,
83+
const int32_t devIdx = 0) {
8384
thread_local static ze_device_handle_t l0Device;
84-
thread_local static int32_t currDevIdx{-1};
85-
thread_local static uint32_t currDriverIdx{0};
85+
thread_local int32_t currDevIdx{-1};
86+
thread_local uint32_t currDriverIdx{0};
8687
if (currDriverIdx == driverIdx && currDevIdx == devIdx)
8788
return l0Device;
8889
auto driver = getDriver(driverIdx);
8990
uint32_t deviceCount{0};
9091
L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, nullptr));
9192
if (!deviceCount)
92-
throw std::runtime_error(
93-
"getDefaultDevice failed: did not find L0 device.");
93+
throw std::runtime_error("getDevice failed: did not find L0 device.");
9494
if (static_cast<int>(deviceCount) < devIdx + 1)
95-
throw std::runtime_error("getDefaultDevice failed: devIdx out-of-bounds.");
95+
throw std::runtime_error("getDevice failed: devIdx out-of-bounds.");
9696
std::vector<ze_device_handle_t> devices(deviceCount);
9797
L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));
9898
l0Device = devices[devIdx];
@@ -150,8 +150,8 @@ struct L0RtContext {
150150
uint32_t copyEngineMaxMemoryFillPatternSize{-1u};
151151

152152
L0RtContext() = default;
153-
L0RtContext(const int32_t devIdx = 0)
154-
: driver(getDriver()), device(getDefaultDevice(devIdx)) {
153+
L0RtContext(const uint32_t driverIdx = 0, const int32_t devIdx = 0)
154+
: driver(getDriver(driverIdx)), device(getDevice(devIdx)) {
155155
// Create context
156156
ze_context_handle_t defaultCtx = getDefaultContext();
157157
context.reset(defaultCtx);
@@ -488,10 +488,11 @@ extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
488488
template <typename PATTERN_TYPE>
489489
void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count,
490490
StreamWrapper *stream) {
491+
L0RtContext &rtContext = getRtContext();
491492
auto listType =
492-
getRtContext().copyEngineMaxMemoryFillPatternSize >= sizeof(PATTERN_TYPE)
493-
? getRtContext().immCmdListCopy.get()
494-
: getRtContext().immCmdListCompute.get();
493+
rtContext.copyEngineMaxMemoryFillPatternSize >= sizeof(PATTERN_TYPE)
494+
? rtContext.immCmdListCopy.get()
495+
: rtContext.immCmdListCompute.get();
495496
stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
496497
ze_event_handle_t *waitEvents) {
497498
L0_SAFE_CALL(zeCommandListAppendMemoryFill(

0 commit comments

Comments
 (0)