Skip to content

Commit ff339a1

Browse files
committed
feat(db): add bulk_create
1 parent 6dc14ac commit ff339a1

File tree

2 files changed

+448
-1
lines changed

2 files changed

+448
-1
lines changed

cot/src/db.rs

Lines changed: 294 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use mockall::automock;
2727
use query::Query;
2828
pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy};
2929
use sea_query::{
30-
Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr,
30+
ColumnRef, Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr,
3131
};
3232
use sea_query_binder::{SqlxBinder, SqlxValues};
3333
use sqlx::{Type, TypeInfo};
@@ -84,6 +84,22 @@ pub enum DatabaseError {
8484
/// Error when a unique constraint is violated in the database.
8585
#[error("{ERROR_PREFIX} unique constraint violation")]
8686
UniqueViolation,
87+
/// Single model has more fields than database parameter limit.
88+
#[error(
89+
"{ERROR_PREFIX} model has {field_count} fields which exceeds the database parameter limit \
90+
of {limit}"
91+
)]
92+
BulkCreateModelTooLarge {
93+
/// The number of fields in the model.
94+
field_count: usize,
95+
/// The database parameter limit.
96+
limit: usize,
97+
},
98+
/// Attempted bulk create with a model that only contains Auto fields.
99+
#[error(
100+
"{ERROR_PREFIX} calling bulk_create with a model that only contains auto/default fields is unsupported"
101+
)]
102+
BulkCreateNoValueColumns,
87103
}
88104
impl_into_cot_error!(DatabaseError, INTERNAL_SERVER_ERROR);
89105

@@ -247,6 +263,39 @@ pub trait Model: Sized + Send + 'static {
247263
db.update(self).await?;
248264
Ok(())
249265
}
266+
267+
/// Bulk insert multiple model instances to the database in a single query.
268+
///
269+
/// This method is significantly faster than calling [`Self::insert`]
270+
/// multiple times, as it combines all instances into a single SQL
271+
/// `INSERT` statement with multiple value sets.
272+
///
273+
/// # Examples
274+
///
275+
/// ```ignore
276+
/// let mut todos = vec![
277+
/// TodoItem { id: Auto::auto(), title: "Task 1".into() },
278+
/// TodoItem { id: Auto::auto(), title: "Task 2".into() },
279+
/// TodoItem { id: Auto::auto(), title: "Task 3".into() },
280+
/// ];
281+
///
282+
/// TodoItem::bulk_create(&db, &mut todos).await?;
283+
///
284+
/// // After insertion, all todos have populated IDs
285+
/// assert!(todos[0].id.is_fixed());
286+
/// ```
287+
///
288+
/// # Errors
289+
///
290+
/// Returns error if:
291+
/// - Database connection fails
292+
/// - Unique constraint is violated
293+
/// - Empty slice is provided
294+
/// - Single model has more fields than the database parameter limit
295+
async fn bulk_create<DB: DatabaseBackend>(db: &DB, instances: &mut [Self]) -> Result<()> {
296+
db.bulk_insert(instances).await?;
297+
Ok(())
298+
}
250299
}
251300

252301
/// An identifier structure that holds table or column name as a string.
@@ -965,6 +1014,215 @@ impl Database {
9651014
Ok(())
9661015
}
9671016

1017+
/// Bulk inserts multiple rows into the database.
1018+
///
1019+
/// # Errors
1020+
///
1021+
/// This method can return an error if the rows could not be inserted into
1022+
/// the database, for instance because the migrations haven't been
1023+
/// applied, or there was a problem with the database connection.
1024+
pub async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
1025+
let span = span!(Level::TRACE, "bulk_insert", table = %T::TABLE_NAME, count = data.len());
1026+
1027+
Self::bulk_insert_impl(self, data, false)
1028+
.instrument(span)
1029+
.await
1030+
}
1031+
1032+
/// Bulk inserts multiple rows into the database, or updates them if they
1033+
/// already exist.
1034+
///
1035+
/// # Errors
1036+
///
1037+
/// This method can return an error if the rows could not be inserted into
1038+
/// the database, for instance because the migrations haven't been
1039+
/// applied, or there was a problem with the database connection.
1040+
pub async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
1041+
let span = span!(
1042+
Level::TRACE,
1043+
"bulk_insert_or_update",
1044+
table = %T::TABLE_NAME,
1045+
count = data.len()
1046+
);
1047+
1048+
Self::bulk_insert_impl(self, data, true)
1049+
.instrument(span)
1050+
.await
1051+
}
1052+
1053+
async fn bulk_insert_impl<T: Model>(&self, data: &mut [T], update: bool) -> Result<()> {
1054+
// TODO: add transactions when implemented
1055+
1056+
if data.is_empty() {
1057+
return Ok(());
1058+
}
1059+
1060+
let max_params = match self.inner {
1061+
// https://sqlite.org/limits.html#max_variable_number
1062+
// Assuming SQLite > 3.32.0 (2020-05-22)
1063+
#[cfg(feature = "sqlite")]
1064+
DatabaseImpl::Sqlite(_) => 32766,
1065+
// https://www.postgresql.org/docs/18/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BIND
1066+
// The number of parameter format codes is Int16
1067+
#[cfg(feature = "postgres")]
1068+
DatabaseImpl::Postgres(_) => 65535,
1069+
// https://dev.mysql.com/doc/dev/mysql-server/9.5.0/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
1070+
// The number of parameter returned in the COM_STMT_PREPARE_OK packet is int<2>
1071+
#[cfg(feature = "mysql")]
1072+
DatabaseImpl::MySql(_) => 65535,
1073+
};
1074+
1075+
let column_identifiers: Vec<_> = T::COLUMNS
1076+
.iter()
1077+
.map(|column| Identifier::from(column.name.as_str()))
1078+
.collect();
1079+
let value_indices: Vec<_> = (0..T::COLUMNS.len()).collect();
1080+
1081+
// Determine which columns are Auto vs Value by examining the first instance
1082+
// (all instances should have the same structure, as we are inserting a single
1083+
// Model)
1084+
let first_values = data[0]
1085+
.get_values(&value_indices)
1086+
.into_iter()
1087+
.map(ToDbFieldValue::to_db_field_value)
1088+
.collect::<Vec<_>>();
1089+
1090+
let mut auto_col_ids = Vec::new();
1091+
let mut auto_col_identifiers = Vec::new();
1092+
let mut value_column_indices = Vec::new();
1093+
let mut value_identifiers = Vec::new();
1094+
1095+
for (index, (identifier, value)) in
1096+
std::iter::zip(column_identifiers.iter(), first_values.iter()).enumerate()
1097+
{
1098+
match value {
1099+
DbFieldValue::Auto => {
1100+
auto_col_ids.push(index);
1101+
auto_col_identifiers.push((*identifier).into_column_ref());
1102+
}
1103+
DbFieldValue::Value(_) => {
1104+
value_column_indices.push(index);
1105+
value_identifiers.push(*identifier);
1106+
}
1107+
}
1108+
}
1109+
1110+
let num_value_fields = value_identifiers.len();
1111+
1112+
if num_value_fields > max_params {
1113+
return Err(DatabaseError::BulkCreateModelTooLarge {
1114+
field_count: num_value_fields,
1115+
limit: max_params,
1116+
});
1117+
}
1118+
1119+
let batch_size = if num_value_fields > 0 {
1120+
max_params / num_value_fields
1121+
} else {
1122+
return Err(DatabaseError::BulkCreateNoValueColumns);
1123+
};
1124+
1125+
for chunk in data.chunks_mut(batch_size) {
1126+
self.bulk_insert_chunk(
1127+
chunk,
1128+
update,
1129+
&value_identifiers,
1130+
&value_column_indices,
1131+
&auto_col_ids,
1132+
&auto_col_identifiers,
1133+
)
1134+
.await?;
1135+
}
1136+
1137+
Ok(())
1138+
}
1139+
1140+
async fn bulk_insert_chunk<T: Model>(
1141+
&self,
1142+
chunk: &mut [T],
1143+
update: bool,
1144+
value_identifiers: &[Identifier],
1145+
value_column_indices: &[usize],
1146+
auto_col_ids: &[usize],
1147+
auto_col_identifiers: &[ColumnRef],
1148+
) -> Result<()> {
1149+
let mut insert_statement = sea_query::Query::insert()
1150+
.into_table(T::TABLE_NAME)
1151+
.columns(value_identifiers.iter().copied())
1152+
.to_owned();
1153+
1154+
// Add values for each instance in the chunk
1155+
for instance in chunk.iter() {
1156+
let values = instance.get_values(value_column_indices);
1157+
let db_values: Vec<_> = values
1158+
.into_iter()
1159+
.map(|v| match v.to_db_field_value() {
1160+
DbFieldValue::Value(val) => val,
1161+
DbFieldValue::Auto => {
1162+
panic!("Expected Value but found Auto in bulk insert")
1163+
}
1164+
})
1165+
.map(SimpleExpr::Value)
1166+
.collect();
1167+
1168+
assert!(!db_values.is_empty(), "expected at least 1 value field");
1169+
insert_statement.values(db_values)?;
1170+
}
1171+
1172+
if update {
1173+
insert_statement.on_conflict(
1174+
OnConflict::column(T::PRIMARY_KEY_NAME)
1175+
.update_columns(value_identifiers.iter().copied())
1176+
.to_owned(),
1177+
);
1178+
}
1179+
1180+
if auto_col_ids.is_empty() {
1181+
self.execute_statement(&insert_statement).await?;
1182+
} else if self.supports_returning() {
1183+
// PostgreSQL/SQLite: Use RETURNING clause
1184+
insert_statement.returning(ReturningClause::Columns(auto_col_identifiers.to_vec()));
1185+
1186+
let rows = self.fetch_all(&insert_statement).await?;
1187+
1188+
for (instance, row) in chunk.iter_mut().zip(rows) {
1189+
instance.update_from_db(row, auto_col_ids)?;
1190+
}
1191+
} else {
1192+
// MySQL: Use LAST_INSERT_ID() and fetch rows
1193+
let result = self.execute_statement(&insert_statement).await?;
1194+
let first_id = result
1195+
.last_inserted_row_id
1196+
.expect("expected last inserted row ID if RETURNING clause is not supported");
1197+
1198+
// Fetch the inserted rows using a SELECT query
1199+
// Note: This assumes IDs are consecutive, which is generally safe for
1200+
// auto_increment but could fail with concurrent inserts
1201+
let query = sea_query::Query::select()
1202+
.from(T::TABLE_NAME)
1203+
.columns(auto_col_identifiers.iter().cloned())
1204+
.and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).gte(first_id).and(
1205+
sea_query::Expr::col(T::PRIMARY_KEY_NAME).lt(first_id + chunk.len() as u64),
1206+
))
1207+
.order_by(T::PRIMARY_KEY_NAME, sea_query::Order::Asc)
1208+
.to_owned();
1209+
1210+
let rows = self.fetch_all(&query).await?;
1211+
1212+
for (instance, row) in chunk.iter_mut().zip(rows) {
1213+
instance.update_from_db(row, auto_col_ids)?;
1214+
}
1215+
}
1216+
1217+
if update {
1218+
trace!(count = chunk.len(), "Inserted or updated rows");
1219+
} else {
1220+
trace!(count = chunk.len(), "Inserted rows");
1221+
}
1222+
1223+
Ok(())
1224+
}
1225+
9681226
/// Executes the given query and returns the results converted to the model
9691227
/// type.
9701228
///
@@ -1271,6 +1529,25 @@ pub trait DatabaseBackend: Send + Sync {
12711529
/// there was a problem with the database connection.
12721530
async fn update<T: Model>(&self, data: &mut T) -> Result<()>;
12731531

1532+
/// Bulk inserts multiple rows into the database.
1533+
///
1534+
/// # Errors
1535+
///
1536+
/// This method can return an error if the rows could not be inserted into
1537+
/// the database, for instance because the migrations haven't been
1538+
/// applied, or there was a problem with the database connection.
1539+
async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()>;
1540+
1541+
/// Bulk inserts multiple rows into the database, or updates existing rows
1542+
/// if they already exist.
1543+
///
1544+
/// # Errors
1545+
///
1546+
/// This method can return an error if the rows could not be inserted into
1547+
/// the database, for instance because the migrations haven't been
1548+
/// applied, or there was a problem with the database connection.
1549+
async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()>;
1550+
12741551
/// Executes a query and returns the results converted to the model type.
12751552
///
12761553
/// # Errors
@@ -1339,6 +1616,14 @@ impl DatabaseBackend for Database {
13391616
Database::update(self, data).await
13401617
}
13411618

1619+
async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
1620+
Database::bulk_insert(self, data).await
1621+
}
1622+
1623+
async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
1624+
Database::bulk_insert_or_update(self, data).await
1625+
}
1626+
13421627
async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
13431628
Database::query(self, query).await
13441629
}
@@ -1370,6 +1655,14 @@ impl DatabaseBackend for std::sync::Arc<Database> {
13701655
Database::update(self, data).await
13711656
}
13721657

1658+
async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
1659+
Database::bulk_insert(self, data).await
1660+
}
1661+
1662+
async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
1663+
Database::bulk_insert_or_update(self, data).await
1664+
}
1665+
13731666
async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
13741667
Database::query(self, query).await
13751668
}

0 commit comments

Comments
 (0)