Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ggml/include/ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c

GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);

GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
const char * cache_dir,
size_t free_mem, size_t total_mem);

GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);

Expand Down
10 changes: 5 additions & 5 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1791,11 +1791,11 @@ extern "C" {

#define GGML_KQ_MASK_PAD 64

// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
// q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
50 changes: 25 additions & 25 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -12238,23 +12238,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith;
const int nth = params->nth;

const int64_t D = neq0;
const int64_t N = neq1;
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;

GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);

// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));

GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);

GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);

// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
Expand Down Expand Up @@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value

float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16

if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, D*sizeof(float));
memset(VKQ32, 0, DV*sizeof(float));
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand All @@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv2 = iq2 / rv2;

const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D);
q_to_vec_dot(pq, Q_q, DK);

// online softmax / attention
// loop over n_kv and n_head_kv
Expand All @@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float s; // KQ value

const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);

s = s*scale; // scale KQ value

Expand All @@ -12380,45 +12380,45 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

v_to_float(v_data, V32, D);
v_to_float(v_data, V32, DV);

// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
}

S = S*ms + vs; // scale and increment sum with partial sum
}

if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) {
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}

// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv);
ggml_vec_scale_f32(DV, VKQ32, S_inv);

// dst indices
const int i1 = iq1;
Expand Down Expand Up @@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0;

if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {

switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
Expand Down Expand Up @@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV

cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
Expand Down
18 changes: 16 additions & 2 deletions ggml/src/ggml-cpu/llamafile/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2680,13 +2680,25 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_0);
vec_t vec_A[4] {0}, vec_B[4] = {0};
for (int l=0; l<k; l+=4) {
if (RN >= 4 && RM == 1) {
/* 'GEMV Forwarding' concept is used in first two conditional loops.
* when one of the matrix has a single row/column, the elements are
* broadcasted, instead of using packing routine to prepack the
* matrix elements.
*/
if (RM == 1) {
TA* a = const_cast<TA*>(A+(ii)*lda+l);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
} else if (RN == 1) {
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
vec_B[0] = (vec_t)vec_xl(0,b);
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
} else {
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
Expand Down Expand Up @@ -2790,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
assert(params->ith < params->nth);

// only enable sgemm for prompt processing
#if !defined(__MMA__)
if (n < 2)
return false;
#endif

if (Ctype != GGML_TYPE_F32)
return false;
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
Expand Down
9 changes: 6 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
Expand Down
Loading
Loading