Skip to content

Commit b85ae7c

Browse files
committed
More tests for MapAsyncIterator
1 parent 70e0d85 commit b85ae7c

File tree

3 files changed

+91
-19
lines changed

3 files changed

+91
-19
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ a query language for APIs created by Facebook.
1313

1414
The current version 1.0.0rc2 of GraphQL-core-next is up-to-date with GraphQL.js
1515
version 14.0.0rc2. All parts of the API are covered by an extensive test
16-
suite of currently 1561 unit tests.
16+
suite of currently 1562 unit tests.
1717

1818

1919
## Documentation

graphql/subscription/map_async_iterator.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,45 @@ def __init__(self, iterable: AsyncIterable, callback: Callable,
1919
self.iterator = iterable.__aiter__()
2020
self.callback = callback
2121
self.reject_callback = reject_callback
22-
self.error = None
22+
self.stop = False
2323

2424
def __aiter__(self):
2525
return self
2626

2727
async def __anext__(self):
28-
if self.error is not None:
29-
raise self.error
28+
if self.stop:
29+
raise StopAsyncIteration
3030
try:
3131
value = await self.iterator.__anext__()
3232
except Exception as error:
3333
if not self.reject_callback or isinstance(error, (
3434
StopAsyncIteration, GeneratorExit)):
3535
raise
36-
if self.error is not None:
37-
raise self.error
3836
result = self.reject_callback(error)
3937
else:
40-
if self.error is not None:
41-
raise self.error
4238
result = self.callback(value)
4339
if isawaitable(result):
4440
result = await result
45-
if self.error is not None:
46-
raise self.error
4741
return result
4842

4943
async def athrow(self, type_, value=None, traceback=None):
50-
if self.error:
44+
if self.stop:
5145
return
5246
athrow = getattr(self.iterator, 'athrow', None)
5347
if athrow:
5448
await athrow(type_, value, traceback)
5549
else:
56-
error = type_
57-
if value is not None:
58-
error = error(value)
59-
if traceback is not None:
60-
error = error.with_traceback(traceback)
61-
self.error = error
50+
self.stop = True
51+
if value is None:
52+
if traceback is None:
53+
raise type_
54+
value = type_()
55+
if traceback is not None:
56+
value = value.with_traceback(traceback)
57+
raise value
6258

6359
async def aclose(self):
64-
if self.error:
60+
if self.stop:
6561
return
6662
aclose = getattr(self.iterator, 'aclose', None)
6763
if aclose:
@@ -70,4 +66,4 @@ async def aclose(self):
7066
except RuntimeError:
7167
pass
7268
else:
73-
self.error = StopAsyncIteration
69+
self.stop = True

tests/subscription/test_map_async_iterator.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
from pytest import mark, raises
24

35
from graphql.subscription.map_async_iterator import MapAsyncIterator
@@ -171,3 +173,77 @@ async def source():
171173

172174
with raises(StopAsyncIteration):
173175
await anext(doubles)
176+
177+
@mark.asyncio
178+
async def can_use_simple_iterator_instead_of_generator():
179+
async def source():
180+
yield 1
181+
yield 2
182+
yield 3
183+
184+
class Source:
185+
def __init__(self):
186+
self.counter = 0
187+
188+
def __aiter__(self):
189+
return self
190+
191+
async def __anext__(self):
192+
self.counter += 1
193+
if self.counter > 3:
194+
raise StopAsyncIteration
195+
return self.counter
196+
197+
for iterator in source, Source:
198+
doubles = MapAsyncIterator(iterator(), lambda x: x + x)
199+
200+
await doubles.aclose()
201+
202+
with raises(StopAsyncIteration):
203+
await anext(doubles)
204+
205+
doubles = MapAsyncIterator(iterator(), lambda x: x + x)
206+
207+
assert await anext(doubles) == 2
208+
assert await anext(doubles) == 4
209+
assert await anext(doubles) == 6
210+
211+
with raises(StopAsyncIteration):
212+
await anext(doubles)
213+
214+
doubles = MapAsyncIterator(iterator(), lambda x: x + x)
215+
216+
assert await anext(doubles) == 2
217+
assert await anext(doubles) == 4
218+
219+
# Throw error
220+
with raises(RuntimeError) as exc_info:
221+
await doubles.athrow(RuntimeError('ouch'))
222+
223+
assert str(exc_info.value) == 'ouch'
224+
225+
with raises(StopAsyncIteration):
226+
await anext(doubles)
227+
with raises(StopAsyncIteration):
228+
await anext(doubles)
229+
230+
await doubles.athrow(RuntimeError('no more ouch'))
231+
232+
with raises(StopAsyncIteration):
233+
await anext(doubles)
234+
235+
await doubles.aclose()
236+
237+
doubles = MapAsyncIterator(iterator(), lambda x: x + x)
238+
239+
assert await anext(doubles) == 2
240+
assert await anext(doubles) == 4
241+
242+
try:
243+
raise ValueError('bad')
244+
except ValueError:
245+
tb = sys.exc_info()[2]
246+
247+
# Throw error
248+
with raises(ValueError):
249+
await doubles.athrow(ValueError, None, tb)

0 commit comments

Comments
 (0)