diff --git a/jsonschema/_typing.py b/jsonschema/_typing.py index 1d091d70c..2824d51f4 100644 --- a/jsonschema/_typing.py +++ b/jsonschema/_typing.py @@ -27,3 +27,12 @@ def __call__( [referencing.jsonschema.Schema], Iterable[tuple[str, Any]], ] + +class ValidateHook(Protocol): + def __call__( + self, + is_valid: bool, + instance: Any, + schema: referencing.jsonschema.Schema, + ) -> None: + ... diff --git a/jsonschema/protocols.py b/jsonschema/protocols.py index 0fd993eec..4cffc2b15 100644 --- a/jsonschema/protocols.py +++ b/jsonschema/protocols.py @@ -14,7 +14,7 @@ # therefore, only import at type-checking time (to avoid circular references), # but use `jsonschema` for any types which will otherwise not be resolvable if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Iterable, Mapping, Sequence import referencing.jsonschema @@ -102,6 +102,8 @@ class Validator(Protocol): #: A function which given a schema returns its ID. ID_OF: _typing.id_of + VALIDATE_HOOKS: ClassVar[Sequence] + #: The schema that will be used to validate instances schema: Mapping | bool diff --git a/jsonschema/validators.py b/jsonschema/validators.py index b8ca3bd45..7a375c144 100644 --- a/jsonschema/validators.py +++ b/jsonschema/validators.py @@ -147,6 +147,7 @@ def create( applicable_validators: _typing.ApplicableValidators = methodcaller( "items", ), + validate_hooks: Sequence[_typing.ValidateHook] = (), ): """ Create a new validator class. @@ -207,6 +208,16 @@ def create( implement similar behavior, you can typically ignore this argument and leave it at its default. + validate_hooks: + + A list of callables, will be called after validate. + + Each callable should take 4 arguments: + + 1. is valid or not + 2. the instance + 3. the schema + Returns: a new `jsonschema.protocols.Validator` class @@ -220,6 +231,10 @@ def create( default=referencing.Specification.OPAQUE, ) + def _call_validate_hooks(is_valid, instance, schema): + for hook in validate_hooks: + hook(is_valid, instance, schema) + @define class Validator: @@ -228,6 +243,7 @@ class Validator: TYPE_CHECKER = type_checker FORMAT_CHECKER = format_checker_arg ID_OF = staticmethod(id_of) + VALIDATE_HOOKS = list(validate_hooks) # noqa: RUF012 _APPLICABLE_VALIDATORS = applicable_validators _validators = field(init=False, repr=False, eq=False) @@ -368,6 +384,7 @@ def iter_errors(self, instance, _schema=None): _schema, validators = self.schema, self._validators if _schema is True: + _call_validate_hooks(True, instance, _schema) return elif _schema is False: yield exceptions.ValidationError( @@ -377,8 +394,10 @@ def iter_errors(self, instance, _schema=None): instance=instance, schema=_schema, ) + _call_validate_hooks(False, instance, _schema) return + is_valid = True for validator, k, v in validators: errors = validator(self, v, instance, _schema) or () for error in errors: @@ -392,7 +411,9 @@ def iter_errors(self, instance, _schema=None): ) if k not in {"if", "$ref"}: error.schema_path.appendleft(k) + is_valid = False yield error + _call_validate_hooks(is_valid, instance, _schema) def descend( self, @@ -403,6 +424,7 @@ def descend( resolver=None, ): if schema is True: + _call_validate_hooks(True, instance, schema) return elif schema is False: yield exceptions.ValidationError( @@ -412,6 +434,7 @@ def descend( instance=instance, schema=schema, ) + _call_validate_hooks(False, instance, schema) return if self._ref_resolver is not None: @@ -423,6 +446,7 @@ def descend( ) evolved = self.evolve(schema=schema, _resolver=resolver) + is_valid = True for k, v in applicable_validators(schema): validator = evolved.VALIDATORS.get(k) if validator is None: @@ -444,10 +468,15 @@ def descend( error.path.appendleft(path) if schema_path is not None: error.schema_path.appendleft(schema_path) + is_valid = False yield error + _call_validate_hooks(is_valid, instance, schema) - def validate(self, *args, **kwargs): - for error in self.iter_errors(*args, **kwargs): + def validate(self, instance, _schema=None): + for error in self.iter_errors(instance, _schema): + if _schema is None: + _schema = self.schema + _call_validate_hooks(False, instance, _schema) raise error def is_type(self, instance, type): @@ -498,6 +527,8 @@ def is_valid(self, instance, _schema=None): self = self.evolve(schema=_schema) error = next(self.iter_errors(instance), None) + if error is not None: + _call_validate_hooks(False, instance, self.schema) return error is None evolve_fields = [ @@ -520,6 +551,7 @@ def extend( version=None, type_checker=None, format_checker=None, + validate_hooks=(), ): """ Create a new validator class by extending an existing one. @@ -565,6 +597,12 @@ def extend( If unprovided, the format checker of the extended `jsonschema.protocols.Validator` will be carried along. + validate_hooks (collections.abc.Sequence): + + a list of new validate hooks to extend with, whose + structure is as in `create`. + + Returns: a new `jsonschema.protocols.Validator` class extending the one @@ -584,6 +622,9 @@ def extend( all_validators = dict(validator.VALIDATORS) all_validators.update(validators) + all_validate_hooks = list(validator.VALIDATE_HOOKS) + all_validate_hooks.extend(validate_hooks) + if type_checker is None: type_checker = validator.TYPE_CHECKER if format_checker is None: @@ -596,6 +637,7 @@ def extend( format_checker=format_checker, id_of=validator.ID_OF, applicable_validators=validator._APPLICABLE_VALIDATORS, + validate_hooks=all_validate_hooks, )