Skip to content

Commit 88461f2

Browse files
committed
support remote image_url
1 parent 3304b44 commit 88461f2

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

tools/server/server.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4327,7 +4327,13 @@ int main(int argc, char ** argv) {
43274327

43284328
auto body = json::parse(req.body);
43294329
std::vector<raw_buffer> files;
4330-
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get(), files);
4330+
json data = oaicompat_completion_params_parse(
4331+
body,
4332+
params.use_jinja,
4333+
params.reasoning_format,
4334+
ctx_server.chat_templates.get(),
4335+
ctx_server.mctx,
4336+
files);
43314337

43324338
return handle_completions_impl(
43334339
SERVER_TASK_TYPE_COMPLETION,
@@ -4342,7 +4348,13 @@ int main(int argc, char ** argv) {
43424348
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
43434349
auto body = json::parse(req.body);
43444350
std::vector<raw_buffer> files; // dummy, unused
4345-
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get(), files);
4351+
json data = oaicompat_completion_params_parse(
4352+
body,
4353+
params.use_jinja,
4354+
params.reasoning_format,
4355+
ctx_server.chat_templates.get(),
4356+
ctx_server.mctx,
4357+
files);
43464358
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
43474359
};
43484360

tools/server/utils.hpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "common.h"
44
#include "log.h"
55
#include "llama.h"
6+
#include "arg.h" // common_remote_get_content
67
#include "base64.hpp"
78
#include "mtmd.h"
89

@@ -584,6 +585,7 @@ static json oaicompat_completion_params_parse(
584585
bool use_jinja,
585586
common_reasoning_format reasoning_format,
586587
const struct common_chat_templates * tmpls,
588+
bool allow_non_text,
587589
std::vector<raw_buffer> & out_files)
588590
{
589591
json llama_params;
@@ -654,21 +656,41 @@ static json oaicompat_completion_params_parse(
654656
std::string type = json_value(p, "type", std::string());
655657
json image_url = json_value(p, "image_url", json::object());
656658
if (type == "image_url") {
659+
if (!allow_non_text) {
660+
throw std::runtime_error("image input is not supported by this server");
661+
}
662+
657663
std::string url = json_value(image_url, "url", std::string());
658-
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
659-
if (parts.size() != 2) {
660-
throw std::runtime_error("Invalid image_url.url value");
661-
} else if (!string_starts_with(parts[0], "data:image/")) {
662-
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
663-
} else if (!string_ends_with(parts[0], "base64")) {
664-
throw std::runtime_error("image_url.url must be base64 encoded");
664+
if (string_starts_with(url, "http")) {
665+
// download remote image
666+
// TODO @ngxson : maybe make these params configurable
667+
common_remote_params params;
668+
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
669+
params.max_size = 1024 * 1024 * 10; // 10MB
670+
auto res = common_remote_get_content(url, params);
671+
raw_buffer data;
672+
data.insert(data.end(), res.second.begin(), res.second.end());
673+
out_files.push_back(data);
674+
665675
} else {
666-
auto base64_data = parts[1];
667-
auto decoded_data = base64_decode(base64_data);
668-
out_files.push_back(decoded_data);
676+
// try to decode base64 image
677+
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
678+
if (parts.size() != 2) {
679+
throw std::runtime_error("Invalid image_url.url value");
680+
} else if (!string_starts_with(parts[0], "data:image/")) {
681+
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
682+
} else if (!string_ends_with(parts[0], "base64")) {
683+
throw std::runtime_error("image_url.url must be base64 encoded");
684+
} else {
685+
auto base64_data = parts[1];
686+
auto decoded_data = base64_decode(base64_data);
687+
out_files.push_back(decoded_data);
688+
}
669689
}
690+
691+
// replace this chunk with a marker
670692
p["type"] = "text";
671-
p["text"] = "<__image__>";
693+
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
672694
p.erase("image_url");
673695
}
674696
}

0 commit comments

Comments
 (0)