Skip to content

Commit f088176

Browse files
optimize batch insert query for reward accounts (#3836)
* WIP: batch inserts * add tests * support batch inserts for sqlite * use batch insert for state update * clippy * cleanup * lint * use hashmap * move test to slow-tests * fix test attempt1 * test w 20k accounts --------- Co-authored-by: sveitser <[email protected]>
1 parent 5eb5619 commit f088176

File tree

12 files changed

+1106
-331
lines changed

12 files changed

+1106
-331
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hotshot-query-service/src/data_source/extension.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,19 @@ where
443443
.insert_merkle_nodes(path, traversal_path, block_number)
444444
.await
445445
}
446+
447+
async fn insert_merkle_nodes_batch(
448+
&mut self,
449+
proofs: Vec<(
450+
MerkleProof<State::Entry, State::Key, State::T, ARITY>,
451+
Vec<usize>,
452+
)>,
453+
block_number: u64,
454+
) -> anyhow::Result<()> {
455+
self.data_source
456+
.insert_merkle_nodes_batch(proofs, block_number)
457+
.await
458+
}
446459
}
447460

448461
#[async_trait]

hotshot-query-service/src/data_source/storage/sql/queries/state.rs

Lines changed: 275 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use std::{
1717
sync::Arc,
1818
};
1919

20+
#[cfg(not(feature = "embedded-db"))]
21+
use anyhow::Context;
2022
use ark_serialize::CanonicalDeserialize;
2123
use async_trait::async_trait;
2224
use futures::stream::TryStreamExt;
@@ -341,6 +343,7 @@ impl<Mode: TransactionMode> Transaction<Mode> {
341343
}
342344

343345
// TODO: create a generic upsert function with retries that returns the column
346+
#[cfg(feature = "embedded-db")]
344347
pub(crate) fn build_hash_batch_insert(
345348
hashes: &[Vec<u8>],
346349
) -> QueryResult<(QueryBuilder<'_>, String)> {
@@ -357,6 +360,165 @@ pub(crate) fn build_hash_batch_insert(
357360
Ok((query, sql))
358361
}
359362

363+
/// Batch insert hashes using UNNEST for large batches (postgres only).
364+
/// Returns a map from hash bytes to their database IDs.
365+
#[cfg(not(feature = "embedded-db"))]
366+
pub(crate) async fn batch_insert_hashes(
367+
hashes: Vec<Vec<u8>>,
368+
tx: &mut Transaction<Write>,
369+
) -> QueryResult<HashMap<Vec<u8>, i32>> {
370+
if hashes.is_empty() {
371+
return Ok(HashMap::new());
372+
}
373+
374+
// Use UNNEST-based batch insert (more efficient and avoids parameter limits)
375+
let sql = "INSERT INTO hash(value) SELECT * FROM UNNEST($1::bytea[]) ON CONFLICT (value) DO \
376+
UPDATE SET value = EXCLUDED.value RETURNING value, id";
377+
378+
let result: HashMap<Vec<u8>, i32> = sqlx::query_as(sql)
379+
.bind(&hashes)
380+
.fetch(tx.as_mut())
381+
.try_collect()
382+
.await
383+
.map_err(|e| QueryError::Error {
384+
message: format!("batch hash insert failed: {e}"),
385+
})?;
386+
387+
Ok(result)
388+
}
389+
390+
/// Type alias for a merkle proof with its traversal path.
391+
pub(crate) type ProofWithPath<Entry, Key, T, const ARITY: usize> =
392+
(MerkleProof<Entry, Key, T, ARITY>, Vec<usize>);
393+
394+
/// Collects nodes and hashes from merkle proofs.
395+
/// Returns (nodes, hashes) for batch insertion.
396+
pub(crate) fn collect_nodes_from_proofs<Entry, Key, T, const ARITY: usize>(
397+
proofs: &[ProofWithPath<Entry, Key, T, ARITY>],
398+
) -> QueryResult<(Vec<NodeWithHashes>, HashSet<Vec<u8>>)>
399+
where
400+
Entry: jf_merkle_tree_compat::Element + serde::Serialize,
401+
Key: jf_merkle_tree_compat::Index + serde::Serialize,
402+
T: jf_merkle_tree_compat::NodeValue,
403+
{
404+
let mut nodes = Vec::new();
405+
let mut hashes = HashSet::new();
406+
407+
for (proof, traversal_path) in proofs {
408+
let pos = &proof.pos;
409+
let path = &proof.proof;
410+
let mut trav_path = traversal_path.iter().map(|n| *n as i32);
411+
412+
for node in path.iter() {
413+
match node {
414+
MerkleNode::Empty => {
415+
let index =
416+
serde_json::to_value(pos.clone()).map_err(|e| QueryError::Error {
417+
message: format!("malformed merkle position: {e}"),
418+
})?;
419+
let node_path: Vec<i32> = trav_path.clone().rev().collect();
420+
nodes.push((
421+
Node {
422+
path: node_path.into(),
423+
idx: Some(index),
424+
..Default::default()
425+
},
426+
None,
427+
[0_u8; 32].to_vec(),
428+
));
429+
hashes.insert([0_u8; 32].to_vec());
430+
},
431+
MerkleNode::ForgettenSubtree { .. } => {
432+
return Err(QueryError::Error {
433+
message: "Node in the Merkle path contains a forgotten subtree".into(),
434+
});
435+
},
436+
MerkleNode::Leaf { value, pos, elem } => {
437+
let mut leaf_commit = Vec::new();
438+
value.serialize_compressed(&mut leaf_commit).map_err(|e| {
439+
QueryError::Error {
440+
message: format!("malformed merkle leaf commitment: {e}"),
441+
}
442+
})?;
443+
444+
let node_path: Vec<i32> = trav_path.clone().rev().collect();
445+
446+
let index =
447+
serde_json::to_value(pos.clone()).map_err(|e| QueryError::Error {
448+
message: format!("malformed merkle position: {e}"),
449+
})?;
450+
let entry = serde_json::to_value(elem).map_err(|e| QueryError::Error {
451+
message: format!("malformed merkle element: {e}"),
452+
})?;
453+
454+
nodes.push((
455+
Node {
456+
path: node_path.into(),
457+
idx: Some(index),
458+
entry: Some(entry),
459+
..Default::default()
460+
},
461+
None,
462+
leaf_commit.clone(),
463+
));
464+
465+
hashes.insert(leaf_commit);
466+
},
467+
MerkleNode::Branch { value, children } => {
468+
let mut branch_hash = Vec::new();
469+
value.serialize_compressed(&mut branch_hash).map_err(|e| {
470+
QueryError::Error {
471+
message: format!("malformed merkle branch hash: {e}"),
472+
}
473+
})?;
474+
475+
let mut children_bitvec = BitVec::new();
476+
let mut children_values = Vec::new();
477+
for child in children {
478+
let child = child.as_ref();
479+
match child {
480+
MerkleNode::Empty => {
481+
children_bitvec.push(false);
482+
},
483+
MerkleNode::Branch { value, .. }
484+
| MerkleNode::Leaf { value, .. }
485+
| MerkleNode::ForgettenSubtree { value } => {
486+
let mut hash = Vec::new();
487+
value.serialize_compressed(&mut hash).map_err(|e| {
488+
QueryError::Error {
489+
message: format!("malformed merkle node hash: {e}"),
490+
}
491+
})?;
492+
493+
children_values.push(hash);
494+
children_bitvec.push(true);
495+
},
496+
}
497+
}
498+
499+
let node_path: Vec<i32> = trav_path.clone().rev().collect();
500+
nodes.push((
501+
Node {
502+
path: node_path.into(),
503+
children: None,
504+
children_bitvec: Some(children_bitvec),
505+
..Default::default()
506+
},
507+
Some(children_values.clone()),
508+
branch_hash.clone(),
509+
));
510+
hashes.insert(branch_hash);
511+
hashes.extend(children_values);
512+
},
513+
}
514+
515+
trav_path.next();
516+
}
517+
}
518+
519+
Ok((nodes, hashes))
520+
}
521+
360522
// Represents a row in a state table
361523
#[derive(Debug, Default, Clone)]
362524
pub(crate) struct Node {
@@ -369,6 +531,10 @@ pub(crate) struct Node {
369531
pub(crate) entry: Option<JsonValue>,
370532
}
371533

534+
/// Type alias for node data with optional children hashes and node hash.
535+
/// Used during batch collection before database insertion.
536+
pub(crate) type NodeWithHashes = (Node, Option<Vec<Vec<u8>>>, Vec<u8>);
537+
372538
#[cfg(feature = "embedded-db")]
373539
impl From<sqlx::sqlite::SqliteRow> for Node {
374540
fn from(row: sqlx::sqlite::SqliteRow) -> Self {
@@ -409,40 +575,116 @@ impl Node {
409575
nodes: impl IntoIterator<Item = Self>,
410576
tx: &mut Transaction<Write>,
411577
) -> anyhow::Result<()> {
412-
tx.upsert(
413-
name,
414-
[
415-
"path",
416-
"created",
417-
"hash_id",
418-
"children",
419-
"children_bitvec",
420-
"idx",
421-
"entry",
422-
],
423-
["path", "created"],
424-
nodes.into_iter().map(|n| {
425-
#[cfg(feature = "embedded-db")]
426-
let children_bitvec: Option<String> = n
427-
.children_bitvec
428-
.clone()
429-
.map(|b| b.iter().map(|bit| if bit { '1' } else { '0' }).collect());
430-
431-
#[cfg(not(feature = "embedded-db"))]
432-
let children_bitvec = n.children_bitvec.clone();
433-
434-
(
435-
n.path.clone(),
436-
n.created,
437-
n.hash_id,
438-
n.children.clone(),
439-
children_bitvec,
440-
n.idx.clone(),
441-
n.entry.clone(),
578+
let nodes: Vec<_> = nodes.into_iter().collect();
579+
580+
// Use UNNEST-based batch insert for postgres (more efficient and avoids parameter limits)
581+
#[cfg(not(feature = "embedded-db"))]
582+
return Self::upsert_batch_unnest(name, nodes, tx).await;
583+
584+
#[cfg(feature = "embedded-db")]
585+
{
586+
for node_chunk in nodes.chunks(20) {
587+
let rows: Vec<_> = node_chunk
588+
.iter()
589+
.map(|n| {
590+
let children_bitvec: Option<String> = n
591+
.children_bitvec
592+
.clone()
593+
.map(|b| b.iter().map(|bit| if bit { '1' } else { '0' }).collect());
594+
595+
(
596+
n.path.clone(),
597+
n.created,
598+
n.hash_id,
599+
n.children.clone(),
600+
children_bitvec,
601+
n.idx.clone(),
602+
n.entry.clone(),
603+
)
604+
})
605+
.collect();
606+
607+
tx.upsert(
608+
name,
609+
[
610+
"path",
611+
"created",
612+
"hash_id",
613+
"children",
614+
"children_bitvec",
615+
"idx",
616+
"entry",
617+
],
618+
["path", "created"],
619+
rows,
442620
)
443-
}),
444-
)
445-
.await
621+
.await?;
622+
}
623+
Ok(())
624+
}
625+
}
626+
627+
#[cfg(not(feature = "embedded-db"))]
628+
async fn upsert_batch_unnest(
629+
name: &str,
630+
nodes: Vec<Self>,
631+
tx: &mut Transaction<Write>,
632+
) -> anyhow::Result<()> {
633+
if nodes.is_empty() {
634+
return Ok(());
635+
}
636+
637+
// Deduplicate nodes by (path, created) - keep the last occurrence
638+
// This is required because UNNEST + ON CONFLICT cannot handle duplicates in the same batch
639+
let mut deduped = HashMap::new();
640+
for node in nodes {
641+
deduped.insert((node.path.to_string(), node.created), node);
642+
}
643+
644+
let mut paths = Vec::with_capacity(deduped.len());
645+
let mut createds = Vec::with_capacity(deduped.len());
646+
let mut hash_ids = Vec::with_capacity(deduped.len());
647+
let mut childrens = Vec::with_capacity(deduped.len());
648+
let mut children_bitvecs = Vec::with_capacity(deduped.len());
649+
let mut idxs = Vec::with_capacity(deduped.len());
650+
let mut entries = Vec::with_capacity(deduped.len());
651+
652+
for node in deduped.into_values() {
653+
paths.push(node.path);
654+
createds.push(node.created);
655+
hash_ids.push(node.hash_id);
656+
childrens.push(node.children);
657+
children_bitvecs.push(node.children_bitvec);
658+
idxs.push(node.idx);
659+
entries.push(node.entry);
660+
}
661+
662+
let sql = format!(
663+
r#"
664+
INSERT INTO "{name}" (path, created, hash_id, children, children_bitvec, idx, entry)
665+
SELECT * FROM UNNEST($1::jsonb[], $2::bigint[], $3::int[], $4::jsonb[], $5::bit varying[], $6::jsonb[], $7::jsonb[])
666+
ON CONFLICT (path, created) DO UPDATE SET
667+
hash_id = EXCLUDED.hash_id,
668+
children = EXCLUDED.children,
669+
children_bitvec = EXCLUDED.children_bitvec,
670+
idx = EXCLUDED.idx,
671+
entry = EXCLUDED.entry
672+
"#
673+
);
674+
675+
sqlx::query(&sql)
676+
.bind(&paths)
677+
.bind(&createds)
678+
.bind(&hash_ids)
679+
.bind(&childrens)
680+
.bind(&children_bitvecs)
681+
.bind(&idxs)
682+
.bind(&entries)
683+
.execute(tx.as_mut())
684+
.await
685+
.context("batch upsert with UNNEST failed")?;
686+
687+
Ok(())
446688
}
447689
}
448690

0 commit comments

Comments
 (0)