Skip to content

Commit d2d05bd

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # ggml/src/ggml-rpc/ggml-rpc.cpp
2 parents b30f09d + 73955f7 commit d2d05bd

File tree

5 files changed

+75
-3
lines changed

5 files changed

+75
-3
lines changed

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ extern "C" {
88
#endif
99

1010
#define RPC_PROTO_MAJOR_VERSION 3
11-
#define RPC_PROTO_MINOR_VERSION 0
11+
#define RPC_PROTO_MINOR_VERSION 5
1212
#define RPC_PROTO_PATCH_VERSION 0
1313
#define GGML_RPC_MAX_SERVERS 16
1414

ggml/src/ggml-cuda/mmf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
151151
return false;
152152
}
153153
} else {
154-
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
154+
if (src1_ncols > 16) {
155155
return false;
156156
}
157157
}

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ struct vk_device_struct {
665665
vk_pipeline pipeline_sin_f32;
666666
vk_pipeline pipeline_cos_f32;
667667
vk_pipeline pipeline_log[2];
668+
vk_pipeline pipeline_tri[2];
668669
vk_pipeline pipeline_clamp_f32;
669670
vk_pipeline pipeline_pad_f32;
670671
vk_pipeline pipeline_roll_f32;
@@ -3892,6 +3893,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
38923893
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38933894
}
38943895

3896+
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3897+
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3898+
38953899
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38963900

38973901
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8320,6 +8324,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
83208324
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
83218325
}
83228326
return nullptr;
8327+
case GGML_OP_TRI:
8328+
if (src0->type == dst->type &&
8329+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8330+
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
8331+
}
8332+
return nullptr;
83238333
case GGML_OP_CLAMP:
83248334
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
83258335
return ctx->device->pipeline_clamp_f32;
@@ -9021,6 +9031,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
90219031
case GGML_OP_SIN:
90229032
case GGML_OP_COS:
90239033
case GGML_OP_LOG:
9034+
case GGML_OP_TRI:
90249035
case GGML_OP_CLAMP:
90259036
case GGML_OP_PAD:
90269037
case GGML_OP_ROLL:
@@ -9701,6 +9712,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
97019712
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
97029713
}
97039714

9715+
static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9716+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
9717+
p.param1 = ggml_get_op_params_f32(dst, 0);
9718+
9719+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
9720+
}
9721+
97049722
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
97059723
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
97069724
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11824,6 +11842,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1182411842
case GGML_OP_LOG:
1182511843
ggml_vk_log(ctx, compute_ctx, src0, node);
1182611844

11845+
break;
11846+
case GGML_OP_TRI:
11847+
ggml_vk_tri(ctx, compute_ctx, src0, node);
11848+
1182711849
break;
1182811850
case GGML_OP_CLAMP:
1182911851
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -13949,7 +13971,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1394913971
case GGML_OP_OPT_STEP_SGD:
1395013972
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1395113973
case GGML_OP_LOG:
13952-
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
13974+
case GGML_OP_TRI:
13975+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13976+
op->type == op->src[0]->type;
1395313977
case GGML_OP_ARGSORT:
1395413978
{
1395513979
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
@@ -14540,6 +14564,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1454014564
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
1454114565
} else if (tensor->op == GGML_OP_LOG) {
1454214566
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
14567+
} else if (tensor->op == GGML_OP_TRI) {
14568+
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
1454314569
} else if (tensor->op == GGML_OP_CLAMP) {
1454414570
const float * params = (const float *)tensor->op_params;
1454514571
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#version 450
2+
3+
#include "rte.glsl"
4+
#include "types.glsl"
5+
#include "generic_unary_head.glsl"
6+
7+
#define GGML_TRI_TYPE_UPPER_DIAG 0
8+
#define GGML_TRI_TYPE_UPPER 1
9+
#define GGML_TRI_TYPE_LOWER_DIAG 2
10+
#define GGML_TRI_TYPE_LOWER 3
11+
12+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
13+
14+
void main() {
15+
const uint idx = get_idx();
16+
17+
if (idx >= p.ne) {
18+
return;
19+
}
20+
21+
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
22+
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
23+
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
24+
const uint i02_offset = i02*p.ne01*p.ne00;
25+
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
26+
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
27+
28+
int param = floatBitsToInt(p.param1);
29+
bool pass = false;
30+
switch (param) {
31+
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
32+
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
33+
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
34+
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
35+
}
36+
37+
if (pass) {
38+
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
39+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
40+
} else {
41+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
42+
}
43+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,9 @@ void process_shaders() {
863863
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
864864
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
865865

866+
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
867+
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
868+
866869
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
867870
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
868871

0 commit comments

Comments
 (0)