@@ -886,6 +886,79 @@ namespace dpct
886886 return -1;
887887 }
888888
889+ inline std::string get_preferred_gpu_platform_name() {
890+ std::string result;
891+
892+ std::string filter = " " ;
893+ char* env = getenv(" ONEAPI_DEVICE_SELECTOR" );
894+ if (env) {
895+ if (std::strstr(env, " level_zero" )) {
896+ filter = " level-zero" ;
897+ }
898+ else if (std::strstr(env, " opencl" )) {
899+ filter = " opencl" ;
900+ }
901+ else if (std::strstr(env, " cuda" )) {
902+ filter = " cuda" ;
903+ }
904+ else if (std::strstr(env, " hip" )) {
905+ filter = " hip" ;
906+ }
907+ else {
908+ throw std::runtime_error(" invalid device filter: " + std::string(env));
909+ }
910+ } else {
911+ auto default_device = sycl::device(sycl::default_selector_v);
912+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
913+
914+ if (std::strstr(default_platform_name.c_str(), " Level-Zero" ) || default_device.is_cpu()) {
915+ filter = " level-zero" ;
916+ }
917+ else if (std::strstr(default_platform_name.c_str(), " CUDA" )) {
918+ filter = " cuda" ;
919+ }
920+ else if (std::strstr(default_platform_name.c_str(), " HIP" )) {
921+ filter = " hip" ;
922+ }
923+ }
924+
925+ auto platform_list = sycl::platform::get_platforms();
926+
927+ for (const auto& platform : platform_list) {
928+ auto devices = platform.get_devices();
929+ auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
930+ return d.is_gpu();
931+ });
932+
933+ if (gpu_dev == devices.end()) {
934+ // cout << " platform [" << platform_name
935+ // << " ] does not contain GPU devices, skipping\n" ;
936+ continue;
937+ }
938+
939+ auto platform_name = platform.get_info<sycl::info::platform::name>();
940+ std::string platform_name_low_case;
941+ platform_name_low_case.resize(platform_name.size());
942+
943+ std::transform(
944+ platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
945+
946+ if (platform_name_low_case.find(filter) == std::string::npos) {
947+ // cout << " platform [" << platform_name
948+ // << " ] does not match with requested "
949+ // << filter << " , skipping\n" ;
950+ continue;
951+ }
952+
953+ result = platform_name;
954+ }
955+
956+ if (result.empty())
957+ throw std::runtime_error(" can not find preferred GPU platform" );
958+
959+ return result;
960+ }
961+
889962 template <class DeviceSelector>
890963 std::enable_if_t<
891964 std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
0 commit comments