Skip to content

Commit 34100e6

Browse files
author
alexju
committed
metal : refactor pool_2d parameters into a struct
1 parent 64e4a7e commit 34100e6

File tree

3 files changed

+56
-59
lines changed

3 files changed

+56
-59
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,18 @@ typedef struct {
506506
float step;
507507
} ggml_metal_kargs_arange;
508508

509+
typedef struct {
510+
int32_t k0;
511+
int32_t k1;
512+
int32_t s0;
513+
int32_t s1;
514+
int32_t p0;
515+
int32_t p1;
516+
int64_t IH;
517+
int64_t IW;
518+
int64_t OH;
519+
int64_t OW;
520+
int64_t parallel_elements;
521+
} ggml_metal_kargs_pool_2d;
522+
509523
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4199,22 +4199,25 @@ static void ggml_metal_encode_node(
41994199
const int64_t parallel_elements = N * OC * OH * OW;
42004200
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
42014201
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
4202+
4203+
ggml_metal_kargs_pool_2d args_pool_2d = {
4204+
/* .k0 = */ k0,
4205+
/* .k1 = */ k1,
4206+
/* .s0 = */ s0,
4207+
/* .s1 = */ s1,
4208+
/* .p0 = */ p0,
4209+
/* .p1 = */ p1,
4210+
/* .IH = */ IH,
4211+
/* .IW = */ IW,
4212+
/* .OH = */ OH,
4213+
/* .OW = */ OW,
4214+
/* .parallel_elements = */ parallel_elements
4215+
};
42024216

4203-
// TODO: add ggml_metal_kargs struct
42044217
[encoder setComputePipelineState:pipeline];
4205-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4206-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4207-
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
4208-
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
4209-
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
4210-
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
4211-
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
4212-
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
4213-
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
4214-
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
4215-
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
4216-
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
4217-
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
4218+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4219+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4220+
[encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
42184221

42194222
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
42204223
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6463,98 +6463,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t
64636463
kernel void kernel_pool_2d_max_f32(
64646464
device const float * src0,
64656465
device float * dst,
6466-
constant int32_t & k0,
6467-
constant int32_t & k1,
6468-
constant int32_t & s0,
6469-
constant int32_t & s1,
6470-
constant int32_t & p0,
6471-
constant int32_t & p1,
6472-
constant int64_t & IH,
6473-
constant int64_t & IW,
6474-
constant int64_t & OH,
6475-
constant int64_t & OW,
6476-
constant int64_t & parallel_elements,
6466+
constant ggml_metal_kargs_pool_2d & args,
64776467
uint gid[[thread_position_in_grid]]) {
64786468

6479-
if (gid >= parallel_elements) {
6469+
if (gid >= args.parallel_elements) {
64806470
return;
64816471
}
64826472

64836473
const int idx = gid;
6484-
const int I_HW = IH * IW;
6485-
const int O_HW = OH * OW;
6474+
const int I_HW = args.IH * args.IW;
6475+
const int O_HW = args.OH * args.OW;
64866476
const int nc = idx / O_HW;
6487-
const int cur_oh = idx % O_HW / OW;
6488-
const int cur_ow = idx % O_HW % OW;
6477+
const int cur_oh = idx % O_HW / args.OW;
6478+
const int cur_ow = idx % O_HW % args.OW;
64896479

64906480
device const float * i_ptr = src0 + nc * I_HW;
64916481
device float * o_ptr = dst + nc * O_HW;
64926482

6493-
const int start_h = cur_oh * s1 - p1;
6483+
const int start_h = cur_oh * args.s1 - args.p1;
64946484
const int bh = MAX(0, start_h);
6495-
const int eh = MIN(IH, start_h + k1);
6496-
const int start_w = cur_ow * s0 - p0;
6485+
const int eh = MIN(args.IH, start_h + args.k1);
6486+
const int start_w = cur_ow * args.s0 - args.p0;
64976487
const int bw = MAX(0, start_w);
6498-
const int ew = MIN(IW, start_w + k0);
6488+
const int ew = MIN(args.IW, start_w + args.k0);
64996489

65006490
float res = -INFINITY;
65016491

65026492
for (int i = bh; i < eh; i += 1) {
65036493
for (int j = bw; j < ew; j += 1) {
6504-
res = MAX(res, i_ptr[i * IW + j]);
6494+
res = MAX(res, i_ptr[i * args.IW + j]);
65056495
}
65066496
}
65076497

6508-
o_ptr[cur_oh * OW + cur_ow] = res;
6498+
o_ptr[cur_oh * args.OW + cur_ow] = res;
65096499
}
65106500

65116501
kernel void kernel_pool_2d_avg_f32(
65126502
device const float * src0,
65136503
device float * dst,
6514-
constant int32_t & k0,
6515-
constant int32_t & k1,
6516-
constant int32_t & s0,
6517-
constant int32_t & s1,
6518-
constant int32_t & p0,
6519-
constant int32_t & p1,
6520-
constant int64_t & IH,
6521-
constant int64_t & IW,
6522-
constant int64_t & OH,
6523-
constant int64_t & OW,
6524-
constant int64_t & parallel_elements,
6504+
constant ggml_metal_kargs_pool_2d & args,
65256505
uint gid[[thread_position_in_grid]]) {
65266506

6527-
if (gid >= parallel_elements) {
6507+
if (gid >= args.parallel_elements) {
65286508
return;
65296509
}
65306510

65316511
const int idx = gid;
6532-
const int I_HW = IH * IW;
6533-
const int O_HW = OH * OW;
6512+
const int I_HW = args.IH * args.IW;
6513+
const int O_HW = args.OH * args.OW;
65346514
const int nc = idx / O_HW;
6535-
const int cur_oh = idx % O_HW / OW;
6536-
const int cur_ow = idx % O_HW % OW;
6515+
const int cur_oh = idx % O_HW / args.OW;
6516+
const int cur_ow = idx % O_HW % args.OW;
65376517

65386518
device const float * i_ptr = src0 + nc * I_HW;
65396519
device float * o_ptr = dst + nc * O_HW;
65406520

6541-
const int start_h = cur_oh * s1 - p1;
6521+
const int start_h = cur_oh * args.s1 - args.p1;
65426522
const int bh = MAX(0, start_h);
6543-
const int eh = MIN(IH, start_h + k1);
6544-
const int start_w = cur_ow * s0 - p0;
6523+
const int eh = MIN(args.IH, start_h + args.k1);
6524+
const int start_w = cur_ow * args.s0 - args.p0;
65456525
const int bw = MAX(0, start_w);
6546-
const int ew = MIN(IW, start_w + k0);
6526+
const int ew = MIN(args.IW, start_w + args.k0);
65476527
// const float scale = 1. / ((eh - bh) * (ew - bw));
6548-
const float scale = 1. / (k0 * k1);
6528+
const float scale = 1. / (args.k0 * args.k1);
65496529

65506530
float res = 0;
65516531

65526532
for (int i = bh; i < eh; i += 1) {
65536533
for (int j = bw; j < ew; j += 1) {
6554-
float cur = i_ptr[i * IW + j];
6534+
float cur = i_ptr[i * args.IW + j];
65556535
res += cur * scale;
65566536
}
65576537
}
65586538

6559-
o_ptr[cur_oh * OW + cur_ow] = res;
6539+
o_ptr[cur_oh * args.OW + cur_ow] = res;
65606540
}

0 commit comments

Comments
 (0)