11#include " common.cuh"
2- #include " ggml.h "
2+ #include " convert.cuh "
33#include " ggml-cuda/win.cuh"
4+ #include " ggml.h"
45
56/*
67
@@ -28,7 +29,7 @@ static void ggml_compute_forward_win_part_f16(
2829 for (int64_t i3 = 0; i3 < ne3; i3++) {
2930 int px = i3 % nep0;
3031 int py = (i3 / nep0) % nep1;
31- int b = i3 / (nep0 * nep1);
32+ int b = i3 / (nep0 * nep1);
3233 for (int64_t i2 = 0; i2 < ne2; ++i2) {
3334 for (int64_t i1 = 0; i1 < ne1; ++i1) {
3435 for (int64_t i0 = 0; i0 < ne0; ++i0) {
@@ -38,7 +39,7 @@ static void ggml_compute_forward_win_part_f16(
3839 const int64_t i00 = i0;
3940
4041 void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
41- void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
42+ void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
4243
4344 if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
4445 *((ggml_fp16_t *) dp) = 0;
@@ -138,7 +139,7 @@ __global__ static void win_part_kernel(
138139 if (py*p.w + i2 >= p.ne2 || px*p.w + i1 >= p.ne1 ) {
139140 for (int i0 = threadIdx .x ; i0 < p.C ; i0 += blockDim .x ) {
140141 char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
141- *((T *) dp) = 0 ;
142+ *((T *) dp) = ggml_cuda_cast<T>( 0 . 0f ) ;
142143 }
143144 return ;
144145 }
@@ -210,7 +211,7 @@ static unsigned int round_to_pow2(unsigned int v) {
210211 v++;
211212
212213 return v;
213- }
214+ }
214215
215216void ggml_cuda_op_win_part (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
216217 const ggml_tensor * src0 = dst->src [0 ];
@@ -297,12 +298,12 @@ static void ggml_compute_forward_win_unpart_f16(
297298 for (int64_t i0 = 0; i0 < ne0; ++i0) {
298299 const int ip2 = i2/w;
299300 const int ip1 = i1/w;
300-
301+
301302 const int64_t i03 = i3*npx*npy + ip2*npx + ip1;
302303 const int64_t i02 = i2%w;
303304 const int64_t i01 = i1%w;
304305 const int64_t i00 = i0;
305-
306+
306307 void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
307308 void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
308309
0 commit comments