Skip to content

Commit 4a516b5

Browse files
committed
Merge branch 'master' into esocrok
2 parents 187994b + 4807e8f commit 4a516b5

File tree

142 files changed

+1571
-1975
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

142 files changed

+1571
-1975
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -591,17 +591,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
591591
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
592592
}
593593

594+
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
595+
#ifdef FAST_FP16_AVAILABLE
596+
acc += v*u;
597+
#else
598+
const float2 tmpv = __half22float2(v);
599+
const float2 tmpu = __half22float2(u);
600+
float2 tmpacc = __half22float2(acc);
601+
tmpacc.x += tmpv.x * tmpu.x;
602+
tmpacc.y += tmpv.y * tmpu.y;
603+
acc = make_half2(tmpacc.x, tmpacc.y);
604+
#endif // FAST_FP16_AVAILABLE
605+
}
606+
594607
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
595-
template <int nbytes>
608+
template <int nbytes, int alignment = 0>
596609
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
597-
if constexpr (nbytes == 4) {
598-
*(int *) dst = *(const int *) src;
599-
} else if constexpr (nbytes == 8) {
600-
*(int2 *) dst = *(const int2 *) src;
601-
} else if constexpr (nbytes == 16) {
602-
*(int4 *) dst = *(const int4 *) src;
603-
} else {
604-
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
610+
if constexpr (alignment != 0) {
611+
static_assert(nbytes % alignment == 0, "bad alignment");
612+
}
613+
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
614+
615+
#pragma unroll
616+
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
617+
if constexpr (nb_per_cpy == 1) {
618+
((char *) dst)[i] = ((const char *) src)[i];
619+
} else if constexpr (nb_per_cpy == 2) {
620+
((short *) dst)[i] = ((const short *) src)[i];
621+
} else if constexpr (nb_per_cpy == 4) {
622+
((int *) dst)[i] = ((const int *) src)[i];
623+
} else if constexpr (nb_per_cpy == 8) {
624+
((int2 *) dst)[i] = ((const int2 *) src)[i];
625+
} else if constexpr (nb_per_cpy == 16) {
626+
((int4 *) dst)[i] = ((const int4 *) src)[i];
627+
} else {
628+
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
629+
}
605630
}
606631
}
607632

0 commit comments

Comments
 (0)