Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions flake8_async/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ast
from dataclasses import dataclass
from fnmatch import fnmatch
from typing import TYPE_CHECKING, NamedTuple, TypeVar, Union
from typing import TYPE_CHECKING, Generic, TypeVar, Union

import libcst as cst
import libcst.matchers as m
Expand Down Expand Up @@ -38,6 +38,8 @@
"T_EITHER", bound=Union[Flake8AsyncVisitor, Flake8AsyncVisitor_cst]
)

T_Call = TypeVar("T_Call", bound=Union[cst.Call, ast.Call])


def error_class(error_class: type[T]) -> type[T]:
assert error_class.error_codes
Expand Down Expand Up @@ -289,8 +291,8 @@ def has_exception(node: ast.expr) -> str | None:


@dataclass
class MatchingCall:
node: ast.Call
class MatchingCall(Generic[T_Call]):
node: T_Call
name: str
base: str

Expand All @@ -301,7 +303,7 @@ def __str__(self) -> str:
# convenience function used in a lot of visitors
def get_matching_call(
node: ast.AST, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> MatchingCall | None:
) -> MatchingCall[ast.Call] | None:
if isinstance(base, str):
base = (base,)
if (
Expand All @@ -316,6 +318,23 @@ def get_matching_call(


# ___ CST helpers ___
def get_matching_call_cst(
node: cst.CSTNode, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> MatchingCall[cst.Call] | None:
if isinstance(base, str):
base = (base,)
if (
isinstance(node, cst.Call)
and isinstance(node.func, cst.Attribute)
and node.func.attr.value in names
and isinstance(node.func.value, (cst.Name, cst.Attribute))
):
attr_base = identifier_to_string(node.func.value)
if attr_base is not None and attr_base in base:
return MatchingCall(node, node.func.attr.value, attr_base)
return None


def oneof_names(*names: str):
return m.OneOf(*map(m.Name, names))

Expand All @@ -329,12 +348,6 @@ def list_contains(
yield from (item for item in seq if m.matches(item, matcher))


class AttributeCall(NamedTuple):
node: cst.Call
base: str
function: str


# the custom __or__ in libcst breaks pyright type checking. It's possible to use
# `Union` as a workaround ... except pyupgrade will automatically replace that.
# So we have to resort to specifying one of the base classes.
Expand Down Expand Up @@ -365,7 +378,7 @@ def identifier_to_string(node: cst.CSTNode) -> str | None:

def with_has_call(
node: cst.With, *names: str, base: Iterable[str] | str = ("trio", "anyio")
) -> list[AttributeCall]:
) -> list[MatchingCall[cst.Call]]:
"""Check if a with statement has a matching call, returning a list with matches.

`names` specify the names of functions to match, `base` specifies the
Expand Down Expand Up @@ -396,7 +409,7 @@ def with_has_call(
)
)

res_list: list[AttributeCall] = []
res_list: list[MatchingCall[cst.Call]] = []
for item in node.items:
if res := m.extract(item.item, matcher):
assert isinstance(item.item, cst.Call)
Expand All @@ -405,7 +418,9 @@ def with_has_call(
base_string = identifier_to_string(res["base"])
assert base_string is not None, "subscripts should never get matched"
res_list.append(
AttributeCall(item.item, base_string, res["function"].value)
MatchingCall(
node=item.item, base=base_string, name=res["function"].value
)
)
return res_list

Expand Down
161 changes: 110 additions & 51 deletions flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@

import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider
from libcst.metadata import CodeRange, PositionProvider

from ..base import Statement
from .flake8asyncvisitor import Flake8AsyncVisitor_cst
from .helpers import (
AttributeCall,
MatchingCall,
cancel_scope_names,
disable_codes_by_default,
error_class_cst,
flatten_preserving_comments,
fnmatch_qualified_name_cst,
func_has_decorator,
get_matching_call_cst,
identifier_to_string,
iter_guaranteed_once_cst,
with_has_call,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -374,6 +374,14 @@ def leave_Yield(
disable_codes_by_default("ASYNC910", "ASYNC911", "ASYNC912", "ASYNC913")


@dataclass
class ContextManager:
has_checkpoint: bool | None = None
call: MatchingCall[cst.Call] | None = None
line: int | None = None
column: int | None = None


@error_class_cst
class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors):
error_codes: Mapping[str, str] = {
Expand Down Expand Up @@ -408,8 +416,7 @@ def __init__(self, *args: Any, **kwargs: Any):
self.match_state = MatchState()

# ASYNC100
self.has_checkpoint_stack: list[bool] = []
self.node_dict: dict[cst.With, list[AttributeCall]] = {}
self.has_checkpoint_stack: list[ContextManager] = []
self.taskgroup_has_start_soon: dict[str, bool] = {}

# --exception-suppress-context-manager
Expand All @@ -429,7 +436,11 @@ def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
)

def checkpoint_cancel_point(self) -> None:
self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack)
for cm in reversed(self.has_checkpoint_stack):
if cm.has_checkpoint:
# Everything further down in the stack is already True.
break
cm.has_checkpoint = True
# don't need to look for any .start_soon() calls
self.taskgroup_has_start_soon.clear()

Expand Down Expand Up @@ -705,59 +716,106 @@ def _checkpoint_with(self, node: cst.With, entry: bool):
# missing-checkpoint warning when there might in fact be one (i.e. a false alarm).
def visit_With_body(self, node: cst.With):
self.save_state(node, "taskgroup_has_start_soon", copy=True)
self._checkpoint_with(node, entry=True)

is_suppressing = False

# if this might suppress exceptions, we cannot treat anything inside it as
# checkpointing.
if self._is_exception_suppressing_context_manager(node):
self.save_state(node, "uncheckpointed_statements", copy=True)

if res := (
with_has_call(node, *cancel_scope_names)
or with_has_call(
node, "timeout", "timeout_at", base=("asyncio", "asyncio.timeouts")
)
):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
self.uncheckpointed_statements.add(
ArtificialStatement("with", line, column)
)
self.node_dict[node] = res
self.has_checkpoint_stack.append(False)
else:
self.has_checkpoint_stack.append(True)
for withitem in node.items:
self.has_checkpoint_stack.append(ContextManager())
if get_matching_call_cst(
withitem.item, "open_nursery", "create_task_group"
):
if withitem.asname is not None and isinstance(
withitem.asname.name, cst.Name
):
self.taskgroup_has_start_soon[withitem.asname.name.value] = False
self.checkpoint_schedule_point()
# Technically somebody could set open_nursery or create_task_group as
# suppressing context managers, but we're not add logic for that.
continue

if bool(getattr(node, "asynchronous", False)):
self.checkpoint()

# not a clean function call
if not isinstance(withitem.item, cst.Call) or not isinstance(
withitem.item.func, (cst.Name, cst.Attribute)
):
continue

if (
fnmatch_qualified_name_cst(
(withitem.item.func,),
"contextlib.suppress",
*self.suppress_imported_as,
*self.options.exception_suppress_context_managers,
)
is not None
):
# Don't re-update state if there's several suppressing cm's.
if not is_suppressing:
self.save_state(node, "uncheckpointed_statements", copy=True)
is_suppressing = True
continue

if res := (
get_matching_call_cst(withitem.item, *cancel_scope_names)
or get_matching_call_cst(
withitem.item,
"timeout",
"timeout_at",
base="asyncio",
)
):
# typing issue: https://github.com/Instagram/LibCST/issues/1107
pos = cst.ensure_type(
self.get_metadata(PositionProvider, withitem),
CodeRange,
).start
self.uncheckpointed_statements.add(
ArtificialStatement("withitem", pos.line, pos.column)
)

cm = self.has_checkpoint_stack[-1]
cm.line = pos.line
cm.column = pos.column
cm.call = res
cm.has_checkpoint = False

def leave_With(self, original_node: cst.With, updated_node: cst.With):
# Uses leave_With instead of leave_With_body because we need access to both
# original and updated node
# ASYNC100
if not self.has_checkpoint_stack.pop():
autofix = len(updated_node.items) == 1
for res in self.node_dict[original_node]:
withitems = list(updated_node.items)
for i in reversed(range(len(updated_node.items))):
cm = self.has_checkpoint_stack.pop()
# ASYNC100
if cm.has_checkpoint is False:
res = cm.call
assert res is not None
# bypass 910 & 911's should_autofix logic, which excludes asyncio
# (TODO: and uses self.noautofix ... which I don't remember what it's for)
autofix &= self.error(
res.node, res.base, res.function, error_code="ASYNC100"
) and super().should_autofix(res.node, code="ASYNC100")

if autofix:
return flatten_preserving_comments(updated_node)
# ASYNC912
else:
pos = self.get_metadata( # pyright: ignore
PositionProvider, original_node
).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
s = ArtificialStatement("with", line, column)
if s in self.uncheckpointed_statements:
self.uncheckpointed_statements.remove(s)
for res in self.node_dict[original_node]:
self.error(res.node, error_code="ASYNC912")

self.node_dict.pop(original_node, None)
if self.error(
res.node, res.base, res.name, error_code="ASYNC100"
) and super().should_autofix(res.node, code="ASYNC100"):
if len(withitems) == 1:
# Remove this With node, bypassing later logic.
return flatten_preserving_comments(updated_node)
if i == len(withitems) - 1:
# preserve trailing comma, or remove comma if there was none
withitems[-2] = withitems[-2].with_changes(
comma=withitems[-1].comma
)
withitems.pop(i)

# ASYNC912
elif cm.call is not None:
assert cm.line is not None
assert cm.column is not None
s = ArtificialStatement("withitem", cm.line, cm.column)
if s in self.uncheckpointed_statements:
self.uncheckpointed_statements.remove(s)
self.error(cm.call.node, error_code="ASYNC912")

# if exception-suppressing, restore all uncheckpointed statements from
# before the `with`.
Expand All @@ -767,7 +825,8 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With):
self.uncheckpointed_statements.update(prev_checkpoints)

self._checkpoint_with(original_node, entry=False)
return updated_node

return updated_node.with_changes(items=withitems)

# error if no checkpoint since earlier yield or function entry
def leave_Yield(
Expand Down
29 changes: 29 additions & 0 deletions tests/autofix_files/async100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
# AUTOFIX
# ASYNCIO_NO_ERROR # timeout primitives are named differently in asyncio

import contextlib
import trio


def condition() -> bool:
return False


# error: 5, "trio", "move_on_after"
...

Expand Down Expand Up @@ -214,3 +220,26 @@ async def nursery_exit_blocks_with_start():
async with trio.open_nursery() as n:
with trio.CancelScope():
await n.start(trio.sleep, 0)


async def autofix_multi_withitem():
with open("foo"): # error: 9, "trio", "CancelScope"
...
# error: 8, "trio", "CancelScope"
# error: 8, "trio", "CancelScope"
...

with (
open("") as _, # error: 8, "trio", "fail_after"
):
...

with (
open("") as _, # error: 8, "trio", "move_on_after"
):
...

with (
open("") as f,
):
...
Loading
Loading