diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 157100c..93175f3 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -17,7 +17,8 @@ import math import re from copy import deepcopy -from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union +from typing import Any, Dict, List, NoReturn, Optional, Set, Tuple, Union +from urllib.parse import urljoin import jsonschema from hypothesis.errors import InvalidArgument @@ -68,6 +69,13 @@ def next_down(val: float) -> float: return out +class LocalResolver(jsonschema.RefResolver): + def resolve_remote(self, uri: str) -> NoReturn: + raise HypothesisRefResolutionError( + f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})" + ) + + def _get_validator_class(schema: Schema) -> JSONSchemaValidator: try: validator = jsonschema.validators.validator_for(schema) @@ -78,9 +86,9 @@ def _get_validator_class(schema: Schema) -> JSONSchemaValidator: return validator -def make_validator(schema: Schema) -> JSONSchemaValidator: +def make_validator(schema: Schema, resolver: LocalResolver) -> JSONSchemaValidator: validator = _get_validator_class(schema) - return validator(schema) + return validator(schema, resolver=resolver) class HypothesisRefResolutionError(jsonschema.exceptions.RefResolutionError): @@ -202,7 +210,9 @@ def get_integer_bounds(schema: Schema) -> Tuple[Optional[int], Optional[int]]: return lower, upper -def canonicalish(schema: JSONType) -> Dict[str, Any]: +def canonicalish( + schema: JSONType, resolver: Optional[LocalResolver] = None +) -> Dict[str, Any]: """Convert a schema into a more-canonical form. This is obviously incomplete, but improves best-effort recognition of @@ -224,12 +234,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: "but expected a dict." ) + if resolver is None: + resolver = LocalResolver.from_schema(deepcopy(schema)) + if "const" in schema: - if not make_validator(schema).is_valid(schema["const"]): + if not make_validator(schema, resolver=resolver).is_valid(schema["const"]): return FALSEY return {"const": schema["const"]} if "enum" in schema: - validator = make_validator(schema) + validator = make_validator(schema, resolver=resolver) enum_ = sorted( (v for v in schema["enum"] if validator.is_valid(v)), key=sort_key ) @@ -253,15 +266,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: # Recurse into the value of each keyword with a schema (or list of them) as a value for key in SCHEMA_KEYS: if isinstance(schema.get(key), list): - schema[key] = [canonicalish(v) for v in schema[key]] + schema[key] = [canonicalish(v, resolver=resolver) for v in schema[key]] elif isinstance(schema.get(key), (bool, dict)): - schema[key] = canonicalish(schema[key]) + schema[key] = canonicalish(schema[key], resolver=resolver) else: assert key not in schema, (key, schema[key]) for key in SCHEMA_OBJECT_KEYS: if key in schema: schema[key] = { - k: v if isinstance(v, list) else canonicalish(v) + k: v if isinstance(v, list) else canonicalish(v, resolver=resolver) for k, v in schema[key].items() } @@ -307,7 +320,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: if "array" in type_ and "contains" in schema: if isinstance(schema.get("items"), dict): - contains_items = merged([schema["contains"], schema["items"]]) + contains_items = merged( + [schema["contains"], schema["items"]], resolver=resolver + ) if contains_items is not None: schema["contains"] = contains_items @@ -432,7 +447,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: type_.remove("object") else: propnames = schema.get("propertyNames", {}) - validator = make_validator(propnames) + validator = make_validator(propnames, resolver=resolver) if not all(validator.is_valid(name) for name in schema["required"]): type_.remove("object") @@ -461,9 +476,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: type_.remove(t) if t not in ("integer", "number"): not_["type"].remove(t) - not_ = canonicalish(not_) + not_ = canonicalish(not_, resolver=resolver) - m = merged([not_, {**schema, "type": type_}]) + m = merged([not_, {**schema, "type": type_}], resolver=resolver) if m is not None: not_ = m if not_ != FALSEY: @@ -525,7 +540,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: else: tmp = schema.copy() ao = tmp.pop("allOf") - out = merged([tmp] + ao) + out = merged([tmp] + ao, resolver=resolver) if isinstance(out, dict): # pragma: no branch schema = out # TODO: this assertion is soley because mypy 0.750 doesn't know @@ -537,7 +552,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: one_of = sorted(one_of, key=encode_canonical_json) one_of = [s for s in one_of if s != FALSEY] if len(one_of) == 1: - m = merged([schema, one_of[0]]) + m = merged([schema, one_of[0]], resolver=resolver) if m is not None: # pragma: no branch return m if (not one_of) or one_of.count(TRUTHY) > 1: @@ -552,23 +567,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: FALSEY = canonicalish(False) -class LocalResolver(jsonschema.RefResolver): - def resolve_remote(self, uri: str) -> NoReturn: - raise HypothesisRefResolutionError( - f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})" - ) - - def resolve_all_refs( - schema: Union[bool, Schema], *, resolver: LocalResolver = None + schema: Union[bool, Schema], + *, + resolver: LocalResolver = None, + seen_map: Dict[str, Set[str]] = None, ) -> Schema: - """ - Resolve all references in the given schema. - - This handles nested definitions, but not recursive definitions. - The latter require special handling to convert to strategies and are much - less common, so we just ignore them (and error out) for now. - """ + """Resolve all non-recursive references in the given schema.""" + if seen_map is None: + seen_map = {} if isinstance(schema, bool): return canonicalish(schema) assert isinstance(schema, dict), schema @@ -579,28 +586,42 @@ def resolve_all_refs( f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver" ) + def is_recursive(reference: str) -> bool: + full_ref = urljoin(resolver.base_uri, reference) # type: ignore + return reference == "#" or reference in resolver._scopes_stack or full_ref in resolver._scopes_stack # type: ignore + + # To avoid infinite recursion, we skip all recursive definitions, and such references will be processed later + # A definition is recursive if it contains a reference to itself or one of its ancestors. if "$ref" in schema: - s = dict(schema) - ref = s.pop("$ref") - with resolver.resolving(ref) as got: - if s == {}: - return resolve_all_refs(got, resolver=resolver) - m = merged([s, got]) - if m is None: # pragma: no cover - msg = f"$ref:{ref!r} had incompatible base schema {s!r}" - raise HypothesisRefResolutionError(msg) - return resolve_all_refs(m, resolver=resolver) - assert "$ref" not in schema + path = "-".join(resolver._scopes_stack) + seen_paths = seen_map.setdefault(path, set()) + if schema["$ref"] not in seen_paths and not is_recursive(schema["$ref"]): # type: ignore + seen_paths.add(schema["$ref"]) # type: ignore + s = dict(schema) + ref = s.pop("$ref") + with resolver.resolving(ref) as got: + if s == {}: + return resolve_all_refs(got, resolver=resolver, seen_map=seen_map) + m = merged([s, got]) + if m is None: # pragma: no cover + msg = f"$ref:{ref!r} had incompatible base schema {s!r}" + raise HypothesisRefResolutionError(msg) + + return resolve_all_refs(m, resolver=resolver, seen_map=seen_map) for key in SCHEMA_KEYS: val = schema.get(key, False) if isinstance(val, list): schema[key] = [ - resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v + resolve_all_refs(deepcopy(v), resolver=resolver, seen_map=seen_map) + if isinstance(v, dict) + else v for v in val ] elif isinstance(val, dict): - schema[key] = resolve_all_refs(val, resolver=resolver) + schema[key] = resolve_all_refs( + deepcopy(val), resolver=resolver, seen_map=seen_map + ) else: assert isinstance(val, bool) for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas @@ -608,14 +629,18 @@ def resolve_all_refs( subschema = schema[key] assert isinstance(subschema, dict) schema[key] = { - k: resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v + k: resolve_all_refs(deepcopy(v), resolver=resolver, seen_map=seen_map) + if isinstance(v, dict) + else v for k, v in subschema.items() } assert isinstance(schema, dict) return schema -def merged(schemas: List[Any]) -> Optional[Schema]: +def merged( + schemas: List[Any], resolver: Optional[LocalResolver] = None +) -> Optional[Schema]: """Merge *n* schemas into a single schema, or None if result is invalid. Takes the logical intersection, so any object that validates against the returned @@ -628,7 +653,9 @@ def merged(schemas: List[Any]) -> Optional[Schema]: It's currently also used for keys that could be merged but aren't yet. """ assert schemas, "internal error: must pass at least one schema to merge" - schemas = sorted((canonicalish(s) for s in schemas), key=upper_bound_instances) + schemas = sorted( + (canonicalish(s, resolver=resolver) for s in schemas), key=upper_bound_instances + ) if any(s == FALSEY for s in schemas): return FALSEY out = schemas[0] @@ -637,11 +664,11 @@ def merged(schemas: List[Any]) -> Optional[Schema]: continue # If we have a const or enum, this is fairly easy by filtering: if "const" in out: - if make_validator(s).is_valid(out["const"]): + if make_validator(s, resolver=resolver).is_valid(out["const"]): continue return FALSEY if "enum" in out: - validator = make_validator(s) + validator = make_validator(s, resolver=resolver) enum_ = [v for v in out["enum"] if validator.is_valid(v)] if not enum_: return FALSEY @@ -692,21 +719,23 @@ def merged(schemas: List[Any]) -> Optional[Schema]: else: out_combined = merged( [s for p, s in out_pat.items() if re.search(p, prop_name)] - or [out_add] + or [out_add], + resolver=resolver, ) if prop_name in s_props: s_combined = s_props[prop_name] else: s_combined = merged( [s for p, s in s_pat.items() if re.search(p, prop_name)] - or [s_add] + or [s_add], + resolver=resolver, ) if out_combined is None or s_combined is None: # pragma: no cover # Note that this can only be the case if we were actually going to # use the schema which we attempted to merge, i.e. prop_name was # not in the schema and there were unmergable pattern schemas. return None - m = merged([out_combined, s_combined]) + m = merged([out_combined, s_combined], resolver=resolver) if m is None: return None out_props[prop_name] = m @@ -714,14 +743,17 @@ def merged(schemas: List[Any]) -> Optional[Schema]: # simpler as we merge with either an identical pattern, or additionalProperties. if out_pat or s_pat: for pattern in set(out_pat) | set(s_pat): - m = merged([out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)]) + m = merged( + [out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)], + resolver=resolver, + ) if m is None: # pragma: no cover return None out_pat[pattern] = m out["patternProperties"] = out_pat # Finally, we merge togther the additionalProperties schemas. if out_add or s_add: - m = merged([out_add, s_add]) + m = merged([out_add, s_add], resolver=resolver) if m is None: # pragma: no cover return None out["additionalProperties"] = m @@ -755,7 +787,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: return None if "contains" in out and "contains" in s and out["contains"] != s["contains"]: # If one `contains` schema is a subset of the other, we can discard it. - m = merged([out["contains"], s["contains"]]) + m = merged([out["contains"], s["contains"]], resolver=resolver) if m == out["contains"] or m == s["contains"]: out["contains"] = m s.pop("contains") @@ -785,7 +817,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: v = {"required": v} elif isinstance(sval, list): sval = {"required": sval} - m = merged([v, sval]) + m = merged([v, sval], resolver=resolver) if m is None: return None odeps[k] = m @@ -799,26 +831,27 @@ def merged(schemas: List[Any]) -> Optional[Schema]: [ out.get("additionalItems", TRUTHY), s.get("additionalItems", TRUTHY), - ] + ], + resolver=resolver, ) for a, b in itertools.zip_longest(oitems, sitems): if a is None: a = out.get("additionalItems", TRUTHY) elif b is None: b = s.get("additionalItems", TRUTHY) - out["items"].append(merged([a, b])) + out["items"].append(merged([a, b], resolver=resolver)) elif isinstance(oitems, list): - out["items"] = [merged([x, sitems]) for x in oitems] + out["items"] = [merged([x, sitems], resolver=resolver) for x in oitems] out["additionalItems"] = merged( - [out.get("additionalItems", TRUTHY), sitems] + [out.get("additionalItems", TRUTHY), sitems], resolver=resolver ) elif isinstance(sitems, list): - out["items"] = [merged([x, oitems]) for x in sitems] + out["items"] = [merged([x, oitems], resolver=resolver) for x in sitems] out["additionalItems"] = merged( - [s.get("additionalItems", TRUTHY), oitems] + [s.get("additionalItems", TRUTHY), oitems], resolver=resolver ) else: - out["items"] = merged([oitems, sitems]) + out["items"] = merged([oitems, sitems], resolver=resolver) if out["items"] is None: return None if isinstance(out["items"], list) and None in out["items"]: @@ -842,7 +875,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: # If non-validation keys like `title` or `description` don't match, # that doesn't really matter and we'll just go with first we saw. return None - out = canonicalish(out) + out = canonicalish(out, resolver=resolver) if out == FALSEY: return FALSEY assert isinstance(out, dict) diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index 46cc3e1..658cb42 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -4,6 +4,7 @@ import math import operator import re +from copy import deepcopy from fractions import Fraction from functools import partial from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union @@ -17,7 +18,7 @@ FALSEY, TRUTHY, TYPE_STRINGS, - HypothesisRefResolutionError, + LocalResolver, Schema, canonicalish, get_integer_bounds, @@ -42,11 +43,13 @@ def merged_as_strategies( - schemas: List[Schema], custom_formats: Optional[Dict[str, st.SearchStrategy[str]]] + schemas: List[Schema], + custom_formats: Optional[Dict[str, st.SearchStrategy[str]]], + resolver: LocalResolver, ) -> st.SearchStrategy[JSONType]: assert schemas, "internal error: must pass at least one schema to merge" if len(schemas) == 1: - return from_schema(schemas[0], custom_formats=custom_formats) + return from_schema(schemas[0], custom_formats=custom_formats, resolver=resolver) # Try to merge combinations of strategies. strats = [] combined: Set[str] = set() @@ -56,11 +59,11 @@ def merged_as_strategies( ): if combined.issuperset(group): continue - s = merged([inputs[g] for g in group]) + s = merged([inputs[g] for g in group], resolver=resolver) if s is not None and s != FALSEY: - validators = [make_validator(s) for s in schemas] + validators = [make_validator(s, resolver=resolver) for s in schemas] strats.append( - from_schema(s, custom_formats=custom_formats).filter( + from_schema(s, custom_formats=custom_formats, resolver=resolver).filter( lambda obj: all(v.is_valid(obj) for v in validators) ) ) @@ -72,6 +75,7 @@ def from_schema( schema: Union[bool, Schema], *, custom_formats: Dict[str, st.SearchStrategy[str]] = None, + resolver: Optional[LocalResolver] = None, ) -> st.SearchStrategy[JSONType]: """Take a JSON schema and return a strategy for allowed JSON objects. @@ -79,7 +83,7 @@ def from_schema( everything else in drafts 04, 05, and 07 is fully tested and working. """ try: - return __from_schema(schema, custom_formats=custom_formats) + return __from_schema(schema, custom_formats=custom_formats, resolver=resolver) except Exception as err: error = err @@ -112,13 +116,9 @@ def __from_schema( schema: Union[bool, Schema], *, custom_formats: Dict[str, st.SearchStrategy[str]] = None, + resolver: Optional[LocalResolver] = None, ) -> st.SearchStrategy[JSONType]: - try: - schema = resolve_all_refs(schema) - except RecursionError: - raise HypothesisRefResolutionError( - f"Could not resolve recursive references in schema={schema!r}" - ) from None + schema = resolve_all_refs(schema, resolver=resolver) # We check for _FORMATS_TOKEN to avoid re-validating known good data. if custom_formats is not None and _FORMATS_TOKEN not in custom_formats: assert isinstance(custom_formats, dict) @@ -141,7 +141,10 @@ def __from_schema( } custom_formats[_FORMATS_TOKEN] = None # type: ignore - schema = canonicalish(schema) + if resolver is None: + resolver = LocalResolver.from_schema(deepcopy(schema)) + + schema = canonicalish(schema, resolver) # Boolean objects are special schemata; False rejects all and True accepts all. if schema == FALSEY: return st.nothing() @@ -155,36 +158,49 @@ def __from_schema( assert isinstance(schema, dict) # Now we handle as many validation keywords as we can... + if "$ref" in schema: + ref = schema["$ref"] + + def _recurse() -> st.SearchStrategy[JSONType]: + url, resolved = resolver.resolve(ref) # type: ignore + resolver.push_scope(url) # type: ignore + return __from_schema( + deepcopy(resolved), custom_formats=custom_formats, resolver=resolver + ) + + return st.deferred(_recurse) # Applying subschemata with boolean logic if "not" in schema: not_ = schema.pop("not") assert isinstance(not_, dict) - validator = make_validator(not_).is_valid - return from_schema(schema, custom_formats=custom_formats).filter( - lambda v: not validator(v) - ) + validator = make_validator(not_, resolver=resolver).is_valid + return from_schema( + schema, custom_formats=custom_formats, resolver=resolver + ).filter(lambda v: not validator(v)) if "anyOf" in schema: tmp = schema.copy() ao = tmp.pop("anyOf") assert isinstance(ao, list) - return st.one_of([merged_as_strategies([tmp, s], custom_formats) for s in ao]) + return st.one_of( + [merged_as_strategies([tmp, s], custom_formats, resolver) for s in ao] + ) if "allOf" in schema: tmp = schema.copy() ao = tmp.pop("allOf") assert isinstance(ao, list) - return merged_as_strategies([tmp] + ao, custom_formats) + return merged_as_strategies([tmp] + ao, custom_formats, resolver) if "oneOf" in schema: tmp = schema.copy() oo = tmp.pop("oneOf") assert isinstance(oo, list) - schemas = [merged([tmp, s]) for s in oo] + schemas = [merged([tmp, s], resolver=resolver) for s in oo] return st.one_of( [ - from_schema(s, custom_formats=custom_formats) + from_schema(s, custom_formats=custom_formats, resolver=resolver) for s in schemas if s is not None ] - ).filter(make_validator(schema).is_valid) + ).filter(make_validator(schema, resolver=resolver).is_valid) # Simple special cases if "enum" in schema: assert schema["enum"], "Canonicalises to non-empty list or FALSEY" @@ -195,18 +211,22 @@ def __from_schema( map_: Dict[str, Callable[[Schema], st.SearchStrategy[JSONType]]] = { "null": lambda _: st.none(), "boolean": lambda _: st.booleans(), - "number": number_schema, - "integer": integer_schema, + # Mypy doesn't recognize that `resolver` has the `LocalResolver` type + "number": lambda s: number_schema(s, resolver=resolver), # type: ignore + "integer": lambda s: integer_schema(s, resolver=resolver), # type: ignore "string": partial(string_schema, custom_formats), - "array": partial(array_schema, custom_formats), - "object": partial(object_schema, custom_formats), + "array": partial(array_schema, custom_formats, resolver), + "object": partial(object_schema, custom_formats, resolver), } assert set(map_) == set(TYPE_STRINGS) return st.one_of([map_[t](schema) for t in get_type(schema)]) def _numeric_with_multiplier( - min_value: Optional[float], max_value: Optional[float], schema: Schema + min_value: Optional[float], + max_value: Optional[float], + schema: Schema, + resolver: LocalResolver, ) -> st.SearchStrategy[float]: """Handle numeric schemata containing the multipleOf key.""" multiple_of = schema["multipleOf"] @@ -224,23 +244,23 @@ def _numeric_with_multiplier( return ( st.integers(min_value, max_value) .map(lambda x: x * multiple_of) - .filter(make_validator(schema).is_valid) + .filter(make_validator(schema, resolver=resolver).is_valid) ) -def integer_schema(schema: dict) -> st.SearchStrategy[float]: +def integer_schema(schema: dict, resolver: LocalResolver) -> st.SearchStrategy[float]: """Handle integer schemata.""" min_value, max_value = get_integer_bounds(schema) if "multipleOf" in schema: - return _numeric_with_multiplier(min_value, max_value, schema) + return _numeric_with_multiplier(min_value, max_value, schema, resolver) return st.integers(min_value, max_value) -def number_schema(schema: dict) -> st.SearchStrategy[float]: +def number_schema(schema: dict, resolver: LocalResolver) -> st.SearchStrategy[float]: """Handle numeric schemata.""" min_value, max_value, exclude_min, exclude_max = get_number_bounds(schema) if "multipleOf" in schema: - return _numeric_with_multiplier(min_value, max_value, schema) + return _numeric_with_multiplier(min_value, max_value, schema, resolver) return st.floats( min_value=min_value, max_value=max_value, @@ -422,10 +442,14 @@ def string_schema( def array_schema( - custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict + custom_formats: Dict[str, st.SearchStrategy[str]], + resolver: LocalResolver, + schema: dict, ) -> st.SearchStrategy[List[JSONType]]: """Handle schemata for arrays.""" - _from_schema_ = partial(from_schema, custom_formats=custom_formats) + _from_schema_ = partial( + from_schema, custom_formats=custom_formats, resolver=resolver + ) items = schema.get("items", {}) additional_items = schema.get("additionalItems", {}) min_size = schema.get("minItems", 0) @@ -443,10 +467,14 @@ def array_schema( # allowed to do so. We'll skip the None (unmergable / no contains) cases # below, and let Hypothesis ignore the FALSEY cases for us. if "contains" in schema: - for i, mrgd in enumerate(merged([schema["contains"], s]) for s in items): + for i, mrgd in enumerate( + merged([schema["contains"], s], resolver=resolver) for s in items + ): if mrgd is not None: items_strats[i] |= _from_schema_(mrgd) - contains_additional = merged([schema["contains"], additional_items]) + contains_additional = merged( + [schema["contains"], additional_items], resolver=resolver + ) if contains_additional is not None: additional_items_strat |= _from_schema_(contains_additional) @@ -483,12 +511,17 @@ def not_seen(elem: JSONType) -> bool: items_strat = _from_schema_(items) if "contains" in schema: contains_strat = _from_schema_(schema["contains"]) - if merged([items, schema["contains"]]) != schema["contains"]: + if ( + merged([items, schema["contains"]], resolver=resolver) + != schema["contains"] + ): # We only need this filter if we couldn't merge items in when # canonicalising. Note that for list-items, above, we just skip # the mixed generation in this case (because they tend to be # heterogeneous) and hope it works out anyway. - contains_strat = contains_strat.filter(make_validator(items).is_valid) + contains_strat = contains_strat.filter( + make_validator(items, resolver=resolver).is_valid + ) items_strat |= contains_strat strat = st.lists( @@ -499,12 +532,14 @@ def not_seen(elem: JSONType) -> bool: ) if "contains" not in schema: return strat - contains = make_validator(schema["contains"]).is_valid + contains = make_validator(schema["contains"], resolver=resolver).is_valid return strat.filter(lambda val: any(contains(x) for x in val)) def object_schema( - custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict + custom_formats: Dict[str, st.SearchStrategy[str]], + resolver: LocalResolver, + schema: dict, ) -> st.SearchStrategy[Dict[str, JSONType]]: """Handle a manageable subset of possible schemata for objects.""" required = schema.get("required", []) # required keys @@ -533,13 +568,13 @@ def object_schema( st.sampled_from(sorted(dep_names) + sorted(dep_schemas) + sorted(properties)) if (dep_names or dep_schemas or properties) else st.nothing(), - from_schema(names, custom_formats=custom_formats) + from_schema(names, custom_formats=custom_formats, resolver=resolver) if additional_allowed else st.nothing(), st.one_of([st.from_regex(p) for p in sorted(patterns)]), ) all_names_strategy = st.one_of([s for s in name_strats if not s.is_empty]).filter( - make_validator(names).is_valid + make_validator(names, resolver=resolver).is_valid ) @st.composite # type: ignore @@ -582,12 +617,20 @@ def from_object_schema(draw: Any) -> Any: pattern_schemas.insert(0, properties[key]) if pattern_schemas: - out[key] = draw(merged_as_strategies(pattern_schemas, custom_formats)) + out[key] = draw( + merged_as_strategies( + pattern_schemas, custom_formats, resolver=resolver + ) + ) else: - out[key] = draw(from_schema(additional, custom_formats=custom_formats)) + out[key] = draw( + from_schema( + additional, custom_formats=custom_formats, resolver=resolver, + ) + ) for k, v in dep_schemas.items(): - if k in out and not make_validator(v).is_valid(out): + if k in out and not make_validator(v, resolver=resolver).is_valid(out): out.pop(key) elements.reject() diff --git a/tests/test_canonicalise.py b/tests/test_canonicalise.py index 45f4a03..6d0376e 100644 --- a/tests/test_canonicalise.py +++ b/tests/test_canonicalise.py @@ -20,7 +20,7 @@ def is_valid(instance, schema): - return make_validator(schema).is_valid(instance) + return make_validator(schema, resolver=None).is_valid(instance) @settings(suppress_health_check=[HealthCheck.too_slow], deadline=None) diff --git a/tests/test_from_schema.py b/tests/test_from_schema.py index 4a6e246..f2b9dd8 100644 --- a/tests/test_from_schema.py +++ b/tests/test_from_schema.py @@ -242,16 +242,11 @@ def inner(*args, **kwargs): assert isinstance(name, str) try: f(*args, **kwargs) - assert name not in RECURSIVE_REFS except jsonschema.exceptions.RefResolutionError as err: if ( isinstance(err, HypothesisRefResolutionError) or isinstance(err._cause, HypothesisRefResolutionError) - ) and ( - "does not fetch remote references" in str(err) - or name in RECURSIVE_REFS - and "Could not resolve recursive references" in str(err) - ): + ) and "does not fetch remote references" in str(err): pytest.xfail() raise @@ -425,3 +420,106 @@ def test_allowed_custom_format(num): def test_allowed_unknown_custom_format(string): assert string == "hello world" assert "not registered" not in jsonschema.FormatChecker().checkers + + +@pytest.mark.parametrize( + "schema", + ( + { + "properties": {"foo": {"$ref": "#"}}, + "additionalProperties": False, + "type": "object", + }, + { + "definitions": { + "Node": { + "type": "object", + "properties": { + "children": { + "type": "array", + "items": {"$ref": "#/definitions/Node"}, + "maxItems": 2, + } + }, + "required": ["children"], + "additionalProperties": False, + }, + }, + "$ref": "#/definitions/Node", + }, + # Simplified Open API schema + { + "type": "object", + "required": ["paths"], + "properties": {"paths": {"$ref": "#/definitions/Paths"}}, + "additionalProperties": False, + "definitions": { + "Schema": { + "type": "object", + "properties": {"items": {"$ref": "#/definitions/Schema"}}, + "additionalProperties": False, + }, + "MediaType": { + "type": "object", + "properties": {"schema": {"$ref": "#/definitions/Schema"}}, + "patternProperties": {"^x-": {}}, + "additionalProperties": False, + }, + "Paths": { + "type": "object", + "patternProperties": { + "^\\/": {"$ref": "#/definitions/PathItem"}, + "^x-": {}, + }, + "additionalProperties": False, + }, + "PathItem": { + "type": "object", + "properties": { + "parameters": { + "type": "array", + "items": {"$ref": "#/definitions/Parameter"}, + "uniqueItems": True, + }, + }, + "patternProperties": { + "^(get|put|post|delete|options|head|patch|trace)$": { + "$ref": "#/definitions/Operation" + }, + "^x-": {}, + }, + "additionalProperties": False, + }, + "Operation": { + "type": "object", + "required": ["responses"], + "properties": { + "parameters": { + "type": "array", + "items": {"$ref": "#/definitions/Parameter"}, + "uniqueItems": True, + }, + }, + "additionalProperties": False, + }, + "Parameter": { + "type": "object", + "properties": { + "schema": {"$ref": "#/definitions/Schema"}, + "content": { + "type": "object", + "minProperties": 1, + "maxProperties": 1, + }, + }, + "additionalProperties": False, + }, + }, + }, + ), +) +@given(data=st.data()) +@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]) +def test_recursive_reference(data, schema): + value = data.draw(from_schema(schema)) + jsonschema.validate(value, schema)