@@ -61,7 +61,21 @@ auto ggml_graph_compute_helper(std::vector<uninitialized_char> &buf, ggml_cgraph
6161 ggml_graph_compute (graph, &plan);
6262}
6363
64- auto ModelContext::init_device_context () -> void {}
64+ auto ModelContext::init_device_context () -> void {
65+ #ifdef GGML_USE_METAL
66+ ctx_metal = make_unique_ggml_metal_context (1 );
67+ const size_t max_size = ggml_get_max_tensor_size (ctx_w.get ());
68+ void *weight_data = weight_buffer.empty () ? ggml_get_mem_buffer (ctx_w.get ()) : (void *)weight_buffer.data ();
69+ size_t weight_size = weight_buffer.empty () ? ggml_get_mem_size (ctx_w.get ()) : weight_buffer.size ();
70+ QWEN_CHECK (ggml_metal_add_buffer (ctx_metal.get (), " weights" , weight_data, weight_size, max_size));
71+ QWEN_CHECK (ggml_metal_add_buffer (ctx_metal.get (), " kv" , ggml_get_mem_buffer (ctx_kv.get ()),
72+ ggml_get_mem_size (ctx_kv.get ()), 0 ));
73+ void *compute_data = ctx_b ? ggml_get_mem_buffer (ctx_b.get ()) : compute_buffer.data ();
74+ size_t compute_size = ctx_b ? ggml_get_mem_size (ctx_b.get ()) : compute_buffer.size ();
75+ QWEN_CHECK (ggml_metal_add_buffer (ctx_metal.get (), " compute" , compute_data, compute_size, 0 ));
76+ QWEN_CHECK (ggml_metal_add_buffer (ctx_metal.get (), " scratch" , scratch.data , scratch.size , 0 ));
77+ #endif
78+ }
6579
6680// ===== streamer =====
6781
@@ -482,7 +496,7 @@ auto get_num_physical_cores() -> int {
482496}
483497
484498auto get_default_num_threads () -> int {
485- #ifdef GGML_USE_CUBLAS
499+ #if defined( GGML_USE_CUBLAS) || defined(GGML_USE_METAL)
486500 return 1 ;
487501#else
488502 return std::min (get_num_physical_cores (), 16 );
@@ -583,7 +597,11 @@ auto QwenForCausalLM::generate_next_token(
583597 }
584598
585599 ggml_build_forward_expand (&ctx_.gf , lm_logits);
600+ #ifdef GGML_USE_METAL
601+ ggml_metal_graph_compute (ctx_.ctx_metal .get (), &ctx_.gf );
602+ #else
586603 ggml_graph_compute_helper (ctx_.work_buffer , &ctx_.gf , n_threads);
604+ #endif
587605
588606 int vocab_size = lm_logits->ne [0 ];
589607 float *next_token_logits = (float *)lm_logits->data ;
0 commit comments