@@ -29,6 +29,28 @@ class platform;
2929
3030namespace detail {
3131
32+ // Note that UR's enums have weird *_FORCE_UINT32 values, we ignore them in the
33+ // callers. But we also can't write a fully-covered switch without mentioning it
34+ // there, which wouldn't make any sense. As such, ensure that "real" values
35+ // match and then just `static_cast` them (in the caller).
36+ template <typename T0, typename T1>
37+ constexpr bool enums_match (std::initializer_list<T0> l0,
38+ std::initializer_list<T1> l1) {
39+ using U0 = std::underlying_type_t <T0>;
40+ using U1 = std::underlying_type_t <T1>;
41+ using C = std::common_type_t <U0, U1>;
42+ // std::equal isn't constexpr until C++20.
43+ if (l0.size () != l1.size ())
44+ return false ;
45+ auto i0 = l0.begin ();
46+ auto e = l0.end ();
47+ auto i1 = l1.begin ();
48+ for (; i0 != e; ++i0, ++i1)
49+ if (static_cast <C>(*i0) != static_cast <C>(*i1))
50+ return false ;
51+ return true ;
52+ }
53+
3254// Forward declaration
3355class platform_impl ;
3456
@@ -208,6 +230,30 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
208230
209231 // device_traits.def
210232
233+ CASE (info::device::device_type) {
234+ using device_type = info::device_type;
235+ switch (get_info_impl<ur_device_type_t , UR_DEVICE_INFO_TYPE>()) {
236+ case UR_DEVICE_TYPE_DEFAULT:
237+ return device_type::automatic;
238+ case UR_DEVICE_TYPE_ALL:
239+ return device_type::all;
240+ case UR_DEVICE_TYPE_GPU:
241+ return device_type::gpu;
242+ case UR_DEVICE_TYPE_CPU:
243+ return device_type::cpu;
244+ case UR_DEVICE_TYPE_FPGA:
245+ return device_type::accelerator;
246+ case UR_DEVICE_TYPE_MCA:
247+ case UR_DEVICE_TYPE_VPU:
248+ return device_type::custom;
249+ default : {
250+ assert (false );
251+ // FIXME: what is that???
252+ return device_type::custom;
253+ }
254+ }
255+ }
256+
211257 CASE (info::device::max_work_item_sizes<3 >) {
212258 auto result = get_info_impl<std::array<size_t , 3 >,
213259 UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES>();
@@ -242,24 +288,46 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
242288 return get_fp_config<UR_DEVICE_INFO_DOUBLE_FP_CONFIG>();
243289 }
244290
291+ CASE (info::device::global_mem_cache_type) {
292+ using cache = info::global_mem_cache_type;
293+ static_assert (
294+ enums_match ({UR_DEVICE_MEM_CACHE_TYPE_NONE,
295+ UR_DEVICE_MEM_CACHE_TYPE_READ_ONLY_CACHE,
296+ UR_DEVICE_MEM_CACHE_TYPE_READ_WRITE_CACHE},
297+ {cache::none, cache::read_only, cache::read_write}));
298+ return static_cast <cache>(
299+ get_info_impl<ur_device_mem_cache_type_t ,
300+ UR_DEVICE_INFO_GLOBAL_MEM_CACHE_TYPE>());
301+ }
302+
303+ CASE (info::device::local_mem_type) {
304+ using mem = info::local_mem_type;
305+ static_assert (enums_match ({UR_DEVICE_LOCAL_MEM_TYPE_NONE,
306+ UR_DEVICE_LOCAL_MEM_TYPE_LOCAL,
307+ UR_DEVICE_LOCAL_MEM_TYPE_GLOBAL},
308+ {mem::none, mem::local, mem::global}));
309+ return static_cast <mem>(get_info_impl<ur_device_local_mem_type_t ,
310+ UR_DEVICE_INFO_LOCAL_MEM_TYPE>());
311+ }
312+
245313 CASE (info::device::atomic_memory_order_capabilities) {
246314 return readMemoryOrderBitfield (
247- get_info_impl<ur_memory_order_capability_flag_t ,
315+ get_info_impl<ur_memory_order_capability_flags_t ,
248316 UR_DEVICE_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES>());
249317 }
250318 CASE (info::device::atomic_fence_order_capabilities) {
251319 return readMemoryOrderBitfield (
252- get_info_impl<ur_memory_order_capability_flag_t ,
320+ get_info_impl<ur_memory_order_capability_flags_t ,
253321 UR_DEVICE_INFO_ATOMIC_FENCE_ORDER_CAPABILITIES>());
254322 }
255323 CASE (info::device::atomic_memory_scope_capabilities) {
256324 return readMemoryScopeBitfield (
257- get_info_impl<size_t ,
325+ get_info_impl<ur_memory_scope_capability_flags_t ,
258326 UR_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES>());
259327 }
260328 CASE (info::device::atomic_fence_scope_capabilities) {
261329 return readMemoryScopeBitfield (
262- get_info_impl<size_t ,
330+ get_info_impl<ur_memory_scope_capability_flags_t ,
263331 UR_DEVICE_INFO_ATOMIC_FENCE_SCOPE_CAPABILITIES>());
264332 }
265333
@@ -269,8 +337,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
269337 " info::device::execution_capabilities is available for "
270338 " backend::opencl only" );
271339
272- ur_device_exec_capability_flag_t bits =
273- get_info_impl<ur_device_exec_capability_flag_t ,
340+ ur_device_exec_capability_flags_t bits =
341+ get_info_impl<ur_device_exec_capability_flags_t ,
274342 UR_DEVICE_INFO_EXECUTION_CAPABILITIES>();
275343 std::vector<info::execution_capability> result;
276344 if (bits & UR_DEVICE_EXEC_CAPABILITY_FLAG_KERNEL)
@@ -593,6 +661,12 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
593661 return get_matrix_combinations ();
594662 }
595663
664+ CASE (ext::oneapi::experimental::info::device::mipmap_max_anisotropy) {
665+ // Implicit conversion:
666+ return get_info_impl<uint32_t ,
667+ UR_DEVICE_INFO_MIPMAP_MAX_ANISOTROPY_EXP>();
668+ }
669+
596670 CASE (ext::oneapi::experimental::info::device::component_devices) {
597671 auto Devs = get_info_impl_nocheck<std::vector<ur_device_handle_t >,
598672 UR_DEVICE_INFO_COMPONENT_DEVICES>();
@@ -628,6 +702,10 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
628702 " A component with aspect::ext_oneapi_is_component "
629703 " must have a composite device." );
630704 }
705+ CASE (ext::oneapi::info::device::num_compute_units) {
706+ // uint32_t -> size_t
707+ return get_info_impl<uint32_t , UR_DEVICE_INFO_NUM_COMPUTE_UNITS>();
708+ }
631709
632710 // ext_intel_device_traits.def
633711
@@ -718,6 +796,11 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
718796 " The device does not have the ext_intel_memory_bus_width aspect" );
719797 return get_info_impl<uint32_t , UR_DEVICE_INFO_MEMORY_BUS_WIDTH>();
720798 }
799+ CASE (ext::intel::info::device::max_compute_queue_indices) {
800+ // uint32_t->int implicit conversion.
801+ return get_info_impl<uint32_t ,
802+ UR_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES>();
803+ }
721804 CASE (ext::intel::esimd::info::device::has_2d_block_io_support) {
722805 if (!has (aspect::ext_intel_esimd))
723806 return false ;
@@ -904,6 +987,17 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
904987 MDevice, Desc, 0 , nullptr , &return_size) == UR_RESULT_SUCCESS;
905988 }
906989
990+ template <ur_device_info_t Desc> static constexpr auto ur_ret_type_impl () {
991+ if constexpr (false ) {
992+ }
993+ #define MAP (VALUE, ...) else if constexpr (Desc == VALUE) return __VA_ARGS__{};
994+ #include " ur_device_info_ret_types.inc"
995+ #undef MAP
996+ }
997+
998+ template <ur_device_info_t Desc>
999+ using ur_ret_type = decltype (ur_ret_type_impl<Desc>());
1000+
9071001 // This should really be
9081002 // std::expected<ReturnT, ur_result_t>
9091003 // but we don't have C++23. Emulate close enough with as little code as
@@ -932,62 +1026,70 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
9321026
9331027 template <typename ReturnT, ur_device_info_t Desc>
9341028 expected<ReturnT, ur_result_t > get_info_impl_nocheck () const {
1029+ using ur_ret_t = ur_ret_type<Desc>;
9351030 static_assert (!std::is_same_v<ReturnT, std::string>,
9361031 " Wasn't needed before." );
9371032 if constexpr (std::is_same_v<ReturnT, bool >) {
9381033 return get_info_impl_nocheck<ur_bool_t , Desc>();
939- } else if constexpr (is_std_vector_v<ReturnT>) {
940- static_assert (
941- !check_type_in_v<typename ReturnT::value_type, bool , std::string>);
942- size_t ResultSize = 0 ;
943- ur_result_t Error =
944- getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
945- getHandleRef (), Desc, 0 , nullptr , &ResultSize);
946- if (Error != UR_RESULT_SUCCESS)
947- return {Error};
948- if (ResultSize == 0 )
949- return {ReturnT{}};
950-
951- ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
952- Error = getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
953- getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
954- if (Error != UR_RESULT_SUCCESS)
955- return {Error};
956- return {Result};
9571034 } else {
958- ReturnT Result;
959- ur_result_t Error =
960- getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
961- getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
962- if (Error == UR_RESULT_SUCCESS)
1035+ static_assert (std::is_same_v<ur_ret_t , ReturnT>);
1036+ if constexpr (is_std_vector_v<ReturnT>) {
1037+ static_assert (
1038+ !check_type_in_v<typename ReturnT::value_type, bool , std::string>);
1039+ size_t ResultSize = 0 ;
1040+ ur_result_t Error =
1041+ getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
1042+ getHandleRef (), Desc, 0 , nullptr , &ResultSize);
1043+ if (Error != UR_RESULT_SUCCESS)
1044+ return {Error};
1045+ if (ResultSize == 0 )
1046+ return {ReturnT{}};
1047+
1048+ ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
1049+ Error = getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
1050+ getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
1051+ if (Error != UR_RESULT_SUCCESS)
1052+ return {Error};
9631053 return {Result};
964- else
965- return {Error};
1054+ } else {
1055+ ReturnT Result;
1056+ ur_result_t Error =
1057+ getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
1058+ getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
1059+ if (Error == UR_RESULT_SUCCESS)
1060+ return {Result};
1061+ else
1062+ return {Error};
1063+ }
9661064 }
9671065 }
9681066
9691067 template <typename ReturnT, ur_device_info_t Desc>
9701068 ReturnT get_info_impl () const {
1069+ using ur_ret_t = ur_ret_type<Desc>;
9711070 if constexpr (std::is_same_v<ReturnT, bool >) {
9721071 return get_info_impl<ur_bool_t , Desc>();
973- } else if constexpr (std::is_same_v<ReturnT, std::string>) {
974- return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this , Desc);
975- } else if constexpr (is_std_vector_v<ReturnT>) {
976- size_t ResultSize = 0 ;
977- getAdapter ()->call <UrApiKind::urDeviceGetInfo>(getHandleRef (), Desc, 0 ,
978- nullptr , &ResultSize);
979- if (ResultSize == 0 )
980- return {};
981-
982- ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
983- getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
984- getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
985- return Result;
9861072 } else {
987- ReturnT Result;
988- getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
989- getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
990- return Result;
1073+ static_assert (std::is_same_v<ur_ret_t , ReturnT>);
1074+ if constexpr (std::is_same_v<ReturnT, std::string>) {
1075+ return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this , Desc);
1076+ } else if constexpr (is_std_vector_v<ReturnT>) {
1077+ size_t ResultSize = 0 ;
1078+ getAdapter ()->call <UrApiKind::urDeviceGetInfo>(getHandleRef (), Desc, 0 ,
1079+ nullptr , &ResultSize);
1080+ if (ResultSize == 0 )
1081+ return {};
1082+
1083+ ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
1084+ getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1085+ getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
1086+ return Result;
1087+ } else {
1088+ ReturnT Result;
1089+ getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1090+ getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
1091+ return Result;
1092+ }
9911093 }
9921094 }
9931095
0 commit comments