Skip to content

Commit 1f1f06a

Browse files
authored
Merge branch 'master' into r1-toolcall
2 parents 5d60ceb + 9f4cc8f commit 1f1f06a

File tree

9 files changed

+126
-54
lines changed

9 files changed

+126
-54
lines changed

common/arg.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,15 +1465,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14651465
{"--list-devices"},
14661466
"print list of available devices and exit",
14671467
[](common_params &) {
1468-
printf("Available devices:\n");
1468+
std::vector<ggml_backend_dev_t> rpc_devices;
1469+
std::vector<ggml_backend_dev_t> all_devices;
14691470
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
14701471
auto * dev = ggml_backend_dev_get(i);
14711472
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1472-
size_t free, total;
1473-
ggml_backend_dev_memory(dev, &free, &total);
1474-
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
1473+
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
1474+
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
1475+
rpc_devices.push_back(dev);
1476+
} else {
1477+
all_devices.push_back(dev);
1478+
}
14751479
}
14761480
}
1481+
// insert RPC devices in front
1482+
all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end());
1483+
printf("Available devices:\n");
1484+
for (size_t i = 0; i < all_devices.size(); ++i) {
1485+
auto * dev = all_devices[i];
1486+
size_t free, total;
1487+
ggml_backend_dev_memory(dev, &free, &total);
1488+
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
1489+
}
14771490
exit(0);
14781491
}
14791492
));

examples/server/server.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3357,6 +3357,8 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
33573357
return;
33583358
}
33593359

3360+
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
3361+
33603362
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
33613363

33623364
LOG_DBG("request: %s\n", req.body.c_str());
@@ -3443,9 +3445,13 @@ int main(int argc, char ** argv) {
34433445
message = "Unknown Exception";
34443446
}
34453447

3446-
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
3447-
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
3448-
res_error(res, formatted_error);
3448+
try {
3449+
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
3450+
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
3451+
res_error(res, formatted_error);
3452+
} catch (const std::exception & e) {
3453+
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
3454+
}
34493455
});
34503456

34513457
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "ggml-cuda/upscale.cuh"
3939
#include "ggml-cuda/wkv6.cuh"
4040
#include "ggml-cuda/gla.cuh"
41+
#include "ggml.h"
4142

4243
#include <algorithm>
4344
#include <array>
@@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31393140
break;
31403141
case GGML_OP_NORM:
31413142
case GGML_OP_RMS_NORM:
3143+
return true;
31423144
case GGML_OP_RMS_NORM_BACK:
31433145
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
31443146
break;
@@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31813183
case GGML_OP_SUM_ROWS:
31823184
case GGML_OP_ARGSORT:
31833185
case GGML_OP_ACC:
3186+
return true;
31843187
case GGML_OP_GROUP_NORM:
3188+
return ggml_is_contiguous(op->src[0]);
31853189
case GGML_OP_UPSCALE:
31863190
case GGML_OP_PAD:
31873191
case GGML_OP_ARANGE:

ggml/src/ggml-cuda/norm.cu

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
#include "norm.cuh"
2+
#include <cstdint>
23

34
template <int block_size>
4-
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
5-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
6-
const int tid = threadIdx.x;
5+
static __global__ void norm_f32(
6+
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
7+
const int64_t stride_sample, const float eps) {
8+
const int nrows = gridDim.x;
9+
const int nchannels = gridDim.y;
710

8-
x += int64_t(row)*ncols;
9-
dst += int64_t(row)*ncols;
11+
const int row = blockIdx.x;
12+
const int channel = blockIdx.y;
13+
const int sample = blockIdx.z;
14+
const int tid = threadIdx.x;
15+
16+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
17+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
1018

1119
float2 mean_var = make_float2(0.0f, 0.0f);
1220

@@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
97105
}
98106

99107
template <int block_size>
100-
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
101-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
102-
const int tid = threadIdx.x;
108+
static __global__ void rms_norm_f32(
109+
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110+
const int64_t stride_sample, const float eps) {
111+
const int nrows = gridDim.x;
112+
const int nchannels = gridDim.y;
113+
114+
const int row = blockIdx.x;
115+
const int channel = blockIdx.y;
116+
const int sample = blockIdx.z;
117+
const int tid = threadIdx.x;
103118

104-
x += int64_t(row)*ncols;
105-
dst += int64_t(row)*ncols;
119+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
120+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
106121

107122
float tmp = 0.0f; // partial sum for thread in warp
108123

@@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
186201
}
187202
}
188203

189-
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
204+
static void norm_f32_cuda(
205+
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
206+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
207+
const dim3 blocks_num(nrows, nchannels, nsamples);
190208
if (ncols < 1024) {
191209
const dim3 block_dims(WARP_SIZE, 1, 1);
192-
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
210+
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
193211
} else {
194212
const dim3 block_dims(1024, 1, 1);
195-
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
213+
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
196214
}
197215
}
198216

@@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
207225
}
208226
}
209227

210-
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
228+
static void rms_norm_f32_cuda(
229+
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
230+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
231+
const dim3 blocks_num(nrows, nchannels, nsamples);
211232
if (ncols < 1024) {
212233
const dim3 block_dims(WARP_SIZE, 1, 1);
213-
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
234+
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
214235
} else {
215236
const dim3 block_dims(1024, 1, 1);
216-
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
237+
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
217238
}
218239
}
219240

@@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
229250

230251
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231252
const ggml_tensor * src0 = dst->src[0];
232-
const float * src0_d = (const float *)src0->data;
233-
float * dst_d = (float *)dst->data;
253+
const float * src0_d = (const float *) src0->data;
254+
float * dst_d = (float *) dst->data;
234255
cudaStream_t stream = ctx.stream();
235256

236-
GGML_ASSERT(ggml_is_contiguous(src0));
237-
238257
GGML_ASSERT(src0->type == GGML_TYPE_F32);
239258
GGML_ASSERT( dst->type == GGML_TYPE_F32);
240259

241-
const int64_t ne00 = src0->ne[0];
242-
const int64_t nrows = ggml_nrows(src0);
260+
GGML_TENSOR_UNARY_OP_LOCALS;
243261

244262
float eps;
245263
memcpy(&eps, dst->op_params, sizeof(float));
246264
GGML_ASSERT(eps >= 0.0f);
247265

248-
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
266+
const size_t ts0 = ggml_type_size(src0->type);
267+
GGML_ASSERT(nb00 == ts0);
268+
const int64_t s01 = nb01 / ts0;
269+
const int64_t s02 = nb02 / ts0;
270+
const int64_t s03 = nb03 / ts0;
271+
272+
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
249273
}
250274

251275
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
254278
float * dst_d = (float *)dst->data;
255279
cudaStream_t stream = ctx.stream();
256280

257-
GGML_ASSERT(ggml_is_contiguous(src0));
258-
259281
GGML_ASSERT(src0->type == GGML_TYPE_F32);
260282
GGML_ASSERT( dst->type == GGML_TYPE_F32);
261283

@@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
271293

272294
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273295
const ggml_tensor * src0 = dst->src[0];
274-
const float * src0_d = (const float *)src0->data;
275-
float * dst_d = (float *)dst->data;
296+
const float * src0_d = (const float *) src0->data;
297+
float * dst_d = (float *) dst->data;
276298
cudaStream_t stream = ctx.stream();
277299

278-
GGML_ASSERT(ggml_is_contiguous(src0));
279-
280300
GGML_ASSERT(src0->type == GGML_TYPE_F32);
281301
GGML_ASSERT( dst->type == GGML_TYPE_F32);
282302

283-
const int64_t ne00 = src0->ne[0];
284-
const int64_t nrows = ggml_nrows(src0);
303+
GGML_TENSOR_UNARY_OP_LOCALS;
285304

286305
float eps;
287306
memcpy(&eps, dst->op_params, sizeof(float));
288307
GGML_ASSERT(eps >= 0.0f);
289308

290-
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
309+
const size_t ts0 = ggml_type_size(src0->type);
310+
GGML_ASSERT(nb00 == ts0);
311+
const int64_t s01 = nb01 / ts0;
312+
const int64_t s02 = nb02 / ts0;
313+
const int64_t s03 = nb03 / ts0;
314+
315+
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
291316
}
292317

293318
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ endif()
4646

4747
message(STATUS "HIP and hipBLAS found")
4848

49+
# Workaround old compilers
50+
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
51+
4952
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
5053
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
5154

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12061206
case GGML_OP_GROUP_NORM:
12071207
return has_simdgroup_reduction;
12081208
case GGML_OP_RMS_NORM:
1209-
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
1209+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
12101210
case GGML_OP_ARGMAX:
1211-
case GGML_OP_NORM:
12121211
return true;
1212+
case GGML_OP_NORM:
1213+
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
12131214
case GGML_OP_ROPE:
12141215
{
12151216
const int mode = ((const int32_t *) op->op_params)[2];

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
81828182
case GGML_OP_VIEW:
81838183
case GGML_OP_PERMUTE:
81848184
case GGML_OP_TRANSPOSE:
8185+
return true;
81858186
case GGML_OP_NORM:
81868187
case GGML_OP_GROUP_NORM:
81878188
case GGML_OP_RMS_NORM:
8189+
return ggml_is_contiguous(op->src[0]);
81888190
case GGML_OP_ADD:
81898191
case GGML_OP_ACC:
81908192
case GGML_OP_MUL:

src/llama.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4610,7 +4610,8 @@ struct llm_build_context {
46104610
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
46114611
cb(k_pe, "k_pe", il);
46124612

4613-
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
4613+
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
4614+
kv_compressed = ggml_cont(ctx0, kv_compressed);
46144615
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
46154616
model.layers[il].attn_kv_a_norm, NULL,
46164617
LLM_NORM_RMS, cb, il);
@@ -6464,7 +6465,8 @@ struct llm_build_context {
64646465
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
64656466
cb(k_pe, "k_pe", il);
64666467

6467-
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
6468+
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
6469+
kv_compressed = ggml_cont(ctx0, kv_compressed);
64686470
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
64696471
model.layers[il].attn_kv_a_norm, NULL,
64706472
LLM_NORM_RMS, cb, il);

0 commit comments

Comments
 (0)