@@ -356,32 +356,32 @@ float4 *input_frag_mem, float4* filter_frag_mem){
356356
357357
358358// Set of functions per row in Gw product
359- __device__ float f_row1 (float *Gw , int j){
360- return Gw [j];
359+ __device__ float f_row1 (float *G , int j){
360+ return G [j];
361361 }
362- __device__ float f_row2 (float *Gw , int j){
363- return 0.5 *(Gw [j] + Gw [6 +j] + Gw [3 +j]);
362+ __device__ float f_row2 (float *G , int j){
363+ return 0.5 *(G [j] + G [6 +j] + G [3 +j]);
364364 }
365- __device__ float f_row3 (float *Gw , int j){
366- return 0.5 *(Gw [j] + Gw [6 +j] - Gw [3 +j]);
365+ __device__ float f_row3 (float *G , int j){
366+ return 0.5 *(G [j] + G [6 +j] - G [3 +j]);
367367 }
368- __device__ float f_row4 (float *Gw , int j){
369- return Gw [6 +j];
368+ __device__ float f_row4 (float *G , int j){
369+ return G [6 +j];
370370 }
371371 // Set of functions per column in GwGt product
372- __device__ float f_col1 (float *Gw , int j){
373- return Gw [j];
372+ __device__ float f_col1 (float *G , int j){
373+ return G [j];
374374 }
375- __device__ float f_col2 (float *Gw , int j){
376- return 0.5 *(Gw [j] + Gw [j+2 ] + Gw [j+1 ]);
375+ __device__ float f_col2 (float *G , int j){
376+ return 0.5 *(G [j] + G [j+2 ] + G [j+1 ]);
377377 }
378- __device__ float f_col3 (float *Gw , int j){
379- return 0.5 *(Gw [j] + Gw [j+2 ] - Gw [j+1 ]);
378+ __device__ float f_col3 (float *G , int j){
379+ return 0.5 *(G [j] + G [j+2 ] - G [j+1 ]);
380380 }
381- __device__ float f_col4 (float *Gw , int j){
382- return Gw [j+2 ];
381+ __device__ float f_col4 (float *G , int j){
382+ return G [j+2 ];
383383 }
384-
384+
385385 typedef float (*pointFunction_t)(float *, int );
386386
387387 __global__ void FX (const float *pInputs, float *pOutputs, int filt_k,
@@ -403,9 +403,78 @@ __device__ float f_row1(float *Gw, int j){
403403 pointFunction_t func2[4 ] = {f_col1, f_col2, f_col3, f_col4};
404404
405405 for (int bk=0 ; bk<BK; bk+=blockDim .x ){
406+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
407+ // printf("[");
408+ // }
406409 for (int i=0 ; i<9 ; i++){
407410 Gw[i] = pInputs[c_kernel + i*filt_k];
411+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
412+ // printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
413+ // }
414+ }
415+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
416+ // printf("]\n");
417+ // }
418+
419+ int aux;
420+ for (int i=0 ; i<4 ; i++){
421+ aux = i*3 ;
422+ for (int j=0 ; j<3 ; j++){
423+ Gw_buffer[j+aux] = (*func1[i])(Gw, j);
424+ }
425+ }
426+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
427+ // printf("X[");
428+ // for(int kk = 0; kk < 21; kk++){
429+ // printf("%f, ", Gw[kk]);
430+ // }
431+ // printf("]\n");
432+ // }
433+
434+ int aux2;
435+ for (int i=0 ; i<4 ; i++){
436+ aux = i*3 ; aux2 = i<<2 ;
437+ for (int j=0 ; j<4 ; j++){
438+ pOutputs[c_kernel_s+aux2*filt_k+j*filt_k] = (*func2[j])(Gw_buffer, aux);
439+ }
440+ }
441+
442+ c_kernel += blockDim .x ;
443+ c_kernel_s += blockDim .x ;
444+ }
445+ }
446+
447+ __global__ void FX_FP16 (const half *pInputs, float *pOutputs, int filt_k,
448+ int filt_c, int filt_h, int filt_w){
449+
450+ // assumes CHWK layout
451+ int Inx = threadIdx .x , Iny = threadIdx .y ;
452+ int TileX = blockIdx .x , TileY = blockIdx .y ;
453+
454+ int c_glb_offset = filt_k*filt_h*filt_w;
455+ int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
456+ int c_glb_offset_s = filt_k*4 *4 ;
457+ int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx;
458+
459+ float Gw[21 ]; // 9+12. In registers
460+ float *Gw_buffer = Gw+9 ;
461+
462+ pointFunction_t func1[4 ] = {f_row1, f_row2, f_row3, f_row4};
463+ pointFunction_t func2[4 ] = {f_col1, f_col2, f_col3, f_col4};
464+
465+ for (int bk=0 ; bk<BK; bk+=blockDim .x ){
466+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
467+ // printf("[");
468+ // }
469+ for (int i=0 ; i<9 ; i++){
470+ Gw[i] = __half2float (pInputs[c_kernel + i*filt_k]);
471+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
472+ // printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
473+ // }
408474 }
475+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
476+ // printf("]\n");
477+ // }
409478
410479 int aux;
411480 for (int i=0 ; i<4 ; i++){
@@ -414,6 +483,13 @@ __device__ float f_row1(float *Gw, int j){
414483 Gw_buffer[j+aux] = (*func1[i])(Gw, j);
415484 }
416485 }
486+ // if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
487+ // printf("X[");
488+ // for(int kk = 0; kk < 21; kk++){
489+ // printf("%f, ", Gw[kk]);
490+ // }
491+ // printf("]\n");
492+ // }
417493
418494 int aux2;
419495 for (int i=0 ; i<4 ; i++){
@@ -730,7 +806,21 @@ static void conv_winograd_stage0_f32_f32_cuda(
730806 int64_t filt_k = src0_ne0;
731807 int64_t filt_c = src0_ne3;
732808
733- FX<<<dim3 (filt_k/BK, filt_c/BC), dim3 (32 , BC)>>> (src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
809+ FX<<<dim3 (filt_k/BK, filt_c/BC), dim3 (32 , BC), 0 , stream>>> (src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
810+
811+ }
812+
813+ static void conv_winograd_stage0_f16_f32_cuda (
814+ const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
815+ const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
816+ const half * src0, float * dst,
817+ cudaStream_t stream) {
818+
819+
820+ int64_t filt_k = src0_ne0;
821+ int64_t filt_c = src0_ne3;
822+
823+ FX_FP16<<<dim3 (filt_k/BK, filt_c/BC), dim3 (32 , BC), 0 , stream>>> (src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
734824
735825}
736826
@@ -756,38 +846,45 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h,
756846 // printf("B %d, %d, %d \n", in_c, in_h, in_w);
757847 // printf("C %d, %d, %d \n", out_c, out_h, out_w);
758848
759- Winograd_kernel<<<dim3 ((tiles_dim_w+X-1 )/X, (tiles_dim_h+Y-1 )/Y, filt_k/BK), dim3 (BN, 8 ), smem_size>>> (src1, src0, dst,
849+ Winograd_kernel<<<dim3 ((tiles_dim_w+X-1 )/X, (tiles_dim_h+Y-1 )/Y, filt_k/BK), dim3 (BN, 8 ), smem_size, stream >>> (src1, src0, dst,
760850 tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
761851}
762852
763853
764854void ggml_cuda_op_winograd_stage0 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
765855 const ggml_tensor * src0 = dst->src [0 ];
766- // const half * src0_d = (const float *)src0->data;
856+ // const half * src0_d = (const half *)src0->data;
767857
768858 float * dst_d = (float *)dst->data ;
769859 cudaStream_t stream = ctx.stream ();
770- int id = ggml_cuda_get_device ();
860+ // int id = ggml_cuda_get_device();
771861
772862 // GGML_ASSERT(src0->type == GGML_TYPE_F16);
773863 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
774864
775- ggml_cuda_pool_alloc<float > src0_ddq_as_f32 (ctx.pool (id));
776- if (src0->type != GGML_TYPE_F32) {
777- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src0->type );
778- GGML_ASSERT (to_fp32_cuda != nullptr );
779- int64_t nle = ggml_nelements (src0);
780- src0_ddq_as_f32.alloc (nle);
781- const half * src0_dd = (const half *)src0->data ;
782- to_fp32_cuda (src0_dd, src0_ddq_as_f32.get (), nle, stream);
865+ // ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
866+ // if (src0->type != GGML_TYPE_F32) {
867+ // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
868+ // GGML_ASSERT(to_fp32_cuda != nullptr);
869+ // int64_t nle = ggml_nelements(src0);
870+ // src0_ddq_as_f32.alloc(nle);
871+ // const char * src0_dd = (char *)src0->data;
872+ // to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream);
873+ // }
874+
875+ // // GGML_ASSERT(ggml_is_contiguous(src0));
876+ // const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get();
877+ if (src0->type == GGML_TYPE_F32){
878+ const float * src0_d = (const float *)src0->data ;
879+ conv_winograd_stage0_f32_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
880+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
881+ src0_d, dst_d, stream);
882+ }else {
883+ const half * src0_d = (const half *)src0->data ;
884+ conv_winograd_stage0_f16_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
885+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
886+ src0_d, dst_d, stream);
783887 }
784-
785- // GGML_ASSERT(ggml_is_contiguous(src0));
786- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get ();
787-
788- conv_winograd_stage0_f32_f32_cuda (src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
789- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
790- src0_ddf_i, dst_d, stream);
791888}
792889
793890
@@ -822,13 +919,14 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor *
822919 cudaMemcpyToSymbol (access_f_s, aux, 64 *sizeof (int ));
823920 cudaMemcpyToSymbol (access_s, aux2, 64 *sizeof (int ));
824921 cudaMemcpyToSymbol (tileid, tid, 64 *sizeof (int ));
825- // printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size);
922+
826923 conv_winograd_stage1_f32_f32_cuda (tiles_dim_w, tiles_dim_h, 4 , 8 ,
827924 tile_size, tile_2d_s,
828925 src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
829926 src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
830927 dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
831928 src0_d, src1_d, dst_d, stream);
929+
832930}
833931
834932
0 commit comments