Skip to content

Commit 0014064

Browse files
committed
Add validate hook
1 parent ff9c75b commit 0014064

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

jsonschema/_typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ def __call__(
2727
[referencing.jsonschema.Schema],
2828
Iterable[tuple[str, Any]],
2929
]
30+
31+
class ValidateHook(Protocol):
32+
def __call__(
33+
self,
34+
is_valid: bool,
35+
instance: Any,
36+
schema: referencing.jsonschema.Schema,
37+
) -> None:
38+
...

jsonschema/protocols.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# therefore, only import at type-checking time (to avoid circular references),
1515
# but use `jsonschema` for any types which will otherwise not be resolvable
1616
if TYPE_CHECKING:
17-
from collections.abc import Iterable, Mapping
17+
from collections.abc import Iterable, Mapping, Sequence
1818

1919
import referencing.jsonschema
2020

@@ -102,6 +102,8 @@ class Validator(Protocol):
102102
#: A function which given a schema returns its ID.
103103
ID_OF: _typing.id_of
104104

105+
VALIDATE_HOOKS: ClassVar[Sequence]
106+
105107
#: The schema that will be used to validate instances
106108
schema: Mapping | bool
107109

jsonschema/validators.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def create(
147147
applicable_validators: _typing.ApplicableValidators = methodcaller(
148148
"items",
149149
),
150+
validate_hooks: Sequence[_typing.ValidateHook] = (),
150151
):
151152
"""
152153
Create a new validator class.
@@ -207,6 +208,16 @@ def create(
207208
implement similar behavior, you can typically ignore this argument
208209
and leave it at its default.
209210
211+
validate_hooks:
212+
213+
A list of callables, will be called after validate.
214+
215+
Each callable should take 4 arguments:
216+
217+
1. is valid or not
218+
2. the instance
219+
3. the schema
220+
210221
Returns:
211222
212223
a new `jsonschema.protocols.Validator` class
@@ -220,6 +231,10 @@ def create(
220231
default=referencing.Specification.OPAQUE,
221232
)
222233

234+
def _call_validate_hooks(is_valid, instance, schema):
235+
for hook in validate_hooks:
236+
hook(is_valid, instance, schema)
237+
223238
@define
224239
class Validator:
225240

@@ -228,6 +243,7 @@ class Validator:
228243
TYPE_CHECKER = type_checker
229244
FORMAT_CHECKER = format_checker_arg
230245
ID_OF = staticmethod(id_of)
246+
VALIDATE_HOOKS = list(validate_hooks) # noqa: RUF012
231247

232248
_APPLICABLE_VALIDATORS = applicable_validators
233249
_validators = field(init=False, repr=False, eq=False)
@@ -368,6 +384,7 @@ def iter_errors(self, instance, _schema=None):
368384
_schema, validators = self.schema, self._validators
369385

370386
if _schema is True:
387+
_call_validate_hooks(True, instance, _schema)
371388
return
372389
elif _schema is False:
373390
yield exceptions.ValidationError(
@@ -377,8 +394,10 @@ def iter_errors(self, instance, _schema=None):
377394
instance=instance,
378395
schema=_schema,
379396
)
397+
_call_validate_hooks(False, instance, _schema)
380398
return
381399

400+
is_valid = True
382401
for validator, k, v in validators:
383402
errors = validator(self, v, instance, _schema) or ()
384403
for error in errors:
@@ -392,7 +411,9 @@ def iter_errors(self, instance, _schema=None):
392411
)
393412
if k not in {"if", "$ref"}:
394413
error.schema_path.appendleft(k)
414+
is_valid = False
395415
yield error
416+
_call_validate_hooks(is_valid, instance, _schema)
396417

397418
def descend(
398419
self,
@@ -403,6 +424,7 @@ def descend(
403424
resolver=None,
404425
):
405426
if schema is True:
427+
_call_validate_hooks(True, instance, schema)
406428
return
407429
elif schema is False:
408430
yield exceptions.ValidationError(
@@ -412,6 +434,7 @@ def descend(
412434
instance=instance,
413435
schema=schema,
414436
)
437+
_call_validate_hooks(False, instance, schema)
415438
return
416439

417440
if self._ref_resolver is not None:
@@ -423,6 +446,7 @@ def descend(
423446
)
424447
evolved = self.evolve(schema=schema, _resolver=resolver)
425448

449+
is_valid = True
426450
for k, v in applicable_validators(schema):
427451
validator = evolved.VALIDATORS.get(k)
428452
if validator is None:
@@ -444,10 +468,15 @@ def descend(
444468
error.path.appendleft(path)
445469
if schema_path is not None:
446470
error.schema_path.appendleft(schema_path)
471+
is_valid = False
447472
yield error
473+
_call_validate_hooks(is_valid, instance, schema)
448474

449-
def validate(self, *args, **kwargs):
450-
for error in self.iter_errors(*args, **kwargs):
475+
def validate(self, instance, _schema=None):
476+
for error in self.iter_errors(instance, _schema):
477+
if _schema is None:
478+
_schema = self.schema
479+
_call_validate_hooks(False, instance, _schema)
451480
raise error
452481

453482
def is_type(self, instance, type):
@@ -498,6 +527,8 @@ def is_valid(self, instance, _schema=None):
498527
self = self.evolve(schema=_schema)
499528

500529
error = next(self.iter_errors(instance), None)
530+
if error is not None:
531+
_call_validate_hooks(False, instance, self.schema)
501532
return error is None
502533

503534
evolve_fields = [
@@ -520,6 +551,7 @@ def extend(
520551
version=None,
521552
type_checker=None,
522553
format_checker=None,
554+
validate_hooks=(),
523555
):
524556
"""
525557
Create a new validator class by extending an existing one.
@@ -565,6 +597,12 @@ def extend(
565597
If unprovided, the format checker of the extended
566598
`jsonschema.protocols.Validator` will be carried along.
567599
600+
validate_hooks (collections.abc.Sequence):
601+
602+
a list of new validate hooks to extend with, whose
603+
structure is as in `create`.
604+
605+
568606
Returns:
569607
570608
a new `jsonschema.protocols.Validator` class extending the one
@@ -584,6 +622,9 @@ def extend(
584622
all_validators = dict(validator.VALIDATORS)
585623
all_validators.update(validators)
586624

625+
all_validate_hooks = list(validator.VALIDATE_HOOKS)
626+
all_validate_hooks.extend(validate_hooks)
627+
587628
if type_checker is None:
588629
type_checker = validator.TYPE_CHECKER
589630
if format_checker is None:
@@ -596,6 +637,7 @@ def extend(
596637
format_checker=format_checker,
597638
id_of=validator.ID_OF,
598639
applicable_validators=validator._APPLICABLE_VALIDATORS,
640+
validate_hooks=all_validate_hooks,
599641
)
600642

601643

0 commit comments

Comments
 (0)