77from functools import partial
88from typing import (
99 TYPE_CHECKING ,
10+ AbstractSet ,
1011 Any ,
1112 Callable ,
1213 ContextManager ,
1617 Literal ,
1718 Optional ,
1819 Protocol ,
20+ Sequence ,
1921 Tuple ,
2022 Union ,
2123)
@@ -119,6 +121,56 @@ def django_db_createdb(request: pytest.FixtureRequest) -> bool:
119121 return create_db
120122
121123
124+ def _get_databases_for_test (test : pytest .Item ) -> tuple [Iterable [str ], bool ]:
125+ """Get the database aliases that need to be setup for a test, and whether
126+ they need to be serialized."""
127+ from django .db import DEFAULT_DB_ALIAS , connections
128+ from django .test import TransactionTestCase
129+
130+ test_cls = getattr (test , "cls" , None )
131+ if test_cls and issubclass (test_cls , TransactionTestCase ):
132+ serialized_rollback = getattr (test , "serialized_rollback" , False )
133+ databases = getattr (test , "databases" , None )
134+ else :
135+ fixtures = getattr (test , "fixturenames" , ())
136+ marker_db = test .get_closest_marker ("django_db" )
137+ if marker_db :
138+ (
139+ transaction ,
140+ reset_sequences ,
141+ databases ,
142+ serialized_rollback ,
143+ available_apps ,
144+ ) = validate_django_db (marker_db )
145+ elif "db" in fixtures or "transactional_db" in fixtures or "live_server" in fixtures :
146+ serialized_rollback = "django_db_serialized_rollback" in fixtures
147+ databases = None
148+ else :
149+ return (), False
150+ if databases is None :
151+ return (DEFAULT_DB_ALIAS ,), serialized_rollback
152+ elif databases == "__all__" :
153+ return connections , serialized_rollback
154+ else :
155+ return databases , serialized_rollback
156+
157+
158+ def _get_databases_for_setup (
159+ items : Sequence [pytest .Item ],
160+ ) -> tuple [AbstractSet [str ], AbstractSet [str ]]:
161+ """Get the database aliases that need to be setup, and the subset that needs
162+ to be serialized."""
163+ # Code derived from django.test.utils.DiscoverRunner.get_databases().
164+ aliases : set [str ] = set ()
165+ serialized_aliases : set [str ] = set ()
166+ for test in items :
167+ databases , serialized_rollback = _get_databases_for_test (test )
168+ aliases .update (databases )
169+ if serialized_rollback :
170+ serialized_aliases .update (databases )
171+ return aliases , serialized_aliases
172+
173+
122174@pytest .fixture (scope = "session" )
123175def django_db_setup (
124176 request : pytest .FixtureRequest ,
@@ -140,10 +192,14 @@ def django_db_setup(
140192 if django_db_keepdb and not django_db_createdb :
141193 setup_databases_args ["keepdb" ] = True
142194
195+ aliases , serialized_aliases = _get_databases_for_setup (request .session .items )
196+
143197 with django_db_blocker .unblock ():
144198 db_cfg = setup_databases (
145199 verbosity = request .config .option .verbose ,
146200 interactive = False ,
201+ aliases = aliases ,
202+ serialized_aliases = serialized_aliases ,
147203 ** setup_databases_args ,
148204 )
149205
0 commit comments