Skip to content

Commit 6afbf6e

Browse files
author
bssrdf
committed
added a FP16 FX kernel to deal with fp16 filter data; no need to use FP32 buffer
1 parent 893ca79 commit 6afbf6e

File tree

1 file changed

+135
-37
lines changed

1 file changed

+135
-37
lines changed

src/ggml-cuda/conv-winograd.cu

Lines changed: 135 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -356,32 +356,32 @@ float4 *input_frag_mem, float4* filter_frag_mem){
356356

357357

358358
// Set of functions per row in Gw product
359-
__device__ float f_row1(float *Gw, int j){
360-
return Gw[j];
359+
__device__ float f_row1(float *G, int j){
360+
return G[j];
361361
}
362-
__device__ float f_row2(float *Gw, int j){
363-
return 0.5*(Gw[j] + Gw[6+j] + Gw[3+j]);
362+
__device__ float f_row2(float *G, int j){
363+
return 0.5*(G[j] + G[6+j] + G[3+j]);
364364
}
365-
__device__ float f_row3(float *Gw, int j){
366-
return 0.5*(Gw[j] + Gw[6+j] - Gw[3+j]);
365+
__device__ float f_row3(float *G, int j){
366+
return 0.5*(G[j] + G[6+j] - G[3+j]);
367367
}
368-
__device__ float f_row4(float *Gw, int j){
369-
return Gw[6+j];
368+
__device__ float f_row4(float *G, int j){
369+
return G[6+j];
370370
}
371371
// Set of functions per column in GwGt product
372-
__device__ float f_col1(float *Gw, int j){
373-
return Gw[j];
372+
__device__ float f_col1(float *G, int j){
373+
return G[j];
374374
}
375-
__device__ float f_col2(float *Gw, int j){
376-
return 0.5*(Gw[j] + Gw[j+2] + Gw[j+1]);
375+
__device__ float f_col2(float *G, int j){
376+
return 0.5*(G[j] + G[j+2] + G[j+1]);
377377
}
378-
__device__ float f_col3(float *Gw, int j){
379-
return 0.5*(Gw[j] + Gw[j+2] - Gw[j+1]);
378+
__device__ float f_col3(float *G, int j){
379+
return 0.5*(G[j] + G[j+2] - G[j+1]);
380380
}
381-
__device__ float f_col4(float *Gw, int j){
382-
return Gw[j+2];
381+
__device__ float f_col4(float *G, int j){
382+
return G[j+2];
383383
}
384-
384+
385385
typedef float(*pointFunction_t)(float *, int);
386386

387387
__global__ void FX(const float *pInputs, float *pOutputs, int filt_k,
@@ -403,9 +403,78 @@ __device__ float f_row1(float *Gw, int j){
403403
pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4};
404404

405405
for(int bk=0; bk<BK; bk+=blockDim.x){
406+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
407+
// printf("[");
408+
// }
406409
for(int i=0; i<9; i++){
407410
Gw[i] = pInputs[c_kernel + i*filt_k];
411+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
412+
// printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
413+
// }
414+
}
415+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
416+
// printf("]\n");
417+
// }
418+
419+
int aux;
420+
for(int i=0; i<4; i++){
421+
aux = i*3;
422+
for(int j=0; j<3; j++){
423+
Gw_buffer[j+aux] = (*func1[i])(Gw, j);
424+
}
425+
}
426+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
427+
// printf("X[");
428+
// for(int kk = 0; kk < 21; kk++){
429+
// printf("%f, ", Gw[kk]);
430+
// }
431+
// printf("]\n");
432+
// }
433+
434+
int aux2;
435+
for(int i=0; i<4; i++){
436+
aux = i*3; aux2 = i<<2;
437+
for(int j=0; j<4; j++){
438+
pOutputs[c_kernel_s+aux2*filt_k+j*filt_k] = (*func2[j])(Gw_buffer, aux);
439+
}
440+
}
441+
442+
c_kernel += blockDim.x;
443+
c_kernel_s += blockDim.x;
444+
}
445+
}
446+
447+
__global__ void FX_FP16(const half *pInputs, float *pOutputs, int filt_k,
448+
int filt_c, int filt_h, int filt_w){
449+
450+
// assumes CHWK layout
451+
int Inx = threadIdx.x, Iny = threadIdx.y;
452+
int TileX = blockIdx.x, TileY = blockIdx.y;
453+
454+
int c_glb_offset = filt_k*filt_h*filt_w;
455+
int c_kernel = TileY*BC*c_glb_offset + TileX*BK + Iny*c_glb_offset + Inx;
456+
int c_glb_offset_s = filt_k*4*4;
457+
int c_kernel_s = TileY*BC*c_glb_offset_s + TileX*BK + Iny*c_glb_offset_s + Inx;
458+
459+
float Gw[21]; //9+12. In registers
460+
float *Gw_buffer = Gw+9;
461+
462+
pointFunction_t func1[4] = {f_row1, f_row2, f_row3, f_row4};
463+
pointFunction_t func2[4] = {f_col1, f_col2, f_col3, f_col4};
464+
465+
for(int bk=0; bk<BK; bk+=blockDim.x){
466+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
467+
// printf("[");
468+
// }
469+
for(int i=0; i<9; i++){
470+
Gw[i] = __half2float(pInputs[c_kernel + i*filt_k]);
471+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
472+
// printf("(%f,%d) ", Gw[i], c_kernel + i*filt_k);
473+
// }
408474
}
475+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
476+
// printf("]\n");
477+
// }
409478

410479
int aux;
411480
for(int i=0; i<4; i++){
@@ -414,6 +483,13 @@ __device__ float f_row1(float *Gw, int j){
414483
Gw_buffer[j+aux] = (*func1[i])(Gw, j);
415484
}
416485
}
486+
// if(blockIdx.x == 0 && blockIdx.y == 0 && Inx == 0 && Iny == 0){
487+
// printf("X[");
488+
// for(int kk = 0; kk < 21; kk++){
489+
// printf("%f, ", Gw[kk]);
490+
// }
491+
// printf("]\n");
492+
// }
417493

418494
int aux2;
419495
for(int i=0; i<4; i++){
@@ -730,7 +806,21 @@ static void conv_winograd_stage0_f32_f32_cuda(
730806
int64_t filt_k = src0_ne0;
731807
int64_t filt_c = src0_ne3;
732808

733-
FX<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC)>>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
809+
FX<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC), 0, stream>>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
810+
811+
}
812+
813+
static void conv_winograd_stage0_f16_f32_cuda(
814+
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
815+
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
816+
const half * src0, float * dst,
817+
cudaStream_t stream) {
818+
819+
820+
int64_t filt_k = src0_ne0;
821+
int64_t filt_c = src0_ne3;
822+
823+
FX_FP16<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC), 0, stream>>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
734824

735825
}
736826

@@ -756,38 +846,45 @@ static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h,
756846
// printf("B %d, %d, %d \n", in_c, in_h, in_w);
757847
// printf("C %d, %d, %d \n", out_c, out_h, out_w);
758848

759-
Winograd_kernel<<<dim3((tiles_dim_w+X-1)/X, (tiles_dim_h+Y-1)/Y, filt_k/BK), dim3(BN, 8), smem_size>>>(src1, src0, dst,
849+
Winograd_kernel<<<dim3((tiles_dim_w+X-1)/X, (tiles_dim_h+Y-1)/Y, filt_k/BK), dim3(BN, 8), smem_size, stream>>>(src1, src0, dst,
760850
tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
761851
}
762852

763853

764854
void ggml_cuda_op_winograd_stage0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
765855
const ggml_tensor * src0 = dst->src[0];
766-
// const half * src0_d = (const float *)src0->data;
856+
// const half * src0_d = (const half *)src0->data;
767857

768858
float * dst_d = (float *)dst->data;
769859
cudaStream_t stream = ctx.stream();
770-
int id = ggml_cuda_get_device();
860+
// int id = ggml_cuda_get_device();
771861

772862
// GGML_ASSERT(src0->type == GGML_TYPE_F16);
773863
GGML_ASSERT( dst->type == GGML_TYPE_F32);
774864

775-
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
776-
if (src0->type != GGML_TYPE_F32) {
777-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
778-
GGML_ASSERT(to_fp32_cuda != nullptr);
779-
int64_t nle = ggml_nelements(src0);
780-
src0_ddq_as_f32.alloc(nle);
781-
const half * src0_dd = (const half *)src0->data;
782-
to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream);
865+
// ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
866+
// if (src0->type != GGML_TYPE_F32) {
867+
// const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
868+
// GGML_ASSERT(to_fp32_cuda != nullptr);
869+
// int64_t nle = ggml_nelements(src0);
870+
// src0_ddq_as_f32.alloc(nle);
871+
// const char * src0_dd = (char *)src0->data;
872+
// to_fp32_cuda(src0_dd, src0_ddq_as_f32.get(), nle, stream);
873+
// }
874+
875+
// // GGML_ASSERT(ggml_is_contiguous(src0));
876+
// const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get();
877+
if(src0->type == GGML_TYPE_F32){
878+
const float* src0_d = (const float *)src0->data;
879+
conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
880+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
881+
src0_d, dst_d, stream);
882+
}else{
883+
const half * src0_d = (const half *)src0->data;
884+
conv_winograd_stage0_f16_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
885+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
886+
src0_d, dst_d, stream);
783887
}
784-
785-
// GGML_ASSERT(ggml_is_contiguous(src0));
786-
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *)src0->data : src0_ddq_as_f32.get();
787-
788-
conv_winograd_stage0_f32_f32_cuda(src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
789-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
790-
src0_ddf_i, dst_d, stream);
791888
}
792889

793890

@@ -822,13 +919,14 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor *
822919
cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int));
823920
cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int));
824921
cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int));
825-
// printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size);
922+
826923
conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8,
827924
tile_size, tile_2d_s,
828925
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
829926
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
830927
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
831928
src0_d, src1_d, dst_d, stream);
929+
832930
}
833931

834932

0 commit comments

Comments
 (0)