1
+ import functools
1
2
import inspect
2
3
import sys
3
4
@@ -96,18 +97,36 @@ def stop_twisted_greenlet():
96
97
_instances .gr_twisted .switch ()
97
98
98
99
99
- def is_coroutine (something ):
100
- if ASYNC_AWAIT :
101
- return asyncio .iscoroutine (something )
100
+ class _CoroutineWrapper :
101
+ def __init__ (self , coroutine , mark ):
102
+ self .coroutine = coroutine
103
+ self .mark = mark
102
104
103
- return False
104
105
106
+ def _marked_async_fixture (mark ):
107
+ def fixture (* args , ** kwargs ):
108
+ def marker (f ):
109
+ @functools .wraps (f )
110
+ def w (* args , ** kwargs ):
111
+ return _CoroutineWrapper (
112
+ coroutine = f (* args , ** kwargs ),
113
+ mark = mark ,
114
+ )
115
+
116
+ return w
117
+
118
+ def decorator (f ):
119
+ result = pytest .fixture (* args , ** kwargs )(marker (f ))
120
+
121
+ return result
122
+
123
+ return decorator
124
+
125
+ return fixture
105
126
106
- def is_async_generator (something ):
107
- if ASYNC_GENERATORS :
108
- return inspect .isasyncgen (something )
109
127
110
- return False
128
+ async_fixture = _marked_async_fixture ('async_fixture' )
129
+ async_yield_fixture = _marked_async_fixture ('async_yield_fixture' )
111
130
112
131
113
132
@defer .inlineCallbacks
@@ -122,20 +141,23 @@ def _pytest_pyfunc_call(pyfuncitem):
122
141
testargs = {}
123
142
for arg in pyfuncitem ._fixtureinfo .argnames :
124
143
something = funcargs [arg ]
125
- if is_coroutine (something ):
126
- something = yield defer .ensureDeferred (something )
127
- elif is_async_generator (something ):
128
- async_generators .append ((arg , something ))
129
- something = yield defer .ensureDeferred (
130
- something .__anext__ (),
131
- )
144
+ if isinstance (something , _CoroutineWrapper ):
145
+ if something .mark == 'async_fixture' :
146
+ something = yield defer .ensureDeferred (
147
+ something .coroutine
148
+ )
149
+ elif something .mark == 'async_yield_fixture' :
150
+ async_generators .append ((arg , something ))
151
+ something = yield defer .ensureDeferred (
152
+ something .coroutine .__anext__ (),
153
+ )
132
154
testargs [arg ] = something
133
155
else :
134
156
testargs = funcargs
135
157
result = yield testfunction (** testargs )
136
158
137
159
async_generator_deferreds = [
138
- (arg , defer .ensureDeferred (g .__anext__ ()))
160
+ (arg , defer .ensureDeferred (g .coroutine . __anext__ ()))
139
161
for arg , g in async_generators
140
162
]
141
163
0 commit comments