Skip to content

Commit 5dd04ab

Browse files
committed
rebase from master
1 parent 6ea37f5 commit 5dd04ab

File tree

9 files changed

+349
-5
lines changed

9 files changed

+349
-5
lines changed

docs/ops.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Legend:
2222
| ARANGE ||||||||||
2323
| ARGMAX ||||||||||
2424
| ARGSORT ||||||||||
25-
| CEIL |||||||| ||
25+
| CEIL |||||||| 🟡 ||
2626
| CLAMP ||||| 🟡 | 🟡 || 🟡 ||
2727
| CONCAT |||| 🟡 || 🟡 | 🟡 |||
2828
| CONT || 🟡 |||| 🟡 | 🟡 | 🟡 ||
@@ -42,7 +42,7 @@ Legend:
4242
| ELU |||| 🟡 | 🟡 || 🟡 |||
4343
| EXP |||| 🟡 | 🟡 || 🟡 |||
4444
| FLASH_ATTN_EXT || 🟡 || 🟡 | 🟡 ||| 🟡 ||
45-
| FLOOR |||||||| ||
45+
| FLOOR |||||||| 🟡 ||
4646
| GATED_LINEAR_ATTN ||||||||||
4747
| GEGLU ||||| 🟡 ||| 🟡 ||
4848
| GEGLU_ERF ||||| 🟡 ||| 🟡 ||
@@ -84,7 +84,7 @@ Legend:
8484
| ROLL ||||||||||
8585
| ROPE || 🟡 ||||||||
8686
| ROPE_BACK ||||||||||
87-
| ROUND |||||||| ||
87+
| ROUND |||||||| 🟡 ||
8888
| RWKV_WKV6 ||||||||||
8989
| RWKV_WKV7 ||||||||||
9090
| SCALE || 🟡 ||||||||
@@ -111,6 +111,6 @@ Legend:
111111
| TANH |||| 🟡 | 🟡 || 🟡 | 🟡 ||
112112
| TIMESTEP_EMBEDDING ||||||||||
113113
| TOPK_MOE ||||||||||
114-
| TRUNC |||||||| ||
114+
| TRUNC |||||||| 🟡 ||
115115
| UPSCALE || 🟡 ||| 🟡 || 🟡 |||
116116
| XIELU ||||||||||

docs/ops/Vulkan.csv

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5434,12 +5434,20 @@
54345434
"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
54355435
"Vulkan0","SIN","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
54365436
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
5437+
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
5438+
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
5439+
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
5440+
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
54375441
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
54385442
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
54395443
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","Vulkan"
54405444
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
54415445
"Vulkan0","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
54425446
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
5447+
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
5448+
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
5449+
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
5450+
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
54435451
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
54445452
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan"
54455453
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan"

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,10 @@ struct vk_device_struct {
549549
vk_pipeline pipeline_rms_norm_mul_partials_f32;
550550
vk_pipeline pipeline_rms_norm_back_f32;
551551
vk_pipeline pipeline_l2_norm_f32;
552+
vk_pipeline pipeline_floor_f32;
553+
vk_pipeline pipeline_ceil_f32;
554+
vk_pipeline pipeline_round_f32;
555+
vk_pipeline pipeline_trunc_f32;
552556

553557
// [src/dst 0=fp32,1=fp16]
554558
vk_pipeline pipeline_exp[2];
@@ -3516,6 +3520,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
35163520
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
35173521
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
35183522

3523+
ggml_vk_create_pipeline(device, device->pipeline_floor_f32, "floor_f32", floor_f32_len, floor_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3524+
ggml_vk_create_pipeline(device, device->pipeline_ceil_f32, "ceil_f32", ceil_f32_len, ceil_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3525+
ggml_vk_create_pipeline(device, device->pipeline_round_f32, "round_f32", round_f32_len, round_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3526+
ggml_vk_create_pipeline(device, device->pipeline_trunc_f32, "trunc_f32", trunc_f32_len, trunc_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3527+
35193528
#define CREATE_UNARY(name) \
35203529
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
35213530
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -7891,6 +7900,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
78917900
return ctx->device->pipeline_cos_f32;
78927901
}
78937902
return nullptr;
7903+
case GGML_OP_FLOOR:
7904+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7905+
return ctx->device->pipeline_floor_f32;
7906+
}
7907+
return nullptr;
7908+
case GGML_OP_CEIL:
7909+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7910+
return ctx->device->pipeline_ceil_f32;
7911+
}
7912+
return nullptr;
7913+
case GGML_OP_ROUND:
7914+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7915+
return ctx->device->pipeline_round_f32;
7916+
}
7917+
return nullptr;
7918+
case GGML_OP_TRUNC:
7919+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7920+
return ctx->device->pipeline_trunc_f32;
7921+
}
7922+
return nullptr;
78947923
case GGML_OP_CLAMP:
78957924
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
78967925
return ctx->device->pipeline_clamp_f32;
@@ -8268,6 +8297,10 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
82688297
case GGML_OP_SUM:
82698298
case GGML_OP_SUM_ROWS:
82708299
case GGML_OP_MEAN:
8300+
case GGML_OP_FLOOR:
8301+
case GGML_OP_CEIL:
8302+
case GGML_OP_ROUND:
8303+
case GGML_OP_TRUNC:
82718304
return true;
82728305
default:
82738306
return false;
@@ -8608,6 +8641,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
86088641
case GGML_OP_UNARY:
86098642
case GGML_OP_GLU:
86108643
case GGML_OP_CONV_2D_DW:
8644+
case GGML_OP_FLOOR:
8645+
case GGML_OP_CEIL:
8646+
case GGML_OP_ROUND:
8647+
case GGML_OP_TRUNC:
86118648
{
86128649
uint32_t ne = ggml_nelements(dst);
86138650
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
@@ -9394,6 +9431,22 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
93949431
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
93959432
}
93969433

9434+
static void ggml_vk_floor(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9435+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_FLOOR, vk_op_unary_push_constants_init(src0, dst), dryrun);
9436+
}
9437+
9438+
static void ggml_vk_ceil(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9439+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CEIL, vk_op_unary_push_constants_init(src0, dst), dryrun);
9440+
}
9441+
9442+
static void ggml_vk_round(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9443+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROUND, vk_op_unary_push_constants_init(src0, dst), dryrun);
9444+
}
9445+
9446+
static void ggml_vk_trunc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9447+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TRUNC, vk_op_unary_push_constants_init(src0, dst), dryrun);
9448+
}
9449+
93979450
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
93989451
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
93999452
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11140,6 +11193,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1114011193
case GGML_OP_FLASH_ATTN_EXT:
1114111194
case GGML_OP_OPT_STEP_ADAMW:
1114211195
case GGML_OP_OPT_STEP_SGD:
11196+
case GGML_OP_FLOOR:
11197+
case GGML_OP_CEIL:
11198+
case GGML_OP_ROUND:
11199+
case GGML_OP_TRUNC:
1114311200
break;
1114411201
default:
1114511202
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -11208,6 +11265,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1120811265
case GGML_OP_CONV_2D_DW:
1120911266
case GGML_OP_LEAKY_RELU:
1121011267
case GGML_OP_OPT_STEP_SGD:
11268+
case GGML_OP_FLOOR:
11269+
case GGML_OP_CEIL:
11270+
case GGML_OP_ROUND:
11271+
case GGML_OP_TRUNC:
1121111272
{
1121211273
// These operations all go through ggml_vk_op_f32, so short-circuit and
1121311274
// do the only thing needed for the dryrun.
@@ -11364,6 +11425,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1136411425
case GGML_OP_COS:
1136511426
ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun);
1136611427

11428+
break;
11429+
case GGML_OP_FLOOR:
11430+
ggml_vk_floor(ctx, compute_ctx, src0, node, dryrun);
11431+
11432+
break;
11433+
case GGML_OP_CEIL:
11434+
ggml_vk_ceil(ctx, compute_ctx, src0, node, dryrun);
11435+
11436+
break;
11437+
case GGML_OP_ROUND:
11438+
ggml_vk_round(ctx, compute_ctx, src0, node, dryrun);
11439+
11440+
break;
11441+
case GGML_OP_TRUNC:
11442+
ggml_vk_trunc(ctx, compute_ctx, src0, node, dryrun);
11443+
1136711444
break;
1136811445
case GGML_OP_CLAMP:
1136911446
ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
@@ -13275,6 +13352,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1327513352
case GGML_OP_LEAKY_RELU:
1327613353
case GGML_OP_OPT_STEP_ADAMW:
1327713354
case GGML_OP_OPT_STEP_SGD:
13355+
case GGML_OP_FLOOR:
13356+
case GGML_OP_CEIL:
13357+
case GGML_OP_ROUND:
13358+
case GGML_OP_TRUNC:
1327813359
return op->src[0]->type == GGML_TYPE_F32;
1327913360
case GGML_OP_ARGSORT:
1328013361
return op->ne[0] <= max_argsort_cols;
@@ -13788,12 +13869,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1378813869
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
1378913870
} else if (tensor->op == GGML_OP_COS) {
1379013871
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
13872+
} else if (tensor->op == GGML_OP_FLOOR) {
13873+
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
13874+
} else if (tensor->op == GGML_OP_CEIL) {
13875+
tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
13876+
} else if (tensor->op == GGML_OP_ROUND) {
13877+
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
13878+
} else if (tensor->op == GGML_OP_TRUNC) {
13879+
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
1379113880
} else if (tensor->op == GGML_OP_CLAMP) {
1379213881
const float * params = (const float *)tensor->op_params;
1379313882
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
1379413883
} else if (tensor->op == GGML_OP_PAD) {
1379513884
tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
13796-
tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
13885+
tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
1379713886
} else if (tensor->op == GGML_OP_REPEAT) {
1379813887
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
1379913888
} else if (tensor->op == GGML_OP_REPEAT_BACK) {
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "generic_unary_head.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
16+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(ceil(val));
17+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "generic_unary_head.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
16+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(floor(val));
17+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "generic_unary_head.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
16+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(round(val));
17+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "generic_unary_head.glsl"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
16+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(trunc(val));
17+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,14 @@ void process_shaders() {
779779

780780
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
781781

782+
string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
783+
784+
string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
785+
786+
string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
787+
788+
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
789+
782790
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
783791

784792
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)