|
3 | 3 | #include "common.h" |
4 | 4 | #include "log.h" |
5 | 5 | #include "llama.h" |
| 6 | +#include "arg.h" // common_remote_get_content |
6 | 7 | #include "base64.hpp" |
7 | 8 | #include "mtmd.h" |
8 | 9 |
|
@@ -584,6 +585,7 @@ static json oaicompat_completion_params_parse( |
584 | 585 | bool use_jinja, |
585 | 586 | common_reasoning_format reasoning_format, |
586 | 587 | const struct common_chat_templates * tmpls, |
| 588 | + bool allow_non_text, |
587 | 589 | std::vector<raw_buffer> & out_files) |
588 | 590 | { |
589 | 591 | json llama_params; |
@@ -654,21 +656,41 @@ static json oaicompat_completion_params_parse( |
654 | 656 | std::string type = json_value(p, "type", std::string()); |
655 | 657 | json image_url = json_value(p, "image_url", json::object()); |
656 | 658 | if (type == "image_url") { |
| 659 | + if (!allow_non_text) { |
| 660 | + throw std::runtime_error("image input is not supported by this server"); |
| 661 | + } |
| 662 | + |
657 | 663 | std::string url = json_value(image_url, "url", std::string()); |
658 | | - std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ','); |
659 | | - if (parts.size() != 2) { |
660 | | - throw std::runtime_error("Invalid image_url.url value"); |
661 | | - } else if (!string_starts_with(parts[0], "data:image/")) { |
662 | | - throw std::runtime_error("Invalid image_url.url format: " + parts[0]); |
663 | | - } else if (!string_ends_with(parts[0], "base64")) { |
664 | | - throw std::runtime_error("image_url.url must be base64 encoded"); |
| 664 | + if (string_starts_with(url, "http")) { |
| 665 | + // download remote image |
| 666 | + // TODO @ngxson : maybe make these params configurable |
| 667 | + common_remote_params params; |
| 668 | + params.headers.push_back("User-Agent: llama.cpp/" + build_info); |
| 669 | + params.max_size = 1024 * 1024 * 10; // 10MB |
| 670 | + auto res = common_remote_get_content(url, params); |
| 671 | + raw_buffer data; |
| 672 | + data.insert(data.end(), res.second.begin(), res.second.end()); |
| 673 | + out_files.push_back(data); |
| 674 | + |
665 | 675 | } else { |
666 | | - auto base64_data = parts[1]; |
667 | | - auto decoded_data = base64_decode(base64_data); |
668 | | - out_files.push_back(decoded_data); |
| 676 | + // try to decode base64 image |
| 677 | + std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ','); |
| 678 | + if (parts.size() != 2) { |
| 679 | + throw std::runtime_error("Invalid image_url.url value"); |
| 680 | + } else if (!string_starts_with(parts[0], "data:image/")) { |
| 681 | + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); |
| 682 | + } else if (!string_ends_with(parts[0], "base64")) { |
| 683 | + throw std::runtime_error("image_url.url must be base64 encoded"); |
| 684 | + } else { |
| 685 | + auto base64_data = parts[1]; |
| 686 | + auto decoded_data = base64_decode(base64_data); |
| 687 | + out_files.push_back(decoded_data); |
| 688 | + } |
669 | 689 | } |
| 690 | + |
| 691 | + // replace this chunk with a marker |
670 | 692 | p["type"] = "text"; |
671 | | - p["text"] = "<__image__>"; |
| 693 | + p["text"] = MTMD_DEFAULT_IMAGE_MARKER; |
672 | 694 | p.erase("image_url"); |
673 | 695 | } |
674 | 696 | } |
|
0 commit comments