From 1f62639d5cf0ece82221e0b43068e5ab42e9fe09 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 11 Jun 2025 16:23:33 -0400 Subject: [PATCH 1/9] Fix flags --- protovalidate/internal/extra_func.py | 82 +++++++++++++++++++++++++++- tests/matches_test.py | 50 +++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 tests/matches_test.py diff --git a/protovalidate/internal/extra_func.py b/protovalidate/internal/extra_func.py index ae2b02c0..c0d4912c 100644 --- a/protovalidate/internal/extra_func.py +++ b/protovalidate/internal/extra_func.py @@ -14,6 +14,9 @@ import math import re +import sys +from functools import reduce +import operator import typing from urllib import parse as urlparse @@ -1553,13 +1556,90 @@ def __peek(self, char: str) -> bool: return self._index < len(self._string) and self._string[self._index] == char +# Patterns that are supported in Python's re package and not in re2. +# RE2: https://github.com/google/re2/wiki/syntax +invalid_patterns = [ + r"\\[1-9]", # backreference + r"\\k<\w+>", # backreference + r"\(\?\=", # lookahead + r"\(\?\!", # negative lookahead + r"\(\?\<\=", # lookbehind + r"\(\?\<\!", # negative lookbehind + r"\\c[A-Z]", # control character + r"\\u[0-9a-fA-F]{4}", # UTF-16 code-unit + r"\\0(?!\d)", # NUL + r"\[\\b.*\]", # Backspace eg: [\b] +] + +flag_pattern = re.compile(r"^\(\?(?P[ims\-]+)\)"); + +flag_mapping = { + "a": re.A, + "i": re.I, + "l": re.L, + "m": re.M, +} + +def flags_from_letters(letters: str) -> int: + return reduce(operator.or_, (flag_mapping[c] for c in letters if c in flag_mapping), 0) + +def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result: + if not isinstance(text, celtypes.StringType): + msg = "invalid argument for text, expected string" + raise celpy.CELEvalError(msg) + if not isinstance(pattern, celtypes.StringType): + msg = "invalid argument for pattern, expected string" + raise celpy.CELEvalError(msg) + + for invalid_pattern in invalid_patterns: + r = re.search(invalid_pattern, pattern) + if r is not None: + msg = f"error evaluating pattern {pattern}, invalid RE2 syntax" + raise celpy.CELEvalError(msg) + # CEL uses RE2 syntax which is a subset of Python re except for + # the flags and the ability to change the flags mid sequence. + # + # The conformance tests use flags at the very beginning of the sequence, which + # is likely the most common place where this rare feature will be used. + # + # Instead of importing an RE2 engine to be able to support this niche, we + # can instead just check for the flags at the very beginning and apply them. + # + # Unsupported flags and flags mid sequence will fail to compile the regex. + # + # Users can choose to override this function and provide an RE2 engine if they really need to. + flags = "" + flag_matches = re.match(flag_pattern, pattern) + pattern_str = pattern + if flag_matches is not None: + ms = flag_matches.groupdict() + flagsies = ms["flags"] + for fl in flagsies: + if fl == "-": + continue + flags += fl + + pattern_str = pattern[len(flag_matches[0]):] + flags_enums = flags_from_letters(flags) + + expresh = re.compile(pattern_str, flags=flags_enums) + + try: + m = re.search(expresh, text) + except re.error as ex: + return celpy.CELEvalError("match error", ex.__class__, ex.args) + + return celtypes.BoolType(m is not None) + + def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]: - # TODO(#257): Fix types and add tests for StringFormat. # For now, ignoring the type. string_fmt = string_format.StringFormat(locale) # type: ignore return { # Missing standard functions "format": string_fmt.format, + # Overridden standard functions + "matches": cel_matches, # protovalidate specific functions "getField": cel_get_field, "isNan": cel_is_nan, diff --git a/tests/matches_test.py b/tests/matches_test.py new file mode 100644 index 00000000..a4274499 --- /dev/null +++ b/tests/matches_test.py @@ -0,0 +1,50 @@ +# Copyright 2023-2025 Buf Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import celpy +from celpy import celtypes + +from protovalidate.internal import extra_func + +invalid_patterns = [ + r"\1", + r"\k", + r"Jack(?=Sprat)", + "Jack(?!Sprat)", + "(?<=Sprat)Jack", + "(? None: + result = extra_func.cel_matches(celtypes.StringType("!@#$%^&*()"), celtypes.StringType("(?i)^[a-z0-9]+$")) + self.assertFalse(result) + + From 3bb71cb3ecc089963cdfa0c56d8fc24ac930f0e2 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 11 Jun 2025 16:39:40 -0400 Subject: [PATCH 2/9] Tests --- protovalidate/internal/extra_func.py | 63 +++++++++++++--------------- tests/matches_test.py | 2 - 2 files changed, 30 insertions(+), 35 deletions(-) diff --git a/protovalidate/internal/extra_func.py b/protovalidate/internal/extra_func.py index c0d4912c..14f9b18c 100644 --- a/protovalidate/internal/extra_func.py +++ b/protovalidate/internal/extra_func.py @@ -13,11 +13,10 @@ # limitations under the License. import math -import re -import sys -from functools import reduce import operator +import re import typing +from functools import reduce from urllib import parse as urlparse import celpy @@ -1571,17 +1570,14 @@ def __peek(self, char: str) -> bool: r"\[\\b.*\]", # Backspace eg: [\b] ] -flag_pattern = re.compile(r"^\(\?(?P[ims\-]+)\)"); - +flag_pattern = re.compile(r"^\(\?(?P[ims\-]+)\)") flag_mapping = { - "a": re.A, - "i": re.I, - "l": re.L, - "m": re.M, + "a": re.A, + "i": re.I, + "l": re.L, + "m": re.M, } -def flags_from_letters(letters: str) -> int: - return reduce(operator.or_, (flag_mapping[c] for c in letters if c in flag_mapping), 0) def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result: if not isinstance(text, celtypes.StringType): @@ -1591,41 +1587,42 @@ def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result: msg = "invalid argument for pattern, expected string" raise celpy.CELEvalError(msg) + # Simulate re2 by failing on any patterns not compatible with re2 syntax for invalid_pattern in invalid_patterns: r = re.search(invalid_pattern, pattern) if r is not None: msg = f"error evaluating pattern {pattern}, invalid RE2 syntax" raise celpy.CELEvalError(msg) - # CEL uses RE2 syntax which is a subset of Python re except for - # the flags and the ability to change the flags mid sequence. - # - # The conformance tests use flags at the very beginning of the sequence, which - # is likely the most common place where this rare feature will be used. - # - # Instead of importing an RE2 engine to be able to support this niche, we - # can instead just check for the flags at the very beginning and apply them. - # - # Unsupported flags and flags mid sequence will fail to compile the regex. - # - # Users can choose to override this function and provide an RE2 engine if they really need to. + + # CEL uses RE2 syntax which is a subset of Python re except for + # the flags and the ability to change the flags mid sequence. + # + # The conformance tests use flags at the very beginning of the sequence, which + # is likely the most common place where this rare feature will be used. + # + # Instead of importing an RE2 engine to be able to support this niche, we + # can instead just check for the flags at the very beginning and apply them. + # + # Unsupported flags and flags mid sequence will fail to compile the regex. + # + # Users can choose to override this function and provide an RE2 engine if they really need to. flags = "" flag_matches = re.match(flag_pattern, pattern) - pattern_str = pattern if flag_matches is not None: - ms = flag_matches.groupdict() - flagsies = ms["flags"] - for fl in flagsies: + flag_group = flag_matches.groupdict()["flags"] + for fl in flag_group: + # Flag removal, don't include it in the output if fl == "-": continue flags += fl - - pattern_str = pattern[len(flag_matches[0]):] - flags_enums = flags_from_letters(flags) - - expresh = re.compile(pattern_str, flags=flags_enums) + pattern_str = pattern[len(flag_matches[0]) :] + flags_enums = reduce(operator.or_, (flag_mapping[c] for c in flags if c in flag_mapping), 0) + exp = re.compile(pattern_str, flags=flags_enums) + else: + exp = re.compile(pattern) try: - m = re.search(expresh, text) + m = re.search(exp, text) except re.error as ex: return celpy.CELEvalError("match error", ex.__class__, ex.args) diff --git a/tests/matches_test.py b/tests/matches_test.py index a4274499..9708b741 100644 --- a/tests/matches_test.py +++ b/tests/matches_test.py @@ -46,5 +46,3 @@ def test_invalid_re2_syntax(self): def test_flags(self) -> None: result = extra_func.cel_matches(celtypes.StringType("!@#$%^&*()"), celtypes.StringType("(?i)^[a-z0-9]+$")) self.assertFalse(result) - - From 11717408e3490e7c594f5c31343abf997a9718a7 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Thu, 12 Jun 2025 11:17:52 -0400 Subject: [PATCH 3/9] Simulate re2 --- protovalidate/internal/extra_func.py | 77 +------------------ protovalidate/internal/matches.py | 106 +++++++++++++++++++++++++++ tests/matches_test.py | 10 ++- 3 files changed, 115 insertions(+), 78 deletions(-) create mode 100644 protovalidate/internal/matches.py diff --git a/protovalidate/internal/extra_func.py b/protovalidate/internal/extra_func.py index 14f9b18c..442cdb77 100644 --- a/protovalidate/internal/extra_func.py +++ b/protovalidate/internal/extra_func.py @@ -13,16 +13,15 @@ # limitations under the License. import math -import operator import re import typing -from functools import reduce from urllib import parse as urlparse import celpy from celpy import celtypes from protovalidate.internal import string_format +from protovalidate.internal.matches import cel_matches from protovalidate.internal.rules import MessageType, field_to_cel # See https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address @@ -1555,80 +1554,6 @@ def __peek(self, char: str) -> bool: return self._index < len(self._string) and self._string[self._index] == char -# Patterns that are supported in Python's re package and not in re2. -# RE2: https://github.com/google/re2/wiki/syntax -invalid_patterns = [ - r"\\[1-9]", # backreference - r"\\k<\w+>", # backreference - r"\(\?\=", # lookahead - r"\(\?\!", # negative lookahead - r"\(\?\<\=", # lookbehind - r"\(\?\<\!", # negative lookbehind - r"\\c[A-Z]", # control character - r"\\u[0-9a-fA-F]{4}", # UTF-16 code-unit - r"\\0(?!\d)", # NUL - r"\[\\b.*\]", # Backspace eg: [\b] -] - -flag_pattern = re.compile(r"^\(\?(?P[ims\-]+)\)") -flag_mapping = { - "a": re.A, - "i": re.I, - "l": re.L, - "m": re.M, -} - - -def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result: - if not isinstance(text, celtypes.StringType): - msg = "invalid argument for text, expected string" - raise celpy.CELEvalError(msg) - if not isinstance(pattern, celtypes.StringType): - msg = "invalid argument for pattern, expected string" - raise celpy.CELEvalError(msg) - - # Simulate re2 by failing on any patterns not compatible with re2 syntax - for invalid_pattern in invalid_patterns: - r = re.search(invalid_pattern, pattern) - if r is not None: - msg = f"error evaluating pattern {pattern}, invalid RE2 syntax" - raise celpy.CELEvalError(msg) - - # CEL uses RE2 syntax which is a subset of Python re except for - # the flags and the ability to change the flags mid sequence. - # - # The conformance tests use flags at the very beginning of the sequence, which - # is likely the most common place where this rare feature will be used. - # - # Instead of importing an RE2 engine to be able to support this niche, we - # can instead just check for the flags at the very beginning and apply them. - # - # Unsupported flags and flags mid sequence will fail to compile the regex. - # - # Users can choose to override this function and provide an RE2 engine if they really need to. - flags = "" - flag_matches = re.match(flag_pattern, pattern) - if flag_matches is not None: - flag_group = flag_matches.groupdict()["flags"] - for fl in flag_group: - # Flag removal, don't include it in the output - if fl == "-": - continue - flags += fl - pattern_str = pattern[len(flag_matches[0]) :] - flags_enums = reduce(operator.or_, (flag_mapping[c] for c in flags if c in flag_mapping), 0) - exp = re.compile(pattern_str, flags=flags_enums) - else: - exp = re.compile(pattern) - - try: - m = re.search(exp, text) - except re.error as ex: - return celpy.CELEvalError("match error", ex.__class__, ex.args) - - return celtypes.BoolType(m is not None) - - def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]: # For now, ignoring the type. string_fmt = string_format.StringFormat(locale) # type: ignore diff --git a/protovalidate/internal/matches.py b/protovalidate/internal/matches.py new file mode 100644 index 00000000..54ab0e07 --- /dev/null +++ b/protovalidate/internal/matches.py @@ -0,0 +1,106 @@ +# Copyright 2023-2025 Buf Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +import re +from functools import reduce + +import celpy +from celpy import celtypes + +# Patterns that are supported in Python's re package and not in re2. +# RE2: https://github.com/google/re2/wiki/syntax +invalid_patterns = [ + r"\\[1-9]", # backreference + r"\\k<\w+>", # backreference + r"\(\?\=", # lookahead + r"\(\?\!", # negative lookahead + r"\(\?\<\=", # lookbehind + r"\(\?\<\!", # negative lookbehind + r"\\c[A-Z]", # control character + r"\\u[0-9a-fA-F]{4}", # UTF-16 code-unit + r"\\0(?!\d)", # NUL + r"\[\\b.*\]", # Backspace eg: [\b] +] + +# Regex for searching a regex pattern for flags. +flag_pattern = re.compile(r"^\(\?(?P[ims\-]+)\)") + +# See https://docs.python.org/3/library/re.html#flags +flag_mapping = { + "a": re.A, + "i": re.I, + "L": re.L, + "m": re.M, + "s": re.S, + "u": re.U, + "x": re.X, +} + + +def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result: + """Return True if the given pattern matches text. False otherwise. + + CEL uses RE2 syntax which diverges from Python re in various ways. Ideally, we + would use the google-re2 package, which is an extra dep in celpy, but at press + time it does not provide a pre-built binary for the latest version of Python (3.13) + which means those using this version will run into many issues. + + Instead of foisting this issue on users, we instead mimic re2 syntax by failing + to compile the regex for patterns not compatible with re2. + + If users really want a pure re2 engine, they can provide their own via a config + parameter when creating a validator. + """ + if not isinstance(text, celtypes.StringType): + msg = "invalid argument for text, expected string" + raise celpy.CELEvalError(msg) + if not isinstance(pattern, celtypes.StringType): + msg = "invalid argument for pattern, expected string" + raise celpy.CELEvalError(msg) + + # Simulate re2 by failing on any patterns not compatible with re2 syntax + for invalid_pattern in invalid_patterns: + r = re.search(invalid_pattern, pattern) + if r is not None: + msg = f"error evaluating pattern {pattern}, invalid RE2 syntax" + raise celpy.CELEvalError(msg) + # The conformance tests use flags at the very beginning of the sequence, which + # is likely the most common place where this rare feature will be used. + # + # So we check for the flags at the very beginning and if present, apply them + # using Python re enums. + flags = "" + flag_matches = re.match(flag_pattern, pattern) + if flag_matches is not None: + flag_group = flag_matches.groupdict()["flags"] + for fl in flag_group: + # Flag removal, don't include it in the output + if fl == "-": + continue + flags += fl + # Grab the rest of the expression minus the flags + pattern_str = pattern[len(flag_matches[0]) :] + # Convert a string of flags (i.e. aiLm) into the actual re.A, re.I enums + flags_enums = reduce(operator.or_, (flag_mapping[c] for c in flags if c in flag_mapping), 0) + exp = re.compile(pattern_str, flags=flags_enums) + else: + exp = re.compile(pattern) + + try: + m = re.search(exp, text) + except re.error as ex: + return celpy.CELEvalError("match error", ex.__class__, ex.args) + + return celtypes.BoolType(m is not None) diff --git a/tests/matches_test.py b/tests/matches_test.py index 9708b741..87818cca 100644 --- a/tests/matches_test.py +++ b/tests/matches_test.py @@ -44,5 +44,11 @@ def test_invalid_re2_syntax(self): self.assertEqual(str(e), f"error evaluating pattern {cel_pattern}, invalid RE2 syntax") def test_flags(self) -> None: - result = extra_func.cel_matches(celtypes.StringType("!@#$%^&*()"), celtypes.StringType("(?i)^[a-z0-9]+$")) - self.assertFalse(result) + self.assertTrue(extra_func.cel_matches(celtypes.StringType("foobar"), celtypes.StringType("(?i:foo)(?-i:bar)"))) + self.assertTrue(extra_func.cel_matches(celtypes.StringType("FOObar"), celtypes.StringType("(?i:foo)(?-i:bar)"))) + self.assertFalse( + extra_func.cel_matches(celtypes.StringType("fooBAR"), celtypes.StringType("(?i:foo)(?-i:bar)")) + ) + self.assertFalse( + extra_func.cel_matches(celtypes.StringType("FOOBAR"), celtypes.StringType("(?i:foo)(?-i:bar)")) + ) From 63a2a594af5a42305aadc553acbfc8ee1d2a9ed9 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Thu, 12 Jun 2025 11:21:37 -0400 Subject: [PATCH 4/9] Remove comment for now --- protovalidate/internal/matches.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/protovalidate/internal/matches.py b/protovalidate/internal/matches.py index 54ab0e07..9c785597 100644 --- a/protovalidate/internal/matches.py +++ b/protovalidate/internal/matches.py @@ -59,9 +59,6 @@ def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result: Instead of foisting this issue on users, we instead mimic re2 syntax by failing to compile the regex for patterns not compatible with re2. - - If users really want a pure re2 engine, they can provide their own via a config - parameter when creating a validator. """ if not isinstance(text, celtypes.StringType): msg = "invalid argument for text, expected string" From 1e330570664d5937c1ebd142fb323db2eef2ebbc Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 16 Jun 2025 13:45:23 -0400 Subject: [PATCH 5/9] Config --- gen/tests/example/v1/validations_pb2.py | 64 +++++++++++++----------- gen/tests/example/v1/validations_pb2.pyi | 8 +++ proto/tests/example/v1/validations.proto | 5 ++ protovalidate/internal/config.py | 26 ++++++++++ protovalidate/internal/rules.py | 13 +++-- protovalidate/validator.py | 14 +++--- tests/config_test.py | 23 +++++++++ tests/validate_test.py | 19 ++++++- 8 files changed, 128 insertions(+), 44 deletions(-) create mode 100644 protovalidate/internal/config.py create mode 100644 tests/config_test.py diff --git a/gen/tests/example/v1/validations_pb2.py b/gen/tests/example/v1/validations_pb2.py index 29138880..6c88b1ca 100644 --- a/gen/tests/example/v1/validations_pb2.py +++ b/gen/tests/example/v1/validations_pb2.py @@ -40,7 +40,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"tests/example/v1/validations.proto\x12\x10tests.example.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1fgoogle/protobuf/timestamp.proto\")\n\x0c\x44oubleFinite\x12\x19\n\x03val\x18\x01 \x01(\x01\x42\x07\xbaH\x04\x12\x02@\x01R\x03val\";\n\x0eSFixed64ExLTGT\x12)\n\x03val\x18\x01 \x01(\x10\x42\x17\xbaH\x14\x62\x12\x11\x00\x00\x00\x00\x00\x00\x00\x00!\n\x00\x00\x00\x00\x00\x00\x00R\x03val\")\n\x0cTestOneofMsg\x12\x19\n\x03val\x18\x01 \x01(\x08\x42\x07\xbaH\x04j\x02\x08\x01R\x03val\"q\n\x05Oneof\x12\x1a\n\x01x\x18\x01 \x01(\tB\n\xbaH\x07r\x05:\x03\x66ooH\x00R\x01x\x12\x17\n\x01y\x18\x02 \x01(\x05\x42\x07\xbaH\x04\x1a\x02 \x00H\x00R\x01y\x12.\n\x01z\x18\x03 \x01(\x0b\x32\x1e.tests.example.v1.TestOneofMsgH\x00R\x01zB\x03\n\x01o\"[\n\x12ProtovalidateOneof\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x0b\xbaH\x08\"\x06\n\x01\x61\n\x01\x62\"e\n\x1aProtovalidateOneofRequired\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\r\xbaH\n\"\x08\n\x01\x61\n\x01\x62\x10\x01\"p\n\"ProtovalidateOneofUnknownFieldName\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x10\xbaH\r\"\x0b\n\x01\x61\n\x01\x62\n\x03xxx\"H\n\x0eTimestampGTNow\x12\x36\n\x03val\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xbaH\x05\xb2\x01\x02@\x01R\x03val\"\x87\x01\n\tMapMinMax\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32$.tests.example.v1.MapMinMax.ValEntryB\n\xbaH\x07\x9a\x01\x04\x08\x02\x10\x04R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\x08R\x05value:\x02\x38\x01\"\x85\x01\n\x07MapKeys\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32\".tests.example.v1.MapKeys.ValEntryB\x0c\xbaH\t\x9a\x01\x06\"\x04\x42\x02\x10\x00R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\x12R\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\"\n\x05\x45mbed\x12\x19\n\x03val\x18\x01 \x01(\x03\x42\x07\xbaH\x04\"\x02 \x00R\x03val\"K\n\x11RepeatedEmbedSkip\x12\x36\n\x03val\x18\x01 \x03(\x0b\x32\x17.tests.example.v1.EmbedB\x0b\xbaH\x08\x92\x01\x05\"\x03\xd8\x01\x03R\x03valB\x8a\x01\n\x14\x63om.tests.example.v1B\x10ValidationsProtoP\x01\xa2\x02\x03TEX\xaa\x02\x10Tests.Example.V1\xca\x02\x10Tests\\Example\\V1\xe2\x02\x1cTests\\Example\\V1\\GPBMetadata\xea\x02\x12Tests::Example::V1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"tests/example/v1/validations.proto\x12\x10tests.example.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"T\n\x13MultipleValidations\x12 \n\x05title\x18\x01 \x01(\tB\n\xbaH\x07r\x05:\x03\x66ooR\x05title\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xbaH\x04r\x02\x10\x05R\x04name\")\n\x0c\x44oubleFinite\x12\x19\n\x03val\x18\x01 \x01(\x01\x42\x07\xbaH\x04\x12\x02@\x01R\x03val\";\n\x0eSFixed64ExLTGT\x12)\n\x03val\x18\x01 \x01(\x10\x42\x17\xbaH\x14\x62\x12\x11\x00\x00\x00\x00\x00\x00\x00\x00!\n\x00\x00\x00\x00\x00\x00\x00R\x03val\")\n\x0cTestOneofMsg\x12\x19\n\x03val\x18\x01 \x01(\x08\x42\x07\xbaH\x04j\x02\x08\x01R\x03val\"q\n\x05Oneof\x12\x1a\n\x01x\x18\x01 \x01(\tB\n\xbaH\x07r\x05:\x03\x66ooH\x00R\x01x\x12\x17\n\x01y\x18\x02 \x01(\x05\x42\x07\xbaH\x04\x1a\x02 \x00H\x00R\x01y\x12.\n\x01z\x18\x03 \x01(\x0b\x32\x1e.tests.example.v1.TestOneofMsgH\x00R\x01zB\x03\n\x01o\"[\n\x12ProtovalidateOneof\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x0b\xbaH\x08\"\x06\n\x01\x61\n\x01\x62\"e\n\x1aProtovalidateOneofRequired\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\r\xbaH\n\"\x08\n\x01\x61\n\x01\x62\x10\x01\"p\n\"ProtovalidateOneofUnknownFieldName\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x10\xbaH\r\"\x0b\n\x01\x61\n\x01\x62\n\x03xxx\"H\n\x0eTimestampGTNow\x12\x36\n\x03val\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xbaH\x05\xb2\x01\x02@\x01R\x03val\"\x87\x01\n\tMapMinMax\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32$.tests.example.v1.MapMinMax.ValEntryB\n\xbaH\x07\x9a\x01\x04\x08\x02\x10\x04R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\x08R\x05value:\x02\x38\x01\"\x85\x01\n\x07MapKeys\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32\".tests.example.v1.MapKeys.ValEntryB\x0c\xbaH\t\x9a\x01\x06\"\x04\x42\x02\x10\x00R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\x12R\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\"\n\x05\x45mbed\x12\x19\n\x03val\x18\x01 \x01(\x03\x42\x07\xbaH\x04\"\x02 \x00R\x03val\"K\n\x11RepeatedEmbedSkip\x12\x36\n\x03val\x18\x01 \x03(\x0b\x32\x17.tests.example.v1.EmbedB\x0b\xbaH\x08\x92\x01\x05\"\x03\xd8\x01\x03R\x03valB\x8a\x01\n\x14\x63om.tests.example.v1B\x10ValidationsProtoP\x01\xa2\x02\x03TEX\xaa\x02\x10Tests.Example.V1\xca\x02\x10Tests\\Example\\V1\xe2\x02\x1cTests\\Example\\V1\\GPBMetadata\xea\x02\x12Tests::Example::V1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -48,6 +48,10 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\024com.tests.example.v1B\020ValidationsProtoP\001\242\002\003TEX\252\002\020Tests.Example.V1\312\002\020Tests\\Example\\V1\342\002\034Tests\\Example\\V1\\GPBMetadata\352\002\022Tests::Example::V1' + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['title']._loaded_options = None + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['title']._serialized_options = b'\272H\007r\005:\003foo' + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['name']._loaded_options = None + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['name']._serialized_options = b'\272H\004r\002\020\005' _globals['_DOUBLEFINITE'].fields_by_name['val']._loaded_options = None _globals['_DOUBLEFINITE'].fields_by_name['val']._serialized_options = b'\272H\004\022\002@\001' _globals['_SFIXED64EXLTGT'].fields_by_name['val']._loaded_options = None @@ -78,32 +82,34 @@ _globals['_EMBED'].fields_by_name['val']._serialized_options = b'\272H\004\"\002 \000' _globals['_REPEATEDEMBEDSKIP'].fields_by_name['val']._loaded_options = None _globals['_REPEATEDEMBEDSKIP'].fields_by_name['val']._serialized_options = b'\272H\010\222\001\005\"\003\330\001\003' - _globals['_DOUBLEFINITE']._serialized_start=118 - _globals['_DOUBLEFINITE']._serialized_end=159 - _globals['_SFIXED64EXLTGT']._serialized_start=161 - _globals['_SFIXED64EXLTGT']._serialized_end=220 - _globals['_TESTONEOFMSG']._serialized_start=222 - _globals['_TESTONEOFMSG']._serialized_end=263 - _globals['_ONEOF']._serialized_start=265 - _globals['_ONEOF']._serialized_end=378 - _globals['_PROTOVALIDATEONEOF']._serialized_start=380 - _globals['_PROTOVALIDATEONEOF']._serialized_end=471 - _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_start=473 - _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_end=574 - _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_start=576 - _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_end=688 - _globals['_TIMESTAMPGTNOW']._serialized_start=690 - _globals['_TIMESTAMPGTNOW']._serialized_end=762 - _globals['_MAPMINMAX']._serialized_start=765 - _globals['_MAPMINMAX']._serialized_end=900 - _globals['_MAPMINMAX_VALENTRY']._serialized_start=846 - _globals['_MAPMINMAX_VALENTRY']._serialized_end=900 - _globals['_MAPKEYS']._serialized_start=903 - _globals['_MAPKEYS']._serialized_end=1036 - _globals['_MAPKEYS_VALENTRY']._serialized_start=982 - _globals['_MAPKEYS_VALENTRY']._serialized_end=1036 - _globals['_EMBED']._serialized_start=1038 - _globals['_EMBED']._serialized_end=1072 - _globals['_REPEATEDEMBEDSKIP']._serialized_start=1074 - _globals['_REPEATEDEMBEDSKIP']._serialized_end=1149 + _globals['_MULTIPLEVALIDATIONS']._serialized_start=118 + _globals['_MULTIPLEVALIDATIONS']._serialized_end=202 + _globals['_DOUBLEFINITE']._serialized_start=204 + _globals['_DOUBLEFINITE']._serialized_end=245 + _globals['_SFIXED64EXLTGT']._serialized_start=247 + _globals['_SFIXED64EXLTGT']._serialized_end=306 + _globals['_TESTONEOFMSG']._serialized_start=308 + _globals['_TESTONEOFMSG']._serialized_end=349 + _globals['_ONEOF']._serialized_start=351 + _globals['_ONEOF']._serialized_end=464 + _globals['_PROTOVALIDATEONEOF']._serialized_start=466 + _globals['_PROTOVALIDATEONEOF']._serialized_end=557 + _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_start=559 + _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_end=660 + _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_start=662 + _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_end=774 + _globals['_TIMESTAMPGTNOW']._serialized_start=776 + _globals['_TIMESTAMPGTNOW']._serialized_end=848 + _globals['_MAPMINMAX']._serialized_start=851 + _globals['_MAPMINMAX']._serialized_end=986 + _globals['_MAPMINMAX_VALENTRY']._serialized_start=932 + _globals['_MAPMINMAX_VALENTRY']._serialized_end=986 + _globals['_MAPKEYS']._serialized_start=989 + _globals['_MAPKEYS']._serialized_end=1122 + _globals['_MAPKEYS_VALENTRY']._serialized_start=1068 + _globals['_MAPKEYS_VALENTRY']._serialized_end=1122 + _globals['_EMBED']._serialized_start=1124 + _globals['_EMBED']._serialized_end=1158 + _globals['_REPEATEDEMBEDSKIP']._serialized_start=1160 + _globals['_REPEATEDEMBEDSKIP']._serialized_end=1235 # @@protoc_insertion_point(module_scope) diff --git a/gen/tests/example/v1/validations_pb2.pyi b/gen/tests/example/v1/validations_pb2.pyi index 76de322e..e255c703 100644 --- a/gen/tests/example/v1/validations_pb2.pyi +++ b/gen/tests/example/v1/validations_pb2.pyi @@ -22,6 +22,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor +class MultipleValidations(_message.Message): + __slots__ = ("title", "name") + TITLE_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + title: str + name: str + def __init__(self, title: _Optional[str] = ..., name: _Optional[str] = ...) -> None: ... + class DoubleFinite(_message.Message): __slots__ = ("val",) VAL_FIELD_NUMBER: _ClassVar[int] diff --git a/proto/tests/example/v1/validations.proto b/proto/tests/example/v1/validations.proto index 4affacc1..7e51bdee 100644 --- a/proto/tests/example/v1/validations.proto +++ b/proto/tests/example/v1/validations.proto @@ -19,6 +19,11 @@ package tests.example.v1; import "buf/validate/validate.proto"; import "google/protobuf/timestamp.proto"; +message MultipleValidations { + string title = 1 [(buf.validate.field).string.prefix = "foo"]; + string name = 2 [(buf.validate.field).string.min_len = 5]; +} + message DoubleFinite { double val = 1 [(buf.validate.field).double.finite = true]; } diff --git a/protovalidate/internal/config.py b/protovalidate/internal/config.py new file mode 100644 index 00000000..1e21683b --- /dev/null +++ b/protovalidate/internal/config.py @@ -0,0 +1,26 @@ +# Copyright 2023-2025 Buf Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class Config: + """A class for holding configuration values for validation. + + Attributes: + fail_fast (bool): If true, validation will stop after the first violation. Defaults to False. + """ + + fail_fast: bool = False diff --git a/protovalidate/internal/rules.py b/protovalidate/internal/rules.py index 7f726412..91d18b6c 100644 --- a/protovalidate/internal/rules.py +++ b/protovalidate/internal/rules.py @@ -21,6 +21,7 @@ from google.protobuf import any_pb2, descriptor, message, message_factory from buf.validate import validate_pb2 # type: ignore +from protovalidate.internal import config as _config from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has @@ -266,15 +267,17 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N class RuleContext: """The state associated with a single rule evaluation.""" - def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): - self._fail_fast = fail_fast + _cfg: _config.Config + + def __init__(self, *, config: _config.Config, violations: typing.Optional[list[Violation]] = None): + self._cfg = config if violations is None: violations = [] self._violations = violations @property def fail_fast(self) -> bool: - return self._fail_fast + return self._cfg.fail_fast @property def violations(self) -> list[Violation]: @@ -296,13 +299,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat @property def done(self) -> bool: - return self._fail_fast and self.has_errors() + return self.fail_fast and self.has_errors() def has_errors(self) -> bool: return len(self._violations) > 0 def sub_context(self): - return RuleContext(fail_fast=self._fail_fast) + return RuleContext(config=self._cfg) class Rules: diff --git a/protovalidate/validator.py b/protovalidate/validator.py index b53b3e4f..ac2b511b 100644 --- a/protovalidate/validator.py +++ b/protovalidate/validator.py @@ -17,6 +17,7 @@ from google.protobuf import message from buf.validate import validate_pb2 # type: ignore +from protovalidate.internal import config as _config from protovalidate.internal import extra_func from protovalidate.internal import rules as _rules @@ -35,15 +36,15 @@ class Validator: """ _factory: _rules.RuleFactory + _cfg: _config.Config - def __init__(self): + def __init__(self, config=None): self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS) + self._cfg = config if config is not None else _config.Config() def validate( self, message: message.Message, - *, - fail_fast: bool = False, ): """ Validates the given message against the static rules defined in @@ -51,12 +52,11 @@ def validate( Parameters: message: The message to validate. - fail_fast: If true, validation will stop after the first violation. Raises: CompilationError: If the static rules could not be compiled. ValidationError: If the message is invalid. """ - violations = self.collect_violations(message, fail_fast=fail_fast) + violations = self.collect_violations(message) if len(violations) > 0: msg = f"invalid {message.DESCRIPTOR.name}" raise ValidationError(msg, violations) @@ -65,7 +65,6 @@ def collect_violations( self, message: message.Message, *, - fail_fast: bool = False, into: typing.Optional[list[Violation]] = None, ) -> list[Violation]: """ @@ -76,13 +75,12 @@ def collect_violations( Parameters: message: The message to validate. - fail_fast: If true, validation will stop after the first violation. into: If provided, any violations will be appended to the Violations object and the same object will be returned. Raises: CompilationError: If the static rules could not be compiled. """ - ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into) + ctx = _rules.RuleContext(config=self._cfg, violations=into) for rule in self._factory.get(message.DESCRIPTOR): rule.validate(ctx, message) if ctx.done: diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 00000000..4a7d5e26 --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,23 @@ +# Copyright 2023-2025 Buf Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from protovalidate.internal import config + + +class TestConfig(unittest.TestCase): + def test_defaults(self): + cfg = config.Config() + self.assertFalse(cfg.fail_fast) diff --git a/tests/validate_test.py b/tests/validate_test.py index 6e52c947..196bbb7f 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -16,6 +16,7 @@ import protovalidate from gen.tests.example.v1 import validations_pb2 +from protovalidate.internal import config class TestValidate(unittest.TestCase): @@ -114,7 +115,21 @@ def test_maps(self): def test_timestamp(self): msg = validations_pb2.TimestampGTNow() - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) assert len(violations) == 0 + + def test_multiple_validations(self): + msg = validations_pb2.MultipleValidations() + msg.title = "bar" + msg.name = "blah" + violations = protovalidate.collect_violations(msg) + assert len(violations) == 2 + + def test_fail_fast(self): + msg = validations_pb2.MultipleValidations() + msg.title = "bar" + msg.name = "blah" + cfg = config.Config(fail_fast=True) + validator = protovalidate.Validator(config=cfg) + violations = validator.collect_violations(msg) + assert len(violations) == 1 From a6353d17c1c5ba04ac771681feebf220925d8ec6 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 17 Jun 2025 13:24:32 -0400 Subject: [PATCH 6/9] Switch to a config --- protovalidate/__init__.py | 6 +- protovalidate/validator.py | 12 +- tests/validate_test.py | 249 ++++++++++++++++++++++++++++--------- 3 files changed, 204 insertions(+), 63 deletions(-) diff --git a/protovalidate/__init__.py b/protovalidate/__init__.py index 0f82fccb..2ce8261b 100644 --- a/protovalidate/__init__.py +++ b/protovalidate/__init__.py @@ -19,8 +19,8 @@ ValidationError = validator.ValidationError Violations = validator.Violations -_validator = Validator() -validate = _validator.validate -collect_violations = _validator.collect_violations +_default_validator = Validator() +validate = _default_validator.validate +collect_violations = _default_validator.collect_violations __all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] diff --git a/protovalidate/validator.py b/protovalidate/validator.py index ac2b511b..a80b7bc6 100644 --- a/protovalidate/validator.py +++ b/protovalidate/validator.py @@ -54,7 +54,8 @@ def validate( message: The message to validate. Raises: CompilationError: If the static rules could not be compiled. - ValidationError: If the message is invalid. + ValidationError: If the message is invalid. The violations raised as part of this error should + always be equal to the list of violations returned by `collect_violations`. """ violations = self.collect_violations(message) if len(violations) > 0: @@ -69,9 +70,12 @@ def collect_violations( ) -> list[Violation]: """ Validates the given message against the static rules defined in - the message's descriptor. Compared to validate, collect_violations is - faster but puts the burden of raising an appropriate exception on the - caller. + the message's descriptor. Compared to `validate`, `collect_violations` simply + returns the violations as a list and puts the burden of raising an appropriate + exception on the caller. + + The violations returned from this method should always be equal to the violations + raised as part of the ValidationError in the call to `validate`. Parameters: message: The message to validate. diff --git a/tests/validate_test.py b/tests/validate_test.py index 196bbb7f..73b1e927 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -14,122 +14,259 @@ import unittest +from google.protobuf import message + import protovalidate from gen.tests.example.v1 import validations_pb2 -from protovalidate.internal import config +from protovalidate.internal import config, rules + + +def get_default_validator(): + """Returns a default validator created in all available ways + + This allows testing for validators created via: + - module-level singleton + - instantiated class with no config + - instantiated class with config + """ + return [ + ("module singleton", protovalidate), + ("no config", protovalidate.Validator()), + ("with default config", protovalidate.Validator(config.Config())), + ] + + +class TestCollectViolations(unittest.TestCase): + """Test class for testing message validations. + A validator can be created via various ways: + - a module-level singleton, which returns a default validator + - instantiating the Validator class with no config, which returns a default validator + - instantiating the Validator class with a config + + In addition, the API for validating a message allows for two approaches: + - via a call to `validate`, which will raise a ValidationError if validation fails + - via a call to `collect_violations`, which will not raise an error and instead return a list of violations. + + Unless otherwise noted, each test in this class tests against a validator created via all 3 methods and tests + validation using both approaches. + """ -class TestValidate(unittest.TestCase): def test_ninf(self): msg = validations_pb2.DoubleFinite() msg.val = float("-inf") - violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations), 1) - self.assertEqual(violations[0].proto.rule_id, "double.finite") - self.assertEqual(violations[0].field_value, msg.val) - self.assertEqual(violations[0].rule_value, True) + + expected_violation = rules.Violation() + expected_violation.proto.message = "value must be finite" + expected_violation.proto.rule_id = "double.finite" + expected_violation.field_value = msg.val + expected_violation.rule_value = True + + self._run_invalid_tests(msg, [expected_violation]) def test_map_key(self): msg = validations_pb2.MapKeys() msg.val[1] = "a" - violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations), 1) - self.assertEqual(violations[0].proto.for_key, True) - self.assertEqual(violations[0].field_value, 1) - self.assertEqual(violations[0].rule_value, 0) - def test_sfixed64(self): + expected_violation = rules.Violation() + expected_violation.proto.message = "value must be less than 0" + expected_violation.proto.rule_id = "sint64.lt" + expected_violation.proto.for_key = True + expected_violation.field_value = 1 + expected_violation.rule_value = 0 + + self._run_invalid_tests(msg, [expected_violation]) + + def test_sfixed64_valid(self): msg = validations_pb2.SFixed64ExLTGT(val=11) - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations), 0) + self._run_valid_tests(msg) def test_oneofs(self): + msg = validations_pb2.Oneof() + msg.y = 123 + + self._run_valid_tests(msg) + + def test_collect_violations_into(self): msg1 = validations_pb2.Oneof() msg1.y = 123 - protovalidate.validate(msg1) msg2 = validations_pb2.Oneof() msg2.z.val = True - protovalidate.validate(msg2) - violations = protovalidate.collect_violations(msg1) - protovalidate.collect_violations(msg2, into=violations) - assert len(violations) == 0 + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test collect_violations into + violations = v.collect_violations(msg1) + v.collect_violations(msg2, into=violations) + self.assertEqual(len(violations), 0) def test_protovalidate_oneof_valid(self): msg = validations_pb2.ProtovalidateOneof() msg.a = "A" - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - assert len(violations) == 0 + + self._run_valid_tests(msg) def test_protovalidate_oneof_violation(self): msg = validations_pb2.ProtovalidateOneof() msg.a = "A" msg.b = "B" - with self.assertRaises(protovalidate.ValidationError) as cm: - protovalidate.validate(msg) - e = cm.exception - assert str(e) == "invalid ProtovalidateOneof" - assert len(e.violations) == 1 - assert e.to_proto().violations[0].message == "only one of a, b can be set" + + expected_violation = rules.Violation() + expected_violation.proto.message = "only one of a, b can be set" + expected_violation.proto.rule_id = "message.oneof" + + self._run_invalid_tests(msg, [expected_violation]) def test_protovalidate_oneof_required_violation(self): msg = validations_pb2.ProtovalidateOneofRequired() - with self.assertRaises(protovalidate.ValidationError) as cm: - protovalidate.validate(msg) - e = cm.exception - assert str(e) == "invalid ProtovalidateOneofRequired" - assert len(e.violations) == 1 - assert e.to_proto().violations[0].message == "one of a, b must be set" + + expected_violation = rules.Violation() + expected_violation.proto.message = "one of a, b must be set" + expected_violation.proto.rule_id = "message.oneof" + + self._run_invalid_tests(msg, [expected_violation]) def test_protovalidate_oneof_unknown_field_name(self): + """Tests that a compilation error is thrown when specifying a oneof rule with an invalid field name""" msg = validations_pb2.ProtovalidateOneofUnknownFieldName() - with self.assertRaises(protovalidate.CompilationError) as cm: - protovalidate.validate(msg) - assert ( - str(cm.exception) == 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName' + + self._run_compilation_error_tests( + msg, 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName' ) def test_repeated(self): msg = validations_pb2.RepeatedEmbedSkip() msg.val.add(val=-1) - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - assert len(violations) == 0 + self._run_valid_tests(msg) def test_maps(self): msg = validations_pb2.MapMinMax() - with self.assertRaises(protovalidate.ValidationError) as cm: - protovalidate.validate(msg) - e = cm.exception - assert len(e.violations) == 1 - assert len(e.to_proto().violations) == 1 - assert str(e) == "invalid MapMinMax" - violations = protovalidate.collect_violations(msg) - assert len(violations) == 1 + expected_violation = rules.Violation() + expected_violation.proto.message = "map must be at least 2 entries" + expected_violation.proto.rule_id = "map.min_pairs" + expected_violation.field_value = {} + expected_violation.rule_value = 2 + + self._run_invalid_tests(msg, [expected_violation]) def test_timestamp(self): msg = validations_pb2.TimestampGTNow() - violations = protovalidate.collect_violations(msg) - assert len(violations) == 0 + + self._run_valid_tests(msg) def test_multiple_validations(self): + """Test that a message with multiple violations correctly returns all of them.""" msg = validations_pb2.MultipleValidations() msg.title = "bar" msg.name = "blah" - violations = protovalidate.collect_violations(msg) - assert len(violations) == 2 + + expected_violation1 = rules.Violation() + expected_violation1.proto.message = "value does not have prefix `foo`" + expected_violation1.proto.rule_id = "string.prefix" + expected_violation1.field_value = msg.title + expected_violation1.rule_value = "foo" + + expected_violation2 = rules.Violation() + expected_violation2.proto.message = "value length must be at least 5 characters" + expected_violation2.proto.rule_id = "string.min_len" + expected_violation2.field_value = msg.name + expected_violation2.rule_value = 5 + + self._run_invalid_tests(msg, [expected_violation1, expected_violation2]) def test_fail_fast(self): + """Test that fail fast correctly fails on first violation + + Note this does not use a default validator, but instead uses one with a custom config + so that fail_fast can be set to True. + """ msg = validations_pb2.MultipleValidations() msg.title = "bar" msg.name = "blah" + + expected_violation = rules.Violation() + expected_violation.proto.message = "value does not have prefix `foo`" + expected_violation.proto.rule_id = "string.prefix" + expected_violation.field_value = msg.title + expected_violation.rule_value = "foo" + cfg = config.Config(fail_fast=True) validator = protovalidate.Validator(config=cfg) + + # Test validate + with self.assertRaises(protovalidate.ValidationError) as cm: + validator.validate(msg) + e = cm.exception + self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") + self._compare_violations(e.violations, [expected_violation]) + + # Test collect_violations violations = validator.collect_violations(msg) - assert len(violations) == 1 + self._compare_violations(violations, [expected_violation]) + + def _run_valid_tests(self, msg: message.Message): + """A helper function for testing successful validation on a given message + + The tests are run using validators created via all possible methods and + validation is done via a call to `validate` as well as a call to `collect_violations`. + """ + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test validate + try: + v.validate(msg) + except Exception: + self.fail(f"[{label}]: unexpected validation failure") + + # Test collect_violations + violations = v.collect_violations(msg) + self.assertEqual(len(violations), 0) + + def _run_invalid_tests(self, msg: message.Message, expected: list[rules.Violation]): + """A helper function for testing unsuccessful validation on a given message + + The tests are run using validators created via all possible methods and + validation is done via a call to `validate` as well as a call to `collect_violations`. + """ + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test validate + with self.assertRaises(protovalidate.ValidationError) as cm: + v.validate(msg) + e = cm.exception + self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") + self._compare_violations(e.violations, expected) + + # Test collect_violations + violations = v.collect_violations(msg) + self._compare_violations(violations, expected) + + def _run_compilation_error_tests(self, msg: message.Message, expected: str): + """A helper function for testing compilation errors when validating. + + The tests are run using validators created via all possible methods and + validation is done via a call to `validate` as well as a call to `collect_violations`. + """ + for label, v in get_default_validator(): + with self.subTest(label=label): + with self.assertRaises(protovalidate.CompilationError) as cvce: + v.collect_violations(msg) + assert str(cvce.exception) == expected + + with self.assertRaises(protovalidate.CompilationError) as vce: + v.validate(msg) + assert str(vce.exception) == expected + + def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None: + """Compares two lists of violations. The violations are expected to be in the expected order also.""" + self.assertEqual(len(actual), len(expected)) + for a, e in zip(actual, expected): + self.assertEqual(a.proto.message, e.proto.message) + self.assertEqual(a.proto.rule_id, e.proto.rule_id) + self.assertEqual(a.proto.for_key, e.proto.for_key) + self.assertEqual(a.field_value, e.field_value) + self.assertEqual(a.rule_value, e.rule_value) From 3287b4f1b0a7398b2d2b70309cd340d97d15b66f Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 17 Jun 2025 13:28:52 -0400 Subject: [PATCH 7/9] Comments --- tests/validate_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/validate_test.py b/tests/validate_test.py index 73b1e927..927de8b2 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -253,14 +253,16 @@ def _run_compilation_error_tests(self, msg: message.Message, expected: str): """ for label, v in get_default_validator(): with self.subTest(label=label): - with self.assertRaises(protovalidate.CompilationError) as cvce: - v.collect_violations(msg) - assert str(cvce.exception) == expected - + # Test validate with self.assertRaises(protovalidate.CompilationError) as vce: v.validate(msg) assert str(vce.exception) == expected + # Test collect_violations + with self.assertRaises(protovalidate.CompilationError) as cvce: + v.collect_violations(msg) + assert str(cvce.exception) == expected + def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None: """Compares two lists of violations. The violations are expected to be in the expected order also.""" self.assertEqual(len(actual), len(expected)) From b957343126fd262cf6cfb764396247e6ec18c36f Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 17 Jun 2025 13:53:07 -0400 Subject: [PATCH 8/9] Better import path --- protovalidate/{internal => }/config.py | 0 protovalidate/internal/rules.py | 6 +++--- protovalidate/validator.py | 6 +++--- tests/config_test.py | 4 ++-- tests/validate_test.py | 7 ++++--- 5 files changed, 12 insertions(+), 11 deletions(-) rename protovalidate/{internal => }/config.py (100%) diff --git a/protovalidate/internal/config.py b/protovalidate/config.py similarity index 100% rename from protovalidate/internal/config.py rename to protovalidate/config.py diff --git a/protovalidate/internal/rules.py b/protovalidate/internal/rules.py index 91d18b6c..989abf96 100644 --- a/protovalidate/internal/rules.py +++ b/protovalidate/internal/rules.py @@ -21,7 +21,7 @@ from google.protobuf import any_pb2, descriptor, message, message_factory from buf.validate import validate_pb2 # type: ignore -from protovalidate.internal import config as _config +from protovalidate.config import Config from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has @@ -267,9 +267,9 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N class RuleContext: """The state associated with a single rule evaluation.""" - _cfg: _config.Config + _cfg: Config - def __init__(self, *, config: _config.Config, violations: typing.Optional[list[Violation]] = None): + def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None): self._cfg = config if violations is None: violations = [] diff --git a/protovalidate/validator.py b/protovalidate/validator.py index a80b7bc6..30d3fd2d 100644 --- a/protovalidate/validator.py +++ b/protovalidate/validator.py @@ -17,7 +17,7 @@ from google.protobuf import message from buf.validate import validate_pb2 # type: ignore -from protovalidate.internal import config as _config +from protovalidate.config import Config from protovalidate.internal import extra_func from protovalidate.internal import rules as _rules @@ -36,11 +36,11 @@ class Validator: """ _factory: _rules.RuleFactory - _cfg: _config.Config + _cfg: Config def __init__(self, config=None): self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS) - self._cfg = config if config is not None else _config.Config() + self._cfg = config if config is not None else Config() def validate( self, diff --git a/tests/config_test.py b/tests/config_test.py index 4a7d5e26..82738b91 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -14,10 +14,10 @@ import unittest -from protovalidate.internal import config +from protovalidate.config import Config class TestConfig(unittest.TestCase): def test_defaults(self): - cfg = config.Config() + cfg = Config() self.assertFalse(cfg.fail_fast) diff --git a/tests/validate_test.py b/tests/validate_test.py index 927de8b2..d52bb500 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -18,7 +18,8 @@ import protovalidate from gen.tests.example.v1 import validations_pb2 -from protovalidate.internal import config, rules +from protovalidate.config import Config +from protovalidate.internal import rules def get_default_validator(): @@ -32,7 +33,7 @@ def get_default_validator(): return [ ("module singleton", protovalidate), ("no config", protovalidate.Validator()), - ("with default config", protovalidate.Validator(config.Config())), + ("with default config", protovalidate.Validator(Config())), ] @@ -194,7 +195,7 @@ def test_fail_fast(self): expected_violation.field_value = msg.title expected_violation.rule_value = "foo" - cfg = config.Config(fail_fast=True) + cfg = Config(fail_fast=True) validator = protovalidate.Validator(config=cfg) # Test validate From ce219217702e755eb72a4c72f65b4a7531299e27 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 17 Jun 2025 14:48:48 -0400 Subject: [PATCH 9/9] Add config to module init.py --- protovalidate/__init__.py | 5 +++-- tests/config_test.py | 2 +- tests/validate_test.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/protovalidate/__init__.py b/protovalidate/__init__.py index 2ce8261b..1c078426 100644 --- a/protovalidate/__init__.py +++ b/protovalidate/__init__.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from protovalidate import validator +from protovalidate import config, validator +Config = config.Config Validator = validator.Validator CompilationError = validator.CompilationError ValidationError = validator.ValidationError @@ -23,4 +24,4 @@ validate = _default_validator.validate collect_violations = _default_validator.collect_violations -__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] +__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] diff --git a/tests/config_test.py b/tests/config_test.py index 82738b91..71f33af7 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -14,7 +14,7 @@ import unittest -from protovalidate.config import Config +from protovalidate import Config class TestConfig(unittest.TestCase): diff --git a/tests/validate_test.py b/tests/validate_test.py index d52bb500..792aeb08 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -257,12 +257,12 @@ def _run_compilation_error_tests(self, msg: message.Message, expected: str): # Test validate with self.assertRaises(protovalidate.CompilationError) as vce: v.validate(msg) - assert str(vce.exception) == expected + self.assertEqual(str(vce.exception), expected) # Test collect_violations with self.assertRaises(protovalidate.CompilationError) as cvce: v.collect_violations(msg) - assert str(cvce.exception) == expected + self.assertEqual(str(cvce.exception), expected) def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None: """Compares two lists of violations. The violations are expected to be in the expected order also."""