Skip to content

Commit 92fc26a

Browse files
committed
Add bulk_create
1 parent 3cc670b commit 92fc26a

File tree

3 files changed

+73
-14
lines changed

3 files changed

+73
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# 0.7.6
22

33
- Change internal data types to set to speed up building queries and allow caching
4+
- Add a `bulk_create` method to the model manager.
45

56
# 0.7.5
67

atomdb/sql.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ async def get_or_create(self, **filters) -> TupleType[T, bool]:
837837
"""Get or create a model matching the given criteria
838838
839839
Parameters
840-
----------author__name="Tom",
840+
----------
841841
filters: Dict
842842
The filters to use to retrieve the object
843843
@@ -880,6 +880,47 @@ async def create(self, **state) -> T:
880880
await obj.save(force_insert=True, connection=connection)
881881
return obj
882882

883+
async def bulk_create(self, items: Sequence[T], connection=None) -> Sequence[T]:
884+
"""Perform a bulk create from a sequence of models. This will
885+
populate the primary key of the items as needed but will not pull
886+
any fields that have not been defined. The restored flag will still
887+
be False.
888+
889+
Parameters
890+
----------
891+
items: Sequence[T]
892+
The list of items to create.
893+
connection: Connetion
894+
The connection to use (if None one from the pool will be used)
895+
896+
Returns
897+
-------
898+
items: Sequence[T]
899+
The items passed in. Only postgres will populate the primary keys.
900+
901+
"""
902+
table = self.table
903+
values = [item.__prepare_state_for_db__() for item in items]
904+
async with self.connection(connection) as conn:
905+
# TODO: Properly detect?
906+
postgres = "aiopg" in conn.__class__.__module__
907+
if postgres:
908+
pk_column = table.c[self.model.__pk__]
909+
q = table.insert().returning(pk_column).values(values)
910+
else:
911+
# TODO: Get return value?
912+
q = table.insert().values(values)
913+
914+
result = await conn.execute(q)
915+
if postgres:
916+
cache = self.cache
917+
for r, item in zip(await result.fetchall(), items):
918+
# Don't overwrite if force inserting
919+
if not item._id:
920+
item._id = r[0]
921+
cache[item._id] = item
922+
return items
923+
883924
def __getattr__(self, name: str):
884925
"""All other fields are delegated to the query set"""
885926
qs: SQLQuerySet[T] = SQLQuerySet(proxy=self)
@@ -1925,6 +1966,24 @@ async def load(
19251966
state = await db.fetchone(q, connection=connection)
19261967
await self.__restorestate__(state)
19271968

1969+
def __prepare_state_for_db__(self):
1970+
"""Get the state that should be saved into the database"""
1971+
state = self.__getstate__()
1972+
1973+
# Remove any fields are in the state but should not go into the db
1974+
for f in self.__excluded_fields__:
1975+
state.pop(f, None)
1976+
1977+
# Replace any renamed fields
1978+
for py_name, db_name in self.__renamed_fields__.items():
1979+
state[db_name] = state.pop(py_name)
1980+
1981+
if not self._id:
1982+
# Postgres errors if using None for the pk
1983+
state.pop(self.__pk__, None)
1984+
1985+
return state
1986+
19281987
async def save(
19291988
self: T,
19301989
force_insert: bool = False,
@@ -1955,17 +2014,8 @@ async def save(
19552014
raise ValueError("Cannot use force_insert and force_update together")
19562015

19572016
db = self.objects
1958-
state = self.__getstate__()
1959-
1960-
# Remove any fields are in the state but should not go into the db
1961-
for f in self.__excluded_fields__:
1962-
state.pop(f, None)
1963-
1964-
# Replace any renamed fields
1965-
for py_name, db_name in self.__renamed_fields__.items():
1966-
state[db_name] = state.pop(py_name)
1967-
19682017
table = db.table
2018+
state = self.__prepare_state_for_db__()
19692019
async with db.connection(connection) as conn:
19702020
if force_update or (self._id and not force_insert):
19712021

@@ -1990,9 +2040,6 @@ async def save(
19902040
f"pk={self._id} exist or it has not changed."
19912041
)
19922042
else:
1993-
if not self._id:
1994-
# Postgres errors if using None for the pk
1995-
state.pop(self.__pk__, None)
19962043
q = table.insert().values(**state)
19972044
r = await conn.execute(q)
19982045

tests/test_sql.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,17 @@ async def test_create(db):
956956
await Job.objects.create(name=job.name)
957957

958958

959+
async def test_bulk_create(db):
960+
await reset_tables(User)
961+
assert await User.objects.count() == 0
962+
# TODO: Get the id's of the rows inserted?
963+
users = await User.objects.bulk_create([User(name=f"user-{i}") for i in range(10)])
964+
for u in users:
965+
if not IS_MYSQL:
966+
assert u._id
967+
assert await User.objects.count() == 10
968+
969+
959970
async def test_transaction_rollback(db):
960971
await reset_tables(Job, JobSkill, JobRole)
961972

0 commit comments

Comments
 (0)