|
14 | 14 |
|
15 | 15 | import unittest |
16 | 16 |
|
| 17 | +from google.protobuf import message |
| 18 | + |
17 | 19 | import protovalidate |
18 | 20 | from gen.tests.example.v1 import validations_pb2 |
19 | | -from protovalidate.internal import config |
| 21 | +from protovalidate.internal import config, rules |
| 22 | + |
| 23 | + |
| 24 | +def get_default_validator(): |
| 25 | + """Returns a default validator created in all available ways |
| 26 | +
|
| 27 | + This allows testing for validators created via: |
| 28 | + - module-level singleton |
| 29 | + - instantiated class with no config |
| 30 | + - instantiated class with config |
| 31 | + """ |
| 32 | + return [ |
| 33 | + ("module singleton", protovalidate), |
| 34 | + ("no config", protovalidate.Validator()), |
| 35 | + ("with default config", protovalidate.Validator(config.Config())), |
| 36 | + ] |
| 37 | + |
| 38 | + |
| 39 | +class TestCollectViolations(unittest.TestCase): |
| 40 | + """Test class for testing message validations. |
20 | 41 |
|
| 42 | + A validator can be created via various ways: |
| 43 | + - a module-level singleton, which returns a default validator |
| 44 | + - instantiating the Validator class with no config, which returns a default validator |
| 45 | + - instantiating the Validator class with a config |
| 46 | +
|
| 47 | + In addition, the API for validating a message allows for two approaches: |
| 48 | + - via a call to `validate`, which will raise a ValidationError if validation fails |
| 49 | + - via a call to `collect_violations`, which will not raise an error and instead return a list of violations. |
| 50 | +
|
| 51 | + Unless otherwise noted, each test in this class tests against a validator created via all 3 methods and tests |
| 52 | + validation using both approaches. |
| 53 | + """ |
21 | 54 |
|
22 | | -class TestValidate(unittest.TestCase): |
23 | 55 | def test_ninf(self): |
24 | 56 | msg = validations_pb2.DoubleFinite() |
25 | 57 | msg.val = float("-inf") |
26 | | - violations = protovalidate.collect_violations(msg) |
27 | | - self.assertEqual(len(violations), 1) |
28 | | - self.assertEqual(violations[0].proto.rule_id, "double.finite") |
29 | | - self.assertEqual(violations[0].field_value, msg.val) |
30 | | - self.assertEqual(violations[0].rule_value, True) |
| 58 | + |
| 59 | + expected_violation = rules.Violation() |
| 60 | + expected_violation.proto.message = "value must be finite" |
| 61 | + expected_violation.proto.rule_id = "double.finite" |
| 62 | + expected_violation.field_value = msg.val |
| 63 | + expected_violation.rule_value = True |
| 64 | + |
| 65 | + self._run_invalid_tests(msg, [expected_violation]) |
31 | 66 |
|
32 | 67 | def test_map_key(self): |
33 | 68 | msg = validations_pb2.MapKeys() |
34 | 69 | msg.val[1] = "a" |
35 | | - violations = protovalidate.collect_violations(msg) |
36 | | - self.assertEqual(len(violations), 1) |
37 | | - self.assertEqual(violations[0].proto.for_key, True) |
38 | | - self.assertEqual(violations[0].field_value, 1) |
39 | | - self.assertEqual(violations[0].rule_value, 0) |
40 | 70 |
|
41 | | - def test_sfixed64(self): |
| 71 | + expected_violation = rules.Violation() |
| 72 | + expected_violation.proto.message = "value must be less than 0" |
| 73 | + expected_violation.proto.rule_id = "sint64.lt" |
| 74 | + expected_violation.proto.for_key = True |
| 75 | + expected_violation.field_value = 1 |
| 76 | + expected_violation.rule_value = 0 |
| 77 | + |
| 78 | + self._run_invalid_tests(msg, [expected_violation]) |
| 79 | + |
| 80 | + def test_sfixed64_valid(self): |
42 | 81 | msg = validations_pb2.SFixed64ExLTGT(val=11) |
43 | | - protovalidate.validate(msg) |
44 | 82 |
|
45 | | - violations = protovalidate.collect_violations(msg) |
46 | | - self.assertEqual(len(violations), 0) |
| 83 | + self._run_valid_tests(msg) |
47 | 84 |
|
48 | 85 | def test_oneofs(self): |
| 86 | + msg = validations_pb2.Oneof() |
| 87 | + msg.y = 123 |
| 88 | + |
| 89 | + self._run_valid_tests(msg) |
| 90 | + |
| 91 | + def test_collect_violations_into(self): |
49 | 92 | msg1 = validations_pb2.Oneof() |
50 | 93 | msg1.y = 123 |
51 | | - protovalidate.validate(msg1) |
52 | 94 |
|
53 | 95 | msg2 = validations_pb2.Oneof() |
54 | 96 | msg2.z.val = True |
55 | | - protovalidate.validate(msg2) |
56 | 97 |
|
57 | | - violations = protovalidate.collect_violations(msg1) |
58 | | - protovalidate.collect_violations(msg2, into=violations) |
59 | | - assert len(violations) == 0 |
| 98 | + for label, v in get_default_validator(): |
| 99 | + with self.subTest(label=label): |
| 100 | + # Test collect_violations into |
| 101 | + violations = v.collect_violations(msg1) |
| 102 | + v.collect_violations(msg2, into=violations) |
| 103 | + self.assertEqual(len(violations), 0) |
60 | 104 |
|
61 | 105 | def test_protovalidate_oneof_valid(self): |
62 | 106 | msg = validations_pb2.ProtovalidateOneof() |
63 | 107 | msg.a = "A" |
64 | | - protovalidate.validate(msg) |
65 | | - violations = protovalidate.collect_violations(msg) |
66 | | - assert len(violations) == 0 |
| 108 | + |
| 109 | + self._run_valid_tests(msg) |
67 | 110 |
|
68 | 111 | def test_protovalidate_oneof_violation(self): |
69 | 112 | msg = validations_pb2.ProtovalidateOneof() |
70 | 113 | msg.a = "A" |
71 | 114 | msg.b = "B" |
72 | | - with self.assertRaises(protovalidate.ValidationError) as cm: |
73 | | - protovalidate.validate(msg) |
74 | | - e = cm.exception |
75 | | - assert str(e) == "invalid ProtovalidateOneof" |
76 | | - assert len(e.violations) == 1 |
77 | | - assert e.to_proto().violations[0].message == "only one of a, b can be set" |
| 115 | + |
| 116 | + expected_violation = rules.Violation() |
| 117 | + expected_violation.proto.message = "only one of a, b can be set" |
| 118 | + expected_violation.proto.rule_id = "message.oneof" |
| 119 | + |
| 120 | + self._run_invalid_tests(msg, [expected_violation]) |
78 | 121 |
|
79 | 122 | def test_protovalidate_oneof_required_violation(self): |
80 | 123 | msg = validations_pb2.ProtovalidateOneofRequired() |
81 | | - with self.assertRaises(protovalidate.ValidationError) as cm: |
82 | | - protovalidate.validate(msg) |
83 | | - e = cm.exception |
84 | | - assert str(e) == "invalid ProtovalidateOneofRequired" |
85 | | - assert len(e.violations) == 1 |
86 | | - assert e.to_proto().violations[0].message == "one of a, b must be set" |
| 124 | + |
| 125 | + expected_violation = rules.Violation() |
| 126 | + expected_violation.proto.message = "one of a, b must be set" |
| 127 | + expected_violation.proto.rule_id = "message.oneof" |
| 128 | + |
| 129 | + self._run_invalid_tests(msg, [expected_violation]) |
87 | 130 |
|
88 | 131 | def test_protovalidate_oneof_unknown_field_name(self): |
| 132 | + """Tests that a compilation error is thrown when specifying a oneof rule with an invalid field name""" |
89 | 133 | msg = validations_pb2.ProtovalidateOneofUnknownFieldName() |
90 | | - with self.assertRaises(protovalidate.CompilationError) as cm: |
91 | | - protovalidate.validate(msg) |
92 | | - assert ( |
93 | | - str(cm.exception) == 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName' |
| 134 | + |
| 135 | + self._run_compilation_error_tests( |
| 136 | + msg, 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName' |
94 | 137 | ) |
95 | 138 |
|
96 | 139 | def test_repeated(self): |
97 | 140 | msg = validations_pb2.RepeatedEmbedSkip() |
98 | 141 | msg.val.add(val=-1) |
99 | | - protovalidate.validate(msg) |
100 | 142 |
|
101 | | - violations = protovalidate.collect_violations(msg) |
102 | | - assert len(violations) == 0 |
| 143 | + self._run_valid_tests(msg) |
103 | 144 |
|
104 | 145 | def test_maps(self): |
105 | 146 | msg = validations_pb2.MapMinMax() |
106 | | - with self.assertRaises(protovalidate.ValidationError) as cm: |
107 | | - protovalidate.validate(msg) |
108 | | - e = cm.exception |
109 | | - assert len(e.violations) == 1 |
110 | | - assert len(e.to_proto().violations) == 1 |
111 | | - assert str(e) == "invalid MapMinMax" |
112 | 147 |
|
113 | | - violations = protovalidate.collect_violations(msg) |
114 | | - assert len(violations) == 1 |
| 148 | + expected_violation = rules.Violation() |
| 149 | + expected_violation.proto.message = "map must be at least 2 entries" |
| 150 | + expected_violation.proto.rule_id = "map.min_pairs" |
| 151 | + expected_violation.field_value = {} |
| 152 | + expected_violation.rule_value = 2 |
| 153 | + |
| 154 | + self._run_invalid_tests(msg, [expected_violation]) |
115 | 155 |
|
116 | 156 | def test_timestamp(self): |
117 | 157 | msg = validations_pb2.TimestampGTNow() |
118 | | - violations = protovalidate.collect_violations(msg) |
119 | | - assert len(violations) == 0 |
| 158 | + |
| 159 | + self._run_valid_tests(msg) |
120 | 160 |
|
121 | 161 | def test_multiple_validations(self): |
| 162 | + """Test that a message with multiple violations correctly returns all of them.""" |
122 | 163 | msg = validations_pb2.MultipleValidations() |
123 | 164 | msg.title = "bar" |
124 | 165 | msg.name = "blah" |
125 | | - violations = protovalidate.collect_violations(msg) |
126 | | - assert len(violations) == 2 |
| 166 | + |
| 167 | + expected_violation1 = rules.Violation() |
| 168 | + expected_violation1.proto.message = "value does not have prefix `foo`" |
| 169 | + expected_violation1.proto.rule_id = "string.prefix" |
| 170 | + expected_violation1.field_value = msg.title |
| 171 | + expected_violation1.rule_value = "foo" |
| 172 | + |
| 173 | + expected_violation2 = rules.Violation() |
| 174 | + expected_violation2.proto.message = "value length must be at least 5 characters" |
| 175 | + expected_violation2.proto.rule_id = "string.min_len" |
| 176 | + expected_violation2.field_value = msg.name |
| 177 | + expected_violation2.rule_value = 5 |
| 178 | + |
| 179 | + self._run_invalid_tests(msg, [expected_violation1, expected_violation2]) |
127 | 180 |
|
128 | 181 | def test_fail_fast(self): |
| 182 | + """Test that fail fast correctly fails on first violation |
| 183 | +
|
| 184 | + Note this does not use a default validator, but instead uses one with a custom config |
| 185 | + so that fail_fast can be set to True. |
| 186 | + """ |
129 | 187 | msg = validations_pb2.MultipleValidations() |
130 | 188 | msg.title = "bar" |
131 | 189 | msg.name = "blah" |
| 190 | + |
| 191 | + expected_violation = rules.Violation() |
| 192 | + expected_violation.proto.message = "value does not have prefix `foo`" |
| 193 | + expected_violation.proto.rule_id = "string.prefix" |
| 194 | + expected_violation.field_value = msg.title |
| 195 | + expected_violation.rule_value = "foo" |
| 196 | + |
132 | 197 | cfg = config.Config(fail_fast=True) |
133 | 198 | validator = protovalidate.Validator(config=cfg) |
| 199 | + |
| 200 | + # Test validate |
| 201 | + with self.assertRaises(protovalidate.ValidationError) as cm: |
| 202 | + validator.validate(msg) |
| 203 | + e = cm.exception |
| 204 | + self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") |
| 205 | + self._compare_violations(e.violations, [expected_violation]) |
| 206 | + |
| 207 | + # Test collect_violations |
134 | 208 | violations = validator.collect_violations(msg) |
135 | | - assert len(violations) == 1 |
| 209 | + self._compare_violations(violations, [expected_violation]) |
| 210 | + |
| 211 | + def _run_valid_tests(self, msg: message.Message): |
| 212 | + """A helper function for testing successful validation on a given message |
| 213 | +
|
| 214 | + The tests are run using validators created via all possible methods and |
| 215 | + validation is done via a call to `validate` as well as a call to `collect_violations`. |
| 216 | + """ |
| 217 | + for label, v in get_default_validator(): |
| 218 | + with self.subTest(label=label): |
| 219 | + # Test validate |
| 220 | + try: |
| 221 | + v.validate(msg) |
| 222 | + except Exception: |
| 223 | + self.fail(f"[{label}]: unexpected validation failure") |
| 224 | + |
| 225 | + # Test collect_violations |
| 226 | + violations = v.collect_violations(msg) |
| 227 | + self.assertEqual(len(violations), 0) |
| 228 | + |
| 229 | + def _run_invalid_tests(self, msg: message.Message, expected: list[rules.Violation]): |
| 230 | + """A helper function for testing unsuccessful validation on a given message |
| 231 | +
|
| 232 | + The tests are run using validators created via all possible methods and |
| 233 | + validation is done via a call to `validate` as well as a call to `collect_violations`. |
| 234 | + """ |
| 235 | + for label, v in get_default_validator(): |
| 236 | + with self.subTest(label=label): |
| 237 | + # Test validate |
| 238 | + with self.assertRaises(protovalidate.ValidationError) as cm: |
| 239 | + v.validate(msg) |
| 240 | + e = cm.exception |
| 241 | + self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") |
| 242 | + self._compare_violations(e.violations, expected) |
| 243 | + |
| 244 | + # Test collect_violations |
| 245 | + violations = v.collect_violations(msg) |
| 246 | + self._compare_violations(violations, expected) |
| 247 | + |
| 248 | + def _run_compilation_error_tests(self, msg: message.Message, expected: str): |
| 249 | + """A helper function for testing compilation errors when validating. |
| 250 | +
|
| 251 | + The tests are run using validators created via all possible methods and |
| 252 | + validation is done via a call to `validate` as well as a call to `collect_violations`. |
| 253 | + """ |
| 254 | + for label, v in get_default_validator(): |
| 255 | + with self.subTest(label=label): |
| 256 | + with self.assertRaises(protovalidate.CompilationError) as cvce: |
| 257 | + v.collect_violations(msg) |
| 258 | + assert str(cvce.exception) == expected |
| 259 | + |
| 260 | + with self.assertRaises(protovalidate.CompilationError) as vce: |
| 261 | + v.validate(msg) |
| 262 | + assert str(vce.exception) == expected |
| 263 | + |
| 264 | + def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None: |
| 265 | + """Compares two lists of violations. The violations are expected to be in the expected order also.""" |
| 266 | + self.assertEqual(len(actual), len(expected)) |
| 267 | + for a, e in zip(actual, expected): |
| 268 | + self.assertEqual(a.proto.message, e.proto.message) |
| 269 | + self.assertEqual(a.proto.rule_id, e.proto.rule_id) |
| 270 | + self.assertEqual(a.proto.for_key, e.proto.for_key) |
| 271 | + self.assertEqual(a.field_value, e.field_value) |
| 272 | + self.assertEqual(a.rule_value, e.rule_value) |
0 commit comments