Skip to content

Commit 96e5e2a

Browse files
author
Steve Ayers
committed
Format
1 parent e8b4305 commit 96e5e2a

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

protovalidate/internal/rules.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from buf.validate import validate_pb2 # type: ignore
2424
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
2525

26+
# Convenience to stringify the type names for error messages
27+
FIELD_TYPE_NAMES = {v: k for k, v in vars(descriptor.FieldDescriptor).items() if k.startswith("TYPE_")}
28+
2629

2730
class CompilationError(Exception):
2831
pass
@@ -397,7 +400,15 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n
397400
if field.type != expected and (
398401
field.type != descriptor.FieldDescriptor.TYPE_MESSAGE or field.message_type.full_name != wrapper_name
399402
):
400-
msg = f"field {field.name} has type {field.type} but expected {expected}"
403+
field_type_str = FIELD_TYPE_NAMES[field.type]
404+
if expected == 0:
405+
if wrapper_name is not None:
406+
expected_type_str = wrapper_name
407+
else:
408+
expected_type_str = FIELD_TYPE_NAMES[descriptor.FieldDescriptor.TYPE_MESSAGE]
409+
else:
410+
expected_type_str = FIELD_TYPE_NAMES[expected]
411+
msg = f"field {field.name} has type {field_type_str} but expected {expected_type_str}"
401412
raise CompilationError(msg)
402413

403414

@@ -821,6 +832,7 @@ def _new_scalar_field_rule(
821832
if field_level.ignore == validate_pb2.IGNORE_ALWAYS:
822833
return None
823834
type_case = field_level.WhichOneof("type")
835+
# print(f"type case is {type_case}")
824836
if type_case is None:
825837
result = FieldRules(self._env, self._funcs, field, field_level, for_items=for_items)
826838
return result
@@ -929,7 +941,7 @@ def _new_scalar_field_rule(
929941
result = FieldRules(self._env, self._funcs, field, field_level, for_items=for_items)
930942
return result
931943
elif type_case == "any":
932-
check_field_type(field, descriptor.FieldDescriptor.TYPE_MESSAGE, "google.protobuf.Any")
944+
check_field_type(field, 0, "google.protobuf.Any")
933945
result = AnyRules(self._env, self._funcs, field, field_level)
934946
return result
935947

0 commit comments

Comments
 (0)