Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flake8_async/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class AttributeCall(NamedTuple):
def with_has_call(
node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> list[AttributeCall]:
if isinstance(base, str):
base = (base,)
res_list: list[AttributeCall] = []
for item in node.items:
if res := m.extract(
Expand Down
26 changes: 25 additions & 1 deletion flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from ..base import Statement
from .flake8asyncvisitor import Flake8AsyncVisitor_cst
from .helpers import (
cancel_scope_names,
disabled_by_default,
error_class_cst,
fnmatch_qualified_name_cst,
func_has_decorator,
iter_guaranteed_once_cst,
with_has_call,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -243,6 +245,10 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors):
"{0} from async iterable with no guaranteed checkpoint since {1.name} "
"on line {1.lineno}."
),
"ASYNC912": (
"CancelScope with no guaranteed checkpoint. This makes it potentially "
"impossible to cancel."
),
}

def __init__(self, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -420,8 +426,26 @@ def leave_Await(
def visit_With_body(self, node: cst.With):
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
if with_has_call(node, *cancel_scope_names) or with_has_call(
node, "timeout", "timeout_at", base="asyncio"
):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
self.uncheckpointed_statements.add(Statement("with", line, column))
# self.uncheckpointed_statements.add(res[0])

def leave_With_body(self, node: cst.With):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
s = Statement("with", line, column)
if s in self.uncheckpointed_statements:
self.error(node, error_code="ASYNC912")
self.uncheckpointed_statements.remove(s)

leave_With_body = visit_With_body
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()

# error if no checkpoint since earlier yield or function entry
def leave_Yield(
Expand Down
17 changes: 17 additions & 0 deletions tests/autofix_files/async91x_autofix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# ARG --enable=ASYNC910,ASYNC911

from typing import Any

import trio


Expand Down Expand Up @@ -124,3 +125,19 @@ async def async_func(): ...
break
[... for i in range(5)]
return


# TODO: issue 240
async def livelocks():
while True:
...


# this will autofix 910 by adding a checkpoint outside the loop
async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
while bar():
try:
await trio.sleep("1") # type: ignore[arg-type]
except ValueError:
...
await trio.lowlevel.checkpoint()
13 changes: 5 additions & 8 deletions tests/autofix_files/async91x_autofix.py.diff
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
---
+++
@@ x,6 x,7 @@
# ARG --enable=ASYNC910,ASYNC911

from typing import Any
+import trio


def bar() -> Any: ...
@@ x,30 x,38 @@

async def foo1(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
Expand Down Expand Up @@ -78,3 +70,8 @@
yield # ASYNC911: 8, "yield", Statement("function definition", lineno-2) # ASYNC911: 8, "yield", Statement("yield", lineno)

async def bar():
@@ x,3 x,4 @@
await trio.sleep("1") # type: ignore[arg-type]
except ValueError:
...
+ await trio.lowlevel.checkpoint()
108 changes: 108 additions & 0 deletions tests/eval_files/async912.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ASYNCIO_NO_ERROR
import trio


async def foo():
with trio.move_on_after(0.1): # error: 4
...
with trio.move_on_at(0.1): # error: 4
...
with trio.fail_after(0.1): # error: 4
...
with trio.fail_at(0.1): # error: 4
...
with trio.CancelScope(0.1): # error: 4
...

with open(""):
...

with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()

with trio.move_on_after(0.1): # error: 4
with trio.move_on_after(0.1): # error: 8
...

with trio.move_on_after(0.1): # TODO: should probably raise an error?
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()

with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()

with trio.move_on_after(0.1):
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
await trio.lowlevel.checkpoint()

# TODO: should probably raise the error at the call, rather than at the with statement
# fmt: off
with ( # error: 4
# a
# b
trio.move_on_after(0.1)
# c
):
...

with ( # error: 4
open(""),
trio.move_on_at(5),
open(""),
):
...
# fmt: on

# TODO: only raises one error currently, can make it raise 2(?)
with ( # error: 4
trio.move_on_after(0.1),
trio.fail_at(5),
):
...


# TODO: issue #240
async def livelocks():
with trio.move_on_after(0.1): # should error
while True:
try:
await trio.sleep("1") # type: ignore
except TypeError:
pass


def condition() -> bool:
return True


async def livelocks_2():
with trio.move_on_after(0.1): # error: 4
while condition():
try:
await trio.sleep("1") # type: ignore
except TypeError:
pass


# TODO: add --async912-context-managers=
async def livelocks_3():
import contextlib

with trio.move_on_after(0.1): # should error
while True:
with contextlib.suppress(TypeError):
await trio.sleep("1") # type: ignore


# raises an error...?
with trio.move_on_after(10): # error: 0
...


# completely sync function ... is this something we care about?
def sync_func():
with trio.move_on_after(10):
...
19 changes: 19 additions & 0 deletions tests/eval_files/async912_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# BASE_LIBRARY asyncio
# ANYIO_NO_ERROR
# TRIO_NO_ERROR

# timeout[_at] added in py3.11
# mypy: disable-error-code=attr-defined

import asyncio


async def foo():
async with asyncio.timeout(10): # error: 4
...
async with asyncio.timeout(10):
await foo()
async with asyncio.timeout_at(10): # error: 4
...
async with asyncio.timeout_at(10):
await foo()
17 changes: 17 additions & 0 deletions tests/eval_files/async91x_autofix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from typing import Any

import trio


def bar() -> Any: ...

Expand Down Expand Up @@ -109,3 +111,18 @@ async def async_func(): ...
break
[... for i in range(5)]
return


# TODO: issue 240
async def livelocks():
while True:
...


# this will autofix 910 by adding a checkpoint outside the loop
async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
while bar():
try:
await trio.sleep("1") # type: ignore[arg-type]
except ValueError:
...
3 changes: 2 additions & 1 deletion tests/test_flake8_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def _parse_eval_file(
"ASYNC116",
"ASYNC117",
"ASYNC118",
"ASYNC912",
}


Expand Down Expand Up @@ -479,7 +480,7 @@ def visit_AsyncFor(self, node: ast.AsyncFor):
return self.replace_async(node, ast.For, node.target, node.iter)


@pytest.mark.parametrize(("test", "path"), test_files)
@pytest.mark.parametrize(("test", "path"), test_files, ids=[f[0] for f in test_files])
def test_noerror_on_sync_code(test: str, path: Path):
if any(e in test for e in error_codes_ignored_when_checking_transformed_sync_code):
return
Expand Down