Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 65 additions & 63 deletions gen/buf/validate/validate_pb2.py

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions gen/buf/validate/validate_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 27 additions & 15 deletions gen/tests/example/v1/validations_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions gen/tests/example/v1/validations_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions proto/tests/example/v1/validations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ message Oneof {
}
}

message ProtovalidateOneof {
string a = 1;
string b = 2;
bool unrelated = 3;
option (buf.validate.message).oneof = {
fields: ["a", "b"]
};
}

message ProtovalidateOneofRequired {
string a = 1;
string b = 2;
bool unrelated = 3;
option (buf.validate.message).oneof = {
fields: ["a", "b"]
required: true
};
}

message ProtovalidateOneofUnknownFieldName {
string a = 1;
string b = 2;
bool unrelated = 3;
option (buf.validate.message).oneof = {
fields: ["a", "b", "xxx"]
};
}

message TimestampGTNow {
google.protobuf.Timestamp val = 1 [(buf.validate.field).timestamp.gt_now = true];
}
Expand Down
59 changes: 56 additions & 3 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,62 @@ def add_rule(
)


class MessageOneofRule(Rules):
"""Validates a single buf.validate.MessageOneofRule given via the message option (buf.validate.message).oneof"""

def __init__(self, fields: list[descriptor.FieldDescriptor], *, required: bool):
self._fields = fields
self._required = required

def validate(self, ctx: RuleContext, msg: message.Message):
num_set_fields = sum(1 for field in self._fields if not _is_empty_field(msg, field))
if num_set_fields > 1:
ctx.add(
Violation(
rule_id="message.oneof",
message=f"only one of {', '.join([field.name for field in self._fields])} can be set",
)
)
if self._required and num_set_fields == 0:
ctx.add(
Violation(
rule_id="message.oneof",
message=f"one of {', '.join([field.name for field in self._fields])} must be set",
)
)


class MessageRules(CelRules):
"""Message-level rules."""

_oneofs: list[MessageOneofRule]

def __init__(self, rules: typing.Optional[message.Message], desc: descriptor.Descriptor):
super().__init__(rules)
self._oneofs = []
self._desc = desc

def validate(self, ctx: RuleContext, message: message.Message):
self._validate_cel(ctx, this_cel=_msg_to_cel(message))
if ctx.done:
return
for oneof in self._oneofs:
oneof.validate(ctx, message)
if ctx.done:
return

def add_oneof(
self,
rule: validate_pb2.MessageOneofRule,
):
fields = []
for name in rule.fields:
if name in self._desc.fields_by_name:
fields.append(self._desc.fields_by_name[name])
else:
msg = f'field "{name}" not found in message {self._desc.full_name}'
raise CompilationError(msg)
self._oneofs.append(MessageOneofRule(fields, required=rule.required))


def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_name: typing.Optional[str] = None):
Expand Down Expand Up @@ -832,8 +883,10 @@ def get(self, descriptor: descriptor.Descriptor) -> list[Rules]:
raise result
return result

def _new_message_rule(self, rules: validate_pb2.MessageRules) -> MessageRules:
result = MessageRules(rules)
def _new_message_rule(self, rules: validate_pb2.MessageRules, desc: descriptor.Descriptor) -> MessageRules:
result = MessageRules(rules, desc)
for oneof in rules.oneof:
result.add_oneof(oneof)
for cel in rules.cel:
result.add_rule(self._env, self._funcs, cel)
return result
Expand Down Expand Up @@ -989,7 +1042,7 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
message_level = desc.GetOptions().Extensions[validate_pb2.message]
if message_level.disabled:
return []
if rule := self._new_message_rule(message_level):
if rule := self._new_message_rule(message_level, desc):
result.append(rule)

for oneof in desc.oneofs:
Expand Down
45 changes: 40 additions & 5 deletions tests/validate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,41 @@ def test_oneofs(self):
protovalidate.collect_violations(msg2, into=violations)
assert len(violations) == 0

def test_protovalidate_oneof_valid(self):
msg = validations_pb2.ProtovalidateOneof()
msg.a = "A"
protovalidate.validate(msg)
violations = protovalidate.collect_violations(msg)
assert len(violations) == 0

def test_protovalidate_oneof_violation(self):
msg = validations_pb2.ProtovalidateOneof()
msg.a = "A"
msg.b = "B"
with self.assertRaises(protovalidate.ValidationError) as cm:
protovalidate.validate(msg)
e = cm.exception
assert str(e) == "invalid ProtovalidateOneof"
assert len(e.violations) == 1
assert e.to_proto().violations[0].message == "only one of a, b can be set"

def test_protovalidate_oneof_required_violation(self):
msg = validations_pb2.ProtovalidateOneofRequired()
with self.assertRaises(protovalidate.ValidationError) as cm:
protovalidate.validate(msg)
e = cm.exception
assert str(e) == "invalid ProtovalidateOneofRequired"
assert len(e.violations) == 1
assert e.to_proto().violations[0].message == "one of a, b must be set"

def test_protovalidate_oneof_unknown_field_name(self):
msg = validations_pb2.ProtovalidateOneofUnknownFieldName()
with self.assertRaises(protovalidate.CompilationError) as cm:
protovalidate.validate(msg)
assert (
str(cm.exception) == 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName'
)

def test_repeated(self):
msg = validations_pb2.RepeatedEmbedSkip()
msg.val.add(val=-1)
Expand All @@ -67,12 +102,12 @@ def test_repeated(self):

def test_maps(self):
msg = validations_pb2.MapMinMax()
try:
with self.assertRaises(protovalidate.ValidationError) as cm:
protovalidate.validate(msg)
except protovalidate.ValidationError as e:
assert len(e.violations) == 1
assert len(e.to_proto().violations) == 1
assert str(e) == "invalid MapMinMax"
e = cm.exception
assert len(e.violations) == 1
assert len(e.to_proto().violations) == 1
assert str(e) == "invalid MapMinMax"

violations = protovalidate.collect_violations(msg)
assert len(violations) == 1
Expand Down
Loading