Skip to content

Commit bd00fcb

Browse files
committed
gh-125862: Improve context decorator support for generators and async functions
1 parent fa43a1e commit bd00fcb

File tree

3 files changed

+168
-5
lines changed

3 files changed

+168
-5
lines changed

Lib/contextlib.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
"""Utilities for with-statement contexts. See PEP 343."""
2+
from inspect import isasyncgenfunction, iscoroutinefunction, \
3+
isgeneratorfunction
4+
25
import abc
36
import os
47
import sys
@@ -79,11 +82,32 @@ def _recreate_cm(self):
7982
return self
8083

8184
def __call__(self, func):
82-
@wraps(func)
8385
def inner(*args, **kwds):
8486
with self._recreate_cm():
8587
return func(*args, **kwds)
86-
return inner
88+
89+
def gen_inner(*args, **kwds):
90+
with self._recreate_cm():
91+
yield from func(*args, **kwds)
92+
93+
async def async_inner(*args, **kwds):
94+
with self._recreate_cm():
95+
return await func(*args, **kwds)
96+
97+
async def asyncgen_inner(*args, **kwds):
98+
with self._recreate_cm():
99+
async for value in func(*args, **kwds):
100+
yield value
101+
102+
wrapper = wraps(func)
103+
if isasyncgenfunction(func):
104+
return wrapper(asyncgen_inner)
105+
elif iscoroutinefunction(func):
106+
return wrapper(async_inner)
107+
elif isgeneratorfunction(func):
108+
return wrapper(gen_inner)
109+
else:
110+
return wrapper(inner)
87111

88112

89113
class AsyncContextDecorator(object):
@@ -95,11 +119,33 @@ def _recreate_cm(self):
95119
return self
96120

97121
def __call__(self, func):
98-
@wraps(func)
99122
async def inner(*args, **kwds):
123+
async with self._recreate_cm():
124+
return func(*args, **kwds)
125+
126+
async def gen_inner(*args, **kwds):
127+
async with self._recreate_cm():
128+
for value in func(*args, **kwds):
129+
yield value
130+
131+
async def async_inner(*args, **kwds):
100132
async with self._recreate_cm():
101133
return await func(*args, **kwds)
102-
return inner
134+
135+
async def asyncgen_inner(*args, **kwds):
136+
async with self._recreate_cm():
137+
async for value in func(*args, **kwds):
138+
yield value
139+
140+
wrapper = wraps(func)
141+
if isasyncgenfunction(func):
142+
return wrapper(asyncgen_inner)
143+
elif iscoroutinefunction(func):
144+
return wrapper(async_inner)
145+
elif isgeneratorfunction(func):
146+
return wrapper(gen_inner)
147+
else:
148+
return wrapper(inner)
103149

104150

105151
class _GeneratorContextManagerBase:

Lib/test/test_contextlib.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Unit tests for contextlib.py, and other context managers."""
2-
2+
import asyncio
33
import io
44
import os
55
import sys
@@ -680,6 +680,66 @@ def test(x):
680680
self.assertEqual(state, [1, 'something else', 999])
681681

682682

683+
def test_contextmanager_decorate_generator_function(self):
684+
@contextmanager
685+
def woohoo(y):
686+
state.append(y)
687+
yield
688+
state.append(999)
689+
690+
state = []
691+
@woohoo(1)
692+
def test(x):
693+
self.assertEqual(state, [1])
694+
state.append(x)
695+
yield
696+
state.append("second item")
697+
698+
for _ in test("something"):
699+
self.assertEqual(state, [1, "something"])
700+
self.assertEqual(state, [1, "something", "second item", 999])
701+
702+
703+
def test_contextmanager_decorate_coroutine_function(self):
704+
@contextmanager
705+
def woohoo(y):
706+
state.append(y)
707+
yield
708+
state.append(999)
709+
710+
state = []
711+
@woohoo(1)
712+
async def test(x):
713+
self.assertEqual(state, [1])
714+
state.append(x)
715+
716+
asyncio.run(test('something'))
717+
self.assertEqual(state, [1, 'something', 999])
718+
719+
720+
def test_contextmanager_decorate_asyncgen_function(self):
721+
@contextmanager
722+
def woohoo(y):
723+
state.append(y)
724+
yield
725+
state.append(999)
726+
727+
state = []
728+
@woohoo(1)
729+
async def test(x):
730+
self.assertEqual(state, [1])
731+
state.append(x)
732+
yield
733+
state.append("second item")
734+
735+
async def run_test():
736+
async for _ in test("something"):
737+
self.assertEqual(state, [1, "something"])
738+
739+
asyncio.run(run_test())
740+
self.assertEqual(state, [1, 'something', "second item", 999])
741+
742+
683743
class TestBaseExitStack:
684744
exit_stack = None
685745

Lib/test/test_contextlib_async.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,63 @@ async def test():
402402
await test()
403403
self.assertFalse(entered)
404404

405+
@_async_test
406+
async def test_decorator_decorate_sync_function(self):
407+
@asynccontextmanager
408+
async def context():
409+
state.append(1)
410+
yield
411+
state.append(999)
412+
413+
state = []
414+
@context()
415+
def test(x):
416+
self.assertEqual(state, [1])
417+
state.append(x)
418+
419+
await test("something")
420+
self.assertEqual(state, [1, "something", 999])
421+
422+
@_async_test
423+
async def test_decorator_decorate_generator_function(self):
424+
@asynccontextmanager
425+
async def context():
426+
state.append(1)
427+
yield
428+
state.append(999)
429+
430+
state = []
431+
@context()
432+
def test(x):
433+
self.assertEqual(state, [1])
434+
state.append(x)
435+
yield
436+
state.append("second item")
437+
438+
async for _ in test("something"):
439+
self.assertEqual(state, [1, "something"])
440+
self.assertEqual(state, [1, "something", "second item", 999])
441+
442+
@_async_test
443+
async def test_decorator_decorate_asyncgen_function(self):
444+
@asynccontextmanager
445+
async def context():
446+
state.append(1)
447+
yield
448+
state.append(999)
449+
450+
state = []
451+
@context()
452+
async def test(x):
453+
self.assertEqual(state, [1])
454+
state.append(x)
455+
yield
456+
state.append("second item")
457+
458+
async for _ in test("something"):
459+
self.assertEqual(state, [1, "something"])
460+
self.assertEqual(state, [1, "something", "second item", 999])
461+
405462
@_async_test
406463
async def test_decorator_with_exception(self):
407464
entered = False

0 commit comments

Comments
 (0)