77from functools import partial
88from typing import (
99 TYPE_CHECKING ,
10+ AbstractSet ,
1011 Any ,
1112 Callable ,
1213 ContextManager ,
1617 Literal ,
1718 Optional ,
1819 Protocol ,
20+ Sequence ,
1921 Tuple ,
2022 Union ,
2123)
@@ -119,6 +121,54 @@ def django_db_createdb(request: pytest.FixtureRequest) -> bool:
119121 return create_db
120122
121123
124+ def _get_databases_for_test (test : pytest .Item ) -> tuple [Iterable [str ], bool ]:
125+ """Get the database aliases that need to be setup for a test, and whether
126+ they need to be serialized."""
127+ from django .db import DEFAULT_DB_ALIAS , connections
128+ from django .test import TransactionTestCase
129+
130+ test_cls = getattr (test , "cls" , None )
131+ if test_cls and issubclass (test_cls , TransactionTestCase ):
132+ serialized_rollback = getattr (test , "serialized_rollback" , False )
133+ databases = getattr (test , "databases" , None )
134+ else :
135+ fixtures = getattr (test , "fixturenames" , ())
136+ marker_db = test .get_closest_marker ("django_db" )
137+ if marker_db :
138+ (
139+ transaction ,
140+ reset_sequences ,
141+ databases ,
142+ serialized_rollback ,
143+ available_apps ,
144+ ) = validate_django_db (marker_db )
145+ elif "db" in fixtures or "transactional_db" in fixtures or "live_server" in fixtures :
146+ serialized_rollback = "django_db_serialized_rollback" in fixtures
147+ databases = None
148+ else :
149+ return (), False
150+ if databases is None :
151+ return (DEFAULT_DB_ALIAS ,), serialized_rollback
152+ elif databases == "__all__" :
153+ return connections , serialized_rollback
154+ else :
155+ return databases , serialized_rollback
156+
157+
158+ def _get_databases_for_setup (items : Sequence [pytest .Item ]) -> tuple [AbstractSet [str ], AbstractSet [str ]]:
159+ """Get the database aliases that need to be setup, and the subset that needs
160+ to be serialized."""
161+ # Code derived from django.test.utils.DiscoverRunner.get_databases().
162+ aliases : set [str ] = set ()
163+ serialized_aliases : set [str ] = set ()
164+ for test in items :
165+ databases , serialized_rollback = _get_databases_for_test (test )
166+ aliases .update (databases )
167+ if serialized_rollback :
168+ serialized_aliases .update (databases )
169+ return aliases , serialized_aliases
170+
171+
122172@pytest .fixture (scope = "session" )
123173def django_db_setup (
124174 request : pytest .FixtureRequest ,
@@ -140,10 +190,14 @@ def django_db_setup(
140190 if django_db_keepdb and not django_db_createdb :
141191 setup_databases_args ["keepdb" ] = True
142192
193+ aliases , serialized_aliases = _get_databases_for_setup (request .session .items )
194+
143195 with django_db_blocker .unblock ():
144196 db_cfg = setup_databases (
145197 verbosity = request .config .option .verbose ,
146198 interactive = False ,
199+ aliases = aliases ,
200+ serialized_aliases = serialized_aliases ,
147201 ** setup_databases_args ,
148202 )
149203
0 commit comments