Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.port = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
add_opt(common_arg(
{"--allowed-local-media-path"}, "PATH",
string_format("path from which local media files are allowed to be read from (default: none)"),
[](common_params & params, const std::string & value) {
try {
params.allowed_local_media_path = std::filesystem::canonical(std::filesystem::path(value));
if (!std::filesystem::is_directory(params.allowed_local_media_path)) {
throw std::invalid_argument(string_format("allowed local media path must be a dir: %s", params.allowed_local_media_path.c_str()));
}
} catch (std::filesystem::filesystem_error &err) {
throw std::invalid_argument(string_format("invalid allowed local media path: %s", err.what()));
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALLOWED_LOCAL_MEDIA_PATH"));
add_opt(common_arg(
{"--local-media-max-size-mb"}, "N",
string_format("max size in mb for local media files (default: %lu)", params.local_media_max_size_mb),
[](common_params & params, int value) {
params.local_media_max_size_mb = static_cast<size_t>(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_LOCAL_MEDIA_MAX_SIZE_MB"));
add_opt(common_arg(
{"--path"}, "PATH",
string_format("path to serve static files from (default: %s)", params.public_path.c_str()),
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ggml-opt.h"
#include "llama-cpp.h"

#include <filesystem>
#include <set>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -454,9 +455,11 @@ struct common_params {
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
size_t local_media_max_size_mb = 15; // 0 = no limit, 15 = 1 MiB. Max size of loaded local media files

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::filesystem::path allowed_local_media_path; // NOLINT
std::string api_prefix = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
Expand Down
4 changes: 4 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ The project is under active development, and we are [looking for feedback and co
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
| `--allowed-local-media-path PATH` | path from which local media files are allowed to be read from (default: none)<br/>(env: LLAMA_ARG_ALLOWED_LOCAL_MEDIA_PATH) |
| `--local-media-max-size-mb N` | max size in mb for local media files (default: 15)<br/>(env: LLAMA_ARG_LOCAL_MEDIA_MAX_SIZE_MB) |
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
| `--no-webui` | Disable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_NO_WEBUI) |
Expand Down Expand Up @@ -1213,6 +1215,8 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte

If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.

We also support local files as input (e.g. `file://`) if enabled (see `--allowed-local-media-path` and `--local-media-max-size-mb` for details).

*Options:*

See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
Expand Down
32 changes: 32 additions & 0 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <random>
#include <sstream>
#include <fstream>

json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
Expand Down Expand Up @@ -881,6 +882,37 @@ json oaicompat_chat_params_parse(
throw std::runtime_error("Failed to download image");
}

} else if (string_starts_with(url, "file://")) {
if (opt.allowed_local_media_path.empty()) {
throw std::runtime_error("Local media paths are not enabled");
}
// Strip off the leading "file://"
const std::string fname = url.substr(7);
const std::filesystem::path input_path = std::filesystem::canonical(std::filesystem::path(fname));
auto [allowed_end, nothing] = std::mismatch(opt.allowed_local_media_path.begin(), opt.allowed_local_media_path.end(), input_path.begin());
if (allowed_end != opt.allowed_local_media_path.end()) {
throw std::runtime_error("Local media file path not allowed: " + fname);
}
if (!std::filesystem::is_regular_file(input_path)) {
throw std::runtime_error("Local media file does not exist: " + fname);
}
const auto file_size = std::filesystem::file_size(input_path);
if (file_size > opt.local_media_max_size_mb * 1024 * 1024) {
throw std::runtime_error("Local media file exceeds maximum allowed size");
}
// load local file path
std::ifstream f(input_path, std::ios::binary);
if (!f) {
SRV_ERR("Unable to open file %s: %s\n", fname.c_str(), strerror(errno));
throw std::runtime_error("Unable to open local media file: " + fname);
}
raw_buffer buf((std::istreambuf_iterator(f)), std::istreambuf_iterator<char>());
if (buf.size() != file_size) {
SRV_ERR("Failed to read entire file %s", fname.c_str());
throw std::runtime_error("Failed to read entire image file");
}
out_files.push_back(buf);

} else {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ struct oaicompat_parser_options {
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
size_t local_media_max_size_mb;
std::filesystem::path allowed_local_media_path;
};

// used by /chat/completions endpoint
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ struct server_context {
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* local_media_max_size_mb */ params_base.local_media_max_size_mb,
/* allowed_local_media_path */ params_base.allowed_local_media_path,
};

// print sample chat example to make it clear which template is used
Expand Down
87 changes: 85 additions & 2 deletions tools/server/tests/unit/test_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from utils import *
import base64
import requests
from pathlib import Path

server: ServerProcess

def get_img_url(id: str) -> str:
def get_img_url(id: str, tmp_path: str | None = None) -> str:
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
IMG_FILE_2 = "https://picsum.photos/id/237/5000"
if id == "IMG_URL_0":
return IMG_URL_0
elif id == "IMG_URL_1":
Expand All @@ -28,6 +30,45 @@ def get_img_url(id: str) -> str:
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_FILE_0":
if tmp_path is None:
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
image_name = IMG_URL_0.split('/')[-1]
file_name: Path = Path(tmp_path) / image_name
if file_name.exists():
return f"file://{file_name}"
else:
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_name, 'wb') as f:
f.write(response.content)
return f"file://{file_name}"
elif id == "IMG_FILE_1":
if tmp_path is None:
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
image_name = IMG_URL_1.split('/')[-1]
file_name: Path = Path(tmp_path) / image_name
if file_name.exists():
return f"file://{file_name}"
else:
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_name, 'wb') as f:
f.write(response.content)
return f"file://{file_name}"
elif id == "IMG_FILE_2":
if tmp_path is None:
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
image_name = "dog.jpg"
file_name: Path = Path(tmp_path) / image_name
if file_name.exists():
return f"file://{file_name}"
else:
response = requests.get(IMG_FILE_2)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_name, 'wb') as f:
f.write(response.content)
return f"file://{file_name}"
else:
return id

Expand Down Expand Up @@ -70,6 +111,9 @@ def test_v1_models_supports_multimodal_capability():
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
("What is this:\n", "IMG_FILE_0", False, None),
("What is this:\n", "IMG_FILE_1", False, None),
("What is this:\n", "IMG_FILE_2", False, None),
# TODO @ngxson : test with multiple images, no images and with audio
]
)
Expand All @@ -83,7 +127,46 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": get_img_url(image_url),
"url": get_img_url(image_url, "./tmp"),
}},
]},
],
})
if success:
assert res.status_code == 200
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200

@pytest.mark.parametrize(
"allowed_mb_size, allowed_path, img_dir_path, prompt, image_url, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
(0, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_0", True, "(cat)+"),
(0, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_1", True, "(frog)+"),
(1, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_2", False, None),
(0, "./tmp/allowed", "./tmp", "What is this:\n", "IMG_FILE_0", False, None),
(0, "./tm", "./tmp", "What is this:\n", "IMG_FILE_0", False, None),
(0, "./tmp/allowed", "./tmp/allowed/..", "What is this:\n", "IMG_FILE_0", False, None),
(0, "./tmp/allowed", "./tmp/allowed/../.", "What is this:\n", "IMG_FILE_0", False, None),
]
)
def test_vision_chat_completion_local_files(allowed_mb_size, allowed_path, img_dir_path, prompt, image_url, success, re_content):
global server
server.local_media_max_size_mb = allowed_mb_size
server.allowed_local_media_path = allowed_path
Path(allowed_path).mkdir(exist_ok=True)
server.start()
res = server.make_request("POST", "/chat/completions", data={
"temperature": 0.0,
"top_k": 1,
"messages": [
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": get_img_url(image_url, img_dir_path),
}},
]},
],
Expand Down
6 changes: 6 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class ServerProcess:
chat_template_file: str | None = None
server_path: str | None = None
mmproj_url: str | None = None
local_media_max_size_mb: int | None = None
allowed_local_media_path: str | None = None

# session variables
process: subprocess.Popen | None = None
Expand Down Expand Up @@ -215,6 +217,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.mmproj_url:
server_args.extend(["--mmproj-url", self.mmproj_url])
if self.local_media_max_size_mb:
server_args.extend(["--local-media-max-size-mb", self.local_media_max_size_mb])
if self.allowed_local_media_path:
server_args.extend(["--allowed-local-media-path", self.allowed_local_media_path])

args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
Expand Down