Skip to content

Commit 0c7db6c

Browse files
authored
Add the ability to specify a config to validators (#323)
This adds the ability to specify a `Config` data class to validators when creating them. Prior to this PR, validators could be created one of two ways: * via a module-level singleton. ```python import protovalidate protovalidate.validate(msg) ``` * via explicit instantiation of the validator. i.e. ```python import protovalidate validator = protovalidate.Validator() validator.validate(msg) ``` With this PR, it is now possible to create a validator with a `Config` object, specifying various options (only `fail_fast` for now) for configuring a validator. To use a config, users must explicitly instantiate a validator and specify their config. They cannot use the module singleton as this only allows for a default validator. For example: ```python import protovalidate from protovalidate.config import Config cfg = Config(fail_fast=True) validator = protovalidate.Validator(config=cfg) validator.validate(msg) ``` **This PR is a breaking change**. As a result of the above, the `fail_fast` parameters have been removed from the signatures for `validate` and `collect_violations`. Users should instead use the above config method to specify `fail_fast`. This also adds a bit more depth to the `validate_test` unit tests by testing creation of a default validator via the module and via explicit instantiation. It also tests that the violations returned by `collect_violations` and raises as part of the exception returned from `validate` are the same.
1 parent 13ab1ae commit 0c7db6c

File tree

9 files changed

+333
-104
lines changed

9 files changed

+333
-104
lines changed

gen/tests/example/v1/validations_pb2.py

Lines changed: 35 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gen/tests/example/v1/validations_pb2.pyi

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

proto/tests/example/v1/validations.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ package tests.example.v1;
1919
import "buf/validate/validate.proto";
2020
import "google/protobuf/timestamp.proto";
2121

22+
message MultipleValidations {
23+
string title = 1 [(buf.validate.field).string.prefix = "foo"];
24+
string name = 2 [(buf.validate.field).string.min_len = 5];
25+
}
26+
2227
message DoubleFinite {
2328
double val = 1 [(buf.validate.field).double.finite = true];
2429
}

protovalidate/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from protovalidate import validator
15+
from protovalidate import config, validator
1616

17+
Config = config.Config
1718
Validator = validator.Validator
1819
CompilationError = validator.CompilationError
1920
ValidationError = validator.ValidationError
2021
Violations = validator.Violations
2122

22-
_validator = Validator()
23-
validate = _validator.validate
24-
collect_violations = _validator.collect_violations
23+
_default_validator = Validator()
24+
validate = _default_validator.validate
25+
collect_violations = _default_validator.collect_violations
2526

26-
__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
27+
__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]

protovalidate/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2023-2025 Buf Technologies, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
17+
18+
@dataclass
19+
class Config:
20+
"""A class for holding configuration values for validation.
21+
22+
Attributes:
23+
fail_fast (bool): If true, validation will stop after the first violation. Defaults to False.
24+
"""
25+
26+
fail_fast: bool = False

protovalidate/internal/rules.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.protobuf import any_pb2, descriptor, message, message_factory
2222

2323
from buf.validate import validate_pb2 # type: ignore
24+
from protovalidate.config import Config
2425
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
2526

2627

@@ -266,15 +267,17 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N
266267
class RuleContext:
267268
"""The state associated with a single rule evaluation."""
268269

269-
def __init__(self, *, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None):
270-
self._fail_fast = fail_fast
270+
_cfg: Config
271+
272+
def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
273+
self._cfg = config
271274
if violations is None:
272275
violations = []
273276
self._violations = violations
274277

275278
@property
276279
def fail_fast(self) -> bool:
277-
return self._fail_fast
280+
return self._cfg.fail_fast
278281

279282
@property
280283
def violations(self) -> list[Violation]:
@@ -296,13 +299,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat
296299

297300
@property
298301
def done(self) -> bool:
299-
return self._fail_fast and self.has_errors()
302+
return self.fail_fast and self.has_errors()
300303

301304
def has_errors(self) -> bool:
302305
return len(self._violations) > 0
303306

304307
def sub_context(self):
305-
return RuleContext(fail_fast=self._fail_fast)
308+
return RuleContext(config=self._cfg)
306309

307310

308311
class Rules:

protovalidate/validator.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.protobuf import message
1818

1919
from buf.validate import validate_pb2 # type: ignore
20+
from protovalidate.config import Config
2021
from protovalidate.internal import extra_func
2122
from protovalidate.internal import rules as _rules
2223

@@ -35,28 +36,28 @@ class Validator:
3536
"""
3637

3738
_factory: _rules.RuleFactory
39+
_cfg: Config
3840

39-
def __init__(self):
41+
def __init__(self, config=None):
4042
self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS)
43+
self._cfg = config if config is not None else Config()
4144

4245
def validate(
4346
self,
4447
message: message.Message,
45-
*,
46-
fail_fast: bool = False,
4748
):
4849
"""
4950
Validates the given message against the static rules defined in
5051
the message's descriptor.
5152
5253
Parameters:
5354
message: The message to validate.
54-
fail_fast: If true, validation will stop after the first violation.
5555
Raises:
5656
CompilationError: If the static rules could not be compiled.
57-
ValidationError: If the message is invalid.
57+
ValidationError: If the message is invalid. The violations raised as part of this error should
58+
always be equal to the list of violations returned by `collect_violations`.
5859
"""
59-
violations = self.collect_violations(message, fail_fast=fail_fast)
60+
violations = self.collect_violations(message)
6061
if len(violations) > 0:
6162
msg = f"invalid {message.DESCRIPTOR.name}"
6263
raise ValidationError(msg, violations)
@@ -65,24 +66,25 @@ def collect_violations(
6566
self,
6667
message: message.Message,
6768
*,
68-
fail_fast: bool = False,
6969
into: typing.Optional[list[Violation]] = None,
7070
) -> list[Violation]:
7171
"""
7272
Validates the given message against the static rules defined in
73-
the message's descriptor. Compared to validate, collect_violations is
74-
faster but puts the burden of raising an appropriate exception on the
75-
caller.
73+
the message's descriptor. Compared to `validate`, `collect_violations` simply
74+
returns the violations as a list and puts the burden of raising an appropriate
75+
exception on the caller.
76+
77+
The violations returned from this method should always be equal to the violations
78+
raised as part of the ValidationError in the call to `validate`.
7679
7780
Parameters:
7881
message: The message to validate.
79-
fail_fast: If true, validation will stop after the first violation.
8082
into: If provided, any violations will be appended to the
8183
Violations object and the same object will be returned.
8284
Raises:
8385
CompilationError: If the static rules could not be compiled.
8486
"""
85-
ctx = _rules.RuleContext(fail_fast=fail_fast, violations=into)
87+
ctx = _rules.RuleContext(config=self._cfg, violations=into)
8688
for rule in self._factory.get(message.DESCRIPTOR):
8789
rule.validate(ctx, message)
8890
if ctx.done:

tests/config_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2023-2025 Buf Technologies, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from protovalidate import Config
18+
19+
20+
class TestConfig(unittest.TestCase):
21+
def test_defaults(self):
22+
cfg = Config()
23+
self.assertFalse(cfg.fail_fast)

0 commit comments

Comments
 (0)