2323from buf .validate import validate_pb2 # type: ignore
2424from protovalidate .internal .cel_field_presence import InterpretedRunner , in_has
2525
26- # Convenience to stringify the type names for error messages
27- FIELD_TYPE_NAMES = {v : k for k , v in vars (descriptor .FieldDescriptor ).items () if k .startswith ("TYPE_" )}
28-
2926
3027class CompilationError (Exception ):
3128 pass
@@ -61,59 +58,54 @@ def unwrap(msg: message.Message) -> celtypes.Value:
6158}
6259
6360
64- class MessageType (celtypes .MapType ):
65- msg : message .Message
66- desc : descriptor .Descriptor
67-
68- def __init__ (self , msg : message .Message ):
69- super ().__init__ ()
70- self .msg = msg
71- self .desc = msg .DESCRIPTOR
72- field : descriptor .FieldDescriptor
73- for field in self .desc .fields :
74- if field .containing_oneof is not None and not self .msg .HasField (field .name ):
75- continue
76- self [field .name ] = field_to_cel (self .msg , field )
77-
78- def __getitem__ (self , name ):
79- field = self .desc .fields_by_name [name ]
80- if field .has_presence and not self .msg .HasField (name ):
81- if in_has ():
82- raise KeyError ()
83- else :
84- return _zero_value (field )
85- return super ().__getitem__ (name )
86-
87-
8861def _msg_to_cel (msg : message .Message ) -> celtypes .Value :
8962 ctor = _MSG_TYPE_URL_TO_CTOR .get (msg .DESCRIPTOR .full_name )
9063 if ctor is not None :
9164 return ctor (msg )
9265 return MessageType (msg )
9366
9467
95- _TYPE_TO_CTOR : dict [str , typing .Callable [..., celtypes .Value ]] = {
96- descriptor .FieldDescriptor .TYPE_MESSAGE : _msg_to_cel ,
97- descriptor .FieldDescriptor .TYPE_GROUP : _msg_to_cel ,
98- descriptor .FieldDescriptor .TYPE_ENUM : celtypes .IntType ,
99- descriptor .FieldDescriptor .TYPE_BOOL : celtypes .BoolType ,
100- descriptor .FieldDescriptor .TYPE_BYTES : celtypes .BytesType ,
101- descriptor .FieldDescriptor .TYPE_STRING : celtypes .StringType ,
102- descriptor .FieldDescriptor .TYPE_FLOAT : celtypes .DoubleType ,
103- descriptor .FieldDescriptor .TYPE_DOUBLE : celtypes .DoubleType ,
104- descriptor .FieldDescriptor .TYPE_INT32 : celtypes .IntType ,
105- descriptor .FieldDescriptor .TYPE_INT64 : celtypes .IntType ,
106- descriptor .FieldDescriptor .TYPE_UINT32 : celtypes .UintType ,
107- descriptor .FieldDescriptor .TYPE_UINT64 : celtypes .UintType ,
108- descriptor .FieldDescriptor .TYPE_SINT32 : celtypes .IntType ,
109- descriptor .FieldDescriptor .TYPE_SINT64 : celtypes .IntType ,
110- descriptor .FieldDescriptor .TYPE_FIXED32 : celtypes .UintType ,
111- descriptor .FieldDescriptor .TYPE_FIXED64 : celtypes .UintType ,
112- descriptor .FieldDescriptor .TYPE_SFIXED32 : celtypes .IntType ,
113- descriptor .FieldDescriptor .TYPE_SFIXED64 : celtypes .IntType ,
68+ class FieldDescMetadata (typing .TypedDict ):
69+ name : str
70+ ctor : typing .Callable [..., celtypes .Value ]
71+
72+
73+ _FIELD_DESC_METADATA_MAP : dict [typing .Any , FieldDescMetadata ] = {
74+ descriptor .FieldDescriptor .TYPE_MESSAGE : {"name" : "message" , "ctor" : _msg_to_cel },
75+ descriptor .FieldDescriptor .TYPE_GROUP : {"name" : "group" , "ctor" : _msg_to_cel },
76+ descriptor .FieldDescriptor .TYPE_ENUM : {"name" : "enum" , "ctor" : celtypes .IntType },
77+ descriptor .FieldDescriptor .TYPE_BOOL : {"name" : "bool" , "ctor" : celtypes .BoolType },
78+ descriptor .FieldDescriptor .TYPE_BYTES : {"name" : "bytes" , "ctor" : celtypes .BytesType },
79+ descriptor .FieldDescriptor .TYPE_STRING : {"name" : "string" , "ctor" : celtypes .StringType },
80+ descriptor .FieldDescriptor .TYPE_FLOAT : {"name" : "float" , "ctor" : celtypes .DoubleType },
81+ descriptor .FieldDescriptor .TYPE_DOUBLE : {"name" : "double" , "ctor" : celtypes .DoubleType },
82+ descriptor .FieldDescriptor .TYPE_INT32 : {"name" : "int32" , "ctor" : celtypes .IntType },
83+ descriptor .FieldDescriptor .TYPE_INT64 : {"name" : "int64" , "ctor" : celtypes .IntType },
84+ descriptor .FieldDescriptor .TYPE_SINT32 : {"name" : "sint32" , "ctor" : celtypes .IntType },
85+ descriptor .FieldDescriptor .TYPE_SINT64 : {"name" : "sint64" , "ctor" : celtypes .IntType },
86+ descriptor .FieldDescriptor .TYPE_SFIXED32 : {"name" : "sfixed32" , "ctor" : celtypes .IntType },
87+ descriptor .FieldDescriptor .TYPE_SFIXED64 : {"name" : "sfixed64" , "ctor" : celtypes .IntType },
88+ descriptor .FieldDescriptor .TYPE_UINT32 : {"name" : "uint32" , "ctor" : celtypes .UintType },
89+ descriptor .FieldDescriptor .TYPE_UINT64 : {"name" : "uint64" , "ctor" : celtypes .UintType },
90+ descriptor .FieldDescriptor .TYPE_FIXED32 : {"name" : "fixed32" , "ctor" : celtypes .UintType },
91+ descriptor .FieldDescriptor .TYPE_FIXED64 : {"name" : "fixed64" , "ctor" : celtypes .UintType },
11492}
11593
11694
95+ def _get_type_name (fd : typing .Any ) -> str :
96+ md = _FIELD_DESC_METADATA_MAP .get (fd )
97+ if md is None :
98+ return "unknown"
99+ return md ["name" ]
100+
101+
102+ def _get_type_ctor (fd : typing .Any ) -> typing .Optional [typing .Callable [..., celtypes .Value ]]:
103+ md = _FIELD_DESC_METADATA_MAP .get (fd )
104+ if md is None :
105+ return None
106+ return md ["ctor" ]
107+
108+
117109def _proto_message_has_field (msg : message .Message , field : descriptor .FieldDescriptor ) -> typing .Any :
118110 if field .is_extension :
119111 return msg .HasExtension (field ) # type: ignore
@@ -129,7 +121,7 @@ def _proto_message_get_field(msg: message.Message, field: descriptor.FieldDescri
129121
130122
131123def _scalar_field_value_to_cel (val : typing .Any , field : descriptor .FieldDescriptor ) -> celtypes .Value :
132- ctor = _TYPE_TO_CTOR . get (field .type )
124+ ctor = _get_type_ctor (field .type )
133125 if ctor is None :
134126 msg = "unknown field type"
135127 raise CompilationError (msg )
@@ -234,6 +226,30 @@ def _set_path_element_map_key(
234226 raise CompilationError (msg )
235227
236228
229+ class MessageType (celtypes .MapType ):
230+ msg : message .Message
231+ desc : descriptor .Descriptor
232+
233+ def __init__ (self , msg : message .Message ):
234+ super ().__init__ ()
235+ self .msg = msg
236+ self .desc = msg .DESCRIPTOR
237+ field : descriptor .FieldDescriptor
238+ for field in self .desc .fields :
239+ if field .containing_oneof is not None and not self .msg .HasField (field .name ):
240+ continue
241+ self [field .name ] = field_to_cel (self .msg , field )
242+
243+ def __getitem__ (self , name ):
244+ field = self .desc .fields_by_name [name ]
245+ if field .has_presence and not self .msg .HasField (name ):
246+ if in_has ():
247+ raise KeyError ()
248+ else :
249+ return _zero_value (field )
250+ return super ().__getitem__ (name )
251+
252+
237253class Violation :
238254 """A singular rule violation."""
239255
@@ -400,14 +416,14 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n
400416 if field .type != expected and (
401417 field .type != descriptor .FieldDescriptor .TYPE_MESSAGE or field .message_type .full_name != wrapper_name
402418 ):
403- field_type_str = FIELD_TYPE_NAMES [ field .type ]
419+ field_type_str = _get_type_name ( field .type )
404420 if expected == 0 :
405421 if wrapper_name is not None :
406422 expected_type_str = wrapper_name
407423 else :
408- expected_type_str = FIELD_TYPE_NAMES [ descriptor .FieldDescriptor .TYPE_MESSAGE ]
424+ expected_type_str = _get_type_name ( descriptor .FieldDescriptor .TYPE_MESSAGE )
409425 else :
410- expected_type_str = FIELD_TYPE_NAMES [ expected ]
426+ expected_type_str = _get_type_name ( expected )
411427 msg = f"field { field .name } has type { field_type_str } but expected { expected_type_str } "
412428 raise CompilationError (msg )
413429
0 commit comments