@@ -344,55 +344,48 @@ void joint_matrix_mad_hip(
344344 *reinterpret_cast <const float16x4 *>(&A.wi_marray ),
345345 *reinterpret_cast <const float16x4 *>(&B.wi_marray ),
346346 *reinterpret_cast <const floatx4 *>(&C.wi_marray ), 0 , 0 , 0 );
347- for (int i = 0 ; i < 4 ; ++i)
348- D.wi_marray [i] = result[i];
347+ std::memcpy (&D.wi_marray , &result, 4 * sizeof (float ));
349348 } else if constexpr (M == 32 && N == 32 ) {
350349 auto result = __builtin_amdgcn_mfma_f32_32x32x8f16 (
351350 *reinterpret_cast <const float16x4 *>(&A.wi_marray ),
352351 *reinterpret_cast <const float16x4 *>(&B.wi_marray ),
353352 *reinterpret_cast <const floatx16 *>(&C.wi_marray ), 0 , 0 , 0 );
354- for (int i = 0 ; i < 16 ; ++i)
355- D.wi_marray [i] = result[i];
353+ std::memcpy (&D.wi_marray , &result, 16 * sizeof (float ));
356354 }
357355 } else if constexpr (std::is_same_v<Tm, bfloat16>) {
358356 if constexpr (M == 16 && N == 16 ) {
359357 auto result = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k (
360358 *reinterpret_cast <const bfloat16x4 *>(&A.wi_marray ),
361359 *reinterpret_cast <const bfloat16x4 *>(&B.wi_marray ),
362360 *reinterpret_cast <const floatx4 *>(&C.wi_marray ), 0 , 0 , 0 );
363- for (int i = 0 ; i < 4 ; ++i)
364- D.wi_marray [i] = result[i];
361+ std::memcpy (&D.wi_marray , &result, 4 * sizeof (float ));
365362 } else if constexpr (M == 32 && N == 32 ) {
366363 auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k (
367364 *reinterpret_cast <const bfloat16x4 *>(&A.wi_marray ),
368365 *reinterpret_cast <const bfloat16x4 *>(&B.wi_marray ),
369366 *reinterpret_cast <const floatx16 *>(&C.wi_marray ), 0 , 0 , 0 );
370- for (int i = 0 ; i < 16 ; ++i)
371- D.wi_marray [i] = result[i];
367+ std::memcpy (&D.wi_marray , &result, 16 * sizeof (float ));
372368 }
373369 } else if constexpr (std::is_same_v<Tm, double >) {
374370 if constexpr (M == 16 && N == 16 ) {
375371 auto result = __builtin_amdgcn_mfma_f64_16x16x4f64 (
376372 A.wi_marray [0 ], B.wi_marray [0 ],
377373 *reinterpret_cast <const doublex4 *>(&C.wi_marray ), 0 , 0 , 0 );
378- for (int i = 0 ; i < 4 ; ++i)
379- D.wi_marray [i] = result[i];
374+ std::memcpy (&D.wi_marray , &result, 4 * sizeof (double ));
380375 }
381376 } else if constexpr (std::is_same_v<Tm, int8_t >) {
382377 if constexpr (M == 16 && N == 16 ) {
383378 auto result = __builtin_amdgcn_mfma_i32_16x16x16i8 (
384379 *reinterpret_cast <const Tc *>(&A.wi_marray ),
385380 *reinterpret_cast <const Tc *>(&B.wi_marray ),
386381 *reinterpret_cast <const int32x4 *>(&C.wi_marray ), 0 , 0 , 0 );
387- for (int i = 0 ; i < 4 ; ++i)
388- D.wi_marray [i] = result[i];
382+ std::memcpy (&D.wi_marray , &result, 4 * sizeof (int32_t ));
389383 } else if constexpr (M == 32 && N == 32 ) {
390384 auto result = __builtin_amdgcn_mfma_i32_32x32x8i8 (
391385 *reinterpret_cast <const Tc *>(&A.wi_marray ),
392386 *reinterpret_cast <const Tc *>(&B.wi_marray ),
393387 *reinterpret_cast <const int32x16 *>(&C.wi_marray ), 0 , 0 , 0 );
394- for (int i = 0 ; i < 16 ; ++i)
395- D.wi_marray [i] = result[i];
388+ std::memcpy (&D.wi_marray , &result, 16 * sizeof (int32_t ));
396389 }
397390 }
398391}
0 commit comments