1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import dataclasses
1516import datetime
1617import typing
1718
@@ -81,7 +82,7 @@ def __getitem__(self, name):
8182 return super ().__getitem__ (name )
8283
8384
84- def _msg_to_cel (msg : message .Message ) -> dict [ str , celtypes .Value ] :
85+ def _msg_to_cel (msg : message .Message ) -> celtypes .Value :
8586 ctor = _MSG_TYPE_URL_TO_CTOR .get (msg .DESCRIPTOR .full_name )
8687 if ctor is not None :
8788 return ctor (msg )
@@ -230,43 +231,56 @@ def _set_path_element_map_key(
230231 raise CompilationError (msg )
231232
232233
234+ class Violation :
235+ """A singular constraint violation."""
236+
237+ proto : validate_pb2 .Violation
238+ field_value : typing .Any
239+ rule_value : typing .Any
240+
241+ def __init__ (self , * , field_value : typing .Any = None , rule_value : typing .Any = None , ** kwargs ):
242+ self .proto = validate_pb2 .Violation (** kwargs )
243+ self .field_value = field_value
244+ self .rule_value = rule_value
245+
246+
233247class ConstraintContext :
234248 """The state associated with a single constraint evaluation."""
235249
236- def __init__ (self , fail_fast : bool = False , violations : validate_pb2 . Violations = None ): # noqa: FBT001, FBT002
250+ def __init__ (self , fail_fast : bool = False , violations : typing . Optional [ list [ Violation ]] = None ): # noqa: FBT001, FBT002
237251 self ._fail_fast = fail_fast
238252 if violations is None :
239- violations = validate_pb2 . Violations ()
253+ violations = []
240254 self ._violations = violations
241255
242256 @property
243257 def fail_fast (self ) -> bool :
244258 return self ._fail_fast
245259
246260 @property
247- def violations (self ) -> validate_pb2 . Violations :
261+ def violations (self ) -> list [ Violation ] :
248262 return self ._violations
249263
250- def add (self , violation : validate_pb2 . Violation ):
251- self ._violations .violations . append (violation )
264+ def add (self , violation : Violation ):
265+ self ._violations .append (violation )
252266
253267 def add_errors (self , other_ctx ):
254- self ._violations .violations . extend (other_ctx . violations .violations )
268+ self ._violations .extend (other_ctx .violations )
255269
256270 def add_field_path_element (self , element : validate_pb2 .FieldPathElement ):
257- for violation in self ._violations . violations :
258- violation .field .elements .append (element )
271+ for violation in self ._violations :
272+ violation .proto . field .elements .append (element )
259273
260274 def add_rule_path_elements (self , elements : typing .Iterable [validate_pb2 .FieldPathElement ]):
261- for violation in self ._violations . violations :
262- violation .rule .elements .extend (elements )
275+ for violation in self ._violations :
276+ violation .proto . rule .elements .extend (elements )
263277
264278 @property
265279 def done (self ) -> bool :
266280 return self ._fail_fast and self .has_errors ()
267281
268282 def has_errors (self ) -> bool :
269- return len (self ._violations . violations ) > 0
283+ return len (self ._violations ) > 0
270284
271285 def sub_context (self ):
272286 return ConstraintContext (self ._fail_fast )
@@ -277,55 +291,67 @@ class ConstraintRules:
277291
278292 def validate (self , ctx : ConstraintContext , message : message .Message ): # noqa: ARG002
279293 """Validate the message against the rules in this constraint."""
280- ctx .add (validate_pb2 .Violation (constraint_id = "unimplemented" , message = "Unimplemented" ))
294+ ctx .add (Violation (constraint_id = "unimplemented" , message = "Unimplemented" ))
295+
296+
297+ @dataclasses .dataclass
298+ class CelRunner :
299+ runner : celpy .Runner
300+ constraint : validate_pb2 .Constraint
301+ rule_value : typing .Optional [typing .Any ] = None
302+ rule_cel : typing .Optional [celtypes .Value ] = None
303+ rule_path : typing .Optional [validate_pb2 .FieldPath ] = None
281304
282305
283306class CelConstraintRules (ConstraintRules ):
284307 """A constraint that has rules written in CEL."""
285308
286- _runners : list [
287- tuple [
288- celpy .Runner ,
289- validate_pb2 .Constraint ,
290- typing .Optional [celtypes .Value ],
291- typing .Optional [validate_pb2 .FieldPath ],
292- ]
293- ]
294- _rules_cel : celtypes .Value = None
309+ _cel : list [CelRunner ]
310+ _rules : typing .Optional [message .Message ] = None
311+ _rules_cel : typing .Optional [celtypes .Value ] = None
295312
296313 def __init__ (self , rules : typing .Optional [message .Message ]):
297- self ._runners = []
314+ self ._cel = []
298315 if rules is not None :
316+ self ._rules = rules
299317 self ._rules_cel = _msg_to_cel (rules )
300318
301319 def _validate_cel (
302320 self ,
303321 ctx : ConstraintContext ,
304- activation : dict [str , typing .Any ],
305322 * ,
323+ this_value : typing .Optional [typing .Any ] = None ,
324+ this_cel : typing .Optional [celtypes .Value ] = None ,
306325 for_key : bool = False ,
307326 ):
327+ activation : dict [str , celtypes .Value ] = {}
328+ if this_cel is not None :
329+ activation ["this" ] = this_cel
308330 activation ["rules" ] = self ._rules_cel
309331 activation ["now" ] = celtypes .TimestampType (datetime .datetime .now (tz = datetime .timezone .utc ))
310- for runner , constraint , rule , rule_path in self ._runners :
311- activation ["rule" ] = rule
312- result = runner .evaluate (activation )
332+ for cel in self ._cel :
333+ activation ["rule" ] = cel . rule_cel
334+ result = cel . runner .evaluate (activation )
313335 if isinstance (result , celtypes .BoolType ):
314336 if not result :
315337 ctx .add (
316- validate_pb2 .Violation (
317- rule = rule_path ,
318- constraint_id = constraint .id ,
319- message = constraint .message ,
338+ Violation (
339+ field_value = this_value ,
340+ rule = cel .rule_path ,
341+ rule_value = cel .rule_value ,
342+ constraint_id = cel .constraint .id ,
343+ message = cel .constraint .message ,
320344 for_key = for_key ,
321345 ),
322346 )
323347 elif isinstance (result , celtypes .StringType ):
324348 if result :
325349 ctx .add (
326- validate_pb2 .Violation (
327- rule = rule_path ,
328- constraint_id = constraint .id ,
350+ Violation (
351+ field_value = this_value ,
352+ rule = cel .rule_path ,
353+ rule_value = cel .rule_value ,
354+ constraint_id = cel .constraint .id ,
329355 message = result ,
330356 for_key = for_key ,
331357 ),
@@ -339,19 +365,32 @@ def add_rule(
339365 funcs : dict [str , celpy .CELFunction ],
340366 rules : validate_pb2 .Constraint ,
341367 * ,
342- rule : typing .Optional [celtypes . Value ] = None ,
368+ rule_field : typing .Optional [descriptor . FieldDescriptor ] = None ,
343369 rule_path : typing .Optional [validate_pb2 .FieldPath ] = None ,
344370 ):
345371 ast = env .compile (rules .expression )
346372 prog = env .program (ast , functions = funcs )
347- self ._runners .append ((prog , rules , rule , rule_path ))
373+ rule_value = None
374+ rule_cel = None
375+ if rule_field is not None and self ._rules is not None :
376+ rule_value = _proto_message_get_field (self ._rules , rule_field )
377+ rule_cel = _field_to_cel (self ._rules , rule_field )
378+ self ._cel .append (
379+ CelRunner (
380+ runner = prog ,
381+ constraint = rules ,
382+ rule_value = rule_value ,
383+ rule_cel = rule_cel ,
384+ rule_path = rule_path ,
385+ )
386+ )
348387
349388
350389class MessageConstraintRules (CelConstraintRules ):
351390 """Message-level rules."""
352391
353392 def validate (self , ctx : ConstraintContext , message : message .Message ):
354- self ._validate_cel (ctx , { "this" : _msg_to_cel (message )} )
393+ self ._validate_cel (ctx , this_cel = _msg_to_cel (message ))
355394
356395
357396def check_field_type (field : descriptor .FieldDescriptor , expected : int , wrapper_name : typing .Optional [str ] = None ):
@@ -445,7 +484,7 @@ def __init__(
445484 env ,
446485 funcs ,
447486 cel ,
448- rule = _field_to_cel ( rules , list_field ) ,
487+ rule_field = list_field ,
449488 rule_path = validate_pb2 .FieldPath (
450489 elements = [
451490 _field_to_element (list_field ),
@@ -465,13 +504,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
465504 if _is_empty_field (message , self ._field ):
466505 if self ._required :
467506 ctx .add (
468- validate_pb2 . Violation (
507+ Violation (
469508 field = validate_pb2 .FieldPath (
470509 elements = [
471510 _field_to_element (self ._field ),
472511 ],
473512 ),
474513 rule = FieldConstraintRules ._required_rule_path ,
514+ rule_value = self ._required ,
475515 constraint_id = "required" ,
476516 message = "value is required" ,
477517 ),
@@ -485,15 +525,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
485525 return
486526 sub_ctx = ctx .sub_context ()
487527 self ._validate_value (sub_ctx , val )
488- self ._validate_cel (sub_ctx , { "this" : cel_val } )
528+ self ._validate_cel (sub_ctx , this_value = _proto_message_get_field ( message , self . _field ), this_cel = cel_val )
489529 if sub_ctx .has_errors ():
490530 element = _field_to_element (self ._field )
491531 sub_ctx .add_field_path_element (element )
492532 ctx .add_errors (sub_ctx )
493533
494534 def validate_item (self , ctx : ConstraintContext , val : typing .Any , * , for_key : bool = False ):
495535 self ._validate_value (ctx , val , for_key = for_key )
496- self ._validate_cel (ctx , { "this" : _scalar_field_value_to_cel (val , self ._field )} , for_key = for_key )
536+ self ._validate_cel (ctx , this_value = val , this_cel = _scalar_field_value_to_cel (val , self ._field ), for_key = for_key )
497537
498538 def _validate_value (self , ctx : ConstraintContext , val : typing .Any , * , for_key : bool = False ):
499539 pass
@@ -546,17 +586,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
546586 if len (self ._in ) > 0 :
547587 if value .type_url not in self ._in :
548588 ctx .add (
549- validate_pb2 . Violation (
589+ Violation (
550590 rule = AnyConstraintRules ._in_rule_path ,
591+ rule_value = self ._in ,
551592 constraint_id = "any.in" ,
552593 message = "type URL must be in the allow list" ,
553594 for_key = for_key ,
554595 )
555596 )
556597 if value .type_url in self ._not_in :
557598 ctx .add (
558- validate_pb2 . Violation (
599+ Violation (
559600 rule = AnyConstraintRules ._not_in_rule_path ,
601+ rule_value = self ._not_in ,
560602 constraint_id = "any.not_in" ,
561603 message = "type URL must not be in the block list" ,
562604 for_key = for_key ,
@@ -603,13 +645,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
603645 value = getattr (message , self ._field .name )
604646 if value not in self ._field .enum_type .values_by_number :
605647 ctx .add (
606- validate_pb2 . Violation (
648+ Violation (
607649 field = validate_pb2 .FieldPath (
608650 elements = [
609651 _field_to_element (self ._field ),
610652 ],
611653 ),
612654 rule = EnumConstraintRules ._defined_only_rule_path ,
655+ rule_value = self ._defined_only ,
613656 constraint_id = "enum.defined_only" ,
614657 message = "value must be one of the defined enum values" ,
615658 ),
@@ -742,7 +785,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
742785 if not message .WhichOneof (self ._oneof .name ):
743786 if self .required :
744787 ctx .add (
745- validate_pb2 . Violation (
788+ Violation (
746789 field = validate_pb2 .FieldPath (
747790 elements = [_oneof_to_element (self ._oneof )],
748791 ),
0 commit comments