Skip to content

Commit 1db66e4

Browse files
JohannesGaesslerarthw
authored andcommitted
ggml: refactor cross entropy loss CPU impl. (ggml/976)
1 parent e49b066 commit 1db66e4

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

ggml/include/ggml-backend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ extern "C" {
247247
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
248248

249249
// Initialize backend buffers from a measure graph
250-
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
250+
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
251251

252252
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
253253
GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
@@ -262,7 +262,7 @@ extern "C" {
262262
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
263263

264264
// Allocate and compute graph on the backend scheduler
265-
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
265+
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
266266
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
267267
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
268268
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);

ggml/src/ggml.c

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4232,9 +4232,13 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
42324232
}
42334233

42344234
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
4235+
if (ggml_is_empty(tensor)) {
4236+
return tensor;
4237+
}
42354238
if (tensor->buffer) {
42364239
ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
42374240
} else {
4241+
GGML_ASSERT(tensor->data);
42384242
memset(tensor->data, 0, ggml_nbytes(tensor));
42394243
}
42404244
return tensor;
@@ -16851,41 +16855,40 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
1685116855
const struct ggml_tensor * src0 = dst->src[0];
1685216856
const struct ggml_tensor * src1 = dst->src[1];
1685316857

16854-
GGML_ASSERT(ggml_is_contiguous(src0));
16855-
GGML_ASSERT(ggml_is_contiguous(src1));
16856-
GGML_ASSERT(ggml_is_scalar(dst));
16858+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
16859+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
16860+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
16861+
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
1685716862
GGML_ASSERT(ggml_are_same_shape(src0, src1));
16863+
GGML_ASSERT(ggml_is_scalar(dst));
16864+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
16865+
16866+
// TODO: handle transposed/permuted matrices
16867+
const int64_t nc = src0->ne[0];
16868+
const int64_t nr = ggml_nrows(src0);
1685816869

1685916870
const int ith = params->ith;
1686016871
const int nth = params->nth;
1686116872

16862-
float * sums = (float *) params->wdata;
16863-
16864-
// TODO: handle transposed/permuted matrices
16865-
const int nc = src0->ne[0];
16866-
const int nr = ggml_nrows(src0);
16873+
float * sums = (float *) params->wdata;
16874+
float * st = ((float *) params->wdata) + nth + ith*nc;
16875+
float sum_thread = 0.0f;
1686716876

1686816877
GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
1686916878

16870-
if (ith == 0) {
16871-
memset(sums, 0, sizeof(float) * (nth + nth * nc));
16872-
}
16873-
ggml_barrier(params->threadpool);
16874-
1687516879
// rows per thread
16876-
const int dr = (nr + nth - 1)/nth;
16880+
const int64_t dr = (nr + nth - 1)/nth;
1687716881

1687816882
// row range for this thread
16879-
const int ir0 = dr*ith;
16880-
const int ir1 = MIN(ir0 + dr, nr);
16883+
const int64_t ir0 = dr*ith;
16884+
const int64_t ir1 = MIN(ir0 + dr, nr);
1688116885

16882-
for (int i1 = ir0; i1 < ir1; i1++) {
16883-
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
16884-
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
16885-
float * st = ((float *) params->wdata) + nth + ith*nc;
16886+
for (int64_t i1 = ir0; i1 < ir1; ++i1) {
16887+
const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
16888+
const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
1688616889

1688716890
#ifndef NDEBUG
16888-
for (int i = 0; i < nc; ++i) {
16891+
for (int64_t i = 0; i < nc; ++i) {
1688916892
//printf("p[%d] = %f\n", i, p[i]);
1689016893
assert(!isnan(s0[i]));
1689116894
assert(!isnan(s1[i]));
@@ -16894,23 +16897,24 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
1689416897

1689516898
float max = -INFINITY;
1689616899
ggml_vec_max_f32(nc, &max, s0);
16897-
ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
16898-
assert(sum >= 0.0);
16900+
const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
16901+
assert(sum_softmax >= 0.0);
1689916902

16900-
ggml_vec_add1_f32(nc, st, st, -sum);
16903+
ggml_vec_add1_f32(nc, st, st, -sum_softmax);
1690116904
ggml_vec_mul_f32(nc, st, st, s1);
1690216905

16903-
float st_sum = 0.0f;
16904-
ggml_vec_sum_f32(nc, &st_sum, st);
16905-
sums[ith] += st_sum;
16906+
float sum_st = 0.0f;
16907+
ggml_vec_sum_f32(nc, &sum_st, st);
16908+
sum_thread += sum_st;
1690616909

1690716910
#ifndef NDEBUG
16908-
for (int i = 0; i < nc; ++i) {
16911+
for (int64_t i = 0; i < nc; ++i) {
1690916912
assert(!isnan(st[i]));
1691016913
assert(!isinf(st[i]));
1691116914
}
1691216915
#endif
1691316916
}
16917+
sums[ith] = sum_thread;
1691416918
ggml_barrier(params->threadpool);
1691516919

1691616920
if (ith == 0) {
@@ -16976,7 +16980,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1697616980
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
1697716981

1697816982
#ifndef NDEBUG
16979-
for (int i = 0; i < nc; ++i) {
16983+
for (int64_t i = 0; i < nc; ++i) {
1698016984
//printf("p[%d] = %f\n", i, p[i]);
1698116985
assert(!isnan(s0[i]));
1698216986
assert(!isnan(s1[i]));
@@ -16995,7 +16999,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1699516999
ggml_vec_scale_f32(nc, ds0, d_by_nr);
1699617000

1699717001
#ifndef NDEBUG
16998-
for (int i = 0; i < nc; ++i) {
17002+
for (int64_t i = 0; i < nc; ++i) {
1699917003
assert(!isnan(ds0[i]));
1700017004
assert(!isinf(ds0[i]));
1700117005
}

0 commit comments

Comments
 (0)