Skip to content

Commit 7057faf

Browse files
authored
json : support enum values within allOf (ggml-org#15830)
1 parent fe1c92c commit 7057faf

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed

common/json-schema-to-grammar.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,10 @@ class SchemaConverter {
843843
_build_object_rule(
844844
properties, required, name,
845845
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
846-
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
846+
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
847847
std::unordered_set<std::string> required;
848848
std::vector<std::pair<std::string, json>> properties;
849+
std::map<std::string, size_t> enum_values;
849850
std::string hybrid_name = name;
850851
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
851852
if (comp_schema.contains("$ref")) {
@@ -857,6 +858,14 @@ class SchemaConverter {
857858
required.insert(prop.key());
858859
}
859860
}
861+
} else if (comp_schema.contains("enum")) {
862+
for (const auto & v : comp_schema["enum"]) {
863+
const auto rule = _generate_constant_rule(v);
864+
if (enum_values.find(rule) == enum_values.end()) {
865+
enum_values[rule] = 0;
866+
}
867+
enum_values[rule] += 1;
868+
}
860869
} else {
861870
// todo warning
862871
}
@@ -870,6 +879,17 @@ class SchemaConverter {
870879
add_component(t, true);
871880
}
872881
}
882+
if (!enum_values.empty()) {
883+
std::vector<std::string> enum_intersection;
884+
for (const auto & p : enum_values) {
885+
if (p.second == schema["allOf"].size()) {
886+
enum_intersection.push_back(p.first);
887+
}
888+
}
889+
if (!enum_intersection.empty()) {
890+
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
891+
}
892+
}
873893
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
874894
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
875895
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];

examples/json_schema_to_grammar.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,10 @@ def visit(self, schema, name):
586586
properties = list(schema.get('properties', {}).items())
587587
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
588588

589-
elif schema_type in (None, 'object') and 'allOf' in schema:
589+
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
590590
required = set()
591591
properties = []
592+
enum_sets = []
592593
hybrid_name = name
593594
def add_component(comp_schema, is_required):
594595
if (ref := comp_schema.get('$ref')) is not None:
@@ -600,13 +601,25 @@ def add_component(comp_schema, is_required):
600601
if is_required:
601602
required.add(prop_name)
602603

604+
if 'enum' in comp_schema:
605+
enum_sets.append(set(comp_schema['enum']))
606+
603607
for t in schema['allOf']:
604608
if 'anyOf' in t:
605609
for tt in t['anyOf']:
606610
add_component(tt, is_required=False)
607611
else:
608612
add_component(t, is_required=True)
609613

614+
if enum_sets:
615+
enum_intersection = enum_sets[0]
616+
for s in enum_sets[1:]:
617+
enum_intersection &= s
618+
619+
if enum_intersection:
620+
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
621+
return self._add_rule(rule_name, rule)
622+
610623
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
611624

612625
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):

tests/test-json-schema-to-grammar.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,51 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
12091209
)"""
12101210
});
12111211

1212+
test({
1213+
SUCCESS,
1214+
"allOf with enum schema",
1215+
R"""({
1216+
"allOf": [
1217+
{"$ref": "#/definitions/foo"}
1218+
],
1219+
"definitions": {
1220+
"foo": {
1221+
"type": "string",
1222+
"enum": ["a", "b"]
1223+
}
1224+
}
1225+
})""",
1226+
R"""(
1227+
root ::= ("\"a\"" | "\"b\"") space
1228+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
1229+
)"""
1230+
});
1231+
1232+
test({
1233+
SUCCESS,
1234+
"allOf with multiple enum schemas",
1235+
R"""({
1236+
"allOf": [
1237+
{"$ref": "#/definitions/foo"},
1238+
{"$ref": "#/definitions/bar"}
1239+
],
1240+
"definitions": {
1241+
"foo": {
1242+
"type": "string",
1243+
"enum": ["a", "b", "c"]
1244+
},
1245+
"bar": {
1246+
"type": "string",
1247+
"enum": ["b", "c", "d"]
1248+
}
1249+
}
1250+
})""",
1251+
R"""(
1252+
root ::= ("\"b\"" | "\"c\"") space
1253+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
1254+
)"""
1255+
});
1256+
12121257
test({
12131258
SUCCESS,
12141259
"conflicting names",

tools/server/public_legacy/json-schema-to-grammar.mjs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,9 +631,10 @@ export class SchemaConverter {
631631
const required = new Set(schema.required || []);
632632
const properties = Object.entries(schema.properties ?? {});
633633
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
634-
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
634+
} else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) {
635635
const required = new Set();
636636
const properties = [];
637+
const enumSets = [];
637638
const addComponent = (compSchema, isRequired) => {
638639
const ref = compSchema.$ref;
639640
if (ref !== undefined) {
@@ -648,6 +649,10 @@ export class SchemaConverter {
648649
}
649650
}
650651
}
652+
653+
if ('enum' in compSchema) {
654+
enumSets.push(new Set(compSchema.enum || []));
655+
}
651656
};
652657

653658
for (const t of schema.allOf) {
@@ -660,6 +665,14 @@ export class SchemaConverter {
660665
}
661666
}
662667

668+
if (enumSets.length > 0) {
669+
const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v))));
670+
if (enumIntersection.size > 0) {
671+
const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b));
672+
const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
673+
return this._addRule(ruleName, rule);
674+
}
675+
}
663676
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
664677
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
665678
const items = schema.items ?? schema.prefixItems;

0 commit comments

Comments
 (0)