Skip to content
This repository was archived by the owner on Dec 6, 2024. It is now read-only.

Commit 6c9cccf

Browse files
committed
support M1/M2
1 parent 9648bf6 commit 6c9cccf

File tree

4 files changed

+53
-2
lines changed

4 files changed

+53
-2
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ if (GGML_CUBLAS)
3030
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})
3131
endif ()
3232

33+
if (GGML_METAL)
34+
add_compile_definitions(GGML_USE_METAL)
35+
configure_file(third_party/ggml/src/ggml-metal.metal ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
36+
endif ()
37+
3338
file(GLOB CPP_SOURCES
3439
${PROJECT_SOURCE_DIR}/*.h
3540
${PROJECT_SOURCE_DIR}/*.cpp)

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ cuBLAS uses NVIDIA GPU to accelerate BLAS. Add the CMake flag `-DGGML_CUBLAS=ON`
8989
cmake -B build -DGGML_CUBLAS=ON && cmake --build build -j
9090
```
9191

92+
**Metal**
93+
94+
MPS (Metal Performance Shaders) allows computation to run on powerful Apple Silicon GPU. Add the CMake flag `-DGGML_METAL=ON` to enable it.
95+
```sh
96+
cmake -B build -DGGML_METAL=ON && cmake --build build -j
97+
```
98+
9299
## Python Binding
93100

94101
The Python binding provides high-level `chat` and `stream_chat` interface similar to the original Hugging Face Qwen-7B.

qwen.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,21 @@ auto ggml_graph_compute_helper(std::vector<uninitialized_char> &buf, ggml_cgraph
6161
ggml_graph_compute(graph, &plan);
6262
}
6363

64-
auto ModelContext::init_device_context() -> void {}
64+
auto ModelContext::init_device_context() -> void {
65+
#ifdef GGML_USE_METAL
66+
ctx_metal = make_unique_ggml_metal_context(1);
67+
const size_t max_size = ggml_get_max_tensor_size(ctx_w.get());
68+
void *weight_data = weight_buffer.empty() ? ggml_get_mem_buffer(ctx_w.get()) : (void *)weight_buffer.data();
69+
size_t weight_size = weight_buffer.empty() ? ggml_get_mem_size(ctx_w.get()) : weight_buffer.size();
70+
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "weights", weight_data, weight_size, max_size));
71+
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "kv", ggml_get_mem_buffer(ctx_kv.get()),
72+
ggml_get_mem_size(ctx_kv.get()), 0));
73+
void *compute_data = ctx_b ? ggml_get_mem_buffer(ctx_b.get()) : compute_buffer.data();
74+
size_t compute_size = ctx_b ? ggml_get_mem_size(ctx_b.get()) : compute_buffer.size();
75+
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "compute", compute_data, compute_size, 0));
76+
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "scratch", scratch.data, scratch.size, 0));
77+
#endif
78+
}
6579

6680
// ===== streamer =====
6781

@@ -482,7 +496,7 @@ auto get_num_physical_cores() -> int {
482496
}
483497

484498
auto get_default_num_threads() -> int {
485-
#ifdef GGML_USE_CUBLAS
499+
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_METAL)
486500
return 1;
487501
#else
488502
return std::min(get_num_physical_cores(), 16);
@@ -583,7 +597,11 @@ auto QwenForCausalLM::generate_next_token(
583597
}
584598

585599
ggml_build_forward_expand(&ctx_.gf, lm_logits);
600+
#ifdef GGML_USE_METAL
601+
ggml_metal_graph_compute(ctx_.ctx_metal.get(), &ctx_.gf);
602+
#else
586603
ggml_graph_compute_helper(ctx_.work_buffer, &ctx_.gf, n_threads);
604+
#endif
587605

588606
int vocab_size = lm_logits->ne[0];
589607
float *next_token_logits = (float *)lm_logits->data;

qwen.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include <ggml-cuda.h>
1313
#endif
1414

15+
#ifdef GGML_USE_METAL
16+
#include <ggml-metal.h>
17+
#endif
18+
1519
namespace qwen {
1620

1721
class QwenTokenizer;
@@ -58,6 +62,20 @@ static inline auto make_unique_ggml_context(
5862
return unique_ggml_context_t(ggml_init({mem_size, mem_buffer, no_alloc}));
5963
}
6064

65+
#ifdef GGML_USE_METAL
66+
struct ggml_metal_context_deleter_t {
67+
auto operator()(ggml_metal_context *ctx) const noexcept -> void { ggml_metal_free(ctx); }
68+
};
69+
70+
using unique_ggml_metal_context_t = std::unique_ptr<ggml_metal_context, ggml_metal_context_deleter_t>;
71+
72+
static inline auto make_unique_ggml_metal_context(
73+
int n_cb
74+
) -> unique_ggml_metal_context_t {
75+
return unique_ggml_metal_context_t(ggml_metal_init(n_cb));
76+
}
77+
#endif
78+
6179
struct uninitialized_char {
6280
char m;
6381
uninitialized_char() {}
@@ -70,6 +88,9 @@ struct ModelContext {
7088
unique_ggml_context_t ctx_w; // weight
7189
unique_ggml_context_t ctx_kv; // kv cache
7290
unique_ggml_context_t ctx_b; // buffer
91+
#ifdef GGML_USE_METAL
92+
unique_ggml_metal_context_t ctx_metal;
93+
#endif
7394
ggml_cgraph gf;
7495
ggml_scratch scratch;
7596
std::vector<uninitialized_char> compute_buffer; // BLAS buffer

0 commit comments

Comments
 (0)