Skip to content

Commit cff131c

Browse files
committed
Qwen3-Coder XML: handle union schema types and sanitize unsupported branches; add tests
- chat-parser: support schema.type as array (e.g. ["number","null"]) in convert_qwen3_param_value() - chat: resolve $refs; allow unions including "string" as freeform; sanitize empty {"not":{}} in anyOf/oneOf before add_schema - tests: add Qwen3-Coder regression ensuring grammar builds with unions and ignores {"not":{}}
1 parent 9a2cca8 commit cff131c

File tree

3 files changed

+217
-46
lines changed

3 files changed

+217
-46
lines changed

common/chat-parser.cpp

Lines changed: 99 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -649,44 +649,106 @@ namespace {
649649

650650
// If we have schema information, use it
651651
if (param_config.contains(param_name)) {
652-
std::string param_type = "string";
653-
if (param_config[param_name].contains("type")) {
654-
param_type = param_config[param_name]["type"];
655-
}
656-
657-
// Convert based on type
658-
if (param_type == "string" || param_type == "str" || param_type == "text") {
659-
// SECURITY FIX: Use nlohmann::json for proper escaping instead of manual concatenation
660-
return json(trimmed_value).dump();
661-
} else if (param_type == "integer" || param_type == "int") {
662-
int int_val;
663-
if (safe_parse_int(trimmed_value, int_val)) {
664-
return std::to_string(int_val);
665-
} else {
666-
// SECURITY FIX: Use proper JSON escaping for fallback string
667-
return json(trimmed_value).dump();
668-
}
669-
} else if (param_type == "number" || param_type == "float") {
670-
float float_val;
671-
if (safe_parse_float(trimmed_value, float_val)) {
672-
return std::to_string(float_val);
673-
} else {
674-
// SECURITY FIX: Use proper JSON escaping for fallback string
675-
return json(trimmed_value).dump();
676-
}
677-
} else if (param_type == "boolean" || param_type == "bool") {
678-
if (trimmed_value == "true" || trimmed_value == "false") {
679-
return trimmed_value;
680-
}
681-
return "false";
682-
} else if (param_type == "object" || param_type == "array") {
683-
try {
684-
auto parsed = json::parse(trimmed_value);
685-
return parsed.dump();
686-
} catch (...) {
687-
// SECURITY FIX: Use proper JSON escaping for fallback string
688-
return json(trimmed_value).dump();
652+
const auto & schema = param_config.at(std::string(param_name));
653+
if (schema.contains("type")) {
654+
const auto & t = schema.at("type");
655+
// Handle union types like ["number","null"]
656+
if (t.is_array()) {
657+
std::vector<std::string> types;
658+
for (const auto & tv : t) {
659+
if (tv.is_string()) {
660+
types.push_back((std::string) tv);
661+
}
662+
}
663+
auto list_contains = [&](const char * s) {
664+
for (const auto & x : types) {
665+
if (x == s) return true;
666+
}
667+
return false;
668+
};
669+
auto has = [&](std::string_view ty) {
670+
for (const auto & s : types) {
671+
if (s == ty) return true;
672+
}
673+
// Back-compat synonyms
674+
if (ty == "string") return list_contains("str") || list_contains("text");
675+
if (ty == "integer") return list_contains("int");
676+
if (ty == "number") return list_contains("float");
677+
if (ty == "boolean") return list_contains("bool");
678+
return false;
679+
};
680+
if (has("null") && trimmed_value == "null") {
681+
return "null";
682+
}
683+
if (has("object") || has("array")) {
684+
try {
685+
auto parsed = json::parse(trimmed_value);
686+
return parsed.dump();
687+
} catch (...) {
688+
return json(trimmed_value).dump();
689+
}
690+
}
691+
if (has("integer")) {
692+
int int_val;
693+
if (safe_parse_int(trimmed_value, int_val)) {
694+
return std::to_string(int_val);
695+
}
696+
// if integer parse fails, try number or fall through
697+
}
698+
if (has("number")) {
699+
float float_val;
700+
if (safe_parse_float(trimmed_value, float_val)) {
701+
return std::to_string(float_val);
702+
}
703+
}
704+
if (has("boolean")) {
705+
if (trimmed_value == "true" || trimmed_value == "false") {
706+
return trimmed_value;
707+
}
708+
return "false";
709+
}
710+
if (has("string")) {
711+
return json(trimmed_value).dump();
712+
}
713+
// Unknown union types: fall through to generic inference below
714+
} else if (t.is_string()) {
715+
std::string param_type = t;
716+
// Convert based on type
717+
if (param_type == "string" || param_type == "str" || param_type == "text") {
718+
// SECURITY FIX: Use nlohmann::json for proper escaping instead of manual concatenation
719+
return json(trimmed_value).dump();
720+
} else if (param_type == "integer" || param_type == "int") {
721+
int int_val;
722+
if (safe_parse_int(trimmed_value, int_val)) {
723+
return std::to_string(int_val);
724+
} else {
725+
// SECURITY FIX: Use proper JSON escaping for fallback string
726+
return json(trimmed_value).dump();
727+
}
728+
} else if (param_type == "number" || param_type == "float") {
729+
float float_val;
730+
if (safe_parse_float(trimmed_value, float_val)) {
731+
return std::to_string(float_val);
732+
} else {
733+
// SECURITY FIX: Use proper JSON escaping for fallback string
734+
return json(trimmed_value).dump();
735+
}
736+
} else if (param_type == "boolean" || param_type == "bool") {
737+
if (trimmed_value == "true" || trimmed_value == "false") {
738+
return trimmed_value;
739+
}
740+
return "false";
741+
} else if (param_type == "object" || param_type == "array") {
742+
try {
743+
auto parsed = json::parse(trimmed_value);
744+
return parsed.dump();
745+
} catch (...) {
746+
// SECURITY FIX: Use proper JSON escaping for fallback string
747+
return json(trimmed_value).dump();
748+
}
749+
}
689750
}
751+
// If schema.type exists but is not string/array, fall through
690752
}
691753
}
692754

common/chat.cpp

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,7 +2097,8 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c
20972097
foreach_function(inputs.tools, [&](const json & tool) {
20982098
const auto & function = tool.at("function");
20992099
const std::string & name = function.at("name");
2100-
const json & parameters = function.at("parameters");
2100+
auto parameters = function.at("parameters");
2101+
builder.resolve_refs(parameters);
21012102

21022103
std::unordered_set<std::string> required;
21032104
if (parameters.contains("required")) {
@@ -2112,16 +2113,93 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c
21122113
for (const auto & [param_name, param_schema] : parameters["properties"].items()) {
21132114
std::string param_rule = "\"<parameter=" + param_name + ">\" space ";
21142115

2115-
// Add parameter value based on type
2116-
if (param_schema.contains("type")) {
2117-
std::string param_type = param_schema["type"];
2118-
if (param_type == "string") {
2119-
param_rule += not_parameter_end;
2120-
} else {
2121-
param_rule += builder.add_schema(name + "-parameter-" + param_name, param_schema);
2116+
// Add parameter value based on type (supports unions and anyOf/oneOf; sanitize unsupported {"not":{}} branches)
2117+
auto schema_local = param_schema;
2118+
2119+
// Recursively remove entries like {"not":{}} inside anyOf/oneOf that json-schema-to-grammar doesn't support
2120+
std::function<void(json &)> sanitize = [&](json &s) {
2121+
if (s.is_object()) {
2122+
if (s.contains("anyOf") && s["anyOf"].is_array()) {
2123+
json filtered = json::array();
2124+
for (auto v : s["anyOf"]) {
2125+
if (v.is_object() && v.contains("not") && v["not"].is_object() && v["not"].empty()) {
2126+
continue;
2127+
}
2128+
sanitize(v);
2129+
filtered.push_back(v);
2130+
}
2131+
s["anyOf"] = filtered;
2132+
if (s["anyOf"].size() == 1) {
2133+
json single = s["anyOf"][0];
2134+
s.erase("anyOf");
2135+
for (auto it = single.begin(); it != single.end(); ++it) {
2136+
s[it.key()] = it.value();
2137+
}
2138+
}
2139+
}
2140+
if (s.contains("oneOf") && s["oneOf"].is_array()) {
2141+
json filtered = json::array();
2142+
for (auto v : s["oneOf"]) {
2143+
if (v.is_object() && v.contains("not") && v["not"].is_object() && v["not"].empty()) {
2144+
continue;
2145+
}
2146+
sanitize(v);
2147+
filtered.push_back(v);
2148+
}
2149+
s["oneOf"] = filtered;
2150+
if (s["oneOf"].size() == 1) {
2151+
json single = s["oneOf"][0];
2152+
s.erase("oneOf");
2153+
for (auto it = single.begin(); it != single.end(); ++it) {
2154+
s[it.key()] = it.value();
2155+
}
2156+
}
2157+
}
2158+
for (auto it = s.begin(); it != s.end(); ++it) {
2159+
sanitize(it.value());
2160+
}
2161+
} else if (s.is_array()) {
2162+
for (auto & v : s) sanitize(v);
21222163
}
2164+
};
2165+
sanitize(schema_local);
2166+
2167+
// Determine if schema allows a plain string (so we can accept unquoted text content in XML)
2168+
std::function<bool(const json &)> allows_string = [&](const json & sch) -> bool {
2169+
if (!sch.is_object()) return false;
2170+
if (sch.contains("type")) {
2171+
const auto & t = sch.at("type");
2172+
if (t.is_string()) {
2173+
std::string ts = t;
2174+
return ts == "string" || ts == "text" || ts == "str";
2175+
}
2176+
if (t.is_array()) {
2177+
for (const auto & tv : t) {
2178+
if (tv.is_string() && (tv == "string" || tv == "text" || tv == "str")) {
2179+
return true;
2180+
}
2181+
}
2182+
}
2183+
}
2184+
if (sch.contains("anyOf") && sch["anyOf"].is_array()) {
2185+
for (const auto & v : sch["anyOf"]) {
2186+
if (allows_string(v)) return true;
2187+
}
2188+
}
2189+
if (sch.contains("oneOf") && sch["oneOf"].is_array()) {
2190+
for (const auto & v : sch["oneOf"]) {
2191+
if (allows_string(v)) return true;
2192+
}
2193+
}
2194+
return false;
2195+
};
2196+
2197+
if (allows_string(schema_local)) {
2198+
// For string-accepting schemas, keep freeform XML text (no JSON quoting)
2199+
param_rule += not_parameter_end;
21232200
} else {
2124-
param_rule += builder.add_schema(name + "-parameter-" + param_name, param_schema);
2201+
// For non-strings (object/array/number/boolean/null), expect JSON per schema
2202+
param_rule += builder.add_schema(name + "-parameter-" + param_name, schema_local);
21252203
}
21262204

21272205
param_rule += "\"</parameter>\" space";

tests/test-chat.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,37 @@ static void test_template_output_parsers() {
20232023

20242024
printf("✅ All Qwen3-Coder XML error handling and edge case tests passed!\n");
20252025
}
2026+
{
2027+
// Qwen3-Coder template: ensure grammar builds with union types and unsupported {"not": {}} branches
2028+
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
2029+
common_chat_templates_inputs inputs;
2030+
inputs.messages = { message_user };
2031+
2032+
common_chat_tool qwen_union_tool {
2033+
/* .name = */ "qwen_union",
2034+
/* .description = */ "Test tool for union/anyOf handling",
2035+
/* .parameters = */ R"({
2036+
"type": "object",
2037+
"properties": {
2038+
"priority": { "type": ["number", "null"] },
2039+
"maybe_text": { "anyOf": [ { "not": {} }, { "type": "string" } ] },
2040+
"config": { "anyOf": [ { "type": "object" }, { "type": "null" } ] }
2041+
},
2042+
"required": []
2043+
})",
2044+
};
2045+
inputs.tools = { qwen_union_tool };
2046+
2047+
auto params = common_chat_templates_apply(tmpls.get(), inputs);
2048+
assert_equals(COMMON_CHAT_FORMAT_QWEN3_CODER_XML, params.format);
2049+
assert_equals(false, params.grammar.empty());
2050+
2051+
// Grammar should compile successfully
2052+
auto grammar = build_grammar(params.grammar);
2053+
if (!grammar) {
2054+
throw std::runtime_error("Failed to build Qwen3-Coder grammar with union types");
2055+
}
2056+
}
20262057

20272058
{
20282059
auto tmpls = read_templates("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja");

0 commit comments

Comments
 (0)