Skip to content

Commit 8c0288c

Browse files
authored
fix(torii-grpc): models filter for historical (#29)
* fix(torii-grpc): models filter for historical * include having clause in historical * fix historical * fmt
1 parent 51378b2 commit 8c0288c

File tree

2 files changed

+69
-48
lines changed

2 files changed

+69
-48
lines changed

bin/torii/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ mod cli;
2626
async fn main() -> anyhow::Result<()> {
2727
// Set the global tracing subscriber
2828
let filter_layer =
29-
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info,torii=info")); // Adjust default filter if needed
29+
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("torii=info"));
3030

3131
let indicatif_layer = IndicatifLayer::new();
3232

crates/grpc/server/src/lib.rs

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ impl DojoWorld {
221221
table: &str,
222222
model_relation_table: &str,
223223
where_clause: &str,
224+
having_clause: &str,
224225
mut bind_values: Vec<String>,
225226
pagination: Pagination,
226227
) -> Result<Page<proto::types::Entity>, Error> {
@@ -236,87 +237,105 @@ impl DojoWorld {
236237
conditions.push(where_clause.to_string());
237238
}
238239

240+
let order_direction = match pagination.direction {
241+
PaginationDirection::Forward => "ASC",
242+
PaginationDirection::Backward => "DESC",
243+
};
244+
239245
// Add cursor condition if present
240246
if let Some(ref cursor) = pagination.cursor {
241-
match pagination.direction {
242-
PaginationDirection::Forward => {
243-
conditions.push(format!("{table}.event_id >= ?"));
244-
}
245-
PaginationDirection::Backward => {
246-
conditions.push(format!("{table}.event_id <= ?"));
247-
}
248-
}
249-
bind_values.push(
250-
String::from_utf8(
251-
BASE64_STANDARD_NO_PAD
252-
.decode(cursor)
253-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
254-
)
255-
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
256-
);
247+
let decoded_cursor = String::from_utf8(
248+
BASE64_STANDARD_NO_PAD
249+
.decode(cursor)
250+
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?,
251+
)
252+
.map_err(|e| QueryError::InvalidCursor(e.to_string()))?;
253+
254+
let operator = match pagination.direction {
255+
PaginationDirection::Forward => ">=",
256+
PaginationDirection::Backward => "<=",
257+
};
258+
conditions.push(format!("{table}.event_id {operator} ?"));
259+
bind_values.push(decoded_cursor);
257260
}
258261

259-
let where_clause = if !conditions.is_empty() {
262+
let where_sql = if !conditions.is_empty() {
260263
format!("WHERE {}", conditions.join(" AND "))
261264
} else {
262265
String::new()
263266
};
264267

265-
let order_direction = match pagination.direction {
266-
PaginationDirection::Forward => "ASC",
267-
PaginationDirection::Backward => "DESC",
268-
};
268+
let limit = pagination.limit.unwrap_or(100);
269+
let query_limit = limit + 1;
269270

270-
let query = format!(
271+
let query_str = format!(
271272
"SELECT {table}.id, {table}.data, {table}.model_id, {table}.event_id, \
272273
group_concat({model_relation_table}.model_id) as model_ids
273274
FROM {table}
274275
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
275-
{where_clause}
276+
{where_sql}
276277
GROUP BY {table}.event_id
278+
HAVING {having_clause}
277279
ORDER BY {table}.event_id {order_direction}
278280
LIMIT ?
279281
"
280282
);
281283

282-
let mut query = sqlx::query_as(&query);
284+
let mut query = sqlx::query_as(&query_str);
283285
for value in bind_values {
284286
query = query.bind(value);
285287
}
286-
query = query.bind(pagination.limit.unwrap_or(100) + 1);
288+
query = query.bind(query_limit);
287289

288290
let db_entities: Vec<(String, String, String, String, String)> =
289291
query.fetch_all(&self.pool).await?;
290292

291-
let mut entities = Vec::new();
292-
for (id, data, model_id, _, _) in &db_entities[..db_entities.len().saturating_sub(1)] {
293-
let hashed_keys = Felt::from_str(id)
294-
.map_err(ParseError::FromStr)?
295-
.to_bytes_be()
296-
.to_vec();
297-
let model = self
298-
.model_cache
299-
.model(&Felt::from_str(model_id).map_err(ParseError::FromStr)?)
300-
.await?;
301-
let mut schema = model.schema;
302-
schema.from_json_value(serde_json::from_str(data).map_err(ParseError::FromJsonStr)?)?;
303-
304-
entities.push(proto::types::Entity {
305-
hashed_keys,
306-
models: vec![schema.as_struct().unwrap().clone().into()],
307-
});
308-
}
293+
let has_more = db_entities.len() == query_limit as usize;
294+
let results_to_take = if has_more {
295+
limit as usize
296+
} else {
297+
db_entities.len()
298+
};
299+
300+
let entities = db_entities
301+
.iter()
302+
.take(results_to_take)
303+
.map(|(id, data, model_id, _, _)| async {
304+
let hashed_keys = Felt::from_str(id)
305+
.map_err(ParseError::FromStr)?
306+
.to_bytes_be()
307+
.to_vec();
308+
let model = self
309+
.model_cache
310+
.model(&Felt::from_str(model_id).map_err(ParseError::FromStr)?)
311+
.await?;
312+
let mut schema = model.schema;
313+
schema.from_json_value(
314+
serde_json::from_str(data).map_err(ParseError::FromJsonStr)?,
315+
)?;
316+
317+
Ok::<_, Error>(proto::types::Entity {
318+
hashed_keys,
319+
models: vec![schema.as_struct().unwrap().clone().into()],
320+
})
321+
})
322+
// Collect the futures into a Vec
323+
.collect::<Vec<_>>();
309324

310-
// Get the next cursor from the last item's event_id if we fetched an extra one
311-
let next_cursor = if db_entities.len() > entities.len() {
312-
Some(db_entities.last().unwrap().3.clone()) // event_id is at index 3
325+
// Execute all the async mapping operations concurrently
326+
let entities: Vec<proto::types::Entity> = futures::future::try_join_all(entities).await?;
327+
328+
let next_cursor = if has_more {
329+
db_entities
330+
.last()
331+
.map(|(_, _, _, event_id, _)| BASE64_STANDARD_NO_PAD.encode(event_id))
313332
} else {
314333
None
315334
};
316335

317336
Ok(Page {
318337
items: entities,
319-
next_cursor: next_cursor.map(|cursor| BASE64_STANDARD_NO_PAD.encode(cursor)),
338+
next_cursor,
320339
})
321340
}
322341

@@ -377,6 +396,7 @@ impl DojoWorld {
377396
table,
378397
model_relation_table,
379398
&where_clause,
399+
&having_clause,
380400
bind_values,
381401
pagination,
382402
)
@@ -472,6 +492,7 @@ impl DojoWorld {
472492
table,
473493
model_relation_table,
474494
&where_clause,
495+
&having_clause,
475496
bind_values,
476497
pagination,
477498
)

0 commit comments

Comments
 (0)