diff --git a/docs/changelog.rst b/docs/changelog.rst index 1d299702..e45308e3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,11 @@ Changelog `CalVer, YY.month.patch `_ +25.4.3 +====== +- :ref:`ASYNC100 ` can now autofix ``with`` statements with multiple items. +- Fixed a bug where multiple ``with`` items would not interact, leading to ASYNC100 and ASYNC9xx false alarms. https://github.com/python-trio/flake8-async/issues/156 + 25.4.2 ====== - Add :ref:`ASYNC125 ` constant-absolute-deadline diff --git a/docs/usage.rst b/docs/usage.rst index 429408e3..c93834cd 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -33,7 +33,7 @@ adding the following to your ``.pre-commit-config.yaml``: minimum_pre_commit_version: '2.9.0' repos: - repo: https://github.com/python-trio/flake8-async - rev: 25.4.2 + rev: 25.4.3 hooks: - id: flake8-async # args: ["--enable=ASYNC100,ASYNC112", "--disable=", "--autofix=ASYNC"] diff --git a/flake8_async/__init__.py b/flake8_async/__init__.py index a477ce30..862a0d12 100644 --- a/flake8_async/__init__.py +++ b/flake8_async/__init__.py @@ -38,7 +38,7 @@ # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "25.4.2" +__version__ = "25.4.3" # taken from https://github.com/Zac-HD/shed diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 2203e2c4..361755ce 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -8,7 +8,7 @@ import ast from dataclasses import dataclass from fnmatch import fnmatch -from typing import TYPE_CHECKING, NamedTuple, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar, Union import libcst as cst import libcst.matchers as m @@ -38,6 +38,8 @@ "T_EITHER", bound=Union[Flake8AsyncVisitor, Flake8AsyncVisitor_cst] ) +T_Call = TypeVar("T_Call", bound=Union[cst.Call, ast.Call]) + def error_class(error_class: type[T]) -> type[T]: assert error_class.error_codes @@ -289,8 +291,8 @@ def has_exception(node: ast.expr) -> str | None: @dataclass -class MatchingCall: - node: ast.Call +class MatchingCall(Generic[T_Call]): + node: T_Call name: str base: str @@ -301,7 +303,7 @@ def __str__(self) -> str: # convenience function used in a lot of visitors def get_matching_call( node: ast.AST, *names: str, base: Iterable[str] = ("trio", "anyio") -) -> MatchingCall | None: +) -> MatchingCall[ast.Call] | None: if isinstance(base, str): base = (base,) if ( @@ -316,6 +318,23 @@ def get_matching_call( # ___ CST helpers ___ +def get_matching_call_cst( + node: cst.CSTNode, *names: str, base: Iterable[str] = ("trio", "anyio") +) -> MatchingCall[cst.Call] | None: + if isinstance(base, str): + base = (base,) + if ( + isinstance(node, cst.Call) + and isinstance(node.func, cst.Attribute) + and node.func.attr.value in names + and isinstance(node.func.value, (cst.Name, cst.Attribute)) + ): + attr_base = identifier_to_string(node.func.value) + if attr_base is not None and attr_base in base: + return MatchingCall(node, node.func.attr.value, attr_base) + return None + + def oneof_names(*names: str): return m.OneOf(*map(m.Name, names)) @@ -329,12 +348,6 @@ def list_contains( yield from (item for item in seq if m.matches(item, matcher)) -class AttributeCall(NamedTuple): - node: cst.Call - base: str - function: str - - # the custom __or__ in libcst breaks pyright type checking. It's possible to use # `Union` as a workaround ... except pyupgrade will automatically replace that. # So we have to resort to specifying one of the base classes. @@ -365,7 +378,7 @@ def identifier_to_string(node: cst.CSTNode) -> str | None: def with_has_call( node: cst.With, *names: str, base: Iterable[str] | str = ("trio", "anyio") -) -> list[AttributeCall]: +) -> list[MatchingCall[cst.Call]]: """Check if a with statement has a matching call, returning a list with matches. `names` specify the names of functions to match, `base` specifies the @@ -396,7 +409,7 @@ def with_has_call( ) ) - res_list: list[AttributeCall] = [] + res_list: list[MatchingCall[cst.Call]] = [] for item in node.items: if res := m.extract(item.item, matcher): assert isinstance(item.item, cst.Call) @@ -405,7 +418,9 @@ def with_has_call( base_string = identifier_to_string(res["base"]) assert base_string is not None, "subscripts should never get matched" res_list.append( - AttributeCall(item.item, base_string, res["function"].value) + MatchingCall( + node=item.item, base=base_string, name=res["function"].value + ) ) return res_list diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 43cba5b4..433c3757 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -19,21 +19,21 @@ import libcst as cst import libcst.matchers as m -from libcst.metadata import PositionProvider +from libcst.metadata import CodeRange, PositionProvider from ..base import Statement from .flake8asyncvisitor import Flake8AsyncVisitor_cst from .helpers import ( - AttributeCall, + MatchingCall, cancel_scope_names, disable_codes_by_default, error_class_cst, flatten_preserving_comments, fnmatch_qualified_name_cst, func_has_decorator, + get_matching_call_cst, identifier_to_string, iter_guaranteed_once_cst, - with_has_call, ) if TYPE_CHECKING: @@ -374,6 +374,14 @@ def leave_Yield( disable_codes_by_default("ASYNC910", "ASYNC911", "ASYNC912", "ASYNC913") +@dataclass +class ContextManager: + has_checkpoint: bool | None = None + call: MatchingCall[cst.Call] | None = None + line: int | None = None + column: int | None = None + + @error_class_cst class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): error_codes: Mapping[str, str] = { @@ -408,8 +416,7 @@ def __init__(self, *args: Any, **kwargs: Any): self.match_state = MatchState() # ASYNC100 - self.has_checkpoint_stack: list[bool] = [] - self.node_dict: dict[cst.With, list[AttributeCall]] = {} + self.has_checkpoint_stack: list[ContextManager] = [] self.taskgroup_has_start_soon: dict[str, bool] = {} # --exception-suppress-context-manager @@ -429,7 +436,11 @@ def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: ) def checkpoint_cancel_point(self) -> None: - self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) + for cm in reversed(self.has_checkpoint_stack): + if cm.has_checkpoint: + # Everything further down in the stack is already True. + break + cm.has_checkpoint = True # don't need to look for any .start_soon() calls self.taskgroup_has_start_soon.clear() @@ -705,59 +716,106 @@ def _checkpoint_with(self, node: cst.With, entry: bool): # missing-checkpoint warning when there might in fact be one (i.e. a false alarm). def visit_With_body(self, node: cst.With): self.save_state(node, "taskgroup_has_start_soon", copy=True) - self._checkpoint_with(node, entry=True) + + is_suppressing = False # if this might suppress exceptions, we cannot treat anything inside it as # checkpointing. if self._is_exception_suppressing_context_manager(node): self.save_state(node, "uncheckpointed_statements", copy=True) - if res := ( - with_has_call(node, *cancel_scope_names) - or with_has_call( - node, "timeout", "timeout_at", base=("asyncio", "asyncio.timeouts") - ) - ): - pos = self.get_metadata(PositionProvider, node).start # pyright: ignore - line: int = pos.line # pyright: ignore - column: int = pos.column # pyright: ignore - self.uncheckpointed_statements.add( - ArtificialStatement("with", line, column) - ) - self.node_dict[node] = res - self.has_checkpoint_stack.append(False) - else: - self.has_checkpoint_stack.append(True) + for withitem in node.items: + self.has_checkpoint_stack.append(ContextManager()) + if get_matching_call_cst( + withitem.item, "open_nursery", "create_task_group" + ): + if withitem.asname is not None and isinstance( + withitem.asname.name, cst.Name + ): + self.taskgroup_has_start_soon[withitem.asname.name.value] = False + self.checkpoint_schedule_point() + # Technically somebody could set open_nursery or create_task_group as + # suppressing context managers, but we're not add logic for that. + continue + + if bool(getattr(node, "asynchronous", False)): + self.checkpoint() + + # not a clean function call + if not isinstance(withitem.item, cst.Call) or not isinstance( + withitem.item.func, (cst.Name, cst.Attribute) + ): + continue + + if ( + fnmatch_qualified_name_cst( + (withitem.item.func,), + "contextlib.suppress", + *self.suppress_imported_as, + *self.options.exception_suppress_context_managers, + ) + is not None + ): + # Don't re-update state if there's several suppressing cm's. + if not is_suppressing: + self.save_state(node, "uncheckpointed_statements", copy=True) + is_suppressing = True + continue + + if res := ( + get_matching_call_cst(withitem.item, *cancel_scope_names) + or get_matching_call_cst( + withitem.item, + "timeout", + "timeout_at", + base="asyncio", + ) + ): + # typing issue: https://github.com/Instagram/LibCST/issues/1107 + pos = cst.ensure_type( + self.get_metadata(PositionProvider, withitem), + CodeRange, + ).start + self.uncheckpointed_statements.add( + ArtificialStatement("withitem", pos.line, pos.column) + ) + + cm = self.has_checkpoint_stack[-1] + cm.line = pos.line + cm.column = pos.column + cm.call = res + cm.has_checkpoint = False def leave_With(self, original_node: cst.With, updated_node: cst.With): - # Uses leave_With instead of leave_With_body because we need access to both - # original and updated node - # ASYNC100 - if not self.has_checkpoint_stack.pop(): - autofix = len(updated_node.items) == 1 - for res in self.node_dict[original_node]: + withitems = list(updated_node.items) + for i in reversed(range(len(updated_node.items))): + cm = self.has_checkpoint_stack.pop() + # ASYNC100 + if cm.has_checkpoint is False: + res = cm.call + assert res is not None # bypass 910 & 911's should_autofix logic, which excludes asyncio - # (TODO: and uses self.noautofix ... which I don't remember what it's for) - autofix &= self.error( - res.node, res.base, res.function, error_code="ASYNC100" - ) and super().should_autofix(res.node, code="ASYNC100") - - if autofix: - return flatten_preserving_comments(updated_node) - # ASYNC912 - else: - pos = self.get_metadata( # pyright: ignore - PositionProvider, original_node - ).start # pyright: ignore - line: int = pos.line # pyright: ignore - column: int = pos.column # pyright: ignore - s = ArtificialStatement("with", line, column) - if s in self.uncheckpointed_statements: - self.uncheckpointed_statements.remove(s) - for res in self.node_dict[original_node]: - self.error(res.node, error_code="ASYNC912") - - self.node_dict.pop(original_node, None) + if self.error( + res.node, res.base, res.name, error_code="ASYNC100" + ) and super().should_autofix(res.node, code="ASYNC100"): + if len(withitems) == 1: + # Remove this With node, bypassing later logic. + return flatten_preserving_comments(updated_node) + if i == len(withitems) - 1: + # preserve trailing comma, or remove comma if there was none + withitems[-2] = withitems[-2].with_changes( + comma=withitems[-1].comma + ) + withitems.pop(i) + + # ASYNC912 + elif cm.call is not None: + assert cm.line is not None + assert cm.column is not None + s = ArtificialStatement("withitem", cm.line, cm.column) + if s in self.uncheckpointed_statements: + self.uncheckpointed_statements.remove(s) + self.error(cm.call.node, error_code="ASYNC912") # if exception-suppressing, restore all uncheckpointed statements from # before the `with`. @@ -767,7 +825,8 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): self.uncheckpointed_statements.update(prev_checkpoints) self._checkpoint_with(original_node, entry=False) - return updated_node + + return updated_node.with_changes(items=withitems) # error if no checkpoint since earlier yield or function entry def leave_Yield( diff --git a/tests/autofix_files/async100.py b/tests/autofix_files/async100.py index 2c081cdc..99b0dd7e 100644 --- a/tests/autofix_files/async100.py +++ b/tests/autofix_files/async100.py @@ -2,8 +2,14 @@ # AUTOFIX # ASYNCIO_NO_ERROR # timeout primitives are named differently in asyncio +import contextlib import trio + +def condition() -> bool: + return False + + # error: 5, "trio", "move_on_after" ... @@ -214,3 +220,29 @@ async def nursery_exit_blocks_with_start(): async with trio.open_nursery() as n: with trio.CancelScope(): await n.start(trio.sleep, 0) + + +async def autofix_multi_withitem(): + with open("foo"): # error: 9, "trio", "CancelScope" + ... + + # this one is completely removed + # error: 8, "trio", "CancelScope" + # error: 8, "trio", "CancelScope" + ... + + # these keep the `open` + with ( + open("aa") as _, # error: 8, "trio", "fail_after" + ): + ... + + with ( + open("bb") as _, # error: 8, "trio", "move_on_after" + ): + ... + + with ( + open("cc") as f, + ): + ... diff --git a/tests/autofix_files/async100.py.diff b/tests/autofix_files/async100.py.diff index a6095a2b..6119acf8 100644 --- a/tests/autofix_files/async100.py.diff +++ b/tests/autofix_files/async100.py.diff @@ -1,8 +1,8 @@ --- +++ @@ x,24 x,24 @@ + return False - import trio -with trio.move_on_after(10): # error: 5, "trio", "move_on_after" - ... @@ -189,3 +189,42 @@ # async100 does not consider *redundant* cancel scopes +@@ x,32 x,26 @@ + + + async def autofix_multi_withitem(): +- with trio.CancelScope(), open("foo"): # error: 9, "trio", "CancelScope" ++ with open("foo"): # error: 9, "trio", "CancelScope" + ... + + # this one is completely removed +- with ( +- trio.CancelScope(), # error: 8, "trio", "CancelScope" +- trio.CancelScope(), # error: 8, "trio", "CancelScope" +- ): +- ... ++ # error: 8, "trio", "CancelScope" ++ # error: 8, "trio", "CancelScope" ++ ... + + # these keep the `open` + with ( +- open("aa") as _, +- trio.fail_after(10), # error: 8, "trio", "fail_after" ++ open("aa") as _, # error: 8, "trio", "fail_after" + ): + ... + + with ( +- trio.fail_after(5), # error: 8, "trio", "fail_after" +- open("bb") as _, +- trio.move_on_after(5), # error: 8, "trio", "move_on_after" ++ open("bb") as _, # error: 8, "trio", "move_on_after" + ): + ... + + with ( +- trio.move_on_after(10), # error: 8, "trio", "move_on_after" + open("cc") as f, + ): + ... diff --git a/tests/autofix_files/async100_asyncio.py b/tests/autofix_files/async100_asyncio.py index cfd0121e..98866654 100644 --- a/tests/autofix_files/async100_asyncio.py +++ b/tests/autofix_files/async100_asyncio.py @@ -2,12 +2,11 @@ # ANYIO_NO_ERROR # BASE_LIBRARY asyncio -# timeout[_at] re-exported in the main asyncio namespace in py3.11 +# asyncio.timeout[_at] added in py3.11 # mypy: disable-error-code=attr-defined # AUTOFIX import asyncio -import asyncio.timeouts async def foo(): @@ -16,7 +15,17 @@ async def foo(): # error: 9, "asyncio", "timeout" ... - # error: 9, "asyncio.timeouts", "timeout_at" - ... - # error: 9, "asyncio.timeouts", "timeout" - ... + +# this is technically only a problem with asyncio, since timeout primitives in trio/anyio +# are sync cm's +async def multi_withitem(): + with open("foo"): # error: 9, "asyncio", "timeout" + ... + with open("foo"): # error: 22, "asyncio", "timeout" + ... + # retain explicit trailing comma (?) + with ( + open("foo"), + open("bar"), # error: 8, "asyncio", "timeout" + ): + ... diff --git a/tests/autofix_files/async100_asyncio.py.diff b/tests/autofix_files/async100_asyncio.py.diff index f083238a..a80806b2 100644 --- a/tests/autofix_files/async100_asyncio.py.diff +++ b/tests/autofix_files/async100_asyncio.py.diff @@ -1,6 +1,6 @@ --- +++ -@@ x,12 x,12 @@ +@@ x,23 x,22 @@ async def foo(): @@ -13,11 +13,21 @@ + # error: 9, "asyncio", "timeout" + ... -- with asyncio.timeouts.timeout_at(10): # error: 9, "asyncio.timeouts", "timeout_at" -- ... -- with asyncio.timeouts.timeout(10): # error: 9, "asyncio.timeouts", "timeout" -- ... -+ # error: 9, "asyncio.timeouts", "timeout_at" -+ ... -+ # error: 9, "asyncio.timeouts", "timeout" -+ ... + + # this is technically only a problem with asyncio, since timeout primitives in trio/anyio + # are sync cm's + async def multi_withitem(): +- with asyncio.timeout(10), open("foo"): # error: 9, "asyncio", "timeout" ++ with open("foo"): # error: 9, "asyncio", "timeout" + ... +- with open("foo"), asyncio.timeout(10): # error: 22, "asyncio", "timeout" ++ with open("foo"): # error: 22, "asyncio", "timeout" + ... + # retain explicit trailing comma (?) + with ( + open("foo"), +- open("bar"), +- asyncio.timeout(10), # error: 8, "asyncio", "timeout" ++ open("bar"), # error: 8, "asyncio", "timeout" + ): + ... diff --git a/tests/eval_files/async100.py b/tests/eval_files/async100.py index c51a1261..809347be 100644 --- a/tests/eval_files/async100.py +++ b/tests/eval_files/async100.py @@ -2,8 +2,14 @@ # AUTOFIX # ASYNCIO_NO_ERROR # timeout primitives are named differently in asyncio +import contextlib import trio + +def condition() -> bool: + return False + + with trio.move_on_after(10): # error: 5, "trio", "move_on_after" ... @@ -214,3 +220,35 @@ async def nursery_exit_blocks_with_start(): async with trio.open_nursery() as n: with trio.CancelScope(): await n.start(trio.sleep, 0) + + +async def autofix_multi_withitem(): + with trio.CancelScope(), open("foo"): # error: 9, "trio", "CancelScope" + ... + + # this one is completely removed + with ( + trio.CancelScope(), # error: 8, "trio", "CancelScope" + trio.CancelScope(), # error: 8, "trio", "CancelScope" + ): + ... + + # these keep the `open` + with ( + open("aa") as _, + trio.fail_after(10), # error: 8, "trio", "fail_after" + ): + ... + + with ( + trio.fail_after(5), # error: 8, "trio", "fail_after" + open("bb") as _, + trio.move_on_after(5), # error: 8, "trio", "move_on_after" + ): + ... + + with ( + trio.move_on_after(10), # error: 8, "trio", "move_on_after" + open("cc") as f, + ): + ... diff --git a/tests/eval_files/async100_asyncio.py b/tests/eval_files/async100_asyncio.py index 494803ab..5d2754c5 100644 --- a/tests/eval_files/async100_asyncio.py +++ b/tests/eval_files/async100_asyncio.py @@ -2,12 +2,11 @@ # ANYIO_NO_ERROR # BASE_LIBRARY asyncio -# timeout[_at] re-exported in the main asyncio namespace in py3.11 +# asyncio.timeout[_at] added in py3.11 # mypy: disable-error-code=attr-defined # AUTOFIX import asyncio -import asyncio.timeouts async def foo(): @@ -16,7 +15,18 @@ async def foo(): with asyncio.timeout(10): # error: 9, "asyncio", "timeout" ... - with asyncio.timeouts.timeout_at(10): # error: 9, "asyncio.timeouts", "timeout_at" + +# this is technically only a problem with asyncio, since timeout primitives in trio/anyio +# are sync cm's +async def multi_withitem(): + with asyncio.timeout(10), open("foo"): # error: 9, "asyncio", "timeout" + ... + with open("foo"), asyncio.timeout(10): # error: 22, "asyncio", "timeout" ... - with asyncio.timeouts.timeout(10): # error: 9, "asyncio.timeouts", "timeout" + # retain explicit trailing comma (?) + with ( + open("foo"), + open("bar"), + asyncio.timeout(10), # error: 8, "asyncio", "timeout" + ): ... diff --git a/tests/eval_files/async100_asyncio_noautofix.py b/tests/eval_files/async100_asyncio_noautofix.py new file mode 100644 index 00000000..f640d415 --- /dev/null +++ b/tests/eval_files/async100_asyncio_noautofix.py @@ -0,0 +1,21 @@ +# TRIO_NO_ERROR +# ANYIO_NO_ERROR +# BASE_LIBRARY asyncio + +# We remove the last timeout, but don't re-evaluate the whole with statement, +# so the test still raises an error. +# NOAUTOFIX + +# asyncio.timeout[_at] added in py3.11 +# mypy: disable-error-code=attr-defined + +import asyncio + + +async def multi_withitem(): + async with asyncio.timeout( + 10 + ), asyncio.timeout_at( # error: 7, "asyncio", "timeout_at" + 10 + ): + ... diff --git a/tests/eval_files/async100_noautofix.py b/tests/eval_files/async100_noautofix.py deleted file mode 100644 index 6da4ed67..00000000 --- a/tests/eval_files/async100_noautofix.py +++ /dev/null @@ -1,25 +0,0 @@ -# ASYNCIO_NO_ERROR - no asyncio.move_on_after -import trio - - -# Doesn't autofix With's with multiple withitems -async def function_name2(): - with ( - open("") as _, - trio.fail_after(10), # error: 8, "trio", "fail_after" - ): - ... - - with ( - trio.fail_after(5), # error: 8, "trio", "fail_after" - open("") as _, - trio.move_on_after(5), # error: 8, "trio", "move_on_after" - ): - ... - - -with ( - trio.move_on_after(10), # error: 4, "trio", "move_on_after" - open("") as f, -): - ... diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py index ef9200bf..3dcdd986 100644 --- a/tests/eval_files/async912_asyncio.py +++ b/tests/eval_files/async912_asyncio.py @@ -6,7 +6,7 @@ # ASYNC100 supports autofix, but ASYNC912 doesn't, so we must run with NOAUTOFIX # NOAUTOFIX -# timeout[_at] re-exported in the main asyncio namespace in py3.11 +# asyncio.timeout[_at] added in py3.11 # mypy: disable-error-code=attr-defined import asyncio @@ -14,27 +14,16 @@ from typing import Any -def bar() -> bool: +def bar() -> Any: return False -def customWrapper(a: object) -> object: ... - - async def foo(): # async100 async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout" ... async with asyncio.timeout_at(10): # ASYNC100: 15, "asyncio", "timeout_at" ... - async with asyncio.timeouts.timeout( # ASYNC100: 15, "asyncio.timeouts", "timeout" - 10 - ): - ... - async with asyncio.timeouts.timeout_at( # ASYNC100: 15, "asyncio.timeouts", "timeout_at" - 10 - ): - ... # no errors async with asyncio.timeout(10): @@ -50,10 +39,10 @@ async def foo(): if bar(): await foo() - async with asyncio.timeouts.timeout(10): # ASYNC912: 15 - if bar(): - await foo() - async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15 + # multiple withitems + async with asyncio.timeout(10), bar(): + ... + async with bar(), asyncio.timeout(10): # ASYNC912: 22 if bar(): await foo()