Skip to content

Commit f44c177

Browse files
committed
ggml : add ggml_fill()
1 parent 4032ca4 commit f44c177

File tree

5 files changed

+105
-1
lines changed

5 files changed

+105
-1
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_FILL,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -1818,6 +1819,12 @@ extern "C" {
18181819
float stop,
18191820
float step);
18201821

1822+
// fill in-place the tensor with a constant value, return view(a)
1823+
GGML_API struct ggml_tensor * ggml_fill(
1824+
struct ggml_context * ctx,
1825+
struct ggml_tensor * a,
1826+
float value);
1827+
18211828
// top k elements per row
18221829
GGML_API struct ggml_tensor * ggml_top_k(
18231830
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1959,6 +1959,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19591959
{
19601960
ggml_compute_forward_arange(params, tensor);
19611961
} break;
1962+
case GGML_OP_FILL:
1963+
{
1964+
ggml_compute_forward_fill(params, tensor);
1965+
} break;
19621966
case GGML_OP_TIMESTEP_EMBEDDING:
19631967
{
19641968
ggml_compute_forward_timestep_embedding(params, tensor);
@@ -2242,6 +2246,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22422246
case GGML_OP_TRANSPOSE:
22432247
case GGML_OP_GET_ROWS_BACK:
22442248
case GGML_OP_DIAG:
2249+
case GGML_OP_ARANGE:
22452250
{
22462251
n_tasks = 1;
22472252
} break;
@@ -2279,7 +2284,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22792284
case GGML_OP_UPSCALE:
22802285
case GGML_OP_PAD:
22812286
case GGML_OP_PAD_REFLECT_1D:
2282-
case GGML_OP_ARANGE:
2287+
case GGML_OP_FILL:
22832288
case GGML_OP_TIMESTEP_EMBEDDING:
22842289
case GGML_OP_ARGSORT:
22852290
case GGML_OP_FLASH_ATTN_EXT:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6833,6 +6833,56 @@ void ggml_compute_forward_arange(
68336833
}
68346834
}
68356835

6836+
// ggml_compute_forward_fill
6837+
6838+
static void ggml_compute_forward_fill_f32(
6839+
const ggml_compute_params * params,
6840+
ggml_tensor * dst) {
6841+
float v;
6842+
memcpy(&v, dst->op_params, sizeof(float));
6843+
6844+
const int ith = params->ith;
6845+
const int nth = params->nth;
6846+
6847+
const int n = ggml_nrows(dst);
6848+
const int nc = dst->ne[0];
6849+
6850+
const size_t nb00 = dst->nb[0];
6851+
const size_t nb01 = dst->nb[1];
6852+
6853+
const size_t nb0 = dst->nb[0];
6854+
const size_t nb1 = dst->nb[1];
6855+
6856+
GGML_ASSERT( nb0 == sizeof(float));
6857+
GGML_ASSERT(nb00 == sizeof(float));
6858+
6859+
for (int j = ith; j < n; j += nth) {
6860+
float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
6861+
6862+
for (int i = 0; i < nc; i++) {
6863+
dst_ptr[i] = v;
6864+
}
6865+
}
6866+
}
6867+
6868+
void ggml_compute_forward_fill(
6869+
const ggml_compute_params * params,
6870+
ggml_tensor * dst) {
6871+
6872+
const ggml_tensor * src0 = dst->src[0];
6873+
6874+
switch (src0->type) {
6875+
case GGML_TYPE_F32:
6876+
{
6877+
ggml_compute_forward_fill_f32(params, dst);
6878+
} break;
6879+
default:
6880+
{
6881+
GGML_ABORT("fatal error");
6882+
}
6883+
}
6884+
}
6885+
68366886
static void ggml_compute_forward_timestep_embedding_f32(
68376887
const ggml_compute_params * params,
68386888
ggml_tensor * dst) {

ggml/src/ggml.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4342,6 +4342,20 @@ struct ggml_tensor * ggml_arange(
43424342
return result;
43434343
}
43444344

4345+
struct ggml_tensor * ggml_fill(
4346+
struct ggml_context * ctx,
4347+
struct ggml_tensor * a,
4348+
float value) {
4349+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
4350+
4351+
ggml_set_op_params(result, &value, sizeof(value));
4352+
4353+
result->op = GGML_OP_FILL;
4354+
result->src[0] = a;
4355+
4356+
return result;
4357+
}
4358+
43454359
// ggml_timestep_embedding
43464360

43474361
struct ggml_tensor * ggml_timestep_embedding(

tests/test-backend-ops.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,6 +2422,32 @@ struct test_clamp : public test_case {
24222422
}
24232423
};
24242424

2425+
// GGML_OP_FILL
2426+
struct test_fill : public test_case {
2427+
const ggml_type type;
2428+
const std::array<int64_t, 4> ne;
2429+
float v;
2430+
2431+
std::string vars() override {
2432+
return VARS_TO_STR3(type, ne, v);
2433+
}
2434+
2435+
test_fill(ggml_type type = GGML_TYPE_F32,
2436+
std::array<int64_t, 4> ne = {10, 5, 4, 3},
2437+
float v = 0.5f)
2438+
: type(type), ne(ne), v(v) {}
2439+
2440+
ggml_tensor * build_graph(ggml_context * ctx) override {
2441+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2442+
ggml_set_name(a, "a");
2443+
2444+
ggml_tensor * out = ggml_fill(ctx, a, v);
2445+
ggml_set_name(out, "out");
2446+
2447+
return out;
2448+
}
2449+
};
2450+
24252451
// GGML_OP_DIAG_MASK_INF
24262452
struct test_diag_mask_inf : public test_case {
24272453
const ggml_type type;
@@ -4199,6 +4225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
41994225
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
42004226
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
42014227

4228+
test_cases.emplace_back(new test_fill(GGML_TYPE_F32));
4229+
42024230
for (ggml_type type_a : all_types) {
42034231
for (int i = 1; i < 10; ++i) {
42044232
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));

0 commit comments

Comments
 (0)