Skip to content

Commit d82f529

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents ecab78c + a0f98dd commit d82f529

File tree

13 files changed

+302
-478
lines changed

13 files changed

+302
-478
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
17551755
[](common_params & params) {
17561756
params.warmup = false;
17571757
}
1758-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
1758+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
17591759
add_opt(common_arg(
17601760
{"--spm-infill"},
17611761
string_format(

ggml/src/ggml-backend.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13551355
std::vector<int32_t> ids;
13561356
std::vector<ggml_bitset_t> used_ids;
13571357

1358-
for (int i = 0; i < sched->n_splits; i++) {
1359-
struct ggml_backend_sched_split * split = &splits[i];
1358+
for (int split_id = 0; split_id < sched->n_splits; split_id++) {
1359+
struct ggml_backend_sched_split * split = &splits[split_id];
13601360
int split_backend_id = split->backend_id;
13611361
ggml_backend_t split_backend = sched->backends[split_backend_id];
13621362

13631363
// copy the input tensors to the split backend
1364-
for (int j = 0; j < split->n_inputs; j++) {
1365-
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
1366-
struct ggml_tensor * input = split->inputs[j];
1364+
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
1365+
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
1366+
struct ggml_tensor * input = split->inputs[input_id];
13671367
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
13681368

13691369
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
@@ -1398,17 +1398,30 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13981398

13991399
// get the ids
14001400
ggml_tensor * ids_tensor = node->src[2];
1401+
ggml_backend_t ids_backend = split_backend;
1402+
1403+
// if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
1404+
// in that case, we use the original ids tensor
1405+
for (int i = input_id + 1; i < split->n_inputs; i++) {
1406+
if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) {
1407+
ids_tensor = split->inputs[i];
1408+
ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]);
1409+
break;
1410+
}
1411+
}
1412+
14011413
if (ids_tensor != prev_ids_tensor) {
14021414
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
1403-
ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
1404-
ggml_backend_synchronize(split_backend);
1415+
ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
1416+
ggml_backend_synchronize(ids_backend);
14051417

14061418
// find the used experts
14071419
used_ids.clear();
14081420
used_ids.resize(ggml_bitset_size(n_expert));
14091421
for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {
14101422
for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {
14111423
int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];
1424+
GGML_ASSERT(id >= 0 && id < n_expert);
14121425
ggml_bitset_set(used_ids.data(), id);
14131426
}
14141427
}

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 119 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,86 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
867867
return acl_tensor;
868868
}
869869

870+
/**
871+
* @brief Fills a tensor with a scalar value.
872+
*
873+
* This function fills the destination tensor `acl_dst` with the scalar value
874+
* `scalar`.
875+
*
876+
* @param ctx The context for the CANN backend operations.
877+
* @param scalar The scalar value used to fill the tensor.
878+
* @param acl_dst The destination tensor to be filled with the scalar value.
879+
*/
880+
static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
881+
aclTensor* acl_dst) {
882+
auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
883+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
884+
ggml_cann_release_resources(ctx, acl_scalar);
885+
}
886+
887+
/**
888+
* @brief Get or expand a cached float32 tensor filled with a scalar value.
889+
*
890+
* This function manages cached device memory for float32 tensors. If the current
891+
* cache size is insufficient for the requested tensor shape, the old memory will
892+
* be released and new memory will be allocated. The allocated buffer is then
893+
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
894+
* value using CANN operations. Finally, an aclTensor object is created from the
895+
* cached memory and returned.
896+
*
897+
* @param ctx The CANN backend context that manages device memory.
898+
* @param buffer A pointer to the cached device buffer (will be allocated
899+
* or reallocated if necessary).
900+
* @param cache_element The current number of cached elements. This will be
901+
* updated when the cache is expanded.
902+
* @param ne The tensor shape array (number of elements in each dimension).
903+
* @param nb The stride size for each dimension.
904+
* @param dims The number of tensor dimensions.
905+
* @param value The scalar value used to fill the tensor (supports zero
906+
* initialization via memset or arbitrary values via fill_scalar).
907+
* @return An aclTensor pointer created from the cached buffer.
908+
*/
909+
static aclTensor* get_f32_cache_acl_tensor(
910+
ggml_backend_cann_context& ctx,
911+
void** buffer,
912+
int64_t &cache_element,
913+
int64_t* ne,
914+
size_t* nb,
915+
int64_t dims,
916+
float value) {
917+
// Calculate total number of elements
918+
int64_t n_element = 1;
919+
for (int i = 0; i < dims; i++) {
920+
n_element *= ne[i];
921+
}
922+
size_t size = n_element * sizeof(float);
923+
924+
// Allocate or expand cache if needed
925+
if (cache_element < n_element) {
926+
if (*buffer != nullptr) {
927+
aclrtFree(*buffer);
928+
*buffer = nullptr;
929+
}
930+
931+
ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));
932+
cache_element = n_element;
933+
934+
// Initialize cache
935+
if (value == 0.0f) {
936+
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
937+
} else {
938+
int64_t pool_ne[1] = { n_element };
939+
size_t pool_nb[1] = { sizeof(float) };
940+
aclTensor* acl_value = ggml_cann_create_tensor(
941+
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
942+
aclnn_fill_scalar(ctx, 1, acl_value);
943+
ggml_cann_release_resources(ctx, acl_value);
944+
}
945+
}
946+
947+
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
948+
}
949+
870950
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
871951
ggml_tensor* src = dst->src[0];
872952

@@ -875,20 +955,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
875955

876956
float eps;
877957
memcpy(&eps, dst->op_params, sizeof(float));
878-
size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
879-
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
880-
881-
aclTensor* acl_gamma = aclnn_values(
882-
ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
883-
ggml_cann_type_mapping(src->type), ggml_element_size(src));
884-
885-
size_t zero_tensor_n_bytes =
886-
src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
887-
ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
888-
aclTensor* acl_rstd =
889-
aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
890-
src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
891-
ggml_element_size(src));
958+
959+
// build gamma, one...
960+
size_t acl_gamma_nb[GGML_MAX_DIMS];
961+
acl_gamma_nb[0] = sizeof(float);
962+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
963+
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
964+
}
965+
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
966+
ctx,
967+
&ctx.f32_one_cache,
968+
ctx.f32_one_cache_element,
969+
src->ne,
970+
acl_gamma_nb,
971+
1, // dims
972+
1.0f // value
973+
);
974+
975+
// build rstd, zero...
976+
size_t acl_rstd_nb[GGML_MAX_DIMS];
977+
acl_rstd_nb[0] = sizeof(float);
978+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
979+
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
980+
}
981+
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
982+
ctx,
983+
&ctx.f32_zero_cache,
984+
ctx.f32_zero_cache_element,
985+
src->ne,
986+
acl_rstd_nb,
987+
GGML_MAX_DIMS,
988+
0.0f // value
989+
);
990+
892991
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
893992
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
894993
}
@@ -903,14 +1002,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
9031002

9041003
const int n_past = ((int32_t*)dst->op_params)[0];
9051004

906-
size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
907-
src->ne[3] * ggml_element_size(src);
908-
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
1005+
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src));
1006+
void* buffer = one_tensor_allocator.get();
9091007

910-
aclTensor* mask_tensor =
911-
aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes,
912-
src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
913-
ggml_element_size(src), value);
1008+
aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),
1009+
ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);
1010+
1011+
aclnn_fill_scalar(ctx, value, mask_tensor);
9141012

9151013
aclScalar* alpha = nullptr;
9161014
float alphaValue = 1.0f;
@@ -1277,23 +1375,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
12771375
tmp_permute_tensor, tmp_mul_tensor, acl_dst);
12781376
}
12791377

1280-
/**
1281-
* @brief Fills a tensor with a scalar value.
1282-
*
1283-
* This function fills the destination tensor `acl_dst` with the scalar value
1284-
* `scalar`.
1285-
*
1286-
* @param ctx The context for the CANN backend operations.
1287-
* @param scalar The scalar value used to fill the tensor.
1288-
* @param acl_dst The destination tensor to be filled with the scalar value.
1289-
*/
1290-
static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
1291-
aclTensor* acl_dst) {
1292-
auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
1293-
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
1294-
ggml_cann_release_resources(ctx, acl_scalar);
1295-
}
1296-
12971378
/**
12981379
* @brief Raises each element of a tensor to the power of the corresponding
12991380
* element in another tensor.

ggml/src/ggml-cann/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,10 @@ struct ggml_backend_cann_context {
379379
cann_task_queue task_queue;
380380
bool async_mode;
381381
bool support_set_rows;
382+
void* f32_zero_cache = nullptr;
383+
void* f32_one_cache = nullptr;
384+
int64_t f32_zero_cache_element = 0;
385+
int64_t f32_one_cache_element = 0;
382386

383387
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
384388

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ struct vk_device_struct {
490490
vk_pipeline pipeline_l2_norm_f32;
491491

492492
// [src/dst 0=fp32,1=fp16]
493+
vk_pipeline pipeline_exp[2];
493494
vk_pipeline pipeline_gelu[2];
494495
vk_pipeline pipeline_gelu_erf[2];
495496
vk_pipeline pipeline_gelu_quick[2];
@@ -529,8 +530,8 @@ struct vk_device_struct {
529530
vk_pipeline pipeline_opt_step_sgd_f32;
530531
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
531532
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
532-
vk_pipeline pipeline_conv2d_dw_whcn_f32;
533-
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
533+
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
534+
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
534535

535536
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
536537
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
@@ -3066,6 +3067,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30663067
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
30673068
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30683069

3070+
CREATE_UNARY(exp)
30693071
CREATE_UNARY(gelu)
30703072
CREATE_UNARY(gelu_erf)
30713073
CREATE_UNARY(gelu_quick)
@@ -3255,6 +3257,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
32553257

32563258
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
32573259
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3260+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3261+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
32583262

32593263
for (auto &c : compiles) {
32603264
c.wait();
@@ -7133,6 +7137,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71337137
}
71347138

71357139
switch (ggml_get_unary_op(dst)) {
7140+
case GGML_UNARY_OP_EXP:
7141+
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
71367142
case GGML_UNARY_OP_SILU:
71377143
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
71387144
case GGML_UNARY_OP_GELU:
@@ -7342,6 +7348,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
73427348
} else if (ggml_is_contiguous_channels(src1)) {
73437349
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
73447350
}
7351+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
7352+
if (ggml_is_contiguous(src1)) {
7353+
return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
7354+
} else if (ggml_is_contiguous_channels(src1)) {
7355+
return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
7356+
}
73457357
}
73467358
return nullptr;
73477359
default:
@@ -9738,6 +9750,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97389750
return false;
97399751
case GGML_OP_UNARY:
97409752
switch (ggml_get_unary_op(node)) {
9753+
case GGML_UNARY_OP_EXP:
97419754
case GGML_UNARY_OP_SILU:
97429755
case GGML_UNARY_OP_GELU:
97439756
case GGML_UNARY_OP_GELU_ERF:
@@ -10015,6 +10028,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1001510028
break;
1001610029
case GGML_OP_UNARY:
1001710030
switch (ggml_get_unary_op(node)) {
10031+
case GGML_UNARY_OP_EXP:
1001810032
case GGML_UNARY_OP_SILU:
1001910033
case GGML_UNARY_OP_GELU:
1002010034
case GGML_UNARY_OP_GELU_ERF:
@@ -10251,6 +10265,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1025110265
break;
1025210266
case GGML_OP_UNARY:
1025310267
switch (ggml_get_unary_op(tensor)) {
10268+
case GGML_UNARY_OP_EXP:
1025410269
case GGML_UNARY_OP_SILU:
1025510270
case GGML_UNARY_OP_GELU:
1025610271
case GGML_UNARY_OP_GELU_ERF:
@@ -11166,6 +11181,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1116611181
switch (op->op) {
1116711182
case GGML_OP_UNARY:
1116811183
switch (ggml_get_unary_op(op)) {
11184+
case GGML_UNARY_OP_EXP:
1116911185
case GGML_UNARY_OP_GELU:
1117011186
case GGML_UNARY_OP_GELU_ERF:
1117111187
case GGML_UNARY_OP_GELU_QUICK:
@@ -11965,6 +11981,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1196511981
}
1196611982
} else if (tensor->op == GGML_OP_UNARY) {
1196711983
switch (ggml_get_unary_op(tensor)) {
11984+
case GGML_UNARY_OP_EXP:
11985+
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
11986+
break;
1196811987
case GGML_UNARY_OP_SILU:
1196911988
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
1197011989
break;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
data_d[i] = D_TYPE(exp(float(data_a[i])));
20+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,8 @@ void process_shaders() {
586586

587587
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
588588

589+
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
590+
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
589591
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
590592
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
591593
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -678,6 +680,8 @@ void process_shaders() {
678680

679681
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
680682
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
683+
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
684+
string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
681685

682686
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
683687

0 commit comments

Comments
 (0)