@@ -33,8 +33,8 @@ static __global__ void k_get_rows(
3333 dfloat2 v;
3434 dequantize_kernel (src0_row, ib, iqs, v);
3535
36- dst_row[iybs + iqs + 0 ] = v.x ;
37- dst_row[iybs + iqs + y_offset] = v.y ;
36+ dst_row[iybs + iqs + 0 ] = float ( v.x ) ;
37+ dst_row[iybs + iqs + y_offset] = float ( v.y ) ;
3838}
3939
4040template <typename src0_t , typename dst_t >
@@ -60,7 +60,7 @@ static __global__ void k_get_rows_float(
6060 dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
6161 const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6262
63- dst_row[i00] = src0_row[i00];
63+ dst_row[i00] = float ( src0_row[i00]) ;
6464}
6565
6666template <typename grad_t , typename dst_t >
@@ -86,122 +86,161 @@ static __global__ void k_get_rows_back_float(
8686 dst[dst_row*ncols + col] = sum;
8787}
8888
89- template <int qk, int qr, dequantize_kernel_t dq>
90- static void get_rows_cuda (
91- const ggml_tensor * src0 , const ggml_tensor * src1, ggml_tensor * dst ,
92- const void * src0_dd , const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
93-
94- GGML_TENSOR_BINARY_OP_LOCALS
95-
89+ template <int qk, int qr, dequantize_kernel_t dq, typename dst_t >
90+ static void get_rows_cuda_q (
91+ const void * src0_d , const int32_t * src1_d, dst_t * dst_d ,
92+ const int64_t ne00 , const size_t nb01, const size_t nb02, const size_t nb03,
93+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
94+ const size_t nb1, const size_t nb2, const size_t nb3,
95+ cudaStream_t stream) {
9696 const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
9797 const int block_num_x = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
9898 const dim3 block_nums (block_num_x, ne10, ne11*ne12);
9999
100100 // strides in elements
101- // const size_t s0 = nb0 / ggml_element_size(dst );
102- const size_t s1 = nb1 / ggml_element_size (dst );
103- const size_t s2 = nb2 / ggml_element_size (dst );
104- const size_t s3 = nb3 / ggml_element_size (dst );
101+ // const size_t s0 = nb0 / sizeof(dst_t );
102+ const size_t s1 = nb1 / sizeof ( dst_t );
103+ const size_t s2 = nb2 / sizeof ( dst_t );
104+ const size_t s3 = nb3 / sizeof ( dst_t );
105105
106- const size_t s10 = nb10 / ggml_element_size (src1 );
107- const size_t s11 = nb11 / ggml_element_size (src1 );
108- const size_t s12 = nb12 / ggml_element_size (src1 );
109- // const size_t s13 = nb13 / ggml_element_size(src1 );
106+ const size_t s10 = nb10 / sizeof ( int32_t );
107+ const size_t s11 = nb11 / sizeof ( int32_t );
108+ const size_t s12 = nb12 / sizeof ( int32_t );
109+ // const size_t s13 = nb13 / sizeof(int32_t );
110110
111111 GGML_ASSERT (ne00 % 2 == 0 );
112112
113113 k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
114- src0_dd, src1_dd, dst_dd ,
114+ src0_d, src1_d, dst_d ,
115115 ne00, /* ne01, ne02, ne03,*/
116116 /* ne10, ne11,*/ ne12, /* ne13,*/
117117 /* s0,*/ s1, s2, s3,
118118 /* nb00,*/ nb01, nb02, nb03,
119119 s10, s11, s12/* , s13*/ );
120-
121- GGML_UNUSED (dst);
122120}
123121
124- template <typename src0_t >
122+ template <typename src0_t , typename dst_t >
125123static void get_rows_cuda_float (
126- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
127- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
128-
129- GGML_TENSOR_BINARY_OP_LOCALS
130-
131- GGML_ASSERT (ne13 == 1 );
132-
124+ const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
125+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
126+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
127+ const size_t nb1, const size_t nb2, const size_t nb3,
128+ cudaStream_t stream) {
133129 const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
134130 const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
135131 const dim3 block_nums (block_num_x, ne10, ne11*ne12);
136132
137133 // strides in elements
138- // const size_t s0 = nb0 / ggml_element_size(dst );
139- const size_t s1 = nb1 / ggml_element_size (dst );
140- const size_t s2 = nb2 / ggml_element_size (dst );
141- const size_t s3 = nb3 / ggml_element_size (dst );
134+ // const size_t s0 = nb0 / sizeof(dst_t );
135+ const size_t s1 = nb1 / sizeof ( dst_t );
136+ const size_t s2 = nb2 / sizeof ( dst_t );
137+ const size_t s3 = nb3 / sizeof ( dst_t );
142138
143- const size_t s10 = nb10 / ggml_element_size (src1 );
144- const size_t s11 = nb11 / ggml_element_size (src1 );
145- const size_t s12 = nb12 / ggml_element_size (src1 );
146- // const size_t s13 = nb13 / ggml_element_size(src1 );
139+ const size_t s10 = nb10 / sizeof ( int32_t );
140+ const size_t s11 = nb11 / sizeof ( int32_t );
141+ const size_t s12 = nb12 / sizeof ( int32_t );
142+ // const size_t s13 = nb13 / sizeof(int32_t );
147143
148144 k_get_rows_float<<<block_nums, block_dims, 0 , stream>>> (
149- src0_dd, src1_dd, dst_dd ,
145+ src0_d, src1_d, dst_d ,
150146 ne00, /* ne01, ne02, ne03,*/
151147 /* ne10, ne11,*/ ne12, /* ne13,*/
152148 /* s0,*/ s1, s2, s3,
153149 /* nb00,*/ nb01, nb02, nb03,
154150 s10, s11, s12/* , s13*/ );
155-
156- GGML_UNUSED (dst);
157151}
158152
159- void ggml_cuda_op_get_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160- const ggml_tensor * src0 = dst->src [0 ];
161- const ggml_tensor * src1 = dst->src [1 ];
162-
163- const void * src0_d = (const void *) src0->data ;
164- const int32_t * src1_d = (const int32_t *) src1->data ;
165- float * dst_d = (float *) dst->data ;
166-
167- cudaStream_t stream = ctx.stream ();
168-
169- GGML_ASSERT (src1->type == GGML_TYPE_I32);
170- GGML_ASSERT (dst->type == GGML_TYPE_F32);
171-
172- GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
173- GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
174- GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
175-
176- switch (src0->type ) {
153+ template <typename dst_t >
154+ static void ggml_cuda_get_rows_switch_src0_type (
155+ const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
156+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
157+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
158+ const size_t nb1, const size_t nb2, const size_t nb3,
159+ cudaStream_t stream) {
160+ switch (src0_type) {
177161 case GGML_TYPE_F16:
178- get_rows_cuda_float (src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
162+ get_rows_cuda_float ((const half *) src0_d, src1_d, dst_d,
163+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
179164 break ;
180165 case GGML_TYPE_F32:
181- get_rows_cuda_float (src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
166+ get_rows_cuda_float ((const float *) src0_d, src1_d, dst_d,
167+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
168+ break ;
169+ case GGML_TYPE_BF16:
170+ get_rows_cuda_float ((const nv_bfloat16 *) src0_d, src1_d, dst_d,
171+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
182172 break ;
183173 case GGML_TYPE_Q4_0:
184- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
174+ get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
175+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
185176 break ;
186177 case GGML_TYPE_Q4_1:
187- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
178+ get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
179+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
188180 break ;
189181 case GGML_TYPE_Q5_0:
190- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
182+ get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
183+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
191184 break ;
192185 case GGML_TYPE_Q5_1:
193- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
186+ get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
187+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
194188 break ;
195189 case GGML_TYPE_Q8_0:
196- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
190+ get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
191+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
197192 break ;
198193 default :
199194 // TODO: k-quants
200- GGML_ABORT (" %s: unsupported type: %s\n " , __func__, ggml_type_name (src0-> type ));
195+ GGML_ABORT (" %s: unsupported src0 type: %s\n " , __func__, ggml_type_name (src0_type ));
201196 break ;
202197 }
203198}
204199
200+ void get_rows_cuda (
201+ const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
202+ int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
203+ int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
204+ size_t nb1, size_t nb2, size_t nb3,
205+ cudaStream_t stream) {
206+ switch (dst_type) {
207+ case GGML_TYPE_F32:
208+ ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (float *) dst_d,
209+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
210+ break ;
211+ case GGML_TYPE_F16:
212+ ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (half *) dst_d,
213+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
214+ break ;
215+ case GGML_TYPE_BF16:
216+ ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
217+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
218+ break ;
219+ default :
220+ GGML_ABORT (" %s: unsupported dst type: %s\n " , __func__, ggml_type_name (dst_type));
221+ break ;
222+ }
223+ }
224+
225+ void ggml_cuda_op_get_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
226+ const ggml_tensor * src0 = dst->src [0 ];
227+ const ggml_tensor * src1 = dst->src [1 ];
228+
229+ cudaStream_t stream = ctx.stream ();
230+
231+ GGML_TENSOR_BINARY_OP_LOCALS
232+
233+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
234+ GGML_ASSERT (ne13 == 1 );
235+
236+ GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
237+ GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
238+ GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
239+
240+ get_rows_cuda (src0->data , src0->type , (const int32_t *) src1->data , dst->data , dst->type ,
241+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
242+ }
243+
205244void ggml_cuda_op_get_rows_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206245 const ggml_tensor * src0 = dst->src [0 ]; // gradients of forward pass output
207246 const ggml_tensor * src1 = dst->src [1 ]; // src1 in forward pass
0 commit comments