Skip to content

Commit 8972f1d

Browse files
committed
Merge branch '0cc4m/vulkan-subgroup-size-control' of https://github.com/ggerganov/llama.cpp into vulkan
2 parents 1c16367 + 595c1a7 commit 8972f1d

File tree

12 files changed

+551
-415
lines changed

12 files changed

+551
-415
lines changed

convert_hf_to_gguf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,14 @@ def set_vocab(self):
19921992
except FileNotFoundError:
19931993
self._set_vocab_gpt2()
19941994

1995+
def set_gguf_parameters(self):
1996+
super().set_gguf_parameters()
1997+
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
1998+
if self.hparams["rope_scaling"].get("type") == "yarn":
1999+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2000+
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
2001+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
2002+
19952003

19962004
@Model.register("Qwen2MoeForCausalLM")
19972005
class Qwen2MoeModel(Model):

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make
687687
}
688688
},
689689
"total_slots": 1,
690+
"model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
690691
"chat_template": "..."
691692
}
692693
```
693694

694695
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
695696
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
697+
- `model_path` - the path to model file (same with `-m` argument)
696698
- `chat_template` - the model's original Jinja2 prompt template
697699

698700
### POST `/props`: Change server global properties.

examples/server/server.cpp

Lines changed: 366 additions & 378 deletions
Large diffs are not rendered by default.

examples/server/tests/unit/test_basic.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ def test_server_props():
2222
server.start()
2323
res = server.make_request("GET", "/props")
2424
assert res.status_code == 200
25+
assert ".gguf" in res.body["model_path"]
2526
assert res.body["total_slots"] == server.n_slots
27+
default_val = res.body["default_generation_settings"]
28+
assert server.n_ctx is not None and server.n_slots is not None
29+
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
30+
assert default_val["params"]["seed"] == server.seed
2631

2732

2833
def test_server_models():
@@ -33,6 +38,31 @@ def test_server_models():
3338
assert len(res.body["data"]) == 1
3439
assert res.body["data"][0]["id"] == server.model_alias
3540

41+
42+
def test_server_slots():
43+
global server
44+
45+
# without slots endpoint enabled, this should return error
46+
server.server_slots = False
47+
server.start()
48+
res = server.make_request("GET", "/slots")
49+
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
50+
assert "error" in res.body
51+
server.stop()
52+
53+
# with slots endpoint enabled, this should return slots info
54+
server.server_slots = True
55+
server.n_slots = 2
56+
server.start()
57+
res = server.make_request("GET", "/slots")
58+
assert res.status_code == 200
59+
assert len(res.body) == server.n_slots
60+
assert server.n_ctx is not None and server.n_slots is not None
61+
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
62+
assert "params" in res.body[0]
63+
assert res.body[0]["params"]["seed"] == server.seed
64+
65+
3666
def test_load_split_model():
3767
global server
3868
server.model_hf_repo = "ggml-org/models"

examples/server/tests/unit/test_chat_completion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
3030
],
3131
})
3232
assert res.status_code == 200
33+
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
3334
assert res.body["model"] == model if model is not None else server.model_alias
3435
assert res.body["usage"]["prompt_tokens"] == n_prompt
3536
assert res.body["usage"]["completion_tokens"] == n_predicted
@@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
5960
"stream": True,
6061
})
6162
content = ""
63+
last_cmpl_id = None
6264
for data in res:
6365
choice = data["choices"][0]
6466
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
67+
if last_cmpl_id is None:
68+
last_cmpl_id = data["id"]
69+
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
6570
if choice["finish_reason"] in ["stop", "length"]:
6671
assert data["usage"]["prompt_tokens"] == n_prompt
6772
assert data["usage"]["completion_tokens"] == n_predicted

examples/server/tests/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class ServerProcess:
6464
server_embeddings: bool | None = False
6565
server_reranking: bool | None = False
6666
server_metrics: bool | None = False
67+
server_slots: bool | None = False
6768
draft: int | None = None
6869
api_key: str | None = None
6970
response_format: str | None = None
@@ -91,7 +92,6 @@ def start(self, timeout_seconds: int = 10) -> None:
9192
else:
9293
server_path = "../../../build/bin/llama-server"
9394
server_args = [
94-
"--slots", # requires to get slot status via /slots endpoint
9595
"--host",
9696
self.server_host,
9797
"--port",
@@ -129,6 +129,8 @@ def start(self, timeout_seconds: int = 10) -> None:
129129
server_args.append("--reranking")
130130
if self.server_metrics:
131131
server_args.append("--metrics")
132+
if self.server_slots:
133+
server_args.append("--slots")
132134
if self.model_alias:
133135
server_args.extend(["--alias", self.model_alias])
134136
if self.n_ctx:
@@ -181,7 +183,7 @@ def start(self, timeout_seconds: int = 10) -> None:
181183
start_time = time.time()
182184
while time.time() - start_time < timeout_seconds:
183185
try:
184-
response = self.make_request("GET", "/slots", headers={
186+
response = self.make_request("GET", "/health", headers={
185187
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
186188
})
187189
if response.status_code == 200:
@@ -224,7 +226,7 @@ def make_request(
224226
result.headers = dict(response.headers)
225227
result.status_code = response.status_code
226228
result.body = response.json() if parse_body else None
227-
print("Response from server", result.body)
229+
print("Response from server", json.dumps(result.body, indent=2))
228230
return result
229231

230232
def make_stream_request(
@@ -245,7 +247,7 @@ def make_stream_request(
245247
break
246248
elif line.startswith('data: '):
247249
data = json.loads(line[6:])
248-
print("Partial response from server", data)
250+
print("Partial response from server", json.dumps(data, indent=2))
249251
yield data
250252

251253

examples/server/utils.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
164164
} else {
165165
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
166166
}
167+
if (result.empty()) {
168+
throw std::runtime_error("\"prompt\" must not be empty");
169+
}
167170
return result;
168171
}
169172

@@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
496499
const std::string & chat_template) {
497500
json llama_params;
498501

499-
llama_params["__oaicompat"] = true;
500-
501502
// Apply chat template to the list of messages
502503
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
503504

@@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
648649
{"content", content}
649650
};
650651
}
652+
653+
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
654+
json data = json::array();
655+
for (const auto & lb : logit_bias) {
656+
data.push_back(json{
657+
{"bias", lb.bias},
658+
{"token", lb.token},
659+
});
660+
}
661+
return data;
662+
}
663+
664+
static std::string safe_json_to_str(json data) {
665+
return data.dump(-1, ' ', false, json::error_handler_t::replace);
666+
}

ggml/src/ggml-vulkan/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,20 @@ if (Vulkan_FOUND)
88
../../include/ggml-vulkan.h
99
)
1010

11+
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
12+
# If it's not, there will be an error to stderr.
13+
# If it's supported, set a define to indicate that we should compile those shaders
14+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
15+
OUTPUT_VARIABLE glslc_output
16+
ERROR_VARIABLE glslc_error)
17+
18+
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
19+
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
20+
else()
21+
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
22+
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
23+
endif()
24+
1125
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
1226
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
1327

0 commit comments

Comments
 (0)