@@ -586,17 +586,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
586586#endif  //  defined(GGML_USE_HIP) && (defined(RDNA2)  || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
587587}
588588
589+ static  __device__  __forceinline__  void  ggml_cuda_mad (half2 & acc, const  half2 v, const  half2 u) {
590+ #ifdef  FAST_FP16_AVAILABLE
591+     acc += v*u;
592+ #else 
593+     const  float2  tmpv = __half22float2 (v);
594+     const  float2  tmpu = __half22float2 (u);
595+     float2  tmpacc = __half22float2 (acc);
596+     tmpacc.x  += tmpv.x  * tmpu.x ;
597+     tmpacc.y  += tmpv.y  * tmpu.y ;
598+     acc = make_half2 (tmpacc.x , tmpacc.y );
599+ #endif  //  FAST_FP16_AVAILABLE
600+ }
601+ 
589602//  Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
590- template  <int  nbytes>
603+ template  <int  nbytes,  int  alignment =  0 >
591604static  __device__  __forceinline__  void  ggml_cuda_memcpy_1 (void  * __restrict__  dst, const  void  * __restrict__  src) {
592-     if  constexpr  (nbytes == 4 ) {
593-         *(int  *) dst = *(const  int  *) src;
594-     } else  if  constexpr  (nbytes == 8 ) {
595-         *(int2  *) dst = *(const  int2  *) src;
596-     } else  if  constexpr  (nbytes == 16 ) {
597-         *(int4  *) dst = *(const  int4  *) src;
598-     } else  {
599-         static_assert (nbytes == 0  && nbytes == -1 , " bad nbytes" 
605+     if  constexpr  (alignment != 0 ) {
606+         static_assert (nbytes % alignment == 0 , " bad alignment" 
607+     }
608+     constexpr  int  nb_per_cpy = alignment == 0  ? nbytes : alignment;
609+ 
610+ #pragma  unroll
611+     for  (int  i = 0 ; i < nbytes/nb_per_cpy; ++i) {
612+         if  constexpr  (nb_per_cpy == 1 ) {
613+             ((char  *) dst)[i] = ((const  char  *) src)[i];
614+         } else  if  constexpr  (nb_per_cpy == 2 ) {
615+             ((short  *) dst)[i] = ((const  short  *) src)[i];
616+         } else  if  constexpr  (nb_per_cpy == 4 ) {
617+             ((int  *) dst)[i] = ((const  int  *) src)[i];
618+         } else  if  constexpr  (nb_per_cpy == 8 ) {
619+             ((int2  *) dst)[i] = ((const  int2  *) src)[i];
620+         } else  if  constexpr  (nb_per_cpy == 16 ) {
621+             ((int4  *) dst)[i] = ((const  int4  *) src)[i];
622+         } else  {
623+             static_assert (nbytes == 0  && nbytes == -1 , " bad nbytes" 
624+         }
600625    }
601626}
602627
0 commit comments