@@ -574,6 +574,8 @@ struct vk_device_struct {
574
574
vk_pipeline pipeline_opt_step_sgd_f32;
575
575
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
576
576
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
577
+ vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
578
+ vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
577
579
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
578
580
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
579
581
@@ -1117,6 +1119,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
1117
1119
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1118
1120
}
1119
1121
1122
+ struct vk_op_conv_transpose_2d_push_constants {
1123
+ uint32_t Cout;
1124
+ uint32_t Cin;
1125
+ uint32_t N;
1126
+
1127
+ uint32_t KW;
1128
+ uint32_t KH;
1129
+ uint32_t W;
1130
+ uint32_t H;
1131
+ uint32_t OW;
1132
+ uint32_t OH;
1133
+
1134
+ uint32_t s0;
1135
+ uint32_t s1;
1136
+ uint32_t p0;
1137
+ uint32_t p1;
1138
+ uint32_t d0;
1139
+ uint32_t d1;
1140
+
1141
+ uint32_t nb01;
1142
+ uint32_t nb02;
1143
+ uint32_t nb03;
1144
+
1145
+ uint32_t nb11;
1146
+ uint32_t nb12;
1147
+ uint32_t nb13;
1148
+
1149
+ uint32_t nb1;
1150
+ uint32_t nb2;
1151
+ uint32_t nb3;
1152
+
1153
+ // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
1154
+ uint32_t KWmp; uint32_t KWL;
1155
+ uint32_t KWKHmp; uint32_t KWKHL;
1156
+ uint32_t OWmp; uint32_t OWL;
1157
+ uint32_t OWOHmp; uint32_t OWOHL;
1158
+ uint32_t s0mp; uint32_t s0L;
1159
+ uint32_t s1mp; uint32_t s1L;
1160
+ };
1161
+
1162
+ template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
1163
+ // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
1164
+ init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1165
+ init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1166
+ init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1167
+ init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1168
+ init_fastdiv_values(p.s0, p.s0mp, p.s0L);
1169
+ init_fastdiv_values(p.s1, p.s1mp, p.s1L);
1170
+ }
1171
+
1120
1172
struct vk_op_conv2d_dw_push_constants {
1121
1173
uint32_t ne;
1122
1174
uint32_t batches;
@@ -1322,7 +1374,7 @@ class vk_perf_logger {
1322
1374
flops[name].push_back(m * n * (k + (k - 1)) * batch);
1323
1375
return;
1324
1376
}
1325
- if (node->op == GGML_OP_CONV_2D) {
1377
+ if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D ) {
1326
1378
std::string name = ggml_op_name(node->op);
1327
1379
ggml_tensor * knl = node->src[0];
1328
1380
uint64_t OW = node->ne[0];
@@ -1331,7 +1383,7 @@ class vk_perf_logger {
1331
1383
uint64_t Cout = node->ne[2];
1332
1384
uint64_t KW = knl->ne[0];
1333
1385
uint64_t KH = knl->ne[1];
1334
- uint64_t Cin = knl ->ne[2];
1386
+ uint64_t Cin = node->src[1] ->ne[2];
1335
1387
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
1336
1388
uint64_t size_M = Cout;
1337
1389
uint64_t size_K = Cin * KW * KH;
@@ -3492,7 +3544,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3492
3544
3493
3545
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3494
3546
3495
- // conv2d
3547
+ // conv2d, conv_transpose_2d
3496
3548
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3497
3549
uint32_t conv2d_WG_SIZE = 256;
3498
3550
uint32_t conv2d_BS_K = 128;
@@ -3567,31 +3619,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
3567
3619
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
3568
3620
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
3569
3621
3622
+ #define CREATE_CONV(name, type_suffix, spv_suffix) \
3623
+ ggml_vk_create_pipeline( \
3624
+ device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
3625
+ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3626
+ sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3627
+ #define CREATE_CONVS(spv_suffix) \
3628
+ CREATE_CONV(conv2d, _f32, spv_suffix) \
3629
+ CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
3630
+ if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
3631
+ CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
3632
+ CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
3633
+ }
3570
3634
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3571
3635
if (device->coopmat2) {
3572
- ggml_vk_create_pipeline(
3573
- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
3574
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3575
- ggml_vk_create_pipeline(
3576
- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
3577
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3636
+ CREATE_CONVS(_cm2)
3578
3637
} else
3579
3638
#endif
3580
3639
if (conv2d_UNROLL) {
3581
- ggml_vk_create_pipeline(
3582
- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
3583
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3584
- ggml_vk_create_pipeline(
3585
- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
3586
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3640
+ CREATE_CONVS(_unroll)
3587
3641
} else {
3588
- ggml_vk_create_pipeline(
3589
- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3590
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3591
- ggml_vk_create_pipeline(
3592
- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3593
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3642
+ CREATE_CONVS( )
3594
3643
}
3644
+ #undef CREATE_CONV
3645
+ #undef CREATE_CONVS
3595
3646
}
3596
3647
3597
3648
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -7548,6 +7599,33 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
7548
7599
return elements;
7549
7600
}
7550
7601
7602
+ static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
7603
+ const ggml_tensor *src0 = dst->src[0];
7604
+ const ggml_tensor *src1 = dst->src[1];
7605
+
7606
+ // src0 - kernel: [KW, KH, Cout, Cin]
7607
+ // src1 - input: [W, H, Cin, N]
7608
+ // dst - result: [OW, OH, Cout, N]
7609
+
7610
+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7611
+ return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
7612
+ };
7613
+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7614
+ int64_t W = src1->ne[0];
7615
+ int64_t H = src1->ne[1];
7616
+ int64_t KW = src0->ne[0];
7617
+ int64_t KH = src0->ne[1];
7618
+ int64_t Cout = src0->ne[2];
7619
+ int64_t N = src1->ne[3];
7620
+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
7621
+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
7622
+ int64_t NPQ = N * OW * OH;
7623
+
7624
+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7625
+ std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7626
+ return elements;
7627
+ }
7628
+
7551
7629
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
7552
7630
switch (op) {
7553
7631
case GGML_OP_GET_ROWS:
@@ -7925,9 +8003,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7925
8003
}
7926
8004
return nullptr;
7927
8005
case GGML_OP_CONV_2D:
8006
+ case GGML_OP_CONV_TRANSPOSE_2D:
7928
8007
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
7929
8008
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
7930
- auto elements = ggml_vk_get_conv_elements(dst);
8009
+ std::array<uint32_t, 3> elements;
8010
+ if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
8011
+ else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
7931
8012
vk_conv_shapes shape;
7932
8013
7933
8014
uint32_t tiles[CONV_SHAPE_COUNT];
@@ -7947,10 +8028,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7947
8028
shape = CONV_SHAPE_64x32;
7948
8029
}
7949
8030
7950
- if (src0->type == GGML_TYPE_F32) {
7951
- return ctx->device->pipeline_conv2d_f32[shape];
7952
- } else if (src0->type == GGML_TYPE_F16) {
7953
- return ctx->device->pipeline_conv2d_f16_f32[shape];
8031
+ if (op == GGML_OP_CONV_2D) {
8032
+ if (src0->type == GGML_TYPE_F32) {
8033
+ return ctx->device->pipeline_conv2d_f32[shape];
8034
+ } else if (src0->type == GGML_TYPE_F16) {
8035
+ return ctx->device->pipeline_conv2d_f16_f32[shape];
8036
+ }
8037
+ } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
8038
+ if (src0->type == GGML_TYPE_F32) {
8039
+ return ctx->device->pipeline_conv_transpose_2d_f32[shape];
8040
+ } else if (src0->type == GGML_TYPE_F16) {
8041
+ return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
8042
+ }
7954
8043
}
7955
8044
}
7956
8045
return nullptr;
@@ -8350,6 +8439,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
8350
8439
{
8351
8440
elements = ggml_vk_get_conv_elements(dst);
8352
8441
} break;
8442
+ case GGML_OP_CONV_TRANSPOSE_2D:
8443
+ {
8444
+ elements = ggml_vk_get_conv_transpose_2d_elements(dst);
8445
+ } break;
8353
8446
case GGML_OP_ADD:
8354
8447
case GGML_OP_SUB:
8355
8448
case GGML_OP_DIV:
@@ -9523,6 +9616,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
9523
9616
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
9524
9617
}
9525
9618
9619
+ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
9620
+ const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9621
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
9622
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9623
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9624
+
9625
+ GGML_TENSOR_BINARY_OP_LOCALS
9626
+
9627
+ GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
9628
+ GGML_ASSERT(nb10 == sizeof(float));
9629
+ GGML_ASSERT(nb0 == sizeof(float));
9630
+
9631
+ vk_op_conv_transpose_2d_push_constants p{};
9632
+ p.Cout = static_cast<uint32_t>(ne02);
9633
+ p.Cin = static_cast<uint32_t>(ne03);
9634
+ p.N = static_cast<uint32_t>(ne13);
9635
+
9636
+ p.KW = static_cast<uint32_t>(ne00);
9637
+ p.KH = static_cast<uint32_t>(ne01);
9638
+ p.W = static_cast<uint32_t>(ne10);
9639
+ p.H = static_cast<uint32_t>(ne11);
9640
+ p.OW = static_cast<uint32_t>(ne0);
9641
+ p.OH = static_cast<uint32_t>(ne1);
9642
+
9643
+ p.s0 = static_cast<uint32_t>(dst->op_params[0]);
9644
+ p.s1 = static_cast<uint32_t>(dst->op_params[0]);
9645
+ p.p0 = 0;
9646
+ p.p1 = 0;
9647
+ p.d0 = 1;
9648
+ p.d1 = 1;
9649
+
9650
+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
9651
+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
9652
+ p.nb03 = static_cast<uint32_t>(nb03 / nb00);
9653
+
9654
+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
9655
+ p.nb12 = static_cast<uint32_t>(nb12 / nb10);
9656
+ p.nb13 = static_cast<uint32_t>(nb13 / nb10);
9657
+
9658
+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
9659
+ p.nb2 = static_cast<uint32_t>(nb2 / nb0);
9660
+ p.nb3 = static_cast<uint32_t>(nb3 / nb0);
9661
+
9662
+ GGML_ASSERT(ne02 == ne2);
9663
+ GGML_ASSERT(ne03 == ne12);
9664
+
9665
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
9666
+ }
9667
+
9526
9668
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9527
9669
vk_op_conv2d_dw_push_constants p{};
9528
9670
p.ne = ggml_nelements(dst);
@@ -10615,6 +10757,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10615
10757
case GGML_OP_CONV_TRANSPOSE_1D:
10616
10758
case GGML_OP_POOL_2D:
10617
10759
case GGML_OP_CONV_2D:
10760
+ case GGML_OP_CONV_TRANSPOSE_2D:
10618
10761
case GGML_OP_CONV_2D_DW:
10619
10762
case GGML_OP_RWKV_WKV6:
10620
10763
case GGML_OP_RWKV_WKV7:
@@ -10686,6 +10829,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10686
10829
case GGML_OP_CONV_TRANSPOSE_1D:
10687
10830
case GGML_OP_POOL_2D:
10688
10831
case GGML_OP_CONV_2D:
10832
+ case GGML_OP_CONV_TRANSPOSE_2D:
10689
10833
case GGML_OP_CONV_2D_DW:
10690
10834
case GGML_OP_LEAKY_RELU:
10691
10835
case GGML_OP_OPT_STEP_SGD:
@@ -10997,6 +11141,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10997
11141
case GGML_OP_CONV_2D:
10998
11142
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
10999
11143
11144
+ break;
11145
+ case GGML_OP_CONV_TRANSPOSE_2D:
11146
+ ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun);
11147
+
11000
11148
break;
11001
11149
case GGML_OP_CONV_2D_DW:
11002
11150
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -11137,6 +11285,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
11137
11285
case GGML_OP_CONV_TRANSPOSE_1D:
11138
11286
case GGML_OP_POOL_2D:
11139
11287
case GGML_OP_CONV_2D:
11288
+ case GGML_OP_CONV_TRANSPOSE_2D:
11140
11289
case GGML_OP_CONV_2D_DW:
11141
11290
case GGML_OP_RWKV_WKV6:
11142
11291
case GGML_OP_RWKV_WKV7:
@@ -11794,10 +11943,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
11794
11943
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
11795
11944
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
11796
11945
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
11797
- } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
11946
+ } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D ) {
11798
11947
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
11799
11948
auto CRS_size =
11800
- cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0 ]->ne[2];
11949
+ cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1 ]->ne[2];
11801
11950
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
11802
11951
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
11803
11952
}
@@ -12618,10 +12767,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
12618
12767
case GGML_OP_CONV_TRANSPOSE_1D:
12619
12768
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
12620
12769
case GGML_OP_CONV_2D:
12770
+ case GGML_OP_CONV_TRANSPOSE_2D:
12621
12771
{
12622
12772
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
12623
12773
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
12624
12774
const vk_device& device = ggml_vk_get_device(ctx->device);
12775
+ if (op->op == GGML_OP_CONV_TRANSPOSE_2D &&
12776
+ device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) {
12777
+ return false;
12778
+ }
12625
12779
// Channel-contiguous format is not supported yet.
12626
12780
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
12627
12781
op->src[1]->type == GGML_TYPE_F32 &&
@@ -13240,6 +13394,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
13240
13394
const int32_t d0 = tensor->op_params[4];
13241
13395
const int32_t d1 = tensor->op_params[5];
13242
13396
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
13397
+ } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
13398
+ const int32_t s = tensor->op_params[0];
13399
+ tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
13243
13400
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
13244
13401
const float * op_params = (const float *)tensor->op_params;
13245
13402
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
0 commit comments