Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ poetry.toml
# Local scripts
/run-vim.sh
/run-chat.sh
.history
66 changes: 63 additions & 3 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4215,7 +4215,7 @@ int main(int argc, char ** argv) {
throw std::runtime_error("prompt must be a string");
}

if (oaicompat && has_mtmd) {
if (has_mtmd) {
// multimodal
std::string prompt_str = prompt.get<std::string>();
mtmd_input_text inp_txt = {
Expand Down Expand Up @@ -4332,9 +4332,69 @@ int main(int argc, char ** argv) {
}
};

const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
std::vector<raw_buffer> files; // dummy
json medias = json_value(data, "medias", json::array());
auto & opt = ctx_server.oai_parser_opt;
std::vector<raw_buffer> files;

if (medias.is_array()) {
for (auto & m : medias) {
std::string type = json_value(m, "type", std::string());
std::string data = json_value(m, "data", std::string());
if (type.empty() || data.empty()) {
continue;
}
if (type == "image_url" || type == "image" || type == "img") {
if (!opt.allow_image) {
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
if (string_starts_with(data, "http")) {
// download remote image
common_remote_params params;
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
params.max_size = 1024 * 1024 * 10; // 10MB
params.timeout = 10; // seconds
SRV_INF("downloading image from '%s'\n", data.c_str());
auto res = common_remote_get_content(data, params);
Comment on lines +4504 to +4522
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of duplicating this whole code block, extract it to a general function and reuse it in /chat/completion and /completion. DRY code principle

if (200 <= res.first && res.first < 300) {
SRV_INF("downloaded %ld bytes\n", res.second.size());
raw_buffer buf;
buf.insert(buf.end(), res.second.begin(), res.second.end());
files.push_back(buf);
} else {
throw std::runtime_error("Failed to download image");
}
} else {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(data, /*separator*/ ',');
if (parts.size() != 2) {
throw std::runtime_error("Invalid image_url.url value");
} else if (!string_starts_with(parts[0], "data:image/")) {
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
} else if (!string_ends_with(parts[0], "base64")) {
throw std::runtime_error("image_url.url must be base64 encoded");
} else {
auto base64_data = parts[1];
auto decoded_data = base64_decode(base64_data);
files.push_back(decoded_data);
}
}
} else if (type == "input_audio" || type == "audio") {
if (!opt.allow_audio) {
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
std::string format = json_value(m, "format", std::string());
// while we also support flac, we don't allow it here so we matches the OAI spec
if (format != "wav" && format != "mp3") {
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
}
auto decoded_data = base64_decode(data); // expected to be base64 encoded
files.push_back(decoded_data);
}
}
}

handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
Expand Down