Skip to content

Commit 6bf3acb

Browse files
authored
feat(torii-grpc): flate2 sql cursors (#32)
* feat(torii-grpc): flate2 sql cursors * fmt * lock * fix having * use urlsafe base64 no pad * f
1 parent 1850d4b commit 6bf3acb

File tree

4 files changed

+73
-53
lines changed

4 files changed

+73
-53
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/grpc/server/src/lib.rs

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ use std::str::FromStr;
1111
use std::sync::Arc;
1212
use std::time::Duration;
1313

14-
use base64::prelude::BASE64_STANDARD_NO_PAD;
15-
use base64::Engine;
1614
use crypto_bigint::U256;
1715
use dojo_types::naming::compute_selector_from_tag;
1816
use dojo_types::primitive::Primitive;
@@ -44,7 +42,7 @@ use torii_proto::error::ProtoError;
4442
use torii_sqlite::cache::ModelCache;
4543
use torii_sqlite::constants::SQL_DEFAULT_LIMIT;
4644
use torii_sqlite::error::{ParseError, QueryError};
47-
use torii_sqlite::model::{fetch_entities, map_row_to_ty};
45+
use torii_sqlite::model::{decode_cursor, encode_cursor, fetch_entities, map_row_to_ty};
4846
use torii_sqlite::types::{Page, Pagination, PaginationDirection, Token, TokenBalance};
4947
use torii_sqlite::utils::u256_to_sql_string;
5048
use tower_http::cors::{AllowOrigin, CorsLayer};
@@ -245,12 +243,7 @@ impl DojoWorld {
245243

246244
// Add cursor condition if present
247245
if let Some(ref cursor) = pagination.cursor {
248-
let decoded_cursor = String::from_utf8(
249-
BASE64_STANDARD_NO_PAD
250-
.decode(cursor)
251-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
252-
)
253-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?;
246+
let decoded_cursor = decode_cursor(cursor)?;
254247

255248
let operator = match pagination.direction {
256249
PaginationDirection::Forward => ">=",
@@ -260,7 +253,7 @@ impl DojoWorld {
260253
bind_values.push(decoded_cursor);
261254
}
262255

263-
let where_sql = if !conditions.is_empty() {
256+
let where_clause = if !conditions.is_empty() {
264257
format!("WHERE {}", conditions.join(" AND "))
265258
} else {
266259
String::new()
@@ -274,12 +267,17 @@ impl DojoWorld {
274267
group_concat({model_relation_table}.model_id) as model_ids
275268
FROM {table}
276269
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
277-
{where_sql}
270+
{where_clause}
278271
GROUP BY {table}.event_id
279-
HAVING {having_clause}
272+
{}
280273
ORDER BY {table}.event_id {order_direction}
281274
LIMIT ?
282-
"
275+
",
276+
if !having_clause.is_empty() {
277+
format!("HAVING {}", having_clause)
278+
} else {
279+
String::new()
280+
}
283281
);
284282

285283
let mut query = sqlx::query_as(&query_str);
@@ -329,7 +327,8 @@ impl DojoWorld {
329327
let next_cursor = if has_more {
330328
db_entities
331329
.last()
332-
.map(|(_, _, _, event_id, _)| BASE64_STANDARD_NO_PAD.encode(event_id))
330+
.map(|(_, _, _, event_id, _)| encode_cursor(&event_id))
331+
.transpose()?
333332
} else {
334333
None
335334
};
@@ -767,14 +766,7 @@ impl DojoWorld {
767766
}
768767

769768
if let Some(cursor) = cursor {
770-
bind_values.push(
771-
String::from_utf8(
772-
BASE64_STANDARD_NO_PAD
773-
.decode(cursor)
774-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
775-
)
776-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
777-
);
769+
bind_values.push(decode_cursor(&cursor)?);
778770
conditions.push("id >= ?".to_string());
779771
}
780772

@@ -792,7 +784,7 @@ impl DojoWorld {
792784

793785
let mut tokens: Vec<Token> = query.fetch_all(&self.pool).await?;
794786
let next_cursor = if tokens.len() > limit.unwrap_or(SQL_DEFAULT_LIMIT as u32) as usize {
795-
BASE64_STANDARD_NO_PAD.encode(tokens.pop().unwrap().id.to_string().as_bytes())
787+
encode_cursor(&tokens.pop().unwrap().id)?
796788
} else {
797789
String::new()
798790
};
@@ -838,14 +830,7 @@ impl DojoWorld {
838830
}
839831

840832
if let Some(cursor) = cursor {
841-
bind_values.push(
842-
String::from_utf8(
843-
BASE64_STANDARD_NO_PAD
844-
.decode(cursor)
845-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
846-
)
847-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
848-
);
833+
bind_values.push(decode_cursor(&cursor)?);
849834
conditions.push("id >= ?".to_string());
850835
}
851836

@@ -863,7 +848,7 @@ impl DojoWorld {
863848

864849
let mut balances: Vec<TokenBalance> = query.fetch_all(&self.pool).await?;
865850
let next_cursor = if balances.len() > limit.unwrap_or(SQL_DEFAULT_LIMIT as u32) as usize {
866-
BASE64_STANDARD_NO_PAD.encode(balances.pop().unwrap().id.to_string().as_bytes())
851+
encode_cursor(&balances.pop().unwrap().id)?
867852
} else {
868853
String::new()
869854
};
@@ -996,14 +981,7 @@ impl DojoWorld {
996981

997982
if !query.cursor.is_empty() {
998983
conditions.push("id >= ?");
999-
bind_values.push(
1000-
String::from_utf8(
1001-
BASE64_STANDARD_NO_PAD
1002-
.decode(query.cursor.clone())
1003-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
1004-
)
1005-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
1006-
);
984+
bind_values.push(decode_cursor(&query.cursor)?);
1007985
}
1008986

1009987
let mut events_query = r#"
@@ -1027,7 +1005,7 @@ impl DojoWorld {
10271005
row_events.fetch_all(&self.pool).await?;
10281006

10291007
let next_cursor = if row_events.len() > (limit - 1) as usize {
1030-
BASE64_STANDARD_NO_PAD.encode(row_events.pop().unwrap().0.to_string().as_bytes())
1008+
encode_cursor(&row_events.pop().unwrap().0)?
10311009
} else {
10321010
String::new()
10331011
};

crates/sqlite/sqlite/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ tokio = { version = "1.32.0", features = [ "macros", "sync" ], default-features
3939
ipfs-api-backend-hyper.workspace = true
4040
tokio-util.workspace = true
4141
tracing.workspace = true
42+
flate2.workspace = true
4243

4344
[dev-dependencies]
4445
dojo-test-utils.workspace = true

crates/sqlite/sqlite/src/model.rs

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
2+
use base64::Engine;
3+
use flate2::read::DeflateDecoder;
4+
use flate2::write::DeflateEncoder;
5+
use flate2::Compression;
16
use std::collections::HashSet;
7+
use std::io::prelude::*;
28
use std::str::FromStr;
39

410
use async_trait::async_trait;
5-
use base64::engine::general_purpose;
6-
use base64::Engine;
711
use crypto_bigint::U256;
812
use dojo_types::primitive::{Primitive, PrimitiveError};
913
use dojo_types::schema::Ty;
@@ -386,14 +390,11 @@ pub async fn fetch_entities(
386390
.cursor
387391
.as_ref()
388392
.map(|cursor_str| {
389-
let decoded = general_purpose::STANDARD_NO_PAD
390-
.decode(cursor_str)
391-
.map_err(|e| Error::QueryError(QueryError::InvalidCursor(e.to_string())))?;
392-
String::from_utf8(decoded)
393-
.map_err(|e| Error::QueryError(QueryError::InvalidCursor(e.to_string())))
394-
.map(|s| s.split('/').map(|s| s.to_string()).collect())
393+
let decompressed_str = decode_cursor(cursor_str)?;
394+
Ok(decompressed_str.split('/').map(|s| s.to_string()).collect())
395395
})
396-
.transpose()?;
396+
.transpose()
397+
.map_err(|e: Error| Error::QueryError(QueryError::InvalidCursor(e.to_string())))?;
397398

398399
// Build cursor conditions
399400
let (cursor_conditions, cursor_binds) =
@@ -476,9 +477,8 @@ pub async fn fetch_entities(
476477
// Replace generation of next cursor to only when there are more pages
477478
if has_more_pages {
478479
if let Some(last_row) = all_rows.last() {
479-
let cursor_values = build_cursor_values(&pagination, last_row)?;
480-
next_cursor =
481-
Some(general_purpose::STANDARD_NO_PAD.encode(cursor_values.join("/").as_bytes()));
480+
let cursor_values_str = build_cursor_values(&pagination, last_row)?.join("/");
481+
next_cursor = Some(encode_cursor(&cursor_values_str)?);
482482
}
483483
}
484484

@@ -602,3 +602,43 @@ fn build_cursor_values(pagination: &Pagination, row: &SqliteRow) -> Result<Vec<S
602602
Ok(values)
603603
}
604604
}
605+
606+
/// Compresses a string using Deflate and then encodes it using Base64 (no padding).
607+
pub fn encode_cursor(value: &str) -> Result<String, Error> {
608+
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
609+
encoder.write_all(value.as_bytes()).map_err(|e| {
610+
Error::QueryError(QueryError::InvalidCursor(format!(
611+
"Cursor compression error: {}",
612+
e
613+
)))
614+
})?;
615+
let compressed_bytes = encoder.finish().map_err(|e| {
616+
Error::QueryError(QueryError::InvalidCursor(format!(
617+
"Cursor compression finish error: {}",
618+
e
619+
)))
620+
})?;
621+
622+
Ok(BASE64_URL_SAFE_NO_PAD.encode(&compressed_bytes))
623+
}
624+
625+
/// Decodes a Base64 (no padding) string and then decompresses it using Deflate.
626+
pub fn decode_cursor(encoded_cursor: &str) -> Result<String, Error> {
627+
let compressed_cursor_bytes = BASE64_URL_SAFE_NO_PAD.decode(encoded_cursor).map_err(|e| {
628+
Error::QueryError(QueryError::InvalidCursor(format!(
629+
"Base64 decode error: {}",
630+
e
631+
)))
632+
})?;
633+
634+
let mut decoder = DeflateDecoder::new(&compressed_cursor_bytes[..]);
635+
let mut decompressed_str = String::new();
636+
decoder.read_to_string(&mut decompressed_str).map_err(|e| {
637+
Error::QueryError(QueryError::InvalidCursor(format!(
638+
"Decompression error: {}",
639+
e
640+
)))
641+
})?;
642+
643+
Ok(decompressed_str)
644+
}

0 commit comments

Comments
 (0)