Skip to content

Commit 840f42d

Browse files
fix: use ggml_cuda_cast for conversion to bf16
1 parent cc8c254 commit 840f42d

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

ggml/src/ggml-cuda/win.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "common.cuh"
2-
#include "ggml.h"
2+
#include "convert.cuh"
33
#include "ggml-cuda/win.cuh"
4+
#include "ggml.h"
45

56
/*
67
@@ -28,7 +29,7 @@ static void ggml_compute_forward_win_part_f16(
2829
for (int64_t i3 = 0; i3 < ne3; i3++) {
2930
int px = i3 % nep0;
3031
int py = (i3 / nep0) % nep1;
31-
int b = i3 / (nep0 * nep1);
32+
int b = i3 / (nep0 * nep1);
3233
for (int64_t i2 = 0; i2 < ne2; ++i2) {
3334
for (int64_t i1 = 0; i1 < ne1; ++i1) {
3435
for (int64_t i0 = 0; i0 < ne0; ++i0) {
@@ -38,7 +39,7 @@ static void ggml_compute_forward_win_part_f16(
3839
const int64_t i00 = i0;
3940
4041
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
41-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
42+
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
4243
4344
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
4445
*((ggml_fp16_t *) dp) = 0;
@@ -138,7 +139,7 @@ __global__ static void win_part_kernel(
138139
if (py*p.w + i2 >= p.ne2 || px*p.w + i1 >= p.ne1) {
139140
for (int i0 = threadIdx.x; i0 < p.C; i0 += blockDim.x) {
140141
char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
141-
*((T *) dp) = 0;
142+
*((T *) dp) = ggml_cuda_cast<T>(0.0f);
142143
}
143144
return;
144145
}
@@ -210,7 +211,7 @@ static unsigned int round_to_pow2(unsigned int v) {
210211
v++;
211212

212213
return v;
213-
}
214+
}
214215

215216
void ggml_cuda_op_win_part(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
216217
const ggml_tensor * src0 = dst->src[0];
@@ -297,12 +298,12 @@ static void ggml_compute_forward_win_unpart_f16(
297298
for (int64_t i0 = 0; i0 < ne0; ++i0) {
298299
const int ip2 = i2/w;
299300
const int ip1 = i1/w;
300-
301+
301302
const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
302303
const int64_t i02 = i2%w;
303304
const int64_t i01 = i1%w;
304305
const int64_t i00 = i0;
305-
306+
306307
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
307308
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
308309

0 commit comments

Comments
 (0)