Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions protovalidate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from protovalidate import config, validator
from protovalidate import validator

Config = config.Config
Validator = validator.Validator
CompilationError = validator.CompilationError
ValidationError = validator.ValidationError
Expand All @@ -24,4 +23,4 @@
validate = _default_validator.validate
collect_violations = _default_validator.collect_violations

__all__ = ["CompilationError", "Config", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
__all__ = ["CompilationError", "ValidationError", "Validator", "Violations", "collect_violations", "validate"]
26 changes: 0 additions & 26 deletions protovalidate/config.py

This file was deleted.

18 changes: 5 additions & 13 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory, timestamp_pb2

from buf.validate import validate_pb2
from protovalidate.config import Config
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has


Expand Down Expand Up @@ -266,18 +265,11 @@ def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = N
class RuleContext:
"""The state associated with a single rule evaluation."""

_cfg: Config
_violations: list[Violation]

def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
self._cfg = config
if violations is None:
violations = []
self._violations = violations

@property
def fail_fast(self) -> bool:
return self._cfg.fail_fast
def __init__(self, *, fail_fast: bool = False):
self._fail_fast = fail_fast
self._violations = []

@property
def violations(self) -> list[Violation]:
Expand All @@ -299,13 +291,13 @@ def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPat

@property
def done(self) -> bool:
return self.fail_fast and self.has_errors()
return self._fail_fast and self.has_errors()

def has_errors(self) -> bool:
return len(self._violations) > 0

def sub_context(self) -> "RuleContext":
return RuleContext(config=self._cfg)
return RuleContext(fail_fast=self._fail_fast)


class Rules:
Expand Down
22 changes: 7 additions & 15 deletions protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from google.protobuf import message

from buf.validate import validate_pb2
from protovalidate.config import Config
from protovalidate.internal import extra_func
from protovalidate.internal import rules as _rules

Expand All @@ -36,29 +33,25 @@ class Validator:
"""

_factory: _rules.RuleFactory
_cfg: Config

def __init__(self, config: typing.Optional[Config] = None):
self._cfg = config if config is not None else Config()
def __init__(self):
funcs = extra_func.make_extra_funcs()
self._factory = _rules.RuleFactory(funcs)

def validate(
self,
message: message.Message,
):
def validate(self, message: message.Message, *, fail_fast: bool = False):
"""
Validates the given message against the static rules defined in
the message's descriptor.

Parameters:
message: The message to validate.
fail_fast: If true, validation will stop after the first iteration.
Raises:
CompilationError: If the static rules could not be compiled.
ValidationError: If the message is invalid. The violations raised as part of this error should
always be equal to the list of violations returned by `collect_violations`.
"""
violations = self.collect_violations(message)
violations = self.collect_violations(message, fail_fast=fail_fast)
if len(violations) > 0:
msg = f"invalid {message.DESCRIPTOR.name}"
raise ValidationError(msg, violations)
Expand All @@ -67,7 +60,7 @@ def collect_violations(
self,
message: message.Message,
*,
into: typing.Optional[list[Violation]] = None,
fail_fast: bool = False,
) -> list[Violation]:
"""
Validates the given message against the static rules defined in
Expand All @@ -80,12 +73,11 @@ def collect_violations(

Parameters:
message: The message to validate.
into: If provided, any violations will be appended to the
Violations object and the same object will be returned.
fail_fast: If true, validation will stop after the first iteration.
Raises:
CompilationError: If the static rules could not be compiled.
"""
ctx = _rules.RuleContext(config=self._cfg, violations=into)
ctx = _rules.RuleContext(fail_fast=fail_fast)
for rule in self._factory.get(message.DESCRIPTOR):
rule.validate(ctx, message)
if ctx.done:
Expand Down
23 changes: 0 additions & 23 deletions test/test_config.py

This file was deleted.

37 changes: 7 additions & 30 deletions test/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import protovalidate
from gen.tests.example.v1 import validations_pb2
from protovalidate.config import Config
from protovalidate.internal import rules


Expand All @@ -27,13 +26,11 @@ def get_default_validator():

This allows testing for validators created via:
- module-level singleton
- instantiated class with no config
- instantiated class with config
- instantiated class
"""
return [
("module singleton", protovalidate),
("no config", protovalidate.Validator()),
("with default config", protovalidate.Validator(Config())),
("instantiated class", protovalidate.Validator()),
]


Expand All @@ -42,8 +39,7 @@ class TestCollectViolations(unittest.TestCase):

A validator can be created via various ways:
- a module-level singleton, which returns a default validator
- instantiating the Validator class with no config, which returns a default validator
- instantiating the Validator class with a config
- instantiating the Validator class

In addition, the API for validating a message allows for two approaches:
- via a call to `validate`, which will raise a ValidationError if validation fails
Expand Down Expand Up @@ -89,20 +85,6 @@ def test_oneofs(self):

self._run_valid_tests(msg)

def test_collect_violations_into(self):
msg1 = validations_pb2.Oneof()
msg1.y = 123

msg2 = validations_pb2.Oneof()
msg2.z.val = True

for label, v in get_default_validator():
with self.subTest(label=label):
# Test collect_violations into
violations = v.collect_violations(msg1)
v.collect_violations(msg2, into=violations)
self.assertEqual(len(violations), 0)

def test_protovalidate_oneof_valid(self):
msg = validations_pb2.ProtovalidateOneof()
msg.a = "A"
Expand Down Expand Up @@ -188,11 +170,7 @@ def test_concatenated_values(self):
self._run_valid_tests(msg)

def test_fail_fast(self):
"""Test that fail fast correctly fails on first violation

Note this does not use a default validator, but instead uses one with a custom config
so that fail_fast can be set to True.
"""
"""Test that fail fast correctly fails on first violation"""
msg = validations_pb2.MultipleValidations()
msg.title = "bar"
msg.name = "blah"
Expand All @@ -203,18 +181,17 @@ def test_fail_fast(self):
expected_violation.field_value = msg.title
expected_violation.rule_value = "foo"

cfg = Config(fail_fast=True)
validator = protovalidate.Validator(config=cfg)
validator = protovalidate.Validator()

# Test validate
with self.assertRaises(protovalidate.ValidationError) as cm:
validator.validate(msg)
validator.validate(msg, fail_fast=True)
e = cm.exception
self.assertEqual(str(e), f"invalid {msg.DESCRIPTOR.name}")
self._compare_violations(e.violations, [expected_violation])

# Test collect_violations
violations = validator.collect_violations(msg)
violations = validator.collect_violations(msg, fail_fast=True)
self._compare_violations(violations, [expected_violation])

def _run_valid_tests(self, msg: message.Message):
Expand Down