@@ -67,6 +67,62 @@ void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock,
67
67
--ThreadsPerBlock[0 ];
68
68
}
69
69
}
70
+
71
+ ur_result_t setHipMemAdvise (const void *DevPtr, const size_t Size,
72
+ ur_usm_advice_flags_t URAdviceFlags,
73
+ hipDevice_t Device) {
74
+ // Handle unmapped memory advice flags
75
+ if (URAdviceFlags &
76
+ (UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY |
77
+ UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY |
78
+ UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) {
79
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
80
+ }
81
+
82
+ using ur_to_hip_advice_t = std::pair<ur_usm_advice_flags_t , hipMemoryAdvise>;
83
+
84
+ static constexpr std::array<ur_to_hip_advice_t , 6 >
85
+ URToHIPMemAdviseDeviceFlags{
86
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_READ_MOSTLY,
87
+ hipMemAdviseSetReadMostly),
88
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY,
89
+ hipMemAdviseUnsetReadMostly),
90
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION,
91
+ hipMemAdviseSetPreferredLocation),
92
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION,
93
+ hipMemAdviseUnsetPreferredLocation),
94
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE,
95
+ hipMemAdviseSetAccessedBy),
96
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE,
97
+ hipMemAdviseUnsetAccessedBy),
98
+ };
99
+ for (auto &FlagPair : URToHIPMemAdviseDeviceFlags) {
100
+ if (URAdviceFlags & FlagPair.first ) {
101
+ UR_CHECK_ERROR (hipMemAdvise (DevPtr, Size, FlagPair.second , Device));
102
+ }
103
+ }
104
+
105
+ static constexpr std::array<ur_to_hip_advice_t , 4 > URToHIPMemAdviseHostFlags{
106
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST,
107
+ hipMemAdviseSetPreferredLocation),
108
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST,
109
+ hipMemAdviseUnsetPreferredLocation),
110
+ std::make_pair (UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_HOST,
111
+ hipMemAdviseSetAccessedBy),
112
+ std::make_pair (UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_HOST,
113
+ hipMemAdviseUnsetAccessedBy),
114
+ };
115
+
116
+ for (auto &FlagPair : URToHIPMemAdviseHostFlags) {
117
+ if (URAdviceFlags & FlagPair.first ) {
118
+ UR_CHECK_ERROR (
119
+ hipMemAdvise (DevPtr, Size, FlagPair.second , hipCpuDeviceId));
120
+ }
121
+ }
122
+
123
+ return UR_RESULT_SUCCESS;
124
+ }
125
+
70
126
} // namespace
71
127
72
128
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite (
@@ -1386,87 +1442,184 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
1386
1442
ur_queue_handle_t hQueue, const void *pMem, size_t size,
1387
1443
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
1388
1444
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1445
+ std::ignore = flags;
1446
+
1389
1447
void *HIPDevicePtr = const_cast <void *>(pMem);
1390
1448
ur_device_handle_t Device = hQueue->getDevice ();
1391
1449
1392
- // If the device does not support managed memory access, we can't set
1393
- // mem_advise.
1394
- if (!getAttribute (Device, hipDeviceAttributeManagedMemory)) {
1395
- setErrorMessage (" mem_advise ignored as device does not support "
1396
- " managed memory access" ,
1397
- UR_RESULT_SUCCESS);
1398
- return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1399
- }
1400
-
1401
- hipPointerAttribute_t attribs;
1402
- // TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1403
- // memory, as it is neither registered as host memory, nor into the address
1404
- // space for the current device, meaning the pMem ptr points to a
1405
- // system-allocated memory. This means we may need to check system-alloacted
1406
- // memory and handle the failure more gracefully.
1407
- UR_CHECK_ERROR (hipPointerGetAttributes (&attribs, pMem));
1408
- // async prefetch requires USM pointer (or hip SVM) to work.
1409
- if (!attribs.isManaged ) {
1410
- setErrorMessage (" Prefetch hint ignored as prefetch only works with USM" ,
1411
- UR_RESULT_SUCCESS);
1412
- return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1413
- }
1414
-
1415
- // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1416
- // so we can't perform this check for such cases.
1450
+ // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1451
+ // so we can't perform this check for such cases.
1417
1452
#if HIP_VERSION_MAJOR >= 5
1418
1453
unsigned int PointerRangeSize = 0 ;
1419
1454
UR_CHECK_ERROR (hipPointerGetAttribute (&PointerRangeSize,
1420
1455
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1421
1456
(hipDeviceptr_t)HIPDevicePtr));
1422
1457
UR_ASSERT (size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
1423
1458
#endif
1424
- // flags is currently unused so fail if set
1425
- if (flags != 0 )
1426
- return UR_RESULT_ERROR_INVALID_VALUE;
1459
+
1427
1460
ur_result_t Result = UR_RESULT_SUCCESS;
1428
- std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1429
1461
1430
1462
try {
1431
1463
ScopedContext Active (hQueue->getDevice ());
1432
1464
hipStream_t HIPStream = hQueue->getNextTransferStream ();
1433
1465
Result = enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
1434
1466
phEventWaitList);
1467
+
1468
+ std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1469
+
1435
1470
if (phEvent) {
1436
1471
EventPtr =
1437
1472
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1438
1473
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
1439
1474
UR_CHECK_ERROR (EventPtr->start ());
1440
1475
}
1476
+
1477
+ // Helper to ensure returning a valid event on early exit.
1478
+ auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1479
+ if (phEvent) {
1480
+ UR_CHECK_ERROR (EventPtr->record ());
1481
+ *phEvent = EventPtr.release ();
1482
+ }
1483
+ };
1484
+
1485
+ // If the device does not support managed memory access, we can't set
1486
+ // mem_advise.
1487
+ if (!getAttribute (Device, hipDeviceAttributeManagedMemory)) {
1488
+ releaseEvent ();
1489
+ setErrorMessage (" mem_advise ignored as device does not support "
1490
+ " managed memory access" ,
1491
+ UR_RESULT_SUCCESS);
1492
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1493
+ }
1494
+
1495
+ hipPointerAttribute_t attribs;
1496
+ // TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
1497
+ // memory, as it is neither registered as host memory, nor into the address
1498
+ // space for the current device, meaning the pMem ptr points to a
1499
+ // system-allocated memory. This means we may need to check system-alloacted
1500
+ // memory and handle the failure more gracefully.
1501
+ UR_CHECK_ERROR (hipPointerGetAttributes (&attribs, pMem));
1502
+ // async prefetch requires USM pointer (or hip SVM) to work.
1503
+ if (!attribs.isManaged ) {
1504
+ releaseEvent ();
1505
+ setErrorMessage (" Prefetch hint ignored as prefetch only works with USM" ,
1506
+ UR_RESULT_SUCCESS);
1507
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1508
+ }
1509
+
1441
1510
UR_CHECK_ERROR (
1442
1511
hipMemPrefetchAsync (pMem, size, hQueue->getDevice ()->get (), HIPStream));
1443
- if (phEvent) {
1444
- UR_CHECK_ERROR (EventPtr->record ());
1445
- *phEvent = EventPtr.release ();
1446
- }
1512
+ releaseEvent ();
1447
1513
} catch (ur_result_t Err) {
1448
1514
Result = Err;
1449
1515
}
1450
1516
1451
1517
return Result;
1452
1518
}
1453
1519
1520
+ // / USM: memadvise API to govern behavior of automatic migration mechanisms
1454
1521
UR_APIEXPORT ur_result_t UR_APICALL
1455
1522
urEnqueueUSMAdvise (ur_queue_handle_t hQueue, const void *pMem, size_t size,
1456
- ur_usm_advice_flags_t , ur_event_handle_t *phEvent) {
1523
+ ur_usm_advice_flags_t advice, ur_event_handle_t *phEvent) {
1524
+ UR_ASSERT (pMem && size > 0 , UR_RESULT_ERROR_INVALID_VALUE);
1457
1525
void *HIPDevicePtr = const_cast <void *>(pMem);
1458
- // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
1459
- // so we can't perform this check for such cases.
1526
+ ur_device_handle_t Device = hQueue-> getDevice ();
1527
+
1460
1528
#if HIP_VERSION_MAJOR >= 5
1461
- unsigned int PointerRangeSize = 0 ;
1462
- UR_CHECK_ERROR (hipPointerGetAttribute (&PointerRangeSize,
1463
- HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1464
- (hipDeviceptr_t)HIPDevicePtr));
1529
+ // NOTE: The hipPointerGetAttribute API is marked as beta, meaning, while this
1530
+ // is feature complete, it is still open to changes and outstanding issues.
1531
+ size_t PointerRangeSize = 0 ;
1532
+ UR_CHECK_ERROR (hipPointerGetAttribute (
1533
+ &PointerRangeSize, HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
1534
+ static_cast <hipDeviceptr_t>(HIPDevicePtr)));
1465
1535
UR_ASSERT (size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
1466
1536
#endif
1467
- // TODO implement a mapping to hipMemAdvise once the expected behaviour
1468
- // of urEnqueueUSMAdvise is detailed in the USM extension
1469
- return urEnqueueEventsWait (hQueue, 0 , nullptr , phEvent);
1537
+
1538
+ ur_result_t Result = UR_RESULT_SUCCESS;
1539
+
1540
+ try {
1541
+ ScopedContext Active (Device);
1542
+ std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr };
1543
+
1544
+ if (phEvent) {
1545
+ EventPtr =
1546
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1547
+ UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream ()));
1548
+ EventPtr->start ();
1549
+ }
1550
+
1551
+ // Helper to ensure returning a valid event on early exit.
1552
+ auto releaseEvent = [&EventPtr, &phEvent]() -> void {
1553
+ if (phEvent) {
1554
+ UR_CHECK_ERROR (EventPtr->record ());
1555
+ *phEvent = EventPtr.release ();
1556
+ }
1557
+ };
1558
+
1559
+ // If the device does not support managed memory access, we can't set
1560
+ // mem_advise.
1561
+ if (!getAttribute (Device, hipDeviceAttributeManagedMemory)) {
1562
+ releaseEvent ();
1563
+ setErrorMessage (" mem_advise ignored as device does not support "
1564
+ " managed memory access" ,
1565
+ UR_RESULT_SUCCESS);
1566
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1567
+ }
1568
+
1569
+ // Passing MEM_ADVICE_SET/MEM_ADVICE_CLEAR_PREFERRED_LOCATION to
1570
+ // hipMemAdvise on a GPU device requires the GPU device to report a non-zero
1571
+ // value for hipDeviceAttributeConcurrentManagedAccess. Therefore, ignore
1572
+ // the mem advice if concurrent managed memory access is not available.
1573
+ if (advice & (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION |
1574
+ UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION |
1575
+ UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE |
1576
+ UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE |
1577
+ UR_USM_ADVICE_FLAG_DEFAULT)) {
1578
+ if (!getAttribute (Device, hipDeviceAttributeConcurrentManagedAccess)) {
1579
+ releaseEvent ();
1580
+ setErrorMessage (" mem_advise ignored as device does not support "
1581
+ " concurrent managed access" ,
1582
+ UR_RESULT_SUCCESS);
1583
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1584
+ }
1585
+
1586
+ // TODO: If pMem points to valid system-allocated pageable memory, we
1587
+ // should check that the device also has the
1588
+ // hipDeviceAttributePageableMemoryAccess property, so that a valid
1589
+ // read-only copy can be created on the device. This also applies for
1590
+ // UR_USM_MEM_ADVICE_SET/MEM_ADVICE_CLEAR_READ_MOSTLY.
1591
+ }
1592
+
1593
+ const auto DeviceID = Device->get ();
1594
+ if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
1595
+ UR_CHECK_ERROR (
1596
+ hipMemAdvise (pMem, size, hipMemAdviseUnsetReadMostly, DeviceID));
1597
+ UR_CHECK_ERROR (hipMemAdvise (
1598
+ pMem, size, hipMemAdviseUnsetPreferredLocation, DeviceID));
1599
+ UR_CHECK_ERROR (
1600
+ hipMemAdvise (pMem, size, hipMemAdviseUnsetAccessedBy, DeviceID));
1601
+ } else {
1602
+ Result = setHipMemAdvise (HIPDevicePtr, size, advice, DeviceID);
1603
+ // UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but
1604
+ // currently unmapped advice arguments as not supported by this platform.
1605
+ // Therefore, warn the user instead of throwing and aborting the runtime.
1606
+ if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
1607
+ releaseEvent ();
1608
+ setErrorMessage (" mem_advise is ignored as the advice argument is not "
1609
+ " supported by this device" ,
1610
+ UR_RESULT_SUCCESS);
1611
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1612
+ }
1613
+ }
1614
+
1615
+ releaseEvent ();
1616
+ } catch (ur_result_t err) {
1617
+ Result = err;
1618
+ } catch (...) {
1619
+ Result = UR_RESULT_ERROR_UNKNOWN;
1620
+ }
1621
+
1622
+ return Result;
1470
1623
}
1471
1624
1472
1625
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D (
0 commit comments