1
+ import contextvars
1
2
from typing import Optional
2
3
3
4
import html5lib
4
5
from asgiref .local import Local
5
6
from django .http import HttpResponse
6
- from django .test import Client , RequestFactory , TestCase , TransactionTestCase
7
+ from django .test import (
8
+ AsyncClient ,
9
+ AsyncRequestFactory ,
10
+ Client ,
11
+ RequestFactory ,
12
+ TestCase ,
13
+ TransactionTestCase ,
14
+ )
7
15
8
16
from debug_toolbar .panels import Panel
9
17
from debug_toolbar .toolbar import DebugToolbar
10
18
19
+ data_contextvar = contextvars .ContextVar ("djdt_toolbar_test_client" )
20
+
11
21
12
22
class ToolbarTestClient (Client ):
13
23
def request (self , ** request ):
@@ -29,11 +39,35 @@ def handle_toolbar_created(sender, toolbar=None, **kwargs):
29
39
return response
30
40
31
41
42
+ class AsyncToolbarTestClient (AsyncClient ):
43
+ async def request (self , ** request ):
44
+ # Use a thread/async task context-local variable to guard against a
45
+ # concurrent _created signal from a different thread/task.
46
+ # In cases testsuite will have both regular and async tests or
47
+ # multiple async tests running in an eventloop making async_client calls.
48
+ data_contextvar .set (None )
49
+
50
+ def handle_toolbar_created (sender , toolbar = None , ** kwargs ):
51
+ data_contextvar .set (toolbar )
52
+
53
+ DebugToolbar ._created .connect (handle_toolbar_created )
54
+ try :
55
+ response = await super ().request (** request )
56
+ finally :
57
+ DebugToolbar ._created .disconnect (handle_toolbar_created )
58
+ response .toolbar = data_contextvar .get ()
59
+
60
+ return response
61
+
62
+
32
63
rf = RequestFactory ()
64
+ arf = AsyncRequestFactory ()
33
65
34
66
35
67
class BaseMixin :
68
+ _is_async = False
36
69
client_class = ToolbarTestClient
70
+ async_client_class = AsyncToolbarTestClient
37
71
38
72
panel : Optional [Panel ] = None
39
73
panel_id = None
@@ -42,7 +76,11 @@ def setUp(self):
42
76
super ().setUp ()
43
77
self ._get_response = lambda request : HttpResponse ()
44
78
self .request = rf .get ("/" )
45
- self .toolbar = DebugToolbar (self .request , self .get_response )
79
+ if self ._is_async :
80
+ self .request = arf .get ("/" )
81
+ self .toolbar = DebugToolbar (self .request , self .get_response_async )
82
+ else :
83
+ self .toolbar = DebugToolbar (self .request , self .get_response )
46
84
self .toolbar .stats = {}
47
85
48
86
if self .panel_id :
@@ -59,6 +97,9 @@ def tearDown(self):
59
97
def get_response (self , request ):
60
98
return self ._get_response (request )
61
99
100
+ async def get_response_async (self , request ):
101
+ return self ._get_response (request )
102
+
62
103
def assertValidHTML (self , content ):
63
104
parser = html5lib .HTMLParser ()
64
105
parser .parseFragment (content )
0 commit comments