|
| 1 | +import ast |
| 2 | +import inspect |
| 3 | +from typing import List |
| 4 | + |
| 5 | +import pytest |
| 6 | +from django.db import models |
| 7 | + |
| 8 | +from graphql_api.types import query # Import the Ariadne QueryType with resolvers |
| 9 | + |
| 10 | +# List of known synchronous functions that should be wrapped in sync_to_async in async functions |
| 11 | +SYNC_FUNCTIONS_TO_WRAP = {"get", "filter", "create", "update", "delete", "sleep"} |
| 12 | + |
| 13 | + |
| 14 | +def is_function_wrapped_in_sync_to_async(node: ast.Call) -> bool: |
| 15 | + """ |
| 16 | + Check if a function call is wrapped in sync_to_async. |
| 17 | + """ |
| 18 | + return ( |
| 19 | + isinstance(node, ast.Call) |
| 20 | + and isinstance(node.func, ast.Name) |
| 21 | + and node.func.id == "sync_to_async" |
| 22 | + ) |
| 23 | + |
| 24 | + |
| 25 | +def is_django_model_method_call(node: ast.Attribute) -> bool: |
| 26 | + """ |
| 27 | + Check if a given node is a method call on a Django model. |
| 28 | + """ |
| 29 | + if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): |
| 30 | + model_name = node.value.id |
| 31 | + # Attempt to get the actual model class from the globals if available |
| 32 | + model_class = globals().get(model_name) |
| 33 | + # Check if it’s a Django model class |
| 34 | + return inspect.isclass(model_class) and issubclass(model_class, models.Model) |
| 35 | + return False |
| 36 | + |
| 37 | + |
| 38 | +def find_unwrapped_sync_calls_in_function(func) -> List[str]: |
| 39 | + """ |
| 40 | + Parse a function to find sync calls on Django models that aren't wrapped in sync_to_async. |
| 41 | + """ |
| 42 | + unwrapped_sync_calls = [] |
| 43 | + |
| 44 | + # Get the AST of the function's source code |
| 45 | + try: |
| 46 | + source = inspect.getsource(func) |
| 47 | + except (OSError, TypeError, IndentationError): |
| 48 | + print(f"Could not retrieve or parse source for {func.__name__}. Skipping.") |
| 49 | + return unwrapped_sync_calls # Return empty list if source can't be accessed |
| 50 | + |
| 51 | + tree = ast.parse(source) |
| 52 | + |
| 53 | + # Traverse the AST to find unwrapped Django model method calls |
| 54 | + for node in ast.walk(tree): |
| 55 | + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): |
| 56 | + # Check if this is a Django model method call that isn’t wrapped in sync_to_async |
| 57 | + if is_django_model_method_call( |
| 58 | + node.func |
| 59 | + ) and not is_function_wrapped_in_sync_to_async(node): |
| 60 | + unwrapped_sync_calls.append(node.func.attr) |
| 61 | + return unwrapped_sync_calls |
| 62 | + |
| 63 | + |
| 64 | +# Retrieve only Ariadne resolvers for testing |
| 65 | +def get_ariadne_resolvers(): |
| 66 | + """ |
| 67 | + Retrieve all Ariadne resolvers from a specified QueryType or MutationType. |
| 68 | + """ |
| 69 | + # Ensure `query` is a QueryType or MutationType instance |
| 70 | + if hasattr(query, "_resolvers") and isinstance(query._resolvers, dict): |
| 71 | + return list(query._resolvers.values()) |
| 72 | + else: |
| 73 | + print( |
| 74 | + "The object `query` does not contain _resolvers. Ensure it is a QueryType or MutationType." |
| 75 | + ) |
| 76 | + return [] |
| 77 | + |
| 78 | + |
| 79 | +# Only test Ariadne resolvers |
| 80 | +ariadne_resolvers_to_test = get_ariadne_resolvers() |
| 81 | + |
| 82 | + |
| 83 | +@pytest.mark.parametrize("func", ariadne_resolvers_to_test) |
| 84 | +def test_functions_have_wrapped_sync_calls(func): |
| 85 | + unwrapped_calls = find_unwrapped_sync_calls_in_function(func) |
| 86 | + assert not unwrapped_calls, ( |
| 87 | + f"The following Django model method calls are missing `sync_to_async` in {func.__name__}: " |
| 88 | + f"{', '.join(unwrapped_calls)}" |
| 89 | + ) |
0 commit comments