@@ -4232,9 +4232,13 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
42324232}
42334233
42344234struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
4235+ if (ggml_is_empty(tensor)) {
4236+ return tensor;
4237+ }
42354238 if (tensor->buffer) {
42364239 ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
42374240 } else {
4241+ GGML_ASSERT(tensor->data);
42384242 memset(tensor->data, 0, ggml_nbytes(tensor));
42394243 }
42404244 return tensor;
@@ -16851,41 +16855,40 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
1685116855 const struct ggml_tensor * src0 = dst->src[0];
1685216856 const struct ggml_tensor * src1 = dst->src[1];
1685316857
16854- GGML_ASSERT(ggml_is_contiguous(src0));
16855- GGML_ASSERT(ggml_is_contiguous(src1));
16856- GGML_ASSERT(ggml_is_scalar(dst));
16858+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
16859+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
16860+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
16861+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
1685716862 GGML_ASSERT(ggml_are_same_shape(src0, src1));
16863+ GGML_ASSERT(ggml_is_scalar(dst));
16864+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
16865+
16866+ // TODO: handle transposed/permuted matrices
16867+ const int64_t nc = src0->ne[0];
16868+ const int64_t nr = ggml_nrows(src0);
1685816869
1685916870 const int ith = params->ith;
1686016871 const int nth = params->nth;
1686116872
16862- float * sums = (float *) params->wdata;
16863-
16864- // TODO: handle transposed/permuted matrices
16865- const int nc = src0->ne[0];
16866- const int nr = ggml_nrows(src0);
16873+ float * sums = (float *) params->wdata;
16874+ float * st = ((float *) params->wdata) + nth + ith*nc;
16875+ float sum_thread = 0.0f;
1686716876
1686816877 GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
1686916878
16870- if (ith == 0) {
16871- memset(sums, 0, sizeof(float) * (nth + nth * nc));
16872- }
16873- ggml_barrier(params->threadpool);
16874-
1687516879 // rows per thread
16876- const int dr = (nr + nth - 1)/nth;
16880+ const int64_t dr = (nr + nth - 1)/nth;
1687716881
1687816882 // row range for this thread
16879- const int ir0 = dr*ith;
16880- const int ir1 = MIN(ir0 + dr, nr);
16883+ const int64_t ir0 = dr*ith;
16884+ const int64_t ir1 = MIN(ir0 + dr, nr);
1688116885
16882- for (int i1 = ir0; i1 < ir1; i1++) {
16883- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
16884- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
16885- float * st = ((float *) params->wdata) + nth + ith*nc;
16886+ for (int64_t i1 = ir0; i1 < ir1; ++i1) {
16887+ const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
16888+ const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
1688616889
1688716890#ifndef NDEBUG
16888- for (int i = 0; i < nc; ++i) {
16891+ for (int64_t i = 0; i < nc; ++i) {
1688916892 //printf("p[%d] = %f\n", i, p[i]);
1689016893 assert(!isnan(s0[i]));
1689116894 assert(!isnan(s1[i]));
@@ -16894,23 +16897,24 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
1689416897
1689516898 float max = -INFINITY;
1689616899 ggml_vec_max_f32(nc, &max, s0);
16897- ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
16898- assert(sum >= 0.0);
16900+ const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
16901+ assert(sum_softmax >= 0.0);
1689916902
16900- ggml_vec_add1_f32(nc, st, st, -sum );
16903+ ggml_vec_add1_f32(nc, st, st, -sum_softmax );
1690116904 ggml_vec_mul_f32(nc, st, st, s1);
1690216905
16903- float st_sum = 0.0f;
16904- ggml_vec_sum_f32(nc, &st_sum , st);
16905- sums[ith] += st_sum ;
16906+ float sum_st = 0.0f;
16907+ ggml_vec_sum_f32(nc, &sum_st , st);
16908+ sum_thread += sum_st ;
1690616909
1690716910#ifndef NDEBUG
16908- for (int i = 0; i < nc; ++i) {
16911+ for (int64_t i = 0; i < nc; ++i) {
1690916912 assert(!isnan(st[i]));
1691016913 assert(!isinf(st[i]));
1691116914 }
1691216915#endif
1691316916 }
16917+ sums[ith] = sum_thread;
1691416918 ggml_barrier(params->threadpool);
1691516919
1691616920 if (ith == 0) {
@@ -16976,7 +16980,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1697616980 float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
1697716981
1697816982#ifndef NDEBUG
16979- for (int i = 0; i < nc; ++i) {
16983+ for (int64_t i = 0; i < nc; ++i) {
1698016984 //printf("p[%d] = %f\n", i, p[i]);
1698116985 assert(!isnan(s0[i]));
1698216986 assert(!isnan(s1[i]));
@@ -16995,7 +16999,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1699516999 ggml_vec_scale_f32(nc, ds0, d_by_nr);
1699617000
1699717001#ifndef NDEBUG
16998- for (int i = 0; i < nc; ++i) {
17002+ for (int64_t i = 0; i < nc; ++i) {
1699917003 assert(!isnan(ds0[i]));
1700017004 assert(!isinf(ds0[i]));
1700117005 }
0 commit comments