Skip to content

Commit 22fcde2

Browse files
authored
[SYCL][Joint Matrix] Add more cases in common JM tests functions (#15712)
1 parent 0eb08d9 commit 22fcde2

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

sycl/source/detail/device_info.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -841,8 +841,8 @@ struct get_device_info_impl<
841841
};
842842
else if ((architecture::intel_gpu_pvc == DeviceArch) ||
843843
(architecture::intel_gpu_bmg_g21 == DeviceArch) ||
844-
(architecture::intel_gpu_lnl_m == DeviceArch))
845-
return {
844+
(architecture::intel_gpu_lnl_m == DeviceArch)) {
845+
std::vector<ext::oneapi::experimental::matrix::combination> pvc_combs = {
846846
{8, 0, 0, 0, 16, 32, matrix_type::uint8, matrix_type::uint8,
847847
matrix_type::sint32, matrix_type::sint32},
848848
{8, 0, 0, 0, 16, 32, matrix_type::uint8, matrix_type::sint8,
@@ -950,10 +950,11 @@ struct get_device_info_impl<
950950
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
951951
matrix_type::fp32, matrix_type::fp32},
952952
};
953-
else if ((architecture::intel_gpu_dg2_g10 == DeviceArch) ||
954-
(architecture::intel_gpu_dg2_g11 == DeviceArch) ||
955-
(architecture::intel_gpu_dg2_g12 == DeviceArch) ||
956-
(architecture::intel_gpu_arl_h == DeviceArch))
953+
return pvc_combs;
954+
} else if ((architecture::intel_gpu_dg2_g10 == DeviceArch) ||
955+
(architecture::intel_gpu_dg2_g11 == DeviceArch) ||
956+
(architecture::intel_gpu_dg2_g12 == DeviceArch) ||
957+
(architecture::intel_gpu_arl_h == DeviceArch))
957958
return {
958959
{8, 0, 0, 0, 8, 32, matrix_type::uint8, matrix_type::uint8,
959960
matrix_type::sint32, matrix_type::sint32},

sycl/test-e2e/Matrix/common.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6363
if constexpr (std::is_same_v<Ta, bfloat16> &&
6464
std::is_same_v<Tc, float>)
6565
acc += make_fp32(va[i]) * make_fp32(vb[i]);
66-
else if constexpr (std::is_same_v<Ta, sycl::half> &&
67-
std::is_same_v<Tc, float>)
66+
else if constexpr (std::is_same_v<Ta, sycl::half>)
6867
acc += (float)va[i] * (float)vb[i];
6968
else if constexpr (std::is_same_v<Ta, float> &&
7069
std::is_same_v<Tc, float> ||
@@ -135,7 +134,8 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
135134

136135
for (unsigned int i = 0; i < rows; i++) {
137136
for (unsigned int j = 0; j < cols; j++) {
138-
if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float> ||
137+
if constexpr (std::is_same_v<T, sycl::half> ||
138+
std::is_same_v<T, bfloat16> || std::is_same_v<T, float> ||
139139
std::is_same_v<T, double>) {
140140
src[i * cols + j] = T(fdistr(dev));
141141
} else if constexpr (std::is_integral_v<T>) {

0 commit comments

Comments
 (0)