Skip to content

Commit 700aeaa

Browse files
committed
fix generator close for py3.13
1 parent 08a9158 commit 700aeaa

File tree

9 files changed

+38
-28
lines changed

9 files changed

+38
-28
lines changed

google/cloud/firestore_v1/async_aggregation.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ async def get(
8888
timeout=timeout,
8989
explain_options=explain_options,
9090
)
91-
result = [aggregation async for aggregation in stream_result]
92-
93-
if explain_options is None:
94-
explain_metrics = None
95-
else:
96-
explain_metrics = await stream_result.get_explain_metrics()
91+
try:
92+
result = [aggregation async for aggregation in stream_result]
93+
94+
if explain_options is None:
95+
explain_metrics = None
96+
else:
97+
explain_metrics = await stream_result.get_explain_metrics()
98+
finally:
99+
await stream_result.aclose()
97100

98101
return QueryResultsList(result, explain_options, explain_metrics)
99102

google/cloud/firestore_v1/async_query.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,17 @@ async def get(
231231
timeout=timeout,
232232
explain_options=explain_options,
233233
)
234-
result_list = [d async for d in result]
235-
if is_limited_to_last:
236-
result_list = list(reversed(result_list))
237-
238-
if explain_options is None:
239-
explain_metrics = None
240-
else:
241-
explain_metrics = await result.get_explain_metrics()
234+
try:
235+
result_list = [d async for d in result]
236+
if is_limited_to_last:
237+
result_list = list(reversed(result_list))
238+
239+
if explain_options is None:
240+
explain_metrics = None
241+
else:
242+
explain_metrics = await result.get_explain_metrics()
243+
finally:
244+
await result.aclose()
242245

243246
return QueryResultsList(result_list, explain_options, explain_metrics)
244247

google/cloud/firestore_v1/async_vector_query.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,15 @@ async def get(
9191
timeout=timeout,
9292
explain_options=explain_options,
9393
)
94-
result = [snapshot async for snapshot in stream_result]
94+
try:
95+
result = [snapshot async for snapshot in stream_result]
9596

96-
if explain_options is None:
97-
explain_metrics = None
98-
else:
99-
explain_metrics = await stream_result.get_explain_metrics()
97+
if explain_options is None:
98+
explain_metrics = None
99+
else:
100+
explain_metrics = await stream_result.get_explain_metrics()
101+
finally:
102+
await stream_result.aclose()
100103

101104
return QueryResultsList(result, explain_options, explain_metrics)
102105

@@ -151,7 +154,6 @@ async def _make_stream(
151154
metadata=self._client._rpc_metadata,
152155
**kwargs,
153156
)
154-
155157
async for response in response_iterator:
156158
if self._nested_query._all_descendants:
157159
snapshot = _collection_group_query_response_to_snapshot(

tests/unit/v1/test__helpers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,17 +2538,12 @@ async def __call__(self, *args, **kwargs):
25382538
return super(AsyncMock, self).__call__(*args, **kwargs)
25392539

25402540

2541-
class AsyncIter:
2541+
async def AsyncIter(items):
25422542
"""Utility to help recreate the effect of an async generator. Useful when
25432543
you need to mock a system that requires `async for`.
25442544
"""
2545-
2546-
def __init__(self, items):
2547-
self.items = items
2548-
2549-
async def __aiter__(self):
2550-
for i in self.items:
2551-
yield i
2545+
for i in items:
2546+
yield i
25522547

25532548

25542549
def _value_pb(**kwargs):

tests/unit/v1/test_async_aggregation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ async def _async_aggregation_query_stream_helper(
592592
assert r.alias == aggregation_result.alias
593593
assert r.value == aggregation_result.value
594594
results.append(result)
595+
await returned.aclose()
595596
assert len(results) == len(results_list)
596597

597598
if explain_options is None:

tests/unit/v1/test_async_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None):
371371
assert isinstance(stream_response, AsyncStreamGenerator)
372372

373373
returned = [x async for x in stream_response]
374+
await stream_response.aclose()
374375
assert len(returned) == 1
375376
snapshot = returned[0]
376377
assert snapshot.reference._path == ("dee", "sleep")

tests/unit/v1/test_async_stream_generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ async def test_async_stream_generator_explain_metrics_explain_options_analyze_tr
192192
"index_entries_scanned": "index_entries_scanned",
193193
}
194194
assert actual_explain_metrics.execution_stats.debug_stats == expected_debug_stats
195+
await inst.aclose()
195196

196197

197198
@pytest.mark.asyncio
@@ -230,6 +231,7 @@ async def test_async_stream_generator_explain_metrics_explain_options_analyze_fa
230231
}
231232
}
232233
]
234+
await inst.aclose()
233235

234236

235237
@pytest.mark.asyncio
@@ -242,6 +244,7 @@ async def test_async_stream_generator_explain_metrics_missing_explain_options_an
242244
query_profile.QueryExplainError, match="Did not receive explain_metrics"
243245
):
244246
await inst.get_explain_metrics()
247+
await inst.aclose()
245248

246249

247250
@pytest.mark.asyncio

tests/unit/v1/test_async_transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ async def _get_w_query_helper(retry=None, timeout=None, explain_options=None):
442442
metadata=client._rpc_metadata,
443443
**kwargs,
444444
)
445+
await returned_generator.aclose()
445446

446447

447448
@pytest.mark.asyncio

tests/unit/v1/test_async_vector_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ async def _async_vector_query_stream_helper(
595595
assert isinstance(returned, AsyncStreamGenerator)
596596

597597
results_list = [item async for item in returned]
598+
await returned.aclose()
598599
assert len(results_list) == 1
599600
assert results_list[0].to_dict() == data
600601

0 commit comments

Comments
 (0)