Skip to content

ASYNC912: timeout/cancelscope with only conditional checkpoints #242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flake8_async/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class AttributeCall(NamedTuple):
def with_has_call(
node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> list[AttributeCall]:
if isinstance(base, str):
base = (base,)
res_list: list[AttributeCall] = []
for item in node.items:
if res := m.extract(
Expand Down
26 changes: 25 additions & 1 deletion flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from ..base import Statement
from .flake8asyncvisitor import Flake8AsyncVisitor_cst
from .helpers import (
cancel_scope_names,
disabled_by_default,
error_class_cst,
fnmatch_qualified_name_cst,
func_has_decorator,
iter_guaranteed_once_cst,
with_has_call,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -243,6 +245,10 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors):
"{0} from async iterable with no guaranteed checkpoint since {1.name} "
"on line {1.lineno}."
),
"ASYNC912": (
"CancelScope with no guaranteed checkpoint. This makes it potentially "
"impossible to cancel."
),
}

def __init__(self, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -420,8 +426,26 @@ def leave_Await(
def visit_With_body(self, node: cst.With):
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
if with_has_call(node, *cancel_scope_names) or with_has_call(
node, "timeout", "timeout_at", base="asyncio"
):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
self.uncheckpointed_statements.add(Statement("with", line, column))
# self.uncheckpointed_statements.add(res[0])

def leave_With_body(self, node: cst.With):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
s = Statement("with", line, column)
if s in self.uncheckpointed_statements:
self.error(node, error_code="ASYNC912")
self.uncheckpointed_statements.remove(s)

leave_With_body = visit_With_body
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()

# error if no checkpoint since earlier yield or function entry
def leave_Yield(
Expand Down
17 changes: 17 additions & 0 deletions tests/autofix_files/async91x_autofix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# ARG --enable=ASYNC910,ASYNC911

from typing import Any

import trio


Expand Down Expand Up @@ -124,3 +125,19 @@ async def async_func(): ...
break
[... for i in range(5)]
return


# TODO: issue 240
async def livelocks():
while True:
...


# this will autofix 910 by adding a checkpoint outside the loop
async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
while bar():
try:
await trio.sleep("1") # type: ignore[arg-type]
except ValueError:
...
await trio.lowlevel.checkpoint()
13 changes: 5 additions & 8 deletions tests/autofix_files/async91x_autofix.py.diff
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
---
+++
@@ x,6 x,7 @@
# ARG --enable=ASYNC910,ASYNC911

from typing import Any
+import trio


def bar() -> Any: ...
@@ x,30 x,38 @@

async def foo1(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
Expand Down Expand Up @@ -78,3 +70,8 @@
yield # ASYNC911: 8, "yield", Statement("function definition", lineno-2) # ASYNC911: 8, "yield", Statement("yield", lineno)

async def bar():
@@ x,3 x,4 @@
await trio.sleep("1") # type: ignore[arg-type]
except ValueError:
...
+ await trio.lowlevel.checkpoint()
108 changes: 108 additions & 0 deletions tests/eval_files/async912.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ASYNCIO_NO_ERROR
import trio


async def foo():
with trio.move_on_after(0.1): # error: 4
...
with trio.move_on_at(0.1): # error: 4
...
with trio.fail_after(0.1): # error: 4
...
with trio.fail_at(0.1): # error: 4
...
with trio.CancelScope(0.1): # error: 4
...

with open(""):
...

with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()

with trio.move_on_after(0.1): # error: 4
with trio.move_on_after(0.1): # error: 8
...

with trio.move_on_after(0.1): # TODO: should probably raise an error?
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()

with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()

with trio.move_on_after(0.1):
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
await trio.lowlevel.checkpoint()

# TODO: should probably raise the error at the call, rather than at the with statement
# fmt: off
with ( # error: 4
# a
# b
trio.move_on_after(0.1)
# c
):
...

with ( # error: 4
open(""),
trio.move_on_at(5),
open(""),
):
...
# fmt: on

# TODO: only raises one error currently, can make it raise 2(?)
with ( # error: 4
trio.move_on_after(0.1),
trio.fail_at(5),
):
...


# TODO: issue #240
async def livelocks():
with trio.move_on_after(0.1): # should error
while True:
try:
await trio.sleep("1") # type: ignore
except TypeError:
pass


def condition() -> bool:
return True


async def livelocks_2():
with trio.move_on_after(0.1): # error: 4
while condition():
try:
await trio.sleep("1") # type: ignore
except TypeError:
pass


# TODO: add --async912-context-managers=
async def livelocks_3():
import contextlib

with trio.move_on_after(0.1): # should error
while True:
with contextlib.suppress(TypeError):
await trio.sleep("1") # type: ignore


# raises an error...?
with trio.move_on_after(10): # error: 0
...


# completely sync function ... is this something we care about?
def sync_func():
with trio.move_on_after(10):
...
19 changes: 19 additions & 0 deletions tests/eval_files/async912_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# BASE_LIBRARY asyncio
# ANYIO_NO_ERROR
# TRIO_NO_ERROR

# timeout[_at] added in py3.11
# mypy: disable-error-code=attr-defined

import asyncio


async def foo():
async with asyncio.timeout(10): # error: 4
...
async with asyncio.timeout(10):
await foo()
async with asyncio.timeout_at(10): # error: 4
...
async with asyncio.timeout_at(10):
await foo()
17 changes: 17 additions & 0 deletions tests/eval_files/async91x_autofix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from typing import Any

import trio


def bar() -> Any: ...

Expand Down Expand Up @@ -109,3 +111,18 @@ async def async_func(): ...
break
[... for i in range(5)]
return


# TODO: issue 240
async def livelocks():
while True:
...


# this will autofix 910 by adding a checkpoint outside the loop
async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
while bar():
try:
await trio.sleep("1") # type: ignore[arg-type]
except ValueError:
...
3 changes: 2 additions & 1 deletion tests/test_flake8_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def _parse_eval_file(
"ASYNC116",
"ASYNC117",
"ASYNC118",
"ASYNC912",
}


Expand Down Expand Up @@ -479,7 +480,7 @@ def visit_AsyncFor(self, node: ast.AsyncFor):
return self.replace_async(node, ast.For, node.target, node.iter)


@pytest.mark.parametrize(("test", "path"), test_files)
@pytest.mark.parametrize(("test", "path"), test_files, ids=[f[0] for f in test_files])
def test_noerror_on_sync_code(test: str, path: Path):
if any(e in test for e in error_codes_ignored_when_checking_transformed_sync_code):
return
Expand Down