Skip to content

Commit dcfa304

Browse files
committed
Added comment for half casting. Removed redundant code
1 parent c04b373 commit dcfa304

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,6 @@ constexpr uint32_t CalculateMatrixOperand() {
104104
if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
105105
returnValue += static_cast<uint32_t>(
106106
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
107-
if constexpr (std::is_signed<Ta>::value && std::is_signed<Tb>::value) {
108-
returnValue += static_cast<uint32_t>(
109-
__spv::MatrixOperands::MatrixASignedComponentsKHR) +
110-
static_cast<uint32_t>(
111-
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
112-
}
113107
return returnValue;
114108
}
115109

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,11 @@ void joint_matrix_copy(
495495
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
496496
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
497497
for (int i = 0; i < wi_data_c.length(); i++) {
498-
if constexpr (std::is_same_v<T1, half>) {
498+
if constexpr (std::is_same_v<T1, sycl::half>) {
499+
// Special case for SRC type sycl:half since we can't
500+
// cast direvtly from wi_element(typed half) to other type.
501+
// first cast is from wi_element to half (T1).
502+
// second cast is from half to dst type (T2).
499503
wi_data_dst[i] = static_cast<storage_element_type>(
500504
static_cast<src_storage_element_type>(wi_data_c[i]));
501505
} else {

0 commit comments

Comments
 (0)