Skip to content

Commit a6353d1

Browse files
author
Steve Ayers
committed
Switch to a config
1 parent 1e33057 commit a6353d1

File tree

3 files changed

+204
-63
lines changed

3 files changed

+204
-63
lines changed

protovalidate/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
ValidationError = validator.ValidationError
2020
Violations = validator.Violations
2121

22-
_validator = Validator()
23-
validate = _validator.validate
24-
collect_violations = _validator.collect_violations
22+
_default_validator = Validator()
23+
validate = _default_validator.validate
24+
collect_violations = _default_validator.collect_violations
2525

2626
__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]

protovalidate/validator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def validate(
5454
message: The message to validate.
5555
Raises:
5656
CompilationError: If the static rules could not be compiled.
57-
ValidationError: If the message is invalid.
57+
ValidationError: If the message is invalid. The violations raised as part of this error should
58+
always be equal to the list of violations returned by `collect_violations`.
5859
"""
5960
violations = self.collect_violations(message)
6061
if len(violations) > 0:
@@ -69,9 +70,12 @@ def collect_violations(
6970
) -> list[Violation]:
7071
"""
7172
Validates the given message against the static rules defined in
72-
the message's descriptor. Compared to validate, collect_violations is
73-
faster but puts the burden of raising an appropriate exception on the
74-
caller.
73+
the message's descriptor. Compared to `validate`, `collect_violations` simply
74+
returns the violations as a list and puts the burden of raising an appropriate
75+
exception on the caller.
76+
77+
The violations returned from this method should always be equal to the violations
78+
raised as part of the ValidationError in the call to `validate`.
7579
7680
Parameters:
7781
message: The message to validate.

tests/validate_test.py

Lines changed: 193 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,122 +14,259 @@
1414

1515
import unittest
1616

17+
from google.protobuf import message
18+
1719
import protovalidate
1820
from gen.tests.example.v1 import validations_pb2
19-
from protovalidate.internal import config
21+
from protovalidate.internal import config, rules
22+
23+
24+
def get_default_validator():
25+
"""Returns a default validator created in all available ways
26+
27+
This allows testing for validators created via:
28+
- module-level singleton
29+
- instantiated class with no config
30+
- instantiated class with config
31+
"""
32+
return [
33+
("module singleton", protovalidate),
34+
("no config", protovalidate.Validator()),
35+
("with default config", protovalidate.Validator(config.Config())),
36+
]
37+
38+
39+
class TestCollectViolations(unittest.TestCase):
40+
"""Test class for testing message validations.
2041
42+
A validator can be created via various ways:
43+
- a module-level singleton, which returns a default validator
44+
- instantiating the Validator class with no config, which returns a default validator
45+
- instantiating the Validator class with a config
46+
47+
In addition, the API for validating a message allows for two approaches:
48+
- via a call to `validate`, which will raise a ValidationError if validation fails
49+
- via a call to `collect_violations`, which will not raise an error and instead return a list of violations.
50+
51+
Unless otherwise noted, each test in this class tests against a validator created via all 3 methods and tests
52+
validation using both approaches.
53+
"""
2154

22-
class TestValidate(unittest.TestCase):
2355
def test_ninf(self):
2456
msg = validations_pb2.DoubleFinite()
2557
msg.val = float("-inf")
26-
violations = protovalidate.collect_violations(msg)
27-
self.assertEqual(len(violations), 1)
28-
self.assertEqual(violations[0].proto.rule_id, "double.finite")
29-
self.assertEqual(violations[0].field_value, msg.val)
30-
self.assertEqual(violations[0].rule_value, True)
58+
59+
expected_violation = rules.Violation()
60+
expected_violation.proto.message = "value must be finite"
61+
expected_violation.proto.rule_id = "double.finite"
62+
expected_violation.field_value = msg.val
63+
expected_violation.rule_value = True
64+
65+
self._run_invalid_tests(msg, [expected_violation])
3166

3267
def test_map_key(self):
3368
msg = validations_pb2.MapKeys()
3469
msg.val[1] = "a"
35-
violations = protovalidate.collect_violations(msg)
36-
self.assertEqual(len(violations), 1)
37-
self.assertEqual(violations[0].proto.for_key, True)
38-
self.assertEqual(violations[0].field_value, 1)
39-
self.assertEqual(violations[0].rule_value, 0)
4070

41-
def test_sfixed64(self):
71+
expected_violation = rules.Violation()
72+
expected_violation.proto.message = "value must be less than 0"
73+
expected_violation.proto.rule_id = "sint64.lt"
74+
expected_violation.proto.for_key = True
75+
expected_violation.field_value = 1
76+
expected_violation.rule_value = 0
77+
78+
self._run_invalid_tests(msg, [expected_violation])
79+
80+
def test_sfixed64_valid(self):
4281
msg = validations_pb2.SFixed64ExLTGT(val=11)
43-
protovalidate.validate(msg)
4482

45-
violations = protovalidate.collect_violations(msg)
46-
self.assertEqual(len(violations), 0)
83+
self._run_valid_tests(msg)
4784

4885
def test_oneofs(self):
86+
msg = validations_pb2.Oneof()
87+
msg.y = 123
88+
89+
self._run_valid_tests(msg)
90+
91+
def test_collect_violations_into(self):
4992
msg1 = validations_pb2.Oneof()
5093
msg1.y = 123
51-
protovalidate.validate(msg1)
5294

5395
msg2 = validations_pb2.Oneof()
5496
msg2.z.val = True
55-
protovalidate.validate(msg2)
5697

57-
violations = protovalidate.collect_violations(msg1)
58-
protovalidate.collect_violations(msg2, into=violations)
59-
assert len(violations) == 0
98+
for label, v in get_default_validator():
99+
with self.subTest(label=label):
100+
# Test collect_violations into
101+
violations = v.collect_violations(msg1)
102+
v.collect_violations(msg2, into=violations)
103+
self.assertEqual(len(violations), 0)
60104

61105
def test_protovalidate_oneof_valid(self):
62106
msg = validations_pb2.ProtovalidateOneof()
63107
msg.a = "A"
64-
protovalidate.validate(msg)
65-
violations = protovalidate.collect_violations(msg)
66-
assert len(violations) == 0
108+
109+
self._run_valid_tests(msg)
67110

68111
def test_protovalidate_oneof_violation(self):
69112
msg = validations_pb2.ProtovalidateOneof()
70113
msg.a = "A"
71114
msg.b = "B"
72-
with self.assertRaises(protovalidate.ValidationError) as cm:
73-
protovalidate.validate(msg)
74-
e = cm.exception
75-
assert str(e) == "invalid ProtovalidateOneof"
76-
assert len(e.violations) == 1
77-
assert e.to_proto().violations[0].message == "only one of a, b can be set"
115+
116+
expected_violation = rules.Violation()
117+
expected_violation.proto.message = "only one of a, b can be set"
118+
expected_violation.proto.rule_id = "message.oneof"
119+
120+
self._run_invalid_tests(msg, [expected_violation])
78121

79122
def test_protovalidate_oneof_required_violation(self):
80123
msg = validations_pb2.ProtovalidateOneofRequired()
81-
with self.assertRaises(protovalidate.ValidationError) as cm:
82-
protovalidate.validate(msg)
83-
e = cm.exception
84-
assert str(e) == "invalid ProtovalidateOneofRequired"
85-
assert len(e.violations) == 1
86-
assert e.to_proto().violations[0].message == "one of a, b must be set"
124+
125+
expected_violation = rules.Violation()
126+
expected_violation.proto.message = "one of a, b must be set"
127+
expected_violation.proto.rule_id = "message.oneof"
128+
129+
self._run_invalid_tests(msg, [expected_violation])
87130

88131
def test_protovalidate_oneof_unknown_field_name(self):
132+
"""Tests that a compilation error is thrown when specifying a oneof rule with an invalid field name"""
89133
msg = validations_pb2.ProtovalidateOneofUnknownFieldName()
90-
with self.assertRaises(protovalidate.CompilationError) as cm:
91-
protovalidate.validate(msg)
92-
assert (
93-
str(cm.exception) == 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName'
134+
135+
self._run_compilation_error_tests(
136+
msg, 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName'
94137
)
95138

96139
def test_repeated(self):
97140
msg = validations_pb2.RepeatedEmbedSkip()
98141
msg.val.add(val=-1)
99-
protovalidate.validate(msg)
100142

101-
violations = protovalidate.collect_violations(msg)
102-
assert len(violations) == 0
143+
self._run_valid_tests(msg)
103144

104145
def test_maps(self):
105146
msg = validations_pb2.MapMinMax()
106-
with self.assertRaises(protovalidate.ValidationError) as cm:
107-
protovalidate.validate(msg)
108-
e = cm.exception
109-
assert len(e.violations) == 1
110-
assert len(e.to_proto().violations) == 1
111-
assert str(e) == "invalid MapMinMax"
112147

113-
violations = protovalidate.collect_violations(msg)
114-
assert len(violations) == 1
148+
expected_violation = rules.Violation()
149+
expected_violation.proto.message = "map must be at least 2 entries"
150+
expected_violation.proto.rule_id = "map.min_pairs"
151+
expected_violation.field_value = {}
152+
expected_violation.rule_value = 2
153+
154+
self._run_invalid_tests(msg, [expected_violation])
115155

116156
def test_timestamp(self):
117157
msg = validations_pb2.TimestampGTNow()
118-
violations = protovalidate.collect_violations(msg)
119-
assert len(violations) == 0
158+
159+
self._run_valid_tests(msg)
120160

121161
def test_multiple_validations(self):
162+
"""Test that a message with multiple violations correctly returns all of them."""
122163
msg = validations_pb2.MultipleValidations()
123164
msg.title = "bar"
124165
msg.name = "blah"
125-
violations = protovalidate.collect_violations(msg)
126-
assert len(violations) == 2
166+
167+
expected_violation1 = rules.Violation()
168+
expected_violation1.proto.message = "value does not have prefix `foo`"
169+
expected_violation1.proto.rule_id = "string.prefix"
170+
expected_violation1.field_value = msg.title
171+
expected_violation1.rule_value = "foo"
172+
173+
expected_violation2 = rules.Violation()
174+
expected_violation2.proto.message = "value length must be at least 5 characters"
175+
expected_violation2.proto.rule_id = "string.min_len"
176+
expected_violation2.field_value = msg.name
177+
expected_violation2.rule_value = 5
178+
179+
self._run_invalid_tests(msg, [expected_violation1, expected_violation2])
127180

128181
def test_fail_fast(self):
182+
"""Test that fail fast correctly fails on first violation
183+
184+
Note this does not use a default validator, but instead uses one with a custom config
185+
so that fail_fast can be set to True.
186+
"""
129187
msg = validations_pb2.MultipleValidations()
130188
msg.title = "bar"
131189
msg.name = "blah"
190+
191+
expected_violation = rules.Violation()
192+
expected_violation.proto.message = "value does not have prefix `foo`"
193+
expected_violation.proto.rule_id = "string.prefix"
194+
expected_violation.field_value = msg.title
195+
expected_violation.rule_value = "foo"
196+
132197
cfg = config.Config(fail_fast=True)
133198
validator = protovalidate.Validator(config=cfg)
199+
200+
# Test validate
201+
with self.assertRaises(protovalidate.ValidationError) as cm:
202+
validator.validate(msg)
203+
e = cm.exception
204+
self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}")
205+
self._compare_violations(e.violations, [expected_violation])
206+
207+
# Test collect_violations
134208
violations = validator.collect_violations(msg)
135-
assert len(violations) == 1
209+
self._compare_violations(violations, [expected_violation])
210+
211+
def _run_valid_tests(self, msg: message.Message):
212+
"""A helper function for testing successful validation on a given message
213+
214+
The tests are run using validators created via all possible methods and
215+
validation is done via a call to `validate` as well as a call to `collect_violations`.
216+
"""
217+
for label, v in get_default_validator():
218+
with self.subTest(label=label):
219+
# Test validate
220+
try:
221+
v.validate(msg)
222+
except Exception:
223+
self.fail(f"[{label}]: unexpected validation failure")
224+
225+
# Test collect_violations
226+
violations = v.collect_violations(msg)
227+
self.assertEqual(len(violations), 0)
228+
229+
def _run_invalid_tests(self, msg: message.Message, expected: list[rules.Violation]):
230+
"""A helper function for testing unsuccessful validation on a given message
231+
232+
The tests are run using validators created via all possible methods and
233+
validation is done via a call to `validate` as well as a call to `collect_violations`.
234+
"""
235+
for label, v in get_default_validator():
236+
with self.subTest(label=label):
237+
# Test validate
238+
with self.assertRaises(protovalidate.ValidationError) as cm:
239+
v.validate(msg)
240+
e = cm.exception
241+
self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}")
242+
self._compare_violations(e.violations, expected)
243+
244+
# Test collect_violations
245+
violations = v.collect_violations(msg)
246+
self._compare_violations(violations, expected)
247+
248+
def _run_compilation_error_tests(self, msg: message.Message, expected: str):
249+
"""A helper function for testing compilation errors when validating.
250+
251+
The tests are run using validators created via all possible methods and
252+
validation is done via a call to `validate` as well as a call to `collect_violations`.
253+
"""
254+
for label, v in get_default_validator():
255+
with self.subTest(label=label):
256+
with self.assertRaises(protovalidate.CompilationError) as cvce:
257+
v.collect_violations(msg)
258+
assert str(cvce.exception) == expected
259+
260+
with self.assertRaises(protovalidate.CompilationError) as vce:
261+
v.validate(msg)
262+
assert str(vce.exception) == expected
263+
264+
def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None:
265+
"""Compares two lists of violations. The violations are expected to be in the expected order also."""
266+
self.assertEqual(len(actual), len(expected))
267+
for a, e in zip(actual, expected):
268+
self.assertEqual(a.proto.message, e.proto.message)
269+
self.assertEqual(a.proto.rule_id, e.proto.rule_id)
270+
self.assertEqual(a.proto.for_key, e.proto.for_key)
271+
self.assertEqual(a.field_value, e.field_value)
272+
self.assertEqual(a.rule_value, e.rule_value)

0 commit comments

Comments
 (0)