@@ -81,7 +81,7 @@ def __getitem__(self, name):
8181 return super ().__getitem__ (name )
8282
8383
84- def _msg_to_cel (msg : message .Message ) -> dict [ str , celtypes .Value ] :
84+ def _msg_to_cel (msg : message .Message ) -> celtypes .Value :
8585 ctor = _MSG_TYPE_URL_TO_CTOR .get (msg .DESCRIPTOR .full_name )
8686 if ctor is not None :
8787 return ctor (msg )
@@ -230,43 +230,56 @@ def _set_path_element_map_key(
230230 raise CompilationError (msg )
231231
232232
233+ class Violation :
234+ """A singular constraint violation."""
235+
236+ proto : validate_pb2 .Violation
237+ field_value : typing .Any
238+ rule_value : typing .Any
239+
240+ def __init__ (self , * , field_value : typing .Any = None , rule_value : typing .Any = None , ** kwargs ):
241+ self .proto = validate_pb2 .Violation (** kwargs )
242+ self .field_value = field_value
243+ self .rule_value = rule_value
244+
245+
233246class ConstraintContext :
234247 """The state associated with a single constraint evaluation."""
235248
236- def __init__ (self , fail_fast : bool = False , violations : validate_pb2 . Violations = None ): # noqa: FBT001, FBT002
249+ def __init__ (self , fail_fast : bool = False , violations : typing . Optional [ list [ Violation ]] = None ): # noqa: FBT001, FBT002
237250 self ._fail_fast = fail_fast
238251 if violations is None :
239- violations = validate_pb2 . Violations ()
252+ violations = []
240253 self ._violations = violations
241254
242255 @property
243256 def fail_fast (self ) -> bool :
244257 return self ._fail_fast
245258
246259 @property
247- def violations (self ) -> validate_pb2 . Violations :
260+ def violations (self ) -> list [ Violation ] :
248261 return self ._violations
249262
250- def add (self , violation : validate_pb2 . Violation ):
251- self ._violations .violations . append (violation )
263+ def add (self , violation : list [ Violation ] ):
264+ self ._violations .append (violation )
252265
253266 def add_errors (self , other_ctx ):
254- self ._violations .violations . extend (other_ctx . violations .violations )
267+ self ._violations .extend (other_ctx .violations )
255268
256269 def add_field_path_element (self , element : validate_pb2 .FieldPathElement ):
257- for violation in self ._violations . violations :
258- violation .field .elements .append (element )
270+ for violation in self ._violations :
271+ violation .proto . field .elements .append (element )
259272
260273 def add_rule_path_elements (self , elements : typing .Iterable [validate_pb2 .FieldPathElement ]):
261- for violation in self ._violations . violations :
262- violation .rule .elements .extend (elements )
274+ for violation in self ._violations :
275+ violation .proto . rule .elements .extend (elements )
263276
264277 @property
265278 def done (self ) -> bool :
266279 return self ._fail_fast and self .has_errors ()
267280
268281 def has_errors (self ) -> bool :
269- return len (self ._violations . violations ) > 0
282+ return len (self ._violations ) > 0
270283
271284 def sub_context (self ):
272285 return ConstraintContext (self ._fail_fast )
@@ -277,55 +290,81 @@ class ConstraintRules:
277290
278291 def validate (self , ctx : ConstraintContext , message : message .Message ): # noqa: ARG002
279292 """Validate the message against the rules in this constraint."""
280- ctx .add (validate_pb2 .Violation (constraint_id = "unimplemented" , message = "Unimplemented" ))
293+ ctx .add (Violation (constraint_id = "unimplemented" , message = "Unimplemented" ))
294+
295+
296+ class CelRunner :
297+ runner : celpy .Runner
298+ constraint : validate_pb2 .Constraint
299+ rule_value : typing .Optional [typing .Any ]
300+ rule_cel : typing .Optional [celtypes .Value ]
301+ rule_path : typing .Optional [validate_pb2 .FieldPath ]
302+
303+ def __init__ (
304+ self ,
305+ * ,
306+ runner : celpy .Runner ,
307+ constraint : validate_pb2 .Constraint ,
308+ rule_value : typing .Optional [typing .Any ] = None ,
309+ rule_cel : typing .Optional [celtypes .Value ] = None ,
310+ rule_path : typing .Optional [validate_pb2 .FieldPath ] = None ,
311+ ):
312+ self .runner = runner
313+ self .constraint = constraint
314+ self .rule_value = rule_value
315+ self .rule_cel = rule_cel
316+ self .rule_path = rule_path
281317
282318
283319class CelConstraintRules (ConstraintRules ):
284320 """A constraint that has rules written in CEL."""
285321
286- _runners : list [
287- tuple [
288- celpy .Runner ,
289- validate_pb2 .Constraint ,
290- typing .Optional [celtypes .Value ],
291- typing .Optional [validate_pb2 .FieldPath ],
292- ]
293- ]
294- _rules_cel : celtypes .Value = None
322+ _cel : list [CelRunner ]
323+ _rules : typing .Optional [message .Message ] = None
324+ _rules_cel : typing .Optional [celtypes .Value ] = None
295325
296326 def __init__ (self , rules : typing .Optional [message .Message ]):
297- self ._runners = []
327+ self ._cel = []
298328 if rules is not None :
329+ self ._rules = rules
299330 self ._rules_cel = _msg_to_cel (rules )
300331
301332 def _validate_cel (
302333 self ,
303334 ctx : ConstraintContext ,
304- activation : dict [str , typing .Any ],
305335 * ,
336+ this_value : typing .Optional [typing .Any ] = None ,
337+ this_cel : typing .Optional [celtypes .Value ] = None ,
306338 for_key : bool = False ,
307339 ):
340+ activation : dict [str , celtypes .Value ] = {}
341+ if this_cel is not None :
342+ activation ["this" ] = this_cel
308343 activation ["rules" ] = self ._rules_cel
309344 activation ["now" ] = celtypes .TimestampType (datetime .datetime .now (tz = datetime .timezone .utc ))
310- for runner , constraint , rule , rule_path in self ._runners :
311- activation ["rule" ] = rule
312- result = runner .evaluate (activation )
345+ for cel in self ._cel :
346+ activation ["rule" ] = cel . rule_cel
347+ result = cel . runner .evaluate (activation )
313348 if isinstance (result , celtypes .BoolType ):
314349 if not result :
315350 ctx .add (
316- validate_pb2 .Violation (
317- rule = rule_path ,
318- constraint_id = constraint .id ,
319- message = constraint .message ,
351+ Violation (
352+ field_value = this_value ,
353+ rule = cel .rule_path ,
354+ rule_value = cel .rule_value ,
355+ constraint_id = cel .constraint .id ,
356+ message = cel .constraint .message ,
320357 for_key = for_key ,
321358 ),
322359 )
323360 elif isinstance (result , celtypes .StringType ):
324361 if result :
325362 ctx .add (
326- validate_pb2 .Violation (
327- rule = rule_path ,
328- constraint_id = constraint .id ,
363+ Violation (
364+ field_value = this_value ,
365+ rule = cel .rule_path ,
366+ rule_value = cel .rule_value ,
367+ constraint_id = cel .constraint .id ,
329368 message = result ,
330369 for_key = for_key ,
331370 ),
@@ -339,19 +378,32 @@ def add_rule(
339378 funcs : dict [str , celpy .CELFunction ],
340379 rules : validate_pb2 .Constraint ,
341380 * ,
342- rule : typing .Optional [celtypes . Value ] = None ,
381+ rule_field : typing .Optional [descriptor . FieldDescriptor ] = None ,
343382 rule_path : typing .Optional [validate_pb2 .FieldPath ] = None ,
344383 ):
345384 ast = env .compile (rules .expression )
346385 prog = env .program (ast , functions = funcs )
347- self ._runners .append ((prog , rules , rule , rule_path ))
386+ rule_value = None
387+ rule_cel = None
388+ if rule_field is not None and self ._rules is not None :
389+ rule_value = _proto_message_get_field (self ._rules , rule_field )
390+ rule_cel = _field_to_cel (self ._rules , rule_field )
391+ self ._cel .append (
392+ CelRunner (
393+ runner = prog ,
394+ constraint = rules ,
395+ rule_value = rule_value ,
396+ rule_cel = rule_cel ,
397+ rule_path = rule_path ,
398+ )
399+ )
348400
349401
350402class MessageConstraintRules (CelConstraintRules ):
351403 """Message-level rules."""
352404
353405 def validate (self , ctx : ConstraintContext , message : message .Message ):
354- self ._validate_cel (ctx , { "this" : _msg_to_cel (message )} )
406+ self ._validate_cel (ctx , this_cel = _msg_to_cel (message ))
355407
356408
357409def check_field_type (field : descriptor .FieldDescriptor , expected : int , wrapper_name : typing .Optional [str ] = None ):
@@ -445,7 +497,7 @@ def __init__(
445497 env ,
446498 funcs ,
447499 cel ,
448- rule = _field_to_cel ( rules , list_field ) ,
500+ rule_field = list_field ,
449501 rule_path = validate_pb2 .FieldPath (
450502 elements = [
451503 _field_to_element (list_field ),
@@ -465,13 +517,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
465517 if _is_empty_field (message , self ._field ):
466518 if self ._required :
467519 ctx .add (
468- validate_pb2 . Violation (
520+ Violation (
469521 field = validate_pb2 .FieldPath (
470522 elements = [
471523 _field_to_element (self ._field ),
472524 ],
473525 ),
474526 rule = FieldConstraintRules ._required_rule_path ,
527+ rule_value = self ._required ,
475528 constraint_id = "required" ,
476529 message = "value is required" ,
477530 ),
@@ -485,15 +538,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
485538 return
486539 sub_ctx = ctx .sub_context ()
487540 self ._validate_value (sub_ctx , val )
488- self ._validate_cel (sub_ctx , { "this" : cel_val } )
541+ self ._validate_cel (sub_ctx , this_value = _proto_message_get_field ( message , self . _field ), this_cel = cel_val )
489542 if sub_ctx .has_errors ():
490543 element = _field_to_element (self ._field )
491544 sub_ctx .add_field_path_element (element )
492545 ctx .add_errors (sub_ctx )
493546
494547 def validate_item (self , ctx : ConstraintContext , val : typing .Any , * , for_key : bool = False ):
495548 self ._validate_value (ctx , val , for_key = for_key )
496- self ._validate_cel (ctx , { "this" : _scalar_field_value_to_cel (val , self ._field )} , for_key = for_key )
549+ self ._validate_cel (ctx , this_value = val , this_cel = _scalar_field_value_to_cel (val , self ._field ), for_key = for_key )
497550
498551 def _validate_value (self , ctx : ConstraintContext , val : typing .Any , * , for_key : bool = False ):
499552 pass
@@ -546,17 +599,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
546599 if len (self ._in ) > 0 :
547600 if value .type_url not in self ._in :
548601 ctx .add (
549- validate_pb2 . Violation (
602+ Violation (
550603 rule = AnyConstraintRules ._in_rule_path ,
604+ rule_value = self ._in ,
551605 constraint_id = "any.in" ,
552606 message = "type URL must be in the allow list" ,
553607 for_key = for_key ,
554608 )
555609 )
556610 if value .type_url in self ._not_in :
557611 ctx .add (
558- validate_pb2 . Violation (
612+ Violation (
559613 rule = AnyConstraintRules ._not_in_rule_path ,
614+ rule_value = self ._not_in ,
560615 constraint_id = "any.not_in" ,
561616 message = "type URL must not be in the block list" ,
562617 for_key = for_key ,
@@ -603,13 +658,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
603658 value = getattr (message , self ._field .name )
604659 if value not in self ._field .enum_type .values_by_number :
605660 ctx .add (
606- validate_pb2 . Violation (
661+ Violation (
607662 field = validate_pb2 .FieldPath (
608663 elements = [
609664 _field_to_element (self ._field ),
610665 ],
611666 ),
612667 rule = EnumConstraintRules ._defined_only_rule_path ,
668+ rule_value = self ._defined_only ,
613669 constraint_id = "enum.defined_only" ,
614670 message = "value must be one of the defined enum values" ,
615671 ),
@@ -742,7 +798,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
742798 if not message .WhichOneof (self ._oneof .name ):
743799 if self .required :
744800 ctx .add (
745- validate_pb2 . Violation (
801+ Violation (
746802 field = validate_pb2 .FieldPath (
747803 elements = [_oneof_to_element (self ._oneof )],
748804 ),
0 commit comments