@@ -166,8 +166,8 @@ __device__ void __inline__ outer_product(float4* input_frag, float4* filter_frag
166166 accumulator[1 ][15 ].w += input_frag[3 ].w *filter_frag[3 ].w ;
167167 }
168168
169- extern " C"
170- {
169+ // extern "C"
170+ // {
171171
172172__device__ __forceinline__ void transform_output_tile (float *pOutputs, float2 *C_tile, float2 *At,
173173 int round, int c_tensor, int c_glb_offset, int i1, int i2,
@@ -248,7 +248,7 @@ float4 *input_frag_mem, float4* filter_frag_mem){
248248
249249 float2 *output_smem = (float2 *) shared_mem;
250250 float2 *accumulator = (float2 *) acumm_smem;
251- float2 *C_out = (float2 *)C;
251+ // float2 *C_out = (float2*)C;
252252
253253 float2 *C_tile = (float2 *) input_frag_mem;
254254 float2 *At = (float2 *) filter_frag_mem;
@@ -295,12 +295,11 @@ float4 *input_frag_mem, float4* filter_frag_mem){
295295 // blockIdx.x*BN + (threadIdx.x%16)*2+
296296 // ((threadIdx.x/16)*16 + (threadIdx.y%4)*4 + threadIdx.y/4)*c_glb_offset;
297297
298- int tx = TW, ty = TH;
299298 // int c_tile = blockIdx.x * tx + blockIdx.y * in_w * ty;
300299 // int c_tensor = c_tile + (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * in_w * 2 +
301300 // threadIdx.y*(in_h*in_w) - (in_w+1);
302301
303- int c_tensor = blockIdx .z *c_glb_offset*BK + blockIdx .x * tx + blockIdx .y * out_w * ty +
302+ int c_tensor = blockIdx .z *c_glb_offset*BK + blockIdx .x * TW + blockIdx .y * out_w * TH +
304303 // (threadIdx.x % tw) * 2 + (threadIdx.x / tw) * out_w * 2 +
305304 ((threadIdx .x /16 )*16 + (threadIdx .y %4 )*4 + threadIdx .y /4 )*c_glb_offset;
306305
@@ -382,77 +381,31 @@ __device__ float f_row1(float *G, int j){
382381 return G[j+2 ];
383382 }
384383
385- typedef float (*pointFunction_t)( float *, int );
386-
387- __global__ void FX ( const float *pInputs, float *pOutputs, int filt_k,
388- int filt_c, int filt_h, int filt_w){
384+ template < typename T>
385+ static __device__ __forceinline__ float t2f32 (T val) {
386+ return ( float ) val;
387+ }
389388
390- // assumes CHWK layout
391- int Inx = threadIdx .x , Iny = threadIdx .y ;
392- int TileX = blockIdx .x , TileY = blockIdx .y ;
393-
394- int c_glb_offset = filt_k*filt_h*filt_w;
395- int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
396- int c_glb_offset_s = filt_k*4 *4 ;
397- int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx;
398-
399- float Gw[21 ]; // 9+12. In registers
400- float *Gw_buffer = Gw+9 ;
401-
402- pointFunction_t func1[4 ] = {f_row1, f_row2, f_row3, f_row4};
403- pointFunction_t func2[4 ] = {f_col1, f_col2, f_col3, f_col4};
404-
405- for (int bk=0 ; bk<BK; bk+=blockDim .x ){
406- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
407- // printf("[");
408- // }
409- for (int i=0 ; i<9 ; i++){
410- Gw[i] = pInputs[c_kernel + i*filt_k];
411- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
412- // printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
413- // }
414- }
415- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
416- // printf("]\n");
417- // }
418-
419- int aux;
420- for (int i=0 ; i<4 ; i++){
421- aux = i*3 ;
422- for (int j=0 ; j<3 ; j++){
423- Gw_buffer[j+aux] = (*func1[i])(Gw, j);
424- }
425- }
426- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
427- // printf("X[");
428- // for(int kk = 0; kk < 21; kk++){
429- // printf("%f, ", Gw[kk]);
430- // }
431- // printf("]\n");
432- // }
433-
434- int aux2;
435- for (int i=0 ; i<4 ; i++){
436- aux = i*3 ; aux2 = i<<2 ;
437- for (int j=0 ; j<4 ; j++){
438- pOutputs[c_kernel_s+aux2*filt_k+j*filt_k] = (*func2[j])(Gw_buffer, aux);
439- }
440- }
441-
442- c_kernel += blockDim .x ;
443- c_kernel_s += blockDim .x ;
444- }
389+ template <>
390+ __device__ float __forceinline__ t2f32<half>(half val) {
391+ return __half2float (val);
445392 }
446393
447- __global__ void FX_FP16 (const half *pInputs, float *pOutputs, int filt_k,
394+ typedef float (*pointFunction_t)(float *, int );
395+
396+ template <typename T>
397+ __global__ void FX (const T *pInputs, float *pOutputs, int filt_k,
448398 int filt_c, int filt_h, int filt_w){
449399
450- // assumes CHWK layout
400+ // assumes KCHW layout
451401 int Inx = threadIdx .x , Iny = threadIdx .y ;
452402 int TileX = blockIdx .x , TileY = blockIdx .y ;
453403
454- int c_glb_offset = filt_k*filt_h*filt_w;
455- int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
404+ // int c_glb_offset = filt_k*filt_h*filt_w;
405+ // int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
406+ int c_glb_offset = filt_h*filt_w;
407+ // int c_kernel = TileY*BC*c_glb_offset + TileX*BK*filt_c*c_glb_offset + Iny*c_glb_offset+ Inx*filt_c*c_glb_offset;
408+ int c_kernel = (TileY*BC + (TileX*BK+Inx)*filt_c + Iny)*c_glb_offset;
456409 int c_glb_offset_s = filt_k*4 *4 ;
457410 int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx;
458411
@@ -462,19 +415,11 @@ __device__ float f_row1(float *G, int j){
462415 pointFunction_t func1[4 ] = {f_row1, f_row2, f_row3, f_row4};
463416 pointFunction_t func2[4 ] = {f_col1, f_col2, f_col3, f_col4};
464417
465- for (int bk=0 ; bk<BK; bk+=blockDim .x ){
466- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
467- // printf("[");
468- // }
418+ for (int bk=0 ; bk<BK; bk+=blockDim .x ){
469419 for (int i=0 ; i<9 ; i++){
470- Gw[i] = __half2float (pInputs[c_kernel + i*filt_k]);
471- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
472- // printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
473- // }
474- }
475- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
476- // printf("]\n");
477- // }
420+ Gw[i] = t2f32 (pInputs[c_kernel + i]);
421+
422+ }
478423
479424 int aux;
480425 for (int i=0 ; i<4 ; i++){
@@ -483,14 +428,7 @@ __device__ float f_row1(float *G, int j){
483428 Gw_buffer[j+aux] = (*func1[i])(Gw, j);
484429 }
485430 }
486- // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
487- // printf("X[");
488- // for(int kk = 0; kk < 21; kk++){
489- // printf("%f, ", Gw[kk]);
490- // }
491- // printf("]\n");
492- // }
493-
431+
494432 int aux2;
495433 for (int i=0 ; i<4 ; i++){
496434 aux = i*3 ; aux2 = i<<2 ;
@@ -499,7 +437,7 @@ __device__ float f_row1(float *G, int j){
499437 }
500438 }
501439
502- c_kernel += blockDim .x ;
440+ c_kernel += blockDim .x *(filt_c*c_glb_offset) ;
503441 c_kernel_s += blockDim .x ;
504442 }
505443 }
@@ -793,34 +731,16 @@ cudaError_t convolutionForward_32Tx64x8(float *k, int in_h, int in_w, float *w,
793731 return cudaGetLastError ();
794732}
795733
796- }
797-
734+ // }
798735
799- static void conv_winograd_stage0_f32_f32_cuda (
736+ template <typename T>
737+ static void conv_winograd_stage0_f32_cuda (
800738 const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
801739 const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
802- const float * src0, float * dst,
740+ const T * src0, float * dst,
803741 cudaStream_t stream) {
804742
805-
806- int64_t filt_k = src0_ne0;
807- int64_t filt_c = src0_ne3;
808-
809- FX<<<dim3 (filt_k/BK, filt_c/BC), dim3 (32 , BC), 0 , stream>>> (src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
810-
811- }
812-
813- static void conv_winograd_stage0_f16_f32_cuda (
814- const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
815- const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
816- const half * src0, float * dst,
817- cudaStream_t stream) {
818-
819-
820- int64_t filt_k = src0_ne0;
821- int64_t filt_c = src0_ne3;
822-
823- FX_FP16<<<dim3 (filt_k/BK, filt_c/BC), dim3 (32 , BC), 0 , stream>>> (src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
743+ FX<<<dim3 (src0_ne3/BK, src0_ne2/BC), dim3 (32 , BC), 0 , stream>>> (src0, dst, src0_ne3, src0_ne2, src0_ne1, src0_ne0);
824744
825745}
826746
@@ -842,12 +762,9 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h,
842762 int64_t out_w = in_w;
843763 int smem_size = (16 *BN*BC + 16 *BC*BK)*4 ;
844764
845- // printf("A %d, %d\n", filt_k, filt_c);
846- // printf("B %d, %d, %d \n", in_c, in_h, in_w);
847- // printf("C %d, %d, %d \n", out_c, out_h, out_w);
848-
849765 Winograd_kernel<<<dim3 ((tiles_dim_w+X-1 )/X, (tiles_dim_h+Y-1 )/Y, filt_k/BK), dim3 (BN, 8 ), smem_size, stream>>> (src1, src0, dst,
850- tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
766+ tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y,
767+ filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
851768}
852769
853770
@@ -876,12 +793,14 @@ void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor *
876793 // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get();
877794 if (src0->type == GGML_TYPE_F32){
878795 const float * src0_d = (const float *)src0->data ;
879- conv_winograd_stage0_f32_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
796+ // conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
797+ conv_winograd_stage0_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
880798 dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
881799 src0_d, dst_d, stream);
882800 }else {
883801 const half * src0_d = (const half *)src0->data ;
884- conv_winograd_stage0_f16_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
802+ // conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
803+ conv_winograd_stage0_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
885804 dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
886805 src0_d, dst_d, stream);
887806 }
0 commit comments