diff --git a/gen/tests/example/v1/validations_pb2.py b/gen/tests/example/v1/validations_pb2.py index 29138880..6c88b1ca 100644 --- a/gen/tests/example/v1/validations_pb2.py +++ b/gen/tests/example/v1/validations_pb2.py @@ -40,7 +40,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"tests/example/v1/validations.proto\x12\x10tests.example.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1fgoogle/protobuf/timestamp.proto\")\n\x0c\x44oubleFinite\x12\x19\n\x03val\x18\x01 \x01(\x01\x42\x07\xbaH\x04\x12\x02@\x01R\x03val\";\n\x0eSFixed64ExLTGT\x12)\n\x03val\x18\x01 \x01(\x10\x42\x17\xbaH\x14\x62\x12\x11\x00\x00\x00\x00\x00\x00\x00\x00!\n\x00\x00\x00\x00\x00\x00\x00R\x03val\")\n\x0cTestOneofMsg\x12\x19\n\x03val\x18\x01 \x01(\x08\x42\x07\xbaH\x04j\x02\x08\x01R\x03val\"q\n\x05Oneof\x12\x1a\n\x01x\x18\x01 \x01(\tB\n\xbaH\x07r\x05:\x03\x66ooH\x00R\x01x\x12\x17\n\x01y\x18\x02 \x01(\x05\x42\x07\xbaH\x04\x1a\x02 \x00H\x00R\x01y\x12.\n\x01z\x18\x03 \x01(\x0b\x32\x1e.tests.example.v1.TestOneofMsgH\x00R\x01zB\x03\n\x01o\"[\n\x12ProtovalidateOneof\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x0b\xbaH\x08\"\x06\n\x01\x61\n\x01\x62\"e\n\x1aProtovalidateOneofRequired\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\r\xbaH\n\"\x08\n\x01\x61\n\x01\x62\x10\x01\"p\n\"ProtovalidateOneofUnknownFieldName\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x10\xbaH\r\"\x0b\n\x01\x61\n\x01\x62\n\x03xxx\"H\n\x0eTimestampGTNow\x12\x36\n\x03val\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xbaH\x05\xb2\x01\x02@\x01R\x03val\"\x87\x01\n\tMapMinMax\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32$.tests.example.v1.MapMinMax.ValEntryB\n\xbaH\x07\x9a\x01\x04\x08\x02\x10\x04R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\x08R\x05value:\x02\x38\x01\"\x85\x01\n\x07MapKeys\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32\".tests.example.v1.MapKeys.ValEntryB\x0c\xbaH\t\x9a\x01\x06\"\x04\x42\x02\x10\x00R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\x12R\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\"\n\x05\x45mbed\x12\x19\n\x03val\x18\x01 \x01(\x03\x42\x07\xbaH\x04\"\x02 \x00R\x03val\"K\n\x11RepeatedEmbedSkip\x12\x36\n\x03val\x18\x01 \x03(\x0b\x32\x17.tests.example.v1.EmbedB\x0b\xbaH\x08\x92\x01\x05\"\x03\xd8\x01\x03R\x03valB\x8a\x01\n\x14\x63om.tests.example.v1B\x10ValidationsProtoP\x01\xa2\x02\x03TEX\xaa\x02\x10Tests.Example.V1\xca\x02\x10Tests\\Example\\V1\xe2\x02\x1cTests\\Example\\V1\\GPBMetadata\xea\x02\x12Tests::Example::V1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"tests/example/v1/validations.proto\x12\x10tests.example.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"T\n\x13MultipleValidations\x12 \n\x05title\x18\x01 \x01(\tB\n\xbaH\x07r\x05:\x03\x66ooR\x05title\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xbaH\x04r\x02\x10\x05R\x04name\")\n\x0c\x44oubleFinite\x12\x19\n\x03val\x18\x01 \x01(\x01\x42\x07\xbaH\x04\x12\x02@\x01R\x03val\";\n\x0eSFixed64ExLTGT\x12)\n\x03val\x18\x01 \x01(\x10\x42\x17\xbaH\x14\x62\x12\x11\x00\x00\x00\x00\x00\x00\x00\x00!\n\x00\x00\x00\x00\x00\x00\x00R\x03val\")\n\x0cTestOneofMsg\x12\x19\n\x03val\x18\x01 \x01(\x08\x42\x07\xbaH\x04j\x02\x08\x01R\x03val\"q\n\x05Oneof\x12\x1a\n\x01x\x18\x01 \x01(\tB\n\xbaH\x07r\x05:\x03\x66ooH\x00R\x01x\x12\x17\n\x01y\x18\x02 \x01(\x05\x42\x07\xbaH\x04\x1a\x02 \x00H\x00R\x01y\x12.\n\x01z\x18\x03 \x01(\x0b\x32\x1e.tests.example.v1.TestOneofMsgH\x00R\x01zB\x03\n\x01o\"[\n\x12ProtovalidateOneof\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x0b\xbaH\x08\"\x06\n\x01\x61\n\x01\x62\"e\n\x1aProtovalidateOneofRequired\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\r\xbaH\n\"\x08\n\x01\x61\n\x01\x62\x10\x01\"p\n\"ProtovalidateOneofUnknownFieldName\x12\x0c\n\x01\x61\x18\x01 \x01(\tR\x01\x61\x12\x0c\n\x01\x62\x18\x02 \x01(\tR\x01\x62\x12\x1c\n\tunrelated\x18\x03 \x01(\x08R\tunrelated:\x10\xbaH\r\"\x0b\n\x01\x61\n\x01\x62\n\x03xxx\"H\n\x0eTimestampGTNow\x12\x36\n\x03val\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xbaH\x05\xb2\x01\x02@\x01R\x03val\"\x87\x01\n\tMapMinMax\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32$.tests.example.v1.MapMinMax.ValEntryB\n\xbaH\x07\x9a\x01\x04\x08\x02\x10\x04R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\x08R\x05value:\x02\x38\x01\"\x85\x01\n\x07MapKeys\x12\x42\n\x03val\x18\x01 \x03(\x0b\x32\".tests.example.v1.MapKeys.ValEntryB\x0c\xbaH\t\x9a\x01\x06\"\x04\x42\x02\x10\x00R\x03val\x1a\x36\n\x08ValEntry\x12\x10\n\x03key\x18\x01 \x01(\x12R\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\"\n\x05\x45mbed\x12\x19\n\x03val\x18\x01 \x01(\x03\x42\x07\xbaH\x04\"\x02 \x00R\x03val\"K\n\x11RepeatedEmbedSkip\x12\x36\n\x03val\x18\x01 \x03(\x0b\x32\x17.tests.example.v1.EmbedB\x0b\xbaH\x08\x92\x01\x05\"\x03\xd8\x01\x03R\x03valB\x8a\x01\n\x14\x63om.tests.example.v1B\x10ValidationsProtoP\x01\xa2\x02\x03TEX\xaa\x02\x10Tests.Example.V1\xca\x02\x10Tests\\Example\\V1\xe2\x02\x1cTests\\Example\\V1\\GPBMetadata\xea\x02\x12Tests::Example::V1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -48,6 +48,10 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\024com.tests.example.v1B\020ValidationsProtoP\001\242\002\003TEX\252\002\020Tests.Example.V1\312\002\020Tests\\Example\\V1\342\002\034Tests\\Example\\V1\\GPBMetadata\352\002\022Tests::Example::V1' + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['title']._loaded_options = None + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['title']._serialized_options = b'\272H\007r\005:\003foo' + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['name']._loaded_options = None + _globals['_MULTIPLEVALIDATIONS'].fields_by_name['name']._serialized_options = b'\272H\004r\002\020\005' _globals['_DOUBLEFINITE'].fields_by_name['val']._loaded_options = None _globals['_DOUBLEFINITE'].fields_by_name['val']._serialized_options = b'\272H\004\022\002@\001' _globals['_SFIXED64EXLTGT'].fields_by_name['val']._loaded_options = None @@ -78,32 +82,34 @@ _globals['_EMBED'].fields_by_name['val']._serialized_options = b'\272H\004\"\002 \000' _globals['_REPEATEDEMBEDSKIP'].fields_by_name['val']._loaded_options = None _globals['_REPEATEDEMBEDSKIP'].fields_by_name['val']._serialized_options = b'\272H\010\222\001\005\"\003\330\001\003' - _globals['_DOUBLEFINITE']._serialized_start=118 - _globals['_DOUBLEFINITE']._serialized_end=159 - _globals['_SFIXED64EXLTGT']._serialized_start=161 - _globals['_SFIXED64EXLTGT']._serialized_end=220 - _globals['_TESTONEOFMSG']._serialized_start=222 - _globals['_TESTONEOFMSG']._serialized_end=263 - _globals['_ONEOF']._serialized_start=265 - _globals['_ONEOF']._serialized_end=378 - _globals['_PROTOVALIDATEONEOF']._serialized_start=380 - _globals['_PROTOVALIDATEONEOF']._serialized_end=471 - _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_start=473 - _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_end=574 - _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_start=576 - _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_end=688 - _globals['_TIMESTAMPGTNOW']._serialized_start=690 - _globals['_TIMESTAMPGTNOW']._serialized_end=762 - _globals['_MAPMINMAX']._serialized_start=765 - _globals['_MAPMINMAX']._serialized_end=900 - _globals['_MAPMINMAX_VALENTRY']._serialized_start=846 - _globals['_MAPMINMAX_VALENTRY']._serialized_end=900 - _globals['_MAPKEYS']._serialized_start=903 - _globals['_MAPKEYS']._serialized_end=1036 - _globals['_MAPKEYS_VALENTRY']._serialized_start=982 - _globals['_MAPKEYS_VALENTRY']._serialized_end=1036 - _globals['_EMBED']._serialized_start=1038 - _globals['_EMBED']._serialized_end=1072 - _globals['_REPEATEDEMBEDSKIP']._serialized_start=1074 - _globals['_REPEATEDEMBEDSKIP']._serialized_end=1149 + _globals['_MULTIPLEVALIDATIONS']._serialized_start=118 + _globals['_MULTIPLEVALIDATIONS']._serialized_end=202 + _globals['_DOUBLEFINITE']._serialized_start=204 + _globals['_DOUBLEFINITE']._serialized_end=245 + _globals['_SFIXED64EXLTGT']._serialized_start=247 + _globals['_SFIXED64EXLTGT']._serialized_end=306 + _globals['_TESTONEOFMSG']._serialized_start=308 + _globals['_TESTONEOFMSG']._serialized_end=349 + _globals['_ONEOF']._serialized_start=351 + _globals['_ONEOF']._serialized_end=464 + _globals['_PROTOVALIDATEONEOF']._serialized_start=466 + _globals['_PROTOVALIDATEONEOF']._serialized_end=557 + _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_start=559 + _globals['_PROTOVALIDATEONEOFREQUIRED']._serialized_end=660 + _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_start=662 + _globals['_PROTOVALIDATEONEOFUNKNOWNFIELDNAME']._serialized_end=774 + _globals['_TIMESTAMPGTNOW']._serialized_start=776 + _globals['_TIMESTAMPGTNOW']._serialized_end=848 + _globals['_MAPMINMAX']._serialized_start=851 + _globals['_MAPMINMAX']._serialized_end=986 + _globals['_MAPMINMAX_VALENTRY']._serialized_start=932 + _globals['_MAPMINMAX_VALENTRY']._serialized_end=986 + _globals['_MAPKEYS']._serialized_start=989 + _globals['_MAPKEYS']._serialized_end=1122 + _globals['_MAPKEYS_VALENTRY']._serialized_start=1068 + _globals['_MAPKEYS_VALENTRY']._serialized_end=1122 + _globals['_EMBED']._serialized_start=1124 + _globals['_EMBED']._serialized_end=1158 + _globals['_REPEATEDEMBEDSKIP']._serialized_start=1160 + _globals['_REPEATEDEMBEDSKIP']._serialized_end=1235 # @@protoc_insertion_point(module_scope) diff --git a/gen/tests/example/v1/validations_pb2.pyi b/gen/tests/example/v1/validations_pb2.pyi index 76de322e..e255c703 100644 --- a/gen/tests/example/v1/validations_pb2.pyi +++ b/gen/tests/example/v1/validations_pb2.pyi @@ -22,6 +22,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor +class MultipleValidations(_message.Message): + __slots__ = ("title", "name") + TITLE_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + title: str + name: str + def __init__(self, title: _Optional[str] = ..., name: _Optional[str] = ...) -> None: ... + class DoubleFinite(_message.Message): __slots__ = ("val",) VAL_FIELD_NUMBER: _ClassVar[int] diff --git a/proto/tests/example/v1/validations.proto b/proto/tests/example/v1/validations.proto index 4affacc1..7e51bdee 100644 --- a/proto/tests/example/v1/validations.proto +++ b/proto/tests/example/v1/validations.proto @@ -19,6 +19,11 @@ package tests.example.v1; import "buf/validate/validate.proto"; import "google/protobuf/timestamp.proto"; +message MultipleValidations { + string title = 1 [(buf.validate.field).string.prefix = "foo"]; + string name = 2 [(buf.validate.field).string.min_len = 5]; +} + message DoubleFinite { double val = 1 [(buf.validate.field).double.finite = true]; } diff --git a/protovalidate/__init__.py b/protovalidate/__init__.py index 0f82fccb..1c078426 100644 --- a/protovalidate/__init__.py +++ b/protovalidate/__init__.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from protovalidate import validator +from protovalidate import config, validator +Config = config.Config Validator = validator.Validator CompilationError = validator.CompilationError ValidationError = validator.ValidationError Violations = validator.Violations -_validator = Validator() -validate = _validator.validate -collect_violations = _validator.collect_violations +_default_validator = Validator() +validate = _default_validator.validate +collect_violations = _default_validator.collect_violations -__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] +__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"] diff --git a/protovalidate/config.py b/protovalidate/config.py new file mode 100644 index 00000000..1e21683b --- /dev/null +++ b/protovalidate/config.py @@ -0,0 +1,26 @@ +# Copyright 2023-2025 Buf Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class Config: + """A class for holding configuration values for validation. + + Attributes: + fail_fast (bool): If true, validation will stop after the first violation. Defaults to False. + """ + + fail_fast: bool = False diff --git a/protovalidate/internal/rules.py b/protovalidate/internal/rules.py index 7f726412..989abf96 100644 --- a/protovalidate/internal/rules.py +++ b/protovalidate/internal/rules.py @@ -21,6 +21,7 @@ from google.protobuf import any_pb2, descriptor, message, message_factory from buf.validate import validate_pb2 # type: ignore +from protovalidate.config import Config from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has @@ -266,15 +267,17 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N class RuleContext: """The state associated with a single rule evaluation.""" - def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): - self._fail_fast = fail_fast + _cfg: Config + + def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None): + self._cfg = config if violations is None: violations = [] self._violations = violations @property def fail_fast(self) -> bool: - return self._fail_fast + return self._cfg.fail_fast @property def violations(self) -> list[Violation]: @@ -296,13 +299,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat @property def done(self) -> bool: - return self._fail_fast and self.has_errors() + return self.fail_fast and self.has_errors() def has_errors(self) -> bool: return len(self._violations) > 0 def sub_context(self): - return RuleContext(fail_fast=self._fail_fast) + return RuleContext(config=self._cfg) class Rules: diff --git a/protovalidate/validator.py b/protovalidate/validator.py index b53b3e4f..30d3fd2d 100644 --- a/protovalidate/validator.py +++ b/protovalidate/validator.py @@ -17,6 +17,7 @@ from google.protobuf import message from buf.validate import validate_pb2 # type: ignore +from protovalidate.config import Config from protovalidate.internal import extra_func from protovalidate.internal import rules as _rules @@ -35,15 +36,15 @@ class Validator: """ _factory: _rules.RuleFactory + _cfg: Config - def __init__(self): + def __init__(self, config=None): self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS) + self._cfg = config if config is not None else Config() def validate( self, message: message.Message, - *, - fail_fast: bool = False, ): """ Validates the given message against the static rules defined in @@ -51,12 +52,12 @@ def validate( Parameters: message: The message to validate. - fail_fast: If true, validation will stop after the first violation. Raises: CompilationError: If the static rules could not be compiled. - ValidationError: If the message is invalid. + ValidationError: If the message is invalid. The violations raised as part of this error should + always be equal to the list of violations returned by `collect_violations`. """ - violations = self.collect_violations(message, fail_fast=fail_fast) + violations = self.collect_violations(message) if len(violations) > 0: msg = f"invalid {message.DESCRIPTOR.name}" raise ValidationError(msg, violations) @@ -65,24 +66,25 @@ def collect_violations( self, message: message.Message, *, - fail_fast: bool = False, into: typing.Optional[list[Violation]] = None, ) -> list[Violation]: """ Validates the given message against the static rules defined in - the message's descriptor. Compared to validate, collect_violations is - faster but puts the burden of raising an appropriate exception on the - caller. + the message's descriptor. Compared to `validate`, `collect_violations` simply + returns the violations as a list and puts the burden of raising an appropriate + exception on the caller. + + The violations returned from this method should always be equal to the violations + raised as part of the ValidationError in the call to `validate`. Parameters: message: The message to validate. - fail_fast: If true, validation will stop after the first violation. into: If provided, any violations will be appended to the Violations object and the same object will be returned. Raises: CompilationError: If the static rules could not be compiled. """ - ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into) + ctx = _rules.RuleContext(config=self._cfg, violations=into) for rule in self._factory.get(message.DESCRIPTOR): rule.validate(ctx, message) if ctx.done: diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 00000000..71f33af7 --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,23 @@ +# Copyright 2023-2025 Buf Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from protovalidate import Config + + +class TestConfig(unittest.TestCase): + def test_defaults(self): + cfg = Config() + self.assertFalse(cfg.fail_fast) diff --git a/tests/validate_test.py b/tests/validate_test.py index 6e52c947..792aeb08 100644 --- a/tests/validate_test.py +++ b/tests/validate_test.py @@ -14,107 +14,262 @@ import unittest +from google.protobuf import message + import protovalidate from gen.tests.example.v1 import validations_pb2 +from protovalidate.config import Config +from protovalidate.internal import rules + + +def get_default_validator(): + """Returns a default validator created in all available ways + + This allows testing for validators created via: + - module-level singleton + - instantiated class with no config + - instantiated class with config + """ + return [ + ("module singleton", protovalidate), + ("no config", protovalidate.Validator()), + ("with default config", protovalidate.Validator(Config())), + ] + + +class TestCollectViolations(unittest.TestCase): + """Test class for testing message validations. + A validator can be created via various ways: + - a module-level singleton, which returns a default validator + - instantiating the Validator class with no config, which returns a default validator + - instantiating the Validator class with a config + + In addition, the API for validating a message allows for two approaches: + - via a call to `validate`, which will raise a ValidationError if validation fails + - via a call to `collect_violations`, which will not raise an error and instead return a list of violations. + + Unless otherwise noted, each test in this class tests against a validator created via all 3 methods and tests + validation using both approaches. + """ -class TestValidate(unittest.TestCase): def test_ninf(self): msg = validations_pb2.DoubleFinite() msg.val = float("-inf") - violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations), 1) - self.assertEqual(violations[0].proto.rule_id, "double.finite") - self.assertEqual(violations[0].field_value, msg.val) - self.assertEqual(violations[0].rule_value, True) + + expected_violation = rules.Violation() + expected_violation.proto.message = "value must be finite" + expected_violation.proto.rule_id = "double.finite" + expected_violation.field_value = msg.val + expected_violation.rule_value = True + + self._run_invalid_tests(msg, [expected_violation]) def test_map_key(self): msg = validations_pb2.MapKeys() msg.val[1] = "a" - violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations), 1) - self.assertEqual(violations[0].proto.for_key, True) - self.assertEqual(violations[0].field_value, 1) - self.assertEqual(violations[0].rule_value, 0) - def test_sfixed64(self): + expected_violation = rules.Violation() + expected_violation.proto.message = "value must be less than 0" + expected_violation.proto.rule_id = "sint64.lt" + expected_violation.proto.for_key = True + expected_violation.field_value = 1 + expected_violation.rule_value = 0 + + self._run_invalid_tests(msg, [expected_violation]) + + def test_sfixed64_valid(self): msg = validations_pb2.SFixed64ExLTGT(val=11) - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - self.assertEqual(len(violations), 0) + self._run_valid_tests(msg) def test_oneofs(self): + msg = validations_pb2.Oneof() + msg.y = 123 + + self._run_valid_tests(msg) + + def test_collect_violations_into(self): msg1 = validations_pb2.Oneof() msg1.y = 123 - protovalidate.validate(msg1) msg2 = validations_pb2.Oneof() msg2.z.val = True - protovalidate.validate(msg2) - violations = protovalidate.collect_violations(msg1) - protovalidate.collect_violations(msg2, into=violations) - assert len(violations) == 0 + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test collect_violations into + violations = v.collect_violations(msg1) + v.collect_violations(msg2, into=violations) + self.assertEqual(len(violations), 0) def test_protovalidate_oneof_valid(self): msg = validations_pb2.ProtovalidateOneof() msg.a = "A" - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - assert len(violations) == 0 + + self._run_valid_tests(msg) def test_protovalidate_oneof_violation(self): msg = validations_pb2.ProtovalidateOneof() msg.a = "A" msg.b = "B" - with self.assertRaises(protovalidate.ValidationError) as cm: - protovalidate.validate(msg) - e = cm.exception - assert str(e) == "invalid ProtovalidateOneof" - assert len(e.violations) == 1 - assert e.to_proto().violations[0].message == "only one of a, b can be set" + + expected_violation = rules.Violation() + expected_violation.proto.message = "only one of a, b can be set" + expected_violation.proto.rule_id = "message.oneof" + + self._run_invalid_tests(msg, [expected_violation]) def test_protovalidate_oneof_required_violation(self): msg = validations_pb2.ProtovalidateOneofRequired() - with self.assertRaises(protovalidate.ValidationError) as cm: - protovalidate.validate(msg) - e = cm.exception - assert str(e) == "invalid ProtovalidateOneofRequired" - assert len(e.violations) == 1 - assert e.to_proto().violations[0].message == "one of a, b must be set" + + expected_violation = rules.Violation() + expected_violation.proto.message = "one of a, b must be set" + expected_violation.proto.rule_id = "message.oneof" + + self._run_invalid_tests(msg, [expected_violation]) def test_protovalidate_oneof_unknown_field_name(self): + """Tests that a compilation error is thrown when specifying a oneof rule with an invalid field name""" msg = validations_pb2.ProtovalidateOneofUnknownFieldName() - with self.assertRaises(protovalidate.CompilationError) as cm: - protovalidate.validate(msg) - assert ( - str(cm.exception) == 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName' + + self._run_compilation_error_tests( + msg, 'field "xxx" not found in message tests.example.v1.ProtovalidateOneofUnknownFieldName' ) def test_repeated(self): msg = validations_pb2.RepeatedEmbedSkip() msg.val.add(val=-1) - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - assert len(violations) == 0 + self._run_valid_tests(msg) def test_maps(self): msg = validations_pb2.MapMinMax() - with self.assertRaises(protovalidate.ValidationError) as cm: - protovalidate.validate(msg) - e = cm.exception - assert len(e.violations) == 1 - assert len(e.to_proto().violations) == 1 - assert str(e) == "invalid MapMinMax" - violations = protovalidate.collect_violations(msg) - assert len(violations) == 1 + expected_violation = rules.Violation() + expected_violation.proto.message = "map must be at least 2 entries" + expected_violation.proto.rule_id = "map.min_pairs" + expected_violation.field_value = {} + expected_violation.rule_value = 2 + + self._run_invalid_tests(msg, [expected_violation]) def test_timestamp(self): msg = validations_pb2.TimestampGTNow() - protovalidate.validate(msg) - violations = protovalidate.collect_violations(msg) - assert len(violations) == 0 + self._run_valid_tests(msg) + + def test_multiple_validations(self): + """Test that a message with multiple violations correctly returns all of them.""" + msg = validations_pb2.MultipleValidations() + msg.title = "bar" + msg.name = "blah" + + expected_violation1 = rules.Violation() + expected_violation1.proto.message = "value does not have prefix `foo`" + expected_violation1.proto.rule_id = "string.prefix" + expected_violation1.field_value = msg.title + expected_violation1.rule_value = "foo" + + expected_violation2 = rules.Violation() + expected_violation2.proto.message = "value length must be at least 5 characters" + expected_violation2.proto.rule_id = "string.min_len" + expected_violation2.field_value = msg.name + expected_violation2.rule_value = 5 + + self._run_invalid_tests(msg, [expected_violation1, expected_violation2]) + + def test_fail_fast(self): + """Test that fail fast correctly fails on first violation + + Note this does not use a default validator, but instead uses one with a custom config + so that fail_fast can be set to True. + """ + msg = validations_pb2.MultipleValidations() + msg.title = "bar" + msg.name = "blah" + + expected_violation = rules.Violation() + expected_violation.proto.message = "value does not have prefix `foo`" + expected_violation.proto.rule_id = "string.prefix" + expected_violation.field_value = msg.title + expected_violation.rule_value = "foo" + + cfg = Config(fail_fast=True) + validator = protovalidate.Validator(config=cfg) + + # Test validate + with self.assertRaises(protovalidate.ValidationError) as cm: + validator.validate(msg) + e = cm.exception + self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") + self._compare_violations(e.violations, [expected_violation]) + + # Test collect_violations + violations = validator.collect_violations(msg) + self._compare_violations(violations, [expected_violation]) + + def _run_valid_tests(self, msg: message.Message): + """A helper function for testing successful validation on a given message + + The tests are run using validators created via all possible methods and + validation is done via a call to `validate` as well as a call to `collect_violations`. + """ + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test validate + try: + v.validate(msg) + except Exception: + self.fail(f"[{label}]: unexpected validation failure") + + # Test collect_violations + violations = v.collect_violations(msg) + self.assertEqual(len(violations), 0) + + def _run_invalid_tests(self, msg: message.Message, expected: list[rules.Violation]): + """A helper function for testing unsuccessful validation on a given message + + The tests are run using validators created via all possible methods and + validation is done via a call to `validate` as well as a call to `collect_violations`. + """ + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test validate + with self.assertRaises(protovalidate.ValidationError) as cm: + v.validate(msg) + e = cm.exception + self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}") + self._compare_violations(e.violations, expected) + + # Test collect_violations + violations = v.collect_violations(msg) + self._compare_violations(violations, expected) + + def _run_compilation_error_tests(self, msg: message.Message, expected: str): + """A helper function for testing compilation errors when validating. + + The tests are run using validators created via all possible methods and + validation is done via a call to `validate` as well as a call to `collect_violations`. + """ + for label, v in get_default_validator(): + with self.subTest(label=label): + # Test validate + with self.assertRaises(protovalidate.CompilationError) as vce: + v.validate(msg) + self.assertEqual(str(vce.exception), expected) + + # Test collect_violations + with self.assertRaises(protovalidate.CompilationError) as cvce: + v.collect_violations(msg) + self.assertEqual(str(cvce.exception), expected) + + def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None: + """Compares two lists of violations. The violations are expected to be in the expected order also.""" + self.assertEqual(len(actual), len(expected)) + for a, e in zip(actual, expected): + self.assertEqual(a.proto.message, e.proto.message) + self.assertEqual(a.proto.rule_id, e.proto.rule_id) + self.assertEqual(a.proto.for_key, e.proto.for_key) + self.assertEqual(a.field_value, e.field_value) + self.assertEqual(a.rule_value, e.rule_value)