|
1 | 1 | use crate::db::DbPool;
|
2 | 2 | use bb8::RunError;
|
3 |
| -use bb8_postgres::tokio_postgres::{ |
4 |
| - types::{accepts, FromSql, Json, ToSql, Type}, |
5 |
| - Row, |
6 |
| -}; |
7 |
| -use chrono::{DateTime, Utc}; |
8 | 3 | use primitives::{Channel, ChannelId, ValidatorId};
|
9 |
| -use serde::Deserialize; |
10 |
| -use std::error::Error; |
11 | 4 | use std::str::FromStr;
|
12 | 5 |
|
13 |
| -struct TotalCount(pub u64); |
14 |
| -impl<'a> FromSql<'a> for TotalCount { |
15 |
| - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> { |
16 |
| - let str_slice = <&str as FromSql>::from_sql(ty, raw)?; |
17 |
| - |
18 |
| - Ok(Self(u64::from_str(str_slice)?)) |
19 |
| - } |
20 |
| - |
21 |
| - // Use a varchar or text, since otherwise `int8` fails deserialization |
22 |
| - accepts!(VARCHAR, TEXT); |
23 |
| -} |
| 6 | +pub use list_channels::{list_channels, ListChannels}; |
24 | 7 |
|
25 | 8 | pub async fn get_channel_by_id(
|
26 | 9 | pool: &DbPool,
|
@@ -87,111 +70,121 @@ pub async fn insert_channel(
|
87 | 70 | .await
|
88 | 71 | }
|
89 | 72 |
|
90 |
| -#[derive(Debug, Deserialize)] |
91 |
| -pub struct ListChannels { |
92 |
| - pub total_count: u64, |
93 |
| - pub channels: Vec<Channel>, |
94 |
| -} |
| 73 | +mod list_channels { |
| 74 | + use crate::db::DbPool; |
| 75 | + use bb8::RunError; |
| 76 | + use bb8_postgres::tokio_postgres::types::{accepts, FromSql, ToSql, Type}; |
| 77 | + use chrono::{DateTime, Utc}; |
| 78 | + use primitives::{Channel, ValidatorId}; |
| 79 | + use std::error::Error; |
| 80 | + use std::str::FromStr; |
| 81 | + |
| 82 | + struct TotalCount(pub u64); |
| 83 | + impl<'a> FromSql<'a> for TotalCount { |
| 84 | + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> { |
| 85 | + let str_slice = <&str as FromSql>::from_sql(ty, raw)?; |
| 86 | + |
| 87 | + Ok(Self(u64::from_str(str_slice)?)) |
| 88 | + } |
| 89 | + |
| 90 | + // Use a varchar or text, since otherwise `int8` fails deserialization |
| 91 | + accepts!(VARCHAR, TEXT); |
| 92 | + } |
95 | 93 |
|
96 |
| -impl From<&Row> for ListChannels { |
97 |
| - fn from(row: &Row) -> Self { |
98 |
| - let total_count = row.get::<_, TotalCount>(0).0; |
99 |
| - let channels = row.get::<_, Json<Vec<Channel>>>(1).0; |
| 94 | + #[derive(Debug)] |
| 95 | + pub struct ListChannels { |
| 96 | + pub total_count: u64, |
| 97 | + pub channels: Vec<Channel>, |
| 98 | + } |
100 | 99 |
|
101 |
| - Self { |
| 100 | + pub async fn list_channels( |
| 101 | + pool: &DbPool, |
| 102 | + skip: u64, |
| 103 | + limit: u32, |
| 104 | + creator: &Option<String>, |
| 105 | + validator: &Option<ValidatorId>, |
| 106 | + valid_until_ge: &DateTime<Utc>, |
| 107 | + ) -> Result<ListChannels, RunError<bb8_postgres::tokio_postgres::Error>> { |
| 108 | + let validator = validator.as_ref().map(|validator_id| { |
| 109 | + serde_json::Value::from_str(&format!(r#"[{{"id": "{}"}}]"#, validator_id)) |
| 110 | + .expect("Not a valid json") |
| 111 | + }); |
| 112 | + let (where_clauses, params) = |
| 113 | + channel_list_query_params(creator, validator.as_ref(), valid_until_ge); |
| 114 | + let total_count_params = (where_clauses.clone(), params.clone()); |
| 115 | + |
| 116 | + let channels = pool |
| 117 | + .run(move |connection| { |
| 118 | + async move { |
| 119 | + // To understand why we use Order by, see Postgres Documentation: https://www.postgresql.org/docs/8.1/queries-limit.html |
| 120 | + let statement = format!("SELECT id, creator, deposit_asset, deposit_amount, valid_until, spec FROM channels WHERE {} ORDER BY spec->>'created' DESC LIMIT {} OFFSET {}", where_clauses.join(" AND "), limit, skip); |
| 121 | + match connection.prepare(&statement).await { |
| 122 | + Ok(stmt) => { |
| 123 | + match connection.query(&stmt, params.as_slice()).await { |
| 124 | + Ok(rows) => { |
| 125 | + let channels = rows.iter().map(Channel::from).collect(); |
| 126 | + |
| 127 | + Ok((channels, connection)) |
| 128 | + }, |
| 129 | + Err(e) => Err((e, connection)), |
| 130 | + } |
| 131 | + }, |
| 132 | + Err(e) => Err((e, connection)), |
| 133 | + } |
| 134 | + } |
| 135 | + }) |
| 136 | + .await?; |
| 137 | + |
| 138 | + Ok(ListChannels { |
| 139 | + total_count: list_channels_total_count( |
| 140 | + &pool, |
| 141 | + (&total_count_params.0, total_count_params.1), |
| 142 | + ) |
| 143 | + .await?, |
102 | 144 | channels,
|
103 |
| - total_count, |
104 |
| - } |
| 145 | + }) |
105 | 146 | }
|
106 |
| -} |
107 | 147 |
|
108 |
| -pub async fn list_channels( |
109 |
| - pool: &DbPool, |
110 |
| - skip: u64, |
111 |
| - limit: u32, |
112 |
| - creator: &Option<String>, |
113 |
| - validator: &Option<ValidatorId>, |
114 |
| - valid_until_ge: &DateTime<Utc>, |
115 |
| -) -> Result<ListChannels, RunError<bb8_postgres::tokio_postgres::Error>> { |
116 |
| - let validator = validator.as_ref().map(|validator_id| { |
117 |
| - serde_json::Value::from_str(&format!(r#"[{{"id": "{}"}}]"#, validator_id)) |
118 |
| - .expect("Not a valid json") |
119 |
| - }); |
120 |
| - let (where_clauses, params) = |
121 |
| - channel_list_query_params(creator, validator.as_ref(), valid_until_ge); |
122 |
| - let total_count_params = (where_clauses.clone(), params.clone()); |
123 |
| - |
124 |
| - let channels = pool |
125 |
| - .run(move |connection| { |
| 148 | + async fn list_channels_total_count<'a>( |
| 149 | + pool: &DbPool, |
| 150 | + (where_clauses, params): (&'a [String], Vec<&'a (dyn ToSql + Sync)>), |
| 151 | + ) -> Result<u64, RunError<bb8_postgres::tokio_postgres::Error>> { |
| 152 | + pool.run(move |connection| { |
126 | 153 | async move {
|
127 |
| - // To understand why we use Order by, see Postgres Documentation: https://www.postgresql.org/docs/8.1/queries-limit.html |
128 |
| - let statement = format!("SELECT id, creator, deposit_asset, deposit_amount, valid_until, spec FROM channels WHERE {} ORDER BY spec->>'created' DESC LIMIT {} OFFSET {}", where_clauses.join(" AND "), limit, skip); |
| 154 | + let statement = format!( |
| 155 | + "SELECT COUNT(id)::varchar FROM channels WHERE {}", |
| 156 | + where_clauses.join(" AND ") |
| 157 | + ); |
129 | 158 | match connection.prepare(&statement).await {
|
130 |
| - Ok(stmt) => { |
131 |
| - match connection.query(&stmt, params.as_slice()).await { |
132 |
| - Ok(rows) => { |
133 |
| - let channels = rows.iter().map(Channel::from).collect(); |
134 |
| - |
135 |
| - Ok((channels, connection)) |
136 |
| - }, |
137 |
| - Err(e) => Err((e, connection)), |
138 |
| - } |
| 159 | + Ok(stmt) => match connection.query_one(&stmt, params.as_slice()).await { |
| 160 | + Ok(row) => Ok((row.get::<_, TotalCount>(0).0, connection)), |
| 161 | + Err(e) => Err((e, connection)), |
139 | 162 | },
|
140 | 163 | Err(e) => Err((e, connection)),
|
141 | 164 | }
|
142 | 165 | }
|
143 | 166 | })
|
144 |
| - .await?; |
145 |
| - |
146 |
| - Ok(ListChannels { |
147 |
| - total_count: list_channels_total_count( |
148 |
| - &pool, |
149 |
| - (&total_count_params.0, total_count_params.1), |
150 |
| - ) |
151 |
| - .await?, |
152 |
| - channels, |
153 |
| - }) |
154 |
| -} |
| 167 | + .await |
| 168 | + } |
155 | 169 |
|
156 |
| -async fn list_channels_total_count<'a>( |
157 |
| - pool: &DbPool, |
158 |
| - (where_clauses, params): (&'a [String], Vec<&'a (dyn ToSql + Sync)>), |
159 |
| -) -> Result<u64, RunError<bb8_postgres::tokio_postgres::Error>> { |
160 |
| - pool.run(move |connection| { |
161 |
| - async move { |
162 |
| - let statement = format!( |
163 |
| - "SELECT COUNT(id)::varchar FROM channels WHERE {}", |
164 |
| - where_clauses.join(" AND ") |
165 |
| - ); |
166 |
| - match connection.prepare(&statement).await { |
167 |
| - Ok(stmt) => match connection.query_one(&stmt, params.as_slice()).await { |
168 |
| - Ok(row) => Ok((row.get::<_, TotalCount>(0).0, connection)), |
169 |
| - Err(e) => Err((e, connection)), |
170 |
| - }, |
171 |
| - Err(e) => Err((e, connection)), |
172 |
| - } |
| 170 | + fn channel_list_query_params<'a>( |
| 171 | + creator: &'a Option<String>, |
| 172 | + validator: Option<&'a serde_json::Value>, |
| 173 | + valid_until_ge: &'a DateTime<Utc>, |
| 174 | + ) -> (Vec<String>, Vec<&'a (dyn ToSql + Sync)>) { |
| 175 | + let mut where_clauses = vec!["valid_until >= $1".to_string()]; |
| 176 | + let mut params: Vec<&(dyn ToSql + Sync)> = vec![valid_until_ge]; |
| 177 | + |
| 178 | + if let Some(creator) = creator { |
| 179 | + where_clauses.push(format!("creator = ${}", params.len() + 1)); |
| 180 | + params.push(creator); |
173 | 181 | }
|
174 |
| - }) |
175 |
| - .await |
176 |
| -} |
177 | 182 |
|
178 |
| -fn channel_list_query_params<'a>( |
179 |
| - creator: &'a Option<String>, |
180 |
| - validator: Option<&'a serde_json::Value>, |
181 |
| - valid_until_ge: &'a DateTime<Utc>, |
182 |
| -) -> (Vec<String>, Vec<&'a (dyn ToSql + Sync)>) { |
183 |
| - let mut where_clauses = vec!["valid_until >= $1".to_string()]; |
184 |
| - let mut params: Vec<&(dyn ToSql + Sync)> = vec![valid_until_ge]; |
185 |
| - |
186 |
| - if let Some(creator) = creator { |
187 |
| - where_clauses.push(format!("creator = ${}", params.len() + 1)); |
188 |
| - params.push(creator); |
189 |
| - } |
| 183 | + if let Some(validator) = validator { |
| 184 | + where_clauses.push(format!("spec->'validators' @> ${}", params.len() + 1)); |
| 185 | + params.push(validator); |
| 186 | + } |
190 | 187 |
|
191 |
| - if let Some(validator) = validator { |
192 |
| - where_clauses.push(format!("spec->'validators' @> ${}", params.len() + 1)); |
193 |
| - params.push(validator); |
| 188 | + (where_clauses, params) |
194 | 189 | }
|
195 |
| - |
196 |
| - (where_clauses, params) |
197 | 190 | }
|
0 commit comments