Skip to content

Commit 55b8674

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 0fcce31 + bb115d2 commit 55b8674

File tree

14 files changed

+264
-24
lines changed

14 files changed

+264
-24
lines changed

convert_hf_to_gguf.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,25 @@ def prepare_tensors(self):
17471747
raise ValueError(f"Unprocessed experts: {experts}")
17481748

17491749

1750+
@Model.register("Mistral3ForConditionalGeneration")
1751+
class Mistral3Model(LlamaModel):
1752+
model_arch = gguf.MODEL_ARCH.LLAMA
1753+
1754+
# we need to merge the text_config into the root level of hparams
1755+
def __init__(self, *args, **kwargs):
1756+
hparams = Model.load_hparams(kwargs["dir_model"])
1757+
if "text_config" in hparams:
1758+
hparams = {**hparams, **hparams["text_config"]}
1759+
kwargs["hparams"] = hparams
1760+
super().__init__(*args, **kwargs)
1761+
1762+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
1763+
name = name.replace("language_model.", "")
1764+
if "multi_modal_projector" in name or "vision_tower" in name:
1765+
return []
1766+
return super().modify_tensors(data_torch, name, bid)
1767+
1768+
17501769
@Model.register("DeciLMForCausalLM")
17511770
class DeciModel(Model):
17521771
model_arch = gguf.MODEL_ARCH.DECI

docs/backend/SYCL.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,9 @@ use 1 SYCL GPUs: [0] with Max compute units:512
660660
|--------------------|---------------------------------------|---------------------------------------------|
661661
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
662662
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
663-
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
663+
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
664664
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
665+
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
665666
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
666667
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
667668

@@ -671,6 +672,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
671672
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
672673
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
673674
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
675+
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
674676
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
675677

676678

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,10 @@ struct server_context {
18721872
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
18731873
params_dft.n_parallel = 1;
18741874

1875+
// force F16 KV cache for the draft model for extra performance
1876+
params_dft.cache_type_k = GGML_TYPE_F16;
1877+
params_dft.cache_type_v = GGML_TYPE_F16;
1878+
18751879
llama_init_dft = common_init_from_params(params_dft);
18761880

18771881
model_dft = llama_init_dft.model.get();
@@ -1892,10 +1896,6 @@ struct server_context {
18921896
cparams_dft = common_context_params_to_llama(params_dft);
18931897
cparams_dft.n_batch = n_ctx_dft;
18941898

1895-
// force F16 KV cache for the draft model for extra performance
1896-
cparams_dft.type_k = GGML_TYPE_F16;
1897-
cparams_dft.type_v = GGML_TYPE_F16;
1898-
18991899
// the context is not needed - we will create one for each slot
19001900
llama_init_dft.context.reset();
19011901
}

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ int main(int argc, char ** argv) {
331331
}
332332

333333
active_seqs.erase(s);
334-
for(int i = 0; i < n_seq_dft; i++) {
334+
for (int i = 0; i < n_seq_dft; i++) {
335335
if (i == s) {
336336
continue;
337337
}
338-
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
338+
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
339339
// synchronize active status for sequences with the same drafted token
340340
drafts[i].active = drafts[i].active && accept;
341341
if (!drafts[i].active) {

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ option(GGML_OPENMP "ggml: use OpenMP"
186186
option(GGML_RPC "ggml: use RPC" OFF)
187187
option(GGML_SYCL "ggml: use SYCL" OFF)
188188
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
189+
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
189190
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
190191
"ggml: sycl target device")
191192
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,17 +287,25 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
287287
endif()
288288
endif()
289289
endif()
290-
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
290+
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
291291
message(STATUS "PowerPC detected")
292-
execute_process(COMMAND bash -c "grep POWER /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER_M)
293-
if (${POWER_M} MATCHES "POWER10")
294-
list(APPEND ARCH_FLAGS -mcpu=power10)
295-
elseif (${POWER_M} MATCHES "POWER9")
296-
list(APPEND ARCH_FLAGS -mcpu=power9)
292+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
293+
file(READ "/proc/cpuinfo" POWER10_M)
294+
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
295+
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
296+
endif()
297+
298+
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
299+
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
300+
301+
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
302+
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
303+
elseif (EXTRACTED_NUMBER EQUAL 9)
304+
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
297305
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
298306
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
299307
else()
300-
list(APPEND ARCH_FLAGS -mcpu=powerpc64 -mtune=native)
308+
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
301309
endif()
302310
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
303311
message(STATUS "loongarch64 detected")

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
81588158

81598159
const int nb = n / QK_K;
81608160

8161-
#ifdef __ARM_NEON
8161+
#ifdef __ARM_FEATURE_SVE
8162+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163+
float sum = 0;
8164+
svuint8_t m4b = svdup_n_u8(0xf);
8165+
svint32_t vzero = svdup_n_s32(0);
8166+
svuint8_t mone = svdup_n_u8(0x30);
8167+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169+
8170+
for (int i = 0; i < nb; ++i) {
8171+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172+
8173+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176+
8177+
const int8_t * GGML_RESTRICT scale = x[i].scales;
8178+
8179+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184+
const svint64_t prod = svdup_n_s64(0);
8185+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186+
svdot_s64(prod, q8sums_2, q6scales_2)));
8187+
int32_t isum = 0;
8188+
8189+
switch (vector_length) {
8190+
case 128:
8191+
{
8192+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194+
svint32_t isum_tmp = svdup_n_s32(0);
8195+
for (int j = 0; j < QK_K/128; ++j) {
8196+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198+
qh += 32;
8199+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203+
q6 += 64;
8204+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208+
q8 += 64;
8209+
8210+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222+
8223+
scale += 4;
8224+
q8bytes_1 = svld1_s8(pg8_16, q8);
8225+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228+
q8 += 64;
8229+
8230+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242+
scale += 4;
8243+
}
8244+
isum += svaddv_s32(pg32_4, isum_tmp);
8245+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246+
}
8247+
break;
8248+
case 256:
8249+
case 512:
8250+
{
8251+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254+
svint32_t isum_tmp = svdup_n_s32(0);
8255+
for (int j = 0; j < QK_K/128; j++) {
8256+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257+
qh += 32;
8258+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260+
q6 += 64;
8261+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265+
q8 += 128;
8266+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274+
8275+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291+
8292+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296+
scale += 8;
8297+
}
8298+
isum += svaddv_s32(pg32_8, isum_tmp);
8299+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300+
}
8301+
break;
8302+
default:
8303+
assert(false && "Unsupported vector length");
8304+
break;
8305+
}
8306+
}
8307+
8308+
*s = sum;
8309+
8310+
#elif __ARM_NEON
81628311
float sum = 0;
81638312

81648313
const uint8x16_t m4b = vdupq_n_u8(0xF);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
262262
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
263263
device_vmm ? "yes" : "no", prop.warpSize);
264264
#elif defined(GGML_USE_MUSA)
265+
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
266+
info.devices[id].warp_size = 32;
265267
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
266268
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
267269
info.devices[id].cc = 100*prop.major + 10*prop.minor;

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ if (WIN32)
6666
find_package(MKL REQUIRED)
6767
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
6868
else()
69+
if (GGML_SYCL_GRAPH)
70+
add_compile_definitions(GGML_SYCL_GRAPH)
71+
endif()
6972
if (GGML_SYCL_TARGET STREQUAL "INTEL")
7073
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
7174
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")

ggml/src/ggml-sycl/common.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
301301
return opt;
302302
}
303303

304+
namespace sycl_ex = sycl::ext::oneapi::experimental;
304305
struct ggml_backend_sycl_context {
305306
int device;
306307
std::string name;
@@ -392,6 +393,10 @@ struct ggml_backend_sycl_context {
392393
return pool(device);
393394
}
394395

396+
#ifdef GGML_SYCL_GRAPH
397+
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
398+
#endif
399+
395400
ggml_sycl_pool & host_pool(int device) {
396401
if (host_pools[device] == nullptr) {
397402
host_pools[device] = new_pool_for_host(stream(device, 0), device);

0 commit comments

Comments
 (0)