Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/detail/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class Adapter {
std::vector<ur_platform_handle_t> &getUrPlatforms() {
std::call_once(PlatformsPopulated, [&]() {
uint32_t platformCount = 0;
call<UrApiKind::urPlatformGet>(&MAdapter, 1, 0, nullptr, &platformCount);
call<UrApiKind::urPlatformGet>(MAdapter, 0, nullptr, &platformCount);
UrPlatforms.resize(platformCount);
if (platformCount) {
call<UrApiKind::urPlatformGet>(&MAdapter, 1, platformCount,
call<UrApiKind::urPlatformGet>(MAdapter, platformCount,
UrPlatforms.data(), nullptr);
}
// We need one entry in this per platform
Expand Down
15 changes: 10 additions & 5 deletions unified-runtime/examples/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,21 @@ get_supported_adapters(std::vector<ur_adapter_handle_t> &adapters) {
std::vector<ur_platform_handle_t>
get_platforms(std::vector<ur_adapter_handle_t> &adapters) {
uint32_t platformCount = 0;
ur_check(urPlatformGet(adapters.data(), adapters.size(), 1, nullptr,
&platformCount));
std::vector<ur_platform_handle_t> platforms;
for (auto adapter : adapters) {
uint32_t adapterPlatformCount = 0;
urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount);

platforms.reserve(platformCount + adapterPlatformCount);
urPlatformGet(adapter, adapterPlatformCount, &platforms[platformCount],
&adapterPlatformCount);
platformCount += adapterPlatformCount;
}
if (!platformCount) {
throw std::runtime_error("No platforms available.");
}
platforms.resize(platformCount);
Comment on lines +72 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is UB, the access to platforms[platformCount] is out of bounds, because platforms.size() is always 0 here. This should be using resize() and not reserve().
In practice the later .resize() will reset all platforms to nullptr and make the example crash. See #18032


std::vector<ur_platform_handle_t> platforms(platformCount);
ur_check(urPlatformGet(adapters.data(), adapters.size(), platformCount,
platforms.data(), nullptr));
return platforms;
}

Expand Down
32 changes: 18 additions & 14 deletions unified-runtime/examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,26 @@ int main(int, char *[]) {
return 1;
}

status =
urPlatformGet(adapters.data(), adapterCount, 1, nullptr, &platformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}
for (auto adapter : adapters) {
uint32_t adapterPlatformCount = 0;
status = urPlatformGet(adapter, 0, nullptr, &adapterPlatformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}

platforms.resize(platformCount);
status = urPlatformGet(adapters.data(), adapterCount, platformCount,
platforms.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
platforms.reserve(platformCount + adapterPlatformCount);
status = urPlatformGet(adapter, adapterPlatformCount,
&platforms[platformCount], &adapterPlatformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}
platformCount += adapterPlatformCount;
}
platforms.resize(platformCount);

for (auto p : platforms) {
ur_api_version_t api_version = {};
Expand Down
13 changes: 5 additions & 8 deletions unified-runtime/include/ur_api.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions unified-runtime/include/ur_ddi.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 2 additions & 18 deletions unified-runtime/include/ur_print.hpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions unified-runtime/scripts/core/PROG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ Initialization and Discovery
${x}AdapterGet(adapterCount, adapters.data(), nullptr);

// Discover all the platform instances
uint32_t platformCount = 0;
${x}PlatformGet(adapters.data(), adapterCount, 0, nullptr, &platformCount);

std::vector<${x}_platform_handle_t> platforms(platformCount);
${x}PlatformGet(adapters.data(), adapterCount, platform.size(), platforms.data(), &platformCount);
std::vector<${x}_platform_handle_t> platforms;
uint32_t totalPlatformCount = 0;
for (auto adapter : adapters) {
uint32_t adapterPlatformCount = 0;
${x}PlatformGet(adapter, 0, nullptr, &adapterPlatformCount);

platforms.reserve(totalPlatformCount + adapterPlatformCount);
${x}PlatformGet(adapter, adapterPlatformCount, &platforms[totalPlatformCount], &adapterPlatformCount);
totalPlatformCount += adapterPlatformCount;
}

// Get number of total GPU devices in the platform
uint32_t deviceCount = 0;
Expand Down
9 changes: 3 additions & 6 deletions unified-runtime/scripts/core/platform.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ details:
- "Multiple calls to this function will return identical platforms handles, in the same order."
- "The application may call this function from simultaneous threads, the implementation must be thread-safe"
params:
- type: "$x_adapter_handle_t*"
name: "phAdapters"
desc: "[in][range(0, NumAdapters)] array of adapters to query for platforms."
- type: "uint32_t"
name: "NumAdapters"
desc: "[in] number of adapters pointed to by phAdapters"
- type: "$x_adapter_handle_t"
name: "hAdapter"
desc: "[in] adapter to query for platforms."
- type: "uint32_t"
name: NumEntries
desc: |
Expand Down
56 changes: 23 additions & 33 deletions unified-runtime/scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -101,49 +101,39 @@ namespace ur_loader
}

%elif func_basename == "PlatformGet":
uint32_t total_platform_handle_count = 0;

for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
{
// extract adapter's function pointer table
auto dditable =
reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']}[adapter_index])->dditable;
// extract adapter's function pointer table
auto dditable =
reinterpret_cast<${n}_platform_object_t *>( ${obj['params'][0]['name']})->dditable;

if( ( 0 < ${obj['params'][2]['name']} ) && ( ${obj['params'][2]['name']} == total_platform_handle_count))
break;
uint32_t library_platform_handle_count = 0;

uint32_t library_platform_handle_count = 0;
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, 0, nullptr, &library_platform_handle_count );
if( ${X}_RESULT_SUCCESS != result ) return result;

result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, 0, nullptr, &library_platform_handle_count );
if( ${X}_RESULT_SUCCESS != result ) break;
if( nullptr != ${obj['params'][2]['name']} && ${obj['params'][1]['name']} !=0)
{
if( library_platform_handle_count > ${obj['params'][1]['name']}) {
library_platform_handle_count = ${obj['params'][1]['name']};
}
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( ${obj['params'][0]['name']}, library_platform_handle_count, ${obj['params'][2]['name']}, nullptr );
if( ${X}_RESULT_SUCCESS != result ) return result;

if( nullptr != ${obj['params'][3]['name']} && ${obj['params'][2]['name']} !=0)
try
{
if( total_platform_handle_count + library_platform_handle_count > ${obj['params'][2]['name']}) {
library_platform_handle_count = ${obj['params'][2]['name']} - total_platform_handle_count;
}
result = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &${obj['params'][0]['name']}[adapter_index], 1, library_platform_handle_count, &${obj['params'][3]['name']}[ total_platform_handle_count ], nullptr );
if( ${X}_RESULT_SUCCESS != result ) break;

try
{
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
uint32_t platform_index = total_platform_handle_count + i;
${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>(
context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
}
}
catch( std::bad_alloc& )
{
result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY;
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
${obj['params'][2]['name']}[ i ] = reinterpret_cast<${n}_platform_handle_t>(
context->factories.${n}_platform_factory.getInstance( ${obj['params'][2]['name']}[ i ], dditable ) );
}
}

total_platform_handle_count += library_platform_handle_count;
catch( std::bad_alloc& )
{
result = ${X}_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
}

if( ${X}_RESULT_SUCCESS == result && ${obj['params'][4]['name']} != nullptr )
*${obj['params'][4]['name']} = total_platform_handle_count;
if( ${X}_RESULT_SUCCESS == result && ${obj['params'][3]['name']} != nullptr )
*${obj['params'][3]['name']} = library_platform_handle_count;

%else:
<%param_replacements={}%>
Expand Down
5 changes: 2 additions & 3 deletions unified-runtime/source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1285,15 +1285,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
ur_result_t Result =
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
if (Result != UR_RESULT_SUCCESS)
return Result;

std::vector<ur_platform_handle_t> Platforms(NumPlatforms);

Result =
urPlatformGet(&AdapterHandle, 1, NumPlatforms, Platforms.data(), nullptr);
urPlatformGet(AdapterHandle, NumPlatforms, Platforms.data(), nullptr);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
/// Triggers the CUDA Driver initialization (cuInit) the first time, so this
/// must be the first PI API called.
UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

try {
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
// cuda backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
Result = urPlatformGet(AdapterHandle, 1, &platform, nullptr);

// get the device from the platform
ur_device_handle_t Device = platform->Devices[DeviceIndex].get();
Expand Down
5 changes: 2 additions & 3 deletions unified-runtime/source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,8 +1182,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
ur_result_t Result =
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand All @@ -1193,7 +1192,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(

ur_platform_handle_t Platform = nullptr;

Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr);
Result = urPlatformGet(AdapterHandle, NumPlatforms, &Platform, nullptr);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/hip/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
/// Triggers the HIP Driver initialization (hipInit) the first time, so this
/// must be the first UR API called.
UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

try {
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
// hip backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr));
UR_CHECK_ERROR(urPlatformGet(AdapterHandle, 1, &platform, nullptr));

// get the device from the platform
ur_device_handle_t Device = platform->Devices[DeviceIdx].get();
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace ur::level_zero {

ur_result_t urPlatformGet(
ur_adapter_handle_t *, uint32_t,
ur_adapter_handle_t,
/// [in] the number of platforms to be added to phPlatforms. If phPlatforms
/// is not NULL, then NumEntries should be greater than zero, otherwise
/// ::UR_RESULT_ERROR_INVALID_SIZE, will be returned.
Expand Down Expand Up @@ -141,12 +141,12 @@ ur_result_t urPlatformCreateWithNativeHandle(

uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, 0, nullptr,
&NumPlatforms));
UR_CALL(
ur::level_zero::urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms));

if (NumPlatforms) {
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumPlatforms,
UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumPlatforms,
Platforms.data(), nullptr));

// The SYCL spec requires that the set of platforms must remain fixed for
Expand Down
4 changes: 2 additions & 2 deletions unified-runtime/source/adapters/level_zero/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ ur_result_t urQueueCreateWithNativeHandle(
uint32_t NumEntries = 1;
ur_platform_handle_t Platform{};
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
UR_CALL(ur::level_zero::urPlatformGet(&AdapterHandle, 1, NumEntries,
&Platform, nullptr));
UR_CALL(ur::level_zero::urPlatformGet(AdapterHandle, NumEntries, &Platform,
nullptr));

ur_device_handle_t UrDevice = Device;
if (UrDevice == nullptr) {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading