|
2 | 2 | import os
|
3 | 3 | from contextlib import contextmanager
|
4 | 4 | from functools import partial
|
5 |
| -from typing import ( |
6 |
| - Any, Callable, Generator, Iterable, List, Optional, Tuple, Union, |
7 |
| -) |
| 5 | +from typing import Any, Generator, Iterable, List, Optional, Tuple, Union |
8 | 6 |
|
9 | 7 | import pytest
|
10 | 8 |
|
11 | 9 | from . import live_server_helper
|
12 | 10 | from .django_compat import is_django_unittest
|
13 |
| -from .lazy_django import get_django_version, skip_if_no_django |
| 11 | +from .lazy_django import skip_if_no_django |
14 | 12 |
|
15 | 13 |
|
16 | 14 | TYPE_CHECKING = False
|
@@ -216,12 +214,12 @@ class PytestDjangoTestCase(test_case_class): # type: ignore[misc,valid-type]
|
216 | 214 | @classmethod
|
217 | 215 | def setUpClass(cls) -> None:
|
218 | 216 | super(django.test.TestCase, cls).setUpClass()
|
219 |
| - if (3, 2) <= VERSION < (4, 1): |
| 217 | + if VERSION < (4, 1): |
220 | 218 | django.db.transaction.Atomic._ensure_durability = False
|
221 | 219 |
|
222 | 220 | @classmethod
|
223 | 221 | def tearDownClass(cls) -> None:
|
224 |
| - if (3, 2) <= VERSION < (4, 1): |
| 222 | + if VERSION < (4, 1): |
225 | 223 | django.db.transaction.Atomic._ensure_durability = True
|
226 | 224 | super(django.test.TestCase, cls).tearDownClass()
|
227 | 225 |
|
@@ -616,36 +614,8 @@ def django_assert_max_num_queries(pytestconfig):
|
616 | 614 | return partial(_assert_num_queries, pytestconfig, exact=False)
|
617 | 615 |
|
618 | 616 |
|
619 |
| -@contextmanager |
620 |
| -def _capture_on_commit_callbacks( |
621 |
| - *, |
622 |
| - using: Optional[str] = None, |
623 |
| - execute: bool = False |
624 |
| -): |
625 |
| - from django.db import DEFAULT_DB_ALIAS, connections |
626 |
| - from django.test import TestCase |
627 |
| - |
628 |
| - if using is None: |
629 |
| - using = DEFAULT_DB_ALIAS |
630 |
| - |
631 |
| - # Polyfill of Django code as of Django 3.2. |
632 |
| - if get_django_version() < (3, 2): |
633 |
| - callbacks: List[Callable[[], Any]] = [] |
634 |
| - start_count = len(connections[using].run_on_commit) |
635 |
| - try: |
636 |
| - yield callbacks |
637 |
| - finally: |
638 |
| - run_on_commit = connections[using].run_on_commit[start_count:] |
639 |
| - callbacks[:] = [func for sids, func in run_on_commit] |
640 |
| - if execute: |
641 |
| - for callback in callbacks: |
642 |
| - callback() |
643 |
| - |
644 |
| - else: |
645 |
| - with TestCase.captureOnCommitCallbacks(using=using, execute=execute) as callbacks: |
646 |
| - yield callbacks |
647 |
| - |
648 |
| - |
649 | 617 | @pytest.fixture(scope="function")
|
650 | 618 | def django_capture_on_commit_callbacks():
|
651 |
| - return _capture_on_commit_callbacks |
| 619 | + from django.test import TestCase |
| 620 | + |
| 621 | + return TestCase.captureOnCommitCallbacks |
0 commit comments