@@ -63,16 +63,19 @@ pub async fn insert_channel(
63
63
}
64
64
65
65
mod list_channels {
66
+ use futures:: { pin_mut, TryStreamExt } ;
66
67
use primitives:: {
67
68
sentry:: { channel_list:: ChannelListResponse , Pagination } ,
68
69
ChainId , Channel , ValidatorId ,
69
70
} ;
71
+ use tokio_postgres:: { types:: ToSql , Row } ;
70
72
71
73
use crate :: db:: { DbPool , PoolError , TotalCount } ;
72
74
73
75
/// Lists the `Channel`s in `ASC` order.
74
- /// This makes sure that if a new `Channel` is added
75
- // while we are scrolling through the pages it will not alter the `Channel`s ordering
76
+ ///
77
+ /// This makes sure that if a new [`Channel`] is added
78
+ /// while we are scrolling through the pages it will not alter the [`Channel`]s ordering.
76
79
pub async fn list_channels (
77
80
pool : & DbPool ,
78
81
skip : u64 ,
@@ -81,47 +84,63 @@ mod list_channels {
81
84
chains : & [ ChainId ] ,
82
85
) -> Result < ChannelListResponse , PoolError > {
83
86
let client = pool. get ( ) . await ?;
87
+
84
88
let mut where_clauses = vec ! [ ] ;
89
+ let mut params: Vec < Box < ( dyn ToSql + Send + Sync ) > > = vec ! [ ] ;
90
+ let mut params_total: Vec < Box < ( dyn ToSql + Send + Sync ) > > = vec ! [ ] ;
91
+
85
92
if !chains. is_empty ( ) {
86
- where_clauses. push ( format ! (
87
- "chain_id IN ({})" ,
88
- chains
89
- . iter( )
90
- . map( |id| id. to_u32( ) . to_string( ) )
91
- . collect:: <Vec <String >>( )
92
- . join( "," )
93
- ) ) ;
93
+ let ( chain_params, chain_params_total) : (
94
+ Vec < Box < dyn ToSql + Send + Sync > > ,
95
+ Vec < Box < dyn ToSql + Send + Sync > > ,
96
+ ) = chains
97
+ . iter ( )
98
+ . map ( |chain_id| ( Box :: new ( * chain_id) as _ , Box :: new ( * chain_id) as _ ) )
99
+ . unzip ( ) ;
100
+
101
+ // prepare the query parameters, they are 1-indexed!
102
+ let params_prepared = ( 1 ..=chain_params. len ( ) )
103
+ . map ( |param_num| format ! ( "${param_num}" ) )
104
+ . collect :: < Vec < _ > > ( )
105
+ . join ( "," ) ;
106
+
107
+ params. extend ( chain_params) ;
108
+ params_total. extend ( chain_params_total) ;
109
+
110
+ where_clauses. push ( format ! ( "chain_id IN ({})" , params_prepared) ) ;
94
111
}
95
112
96
- // To understand why we use Order by, see Postgres Documentation: https://www.postgresql.org/docs/8.1/queries-limit.html
97
- let rows = match validator {
113
+ match validator {
98
114
Some ( validator) => {
99
- where_clauses. push ( "(leader = $1 OR follower = $1)" . to_string ( ) ) ;
100
-
101
- let statement = format ! ( "SELECT leader, follower, guardian, token, nonce, created FROM channels WHERE {} ORDER BY created ASC LIMIT {} OFFSET {}" ,
102
- where_clauses. join( " AND " ) , limit, skip) ;
103
- let stmt = client. prepare ( & statement) . await ?;
104
-
105
- client. query ( & stmt, & [ & validator. to_string ( ) ] ) . await ?
106
- }
107
- None => {
108
- let statement = if !where_clauses. is_empty ( ) {
109
- format ! ( "SELECT id, leader, follower, guardian, token, nonce, created FROM channels WHERE {} ORDER BY created ASC LIMIT {} OFFSET {}" ,
110
- where_clauses. join( " AND " ) , limit, skip)
111
- } else {
112
- format ! ( "SELECT id, leader, follower, guardian, token, nonce, created FROM channels ORDER BY created ASC LIMIT {} OFFSET {}" ,
113
- limit, skip)
114
- } ;
115
-
116
- let stmt = client. prepare ( & statement) . await ?;
117
-
118
- client. query ( & stmt, & [ ] ) . await ?
115
+ // params are 1-indexed
116
+ where_clauses. push ( format ! (
117
+ "(leader = ${validator_param} OR follower = ${validator_param})" ,
118
+ validator_param = params. len( ) + 1
119
+ ) ) ;
120
+ // then add the new param to the list!
121
+ params. push ( Box :: new ( validator) as _ ) ;
122
+ params_total. push ( Box :: new ( validator) as _ ) ;
119
123
}
124
+ _ => { }
125
+ }
126
+
127
+ // To understand why we use Order by, see Postgres Documentation: https://www.postgresql.org/docs/8.1/queries-limit.html
128
+ let statement = if !where_clauses. is_empty ( ) {
129
+ format ! ( "SELECT id, leader, follower, guardian, token, nonce, created FROM channels WHERE {} ORDER BY created ASC LIMIT {} OFFSET {}" ,
130
+ where_clauses. join( " AND " ) , limit, skip)
131
+ } else {
132
+ format ! ( "SELECT id, leader, follower, guardian, token, nonce, created FROM channels ORDER BY created ASC LIMIT {} OFFSET {}" ,
133
+ limit, skip)
120
134
} ;
121
135
136
+ let stmt = client. prepare ( & statement) . await ?;
137
+
138
+ let rows: Vec < Row > = client. query_raw ( & stmt, params) . await ?. try_collect ( ) . await ?;
139
+
122
140
let channels = rows. iter ( ) . map ( Channel :: from) . collect ( ) ;
123
141
124
- let total_count = list_channels_total_count ( pool, validator, chains) . await ?;
142
+ let total_count = list_channels_total_count ( pool, ( where_clauses, params_total) ) . await ?;
143
+
125
144
// fast ceil for total_pages
126
145
let total_pages = if total_count == 0 {
127
146
1
@@ -140,57 +159,38 @@ mod list_channels {
140
159
141
160
async fn list_channels_total_count < ' a > (
142
161
pool : & DbPool ,
143
- validator : Option < ValidatorId > ,
144
- chains : & [ ChainId ] ,
162
+ ( where_clauses, params) : ( Vec < String > , Vec < Box < dyn ToSql + Send + Sync > > ) ,
145
163
) -> Result < u64 , PoolError > {
146
164
let client = pool. get ( ) . await ?;
147
165
148
- let mut where_clauses = vec ! [ ] ;
149
- if !chains. is_empty ( ) {
150
- where_clauses. push ( format ! (
151
- "chain_id IN ({})" ,
152
- chains
153
- . iter( )
154
- . map( |id| id. to_u32( ) . to_string( ) )
155
- . collect:: <Vec <String >>( )
156
- . join( "," )
157
- ) ) ;
158
- }
159
-
160
- let row = match validator {
161
- Some ( validator) => {
162
- where_clauses. push ( "(leader = $1 OR follower = $1)" . to_string ( ) ) ;
166
+ let statement = if !where_clauses. is_empty ( ) {
167
+ format ! (
168
+ "SELECT COUNT(id)::varchar FROM channels WHERE {}" ,
169
+ where_clauses. join( " AND " )
170
+ )
171
+ } else {
172
+ format ! ( "SELECT COUNT(id)::varchar FROM channels" )
173
+ } ;
163
174
164
- let statement = format ! (
165
- "SELECT COUNT(id)::varchar FROM channels WHERE {}" ,
166
- where_clauses. join( " AND " )
167
- ) ;
168
- let stmt = client. prepare ( & statement) . await ?;
175
+ let stmt = client. prepare ( & statement) . await ?;
169
176
170
- client. query_one ( & stmt, & [ & validator. to_string ( ) ] ) . await ?
171
- }
172
- None => {
173
- let statement = if !where_clauses. is_empty ( ) {
174
- format ! (
175
- "SELECT COUNT(id)::varchar FROM channels WHERE {}" ,
176
- where_clauses. join( " AND " )
177
- )
178
- } else {
179
- "SELECT COUNT(id)::varchar FROM channels" . to_string ( )
180
- } ;
181
- let stmt = client. prepare ( & statement) . await ?;
182
-
183
- client. query_one ( & stmt, & [ ] ) . await ?
184
- }
185
- } ;
177
+ let stream = client. query_raw ( & stmt, params) . await ?;
178
+ pin_mut ! ( stream) ;
179
+ let row = stream
180
+ . try_next ( )
181
+ . await ?
182
+ . expect ( "Query should always return exactly 1 row!" ) ;
186
183
187
184
Ok ( row. get :: < _ , TotalCount > ( 0 ) . 0 )
188
185
}
189
186
}
190
187
191
188
#[ cfg( test) ]
192
189
mod test {
193
- use primitives:: { config:: GANACHE_CONFIG , test_util:: DUMMY_CAMPAIGN } ;
190
+ use adapter:: ethereum:: test_util:: { GANACHE_1 , GANACHE_INFO_1 } ;
191
+ use primitives:: {
192
+ config:: GANACHE_CONFIG , sentry:: Pagination , test_util:: DUMMY_CAMPAIGN , ChainOf , Channel ,
193
+ } ;
194
194
195
195
use crate :: db:: {
196
196
insert_channel,
@@ -206,18 +206,40 @@ mod test {
206
206
. await
207
207
. expect ( "Should setup migrations" ) ;
208
208
209
- let channel_chain = GANACHE_CONFIG
209
+ let channel_1337 = GANACHE_CONFIG
210
210
. find_chain_of ( DUMMY_CAMPAIGN . channel . token )
211
- . expect ( "Channel token should be whitelisted in config!" ) ;
212
- let channel_context = channel_chain. with_channel ( DUMMY_CAMPAIGN . channel ) ;
211
+ . expect ( "Channel token should be whitelisted in config!" )
212
+ . with_channel ( DUMMY_CAMPAIGN . channel ) ;
213
+
214
+ let channel_1 = {
215
+ let token_info = GANACHE_INFO_1 . tokens [ "Mocked TOKEN 1" ] . clone ( ) ;
216
+
217
+ let channel_1 = Channel {
218
+ token : token_info. address ,
219
+ ..DUMMY_CAMPAIGN . channel
220
+ } ;
213
221
214
- let actual_channel = {
215
- let insert = insert_channel ( & database. pool , & channel_context)
222
+ ChainOf :: new ( GANACHE_1 . clone ( ) , token_info) . with_channel ( channel_1)
223
+ } ;
224
+
225
+ assert_ne ! (
226
+ channel_1337. chain. chain_id, channel_1. chain. chain_id,
227
+ "The two channels should be on different Chains!"
228
+ ) ;
229
+
230
+ // Insert channel on chain #1
231
+ insert_channel ( & database. pool , & channel_1)
232
+ . await
233
+ . expect ( "Should insert Channel" ) ;
234
+
235
+ // try to insert the same channel twice
236
+ let actual_channel_1337 = {
237
+ let insert = insert_channel ( & database. pool , & channel_1337)
216
238
. await
217
239
. expect ( "Should insert Channel" ) ;
218
240
219
241
// once inserted, the channel should only be returned by the function
220
- let only_select = insert_channel ( & database. pool , & channel_context )
242
+ let only_select = insert_channel ( & database. pool , & channel_1337 )
221
243
. await
222
244
. expect ( "Should run insert with RETURNING on the Channel" ) ;
223
245
@@ -226,17 +248,46 @@ mod test {
226
248
only_select
227
249
} ;
228
250
229
- let response = list_channels (
230
- & database. pool ,
231
- 0 ,
232
- 10 ,
233
- None ,
234
- & [ channel_context. chain . chain_id ] ,
235
- )
236
- . await
237
- . expect ( "Should list Channels" ) ;
251
+ // List Channels with Chain #1337
252
+ {
253
+ // Check the response using only that channel's ChainId
254
+ let response =
255
+ list_channels ( & database. pool , 0 , 10 , None , & [ channel_1337. chain . chain_id ] )
256
+ . await
257
+ . expect ( "Should list Channels" ) ;
258
+
259
+ assert_eq ! ( 1 , response. channels. len( ) ) ;
260
+ assert_eq ! (
261
+ response. channels[ 0 ] , actual_channel_1337,
262
+ "Only the single channel of Chain #1337 should be returned"
263
+ ) ;
264
+ }
238
265
239
- assert_eq ! ( 1 , response. channels. len( ) ) ;
240
- assert_eq ! ( DUMMY_CAMPAIGN . channel, actual_channel) ;
266
+ // Cist channels with Chain #1 and Chain #1337
267
+ {
268
+ let response = list_channels (
269
+ & database. pool ,
270
+ 0 ,
271
+ 10 ,
272
+ None ,
273
+ & [ channel_1337. chain . chain_id , channel_1. chain . chain_id ] ,
274
+ )
275
+ . await
276
+ . expect ( "Should list Channels" ) ;
277
+
278
+ assert_eq ! ( 2 , response. channels. len( ) ) ;
279
+ assert_eq ! (
280
+ Pagination {
281
+ total_pages: 1 ,
282
+ page: 0 ,
283
+ } ,
284
+ response. pagination
285
+ ) ;
286
+ pretty_assertions:: assert_eq!(
287
+ response. channels,
288
+ vec![ channel_1. context, actual_channel_1337] ,
289
+ "All channels in ASC order should be returned"
290
+ ) ;
291
+ }
241
292
}
242
293
}
0 commit comments