@@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
2525 }
2626}
2727
28+ static __global__ void pad_f16 (const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
29+ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
30+ // blockIdx.y: idx of ne1
31+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
32+ int nidx = threadIdx .x + blockIdx .x * blockDim .x ;
33+ if (nidx >= ne0) {
34+ return ;
35+ }
36+
37+ // operation
38+ int offset_dst =
39+ nidx +
40+ blockIdx .y * ne0 +
41+ blockIdx .z * ne0 * gridDim .y ;
42+ if (nidx < ne00 && blockIdx .y < ne01 && blockIdx .z < ne02*ne03) {
43+ int offset_src =
44+ nidx +
45+ blockIdx .y * ne00 +
46+ blockIdx .z * ne00 * ne01;
47+ dst[offset_dst] = x[offset_src];
48+ } else {
49+ dst[offset_dst] = 0 .0f ;
50+ }
51+ }
52+
2853static void pad_f32_cuda (const float * x, float * dst,
2954 const int ne00, const int ne01, const int ne02, const int ne03,
3055 const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
@@ -33,17 +58,35 @@ static void pad_f32_cuda(const float * x, float * dst,
3358 pad_f32<<<gridDim , CUDA_PAD_BLOCK_SIZE, 0 , stream>>> (x, dst, ne0, ne00, ne01, ne02, ne03);
3459}
3560
61+ static void pad_f16_cuda (const half * x, half * dst,
62+ const int ne00, const int ne01, const int ne02, const int ne03,
63+ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
64+ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1 ) / CUDA_PAD_BLOCK_SIZE;
65+ dim3 gridDim (num_blocks, ne1, ne2*ne3);
66+ pad_f16<<<gridDim , CUDA_PAD_BLOCK_SIZE, 0 , stream>>> (x, dst, ne0, ne00, ne01, ne02, ne03);
67+ }
68+
3669void ggml_cuda_op_pad (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3770 const ggml_tensor * src0 = dst->src [0 ];
3871 const float * src0_d = (const float *)src0->data ;
3972 float * dst_d = (float *)dst->data ;
4073 cudaStream_t stream = ctx.stream ();
4174
42- GGML_ASSERT (src0->type == GGML_TYPE_F32);
43- GGML_ASSERT (dst->type == GGML_TYPE_F32 );
75+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0-> type == GGML_TYPE_F16 );
76+ GGML_ASSERT (dst->type == src0-> type );
4477 GGML_ASSERT (src0->ne [3 ] == 1 && dst->ne [3 ] == 1 ); // just 3D tensors
4578
46- pad_f32_cuda (src0_d, dst_d,
47- src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
48- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], stream);
79+ if (src0->type == GGML_TYPE_F32) {
80+ const float * src0_d = (const float *)src0->data ;
81+ float * dst_d = (float *)dst->data ;
82+ pad_f32_cuda (src0_d, dst_d,
83+ src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
84+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], stream);
85+ } else {
86+ const half * src0_d = (const half *)src0->data ;
87+ half * dst_d = (half *)dst->data ;
88+ pad_f16_cuda (src0_d, dst_d,
89+ src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
90+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], stream);
91+ }
4992}
0 commit comments