Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/_pytest/mark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import TYPE_CHECKING

from .expression import Expression
from .expression import ParseError
from .structures import _HiddenParam
from .structures import EMPTY_PARAMETERSET_OPTION
from .structures import get_empty_parameterset_mark
Expand Down Expand Up @@ -274,8 +273,10 @@ def deselect_by_mark(items: list[Item], config: Config) -> None:
def _parse_expression(expr: str, exc_message: str) -> Expression:
try:
return Expression.compile(expr)
except ParseError as e:
raise UsageError(f"{exc_message}: {expr}: {e}") from None
except SyntaxError as e:
raise UsageError(
f"{exc_message}: {e.text}: at column {e.offset}: {e.msg}"
) from None


def pytest_collection_modifyitems(items: list[Item], config: Config) -> None:
Expand Down
112 changes: 67 additions & 45 deletions src/_pytest/mark/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

- Empty expression evaluates to False.
- ident evaluates to True or False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
- ident with parentheses and keyword arguments evaluates to True or False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
"""

from __future__ import annotations
Expand All @@ -31,6 +31,8 @@
import keyword
import re
import types
from typing import Final
from typing import final
from typing import Literal
from typing import NoReturn
from typing import overload
Expand All @@ -39,10 +41,13 @@

__all__ = [
"Expression",
"ParseError",
"ExpressionMatcher",
]


FILE_NAME: Final = "<pytest match expression>"


class TokenType(enum.Enum):
LPAREN = "left parenthesis"
RPAREN = "right parenthesis"
Expand All @@ -64,25 +69,11 @@ class Token:
pos: int


class ParseError(Exception):
"""The expression contains invalid syntax.

:param column: The column in the line where the error occurred (1-based).
:param message: A description of the error.
"""

def __init__(self, column: int, message: str) -> None:
self.column = column
self.message = message

def __str__(self) -> str:
return f"at column {self.column}: {self.message}"


class Scanner:
__slots__ = ("current", "tokens")
__slots__ = ("current", "input", "tokens")

def __init__(self, input: str) -> None:
self.input = input
self.tokens = self.lex(input)
self.current = next(self.tokens)

Expand All @@ -106,15 +97,15 @@ def lex(self, input: str) -> Iterator[Token]:
elif (quote_char := input[pos]) in ("'", '"'):
end_quote_pos = input.find(quote_char, pos + 1)
if end_quote_pos == -1:
raise ParseError(
pos + 1,
raise SyntaxError(
f'closing quote "{quote_char}" is missing',
(FILE_NAME, 1, pos + 1, input),
)
value = input[pos : end_quote_pos + 1]
if (backslash_pos := input.find("\\")) != -1:
raise ParseError(
backslash_pos + 1,
raise SyntaxError(
r'escaping with "\" not supported in marker expression',
(FILE_NAME, 1, backslash_pos + 1, input),
)
yield Token(TokenType.STRING, value, pos)
pos += len(value)
Expand All @@ -132,9 +123,9 @@ def lex(self, input: str) -> Iterator[Token]:
yield Token(TokenType.IDENT, value, pos)
pos += len(value)
else:
raise ParseError(
pos + 1,
raise SyntaxError(
f'unexpected character "{input[pos]}"',
(FILE_NAME, 1, pos + 1, input),
)
yield Token(TokenType.EOF, "", pos)

Expand All @@ -157,12 +148,12 @@ def accept(self, type: TokenType, *, reject: bool = False) -> Token | None:
return None

def reject(self, expected: Sequence[TokenType]) -> NoReturn:
raise ParseError(
self.current.pos + 1,
raise SyntaxError(
"expected {}; got {}".format(
" OR ".join(type.value for type in expected),
self.current.type.value,
),
(FILE_NAME, 1, self.current.pos + 1, self.input),
)


Expand Down Expand Up @@ -223,14 +214,14 @@ def not_expr(s: Scanner) -> ast.expr:
def single_kwarg(s: Scanner) -> ast.keyword:
keyword_name = s.accept(TokenType.IDENT, reject=True)
if not keyword_name.value.isidentifier():
raise ParseError(
keyword_name.pos + 1,
raise SyntaxError(
f"not a valid python identifier {keyword_name.value}",
(FILE_NAME, 1, keyword_name.pos + 1, s.input),
)
if keyword.iskeyword(keyword_name.value):
raise ParseError(
keyword_name.pos + 1,
raise SyntaxError(
f"unexpected reserved python keyword `{keyword_name.value}`",
(FILE_NAME, 1, keyword_name.pos + 1, s.input),
)
s.accept(TokenType.EQUAL, reject=True)

Expand All @@ -245,9 +236,9 @@ def single_kwarg(s: Scanner) -> ast.keyword:
elif value_token.value in BUILTIN_MATCHERS:
value = BUILTIN_MATCHERS[value_token.value]
else:
raise ParseError(
value_token.pos + 1,
raise SyntaxError(
f'unexpected character/s "{value_token.value}"',
(FILE_NAME, 1, value_token.pos + 1, s.input),
)

ret = ast.keyword(keyword_name.value, ast.Constant(value))
Expand All @@ -261,13 +252,36 @@ def all_kwargs(s: Scanner) -> list[ast.keyword]:
return ret


class MatcherCall(Protocol):
class ExpressionMatcher(Protocol):
"""A callable which, given an identifier and optional kwargs, should return
whether it matches in an :class:`Expression` evaluation.

Should be prepared to handle arbitrary strings as input.

If no kwargs are provided, the expression of the form `foo`.
If kwargs are provided, the expression is of the form `foo(1, b=True, "s")`.

If the expression is not supported (e.g. don't want to accept the kwargs
syntax variant), should raise :class:`~pytest.UsageError`.

Example::

def matcher(name: str, /, **kwargs: str | int | bool | None) -> bool:
# Match `cat`.
if name == "cat" and not kwargs:
return True
# Match `dog(barks=True)`.
if name == "dog" and kwargs == {"barks": False}:
return True
return False
"""

def __call__(self, name: str, /, **kwargs: str | int | bool | None) -> bool: ...


@dataclasses.dataclass
class MatcherNameAdapter:
matcher: MatcherCall
matcher: ExpressionMatcher
name: str

def __bool__(self) -> bool:
Expand All @@ -280,7 +294,7 @@ def __call__(self, **kwargs: str | int | bool | None) -> bool:
class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
"""Adapts a matcher function to a locals mapping as required by eval()."""

def __init__(self, matcher: MatcherCall) -> None:
def __init__(self, matcher: ExpressionMatcher) -> None:
self.matcher = matcher

def __getitem__(self, key: str) -> MatcherNameAdapter:
Expand All @@ -293,39 +307,47 @@ def __len__(self) -> int:
raise NotImplementedError()


@final
class Expression:
"""A compiled match expression as used by -k and -m.

The expression can be evaluated against different matchers.
"""

__slots__ = ("code",)
__slots__ = ("_code", "input")

def __init__(self, code: types.CodeType) -> None:
self.code = code
def __init__(self, input: str, code: types.CodeType) -> None:
#: The original input line, as a string.
self.input: Final = input
self._code: Final = code

@classmethod
def compile(cls, input: str) -> Expression:
"""Compile a match expression.

:param input: The input expression - one line.

:raises SyntaxError: If the expression is malformed.
"""
astexpr = expression(Scanner(input))
code: types.CodeType = compile(
code = compile(
astexpr,
filename="<pytest match expression>",
mode="eval",
)
return Expression(code)
return Expression(input, code)

def evaluate(self, matcher: MatcherCall) -> bool:
def evaluate(self, matcher: ExpressionMatcher) -> bool:
"""Evaluate the match expression.

:param matcher:
Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
A callback which determines whether an identifier matches or not.
See the :class:`ExpressionMatcher` protocol for details and example.

:returns: Whether the expression matches or not.

:raises UsageError:
If the matcher doesn't support the expression. Cannot happen if the
matcher supports all expressions.
"""
ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
return ret
return bool(eval(self._code, {"__builtins__": {}}, MatcherAdapter(matcher)))
52 changes: 29 additions & 23 deletions testing/test_mark_expression.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
from __future__ import annotations

from collections.abc import Callable
from typing import cast

from _pytest.mark import MarkMatcher
from _pytest.mark.expression import Expression
from _pytest.mark.expression import MatcherCall
from _pytest.mark.expression import ParseError
from _pytest.mark.expression import ExpressionMatcher
import pytest


def evaluate(input: str, matcher: Callable[[str], bool]) -> bool:
return Expression.compile(input).evaluate(cast(MatcherCall, matcher))
def evaluate(input: str, matcher: ExpressionMatcher) -> bool:
return Expression.compile(input).evaluate(matcher)


def test_empty_is_false() -> None:
assert not evaluate("", lambda ident: False)
assert not evaluate("", lambda ident: True)
assert not evaluate(" ", lambda ident: False)
assert not evaluate("\t", lambda ident: False)
assert not evaluate("", lambda ident, /, **kwargs: False)
assert not evaluate("", lambda ident, /, **kwargs: True)
assert not evaluate(" ", lambda ident, /, **kwargs: False)
assert not evaluate("\t", lambda ident, /, **kwargs: False)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -51,7 +47,9 @@ def test_empty_is_false() -> None:
),
)
def test_basic(expr: str, expected: bool) -> None:
matcher = {"true": True, "false": False}.__getitem__
def matcher(name: str, /, **kwargs: str | int | bool | None) -> bool:
return {"true": True, "false": False}[name]

assert evaluate(expr, matcher) is expected


Expand All @@ -67,7 +65,9 @@ def test_basic(expr: str, expected: bool) -> None:
),
)
def test_syntax_oddities(expr: str, expected: bool) -> None:
matcher = {"true": True, "false": False}.__getitem__
def matcher(name: str, /, **kwargs: str | int | bool | None) -> bool:
return {"true": True, "false": False}[name]

assert evaluate(expr, matcher) is expected


Expand All @@ -77,11 +77,13 @@ def test_backslash_not_treated_specially() -> None:
user will never need to insert a literal newline, only \n (two chars). So
mark expressions themselves do not support escaping, instead they treat
backslashes as regular identifier characters."""
matcher = {r"\nfoo\n"}.__contains__

def matcher(name: str, /, **kwargs: str | int | bool | None) -> bool:
return {r"\nfoo\n"}.__contains__(name)

assert evaluate(r"\nfoo\n", matcher)
assert not evaluate(r"foo", matcher)
with pytest.raises(ParseError):
with pytest.raises(SyntaxError):
evaluate("\nfoo\n", matcher)


Expand Down Expand Up @@ -134,10 +136,10 @@ def test_backslash_not_treated_specially() -> None:
),
)
def test_syntax_errors(expr: str, column: int, message: str) -> None:
with pytest.raises(ParseError) as excinfo:
evaluate(expr, lambda ident: True)
assert excinfo.value.column == column
assert excinfo.value.message == message
with pytest.raises(SyntaxError) as excinfo:
evaluate(expr, lambda ident, /, **kwargs: True)
assert excinfo.value.offset == column
assert excinfo.value.msg == message


@pytest.mark.parametrize(
Expand Down Expand Up @@ -172,7 +174,10 @@ def test_syntax_errors(expr: str, column: int, message: str) -> None:
),
)
def test_valid_idents(ident: str) -> None:
assert evaluate(ident, {ident: True}.__getitem__)
def matcher(name: str, /, **kwargs: str | int | bool | None) -> bool:
return name == ident

assert evaluate(ident, matcher)


@pytest.mark.parametrize(
Expand All @@ -198,13 +203,14 @@ def test_valid_idents(ident: str) -> None:
),
)
def test_invalid_idents(ident: str) -> None:
with pytest.raises(ParseError):
evaluate(ident, lambda ident: True)
with pytest.raises(SyntaxError):
evaluate(ident, lambda ident, /, **kwargs: True)


@pytest.mark.parametrize(
"expr, expected_error_msg",
(
("mark()", "expected identifier; got right parenthesis"),
("mark(True=False)", "unexpected reserved python keyword `True`"),
("mark(def=False)", "unexpected reserved python keyword `def`"),
("mark(class=False)", "unexpected reserved python keyword `class`"),
Expand Down Expand Up @@ -234,7 +240,7 @@ def test_invalid_idents(ident: str) -> None:
def test_invalid_kwarg_name_or_value(
expr: str, expected_error_msg: str, mark_matcher: MarkMatcher
) -> None:
with pytest.raises(ParseError, match=expected_error_msg):
with pytest.raises(SyntaxError, match=expected_error_msg):
assert evaluate(expr, mark_matcher)


Expand Down