Skip to content

Commit a5f3e4a

Browse files
committed
PoC recursive schemas
1 parent 20a09d6 commit a5f3e4a

File tree

3 files changed

+102
-27
lines changed

3 files changed

+102
-27
lines changed

src/hypothesis_jsonschema/_canonicalise.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,12 @@ def resolve_all_refs(
579579
f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver"
580580
)
581581

582-
if "$ref" in schema:
582+
def is_recursive(reference: str) -> bool:
583+
return reference == "#" or resolver.resolution_scope == reference # type: ignore
584+
585+
# To avoid infinite recursion, we skip all recursive definitions, and such references will be processed later
586+
# A definition is recursive if it contains a reference to itself or one of its ancestors.
587+
if "$ref" in schema and not is_recursive(schema["$ref"]): # type: ignore
583588
s = dict(schema)
584589
ref = s.pop("$ref")
585590
with resolver.resolving(ref) as got:
@@ -590,7 +595,6 @@ def resolve_all_refs(
590595
msg = f"$ref:{ref!r} had incompatible base schema {s!r}"
591596
raise HypothesisRefResolutionError(msg)
592597
return resolve_all_refs(m, resolver=resolver)
593-
assert "$ref" not in schema
594598

595599
for key in SCHEMA_KEYS:
596600
val = schema.get(key, False)

src/hypothesis_jsonschema/_from_schema.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import operator
66
import re
7+
from copy import deepcopy
78
from fractions import Fraction
89
from functools import partial
910
from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union
@@ -18,6 +19,8 @@
1819
TRUTHY,
1920
TYPE_STRINGS,
2021
HypothesisRefResolutionError,
22+
JSONType,
23+
LocalResolver,
2124
Schema,
2225
canonicalish,
2326
get_integer_bounds,
@@ -42,11 +45,13 @@
4245

4346

4447
def merged_as_strategies(
45-
schemas: List[Schema], custom_formats: Optional[Dict[str, st.SearchStrategy[str]]]
48+
schemas: List[Schema],
49+
custom_formats: Optional[Dict[str, st.SearchStrategy[str]]],
50+
resolver: LocalResolver,
4651
) -> st.SearchStrategy[JSONType]:
4752
assert schemas, "internal error: must pass at least one schema to merge"
4853
if len(schemas) == 1:
49-
return from_schema(schemas[0], custom_formats=custom_formats)
54+
return from_schema(schemas[0], custom_formats=custom_formats, resolver=resolver)
5055
# Try to merge combinations of strategies.
5156
strats = []
5257
combined: Set[str] = set()
@@ -60,7 +65,7 @@ def merged_as_strategies(
6065
if s is not None and s != FALSEY:
6166
validators = [make_validator(s) for s in schemas]
6267
strats.append(
63-
from_schema(s, custom_formats=custom_formats).filter(
68+
from_schema(s, custom_formats=custom_formats, resolver=resolver).filter(
6469
lambda obj: all(v.is_valid(obj) for v in validators)
6570
)
6671
)
@@ -72,14 +77,15 @@ def from_schema(
7277
schema: Union[bool, Schema],
7378
*,
7479
custom_formats: Dict[str, st.SearchStrategy[str]] = None,
80+
resolver: Optional[LocalResolver] = None,
7581
) -> st.SearchStrategy[JSONType]:
7682
"""Take a JSON schema and return a strategy for allowed JSON objects.
7783
7884
Schema reuse with "definitions" and "$ref" is not yet supported, but
7985
everything else in drafts 04, 05, and 07 is fully tested and working.
8086
"""
8187
try:
82-
return __from_schema(schema, custom_formats=custom_formats)
88+
return __from_schema(schema, custom_formats=custom_formats, resolver=resolver)
8389
except Exception as err:
8490
error = err
8591

@@ -112,9 +118,10 @@ def __from_schema(
112118
schema: Union[bool, Schema],
113119
*,
114120
custom_formats: Dict[str, st.SearchStrategy[str]] = None,
121+
resolver: Optional[LocalResolver] = None,
115122
) -> st.SearchStrategy[JSONType]:
116123
try:
117-
schema = resolve_all_refs(schema)
124+
schema = resolve_all_refs(schema, resolver=resolver)
118125
except RecursionError:
119126
raise HypothesisRefResolutionError(
120127
f"Could not resolve recursive references in schema={schema!r}"
@@ -141,6 +148,9 @@ def __from_schema(
141148
}
142149
custom_formats[_FORMATS_TOKEN] = None # type: ignore
143150

151+
if resolver is None:
152+
resolver = LocalResolver.from_schema(deepcopy(schema))
153+
144154
schema = canonicalish(schema)
145155
# Boolean objects are special schemata; False rejects all and True accepts all.
146156
if schema == FALSEY:
@@ -155,32 +165,44 @@ def __from_schema(
155165

156166
assert isinstance(schema, dict)
157167
# Now we handle as many validation keywords as we can...
168+
if "$ref" in schema:
169+
ref = schema["$ref"]
170+
171+
def _recurse() -> st.SearchStrategy[JSONType]:
172+
_, resolved = resolver.resolve(ref) # type: ignore
173+
return from_schema(
174+
resolved, custom_formats=custom_formats, resolver=resolver
175+
)
176+
177+
return st.deferred(_recurse)
158178
# Applying subschemata with boolean logic
159179
if "not" in schema:
160180
not_ = schema.pop("not")
161181
assert isinstance(not_, dict)
162182
validator = make_validator(not_).is_valid
163-
return from_schema(schema, custom_formats=custom_formats).filter(
164-
lambda v: not validator(v)
165-
)
183+
return from_schema(
184+
schema, custom_formats=custom_formats, resolver=resolver
185+
).filter(lambda v: not validator(v))
166186
if "anyOf" in schema:
167187
tmp = schema.copy()
168188
ao = tmp.pop("anyOf")
169189
assert isinstance(ao, list)
170-
return st.one_of([merged_as_strategies([tmp, s], custom_formats) for s in ao])
190+
return st.one_of(
191+
[merged_as_strategies([tmp, s], custom_formats, resolver) for s in ao]
192+
)
171193
if "allOf" in schema:
172194
tmp = schema.copy()
173195
ao = tmp.pop("allOf")
174196
assert isinstance(ao, list)
175-
return merged_as_strategies([tmp] + ao, custom_formats)
197+
return merged_as_strategies([tmp] + ao, custom_formats, resolver)
176198
if "oneOf" in schema:
177199
tmp = schema.copy()
178200
oo = tmp.pop("oneOf")
179201
assert isinstance(oo, list)
180202
schemas = [merged([tmp, s]) for s in oo]
181203
return st.one_of(
182204
[
183-
from_schema(s, custom_formats=custom_formats)
205+
from_schema(s, custom_formats=custom_formats, resolver=resolver)
184206
for s in schemas
185207
if s is not None
186208
]
@@ -198,8 +220,8 @@ def __from_schema(
198220
"number": number_schema,
199221
"integer": integer_schema,
200222
"string": partial(string_schema, custom_formats),
201-
"array": partial(array_schema, custom_formats),
202-
"object": partial(object_schema, custom_formats),
223+
"array": partial(array_schema, custom_formats, resolver),
224+
"object": partial(object_schema, custom_formats, resolver),
203225
}
204226
assert set(map_) == set(TYPE_STRINGS)
205227
return st.one_of([map_[t](schema) for t in get_type(schema)])
@@ -422,10 +444,14 @@ def string_schema(
422444

423445

424446
def array_schema(
425-
custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict
447+
custom_formats: Dict[str, st.SearchStrategy[str]],
448+
resolver: LocalResolver,
449+
schema: dict,
426450
) -> st.SearchStrategy[List[JSONType]]:
427451
"""Handle schemata for arrays."""
428-
_from_schema_ = partial(from_schema, custom_formats=custom_formats)
452+
_from_schema_ = partial(
453+
from_schema, custom_formats=custom_formats, resolver=resolver
454+
)
429455
items = schema.get("items", {})
430456
additional_items = schema.get("additionalItems", {})
431457
min_size = schema.get("minItems", 0)
@@ -436,14 +462,16 @@ def array_schema(
436462
if max_size is not None:
437463
max_size -= len(items)
438464

439-
items_strats = [_from_schema_(s) for s in items]
465+
items_strats = [_from_schema_(s) for s in deepcopy(items)]
440466
additional_items_strat = _from_schema_(additional_items)
441467

442468
# If we have a contains schema to satisfy, we try generating from it when
443469
# allowed to do so. We'll skip the None (unmergable / no contains) cases
444470
# below, and let Hypothesis ignore the FALSEY cases for us.
445471
if "contains" in schema:
446-
for i, mrgd in enumerate(merged([schema["contains"], s]) for s in items):
472+
for i, mrgd in enumerate(
473+
merged([schema["contains"], s]) for s in deepcopy(items)
474+
):
447475
if mrgd is not None:
448476
items_strats[i] |= _from_schema_(mrgd)
449477
contains_additional = merged([schema["contains"], additional_items])
@@ -480,10 +508,10 @@ def not_seen(elem: JSONType) -> bool:
480508
st.lists(additional_items_strat, min_size=min_size, max_size=max_size),
481509
)
482510
else:
483-
items_strat = _from_schema_(items)
511+
items_strat = _from_schema_(deepcopy(items))
484512
if "contains" in schema:
485513
contains_strat = _from_schema_(schema["contains"])
486-
if merged([items, schema["contains"]]) != schema["contains"]:
514+
if merged([deepcopy(items), schema["contains"]]) != schema["contains"]:
487515
# We only need this filter if we couldn't merge items in when
488516
# canonicalising. Note that for list-items, above, we just skip
489517
# the mixed generation in this case (because they tend to be
@@ -504,7 +532,9 @@ def not_seen(elem: JSONType) -> bool:
504532

505533

506534
def object_schema(
507-
custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict
535+
custom_formats: Dict[str, st.SearchStrategy[str]],
536+
resolver: LocalResolver,
537+
schema: dict,
508538
) -> st.SearchStrategy[Dict[str, JSONType]]:
509539
"""Handle a manageable subset of possible schemata for objects."""
510540
required = schema.get("required", []) # required keys
@@ -518,7 +548,7 @@ def object_schema(
518548
return st.builds(dict)
519549
names["type"] = "string"
520550

521-
properties = schema.get("properties", {}) # exact name: value schema
551+
properties = deepcopy(schema.get("properties", {})) # exact name: value schema
522552
patterns = schema.get("patternProperties", {}) # regex for names: value schema
523553
# schema for other values; handled specially if nothing matches
524554
additional = schema.get("additionalProperties", {})
@@ -533,7 +563,7 @@ def object_schema(
533563
st.sampled_from(sorted(dep_names) + sorted(dep_schemas) + sorted(properties))
534564
if (dep_names or dep_schemas or properties)
535565
else st.nothing(),
536-
from_schema(names, custom_formats=custom_formats)
566+
from_schema(names, custom_formats=custom_formats, resolver=resolver)
537567
if additional_allowed
538568
else st.nothing(),
539569
st.one_of([st.from_regex(p) for p in sorted(patterns)]),
@@ -579,12 +609,20 @@ def from_object_schema(draw: Any) -> Any:
579609
if re.search(rgx, string=key) is not None
580610
]
581611
if key in properties:
582-
pattern_schemas.insert(0, properties[key])
612+
pattern_schemas.insert(0, deepcopy(properties[key]))
583613

584614
if pattern_schemas:
585-
out[key] = draw(merged_as_strategies(pattern_schemas, custom_formats))
615+
out[key] = draw(
616+
merged_as_strategies(pattern_schemas, custom_formats, resolver)
617+
)
586618
else:
587-
out[key] = draw(from_schema(additional, custom_formats=custom_formats))
619+
out[key] = draw(
620+
from_schema(
621+
deepcopy(additional),
622+
custom_formats=custom_formats,
623+
resolver=resolver,
624+
)
625+
)
588626

589627
for k, v in dep_schemas.items():
590628
if k in out and not make_validator(v).is_valid(out):

tests/test_from_schema.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,36 @@ def test_allowed_custom_format(num):
425425
def test_allowed_unknown_custom_format(string):
426426
assert string == "hello world"
427427
assert "not registered" not in jsonschema.FormatChecker().checkers
428+
429+
430+
@pytest.mark.parametrize(
431+
"schema",
432+
(
433+
{
434+
"properties": {"foo": {"$ref": "#"}},
435+
"additionalProperties": False,
436+
"type": "object",
437+
},
438+
{
439+
"definitions": {
440+
"Node": {
441+
"type": "object",
442+
"properties": {
443+
"children": {
444+
"type": "array",
445+
"items": {"$ref": "#/definitions/Node"},
446+
"maxItems": 2,
447+
}
448+
},
449+
"required": ["children"],
450+
"additionalProperties": False,
451+
},
452+
},
453+
"$ref": "#/definitions/Node",
454+
},
455+
),
456+
)
457+
@given(data=st.data())
458+
def test_recursive_reference(data, schema):
459+
value = data.draw(from_schema(schema))
460+
jsonschema.validate(value, schema)

0 commit comments

Comments
 (0)