@@ -88,6 +88,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(
8888 return UR_RESULT_SUCCESS;
8989}
9090
91+ uint64_t calculateGlobalMemSize (ur_device_handle_t Device) {
92+ // Cache GlobalMemSize
93+ Device->ZeGlobalMemSize .Compute =
94+ [Device](struct ze_global_memsize &GlobalMemSize) {
95+ for (const auto &ZeDeviceMemoryExtProperty :
96+ Device->ZeDeviceMemoryProperties ->second ) {
97+ GlobalMemSize.value += ZeDeviceMemoryExtProperty.physicalSize ;
98+ }
99+ if (GlobalMemSize.value == 0 ) {
100+ for (const auto &ZeDeviceMemoryProperty :
101+ Device->ZeDeviceMemoryProperties ->first ) {
102+ GlobalMemSize.value += ZeDeviceMemoryProperty.totalSize ;
103+ }
104+ }
105+ };
106+ return Device->ZeGlobalMemSize .operator ->()->value ;
107+ }
108+
91109UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo (
92110 ur_device_handle_t Device, // /< [in] handle of the device instance
93111 ur_device_info_t ParamName, // /< [in] type of the info to retrieve
@@ -251,20 +269,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
251269 case UR_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
252270 return ReturnValue (uint64_t {Device->ZeDeviceProperties ->maxMemAllocSize });
253271 case UR_DEVICE_INFO_GLOBAL_MEM_SIZE: {
254- uint64_t GlobalMemSize = 0 ;
255272 // Support to read physicalSize depends on kernel,
256273 // so fallback into reading totalSize if physicalSize
257274 // is not available.
258- for (const auto &ZeDeviceMemoryExtProperty :
259- Device->ZeDeviceMemoryProperties ->second ) {
260- GlobalMemSize += ZeDeviceMemoryExtProperty.physicalSize ;
261- }
262- if (GlobalMemSize == 0 ) {
263- for (const auto &ZeDeviceMemoryProperty :
264- Device->ZeDeviceMemoryProperties ->first ) {
265- GlobalMemSize += ZeDeviceMemoryProperty.totalSize ;
266- }
267- }
275+ uint64_t GlobalMemSize = calculateGlobalMemSize (Device);
268276 return ReturnValue (uint64_t {GlobalMemSize});
269277 }
270278 case UR_DEVICE_INFO_LOCAL_MEM_SIZE:
@@ -637,6 +645,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
637645 static_cast <int32_t >(ZE_RESULT_ERROR_UNINITIALIZED));
638646 return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
639647 }
648+ // Calculate the global memory size as the max limit that can be reported as
649+ // "free" memory for the user to allocate.
650+ uint64_t GlobalMemSize = calculateGlobalMemSize (Device);
640651 // Only report device memory which zeMemAllocDevice can allocate from.
641652 // Currently this is only the one enumerated with ordinal 0.
642653 uint64_t FreeMemory = 0 ;
@@ -661,7 +672,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
661672 }
662673 }
663674 }
664- return ReturnValue (FreeMemory);
675+ return ReturnValue (std::min (GlobalMemSize, FreeMemory) );
665676 }
666677 case UR_DEVICE_INFO_MEMORY_CLOCK_RATE: {
667678 // If there are not any memory modules then return 0.
0 commit comments