Skip to content

Commit 6133d21

Browse files
authored
[UR][CUDA] Add support for CUDA 13 (#19752)
A couple function signatures changed. Only UR change here. Choice of `CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT` for host advice is somewhat arbitrary. Full CUDA 13 support requires a compiler driver patch from upstream (12eab1a).
1 parent b2f8ead commit 6133d21

File tree

1 file changed

+37
-10
lines changed

1 file changed

+37
-10
lines changed

unified-runtime/source/adapters/cuda/enqueue.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue, CUstream Stream,
4646
}
4747
}
4848

49+
#if CUDA_VERSION >= 13000
50+
using CuLocationType = CUmemLocation;
51+
#else
52+
using CuLocationType = CUdevice;
53+
#endif
4954
void setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
50-
ur_usm_advice_flags_t URAdviceFlags, CUdevice Device) {
55+
ur_usm_advice_flags_t URAdviceFlags,
56+
CuLocationType Location) {
5157
std::unordered_map<ur_usm_advice_flags_t, CUmem_advise>
5258
URToCUMemAdviseDeviceFlagsMap = {
5359
{UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, CU_MEM_ADVISE_SET_READ_MOSTLY},
@@ -64,7 +70,7 @@ void setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
6470
};
6571
for (auto &FlagPair : URToCUMemAdviseDeviceFlagsMap) {
6672
if (URAdviceFlags & FlagPair.first) {
67-
UR_CHECK_ERROR(cuMemAdvise(DevPtr, Size, FlagPair.second, Device));
73+
UR_CHECK_ERROR(cuMemAdvise(DevPtr, Size, FlagPair.second, Location));
6874
}
6975
}
7076

@@ -82,7 +88,14 @@ void setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
8288

8389
for (auto &FlagPair : URToCUMemAdviseHostFlagsMap) {
8490
if (URAdviceFlags & FlagPair.first) {
85-
UR_CHECK_ERROR(cuMemAdvise(DevPtr, Size, FlagPair.second, CU_DEVICE_CPU));
91+
#if CUDA_VERSION >= 13000
92+
CUmemLocation LocationHost;
93+
LocationHost.id = 0; // ignored with HOST_NUMA_CURRENT
94+
LocationHost.type = CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT;
95+
#else
96+
int LocationHost = CU_DEVICE_CPU;
97+
#endif
98+
UR_CHECK_ERROR(cuMemAdvise(DevPtr, Size, FlagPair.second, LocationHost));
8699
}
87100
}
88101

@@ -1550,8 +1563,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
15501563
return UR_RESULT_SUCCESS;
15511564
}
15521565

1566+
#if CUDA_VERSION >= 13000
1567+
CUmemLocation Location;
1568+
Location.id = Device->get();
1569+
Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
1570+
unsigned int Flags = 0U;
1571+
UR_CHECK_ERROR(
1572+
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Location, Flags, CuStream));
1573+
#else
15531574
UR_CHECK_ERROR(
15541575
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream));
1576+
#endif
15551577
} catch (ur_result_t Err) {
15561578
return Err;
15571579
}
@@ -1619,19 +1641,24 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
16191641
return UR_RESULT_SUCCESS;
16201642
}
16211643

1644+
#if CUDA_VERSION >= 13000
1645+
CUmemLocation Location;
1646+
Location.id = hQueue->getDevice()->get();
1647+
Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
1648+
#else
1649+
int Location = hQueue->getDevice()->get();
1650+
#endif
1651+
16221652
if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
16231653
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
1624-
CU_MEM_ADVISE_UNSET_READ_MOSTLY,
1625-
hQueue->getDevice()->get()));
1654+
CU_MEM_ADVISE_UNSET_READ_MOSTLY, Location));
16261655
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
16271656
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
1628-
hQueue->getDevice()->get()));
1657+
Location));
16291658
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
1630-
CU_MEM_ADVISE_UNSET_ACCESSED_BY,
1631-
hQueue->getDevice()->get()));
1659+
CU_MEM_ADVISE_UNSET_ACCESSED_BY, Location));
16321660
} else {
1633-
setCuMemAdvise((CUdeviceptr)pMem, size, advice,
1634-
hQueue->getDevice()->get());
1661+
setCuMemAdvise((CUdeviceptr)pMem, size, advice, Location);
16351662
}
16361663
} catch (ur_result_t err) {
16371664
return err;

0 commit comments

Comments
 (0)