Skip to content

Commit 2f8ea92

Browse files
committed
Merge branch 'master' of github.com:kingbuzzman/factory_boy
2 parents 9dba520 + d111d1a commit 2f8ea92

File tree

7 files changed

+109
-54
lines changed

7 files changed

+109
-54
lines changed

docs/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ ChangeLog
1010

1111
- :issue:`366`: Add :class:`factory.django.Password` to generate Django :class:`~django.contrib.auth.models.User`
1212
passwords.
13+
- :issue:`304`: Add :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session_factory` to dynamically
14+
create sessions for use by the :class:`~factory.alchemy.SQLAlchemyModelFactory`.
1315
- Add support for Django 3.2
1416
- Add support for Django 4.0
1517
- Add support for Python 3.10

docs/orms.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,25 @@ To work, this class needs an `SQLAlchemy`_ session object affected to the :attr:
369369
SQLAlchemy session to use to communicate with the database when creating
370370
an object through this :class:`SQLAlchemyModelFactory`.
371371

372+
.. attribute:: sqlalchemy_session_factory
373+
374+
.. versionadded:: 3.3.0
375+
376+
:class:`~collections.abc.Callable` returning a :class:`~sqlalchemy.orm.Session` instance to use to communicate
377+
with the database. You can either provide the session through this attribute, or through
378+
:attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session`, but not both at the same time.
379+
380+
.. code-block:: python
381+
382+
from . import common
383+
384+
class UserFactory(factory.alchemy.SQLAlchemyModelFactory):
385+
class Meta:
386+
model = User
387+
sqlalchemy_session_factory = lambda: common.Session()
388+
389+
username = 'john'
390+
372391
.. attribute:: sqlalchemy_session_persistence
373392

374393
Control the action taken by ``sqlalchemy_session`` at the end of a create call.

docs/recipes.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ simply use a :class:`factory.Iterator` on the chosen queryset:
5252
5353
language = factory.Iterator(models.Language.objects.all())
5454
55-
Here, ``models.Language.objects.all()`` won't be evaluated until the
56-
first call to ``UserFactory``; thus avoiding DB queries at import time.
55+
Here, ``models.Language.objects.all()`` is a
56+
:class:`~django.db.models.query.QuerySet` and will only hit the database when
57+
``factory_boy`` starts iterating on it, i.e on the first call to
58+
``UserFactory``; thus avoiding DB queries at import time.
5759

5860

5961
Reverse dependencies (reverse ForeignKey)

factory/alchemy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,18 @@ def _check_sqlalchemy_session_persistence(self, meta, value):
2222
(meta, VALID_SESSION_PERSISTENCE_TYPES, value)
2323
)
2424

25+
@staticmethod
26+
def _check_has_sqlalchemy_session_set(meta, value):
27+
if value and meta.sqlalchemy_session:
28+
raise RuntimeError("Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both")
29+
2530
def _build_default_options(self):
2631
return super()._build_default_options() + [
2732
base.OptionDefault('sqlalchemy_get_or_create', (), inherit=True),
2833
base.OptionDefault('sqlalchemy_session', None, inherit=True),
34+
base.OptionDefault(
35+
'sqlalchemy_session_factory', None, inherit=True, checker=self._check_has_sqlalchemy_session_set
36+
),
2937
base.OptionDefault(
3038
'sqlalchemy_session_persistence',
3139
None,
@@ -90,6 +98,10 @@ def _get_or_create(cls, model_class, session, args, kwargs):
9098
@classmethod
9199
def _create(cls, model_class, *args, **kwargs):
92100
"""Create an instance of the model, and save it to the database."""
101+
session_factory = cls._meta.sqlalchemy_session_factory
102+
if session_factory:
103+
cls._meta.sqlalchemy_session = session_factory()
104+
93105
session = cls._meta.sqlalchemy_session
94106

95107
if session is None:

factory/django.py

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
2626

27-
DJANGO_22 = Version('2.2') <= Version(django_version) < Version('3.0')
27+
DJANGO_22 = Version(django_version) < Version('3.0')
2828

2929
_LAZY_LOADS = {}
3030

@@ -205,9 +205,18 @@ def create_batch(cls, size, **kwargs):
205205

206206
@classmethod
207207
def _refresh_database_pks(cls, model_cls, objs):
208+
"""
209+
Before Django 3.0, there is an issue when bulk_insert.
210+
211+
The issue is that if you create an instance of a model,
212+
and reference it in another unsaved instance of a model.
213+
When you create the instance of the first one, the pk/id
214+
is never updated on the sub model that referenced the first.
215+
"""
208216
if not DJANGO_22:
209217
return
210-
fields = [f for f in model_cls._meta.get_fields() if isinstance(f, models.fields.related.ForeignObject)]
218+
fields = [f for f in model_cls._meta.get_fields()
219+
if isinstance(f, models.fields.related.ForeignObject)]
211220
if not fields:
212221
return
213222
for obj in objs:
@@ -217,17 +226,13 @@ def _refresh_database_pks(cls, model_cls, objs):
217226
@classmethod
218227
def _bulk_create(cls, size, **kwargs):
219228
models_to_create = cls.build_batch(size, **kwargs)
220-
collector = Collector(cls._meta.database)
229+
collector = DependencyInsertOrderCollector()
221230
collector.collect(cls, models_to_create)
222231
collector.sort()
223232
for model_cls, objs in collector.data.items():
224233
manager = cls._get_manager(model_cls)
225-
for instance in objs:
226-
models.signals.pre_save.send(model_cls, instance=instance, created=False)
227234
cls._refresh_database_pks(model_cls, objs)
228235
manager.bulk_create(objs)
229-
for instance in objs:
230-
models.signals.post_save.send(model_cls, instance=instance, created=True)
231236
return models_to_create
232237

233238
@classmethod
@@ -334,29 +339,20 @@ def _make_data(self, params):
334339
return thumb_io.getvalue()
335340

336341

337-
class Collector:
338-
def __init__(self, using):
339-
self.using = using
342+
class DependencyInsertOrderCollector:
343+
def __init__(self):
340344
# Initially, {model: {instances}}, later values become lists.
341345
self.data = defaultdict(list)
342-
# {model: {(field, value): {instances}}}
343-
self.field_updates = defaultdict(functools.partial(defaultdict, set))
344-
# {model: {field: {instances}}}
345-
self.restricted_objects = defaultdict(functools.partial(defaultdict, set))
346-
# fast_deletes is a list of queryset-likes that can be deleted without
347-
# fetching the objects into memory.
348-
self.fast_deletes = []
349-
350346
# Tracks deletion-order dependency for databases without transactions
351347
# or ability to defer constraint checks. Only concrete model classes
352348
# should be included, as the dependencies exist only between actual
353349
# database tables; proxy models are represented here by their concrete
354350
# parent.
355351
self.dependencies = defaultdict(set) # {model: {models}}
356352

357-
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
353+
def add(self, objs, source=None, nullable=False):
358354
"""
359-
Add 'objs' to the collection of objects to be deleted. If the call is
355+
Add 'objs' to the collection of objects to be inserted in order. If the call is
360356
the result of a cascade, 'source' should be the model that caused it,
361357
and 'nullable' should be set to True if the relation can be null.
362358
Return a list of all objects that were not already collected.
@@ -372,21 +368,15 @@ def add(self, objs, source=None, nullable=False, reverse_dependency=False):
372368
continue
373369
if id(obj) not in lookup:
374370
new_objs.append(obj)
375-
# import ipdb; ipdb.sset_trace()
376371
instances.extend(new_objs)
377372
# Nullable relationships can be ignored -- they are nulled out before
378373
# deleting, and therefore do not affect the order in which objects have
379374
# to be deleted.
380375
if source is not None and not nullable:
381-
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
382-
# if not nullable:
383-
# import ipdb; ipdb.sset_trace()
384-
# self.add_dependency(source, model, reverse_dependency=reverse_dependency)
376+
self.add_dependency(source, model)
385377
return new_objs
386378

387-
def add_dependency(self, model, dependency, reverse_dependency=False):
388-
if reverse_dependency:
389-
model, dependency = dependency, model
379+
def add_dependency(self, model, dependency):
390380
self.dependencies[model._meta.concrete_model].add(
391381
dependency._meta.concrete_model
392382
)
@@ -398,11 +388,6 @@ def collect(
398388
objs,
399389
source=None,
400390
nullable=False,
401-
collect_related=True,
402-
source_attr=None,
403-
reverse_dependency=False,
404-
keep_parents=False,
405-
fail_on_restricted=True,
406391
):
407392
"""
408393
Add 'objs' to the collection of objects to be deleted as well as all
@@ -412,10 +397,6 @@ def collect(
412397
If the call is the result of a cascade, 'source' should be the model
413398
that caused it and 'nullable' should be set to True, if the relation
414399
can be null.
415-
If 'reverse_dependency' is True, 'source' will be deleted before the
416-
current model, rather than after. (Needed for cascading to parent
417-
models, the one case in which the cascade follows the forwards
418-
direction of an FK rather than the reverse direction.)
419400
If 'keep_parents' is True, data of parent model's will be not deleted.
420401
If 'fail_on_restricted' is False, error won't be raised even if it's
421402
prohibited to delete such objects due to RESTRICT, that defers
@@ -424,47 +405,47 @@ def collect(
424405
can be deleted.
425406
"""
426407
new_objs = self.add(
427-
objs, source, nullable, reverse_dependency=reverse_dependency
408+
objs, source, nullable
428409
)
429410
if not new_objs:
430411
return
431412

432-
# import ipdb; ipdb.sset_trace()
433413
model = new_objs[0].__class__
434414

435-
def get_candidate_relations(opts):
436-
# The candidate relations are the ones that come from N-1 and 1-1 relations.
437-
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
438-
return (
439-
f
440-
for f in opts.get_fields(include_hidden=True)
441-
if isinstance(f, models.ForeignKey)
442-
)
415+
# The candidate relations are the ones that come from N-1 and 1-1 relations.
416+
candidate_relations = (
417+
f for f in model._meta.get_fields(include_hidden=True)
418+
if isinstance(f, models.ForeignKey)
419+
)
443420

444421
collected_objs = []
445-
for field in get_candidate_relations(model._meta):
422+
for field in candidate_relations:
446423
for obj in new_objs:
447424
val = getattr(obj, field.name)
448425
if isinstance(val, models.Model):
449426
collected_objs.append(val)
450427

451-
for name, _ in factory_cls._meta.post_declarations.as_dict().items():
452-
428+
for name, in factory_cls._meta.post_declarations.as_dict().keys():
453429
for obj in new_objs:
454430
val = getattr(obj, name, None)
455431
if isinstance(val, models.Model):
456432
collected_objs.append(val)
457433

458434
if collected_objs:
459435
new_objs = self.collect(
460-
factory_cls=factory_cls, objs=collected_objs, source=model, reverse_dependency=False
436+
factory_cls=factory_cls, objs=collected_objs, source=model
461437
)
462438

463439
def sort(self):
440+
"""
441+
Sort the model instances by the least dependecies to the most dependencies.
442+
443+
We want to insert the models with no dependencies first, and continue inserting
444+
using the models that the higher models depend on.
445+
"""
464446
sorted_models = []
465447
concrete_models = set()
466448
models = list(self.data)
467-
# import ipdb; ipdb.sset_trace()
468449
while len(sorted_models) < len(models):
469450
found = False
470451
for model in models:
@@ -476,6 +457,7 @@ def sort(self):
476457
concrete_models.add(model._meta.concrete_model)
477458
found = True
478459
if not found:
460+
logger.debug('dependency order could not be determined')
479461
return
480462
self.data = {model: self.data[model] for model in sorted_models}
481463

tests/test_alchemy.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,34 @@ def test_build_does_not_raises_exception_when_no_session_was_set(self):
264264
self.assertEqual(inst1.id, 1)
265265

266266

267+
class SQLAlchemySessionFactoryTestCase(unittest.TestCase):
268+
269+
def test_create_get_session_from_sqlalchemy_session_factory(self):
270+
class SessionGetterFactory(SQLAlchemyModelFactory):
271+
class Meta:
272+
model = models.StandardModel
273+
sqlalchemy_session = None
274+
sqlalchemy_session_factory = lambda: models.session
275+
276+
id = factory.Sequence(lambda n: n)
277+
278+
SessionGetterFactory.create()
279+
self.assertEqual(SessionGetterFactory._meta.sqlalchemy_session, models.session)
280+
# Reuse the session obtained from sqlalchemy_session_factory.
281+
SessionGetterFactory.create()
282+
283+
def test_create_raise_exception_sqlalchemy_session_factory_not_callable(self):
284+
message = "^Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both$"
285+
with self.assertRaisesRegex(RuntimeError, message):
286+
class SessionAndGetterFactory(SQLAlchemyModelFactory):
287+
class Meta:
288+
model = models.StandardModel
289+
sqlalchemy_session = models.session
290+
sqlalchemy_session_factory = lambda: models.session
291+
292+
id = factory.Sequence(lambda n: n)
293+
294+
267295
class NameConflictTests(unittest.TestCase):
268296
"""Regression test for `TypeError: _save() got multiple values for argument 'session'`
269297

tests/test_django.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ class Meta:
175175
level_2 = factory.SubFactory(Level2Factory)
176176

177177

178+
class DependencyInsertOrderCollector(django_test.TestCase):
179+
180+
def test_empty(self):
181+
collector = factory.django.DependencyInsertOrderCollector()
182+
collector.collect(Level2Factory, [])
183+
collector.sort()
184+
185+
self.assertEqual(collector.data, {})
186+
187+
178188
@unittest.skipIf(SKIP_BULK_INSERT, "bulk insert not supported by current db.")
179189
class DjangoBulkInsert(django_test.TestCase):
180190

0 commit comments

Comments
 (0)