diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 5b646fabca0225..9fcbc237ca2c5a 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -1,4 +1,7 @@ """Utilities for with-statement contexts. See PEP 343.""" +from inspect import isasyncgenfunction, iscoroutinefunction, \ + isgeneratorfunction + import abc import os import sys @@ -79,11 +82,32 @@ def _recreate_cm(self): return self def __call__(self, func): - @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): return func(*args, **kwds) - return inner + + def gen_inner(*args, **kwds): + with self._recreate_cm(): + yield from func(*args, **kwds) + + async def async_inner(*args, **kwds): + with self._recreate_cm(): + return await func(*args, **kwds) + + async def asyncgen_inner(*args, **kwds): + with self._recreate_cm(): + async for value in func(*args, **kwds): + yield value + + wrapper = wraps(func) + if isasyncgenfunction(func): + return wrapper(asyncgen_inner) + elif iscoroutinefunction(func): + return wrapper(async_inner) + elif isgeneratorfunction(func): + return wrapper(gen_inner) + else: + return wrapper(inner) class AsyncContextDecorator(object): @@ -95,11 +119,33 @@ def _recreate_cm(self): return self def __call__(self, func): - @wraps(func) async def inner(*args, **kwds): + async with self._recreate_cm(): + return func(*args, **kwds) + + async def gen_inner(*args, **kwds): + async with self._recreate_cm(): + for value in func(*args, **kwds): + yield value + + async def async_inner(*args, **kwds): async with self._recreate_cm(): return await func(*args, **kwds) - return inner + + async def asyncgen_inner(*args, **kwds): + async with self._recreate_cm(): + async for value in func(*args, **kwds): + yield value + + wrapper = wraps(func) + if isasyncgenfunction(func): + return wrapper(asyncgen_inner) + elif iscoroutinefunction(func): + return wrapper(async_inner) + elif isgeneratorfunction(func): + return wrapper(gen_inner) + else: + return wrapper(inner) class _GeneratorContextManagerBase: diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 6a3329fa5aaace..d4ee315ff8a97c 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -1,5 +1,4 @@ """Unit tests for contextlib.py, and other context managers.""" - import io import os import sys @@ -680,6 +679,74 @@ def test(x): self.assertEqual(state, [1, 'something else', 999]) + def test_contextmanager_decorate_generator_function(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + for _ in test("something"): + self.assertEqual(state, [1, "something"]) + self.assertEqual(state, [1, "something", "second item", 999]) + + + def test_contextmanager_decorate_coroutine_function(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + async def test(x): + self.assertEqual(state, [1]) + state.append(x) + + coro = test('something') + with self.assertRaises(StopIteration): + coro.send(None) + + self.assertEqual(state, [1, 'something', 999]) + + + def test_contextmanager_decorate_asyncgen_function(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + async def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + async def run_test(): + async for _ in test("something"): + self.assertEqual(state, [1, "something"]) + + agen = test('something') + with self.assertRaises(StopIteration): + agen.asend(None).send(None) + with self.assertRaises(StopAsyncIteration): + agen.asend(None).send(None) + + self.assertEqual(state, [1, 'something', "second item", 999]) + + class TestBaseExitStack: exit_stack = None diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py index dcd0072037950e..0d1ce9a3bd6be7 100644 --- a/Lib/test/test_contextlib_async.py +++ b/Lib/test/test_contextlib_async.py @@ -402,6 +402,63 @@ async def test(): await test() self.assertFalse(entered) + @_async_test + async def test_decorator_decorate_sync_function(self): + @asynccontextmanager + async def context(): + state.append(1) + yield + state.append(999) + + state = [] + @context() + def test(x): + self.assertEqual(state, [1]) + state.append(x) + + await test("something") + self.assertEqual(state, [1, "something", 999]) + + @_async_test + async def test_decorator_decorate_generator_function(self): + @asynccontextmanager + async def context(): + state.append(1) + yield + state.append(999) + + state = [] + @context() + def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + async for _ in test("something"): + self.assertEqual(state, [1, "something"]) + self.assertEqual(state, [1, "something", "second item", 999]) + + @_async_test + async def test_decorator_decorate_asyncgen_function(self): + @asynccontextmanager + async def context(): + state.append(1) + yield + state.append(999) + + state = [] + @context() + async def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + async for _ in test("something"): + self.assertEqual(state, [1, "something"]) + self.assertEqual(state, [1, "something", "second item", 999]) + @_async_test async def test_decorator_with_exception(self): entered = False diff --git a/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst b/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst new file mode 100644 index 00000000000000..2e7b654bf07840 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst @@ -0,0 +1 @@ +Improved ``@contextmanager`` and ``@asynccontextmanager`` to work correctly with generators, coroutine functions and async generators when the wrapped callables are used as decorators