diff --git a/protovalidate/internal/constraints.py b/protovalidate/internal/constraints.py index 2519108d..48b1ca4d 100644 --- a/protovalidate/internal/constraints.py +++ b/protovalidate/internal/constraints.py @@ -40,7 +40,7 @@ def make_timestamp(msg: message.Message) -> celtypes.TimestampType: def unwrap(msg: message.Message) -> celtypes.Value: - return _field_to_cel(msg, msg.DESCRIPTOR.fields_by_name["value"]) + return field_to_cel(msg, msg.DESCRIPTOR.fields_by_name["value"]) _MSG_TYPE_URL_TO_CTOR: dict[str, typing.Callable[..., celtypes.Value]] = { @@ -70,7 +70,7 @@ def __init__(self, msg: message.Message): for field in self.desc.fields: if field.containing_oneof is not None and not self.msg.HasField(field.name): continue - self[field.name] = _field_to_cel(self.msg, field) + self[field.name] = field_to_cel(self.msg, field) def __getitem__(self, name): field = self.desc.fields_by_name[name] @@ -175,7 +175,7 @@ def _map_field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) - return _map_field_value_to_cel(_proto_message_get_field(msg, field), field) -def _field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value: +def field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value: if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: return _repeated_field_to_cel(msg, field) elif field.message_type is not None and not _proto_message_has_field(msg, field): @@ -374,7 +374,7 @@ def add_rule( rule_cel = None if rule_field is not None and self._rules is not None: rule_value = _proto_message_get_field(self._rules, rule_field) - rule_cel = _field_to_cel(self._rules, rule_field) + rule_cel = field_to_cel(self._rules, rule_field) self._cel.append( CelRunner( runner=prog, diff --git a/protovalidate/internal/extra_func.py b/protovalidate/internal/extra_func.py index eca6e1b5..aa9c5fe9 100644 --- a/protovalidate/internal/extra_func.py +++ b/protovalidate/internal/extra_func.py @@ -21,6 +21,7 @@ from celpy import celtypes from protovalidate.internal import string_format +from protovalidate.internal.constraints import MessageType, field_to_cel # See https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address _email_regex = re.compile( @@ -28,6 +29,19 @@ ) +def cel_get_field(message: celtypes.Value, field_name: celtypes.Value) -> celpy.Result: + if not isinstance(message, MessageType): + msg = "invalid argument, expected message" + raise celpy.CELEvalError(msg) + if not isinstance(field_name, celtypes.StringType): + msg = "invalid argument, expected string" + raise celpy.CELEvalError(msg) + if field_name not in message.desc.fields_by_name: + msg = f"no such field: {field_name}" + raise celpy.CELEvalError(msg) + return field_to_cel(message.msg, message.desc.fields_by_name[field_name]) + + def cel_is_ip(val: celtypes.Value, ver: typing.Optional[celtypes.Value] = None) -> celpy.Result: """Return True if the string is an IPv4 or IPv6 address, optionally limited to a specific version. @@ -1545,6 +1559,7 @@ def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]: # Missing standard functions "format": string_fmt.format, # protovalidate specific functions + "getField": cel_get_field, "isNan": cel_is_nan, "isInf": cel_is_inf, "isIp": cel_is_ip,