Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/intel-llvm-mirror-base-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
8959a5e5a6cebac8993c58c5597638b4510be91f
84518c193adb9d8b03ae449345d892c6c9984846
12 changes: 7 additions & 5 deletions include/ur_api.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 20 additions & 6 deletions include/ur_print.hpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions scripts/core/enqueue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -915,13 +915,16 @@ etors:
value: "$X_BIT(2)"
--- #--------------------------------------------------------------------------
type: enum
desc: "Map flags"
class: $xDevice
desc: "USM migration flags, indicating the direction data is migrated in"
class: $xEnqueue
name: $x_usm_migration_flags_t
etors:
- name: DEFAULT
desc: "Default migration TODO: Add more enums! "
- name: HOST_TO_DEVICE
desc: "Migrate data from host to device"
value: "$X_BIT(0)"
- name: DEVICE_TO_HOST
desc: "Migrate data from device to host"
value: "$X_BIT(1)"
--- #--------------------------------------------------------------------------
type: function
desc: "Enqueue a command to map a region of the buffer object into the host address space and return a pointer to the mapped region"
Expand Down
2 changes: 1 addition & 1 deletion scripts/core/exp-command-buffer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ params:
desc: "[in] size in bytes to be fetched."
- type: $x_usm_migration_flags_t
name: flags
desc: "[in] USM prefetch flags"
desc: "[in] USM migration flags"
- type: uint32_t
name: numSyncPointsInWaitList
desc: "[in] The number of sync points in the provided dependency list."
Expand Down
35 changes: 29 additions & 6 deletions source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,14 +1516,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_migration_flags_t /*flags*/, uint32_t numEventsInWaitList,
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

ur_device_handle_t Device = hQueue->getDevice();
#if CUDA_VERSION >= 13000
CUmemLocation Location;
switch (flags) {
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
Location.id = Device->get();
break;
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
Location.type = CU_MEM_LOCATION_TYPE_HOST;
break;
#else
int dstDevice;
switch (flags) {
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
dstDevice = Device->get();
break;
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
dstDevice = CU_DEVICE_CPU;
break;
#endif
default:
setErrorMessage("Invalid USM migration flag",
UR_RESULT_ERROR_INVALID_ENUMERATION);
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

size_t PointerRangeSize = 0;
UR_CHECK_ERROR(cuPointerGetAttribute(
&PointerRangeSize, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)pMem));
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
ur_device_handle_t Device = hQueue->getDevice();

std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
try {
Expand Down Expand Up @@ -1564,15 +1590,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
}

#if CUDA_VERSION >= 13000
CUmemLocation Location;
Location.id = Device->get();
Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
unsigned int Flags = 0U;
UR_CHECK_ERROR(
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Location, Flags, CuStream));
#else
UR_CHECK_ERROR(
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream));
cuMemPrefetchAsync((CUdeviceptr)pMem, size, dstDevice, CuStream));
#endif
} catch (ur_result_t Err) {
return Err;
Expand Down
20 changes: 16 additions & 4 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,11 +1324,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_migration_flags_t /*flags*/, uint32_t numEventsInWaitList,
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

void *HIPDevicePtr = const_cast<void *>(pMem);
ur_device_handle_t Device = hQueue->getDevice();
hipDevice_t TargetDevice;
switch (flags) {
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
TargetDevice = Device->get();
break;
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
TargetDevice = hipCpuDeviceId;
break;
default:
setErrorMessage("Invalid USM migration flag",
UR_RESULT_ERROR_INVALID_ENUMERATION);
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
void *HIPDevicePtr = const_cast<void *>(pMem);

// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
// so we can't perform this check for such cases.
Expand Down Expand Up @@ -1385,8 +1398,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
return UR_RESULT_SUCCESS;
}

UR_CHECK_ERROR(
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
UR_CHECK_ERROR(hipMemPrefetchAsync(pMem, size, TargetDevice, HIPStream));
releaseEvent();
} catch (ur_result_t Err) {
return Err;
Expand Down
115 changes: 57 additions & 58 deletions source/adapters/level_zero/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,72 +506,71 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
bool forceLoadedAdapter = ur_getenv("UR_ADAPTERS_FORCE_LOAD").has_value();
if (!forceLoadedAdapter) {
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
auto [useV2, reason] = shouldUseV2Adapter();
if (!useV2) {
UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason);
return;
}
auto [useV2, reason] = shouldUseV2Adapter();
if (!useV2) {
UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason);
return;
}
#else
auto [useV1, reason] = shouldUseV1Adapter();
if (!useV1) {
UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason);
return;
}
auto [useV1, reason] = shouldUseV1Adapter();
if (!useV1) {
UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason);
return;
}
#endif
}

// Check if the user has enabled the default L0 SysMan initialization.
const int UrSysmanZesinitEnable = [&UserForcedSysManInit] {
const char *UrRet = std::getenv("UR_L0_ENABLE_ZESINIT_DEFAULT");
if (!UrRet)
return 0;
UserForcedSysManInit &= 2;
return std::atoi(UrRet);
}();

bool ZesInitNeeded = UrSysmanZesinitEnable && !UrSysManEnvInitEnabled;
// Unless the user has forced the SysMan init, we will check the device
// version to see if the zesInit is needed.
if (UserForcedSysManInit == 0 && checkDeviceIntelGPUIpVersionOrNewer(
0x05004000) == UR_RESULT_SUCCESS) {
if (UrSysManEnvInitEnabled) {
setEnvVar("ZES_ENABLE_SYSMAN", "0");
}
ZesInitNeeded = true;
}
if (ZesInitNeeded) {
// Check if the user has enabled the default L0 SysMan initialization.
const int UrSysmanZesinitEnable = [&UserForcedSysManInit] {
const char *UrRet = std::getenv("UR_L0_ENABLE_ZESINIT_DEFAULT");
if (!UrRet)
return 0;
UserForcedSysManInit &= 2;
return std::atoi(UrRet);
}();

bool ZesInitNeeded = UrSysmanZesinitEnable && !UrSysManEnvInitEnabled;
// Unless the user has forced the SysMan init, we will check the device
// version to see if the zesInit is needed.
if (UserForcedSysManInit == 0 &&
checkDeviceIntelGPUIpVersionOrNewer(0x05004000) == UR_RESULT_SUCCESS) {
if (UrSysManEnvInitEnabled) {
setEnvVar("ZES_ENABLE_SYSMAN", "0");
}
ZesInitNeeded = true;
}
if (ZesInitNeeded) {
#ifdef UR_STATIC_LEVEL_ZERO
getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp;
getSysManDriversFunctionPtr = zesDriverGet;
sysManInitFunctionPtr = zesInit;
getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp;
getSysManDriversFunctionPtr = zesDriverGet;
sysManInitFunctionPtr = zesInit;
#else
getDeviceByUUIdFunctionPtr = (zes_pfnDriverGetDeviceByUuidExp_t)
ur_loader::LibLoader::getFunctionPtr(processHandle,
"zesDriverGetDeviceByUuidExp");
getSysManDriversFunctionPtr =
(zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr(
processHandle, "zesDriverGet");
sysManInitFunctionPtr =
(zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(processHandle,
"zesInit");
getDeviceByUUIdFunctionPtr =
(zes_pfnDriverGetDeviceByUuidExp_t)ur_loader::LibLoader::getFunctionPtr(
processHandle, "zesDriverGetDeviceByUuidExp");
getSysManDriversFunctionPtr =
(zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr(
processHandle, "zesDriverGet");
sysManInitFunctionPtr = (zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(
processHandle, "zesInit");
#endif
}
if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr &&
sysManInitFunctionPtr) {
ze_init_flags_t L0ZesInitFlags = 0;
UR_LOG(DEBUG, "\nzesInit with flags value of {}\n",
static_cast<int>(L0ZesInitFlags));
ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags));
} else {
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
}
}
if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr &&
sysManInitFunctionPtr) {
ze_init_flags_t L0ZesInitFlags = 0;
UR_LOG(DEBUG, "\nzesInit with flags value of {}\n",
static_cast<int>(L0ZesInitFlags));
ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags));
} else {
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
}

ur_result_t err = initPlatforms(this, platforms, ZesResult);
if (err == UR_RESULT_SUCCESS) {
Platforms = std::move(platforms);
} else {
throw err;
}
ur_result_t err = initPlatforms(this, platforms, ZesResult);
if (err == UR_RESULT_SUCCESS) {
Platforms = std::move(platforms);
} else {
throw err;
}
}

void globalAdapterOnDemandCleanup() {
Expand Down
21 changes: 17 additions & 4 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,7 @@ ur_result_t urCommandBufferAppendMemBufferReadRectExp(

ur_result_t urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
ur_usm_migration_flags_t /*Flags*/, uint32_t NumSyncPointsInWaitList,
ur_usm_migration_flags_t Flags, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
uint32_t /*NumEventsInWaitList*/,
const ur_event_handle_t * /*EventWaitList*/,
Expand All @@ -1327,6 +1327,17 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
UR_COMMAND_USM_PREFETCH, CommandBuffer,
CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
switch (Flags) {
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
break;
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
UR_LOG(WARN, "commandBufferAppendUSMPrefetch: L0 does not support prefetch "
"to host yet");
break;
default:
UR_LOG(ERR, "commandBufferAppendUSMPrefetch: invalid USM migration flag");
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

if (!ZeEventList.empty()) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
Expand All @@ -1335,9 +1346,11 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
}

// Add the prefetch command to the command-buffer.
// Note that L0 does not handle migration flags.
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeComputeCommandList, Mem, Size));
// TODO Support migration flags after L0 backend support is added.
if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) {
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeComputeCommandList, Mem, Size));
}

if (!CommandBuffer->IsInOrderCmdList) {
// Level Zero does not have a completion "event" with the prefetch API,
Expand Down
Loading