Skip to content

Commit 2bbb6d5

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents c1703c1 + 6db3d1f commit 2bbb6d5

File tree

7 files changed

+472
-348
lines changed

7 files changed

+472
-348
lines changed

examples/gguf/gguf.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,13 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
184184
const char * name = gguf_get_tensor_name (ctx, i);
185185
const size_t size = gguf_get_tensor_size (ctx, i);
186186
const size_t offset = gguf_get_tensor_offset(ctx, i);
187+
const auto type = gguf_get_tensor_type (ctx, i);
187188

188-
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset);
189+
const char * type_name = ggml_type_name(type);
190+
const size_t type_size = ggml_type_size(type);
191+
const size_t n_elements = size / type_size;
192+
193+
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s, n_elts = %zu\n", __func__, i, name, size, offset, type_name, n_elements);
189194
}
190195
}
191196

ggml/src/ggml-cuda/cpy.cu

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

10+
const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
11+
const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
12+
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
13+
1014
template <cpy_kernel_t cpy_1>
1115
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
1216
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
3539
cpy_1(cx + x_offset, cdst + dst_offset);
3640
}
3741

42+
template <typename T>
43+
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
44+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
45+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
46+
const int nb12, const int nb13) {
47+
48+
const T* src = reinterpret_cast<const T*>(cx);
49+
T* dst = reinterpret_cast<T*>(cdst);
50+
51+
const int64_t nmat = ne / (ne00 * ne01);
52+
const int64_t n = ne00 * ne01;
53+
54+
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
55+
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
56+
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
57+
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
58+
59+
__shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
60+
61+
#pragma unroll
62+
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
63+
64+
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
65+
if (imat >= nmat)
66+
break;
67+
68+
#pragma unroll
69+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
70+
if(x < ne01 && y + j < ne00){
71+
const int row = threadIdx.y+j;
72+
const int col = threadIdx.x * sizeof(float)/sizeof(T);
73+
T *tile2 = reinterpret_cast<T*>(tile[row]);
74+
tile2[col] = src[imat*n + (y+j)*ne01 + x];
75+
}
76+
}
77+
78+
__syncthreads();
79+
80+
#pragma unroll
81+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82+
if (ty + j < ne01 && tx < ne00) {
83+
const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
84+
const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
85+
dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
86+
}
87+
}
88+
}
89+
}
90+
3891
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
3992
float * cdstf = (float *)(cdsti);
4093

@@ -136,15 +189,38 @@ cudaStream_t stream) {
136189
(cx, cdst, ne);
137190
}
138191

139-
template<typename src_t, typename dst_t>
192+
template<typename src_t, typename dst_t, bool transposed = false>
140193
static void ggml_cpy_flt_cuda(
141194
const char * cx, char * cdst, const int ne,
142195
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
143196
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
144197

145-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
146-
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
147-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
198+
if (transposed) {
199+
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
200+
int ne00n, ne01n, ne02n;
201+
if (nb00 < nb02) {
202+
ne00n = ne00;
203+
ne01n = ne01;
204+
ne02n = ne02;
205+
} else if (nb00 > nb02) {
206+
ne00n = ne00;
207+
ne01n = ne01*ne02;
208+
ne02n = 1;
209+
} else {
210+
GGML_ASSERT(false);
211+
}
212+
213+
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
214+
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
215+
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
216+
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
217+
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
218+
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
219+
} else {
220+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
221+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
222+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
223+
}
148224
}
149225

150226
static void ggml_cpy_f32_q8_0_cuda(
@@ -310,6 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
310386
char * src1_ddc = (char *) src1->data;
311387

312388
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
389+
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
313390

314391
if (src0->type == src1->type && contiguous_srcs) {
315392
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -322,7 +399,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
322399
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
323400
}
324401
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
325-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+
if (can_be_transposed) {
403+
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
404+
} else {
405+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
406+
}
326407
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
327408
if (contiguous_srcs) {
328409
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -361,7 +442,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
361442
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
362443
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
363444
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
364-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
445+
if (can_be_transposed) {
446+
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
447+
} else {
448+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
449+
}
365450
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
366451
if (contiguous_srcs) {
367452
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -375,7 +460,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
375460
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376461
}
377462
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
378-
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
463+
if (can_be_transposed) {
464+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
465+
} else {
466+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
467+
}
379468
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
380469
if (contiguous_srcs) {
381470
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,13 @@ struct ggml_backend_hexagon_buffer_context {
367367
ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
368368
size += 4 * 1024; // extra page for padding
369369

370-
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
370+
if (rpcmem_alloc2) {
371+
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
372+
} else {
373+
GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
374+
this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
375+
}
376+
371377
if (!this->base) {
372378
GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
373379
throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
@@ -1679,12 +1685,13 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
16791685
}
16801686

16811687
// Get session URI
1682-
char htp_uri[256];
1683-
sprintf(htp_uri, "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
16841688

16851689
char session_uri[256];
16861690
{
1687-
struct remote_rpc_get_uri u;
1691+
char htp_uri[256];
1692+
snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
1693+
1694+
struct remote_rpc_get_uri u = {};
16881695
u.session_id = this->session_id;
16891696
u.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
16901697
u.domain_name_len = strlen(CDSP_DOMAIN_NAME);
@@ -1695,8 +1702,12 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
16951702

16961703
int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));
16971704
if (err != AEE_SUCCESS) {
1698-
GGML_LOG_ERROR("ggml-hex: failed to get URI for session %d : error 0x%x\n", dev_id, err);
1699-
throw std::runtime_error("ggml-hex: remote_session_control(get-uri) failed (see log for details)");
1705+
// fallback to single session uris
1706+
int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN;
1707+
1708+
snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri);
1709+
1710+
GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri);
17001711
}
17011712
}
17021713

@@ -3668,6 +3679,11 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
36683679
}
36693680
}
36703681

3682+
if(opt_arch < 75) {
3683+
opt_ndev = 1;
3684+
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
3685+
}
3686+
36713687
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
36723688

36733689
// Create devices / sessions

ggml/src/ggml-hexagon/htp-utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ extern "C" {
6464
# pragma weak remote_handle64_control
6565
# pragma weak fastrpc_mmap
6666
# pragma weak fastrpc_munmap
67+
# pragma weak rpcmem_alloc2
6768
#endif
6869

6970
#if !defined(_WINDOWS)

0 commit comments

Comments
 (0)