Skip to content
This repository was archived by the owner on Jun 13, 2025. It is now read-only.

Commit d377eff

Browse files
experiments
1 parent 39aba43 commit d377eff

File tree

4 files changed

+141
-7
lines changed

4 files changed

+141
-7
lines changed

async_orm_checker.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# async_orm_checker.py
2+
3+
from mypy.nodes import CallExpr, FuncDef, MemberExpr
4+
from mypy.plugin import FunctionContext, Plugin
5+
from mypy.traverser import TraverserVisitor
6+
from mypy.types import Instance
7+
8+
# ORM method names that should not be used in async functions
9+
SYNC_ORM_METHODS = {"get", "filter", "create", "update", "delete"}
10+
11+
12+
class SyncORMInAsyncChecker(TraverserVisitor):
13+
def __init__(self, plugin, ctx: FunctionContext):
14+
super().__init__()
15+
self.plugin = plugin
16+
self.ctx = ctx
17+
18+
def visit_call_expr(self, expr: CallExpr):
19+
# Check if the function call is accessing an ORM method
20+
if isinstance(expr.callee, MemberExpr):
21+
method_name = expr.callee.name
22+
if method_name in SYNC_ORM_METHODS:
23+
# Check if this is a method on a Django model instance
24+
if isinstance(expr.callee.expr, Instance):
25+
if (
26+
"django.db.models.base.Model"
27+
in expr.callee.expr.type.type.fullname
28+
):
29+
# Trigger an error if a sync ORM method is used in an async context
30+
self.plugin.fail(
31+
f"Sync ORM method '{method_name}' used in async function; wrap in sync_to_async",
32+
expr,
33+
)
34+
super().visit_call_expr(expr)
35+
36+
37+
class AsyncORMPlugin(Plugin):
38+
def get_function_hook(self, fullname: str):
39+
# Only run this check on async functions
40+
def wrapper(ctx: FunctionContext):
41+
func_def = ctx.context
42+
if isinstance(func_def, FuncDef) and func_def.is_async:
43+
# Traverse the async function to look for sync ORM calls
44+
checker = SyncORMInAsyncChecker(self, ctx)
45+
func_def.accept(checker)
46+
return ctx.default_return_type
47+
48+
return wrapper
49+
50+
51+
def plugin(version: str):
52+
return AsyncORMPlugin

graphql_api/tests/test_temp.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import ast
2+
import inspect
3+
from typing import List
4+
5+
import pytest
6+
from django.db import models
7+
8+
from graphql_api.types import query # Import the Ariadne QueryType with resolvers
9+
10+
# List of known synchronous functions that should be wrapped in sync_to_async in async functions
11+
SYNC_FUNCTIONS_TO_WRAP = {"get", "filter", "create", "update", "delete", "sleep"}
12+
13+
14+
def is_function_wrapped_in_sync_to_async(node: ast.Call) -> bool:
15+
"""
16+
Check if a function call is wrapped in sync_to_async.
17+
"""
18+
return (
19+
isinstance(node, ast.Call)
20+
and isinstance(node.func, ast.Name)
21+
and node.func.id == "sync_to_async"
22+
)
23+
24+
25+
def is_django_model_method_call(node: ast.Attribute) -> bool:
26+
"""
27+
Check if a given node is a method call on a Django model.
28+
"""
29+
if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
30+
model_name = node.value.id
31+
# Attempt to get the actual model class from the globals if available
32+
model_class = globals().get(model_name)
33+
# Check if it’s a Django model class
34+
return inspect.isclass(model_class) and issubclass(model_class, models.Model)
35+
return False
36+
37+
38+
def find_unwrapped_sync_calls_in_function(func) -> List[str]:
39+
"""
40+
Parse a function to find sync calls on Django models that aren't wrapped in sync_to_async.
41+
"""
42+
unwrapped_sync_calls = []
43+
44+
# Get the AST of the function's source code
45+
try:
46+
source = inspect.getsource(func)
47+
except (OSError, TypeError, IndentationError):
48+
print(f"Could not retrieve or parse source for {func.__name__}. Skipping.")
49+
return unwrapped_sync_calls # Return empty list if source can't be accessed
50+
51+
tree = ast.parse(source)
52+
53+
# Traverse the AST to find unwrapped Django model method calls
54+
for node in ast.walk(tree):
55+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
56+
# Check if this is a Django model method call that isn’t wrapped in sync_to_async
57+
if is_django_model_method_call(
58+
node.func
59+
) and not is_function_wrapped_in_sync_to_async(node):
60+
unwrapped_sync_calls.append(node.func.attr)
61+
return unwrapped_sync_calls
62+
63+
64+
# Retrieve only Ariadne resolvers for testing
65+
def get_ariadne_resolvers():
66+
"""
67+
Retrieve all Ariadne resolvers from a specified QueryType or MutationType.
68+
"""
69+
# Ensure `query` is a QueryType or MutationType instance
70+
if hasattr(query, "_resolvers") and isinstance(query._resolvers, dict):
71+
return list(query._resolvers.values())
72+
else:
73+
print(
74+
"The object `query` does not contain _resolvers. Ensure it is a QueryType or MutationType."
75+
)
76+
return []
77+
78+
79+
# Only test Ariadne resolvers
80+
ariadne_resolvers_to_test = get_ariadne_resolvers()
81+
82+
83+
@pytest.mark.parametrize("func", ariadne_resolvers_to_test)
84+
def test_functions_have_wrapped_sync_calls(func):
85+
unwrapped_calls = find_unwrapped_sync_calls_in_function(func)
86+
assert not unwrapped_calls, (
87+
f"The following Django model method calls are missing `sync_to_async` in {func.__name__}: "
88+
f"{', '.join(unwrapped_calls)}"
89+
)

graphql_api/types/account/account.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def resolve_name(account: Account, info: GraphQLResolveInfo) -> str:
2121

2222

2323
@account_bindable.field("oktaConfig")
24-
@sync_to_async
2524
def resolve_okta_config(
2625
account: Account, info: GraphQLResolveInfo
2726
) -> OktaSettings | None:

mypy.ini

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,6 @@ ignore_missing_imports = True
66
disable_error_code = attr-defined,import-untyped,name-defined
77
follow_imports = silent
88
warn_no_return = False
9-
plugins = mypy_django_plugin.main
109

1110
[mypy-*.tests.*]
1211
disallow_untyped_defs = False
13-
14-
# Enable async ORM checks to catch synchronous ORM calls in async functions
15-
[mypy.plugins.django-stubs]
16-
enable-async-orm-checks = True
17-
django_settings_module = "codecov.settings_dev"

0 commit comments

Comments
 (0)