diff --git a/protovalidate/internal/rules.py b/protovalidate/internal/rules.py index 8ce6975f..87edeb42 100644 --- a/protovalidate/internal/rules.py +++ b/protovalidate/internal/rules.py @@ -155,7 +155,7 @@ def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescripto def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value: - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: # type: ignore if field.message_type is not None and field.message_type.GetOptions().map_entry: return _map_field_value_to_cel(val, field) return _repeated_field_value_to_cel(val, field) @@ -165,7 +165,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool: if field.has_presence: return not _proto_message_has_field(msg, field) - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: # type: ignore return len(_proto_message_get_field(msg, field)) == 0 return _proto_message_get_field(msg, field) == field.default_value @@ -194,7 +194,7 @@ def _map_field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) - def field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value: - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: # type: ignore return _repeated_field_to_cel(msg, field) elif field.message_type is not None and not _proto_message_has_field(msg, field): return None @@ -484,19 +484,15 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n def _is_map(field: descriptor.FieldDescriptor): - return ( - field.label == descriptor.FieldDescriptor.LABEL_REPEATED - and field.message_type is not None - and field.message_type.GetOptions().map_entry - ) + return field.is_repeated and field.message_type is not None and field.message_type.GetOptions().map_entry # type: ignore def _is_list(field: descriptor.FieldDescriptor): - return field.label == descriptor.FieldDescriptor.LABEL_REPEATED and not _is_map(field) + return field.is_repeated and not _is_map(field) # type: ignore def _zero_value(field: descriptor.FieldDescriptor): - if field.message_type is not None and field.label != descriptor.FieldDescriptor.LABEL_REPEATED: + if field.message_type is not None and not field.is_repeated: # type: ignore return _field_value_to_cel(message_factory.GetMessageClass(field.message_type)(), field) else: return _field_value_to_cel(field.default_value, field) @@ -1003,7 +999,7 @@ def _new_field_rule( field: descriptor.FieldDescriptor, rules: validate_pb2.FieldRules, ) -> FieldRules: - if field.label != descriptor.FieldDescriptor.LABEL_REPEATED: + if not field.is_repeated: # type: ignore return self._new_scalar_field_rule(field, rules) if field.message_type is not None and field.message_type.GetOptions().map_entry: key_rules = None @@ -1057,7 +1053,7 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]: if value_field.type != descriptor.FieldDescriptor.TYPE_MESSAGE: continue result.append(MapValMsgRule(self, field, key_field, value_field)) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + elif field.is_repeated: result.append(RepeatedMsgRule(self, field)) else: result.append(SubMsgRule(self, field))