Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gen/tests/example/v1/validations_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions gen/tests/example/v1/validations_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions proto/tests/example/v1/validations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ message MapKeys {
message Embed {
int64 val = 1 [(buf.validate.field).int64.gt = 0];
}

message RepeatedEmbedSkip {
repeated Embed val = 1 [(buf.validate.field).repeated.items.ignore = IGNORE_ALWAYS];
}

message InvalidRESyntax {
string value = 1 [(buf.validate.field).string.pattern = "^\\z"];
}
5 changes: 5 additions & 0 deletions protovalidate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Callable, Optional


@dataclass
Expand All @@ -21,6 +22,10 @@ class Config:

Attributes:
fail_fast (bool): If true, validation will stop after the first violation. Defaults to False.
regex_matches_func: An optional regex matcher to use. If specified, this will be used to match
on regex expressions instead of this library's `matches` logic.
"""

fail_fast: bool = False

regex_matches_func: Optional[Callable[[str, str], bool]] = None
31 changes: 23 additions & 8 deletions protovalidate/internal/extra_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import celpy
from celpy import celtypes

from protovalidate.config import Config
from protovalidate.internal import string_format
from protovalidate.internal.matches import cel_matches
from protovalidate.internal.matches import matches as protovalidate_matches
from protovalidate.internal.rules import MessageType, field_to_cel

# See https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address
Expand Down Expand Up @@ -1554,14 +1555,31 @@ def __peek(self, char: str) -> bool:
return self._index < len(self._string) and self._string[self._index] == char


def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]:
# For now, ignoring the type.
string_fmt = string_format.StringFormat(locale) # type: ignore
def get_matches_func(matcher: typing.Optional[typing.Callable[[str, str], bool]]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly haven't kept up with type hints in Python but should we be using collections.abc.Callable instead, since typing.Callable seems to be a deprecated alias as of 3.9?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. Thanks. Wonder why the linter didn't catch that. 🤷 . Fixed.

if matcher is None:
matcher = protovalidate_matches

def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result:
if not isinstance(text, celtypes.StringType):
msg = "invalid argument for text, expected string"
raise celpy.CELEvalError(msg)
if not isinstance(pattern, celtypes.StringType):
msg = "invalid argument for pattern, expected string"
raise celpy.CELEvalError(msg)

b = matcher(text, pattern)
return celtypes.BoolType(b)

return cel_matches


def make_extra_funcs(config: Config) -> dict[str, celpy.CELFunction]:
string_fmt = string_format.StringFormat()
return {
# Missing standard functions
"format": string_fmt.format,
# Overridden standard functions
"matches": cel_matches,
"matches": get_matches_func(config.regex_matches_func),
# protovalidate specific functions
"getField": cel_get_field,
"isNan": cel_is_nan,
Expand All @@ -1575,6 +1593,3 @@ def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]:
"isHostAndPort": cel_is_host_and_port,
"unique": cel_unique,
}


EXTRA_FUNCS = make_extra_funcs("en_US")
22 changes: 11 additions & 11 deletions protovalidate/internal/matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import re

import celpy
from celpy import celtypes

# Patterns that are supported in Python's re package and not in re2.
# RE2: https://github.com/google/re2/wiki/syntax
Expand All @@ -30,10 +29,11 @@
r"\\u[0-9a-fA-F]{4}", # UTF-16 code-unit
r"\\0(?!\d)", # NUL
r"\[\\b.*\]", # Backspace eg: [\b]
r"\\Z", # End of text (only lowercase z is supported in re2)
]


def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result:
def matches(text: str, pattern: str) -> bool:
"""Return True if the given pattern matches text. False otherwise.

CEL uses RE2 syntax which diverges from Python re in various ways. Ideally, we
Expand All @@ -43,14 +43,13 @@ def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result:

Instead of foisting this issue on users, we instead mimic re2 syntax by failing
to compile the regex for patterns not compatible with re2.
"""
if not isinstance(text, celtypes.StringType):
msg = "invalid argument for text, expected string"
raise celpy.CELEvalError(msg)
if not isinstance(pattern, celtypes.StringType):
msg = "invalid argument for pattern, expected string"
raise celpy.CELEvalError(msg)

Users can choose to override this behavior by providing their own custom matches
function via the Config.

Raises:
celpy.CELEvalError: If pattern contains invalid re2 syntax.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're also raising this on re.error on the second re.search, do we want to document when that will go wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, done.

"""
# Simulate re2 by failing on any patterns not compatible with re2 syntax
for invalid_pattern in invalid_patterns:
r = re.search(invalid_pattern, pattern)
Expand All @@ -61,6 +60,7 @@ def cel_matches(text: celtypes.Value, pattern: celtypes.Value) -> celpy.Result:
try:
m = re.search(pattern, text)
except re.error as ex:
return celpy.CELEvalError("match error", ex.__class__, ex.args)
msg = "match error"
raise celpy.CELEvalError(msg, ex.__class__, ex.args) from ex

return celtypes.BoolType(m is not None)
return m is not None
3 changes: 1 addition & 2 deletions protovalidate/internal/string_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
class StringFormat:
"""An implementation of string.format() in CEL."""

def __init__(self, locale: str):
self.locale = locale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

locale was not used anywhere. We can add it back when/if it's needed.

def __init__(self):
self.fmt = None

def format(self, fmt: celtypes.Value, args: celtypes.Value) -> celpy.Result:
Expand Down
3 changes: 2 additions & 1 deletion protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class Validator:
_cfg: Config

def __init__(self, config=None):
self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS)
self._cfg = config if config is not None else Config()
funcs = extra_func.make_extra_funcs(self._cfg)
self._factory = _rules.RuleFactory(funcs)

def validate(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ class TestConfig(unittest.TestCase):
def test_defaults(self):
cfg = Config()
self.assertFalse(cfg.fail_fast)
self.assertIsNone(cfg.regex_matches_func)
5 changes: 3 additions & 2 deletions tests/format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from gen.cel.expr import eval_pb2
from gen.cel.expr.conformance.test import simple_pb2
from protovalidate.config import Config
from protovalidate.internal import extra_func
from protovalidate.internal.cel_field_presence import InterpretedRunner

Expand Down Expand Up @@ -108,7 +109,7 @@ def test_format_successes(self):
if test.name in skipped_tests:
continue
ast = self._env.compile(test.expr)
prog = self._env.program(ast, functions=extra_func.EXTRA_FUNCS)
prog = self._env.program(ast, functions=extra_func.make_extra_funcs(Config()))

bindings = build_variables(test.bindings)
# Ideally we should use pytest parametrize instead of subtests, but
Expand All @@ -132,7 +133,7 @@ def test_format_errors(self):
if test.name in skipped_error_tests:
continue
ast = self._env.compile(test.expr)
prog = self._env.program(ast, functions=extra_func.EXTRA_FUNCS)
prog = self._env.program(ast, functions=extra_func.make_extra_funcs(Config()))

bindings = build_variables(test.bindings)
# Ideally we should use pytest parametrize instead of subtests, but
Expand Down
11 changes: 5 additions & 6 deletions tests/matches_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import unittest

import celpy
from celpy import celtypes

from protovalidate.internal import extra_func
from protovalidate.internal.matches import matches

invalid_patterns = [
r"\1",
Expand All @@ -30,15 +29,15 @@
r"\u0041",
r"\0 \01 \0a \012",
r"[\b]",
r"^\Z",
]


class TestMatches(unittest.TestCase):
def test_invalid_re2_syntax(self):
for pattern in invalid_patterns:
cel_pattern = celtypes.StringType(pattern)
try:
extra_func.cel_matches(celtypes.StringType("test"), cel_pattern)
self.fail(f"expected an error on pattern {cel_pattern}")
matches("test", pattern)
self.fail(f"expected an error on pattern {pattern}")
except celpy.CELEvalError as e:
self.assertEqual(str(e), f"error evaluating pattern {cel_pattern}, invalid RE2 syntax")
self.assertEqual(str(e), f"error evaluating pattern {pattern}, invalid RE2 syntax")
42 changes: 40 additions & 2 deletions tests/validate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import unittest

import celpy
from google.protobuf import message

import protovalidate
Expand Down Expand Up @@ -209,6 +211,42 @@ def test_fail_fast(self):
violations = validator.collect_violations(msg)
self._compare_violations(violations, [expected_violation])

def test_custom_matcher(self):
r"""Tests usage of the custom regex_matches_func in the config

A bit of a contrived example, but this exercises the code path
for specifying a custom regex matches function when writing regex rules.

Usage of the pattern \z is not supported in Python's re engine, only \Z is supported.
However, the inverse is true with re2 (\Z is _not_ supported and \z is supported).

This test shows using a custom matcher that converts any re2-compliant usage of \z
to \Z so that Python's re engine can execute it.
"""
msg = validations_pb2.InvalidRESyntax()

def matcher(text: str, pattern: str) -> bool:
pattern = pattern.replace("z", "Z")
try:
m = re.search(pattern, text)
except re.error as ex:
msg = "match error"
raise celpy.CELEvalError(msg, ex.__class__, ex.args) from ex
return m is not None

cfg = Config(regex_matches_func=matcher)
validator = protovalidate.Validator(config=cfg)

# Test validate
try:
validator.validate(msg)
except Exception:
self.fail("unexpected validation failure")

# Test collect_violations
violations = validator.collect_violations(msg)
self.assertEqual(len(violations), 0)

def _run_valid_tests(self, msg: message.Message):
"""A helper function for testing successful validation on a given message

Expand Down Expand Up @@ -257,12 +295,12 @@ def _run_compilation_error_tests(self, msg: message.Message, expected: str):
# Test validate
with self.assertRaises(protovalidate.CompilationError) as vce:
v.validate(msg)
self.assertEqual(str(vce.exception), expected)
self.assertEqual(str(vce.exception), expected)

# Test collect_violations
with self.assertRaises(protovalidate.CompilationError) as cvce:
v.collect_violations(msg)
self.assertEqual(str(cvce.exception), expected)
self.assertEqual(str(cvce.exception), expected)

def _compare_violations(self, actual: list[rules.Violation], expected: list[rules.Violation]) -> None:
"""Compares two lists of violations. The violations are expected to be in the expected order also."""
Expand Down
Loading