Skip to content

Commit 4e6310c

Browse files
committed
cpu : add batching and F16/I32 support to win_part/win_unpart ops/get_rel_pos
1 parent b6b9f02 commit 4e6310c

File tree

3 files changed

+193
-55
lines changed

3 files changed

+193
-55
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,18 +2279,16 @@ extern "C" {
22792279
struct ggml_tensor * ids);
22802280

22812281
// partition into non-overlapping windows with padding if needed
2282-
// example:
2283-
// a: 768 64 64 1
2284-
// w: 14
2285-
// res: 768 14 14 25
2286-
// used in sam
2282+
// a: [B, H, W, C]
2283+
// result: [B*NPY*NPX, w, w, C]
2284+
// NPY = ceil(H/w)
2285+
// NPX = ceil(W/w)
22872286
GGML_API struct ggml_tensor * ggml_win_part(
22882287
struct ggml_context * ctx,
22892288
struct ggml_tensor * a,
22902289
int w);
22912290

22922291
// reverse of ggml_win_part
2293-
// used in sam
22942292
GGML_API struct ggml_tensor * ggml_win_unpart(
22952293
struct ggml_context * ctx,
22962294
struct ggml_tensor * a,
@@ -2308,14 +2306,12 @@ extern "C" {
23082306
struct ggml_tensor * a,
23092307
enum ggml_unary_op op);
23102308

2311-
// used in sam
23122309
GGML_API struct ggml_tensor * ggml_get_rel_pos(
23132310
struct ggml_context * ctx,
23142311
struct ggml_tensor * a,
23152312
int qh,
23162313
int kh);
23172314

2318-
// used in sam
23192315
GGML_API struct ggml_tensor * ggml_add_rel_pos(
23202316
struct ggml_context * ctx,
23212317
struct ggml_tensor * a,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 171 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8946,35 +8946,80 @@ static void ggml_compute_forward_win_part_f32(
89468946

89478947
const ggml_tensor * src0 = dst->src[0];
89488948

8949-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
8950-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8949+
GGML_TENSOR_UNARY_OP_LOCALS
89518950

89528951
const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
89538952
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
8954-
const int32_t w = ((const int32_t *)(dst->op_params))[2];
8953+
const int32_t bs = ((const int32_t *)(dst->op_params))[2];
8954+
const int32_t w = ((const int32_t *)(dst->op_params))[3];
89558955

89568956
assert(ne00 == ne0);
8957-
assert(ne3 == nep0*nep1);
8957+
assert(ne3 == nep0*nep1*bs);
89588958

89598959
// TODO: optimize / multi-thread
8960-
for (int py = 0; py < nep1; ++py) {
8961-
for (int px = 0; px < nep0; ++px) {
8962-
const int64_t i3 = py*nep0 + px;
8963-
for (int64_t i2 = 0; i2 < ne2; ++i2) {
8964-
for (int64_t i1 = 0; i1 < ne1; ++i1) {
8965-
for (int64_t i0 = 0; i0 < ne0; ++i0) {
8966-
const int64_t i02 = py*w + i2;
8967-
const int64_t i01 = px*w + i1;
8968-
const int64_t i00 = i0;
8969-
8970-
const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
8971-
const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
8972-
8973-
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
8974-
((float *) dst->data)[i] = 0.0f;
8975-
} else {
8976-
((float *) dst->data)[i] = ((float *) src0->data)[j];
8977-
}
8960+
for (int64_t i3 = 0; i3 < ne3; i3++) {
8961+
int px = i3 % nep0;
8962+
int py = (i3 / nep0) % nep1;
8963+
int b = i3 / (nep0 * nep1);
8964+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
8965+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
8966+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
8967+
const int64_t i03 = b;
8968+
const int64_t i02 = py*w + i2;
8969+
const int64_t i01 = px*w + i1;
8970+
const int64_t i00 = i0;
8971+
8972+
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8973+
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
8974+
8975+
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
8976+
*((float *) dp) = 0;
8977+
} else {
8978+
*((float *) dp) = *((float *) sp);
8979+
}
8980+
}
8981+
}
8982+
}
8983+
}
8984+
}
8985+
8986+
static void ggml_compute_forward_win_part_f16(
8987+
const ggml_compute_params * params,
8988+
ggml_tensor * dst) {
8989+
GGML_UNUSED(params);
8990+
8991+
const ggml_tensor * src0 = dst->src[0];
8992+
8993+
GGML_TENSOR_UNARY_OP_LOCALS
8994+
8995+
const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
8996+
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
8997+
const int32_t bs = ((const int32_t *)(dst->op_params))[2];
8998+
const int32_t w = ((const int32_t *)(dst->op_params))[3];
8999+
9000+
assert(ne00 == ne0);
9001+
assert(ne3 == nep0*nep1*bs);
9002+
9003+
// TODO: optimize / multi-thread
9004+
for (int64_t i3 = 0; i3 < ne3; i3++) {
9005+
int px = i3 % nep0;
9006+
int py = (i3 / nep0) % nep1;
9007+
int b = i3 / (nep0 * nep1);
9008+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
9009+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9010+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9011+
const int64_t i03 = b;
9012+
const int64_t i02 = py*w + i2;
9013+
const int64_t i01 = px*w + i1;
9014+
const int64_t i00 = i0;
9015+
9016+
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9017+
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9018+
9019+
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
9020+
*((ggml_fp16_t *) dp) = 0;
9021+
} else {
9022+
*((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp);
89789023
}
89799024
}
89809025
}
@@ -8989,10 +9034,16 @@ void ggml_compute_forward_win_part(
89899034
const ggml_tensor * src0 = dst->src[0];
89909035

89919036
switch (src0->type) {
9037+
case GGML_TYPE_I32:
89929038
case GGML_TYPE_F32:
89939039
{
89949040
ggml_compute_forward_win_part_f32(params, dst);
89959041
} break;
9042+
case GGML_TYPE_BF16:
9043+
case GGML_TYPE_F16:
9044+
{
9045+
ggml_compute_forward_win_part_f16(params, dst);
9046+
} break;
89969047
default:
89979048
{
89989049
GGML_ABORT("fatal error");
@@ -9009,35 +9060,82 @@ static void ggml_compute_forward_win_unpart_f32(
90099060

90109061
const ggml_tensor * src0 = dst->src[0];
90119062

9012-
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9013-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
9063+
GGML_TENSOR_UNARY_OP_LOCALS
90149064

90159065
const int32_t w = ((const int32_t *)(dst->op_params))[0];
90169066

90179067
// padding
90189068
const int px = (w - ne1%w)%w;
9019-
//const int py = (w - ne2%w)%w;
9069+
const int py = (w - ne2%w)%w;
90209070

90219071
const int npx = (px + ne1)/w;
9022-
//const int npy = (py + ne2)/w;
9072+
const int npy = (py + ne2)/w;
90239073

90249074
assert(ne0 == ne00);
9075+
assert(ne03 == npx*npy*ne3);
90259076

90269077
// TODO: optimize / multi-thread
9027-
for (int64_t i2 = 0; i2 < ne2; ++i2) {
9028-
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9029-
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9030-
const int ip2 = i2/w;
9031-
const int ip1 = i1/w;
9078+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
9079+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
9080+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9081+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9082+
const int ip2 = i2/w;
9083+
const int ip1 = i1/w;
9084+
9085+
const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
9086+
const int64_t i02 = i2%w;
9087+
const int64_t i01 = i1%w;
9088+
const int64_t i00 = i0;
9089+
9090+
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9091+
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9092+
9093+
*((float *) dp) = *((float *) sp);
9094+
}
9095+
}
9096+
}
9097+
}
9098+
}
90329099

9033-
const int64_t i02 = i2%w;
9034-
const int64_t i01 = i1%w;
9035-
const int64_t i00 = i0;
9100+
static void ggml_compute_forward_win_unpart_f16(
9101+
const ggml_compute_params * params,
9102+
ggml_tensor * dst) {
9103+
GGML_UNUSED(params);
90369104

9037-
const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
9038-
const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
9105+
const ggml_tensor * src0 = dst->src[0];
9106+
9107+
GGML_TENSOR_UNARY_OP_LOCALS
9108+
9109+
const int32_t w = ((const int32_t *)(dst->op_params))[0];
9110+
9111+
// padding
9112+
const int px = (w - ne1%w)%w;
9113+
const int py = (w - ne2%w)%w;
9114+
9115+
const int npx = (px + ne1)/w;
9116+
const int npy = (py + ne2)/w;
9117+
9118+
assert(ne0 == ne00);
9119+
assert(ne03 == npx*npy*ne3);
90399120

9040-
((float *) dst->data)[j] = ((float *) src0->data)[i];
9121+
// TODO: optimize / multi-thread
9122+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
9123+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
9124+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9125+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9126+
const int ip2 = i2/w;
9127+
const int ip1 = i1/w;
9128+
9129+
const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
9130+
const int64_t i02 = i2%w;
9131+
const int64_t i01 = i1%w;
9132+
const int64_t i00 = i0;
9133+
9134+
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9135+
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9136+
9137+
*((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp);
9138+
}
90419139
}
90429140
}
90439141
}
@@ -9050,10 +9148,16 @@ void ggml_compute_forward_win_unpart(
90509148
const ggml_tensor * src0 = dst->src[0];
90519149

90529150
switch (src0->type) {
9151+
case GGML_TYPE_I32:
90539152
case GGML_TYPE_F32:
90549153
{
90559154
ggml_compute_forward_win_unpart_f32(params, dst);
90569155
} break;
9156+
case GGML_TYPE_BF16:
9157+
case GGML_TYPE_F16:
9158+
{
9159+
ggml_compute_forward_win_unpart_f16(params, dst);
9160+
} break;
90579161
default:
90589162
{
90599163
GGML_ABORT("fatal error");
@@ -9199,6 +9303,32 @@ void ggml_compute_forward_glu(
91999303

92009304
// ggml_compute_forward_get_rel_pos
92019305

9306+
static void ggml_compute_forward_get_rel_pos_f32(
9307+
const ggml_compute_params * params,
9308+
ggml_tensor * dst) {
9309+
GGML_UNUSED(params);
9310+
9311+
const ggml_tensor * src0 = dst->src[0];
9312+
9313+
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
9314+
9315+
GGML_TENSOR_UNARY_OP_LOCALS
9316+
9317+
const int64_t w = ne1;
9318+
9319+
float * src0_data = (float *) src0->data;
9320+
float * dst_data = (float *) dst->data;
9321+
9322+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
9323+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9324+
const int64_t pos = (w - i1 - 1) + i2;
9325+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
9326+
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9327+
}
9328+
}
9329+
}
9330+
}
9331+
92029332
static void ggml_compute_forward_get_rel_pos_f16(
92039333
const ggml_compute_params * params,
92049334
ggml_tensor * dst) {
@@ -9232,6 +9362,10 @@ void ggml_compute_forward_get_rel_pos(
92329362
const ggml_tensor * src0 = dst->src[0];
92339363

92349364
switch (src0->type) {
9365+
case GGML_TYPE_F32:
9366+
{
9367+
ggml_compute_forward_get_rel_pos_f32(params, dst);
9368+
} break;
92359369
case GGML_TYPE_F16:
92369370
case GGML_TYPE_BF16:
92379371
{

ggml/src/ggml.c

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5315,21 +5315,19 @@ struct ggml_tensor * ggml_win_part(
53155315
struct ggml_context * ctx,
53165316
struct ggml_tensor * a,
53175317
int w) {
5318-
GGML_ASSERT(a->ne[3] == 1);
5319-
GGML_ASSERT(a->type == GGML_TYPE_F32);
5320-
53215318
// padding
53225319
const int px = (w - a->ne[1]%w)%w;
53235320
const int py = (w - a->ne[2]%w)%w;
53245321

5322+
const int bs = a->ne[3];
53255323
const int npx = (px + a->ne[1])/w;
53265324
const int npy = (py + a->ne[2])/w;
5327-
const int np = npx*npy;
5325+
const int np = npx*npy*bs;
53285326

53295327
const int64_t ne[4] = { a->ne[0], w, w, np, };
53305328
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
53315329

5332-
int32_t params[] = { npx, npy, w };
5330+
int32_t params[] = { npx, npy, bs, w };
53335331
ggml_set_op_params(result, params, sizeof(params));
53345332

53355333
result->op = GGML_OP_WIN_PART;
@@ -5346,10 +5344,20 @@ struct ggml_tensor * ggml_win_unpart(
53465344
int w0,
53475345
int h0,
53485346
int w) {
5349-
GGML_ASSERT(a->type == GGML_TYPE_F32);
5347+
return ggml_win_unpart_ext(ctx, a, w0, h0, 1, w);
5348+
}
5349+
5350+
struct ggml_tensor * ggml_win_unpart_ext(
5351+
struct ggml_context * ctx,
5352+
struct ggml_tensor * a,
5353+
int w0,
5354+
int h0,
5355+
int b0,
5356+
int w) {
5357+
const int64_t ne[4] = { a->ne[0], w0, h0, b0 };
5358+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
53505359

5351-
const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
5352-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
5360+
GGML_ASSERT(ggml_is_contiguous(a));
53535361

53545362
int32_t params[] = { w };
53555363
ggml_set_op_params(result, params, sizeof(params));
@@ -5367,8 +5375,7 @@ struct ggml_tensor * ggml_get_rel_pos(
53675375
struct ggml_tensor * a,
53685376
int qh,
53695377
int kh) {
5370-
GGML_ASSERT(qh == kh);
5371-
GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
5378+
GGML_ASSERT(qh + kh - 1 <= a->ne[1]);
53725379

53735380
const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
53745381
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
@@ -6421,6 +6428,7 @@ static void ggml_compute_backward(
64216428
} break;
64226429
case GGML_OP_WIN_PART:
64236430
case GGML_OP_WIN_UNPART:
6431+
case GGML_OP_GET_REL_POS:
64246432
case GGML_OP_UNARY: {
64256433
switch (ggml_get_unary_op(tensor)) {
64266434
case GGML_UNARY_OP_ABS: {

0 commit comments

Comments
 (0)