Skip to content

Commit 37a06a0

Browse files
authored
Merge pull request #8 from arthw/fix_q4_1
fix ut fault of Q4_1, Q5..
2 parents c69f491 + c5dc8e9 commit 37a06a0

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

ggml/src/ggml-sycl/common.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ void* ggml_sycl_host_malloc(size_t size) try {
2323
if (getenv("GGML_SYCL_NO_PINNED") != nullptr) {
2424
return nullptr;
2525
}
26-
// ggml_sycl_info().device_mgr->first_queue
26+
2727
void* ptr = nullptr;
2828
// allow to use dpct::get_in_order_queue() for host malloc
2929
auto q = dpct::get_in_order_queue();
@@ -32,7 +32,6 @@ void* ggml_sycl_host_malloc(size_t size) try {
3232
dpct::err0 err = CHECK_TRY_ERROR(
3333
ptr = (void*)sycl::malloc_host(size, q));
3434

35-
// printf("zjy ggml_sycl_host_malloc ptr=%p queue=%p size=%lu \n", ptr,q, size);
3635
if (err != 0) {
3736
// clear the error
3837
GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported");
@@ -113,7 +112,6 @@ void print_device_opt_feature(ggml_sycl_device_info &info) {
113112
int device_count = info.device_count;
114113

115114
for (int id = 0; id < device_count; ++id) {
116-
printf("zjy id=%d\n", id);
117115
sycl::device device = dpct::dev_mgr::instance().get_device(id);
118116
std::string backend_type = get_device_backend_and_type(device);
119117
int type_id = DeviceNums[backend_type]++;

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
7777
}
7878

7979
// sum up partial sums and write back result
80+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? warp_size >> 1 : warp_size >> 2;
81+
8082
#pragma unroll
81-
for (int mask = warp_size / 2; mask > 0; mask >>= 1) {
83+
for (int mask = mask_start; mask > 0; mask >>= 1) {
8284
tmp +=
8385
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8486
}

0 commit comments

Comments
 (0)