diff --git a/sycl/source/detail/adapter.hpp b/sycl/source/detail/adapter.hpp index d78743ac6159e..b6ecb2e76a5e6 100644 --- a/sycl/source/detail/adapter.hpp +++ b/sycl/source/detail/adapter.hpp @@ -98,10 +98,10 @@ class Adapter { std::vector &getUrPlatforms() { std::call_once(PlatformsPopulated, [&]() { uint32_t platformCount = 0; - call(&MAdapter, 1, 0, nullptr, &platformCount); + call(MAdapter, 0, nullptr, &platformCount); UrPlatforms.resize(platformCount); if (platformCount) { - call(&MAdapter, 1, platformCount, + call(MAdapter, platformCount, UrPlatforms.data(), nullptr); } // We need one entry in this per platform diff --git a/unified-runtime/examples/codegen/codegen.cpp b/unified-runtime/examples/codegen/codegen.cpp index 7d45789063682..cd6af53e70649 100644 --- a/unified-runtime/examples/codegen/codegen.cpp +++ b/unified-runtime/examples/codegen/codegen.cpp @@ -64,16 +64,21 @@ get_supported_adapters(std::vector &adapters) { std::vector get_platforms(std::vector &adapters) { uint32_t platformCount = 0; - ur_check(urPlatformGet(adapters.data(), adapters.size(), 1, nullptr, - &platformCount)); + std::vector platforms; + for (auto adapter : adapters) { + uint32_t adapterPlatformCount = 0; + urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount); + platforms.reserve(platformCount + adapterPlatformCount); + urPlatformGet(adapter, adapterPlatformCount, &platforms[platformCount], + &adapterPlatformCount); + platformCount += adapterPlatformCount; + } if (!platformCount) { throw std::runtime_error("No platforms available."); } + platforms.resize(platformCount); - std::vector platforms(platformCount); - ur_check(urPlatformGet(adapters.data(), adapters.size(), platformCount, - platforms.data(), nullptr)); return platforms; } diff --git a/unified-runtime/examples/hello_world/hello_world.cpp b/unified-runtime/examples/hello_world/hello_world.cpp index 1653c0482ae55..f5aec12d25d43 100644 --- a/unified-runtime/examples/hello_world/hello_world.cpp +++ b/unified-runtime/examples/hello_world/hello_world.cpp @@ -47,22 +47,26 @@ int main(int, char *[]) { return 1; } - status = - urPlatformGet(adapters.data(), adapterCount, 1, nullptr, &platformCount); - if (status != UR_RESULT_SUCCESS) { - std::cout << "urPlatformGet failed with return code: " << status - << std::endl; - goto out; - } + for (auto adapter : adapters) { + uint32_t adapterPlatformCount = 0; + status = urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount); + if (status != UR_RESULT_SUCCESS) { + std::cout << "urPlatformGet failed with return code: " << status + << std::endl; + goto out; + } - platforms.resize(platformCount); - status = urPlatformGet(adapters.data(), adapterCount, platformCount, - platforms.data(), nullptr); - if (status != UR_RESULT_SUCCESS) { - std::cout << "urPlatformGet failed with return code: " << status - << std::endl; - goto out; + platforms.reserve(platformCount + adapterPlatformCount); + status = urPlatformGet(adapter, adapterPlatformCount, + &platforms[platformCount], &adapterPlatformCount); + if (status != UR_RESULT_SUCCESS) { + std::cout << "urPlatformGet failed with return code: " << status + << std::endl; + goto out; + } + platformCount += adapterPlatformCount; } + platforms.resize(platformCount); for (auto p : platforms) { ur_api_version_t api_version = {}; diff --git a/unified-runtime/include/ur_api.h b/unified-runtime/include/ur_api.h index a6c046a76aded..120b12c618250 100644 --- a/unified-runtime/include/ur_api.h +++ b/unified-runtime/include/ur_api.h @@ -1448,17 +1448,15 @@ typedef enum ur_adapter_backend_t { /// - ::UR_RESULT_ERROR_UNINITIALIZED /// - ::UR_RESULT_ERROR_DEVICE_LOST /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// + `NULL == phAdapters` +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hAdapter` /// - ::UR_RESULT_ERROR_INVALID_SIZE /// + `NumEntries == 0 && phPlatforms != NULL` /// - ::UR_RESULT_ERROR_INVALID_VALUE /// + `pNumPlatforms == NULL && phPlatforms == NULL` UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, @@ -12772,8 +12770,7 @@ typedef struct ur_loader_config_set_mocking_enabled_params_t { /// @details Each entry is a pointer to the parameter passed to the function; /// allowing the callback the ability to modify the parameter's value typedef struct ur_platform_get_params_t { - ur_adapter_handle_t **pphAdapters; - uint32_t *pNumAdapters; + ur_adapter_handle_t *phAdapter; uint32_t *pNumEntries; ur_platform_handle_t **pphPlatforms; uint32_t **ppNumPlatforms; diff --git a/unified-runtime/include/ur_ddi.h b/unified-runtime/include/ur_ddi.h index 68dc0a265d284..b05f225337ef0 100644 --- a/unified-runtime/include/ur_ddi.h +++ b/unified-runtime/include/ur_ddi.h @@ -25,8 +25,8 @@ extern "C" { /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urPlatformGet -typedef ur_result_t(UR_APICALL *ur_pfnPlatformGet_t)(ur_adapter_handle_t *, - uint32_t, uint32_t, +typedef ur_result_t(UR_APICALL *ur_pfnPlatformGet_t)(ur_adapter_handle_t, + uint32_t, ur_platform_handle_t *, uint32_t *); diff --git a/unified-runtime/include/ur_print.hpp b/unified-runtime/include/ur_print.hpp index 6b8a6e2e5f312..3ba8e1a36496f 100644 --- a/unified-runtime/include/ur_print.hpp +++ b/unified-runtime/include/ur_print.hpp @@ -12383,25 +12383,9 @@ inline std::ostream & operator<<(std::ostream &os, [[maybe_unused]] const struct ur_platform_get_params_t *params) { - os << ".phAdapters = "; - ur::details::printPtr(os, - reinterpret_cast(*(params->pphAdapters))); - if (*(params->pphAdapters) != NULL) { - os << " {"; - for (size_t i = 0; i < *params->pNumAdapters; ++i) { - if (i != 0) { - os << ", "; - } - - ur::details::printPtr(os, (*(params->pphAdapters))[i]); - } - os << "}"; - } - - os << ", "; - os << ".NumAdapters = "; + os << ".hAdapter = "; - os << *(params->pNumAdapters); + ur::details::printPtr(os, *(params->phAdapter)); os << ", "; os << ".NumEntries = "; diff --git a/unified-runtime/scripts/core/PROG.rst b/unified-runtime/scripts/core/PROG.rst index 0e6d23300ee4e..84010942dee52 100644 --- a/unified-runtime/scripts/core/PROG.rst +++ b/unified-runtime/scripts/core/PROG.rst @@ -58,11 +58,16 @@ Initialization and Discovery ${x}AdapterGet(adapterCount, adapters.data(), nullptr); // Discover all the platform instances - uint32_t platformCount = 0; - ${x}PlatformGet(adapters.data(), adapterCount, 0, nullptr, &platformCount); - - std::vector<${x}_platform_handle_t> platforms(platformCount); - ${x}PlatformGet(adapters.data(), adapterCount, platform.size(), platforms.data(), &platformCount); + std::vector<${x}_platform_handle_t> platforms; + uint32_t totalPlatformCount = 0; + for (auto adapter : adapters) { + uint32_t adapterPlatformCount = 0; + ${x}PlatformGet(adapter, 0, nullptr, &adapterPlatformCount); + + platforms.reserve(totalPlatformCount + adapterPlatformCount); + ${x}PlatformGet(adapter, adapterPlatformCount, &platforms[totalPlatformCount], &adapterPlatformCount); + totalPlatformCount += adapterPlatformCount; + } // Get number of total GPU devices in the platform uint32_t deviceCount = 0; diff --git a/unified-runtime/scripts/core/platform.yml b/unified-runtime/scripts/core/platform.yml index 3566d42b7b09c..7d4edf5c0b5c0 100644 --- a/unified-runtime/scripts/core/platform.yml +++ b/unified-runtime/scripts/core/platform.yml @@ -24,12 +24,9 @@ details: - "Multiple calls to this function will return identical platforms handles, in the same order." - "The application may call this function from simultaneous threads, the implementation must be thread-safe" params: - - type: "$x_adapter_handle_t*" - name: "phAdapters" - desc: "[in][range(0, NumAdapters)] array of adapters to query for platforms." - - type: "uint32_t" - name: "NumAdapters" - desc: "[in] number of adapters pointed to by phAdapters" + - type: "$x_adapter_handle_t" + name: "hAdapter" + desc: "[in] adapter to query for platforms." - type: "uint32_t" name: NumEntries desc: | diff --git a/unified-runtime/scripts/templates/ldrddi.cpp.mako b/unified-runtime/scripts/templates/ldrddi.cpp.mako index b0d8ba0e563ad..5fff46d725dd2 100644 --- a/unified-runtime/scripts/templates/ldrddi.cpp.mako +++ b/unified-runtime/scripts/templates/ldrddi.cpp.mako @@ -101,49 +101,39 @@ namespace ur_loader } %elif func_basename == "PlatformGet": - uint32_t total_platform_handle_count = 0; - for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++) - { - // extract adapter's function pointer table - auto dditable = - reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']}[adapter_index])->dditable; + // extract adapter's function pointer table + auto dditable = + reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']})->dditable; - if( ( 0 < ${obj['params'][2]['name']} ) && ( ${obj['params'][2]['name']} == total_platform_handle_count)) - break; + uint32_t library_platform_handle_count = 0; - uint32_t library_platform_handle_count = 0; + result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, 0, nullptr, &library_platform_handle_count ); + if( ${X}_RESULT_SUCCESS != result ) return result; - result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, 0, nullptr, &library_platform_handle_count ); - if( ${X}_RESULT_SUCCESS != result ) break; + if( nullptr != ${obj['params'][2]['name']} && ${obj['params'][1]['name']} !=0) + { + if( library_platform_handle_count > ${obj['params'][1]['name']}) { + library_platform_handle_count = ${obj['params'][1]['name']}; + } + result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, library_platform_handle_count, ${obj['params'][2]['name']}, nullptr ); + if( ${X}_RESULT_SUCCESS != result ) return result; - if( nullptr != ${obj['params'][3]['name']} && ${obj['params'][2]['name']} !=0) + try { - if( total_platform_handle_count + library_platform_handle_count > ${obj['params'][2]['name']}) { - library_platform_handle_count = ${obj['params'][2]['name']} - total_platform_handle_count; - } - result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, library_platform_handle_count, &${obj['params'][3]['name']}[ total_platform_handle_count ], nullptr ); - if( ${X}_RESULT_SUCCESS != result ) break; - - try - { - for( uint32_t i = 0; i < library_platform_handle_count; ++i ) { - uint32_t platform_index = total_platform_handle_count + i; - ${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>( - context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) ); - } - } - catch( std::bad_alloc& ) - { - result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY; + for( uint32_t i = 0; i < library_platform_handle_count; ++i ) { + ${obj['params'][2]['name']}[ i ] = reinterpret_cast<${n}_platform_handle_t>( + context->factories.${n}_platform_factory.getInstance( ${obj['params'][2]['name']}[ i ], dditable ) ); } } - - total_platform_handle_count += library_platform_handle_count; + catch( std::bad_alloc& ) + { + result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } } - if( ${X}_RESULT_SUCCESS == result && ${obj['params'][4]['name']} != nullptr ) - *${obj['params'][4]['name']} = total_platform_handle_count; + if( ${X}_RESULT_SUCCESS == result && ${obj['params'][3]['name']} != nullptr ) + *${obj['params'][3]['name']} = library_platform_handle_count; %else: <%param_replacements={}%> diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index 9fe3deeed7672..aa1c0206b2d47 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -1285,15 +1285,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; ur_adapter_handle_t AdapterHandle = &adapter; - ur_result_t Result = - urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); + ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) return Result; std::vector Platforms(NumPlatforms); Result = - urPlatformGet(&AdapterHandle, 1, NumPlatforms, Platforms.data(), nullptr); + urPlatformGet(AdapterHandle, NumPlatforms, Platforms.data(), nullptr); if (Result != UR_RESULT_SUCCESS) return Result; diff --git a/unified-runtime/source/adapters/cuda/platform.cpp b/unified-runtime/source/adapters/cuda/platform.cpp index 953c655bedff5..47aa873ca70f8 100644 --- a/unified-runtime/source/adapters/cuda/platform.cpp +++ b/unified-runtime/source/adapters/cuda/platform.cpp @@ -114,7 +114,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( /// Triggers the CUDA Driver initialization (cuInit) the first time, so this /// must be the first PI API called. UR_APIEXPORT ur_result_t UR_APICALL -urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, +urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { try { diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index c2a933e0ee896..0b300ed97221e 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -243,7 +243,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // cuda backend has only one platform containing all devices ur_platform_handle_t platform; ur_adapter_handle_t AdapterHandle = &adapter; - Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr); + Result = urPlatformGet(AdapterHandle, 1, &platform, nullptr); // get the device from the platform ur_device_handle_t Device = platform->Devices[DeviceIndex].get(); diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index ffbcfb5bce917..d53805d206289 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -1182,8 +1182,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; ur_adapter_handle_t AdapterHandle = &adapter; - ur_result_t Result = - urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); + ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) return Result; @@ -1193,7 +1192,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( ur_platform_handle_t Platform = nullptr; - Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr); + Result = urPlatformGet(AdapterHandle, NumPlatforms, &Platform, nullptr); if (Result != UR_RESULT_SUCCESS) return Result; diff --git a/unified-runtime/source/adapters/hip/platform.cpp b/unified-runtime/source/adapters/hip/platform.cpp index fa0b07cc8244a..8fc44ec4b3858 100644 --- a/unified-runtime/source/adapters/hip/platform.cpp +++ b/unified-runtime/source/adapters/hip/platform.cpp @@ -50,7 +50,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName, /// Triggers the HIP Driver initialization (hipInit) the first time, so this /// must be the first UR API called. UR_APIEXPORT ur_result_t UR_APICALL -urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, +urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { try { diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index 60b86c730076d..34a2f6774744f 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -200,7 +200,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // hip backend has only one platform containing all devices ur_platform_handle_t platform; ur_adapter_handle_t AdapterHandle = &adapter; - UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr)); + UR_CHECK_ERROR(urPlatformGet(AdapterHandle, 1, &platform, nullptr)); // get the device from the platform ur_device_handle_t Device = platform->Devices[DeviceIdx].get(); diff --git a/unified-runtime/source/adapters/level_zero/platform.cpp b/unified-runtime/source/adapters/level_zero/platform.cpp index 59ec72cb9d004..1b0c923bdc8bf 100644 --- a/unified-runtime/source/adapters/level_zero/platform.cpp +++ b/unified-runtime/source/adapters/level_zero/platform.cpp @@ -15,7 +15,7 @@ namespace ur::level_zero { ur_result_t urPlatformGet( - ur_adapter_handle_t *, uint32_t, + ur_adapter_handle_t, /// [in] the number of platforms to be added to phPlatforms. If phPlatforms /// is not NULL, then NumEntries should be greater than zero, otherwise /// ::UR_RESULT_ERROR_INVALID_SIZE, will be returned. @@ -141,12 +141,12 @@ ur_result_t urPlatformCreateWithNativeHandle( uint32_t NumPlatforms = 0; ur_adapter_handle_t AdapterHandle = GlobalAdapter; - UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, 0, nullptr, - &NumPlatforms)); + UR_CALL( + ur::level_zero::urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms)); if (NumPlatforms) { std::vector Platforms(NumPlatforms); - UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumPlatforms, + UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumPlatforms, Platforms.data(), nullptr)); // The SYCL spec requires that the set of platforms must remain fixed for diff --git a/unified-runtime/source/adapters/level_zero/queue.cpp b/unified-runtime/source/adapters/level_zero/queue.cpp index 6cedd3e5bfbcd..46e71e2439d2d 100644 --- a/unified-runtime/source/adapters/level_zero/queue.cpp +++ b/unified-runtime/source/adapters/level_zero/queue.cpp @@ -786,8 +786,8 @@ ur_result_t urQueueCreateWithNativeHandle( uint32_t NumEntries = 1; ur_platform_handle_t Platform{}; ur_adapter_handle_t AdapterHandle = GlobalAdapter; - UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumEntries, - &Platform, nullptr)); + UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumEntries, &Platform, + nullptr)); ur_device_handle_t UrDevice = Device; if (UrDevice == nullptr) { diff --git a/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp b/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp index 78eb006d4d2ff..216f79f4afbf3 100644 --- a/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp +++ b/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp @@ -21,8 +21,7 @@ ur_result_t urAdapterGetLastError(ur_adapter_handle_t hAdapter, ur_result_t urAdapterGetInfo(ur_adapter_handle_t hAdapter, ur_adapter_info_t propName, size_t propSize, void *pPropValue, size_t *pPropSizeRet); -ur_result_t urPlatformGet(ur_adapter_handle_t *phAdapters, uint32_t NumAdapters, - uint32_t NumEntries, +ur_result_t urPlatformGet(ur_adapter_handle_t hAdapter, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms); ur_result_t urPlatformGetInfo(ur_platform_handle_t hPlatform, diff --git a/unified-runtime/source/adapters/mock/ur_mockddi.cpp b/unified-runtime/source/adapters/mock/ur_mockddi.cpp index 805b612dd69a5..9632ce986701a 100644 --- a/unified-runtime/source/adapters/mock/ur_mockddi.cpp +++ b/unified-runtime/source/adapters/mock/ur_mockddi.cpp @@ -264,10 +264,8 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urPlatformGet __urdlllocal ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, @@ -281,8 +279,8 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( uint32_t *pNumPlatforms) try { ur_result_t result = UR_RESULT_SUCCESS; - ur_platform_get_params_t params = {&phAdapters, &NumAdapters, &NumEntries, - &phPlatforms, &pNumPlatforms}; + ur_platform_get_params_t params = {&hAdapter, &NumEntries, &phPlatforms, + &pNumPlatforms}; auto beforeCallback = reinterpret_cast( mock::getCallbacks().get_before_callback("urPlatformGet")); diff --git a/unified-runtime/source/adapters/native_cpu/platform.cpp b/unified-runtime/source/adapters/native_cpu/platform.cpp index 8e550370792c7..54edc06b3366a 100644 --- a/unified-runtime/source/adapters/native_cpu/platform.cpp +++ b/unified-runtime/source/adapters/native_cpu/platform.cpp @@ -18,7 +18,7 @@ #include UR_APIEXPORT ur_result_t UR_APICALL -urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, +urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { UR_ASSERT(pNumPlatforms || phPlatforms, UR_RESULT_ERROR_INVALID_VALUE); diff --git a/unified-runtime/source/adapters/opencl/device.cpp b/unified-runtime/source/adapters/opencl/device.cpp index 1d862f5fbf3a8..d497e2b45e3ef 100644 --- a/unified-runtime/source/adapters/opencl/device.cpp +++ b/unified-runtime/source/adapters/opencl/device.cpp @@ -1601,10 +1601,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( cl_device_id NativeHandle = reinterpret_cast(hNativeDevice); uint32_t NumPlatforms = 0; - UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, 0, nullptr, &NumPlatforms)); + UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, nullptr, &NumPlatforms)); std::vector Platforms(NumPlatforms); UR_RETURN_ON_FAILURE( - urPlatformGet(nullptr, 0, NumPlatforms, Platforms.data(), nullptr)); + urPlatformGet(nullptr, NumPlatforms, Platforms.data(), nullptr)); for (uint32_t i = 0; i < NumPlatforms; i++) { uint32_t NumDevices = 0; diff --git a/unified-runtime/source/adapters/opencl/platform.cpp b/unified-runtime/source/adapters/opencl/platform.cpp index 5d58000b197a5..460203543b47a 100644 --- a/unified-runtime/source/adapters/opencl/platform.cpp +++ b/unified-runtime/source/adapters/opencl/platform.cpp @@ -66,7 +66,7 @@ urPlatformGetApiVersion([[maybe_unused]] ur_platform_handle_t hPlatform, } UR_APIEXPORT ur_result_t UR_APICALL -urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, +urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { static std::mutex adapterPopulationMutex{}; ur_adapter_handle_t Adapter = nullptr; @@ -143,10 +143,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( reinterpret_cast(hNativePlatform); uint32_t NumPlatforms = 0; - UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, 0, nullptr, &NumPlatforms)); + UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, nullptr, &NumPlatforms)); std::vector Platforms(NumPlatforms); UR_RETURN_ON_FAILURE( - urPlatformGet(nullptr, 0, NumPlatforms, Platforms.data(), nullptr)); + urPlatformGet(nullptr, NumPlatforms, Platforms.data(), nullptr)); for (uint32_t i = 0; i < NumPlatforms; i++) { if (Platforms[i]->CLPlatform == NativeHandle) { diff --git a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp index 47b804becb41e..e0a617d1b2213 100644 --- a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp +++ b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp @@ -214,10 +214,8 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urPlatformGet __urdlllocal ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, @@ -234,16 +232,15 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( if (nullptr == pfnGet) return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - ur_platform_get_params_t params = {&phAdapters, &NumAdapters, &NumEntries, - &phPlatforms, &pNumPlatforms}; + ur_platform_get_params_t params = {&hAdapter, &NumEntries, &phPlatforms, + &pNumPlatforms}; uint64_t instance = getContext()->notify_begin(UR_FUNCTION_PLATFORM_GET, "urPlatformGet", ¶ms); auto &logger = getContext()->logger; logger.info(" ---> urPlatformGet\n"); - ur_result_t result = - pfnGet(phAdapters, NumAdapters, NumEntries, phPlatforms, pNumPlatforms); + ur_result_t result = pfnGet(hAdapter, NumEntries, phPlatforms, pNumPlatforms); getContext()->notify_end(UR_FUNCTION_PLATFORM_GET, "urPlatformGet", ¶ms, &result, instance); diff --git a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp index 10e630ef1fbbd..bf9f742967fe2 100644 --- a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp +++ b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp @@ -199,10 +199,8 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urPlatformGet __urdlllocal ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, @@ -221,8 +219,8 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( } if (getContext()->enableParameterValidation) { - if (NULL == phAdapters) - return UR_RESULT_ERROR_INVALID_NULL_POINTER; + if (NULL == hAdapter) + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; if (NumEntries == 0 && phPlatforms != NULL) return UR_RESULT_ERROR_INVALID_SIZE; @@ -231,8 +229,12 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( return UR_RESULT_ERROR_INVALID_VALUE; } - ur_result_t result = - pfnGet(phAdapters, NumAdapters, NumEntries, phPlatforms, pNumPlatforms); + if (getContext()->enableLifetimeValidation && + !getContext()->refCountContext->isReferenceValid(hAdapter)) { + getContext()->refCountContext->logInvalidReference(hAdapter); + } + + ur_result_t result = pfnGet(hAdapter, NumEntries, phPlatforms, pNumPlatforms); return result; } diff --git a/unified-runtime/source/loader/ur_ldrddi.cpp b/unified-runtime/source/loader/ur_ldrddi.cpp index 842e93969f3e0..519a4b89e5d37 100644 --- a/unified-runtime/source/loader/ur_ldrddi.cpp +++ b/unified-runtime/source/loader/ur_ldrddi.cpp @@ -189,10 +189,8 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urPlatformGet __urdlllocal ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, @@ -207,55 +205,39 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( ur_result_t result = UR_RESULT_SUCCESS; [[maybe_unused]] auto context = getContext(); - uint32_t total_platform_handle_count = 0; - for (uint32_t adapter_index = 0; adapter_index < NumAdapters; - adapter_index++) { - // extract adapter's function pointer table - auto dditable = - reinterpret_cast(phAdapters[adapter_index]) - ->dditable; + // extract adapter's function pointer table + auto dditable = reinterpret_cast(hAdapter)->dditable; - if ((0 < NumEntries) && (NumEntries == total_platform_handle_count)) - break; + uint32_t library_platform_handle_count = 0; - uint32_t library_platform_handle_count = 0; + result = dditable->ur.Platform.pfnGet(hAdapter, 0, nullptr, + &library_platform_handle_count); + if (UR_RESULT_SUCCESS != result) + return result; - result = - dditable->ur.Platform.pfnGet(&phAdapters[adapter_index], 1, 0, nullptr, - &library_platform_handle_count); + if (nullptr != phPlatforms && NumEntries != 0) { + if (library_platform_handle_count > NumEntries) { + library_platform_handle_count = NumEntries; + } + result = dditable->ur.Platform.pfnGet( + hAdapter, library_platform_handle_count, phPlatforms, nullptr); if (UR_RESULT_SUCCESS != result) - break; - - if (nullptr != phPlatforms && NumEntries != 0) { - if (total_platform_handle_count + library_platform_handle_count > - NumEntries) { - library_platform_handle_count = - NumEntries - total_platform_handle_count; - } - result = dditable->ur.Platform.pfnGet( - &phAdapters[adapter_index], 1, library_platform_handle_count, - &phPlatforms[total_platform_handle_count], nullptr); - if (UR_RESULT_SUCCESS != result) - break; + return result; - try { - for (uint32_t i = 0; i < library_platform_handle_count; ++i) { - uint32_t platform_index = total_platform_handle_count + i; - phPlatforms[platform_index] = reinterpret_cast( - context->factories.ur_platform_factory.getInstance( - phPlatforms[platform_index], dditable)); - } - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + try { + for (uint32_t i = 0; i < library_platform_handle_count; ++i) { + phPlatforms[i] = reinterpret_cast( + context->factories.ur_platform_factory.getInstance(phPlatforms[i], + dditable)); } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } - - total_platform_handle_count += library_platform_handle_count; } if (UR_RESULT_SUCCESS == result && pNumPlatforms != nullptr) - *pNumPlatforms = total_platform_handle_count; + *pNumPlatforms = library_platform_handle_count; return result; } diff --git a/unified-runtime/source/loader/ur_libapi.cpp b/unified-runtime/source/loader/ur_libapi.cpp index 78ae7e5a6364f..a69ef8f785386 100644 --- a/unified-runtime/source/loader/ur_libapi.cpp +++ b/unified-runtime/source/loader/ur_libapi.cpp @@ -502,17 +502,15 @@ ur_result_t UR_APICALL urAdapterGetInfo( /// - ::UR_RESULT_ERROR_UNINITIALIZED /// - ::UR_RESULT_ERROR_DEVICE_LOST /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// + `NULL == phAdapters` +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hAdapter` /// - ::UR_RESULT_ERROR_INVALID_SIZE /// + `NumEntries == 0 && phPlatforms != NULL` /// - ::UR_RESULT_ERROR_INVALID_VALUE /// + `pNumPlatforms == NULL && phPlatforms == NULL` ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, @@ -528,8 +526,7 @@ ur_result_t UR_APICALL urPlatformGet( if (nullptr == pfnGet) return UR_RESULT_ERROR_UNINITIALIZED; - return pfnGet(phAdapters, NumAdapters, NumEntries, phPlatforms, - pNumPlatforms); + return pfnGet(hAdapter, NumEntries, phPlatforms, pNumPlatforms); } catch (...) { return exceptionToResult(std::current_exception()); } diff --git a/unified-runtime/source/ur_api.cpp b/unified-runtime/source/ur_api.cpp index 3ed7b5771f936..a71d65ae64630 100644 --- a/unified-runtime/source/ur_api.cpp +++ b/unified-runtime/source/ur_api.cpp @@ -459,17 +459,15 @@ ur_result_t UR_APICALL urAdapterGetInfo( /// - ::UR_RESULT_ERROR_UNINITIALIZED /// - ::UR_RESULT_ERROR_DEVICE_LOST /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// + `NULL == phAdapters` +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hAdapter` /// - ::UR_RESULT_ERROR_INVALID_SIZE /// + `NumEntries == 0 && phPlatforms != NULL` /// - ::UR_RESULT_ERROR_INVALID_VALUE /// + `pNumPlatforms == NULL && phPlatforms == NULL` ur_result_t UR_APICALL urPlatformGet( - /// [in][range(0, NumAdapters)] array of adapters to query for platforms. - ur_adapter_handle_t *phAdapters, - /// [in] number of adapters pointed to by phAdapters - uint32_t NumAdapters, + /// [in] adapter to query for platforms. + ur_adapter_handle_t hAdapter, /// [in] the number of platforms to be added to phPlatforms. /// If phPlatforms is not NULL, then NumEntries should be greater than /// zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, diff --git a/unified-runtime/test/conformance/platform/urPlatformGet.cpp b/unified-runtime/test/conformance/platform/urPlatformGet.cpp index 3194a4e72961d..22355cb83c399 100644 --- a/unified-runtime/test/conformance/platform/urPlatformGet.cpp +++ b/unified-runtime/test/conformance/platform/urPlatformGet.cpp @@ -6,49 +6,36 @@ #include -struct urPlatformGetTest : ::testing::Test { - std::vector &adapters = - uur::PlatformEnvironment::instance->adapters; -}; +using urPlatformGetTest = uur::urAdapterTest; -TEST_F(urPlatformGetTest, Success) { +UUR_INSTANTIATE_ADAPTER_TEST_SUITE(urPlatformGetTest); + +TEST_P(urPlatformGetTest, Success) { uint32_t count; - ASSERT_SUCCESS(urPlatformGet(adapters.data(), - static_cast(adapters.size()), 0, - nullptr, &count)); + ASSERT_SUCCESS(urPlatformGet(adapter, 0, nullptr, &count)); ASSERT_NE(count, 0); std::vector platforms(count); - ASSERT_SUCCESS(urPlatformGet(adapters.data(), - static_cast(adapters.size()), count, - platforms.data(), nullptr)); + ASSERT_SUCCESS(urPlatformGet(adapter, count, platforms.data(), nullptr)); for (auto platform : platforms) { ASSERT_NE(nullptr, platform); } } -TEST_F(urPlatformGetTest, InvalidNumEntries) { +TEST_P(urPlatformGetTest, InvalidNumEntries) { uint32_t count; - ASSERT_SUCCESS(urPlatformGet(adapters.data(), - static_cast(adapters.size()), 0, - nullptr, &count)); + ASSERT_SUCCESS(urPlatformGet(adapter, 0, nullptr, &count)); std::vector platforms(count); ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE, - urPlatformGet(adapters.data(), - static_cast(adapters.size()), 0, - platforms.data(), nullptr)); + urPlatformGet(adapter, 0, platforms.data(), nullptr)); } -TEST_F(urPlatformGetTest, InvalidNullPointer) { +TEST_P(urPlatformGetTest, InvalidNullPointer) { uint32_t count; - ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER, - urPlatformGet(nullptr, - static_cast(adapters.size()), 0, - nullptr, &count)); + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE, + urPlatformGet(nullptr, 0, nullptr, &count)); } -TEST_F(urPlatformGetTest, NullArgs) { +TEST_P(urPlatformGetTest, NullArgs) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_VALUE, - urPlatformGet(adapters.data(), - static_cast(adapters.size()), 0, - nullptr, nullptr)); + urPlatformGet(adapter, 0, nullptr, nullptr)); } diff --git a/unified-runtime/test/conformance/source/environment.cpp b/unified-runtime/test/conformance/source/environment.cpp index fe7f97782d1c0..adf4e7972d0da 100644 --- a/unified-runtime/test/conformance/source/environment.cpp +++ b/unified-runtime/test/conformance/source/environment.cpp @@ -80,12 +80,12 @@ uur::PlatformEnvironment::PlatformEnvironment() : AdapterEnvironment() { void uur::PlatformEnvironment::populatePlatforms() { for (auto a : adapters) { uint32_t count = 0; - ASSERT_SUCCESS(urPlatformGet(&a, 1, 0, nullptr, &count)); + ASSERT_SUCCESS(urPlatformGet(a, 0, nullptr, &count)); if (count == 0) { continue; } std::vector platform_list(count); - ASSERT_SUCCESS(urPlatformGet(&a, 1, count, platform_list.data(), nullptr)); + ASSERT_SUCCESS(urPlatformGet(a, count, platform_list.data(), nullptr)); for (auto p : platform_list) { platforms.push_back(p); diff --git a/unified-runtime/test/fuzz/urFuzz.cpp b/unified-runtime/test/fuzz/urFuzz.cpp index b4cd10fa0edba..89007cddd8e3f 100644 --- a/unified-runtime/test/fuzz/urFuzz.cpp +++ b/unified-runtime/test/fuzz/urFuzz.cpp @@ -21,12 +21,22 @@ namespace fuzz { constexpr int MAX_VECTOR_SIZE = 1024; int ur_platform_get(TestState &state) { - ur_result_t res = urPlatformGet(state.adapters.data(), state.adapters.size(), - state.num_entries, state.platforms.data(), - &state.num_platforms); - if (res != UR_RESULT_SUCCESS) { - return -1; + uint32_t totalPlatformCount = 0; + for (auto adapter : state.adapters) { + uint32_t adapterPlatformCount = 0; + urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount); + + state.platforms.reserve(totalPlatformCount + adapterPlatformCount); + ur_result_t res = urPlatformGet(adapter, adapterPlatformCount, + &state.platforms[totalPlatformCount], + &adapterPlatformCount); + if (res != UR_RESULT_SUCCESS) { + return -1; + } + + totalPlatformCount += adapterPlatformCount; } + if (state.platforms.size() != state.num_platforms) { state.platforms.resize(state.num_platforms); } diff --git a/unified-runtime/test/layers/sanitizer/asan.cpp b/unified-runtime/test/layers/sanitizer/asan.cpp index 43ddb6c74804d..b0d6cf027ba14 100644 --- a/unified-runtime/test/layers/sanitizer/asan.cpp +++ b/unified-runtime/test/layers/sanitizer/asan.cpp @@ -47,7 +47,7 @@ TEST(DeviceAsan, Initialization) { } ur_platform_handle_t platform; - status = urPlatformGet(&adapter, 1, 1, &platform, nullptr); + status = urPlatformGet(adapter, 1, &platform, nullptr); ASSERT_EQ(status, UR_RESULT_SUCCESS); ur_device_handle_t device; @@ -109,7 +109,7 @@ TEST(DeviceAsan, UnsupportedFeature) { } ur_platform_handle_t platform; - status = urPlatformGet(&adapter, 1, 1, &platform, nullptr); + status = urPlatformGet(adapter, 1, &platform, nullptr); ASSERT_EQ(status, UR_RESULT_SUCCESS); ur_device_handle_t device; diff --git a/unified-runtime/test/layers/tracing/hello_world.out.logged.match b/unified-runtime/test/layers/tracing/hello_world.out.logged.match index 9e98f457c7ae8..24d900e56383d 100644 --- a/unified-runtime/test/layers/tracing/hello_world.out.logged.match +++ b/unified-runtime/test/layers/tracing/hello_world.out.logged.match @@ -4,9 +4,9 @@ Platform initialized. ---> urAdapterGet <--- urAdapterGet(.NumEntries = 1, .phAdapters = {{.*}}, .pNumAdapters = nullptr) -> UR_RESULT_SUCCESS; ---> urPlatformGet - <--- urPlatformGet(.phAdapters = {{.*}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = {{.*}}, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; + <--- urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = {{.*}}, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; ---> urPlatformGet - <--- urPlatformGet(.phAdapters = {{.*}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = {{.*}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; + <--- urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = {{.*}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; ---> urPlatformGetApiVersion <--- urPlatformGetApiVersion(.hPlatform = {{.*}}, .pVersion = {{.*}} ({{0\.[0-9]+}})) -> UR_RESULT_SUCCESS; API version: {{0\.[0-9]+}} diff --git a/unified-runtime/test/layers/validation/fixtures.hpp b/unified-runtime/test/layers/validation/fixtures.hpp index d712cabc6df6b..9826ab08639bd 100644 --- a/unified-runtime/test/layers/validation/fixtures.hpp +++ b/unified-runtime/test/layers/validation/fixtures.hpp @@ -72,17 +72,14 @@ struct valPlatformsTest : valAdaptersTest { void SetUp() override { valAdaptersTest::SetUp(); - uint32_t count; - ASSERT_EQ(urPlatformGet(adapters.data(), - static_cast(adapters.size()), 0, nullptr, - &count), - UR_RESULT_SUCCESS); - ASSERT_GT(count, 0); - platforms.resize(count); - ASSERT_EQ(urPlatformGet(adapters.data(), - static_cast(adapters.size()), count, - platforms.data(), nullptr), - UR_RESULT_SUCCESS); + for (auto adapter : adapters) { + uint32_t count; + ASSERT_EQ(urPlatformGet(adapter, 0, nullptr, &count), UR_RESULT_SUCCESS); + ASSERT_GT(count, 0); + platforms.resize(count); + ASSERT_EQ(urPlatformGet(adapter, count, platforms.data(), nullptr), + UR_RESULT_SUCCESS); + } } std::vector platforms; diff --git a/unified-runtime/test/loader/handles/fixtures.hpp b/unified-runtime/test/loader/handles/fixtures.hpp index d2eaab13f3543..4f07812e9993c 100644 --- a/unified-runtime/test/loader/handles/fixtures.hpp +++ b/unified-runtime/test/loader/handles/fixtures.hpp @@ -56,7 +56,7 @@ struct LoaderHandleTest : ::testing::Test { ASSERT_NE(adapter, nullptr); uint32_t nplatforms = 0; platform = nullptr; - ASSERT_SUCCESS(urPlatformGet(&adapter, 1, 1, &platform, &nplatforms)); + ASSERT_SUCCESS(urPlatformGet(adapter, 1, &platform, &nplatforms)); ASSERT_NE(platform, nullptr); uint32_t ndevices; device = nullptr; diff --git a/unified-runtime/test/loader/platforms/platforms.cpp b/unified-runtime/test/loader/platforms/platforms.cpp index bbadbbcb199c9..c80e671c96747 100644 --- a/unified-runtime/test/loader/platforms/platforms.cpp +++ b/unified-runtime/test/loader/platforms/platforms.cpp @@ -49,22 +49,25 @@ int main(int, char *[]) { uint32_t platformCount = 0; std::vector platforms; + for (auto adapter : adapters) { + uint32_t adapterPlatformCount = 0; + status = urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount); + if (status != UR_RESULT_SUCCESS) { + out.error("urPlatformGet failed with return code: {}", status); + goto out; + } + out.info("urPlatformGet found {} platforms", platformCount); - status = - urPlatformGet(adapters.data(), adapterCount, 1, nullptr, &platformCount); - if (status != UR_RESULT_SUCCESS) { - out.error("urPlatformGet failed with return code: {}", status); - goto out; + platforms.reserve(platformCount + adapterPlatformCount); + status = urPlatformGet(adapter, adapterPlatformCount, + &platforms[platformCount], &adapterPlatformCount); + if (status != UR_RESULT_SUCCESS) { + out.error("urPlatformGet failed with return code: {}", status); + goto out; + } + platformCount += adapterPlatformCount; } - out.info("urPlatformGet found {} platforms", platformCount); - platforms.resize(platformCount); - status = urPlatformGet(adapters.data(), adapterCount, platformCount, - platforms.data(), nullptr); - if (status != UR_RESULT_SUCCESS) { - out.error("urPlatformGet failed with return code: {}", status); - goto out; - } for (auto p : platforms) { size_t name_len; diff --git a/unified-runtime/test/mock/mock.cpp b/unified-runtime/test/mock/mock.cpp index 1a24cda24ec78..68032f930ad81 100644 --- a/unified-runtime/test/mock/mock.cpp +++ b/unified-runtime/test/mock/mock.cpp @@ -36,8 +36,7 @@ TEST(Mock, DefaultBehavior) { ur_device_handle_t device = nullptr; ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); - ASSERT_EQ(urPlatformGet(&adapter, 1, 1, &platform, nullptr), - UR_RESULT_SUCCESS); + ASSERT_EQ(urPlatformGet(adapter, 1, &platform, nullptr), UR_RESULT_SUCCESS); ASSERT_EQ(urDeviceGet(platform, UR_DEVICE_TYPE_ALL, 1, &device, nullptr), UR_RESULT_SUCCESS); diff --git a/unified-runtime/test/tools/urtrace/mock_hello.match b/unified-runtime/test/tools/urtrace/mock_hello.match index a0af2152a77bd..f47da33648a2b 100644 --- a/unified-runtime/test/tools/urtrace/mock_hello.match +++ b/unified-runtime/test/tools/urtrace/mock_hello.match @@ -1,8 +1,8 @@ Platform initialized. urAdapterGet(.NumEntries = 0, .phAdapters = nullptr, .pNumAdapters = {{.*}} (1)) -> UR_RESULT_SUCCESS; urAdapterGet(.NumEntries = 1, .phAdapters = {{.*}} {{{.*}}}, .pNumAdapters = nullptr) -> UR_RESULT_SUCCESS; -urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; -urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = {{.*}} {{{.*}}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; +urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; +urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = {{.*}} {{{.*}}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; urPlatformGetApiVersion(.hPlatform = {{.*}}, .pVersion = {{.*}} ({{.*}})) -> UR_RESULT_SUCCESS; API version: {{.*}} urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 0, .phDevices = nullptr, .pNumDevices = {{.*}} (1)) -> UR_RESULT_SUCCESS; diff --git a/unified-runtime/test/tools/urtrace/mock_hello_begin.match b/unified-runtime/test/tools/urtrace/mock_hello_begin.match index 318abc7eec46e..536d9561c95b5 100644 --- a/unified-runtime/test/tools/urtrace/mock_hello_begin.match +++ b/unified-runtime/test/tools/urtrace/mock_hello_begin.match @@ -3,10 +3,10 @@ begin(1) - urAdapterGet(.NumEntries = 0, .phAdapters = nullptr, .pNumAdapters = end(1) - urAdapterGet(.NumEntries = 0, .phAdapters = nullptr, .pNumAdapters = {{.*}} (1)) -> UR_RESULT_SUCCESS; begin(2) - urAdapterGet(.NumEntries = 1, .phAdapters = {{.*}} {{{.*}}}, .pNumAdapters = nullptr); end(2) - urAdapterGet(.NumEntries = 1, .phAdapters = {{.*}} {{{.*}}}, .pNumAdapters = nullptr) -> UR_RESULT_SUCCESS; -begin(3) - urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (0)); -end(3) - urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; -begin(4) - urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = {{.*}} {nullptr}, .pNumPlatforms = nullptr); -end(4) - urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = {{.*}} {{{.*}}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; +begin(3) - urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (0)); +end(3) - urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; +begin(4) - urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = {{.*}} {nullptr}, .pNumPlatforms = nullptr); +end(4) - urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = {{.*}} {{{.*}}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; begin(5) - urPlatformGetApiVersion(.hPlatform = {{.*}}, .pVersion = {{.*}} (0.0)); end(5) - urPlatformGetApiVersion(.hPlatform = {{.*}}, .pVersion = {{.*}} (@PROJECT_VERSION_MAJOR@.@PROJECT_VERSION_MINOR@)) -> UR_RESULT_SUCCESS; API version: {{.*}} diff --git a/unified-runtime/test/tools/urtrace/mock_hello_profiling.match b/unified-runtime/test/tools/urtrace/mock_hello_profiling.match index 0ab658a3ebd3d..0ebfd67955b55 100644 --- a/unified-runtime/test/tools/urtrace/mock_hello_profiling.match +++ b/unified-runtime/test/tools/urtrace/mock_hello_profiling.match @@ -1,8 +1,8 @@ Platform initialized. urAdapterGet(.NumEntries = 0, .phAdapters = nullptr, .pNumAdapters = {{.*}} (1)) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) urAdapterGet(.NumEntries = 1, .phAdapters = {{.*}} {{{.*}}}, .pNumAdapters = nullptr) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) -urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) -urPlatformGet(.phAdapters = {{.*}} {{{.*}}}, .NumAdapters = 1, .NumEntries = 1, .phPlatforms = {{.*}} {{{.*}}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) +urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = nullptr, .pNumPlatforms = {{.*}} (1)) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) +urPlatformGet(.hAdapter = {{.*}}, .NumEntries = 1, .phPlatforms = {{.*}} {{{.*}}}, .pNumPlatforms = nullptr) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) urPlatformGetApiVersion(.hPlatform = {{.*}}, .pVersion = {{.*}} ({{.*}})) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) API version: {{.*}} urDeviceGet(.hPlatform = {{.*}}, .DeviceType = UR_DEVICE_TYPE_GPU, .NumEntries = 0, .phDevices = nullptr, .pNumDevices = {{.*}} (1)) -> UR_RESULT_SUCCESS; ({{[0-9]+}}ns) diff --git a/unified-runtime/test/unit/print.h b/unified-runtime/test/unit/print.h index 70a90ab6472e3..721756e2ed221 100644 --- a/unified-runtime/test/unit/print.h +++ b/unified-runtime/test/unit/print.h @@ -45,19 +45,16 @@ struct UrLoaderInitParamsInvalidFlags : UrLoaderInitParams { struct UrPlatformGet { ur_platform_get_params_t params; - uint32_t num_adapters; - ur_adapter_handle_t *phAdapters; + ur_adapter_handle_t adapter; uint32_t num_entries; uint32_t *pNumPlatforms; ur_platform_handle_t *pPlatforms; UrPlatformGet() { - num_adapters = 0; - phAdapters = nullptr; + adapter = nullptr; num_entries = 0; pPlatforms = nullptr; pNumPlatforms = nullptr; - params.pNumAdapters = &num_adapters; - params.pphAdapters = &phAdapters; + params.phAdapter = &adapter; params.pNumEntries = &num_entries; params.pphPlatforms = &pPlatforms; params.ppNumPlatforms = &pNumPlatforms; diff --git a/unified-runtime/tools/urinfo/urinfo.cpp b/unified-runtime/tools/urinfo/urinfo.cpp index c86d166034bdd..5076077c933ed 100644 --- a/unified-runtime/tools/urinfo/urinfo.cpp +++ b/unified-runtime/tools/urinfo/urinfo.cpp @@ -98,12 +98,12 @@ devices which are currently visible in the local execution environment. auto adapter = adapters[adapterIndex]; // Enumerate platforms uint32_t numPlatforms = 0; - UR_CHECK(urPlatformGet(&adapter, 1, 0, nullptr, &numPlatforms)); + UR_CHECK(urPlatformGet(adapter, 0, nullptr, &numPlatforms)); if (numPlatforms == 0) { continue; } adapterPlatformsMap[adapter].resize(numPlatforms); - UR_CHECK(urPlatformGet(&adapter, 1, numPlatforms, + UR_CHECK(urPlatformGet(adapter, numPlatforms, adapterPlatformsMap[adapter].data(), nullptr)); for (size_t platformIndex = 0; platformIndex < numPlatforms;