Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 294 additions & 1 deletion cot/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use mockall::automock;
use query::Query;
pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy};
use sea_query::{
Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr,
ColumnRef, Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr,
};
use sea_query_binder::{SqlxBinder, SqlxValues};
use sqlx::{Type, TypeInfo};
Expand Down Expand Up @@ -84,6 +84,22 @@ pub enum DatabaseError {
/// Error when a unique constraint is violated in the database.
#[error("{ERROR_PREFIX} unique constraint violation")]
UniqueViolation,
/// Single model has more fields than database parameter limit.
#[error(
"{ERROR_PREFIX} model has {field_count} fields which exceeds the database parameter limit \
of {limit}"
)]
BulkCreateModelTooLarge {
/// The number of fields in the model.
field_count: usize,
/// The database parameter limit.
limit: usize,
},
/// Attempted bulk create with a model that only contains Auto fields.
#[error(
"{ERROR_PREFIX} calling bulk_create with a model that only contains auto/default fields is unsupported"
)]
BulkCreateNoValueColumns,
}
impl_into_cot_error!(DatabaseError, INTERNAL_SERVER_ERROR);

Expand Down Expand Up @@ -247,6 +263,39 @@ pub trait Model: Sized + Send + 'static {
db.update(self).await?;
Ok(())
}

/// Bulk insert multiple model instances to the database in a single query.
///
/// This method is significantly faster than calling [`Self::insert`]
/// multiple times, as it combines all instances into a single SQL
/// `INSERT` statement with multiple value sets.
///
/// # Examples
///
/// ```ignore
/// let mut todos = vec![
/// TodoItem { id: Auto::auto(), title: "Task 1".into() },
/// TodoItem { id: Auto::auto(), title: "Task 2".into() },
/// TodoItem { id: Auto::auto(), title: "Task 3".into() },
/// ];
///
/// TodoItem::bulk_create(&db, &mut todos).await?;
///
/// // After insertion, all todos have populated IDs
/// assert!(todos[0].id.is_fixed());
/// ```
///
/// # Errors
///
/// Returns error if:
/// - Database connection fails
/// - Unique constraint is violated
/// - Empty slice is provided
/// - Single model has more fields than the database parameter limit
async fn bulk_create<DB: DatabaseBackend>(db: &DB, instances: &mut [Self]) -> Result<()> {
db.bulk_insert(instances).await?;
Ok(())
}
}

/// An identifier structure that holds table or column name as a string.
Expand Down Expand Up @@ -965,6 +1014,215 @@ impl Database {
Ok(())
}

/// Bulk inserts multiple rows into the database.
///
/// # Errors
///
/// This method can return an error if the rows could not be inserted into
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
pub async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
let span = span!(Level::TRACE, "bulk_insert", table = %T::TABLE_NAME, count = data.len());

Self::bulk_insert_impl(self, data, false)
.instrument(span)
.await
}

/// Bulk inserts multiple rows into the database, or updates them if they
/// already exist.
///
/// # Errors
///
/// This method can return an error if the rows could not be inserted into
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
pub async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
let span = span!(
Level::TRACE,
"bulk_insert_or_update",
table = %T::TABLE_NAME,
count = data.len()
);

Self::bulk_insert_impl(self, data, true)
.instrument(span)
.await
}

async fn bulk_insert_impl<T: Model>(&self, data: &mut [T], update: bool) -> Result<()> {
// TODO: add transactions when implemented

if data.is_empty() {
return Ok(());
}

let max_params = match self.inner {
// https://sqlite.org/limits.html#max_variable_number
// Assuming SQLite > 3.32.0 (2020-05-22)
#[cfg(feature = "sqlite")]
DatabaseImpl::Sqlite(_) => 32766,
// https://www.postgresql.org/docs/18/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BIND
// The number of parameter format codes is Int16
#[cfg(feature = "postgres")]
DatabaseImpl::Postgres(_) => 65535,
// https://dev.mysql.com/doc/dev/mysql-server/9.5.0/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
// The number of parameter returned in the COM_STMT_PREPARE_OK packet is int<2>
#[cfg(feature = "mysql")]
DatabaseImpl::MySql(_) => 65535,
};

let column_identifiers: Vec<_> = T::COLUMNS
.iter()
.map(|column| Identifier::from(column.name.as_str()))
.collect();
let value_indices: Vec<_> = (0..T::COLUMNS.len()).collect();

// Determine which columns are Auto vs Value by examining the first instance
// (all instances should have the same structure, as we are inserting a single
// Model)
let first_values = data[0]
.get_values(&value_indices)
.into_iter()
.map(ToDbFieldValue::to_db_field_value)
.collect::<Vec<_>>();

let mut auto_col_ids = Vec::new();
let mut auto_col_identifiers = Vec::new();
let mut value_column_indices = Vec::new();
let mut value_identifiers = Vec::new();

for (index, (identifier, value)) in
std::iter::zip(column_identifiers.iter(), first_values.iter()).enumerate()
{
match value {
DbFieldValue::Auto => {
auto_col_ids.push(index);
auto_col_identifiers.push((*identifier).into_column_ref());
}
DbFieldValue::Value(_) => {
value_column_indices.push(index);
value_identifiers.push(*identifier);
}
}
}

let num_value_fields = value_identifiers.len();

if num_value_fields > max_params {
return Err(DatabaseError::BulkCreateModelTooLarge {
field_count: num_value_fields,
limit: max_params,
});
}

let batch_size = if num_value_fields > 0 {
max_params / num_value_fields
} else {
return Err(DatabaseError::BulkCreateNoValueColumns);
};

for chunk in data.chunks_mut(batch_size) {
self.bulk_insert_chunk(
chunk,
update,
&value_identifiers,
&value_column_indices,
&auto_col_ids,
&auto_col_identifiers,
)
.await?;
}

Ok(())
}

async fn bulk_insert_chunk<T: Model>(
&self,
chunk: &mut [T],
update: bool,
value_identifiers: &[Identifier],
value_column_indices: &[usize],
auto_col_ids: &[usize],
auto_col_identifiers: &[ColumnRef],
) -> Result<()> {
let mut insert_statement = sea_query::Query::insert()
.into_table(T::TABLE_NAME)
.columns(value_identifiers.iter().copied())
.to_owned();

// Add values for each instance in the chunk
for instance in chunk.iter() {
let values = instance.get_values(value_column_indices);
let db_values: Vec<_> = values
.into_iter()
.map(|v| match v.to_db_field_value() {
DbFieldValue::Value(val) => val,
DbFieldValue::Auto => {
panic!("Expected Value but found Auto in bulk insert")
}
})
.map(SimpleExpr::Value)
.collect();

assert!(!db_values.is_empty(), "expected at least 1 value field");
Copy link

Copilot AI Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion should never trigger given the check at line 1122 that returns an error if num_value_fields == 0. However, asserts can be disabled in release builds with --release, potentially causing undefined behavior.

Since this is a safety-critical validation, consider using debug_assert! if this is truly unreachable, or returning a proper error if there's any scenario where this could occur:

if db_values.is_empty() {
    return Err(DatabaseError::BulkCreateNoValueColumns);
}
Suggested change
assert!(!db_values.is_empty(), "expected at least 1 value field");
if db_values.is_empty() {
return Err(DatabaseError::BulkCreateNoValueColumns);
}

Copilot uses AI. Check for mistakes.
insert_statement.values(db_values)?;
}

if update {
insert_statement.on_conflict(
OnConflict::column(T::PRIMARY_KEY_NAME)
.update_columns(value_identifiers.iter().copied())
.to_owned(),
);
}

if auto_col_ids.is_empty() {
self.execute_statement(&insert_statement).await?;
} else if self.supports_returning() {
// PostgreSQL/SQLite: Use RETURNING clause
insert_statement.returning(ReturningClause::Columns(auto_col_identifiers.to_vec()));

let rows = self.fetch_all(&insert_statement).await?;

for (instance, row) in chunk.iter_mut().zip(rows) {
instance.update_from_db(row, auto_col_ids)?;
}
} else {
// MySQL: Use LAST_INSERT_ID() and fetch rows
let result = self.execute_statement(&insert_statement).await?;
let first_id = result
.last_inserted_row_id
.expect("expected last inserted row ID if RETURNING clause is not supported");

// Fetch the inserted rows using a SELECT query
// Note: This assumes IDs are consecutive, which is generally safe for
// auto_increment but could fail with concurrent inserts
let query = sea_query::Query::select()
.from(T::TABLE_NAME)
.columns(auto_col_identifiers.iter().cloned())
.and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).gte(first_id).and(
sea_query::Expr::col(T::PRIMARY_KEY_NAME).lt(first_id + chunk.len() as u64),
))
.order_by(T::PRIMARY_KEY_NAME, sea_query::Order::Asc)
.to_owned();

let rows = self.fetch_all(&query).await?;

for (instance, row) in chunk.iter_mut().zip(rows) {
instance.update_from_db(row, auto_col_ids)?;
}
}

if update {
trace!(count = chunk.len(), "Inserted or updated rows");
} else {
trace!(count = chunk.len(), "Inserted rows");
}

Ok(())
}

/// Executes the given query and returns the results converted to the model
/// type.
///
Expand Down Expand Up @@ -1271,6 +1529,25 @@ pub trait DatabaseBackend: Send + Sync {
/// there was a problem with the database connection.
async fn update<T: Model>(&self, data: &mut T) -> Result<()>;

/// Bulk inserts multiple rows into the database.
///
/// # Errors
///
/// This method can return an error if the rows could not be inserted into
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()>;

/// Bulk inserts multiple rows into the database, or updates existing rows
/// if they already exist.
///
/// # Errors
///
/// This method can return an error if the rows could not be inserted into
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()>;

/// Executes a query and returns the results converted to the model type.
///
/// # Errors
Expand Down Expand Up @@ -1339,6 +1616,14 @@ impl DatabaseBackend for Database {
Database::update(self, data).await
}

async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
Database::bulk_insert(self, data).await
}

async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
Database::bulk_insert_or_update(self, data).await
}

async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
Database::query(self, query).await
}
Expand Down Expand Up @@ -1370,6 +1655,14 @@ impl DatabaseBackend for std::sync::Arc<Database> {
Database::update(self, data).await
}

async fn bulk_insert<T: Model>(&self, data: &mut [T]) -> Result<()> {
Database::bulk_insert(self, data).await
}

async fn bulk_insert_or_update<T: Model>(&self, data: &mut [T]) -> Result<()> {
Database::bulk_insert_or_update(self, data).await
}

async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
Database::query(self, query).await
}
Expand Down
Loading
Loading