@@ -1544,19 +1544,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
15441544 return Result;
15451545}
15461546
1547- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
1547+ namespace {
1548+
1549+ enum class GlobalVariableCopy { Read, Write };
1550+
1551+ ur_result_t deviceGlobalCopyHelper (
15481552 ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1549- bool blockingWrite , size_t count, size_t offset, const void *pSrc ,
1553+ bool blocking , size_t count, size_t offset, void *ptr ,
15501554 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1551- ur_event_handle_t *phEvent) {
1552- // Since HIP requires the global variable to be referenced by name, we use
1555+ ur_event_handle_t *phEvent, GlobalVariableCopy CopyType ) {
1556+ // Since HIP requires a the global variable to be referenced by name, we use
15531557 // metadata to find the correct name to access it by.
15541558 auto DeviceGlobalNameIt = hProgram->GlobalIDMD .find (name);
15551559 if (DeviceGlobalNameIt == hProgram->GlobalIDMD .end ())
15561560 return UR_RESULT_ERROR_INVALID_VALUE;
15571561 std::string DeviceGlobalName = DeviceGlobalNameIt->second ;
15581562
1559- ur_result_t Result = UR_RESULT_SUCCESS;
15601563 try {
15611564 hipDeviceptr_t DeviceGlobal = 0 ;
15621565 size_t DeviceGlobalSize = 0 ;
@@ -1567,49 +1570,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
15671570 if (offset + count > DeviceGlobalSize)
15681571 return UR_RESULT_ERROR_INVALID_VALUE;
15691572
1570- return urEnqueueUSMMemcpy (
1571- hQueue, blockingWrite,
1572- reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(DeviceGlobal) +
1573- offset),
1574- pSrc, count, numEventsInWaitList, phEventWaitList, phEvent);
1573+ void *pSrc, *pDst;
1574+ if (CopyType == GlobalVariableCopy::Write) {
1575+ pSrc = ptr;
1576+ pDst = reinterpret_cast <uint8_t *>(DeviceGlobal) + offset;
1577+ } else {
1578+ pSrc = reinterpret_cast <uint8_t *>(DeviceGlobal) + offset;
1579+ pDst = ptr;
1580+ }
1581+ return urEnqueueUSMMemcpy (hQueue, blocking, pDst, pSrc, count,
1582+ numEventsInWaitList, phEventWaitList, phEvent);
15751583 } catch (ur_result_t Err) {
1576- Result = Err;
1584+ return Err;
15771585 }
1578- return Result;
1586+ }
1587+ } // namespace
1588+
1589+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
1590+ ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1591+ bool blockingWrite, size_t count, size_t offset, const void *pSrc,
1592+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1593+ ur_event_handle_t *phEvent) {
1594+ return deviceGlobalCopyHelper (hQueue, hProgram, name, blockingWrite, count,
1595+ offset, const_cast <void *>(pSrc),
1596+ numEventsInWaitList, phEventWaitList, phEvent,
1597+ GlobalVariableCopy::Write);
15791598}
15801599
15811600UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead (
15821601 ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
15831602 bool blockingRead, size_t count, size_t offset, void *pDst,
15841603 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
15851604 ur_event_handle_t *phEvent) {
1586- // Since HIP requires the global variable to be referenced by name, we use
1587- // metadata to find the correct name to access it by.
1588- auto DeviceGlobalNameIt = hProgram->GlobalIDMD .find (name);
1589- if (DeviceGlobalNameIt == hProgram->GlobalIDMD .end ())
1590- return UR_RESULT_ERROR_INVALID_VALUE;
1591- std::string DeviceGlobalName = DeviceGlobalNameIt->second ;
1592-
1593- ur_result_t Result = UR_RESULT_SUCCESS;
1594- try {
1595- hipDeviceptr_t DeviceGlobal = 0 ;
1596- size_t DeviceGlobalSize = 0 ;
1597- UR_CHECK_ERROR (hipModuleGetGlobal (&DeviceGlobal, &DeviceGlobalSize,
1598- hProgram->get (),
1599- DeviceGlobalName.c_str ()));
1600-
1601- if (offset + count > DeviceGlobalSize)
1602- return UR_RESULT_ERROR_INVALID_VALUE;
1603-
1604- return urEnqueueUSMMemcpy (
1605- hQueue, blockingRead, pDst,
1606- reinterpret_cast <const void *>(
1607- reinterpret_cast <uint8_t *>(DeviceGlobal) + offset),
1608- count, numEventsInWaitList, phEventWaitList, phEvent);
1609- } catch (ur_result_t Err) {
1610- Result = Err;
1611- }
1612- return Result;
1605+ return deviceGlobalCopyHelper (
1606+ hQueue, hProgram, name, blockingRead, count, offset, pDst,
1607+ numEventsInWaitList, phEventWaitList, phEvent, GlobalVariableCopy::Read);
16131608}
16141609
16151610UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe (
0 commit comments