@@ -161,7 +161,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
161161 switch (dst->type ) {
162162 case GGML_TYPE_F32:
163163 if (src1->type == GGML_TYPE_I64) {
164- set_rows_sycl<float , float >(
164+ set_rows_sycl<float , int64_t , float >(
165165 (const char *)src0->data , (const int64_t *)src1->data , (char *)dst->data ,
166166 ne00, ne01, ne02, ne03,
167167 ne11, ne12,
@@ -172,7 +172,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
172172 stream
173173 );
174174 } else if (src1->type == GGML_TYPE_I32) {
175- set_rows_sycl<float , float >(
175+ set_rows_sycl<float , int32_t , float >(
176176 (const char *)src0->data , (const int32_t *)src1->data , (char *)dst->data ,
177177 ne00, ne01, ne02, ne03,
178178 ne11, ne12,
@@ -187,7 +187,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
187187 case GGML_TYPE_F16:
188188 dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
189189 if (src1->type == GGML_TYPE_I64) {
190- set_rows_sycl<float , sycl::half>(
190+ set_rows_sycl<float , int64_t , sycl::half>(
191191 (const char *)src0->data , (const int64_t *)src1->data , (char *)dst->data ,
192192 ne00, ne01, ne02, ne03,
193193 ne11, ne12,
@@ -198,7 +198,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
198198 stream
199199 );
200200 } else if (src1->type == GGML_TYPE_I32) {
201- set_rows_sycl<float , sycl::half>(
201+ set_rows_sycl<float , int32_t , sycl::half>(
202202 (const char *)src0->data , (const int32_t *)src1->data , (char *)dst->data ,
203203 ne00, ne01, ne02, ne03,
204204 ne11, ne12,
@@ -212,7 +212,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
212212 break ;
213213 case GGML_TYPE_BF16:
214214 if (src1->type == GGML_TYPE_I64) {
215- set_rows_sycl<float , sycl::ext::oneapi::bfloat16>(
215+ set_rows_sycl<float , int64_t , sycl::ext::oneapi::bfloat16>(
216216 (const char *)src0->data , (const int64_t *)src1->data , (char *)dst->data ,
217217 ne00, ne01, ne02, ne03,
218218 ne11, ne12,
@@ -223,7 +223,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
223223 stream
224224 );
225225 } else if (src1->type == GGML_TYPE_I32) {
226- set_rows_sycl<float , sycl::ext::oneapi::bfloat16>(
226+ set_rows_sycl<float , int32_t , sycl::ext::oneapi::bfloat16>(
227227 (const char *)src0->data , (const int32_t *)src1->data , (char *)dst->data ,
228228 ne00, ne01, ne02, ne03,
229229 ne11, ne12,
0 commit comments