Skip to content

Commit 17e55e4

Browse files
author
John Chadwick
committed
Add rule and field value to violations
1 parent 8ea2b64 commit 17e55e4

File tree

4 files changed

+148
-73
lines changed

4 files changed

+148
-73
lines changed

protovalidate/internal/constraints.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __getitem__(self, name):
8181
return super().__getitem__(name)
8282

8383

84-
def _msg_to_cel(msg: message.Message) -> dict[str, celtypes.Value]:
84+
def _msg_to_cel(msg: message.Message) -> celtypes.Value:
8585
ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name)
8686
if ctor is not None:
8787
return ctor(msg)
@@ -230,43 +230,56 @@ def _set_path_element_map_key(
230230
raise CompilationError(msg)
231231

232232

233+
class Violation:
234+
"""A singular constraint violation."""
235+
236+
proto: validate_pb2.Violation
237+
field_value: typing.Any
238+
rule_value: typing.Any
239+
240+
def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = None, **kwargs):
241+
self.proto = validate_pb2.Violation(**kwargs)
242+
self.field_value = field_value
243+
self.rule_value = rule_value
244+
245+
233246
class ConstraintContext:
234247
"""The state associated with a single constraint evaluation."""
235248

236-
def __init__(self, fail_fast: bool = False, violations: validate_pb2.Violations = None): # noqa: FBT001, FBT002
249+
def __init__(self, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): # noqa: FBT001, FBT002
237250
self._fail_fast = fail_fast
238251
if violations is None:
239-
violations = validate_pb2.Violations()
252+
violations = []
240253
self._violations = violations
241254

242255
@property
243256
def fail_fast(self) -> bool:
244257
return self._fail_fast
245258

246259
@property
247-
def violations(self) -> validate_pb2.Violations:
260+
def violations(self) -> list[Violation]:
248261
return self._violations
249262

250-
def add(self, violation: validate_pb2.Violation):
251-
self._violations.violations.append(violation)
263+
def add(self, violation: list[Violation]):
264+
self._violations.append(violation)
252265

253266
def add_errors(self, other_ctx):
254-
self._violations.violations.extend(other_ctx.violations.violations)
267+
self._violations.extend(other_ctx.violations)
255268

256269
def add_field_path_element(self, element: validate_pb2.FieldPathElement):
257-
for violation in self._violations.violations:
258-
violation.field.elements.append(element)
270+
for violation in self._violations:
271+
violation.proto.field.elements.append(element)
259272

260273
def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPathElement]):
261-
for violation in self._violations.violations:
262-
violation.rule.elements.extend(elements)
274+
for violation in self._violations:
275+
violation.proto.rule.elements.extend(elements)
263276

264277
@property
265278
def done(self) -> bool:
266279
return self._fail_fast and self.has_errors()
267280

268281
def has_errors(self) -> bool:
269-
return len(self._violations.violations) > 0
282+
return len(self._violations) > 0
270283

271284
def sub_context(self):
272285
return ConstraintContext(self._fail_fast)
@@ -277,55 +290,81 @@ class ConstraintRules:
277290

278291
def validate(self, ctx: ConstraintContext, message: message.Message): # noqa: ARG002
279292
"""Validate the message against the rules in this constraint."""
280-
ctx.add(validate_pb2.Violation(constraint_id="unimplemented", message="Unimplemented"))
293+
ctx.add(Violation(constraint_id="unimplemented", message="Unimplemented"))
294+
295+
296+
class CelRunner:
297+
runner: celpy.Runner
298+
constraint: validate_pb2.Constraint
299+
rule_value: typing.Optional[typing.Any]
300+
rule_cel: typing.Optional[celtypes.Value]
301+
rule_path: typing.Optional[validate_pb2.FieldPath]
302+
303+
def __init__(
304+
self,
305+
*,
306+
runner: celpy.Runner,
307+
constraint: validate_pb2.Constraint,
308+
rule_value: typing.Optional[typing.Any] = None,
309+
rule_cel: typing.Optional[celtypes.Value] = None,
310+
rule_path: typing.Optional[validate_pb2.FieldPath] = None,
311+
):
312+
self.runner = runner
313+
self.constraint = constraint
314+
self.rule_value = rule_value
315+
self.rule_cel = rule_cel
316+
self.rule_path = rule_path
281317

282318

283319
class CelConstraintRules(ConstraintRules):
284320
"""A constraint that has rules written in CEL."""
285321

286-
_runners: list[
287-
tuple[
288-
celpy.Runner,
289-
validate_pb2.Constraint,
290-
typing.Optional[celtypes.Value],
291-
typing.Optional[validate_pb2.FieldPath],
292-
]
293-
]
294-
_rules_cel: celtypes.Value = None
322+
_cel: list[CelRunner]
323+
_rules: typing.Optional[message.Message] = None
324+
_rules_cel: typing.Optional[celtypes.Value] = None
295325

296326
def __init__(self, rules: typing.Optional[message.Message]):
297-
self._runners = []
327+
self._cel = []
298328
if rules is not None:
329+
self._rules = rules
299330
self._rules_cel = _msg_to_cel(rules)
300331

301332
def _validate_cel(
302333
self,
303334
ctx: ConstraintContext,
304-
activation: dict[str, typing.Any],
305335
*,
336+
this_value: typing.Optional[typing.Any] = None,
337+
this_cel: typing.Optional[celtypes.Value] = None,
306338
for_key: bool = False,
307339
):
340+
activation: dict[str, celtypes.Value] = {}
341+
if this_cel is not None:
342+
activation["this"] = this_cel
308343
activation["rules"] = self._rules_cel
309344
activation["now"] = celtypes.TimestampType(datetime.datetime.now(tz=datetime.timezone.utc))
310-
for runner, constraint, rule, rule_path in self._runners:
311-
activation["rule"] = rule
312-
result = runner.evaluate(activation)
345+
for cel in self._cel:
346+
activation["rule"] = cel.rule_cel
347+
result = cel.runner.evaluate(activation)
313348
if isinstance(result, celtypes.BoolType):
314349
if not result:
315350
ctx.add(
316-
validate_pb2.Violation(
317-
rule=rule_path,
318-
constraint_id=constraint.id,
319-
message=constraint.message,
351+
Violation(
352+
field_value=this_value,
353+
rule=cel.rule_path,
354+
rule_value=cel.rule_value,
355+
constraint_id=cel.constraint.id,
356+
message=cel.constraint.message,
320357
for_key=for_key,
321358
),
322359
)
323360
elif isinstance(result, celtypes.StringType):
324361
if result:
325362
ctx.add(
326-
validate_pb2.Violation(
327-
rule=rule_path,
328-
constraint_id=constraint.id,
363+
Violation(
364+
field_value=this_value,
365+
rule=cel.rule_path,
366+
rule_value=cel.rule_value,
367+
constraint_id=cel.constraint.id,
329368
message=result,
330369
for_key=for_key,
331370
),
@@ -339,19 +378,32 @@ def add_rule(
339378
funcs: dict[str, celpy.CELFunction],
340379
rules: validate_pb2.Constraint,
341380
*,
342-
rule: typing.Optional[celtypes.Value] = None,
381+
rule_field: typing.Optional[descriptor.FieldDescriptor] = None,
343382
rule_path: typing.Optional[validate_pb2.FieldPath] = None,
344383
):
345384
ast = env.compile(rules.expression)
346385
prog = env.program(ast, functions=funcs)
347-
self._runners.append((prog, rules, rule, rule_path))
386+
rule_value = None
387+
rule_cel = None
388+
if rule_field is not None and self._rules is not None:
389+
rule_value = _proto_message_get_field(self._rules, rule_field)
390+
rule_cel = _field_to_cel(self._rules, rule_field)
391+
self._cel.append(
392+
CelRunner(
393+
runner=prog,
394+
constraint=rules,
395+
rule_value=rule_value,
396+
rule_cel=rule_cel,
397+
rule_path=rule_path,
398+
)
399+
)
348400

349401

350402
class MessageConstraintRules(CelConstraintRules):
351403
"""Message-level rules."""
352404

353405
def validate(self, ctx: ConstraintContext, message: message.Message):
354-
self._validate_cel(ctx, {"this": _msg_to_cel(message)})
406+
self._validate_cel(ctx, this_cel=_msg_to_cel(message))
355407

356408

357409
def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_name: typing.Optional[str] = None):
@@ -445,7 +497,7 @@ def __init__(
445497
env,
446498
funcs,
447499
cel,
448-
rule=_field_to_cel(rules, list_field),
500+
rule_field=list_field,
449501
rule_path=validate_pb2.FieldPath(
450502
elements=[
451503
_field_to_element(list_field),
@@ -465,13 +517,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
465517
if _is_empty_field(message, self._field):
466518
if self._required:
467519
ctx.add(
468-
validate_pb2.Violation(
520+
Violation(
469521
field=validate_pb2.FieldPath(
470522
elements=[
471523
_field_to_element(self._field),
472524
],
473525
),
474526
rule=FieldConstraintRules._required_rule_path,
527+
rule_value=self._required,
475528
constraint_id="required",
476529
message="value is required",
477530
),
@@ -485,15 +538,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
485538
return
486539
sub_ctx = ctx.sub_context()
487540
self._validate_value(sub_ctx, val)
488-
self._validate_cel(sub_ctx, {"this": cel_val})
541+
self._validate_cel(sub_ctx, this_value=_proto_message_get_field(message, self._field), this_cel=cel_val)
489542
if sub_ctx.has_errors():
490543
element = _field_to_element(self._field)
491544
sub_ctx.add_field_path_element(element)
492545
ctx.add_errors(sub_ctx)
493546

494547
def validate_item(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
495548
self._validate_value(ctx, val, for_key=for_key)
496-
self._validate_cel(ctx, {"this": _scalar_field_value_to_cel(val, self._field)}, for_key=for_key)
549+
self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key)
497550

498551
def _validate_value(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
499552
pass
@@ -546,17 +599,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
546599
if len(self._in) > 0:
547600
if value.type_url not in self._in:
548601
ctx.add(
549-
validate_pb2.Violation(
602+
Violation(
550603
rule=AnyConstraintRules._in_rule_path,
604+
rule_value=self._in,
551605
constraint_id="any.in",
552606
message="type URL must be in the allow list",
553607
for_key=for_key,
554608
)
555609
)
556610
if value.type_url in self._not_in:
557611
ctx.add(
558-
validate_pb2.Violation(
612+
Violation(
559613
rule=AnyConstraintRules._not_in_rule_path,
614+
rule_value=self._not_in,
560615
constraint_id="any.not_in",
561616
message="type URL must not be in the block list",
562617
for_key=for_key,
@@ -603,13 +658,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
603658
value = getattr(message, self._field.name)
604659
if value not in self._field.enum_type.values_by_number:
605660
ctx.add(
606-
validate_pb2.Violation(
661+
Violation(
607662
field=validate_pb2.FieldPath(
608663
elements=[
609664
_field_to_element(self._field),
610665
],
611666
),
612667
rule=EnumConstraintRules._defined_only_rule_path,
668+
rule_value=self._defined_only,
613669
constraint_id="enum.defined_only",
614670
message="value must be one of the defined enum values",
615671
),
@@ -742,7 +798,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
742798
if not message.WhichOneof(self._oneof.name):
743799
if self.required:
744800
ctx.add(
745-
validate_pb2.Violation(
801+
Violation(
746802
field=validate_pb2.FieldPath(
747803
elements=[_oneof_to_element(self._oneof)],
748804
),

protovalidate/validator.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
from google.protobuf import message
1618

1719
from buf.validate import validate_pb2 # type: ignore
@@ -20,6 +22,7 @@
2022

2123
CompilationError = _constraints.CompilationError
2224
Violations = validate_pb2.Violations
25+
Violation = _constraints.Violation
2326

2427

2528
class Validator:
@@ -54,7 +57,7 @@ def validate(
5457
ValidationError: If the message is invalid.
5558
"""
5659
violations = self.collect_violations(message, fail_fast=fail_fast)
57-
if violations.violations:
60+
if len(violations) > 0:
5861
msg = f"invalid {message.DESCRIPTOR.name}"
5962
raise ValidationError(msg, violations)
6063

@@ -63,8 +66,8 @@ def collect_violations(
6366
message: message.Message,
6467
*,
6568
fail_fast: bool = False,
66-
into: validate_pb2.Violations = None,
67-
) -> validate_pb2.Violations:
69+
into: typing.Optional[list[Violation]] = None,
70+
) -> list[Violation]:
6871
"""
6972
Validates the given message against the static constraints defined in
7073
the message's descriptor. Compared to validate, collect_violations is
@@ -84,12 +87,12 @@ def collect_violations(
8487
constraint.validate(ctx, message)
8588
if ctx.done:
8689
break
87-
for violation in ctx.violations.violations:
88-
if violation.HasField("field"):
89-
violation.field.elements.reverse()
90-
if violation.HasField("rule"):
91-
violation.rule.elements.reverse()
92-
violation.field_path = field_path.string(violation.field)
90+
for violation in ctx.violations:
91+
if violation.proto.HasField("field"):
92+
violation.proto.field.elements.reverse()
93+
if violation.proto.HasField("rule"):
94+
violation.proto.rule.elements.reverse()
95+
violation.proto.field_path = field_path.string(violation.proto.field)
9396
return ctx.violations
9497

9598

@@ -98,15 +101,25 @@ class ValidationError(ValueError):
98101
An error raised when a message fails to validate.
99102
"""
100103

101-
violations: validate_pb2.Violations
104+
_violations: list[Violation]
102105

103-
def __init__(self, msg: str, violations: validate_pb2.Violations):
106+
def __init__(self, msg: str, violations: list[Violations]):
104107
super().__init__(msg)
105-
self.violations = violations
108+
self._violations = violations
109+
110+
def to_proto(self) -> validate_pb2.Violations:
111+
"""
112+
Provides the Protobuf form of the validation errors.
113+
"""
114+
result = validate_pb2.Violations()
115+
for violation in self._violations:
116+
result.violations.append(violation.proto)
117+
return result
106118

107-
def errors(self) -> list[validate_pb2.Violation]:
119+
@property
120+
def violations(self) -> list[Violation]:
108121
"""
109-
Returns the validation errors as a simple Python list, rather than the
122+
Provides the validation errors as a simple Python list, rather than the
110123
Protobuf-specific collection type used by Violations.
111124
"""
112-
return list(self.violations.violations)
125+
return self._violations

0 commit comments

Comments
 (0)