@@ -46,8 +46,14 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue, CUstream Stream,
46
46
}
47
47
}
48
48
49
+ #if CUDA_VERSION >= 13000
50
+ using CuLocationType = CUmemLocation;
51
+ #else
52
+ using CuLocationType = CUdevice;
53
+ #endif
49
54
void setCuMemAdvise (CUdeviceptr DevPtr, size_t Size,
50
- ur_usm_advice_flags_t URAdviceFlags, CUdevice Device) {
55
+ ur_usm_advice_flags_t URAdviceFlags,
56
+ CuLocationType Location) {
51
57
std::unordered_map<ur_usm_advice_flags_t , CUmem_advise>
52
58
URToCUMemAdviseDeviceFlagsMap = {
53
59
{UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, CU_MEM_ADVISE_SET_READ_MOSTLY},
@@ -64,7 +70,7 @@ void setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
64
70
};
65
71
for (auto &FlagPair : URToCUMemAdviseDeviceFlagsMap) {
66
72
if (URAdviceFlags & FlagPair.first ) {
67
- UR_CHECK_ERROR (cuMemAdvise (DevPtr, Size, FlagPair.second , Device ));
73
+ UR_CHECK_ERROR (cuMemAdvise (DevPtr, Size, FlagPair.second , Location ));
68
74
}
69
75
}
70
76
@@ -82,7 +88,14 @@ void setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
82
88
83
89
for (auto &FlagPair : URToCUMemAdviseHostFlagsMap) {
84
90
if (URAdviceFlags & FlagPair.first ) {
85
- UR_CHECK_ERROR (cuMemAdvise (DevPtr, Size, FlagPair.second , CU_DEVICE_CPU));
91
+ #if CUDA_VERSION >= 13000
92
+ CUmemLocation LocationHost;
93
+ LocationHost.id = 0 ; // ignored with HOST_NUMA_CURRENT
94
+ LocationHost.type = CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT;
95
+ #else
96
+ int LocationHost = CU_DEVICE_CPU;
97
+ #endif
98
+ UR_CHECK_ERROR (cuMemAdvise (DevPtr, Size, FlagPair.second , LocationHost));
86
99
}
87
100
}
88
101
@@ -1550,8 +1563,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
1550
1563
return UR_RESULT_SUCCESS;
1551
1564
}
1552
1565
1566
+ #if CUDA_VERSION >= 13000
1567
+ CUmemLocation Location;
1568
+ Location.id = Device->get ();
1569
+ Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
1570
+ unsigned int Flags = 0U ;
1571
+ UR_CHECK_ERROR (
1572
+ cuMemPrefetchAsync ((CUdeviceptr)pMem, size, Location, Flags, CuStream));
1573
+ #else
1553
1574
UR_CHECK_ERROR (
1554
1575
cuMemPrefetchAsync ((CUdeviceptr)pMem, size, Device->get (), CuStream));
1576
+ #endif
1555
1577
} catch (ur_result_t Err) {
1556
1578
return Err;
1557
1579
}
@@ -1619,19 +1641,24 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
1619
1641
return UR_RESULT_SUCCESS;
1620
1642
}
1621
1643
1644
+ #if CUDA_VERSION >= 13000
1645
+ CUmemLocation Location;
1646
+ Location.id = hQueue->getDevice ()->get ();
1647
+ Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
1648
+ #else
1649
+ int Location = hQueue->getDevice ()->get ();
1650
+ #endif
1651
+
1622
1652
if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
1623
1653
UR_CHECK_ERROR (cuMemAdvise ((CUdeviceptr)pMem, size,
1624
- CU_MEM_ADVISE_UNSET_READ_MOSTLY,
1625
- hQueue->getDevice ()->get ()));
1654
+ CU_MEM_ADVISE_UNSET_READ_MOSTLY, Location));
1626
1655
UR_CHECK_ERROR (cuMemAdvise ((CUdeviceptr)pMem, size,
1627
1656
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
1628
- hQueue-> getDevice ()-> get () ));
1657
+ Location ));
1629
1658
UR_CHECK_ERROR (cuMemAdvise ((CUdeviceptr)pMem, size,
1630
- CU_MEM_ADVISE_UNSET_ACCESSED_BY,
1631
- hQueue->getDevice ()->get ()));
1659
+ CU_MEM_ADVISE_UNSET_ACCESSED_BY, Location));
1632
1660
} else {
1633
- setCuMemAdvise ((CUdeviceptr)pMem, size, advice,
1634
- hQueue->getDevice ()->get ());
1661
+ setCuMemAdvise ((CUdeviceptr)pMem, size, advice, Location);
1635
1662
}
1636
1663
} catch (ur_result_t err) {
1637
1664
return err;
0 commit comments