Skip to content

Commit 28cf832

Browse files
mmoadelimmoadeli
andauthored
[SYCL][HIP] Replace for loop to copy results with std::memcpy. (#11779)
Replace for loop to copy results with `std::memcpy` which showed slightly better performance. Co-authored-by: mmoadeli <[email protected]>
1 parent f54f61d commit 28cf832

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -344,55 +344,48 @@ void joint_matrix_mad_hip(
344344
*reinterpret_cast<const float16x4 *>(&A.wi_marray),
345345
*reinterpret_cast<const float16x4 *>(&B.wi_marray),
346346
*reinterpret_cast<const floatx4 *>(&C.wi_marray), 0, 0, 0);
347-
for (int i = 0; i < 4; ++i)
348-
D.wi_marray[i] = result[i];
347+
std::memcpy(&D.wi_marray, &result, 4 * sizeof(float));
349348
} else if constexpr (M == 32 && N == 32) {
350349
auto result = __builtin_amdgcn_mfma_f32_32x32x8f16(
351350
*reinterpret_cast<const float16x4 *>(&A.wi_marray),
352351
*reinterpret_cast<const float16x4 *>(&B.wi_marray),
353352
*reinterpret_cast<const floatx16 *>(&C.wi_marray), 0, 0, 0);
354-
for (int i = 0; i < 16; ++i)
355-
D.wi_marray[i] = result[i];
353+
std::memcpy(&D.wi_marray, &result, 16 * sizeof(float));
356354
}
357355
} else if constexpr (std::is_same_v<Tm, bfloat16>) {
358356
if constexpr (M == 16 && N == 16) {
359357
auto result = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
360358
*reinterpret_cast<const bfloat16x4 *>(&A.wi_marray),
361359
*reinterpret_cast<const bfloat16x4 *>(&B.wi_marray),
362360
*reinterpret_cast<const floatx4 *>(&C.wi_marray), 0, 0, 0);
363-
for (int i = 0; i < 4; ++i)
364-
D.wi_marray[i] = result[i];
361+
std::memcpy(&D.wi_marray, &result, 4 * sizeof(float));
365362
} else if constexpr (M == 32 && N == 32) {
366363
auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
367364
*reinterpret_cast<const bfloat16x4 *>(&A.wi_marray),
368365
*reinterpret_cast<const bfloat16x4 *>(&B.wi_marray),
369366
*reinterpret_cast<const floatx16 *>(&C.wi_marray), 0, 0, 0);
370-
for (int i = 0; i < 16; ++i)
371-
D.wi_marray[i] = result[i];
367+
std::memcpy(&D.wi_marray, &result, 16 * sizeof(float));
372368
}
373369
} else if constexpr (std::is_same_v<Tm, double>) {
374370
if constexpr (M == 16 && N == 16) {
375371
auto result = __builtin_amdgcn_mfma_f64_16x16x4f64(
376372
A.wi_marray[0], B.wi_marray[0],
377373
*reinterpret_cast<const doublex4 *>(&C.wi_marray), 0, 0, 0);
378-
for (int i = 0; i < 4; ++i)
379-
D.wi_marray[i] = result[i];
374+
std::memcpy(&D.wi_marray, &result, 4 * sizeof(double));
380375
}
381376
} else if constexpr (std::is_same_v<Tm, int8_t>) {
382377
if constexpr (M == 16 && N == 16) {
383378
auto result = __builtin_amdgcn_mfma_i32_16x16x16i8(
384379
*reinterpret_cast<const Tc *>(&A.wi_marray),
385380
*reinterpret_cast<const Tc *>(&B.wi_marray),
386381
*reinterpret_cast<const int32x4 *>(&C.wi_marray), 0, 0, 0);
387-
for (int i = 0; i < 4; ++i)
388-
D.wi_marray[i] = result[i];
382+
std::memcpy(&D.wi_marray, &result, 4 * sizeof(int32_t));
389383
} else if constexpr (M == 32 && N == 32) {
390384
auto result = __builtin_amdgcn_mfma_i32_32x32x8i8(
391385
*reinterpret_cast<const Tc *>(&A.wi_marray),
392386
*reinterpret_cast<const Tc *>(&B.wi_marray),
393387
*reinterpret_cast<const int32x16 *>(&C.wi_marray), 0, 0, 0);
394-
for (int i = 0; i < 16; ++i)
395-
D.wi_marray[i] = result[i];
388+
std::memcpy(&D.wi_marray, &result, 16 * sizeof(int32_t));
396389
}
397390
}
398391
}

0 commit comments

Comments
 (0)