Skip to content

Commit b7e1cfa

Browse files
author
Steve Ayers
committed
Feedback
1 parent 08bd69f commit b7e1cfa

File tree

1 file changed

+66
-50
lines changed

1 file changed

+66
-50
lines changed

protovalidate/internal/rules.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
from buf.validate import validate_pb2 # type: ignore
2424
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
2525

26-
# Convenience to stringify the type names for error messages
27-
FIELD_TYPE_NAMES = {v: k for k, v in vars(descriptor.FieldDescriptor).items() if k.startswith("TYPE_")}
28-
2926

3027
class CompilationError(Exception):
3128
pass
@@ -61,59 +58,54 @@ def unwrap(msg: message.Message) -> celtypes.Value:
6158
}
6259

6360

64-
class MessageType(celtypes.MapType):
65-
msg: message.Message
66-
desc: descriptor.Descriptor
67-
68-
def __init__(self, msg: message.Message):
69-
super().__init__()
70-
self.msg = msg
71-
self.desc = msg.DESCRIPTOR
72-
field: descriptor.FieldDescriptor
73-
for field in self.desc.fields:
74-
if field.containing_oneof is not None and not self.msg.HasField(field.name):
75-
continue
76-
self[field.name] = field_to_cel(self.msg, field)
77-
78-
def __getitem__(self, name):
79-
field = self.desc.fields_by_name[name]
80-
if field.has_presence and not self.msg.HasField(name):
81-
if in_has():
82-
raise KeyError()
83-
else:
84-
return _zero_value(field)
85-
return super().__getitem__(name)
86-
87-
8861
def _msg_to_cel(msg: message.Message) -> celtypes.Value:
8962
ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name)
9063
if ctor is not None:
9164
return ctor(msg)
9265
return MessageType(msg)
9366

9467

95-
_TYPE_TO_CTOR: dict[str, typing.Callable[..., celtypes.Value]] = {
96-
descriptor.FieldDescriptor.TYPE_MESSAGE: _msg_to_cel,
97-
descriptor.FieldDescriptor.TYPE_GROUP: _msg_to_cel,
98-
descriptor.FieldDescriptor.TYPE_ENUM: celtypes.IntType,
99-
descriptor.FieldDescriptor.TYPE_BOOL: celtypes.BoolType,
100-
descriptor.FieldDescriptor.TYPE_BYTES: celtypes.BytesType,
101-
descriptor.FieldDescriptor.TYPE_STRING: celtypes.StringType,
102-
descriptor.FieldDescriptor.TYPE_FLOAT: celtypes.DoubleType,
103-
descriptor.FieldDescriptor.TYPE_DOUBLE: celtypes.DoubleType,
104-
descriptor.FieldDescriptor.TYPE_INT32: celtypes.IntType,
105-
descriptor.FieldDescriptor.TYPE_INT64: celtypes.IntType,
106-
descriptor.FieldDescriptor.TYPE_UINT32: celtypes.UintType,
107-
descriptor.FieldDescriptor.TYPE_UINT64: celtypes.UintType,
108-
descriptor.FieldDescriptor.TYPE_SINT32: celtypes.IntType,
109-
descriptor.FieldDescriptor.TYPE_SINT64: celtypes.IntType,
110-
descriptor.FieldDescriptor.TYPE_FIXED32: celtypes.UintType,
111-
descriptor.FieldDescriptor.TYPE_FIXED64: celtypes.UintType,
112-
descriptor.FieldDescriptor.TYPE_SFIXED32: celtypes.IntType,
113-
descriptor.FieldDescriptor.TYPE_SFIXED64: celtypes.IntType,
68+
class FieldDescMetadata(typing.TypedDict):
69+
name: str
70+
ctor: typing.Callable[..., celtypes.Value]
71+
72+
73+
_FIELD_DESC_METADATA_MAP: dict[typing.Any, FieldDescMetadata] = {
74+
descriptor.FieldDescriptor.TYPE_MESSAGE: {"name": "message", "ctor": _msg_to_cel},
75+
descriptor.FieldDescriptor.TYPE_GROUP: {"name": "group", "ctor": _msg_to_cel},
76+
descriptor.FieldDescriptor.TYPE_ENUM: {"name": "enum", "ctor": celtypes.IntType},
77+
descriptor.FieldDescriptor.TYPE_BOOL: {"name": "bool", "ctor": celtypes.BoolType},
78+
descriptor.FieldDescriptor.TYPE_BYTES: {"name": "bytes", "ctor": celtypes.BytesType},
79+
descriptor.FieldDescriptor.TYPE_STRING: {"name": "string", "ctor": celtypes.StringType},
80+
descriptor.FieldDescriptor.TYPE_FLOAT: {"name": "float", "ctor": celtypes.DoubleType},
81+
descriptor.FieldDescriptor.TYPE_DOUBLE: {"name": "double", "ctor": celtypes.DoubleType},
82+
descriptor.FieldDescriptor.TYPE_INT32: {"name": "int32", "ctor": celtypes.IntType},
83+
descriptor.FieldDescriptor.TYPE_INT64: {"name": "int64", "ctor": celtypes.IntType},
84+
descriptor.FieldDescriptor.TYPE_SINT32: {"name": "sint32", "ctor": celtypes.IntType},
85+
descriptor.FieldDescriptor.TYPE_SINT64: {"name": "sint64", "ctor": celtypes.IntType},
86+
descriptor.FieldDescriptor.TYPE_SFIXED32: {"name": "sfixed32", "ctor": celtypes.IntType},
87+
descriptor.FieldDescriptor.TYPE_SFIXED64: {"name": "sfixed64", "ctor": celtypes.IntType},
88+
descriptor.FieldDescriptor.TYPE_UINT32: {"name": "uint32", "ctor": celtypes.UintType},
89+
descriptor.FieldDescriptor.TYPE_UINT64: {"name": "uint64", "ctor": celtypes.UintType},
90+
descriptor.FieldDescriptor.TYPE_FIXED32: {"name": "fixed32", "ctor": celtypes.UintType},
91+
descriptor.FieldDescriptor.TYPE_FIXED64: {"name": "fixed64", "ctor": celtypes.UintType},
11492
}
11593

11694

95+
def _get_type_name(fd: typing.Any) -> str:
96+
md = _FIELD_DESC_METADATA_MAP.get(fd)
97+
if md is None:
98+
return "unknown"
99+
return md["name"]
100+
101+
102+
def _get_type_ctor(fd: typing.Any) -> typing.Optional[typing.Callable[..., celtypes.Value]]:
103+
md = _FIELD_DESC_METADATA_MAP.get(fd)
104+
if md is None:
105+
return None
106+
return md["ctor"]
107+
108+
117109
def _proto_message_has_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any:
118110
if field.is_extension:
119111
return msg.HasExtension(field) # type: ignore
@@ -129,7 +121,7 @@ def _proto_message_get_field(msg: message.Message, field: descriptor.FieldDescri
129121

130122

131123
def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
132-
ctor = _TYPE_TO_CTOR.get(field.type)
124+
ctor = _get_type_ctor(field.type)
133125
if ctor is None:
134126
msg = "unknown field type"
135127
raise CompilationError(msg)
@@ -234,6 +226,30 @@ def _set_path_element_map_key(
234226
raise CompilationError(msg)
235227

236228

229+
class MessageType(celtypes.MapType):
230+
msg: message.Message
231+
desc: descriptor.Descriptor
232+
233+
def __init__(self, msg: message.Message):
234+
super().__init__()
235+
self.msg = msg
236+
self.desc = msg.DESCRIPTOR
237+
field: descriptor.FieldDescriptor
238+
for field in self.desc.fields:
239+
if field.containing_oneof is not None and not self.msg.HasField(field.name):
240+
continue
241+
self[field.name] = field_to_cel(self.msg, field)
242+
243+
def __getitem__(self, name):
244+
field = self.desc.fields_by_name[name]
245+
if field.has_presence and not self.msg.HasField(name):
246+
if in_has():
247+
raise KeyError()
248+
else:
249+
return _zero_value(field)
250+
return super().__getitem__(name)
251+
252+
237253
class Violation:
238254
"""A singular rule violation."""
239255

@@ -400,14 +416,14 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n
400416
if field.type != expected and (
401417
field.type != descriptor.FieldDescriptor.TYPE_MESSAGE or field.message_type.full_name != wrapper_name
402418
):
403-
field_type_str = FIELD_TYPE_NAMES[field.type]
419+
field_type_str = _get_type_name(field.type)
404420
if expected == 0:
405421
if wrapper_name is not None:
406422
expected_type_str = wrapper_name
407423
else:
408-
expected_type_str = FIELD_TYPE_NAMES[descriptor.FieldDescriptor.TYPE_MESSAGE]
424+
expected_type_str = _get_type_name(descriptor.FieldDescriptor.TYPE_MESSAGE)
409425
else:
410-
expected_type_str = FIELD_TYPE_NAMES[expected]
426+
expected_type_str = _get_type_name(expected)
411427
msg = f"field {field.name} has type {field_type_str} but expected {expected_type_str}"
412428
raise CompilationError(msg)
413429

0 commit comments

Comments
 (0)