Skip to content

Commit a364ec7

Browse files
committed
fix UT of concat
1 parent e700d37 commit a364ec7

File tree

1 file changed

+4
-34
lines changed

1 file changed

+4
-34
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3343,10 +3343,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33433343
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
33443344
queue_ptr main_stream = ctx.stream();;
33453345

3346-
bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
3347-
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;
3348-
3349-
33503346
void * src0_ddq = src0->data;
33513347
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
33523348
float * src1_ddf = (float *) src1->data;
@@ -3364,15 +3360,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33643360
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
33653361
: src1_f16_alloc.get();
33663362

3367-
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
33683363
char * dst_t;
33693364

33703365
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
33713366
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
3372-
if (no_mixed_dtypes) {
3373-
cu_compute_type = dpct::library_data_t::real_half;
3374-
cu_data_type = dpct::library_data_t::real_half;
3375-
}
33763367

33773368
// dst strides
33783369
size_t nbd2 = dst->nb[2];
@@ -3381,26 +3372,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33813372
const float alpha_f32 = 1.0f;
33823373
const float beta_f32 = 0.0f;
33833374

3384-
const sycl::half alpha_f16 = 1.0f;
3385-
const sycl::half beta_f16 = 0.0f;
3386-
33873375
const void * alpha = &alpha_f32;
33883376
const void * beta = &beta_f32;
3389-
if (no_mixed_dtypes) {
3390-
alpha = &alpha_f16;
3391-
beta = &beta_f16;
3392-
}
3393-
3394-
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
3395-
// when oneMKL open source supports half, half, float, float: datatypes
33963377

33973378
dst_t = (char *) dst_ddf;
3398-
if (no_mixed_dtypes) {
3399-
dst_t = (char *) dst_f16.alloc(ne_dst);
3400-
3401-
nbd2 /= sizeof(float) / sizeof(sycl::half);
3402-
nbd3 /= sizeof(float) / sizeof(sycl::half);
3403-
}
34043379

34053380
GGML_ASSERT(ne12 % ne02 == 0);
34063381
GGML_ASSERT(ne13 % ne03 == 0);
@@ -3462,11 +3437,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
34623437
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
34633438
cu_compute_type)));
34643439
}
3465-
3466-
if (no_mixed_dtypes) {
3467-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
3468-
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
3469-
}
34703440
}
34713441
catch (sycl::exception const &exc) {
34723442
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -5069,9 +5039,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
50695039

50705040
ggml_type a_type = a->type;
50715041

5072-
if (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MUL_MAT_ID){
5073-
if (op->src[0]->type == GGML_TYPE_BF16) return false;
5074-
}
50755042

50765043
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
50775044
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
@@ -5082,10 +5049,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
50825049
return false;
50835050
}
50845051
}
5052+
50855053
ggml_type src0_type = op->src[0]->type;
50865054
if (src0_type == GGML_TYPE_BF16) {
50875055
return false;
50885056
}
5057+
50895058
return true;
50905059
} break;
50915060
case GGML_OP_GET_ROWS:
@@ -5133,7 +5102,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
51335102
case GGML_OP_CONCAT:
51345103
{
51355104
ggml_type src0_type = op->src[0]->type;
5136-
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
5105+
int dim = op->op_params[0];
5106+
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
51375107
} break;
51385108
case GGML_OP_DUP:
51395109
case GGML_OP_NONE:

0 commit comments

Comments
 (0)