1
- use primitives:: { Channel , ChannelId } ;
1
+ use primitives:: { Channel , ChannelId , ChainOf } ;
2
2
3
3
pub use list_channels:: list_channels;
4
4
@@ -26,17 +26,19 @@ pub async fn get_channel_by_id(
26
26
/// This call should never trigger a `SqlState::UNIQUE_VIOLATION`
27
27
///
28
28
/// ```sql
29
- /// INSERT INTO channels (id, leader, follower, guardian, token, nonce, created)
30
- /// VALUES ($1, $2, $3, $4, $5, $6, NOW())
29
+ /// INSERT INTO channels (id, leader, follower, guardian, token, nonce, chain_id, created)
30
+ /// VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
31
31
/// ON CONFLICT ON CONSTRAINT channels_pkey DO UPDATE SET created=EXCLUDED.created
32
32
/// RETURNING leader, follower, guardian, token, nonce
33
33
/// ```
34
- pub async fn insert_channel ( pool : & DbPool , channel : Channel ) -> Result < Channel , PoolError > {
34
+ pub async fn insert_channel ( pool : & DbPool , channel_chain : & ChainOf < Channel > ) -> Result < Channel , PoolError > {
35
35
let client = pool. get ( ) . await ?;
36
+ let chain_id = channel_chain. chain . chain_id ;
37
+ let channel = channel_chain. context ;
36
38
37
39
// We use `EXCLUDED.created` in order to have to DO UPDATE otherwise it does not return the fields
38
40
// when there is a CONFLICT
39
- let stmt = client. prepare ( "INSERT INTO channels (id, leader, follower, guardian, token, nonce, created) VALUES ($1, $2, $3, $4, $5, $6, NOW())
41
+ let stmt = client. prepare ( "INSERT INTO channels (id, leader, follower, guardian, token, nonce, chain_id, created) VALUES ($1, $2, $3, $4, $5, $6, $7 , NOW())
40
42
ON CONFLICT ON CONSTRAINT channels_pkey DO UPDATE SET created=EXCLUDED.created RETURNING leader, follower, guardian, token, nonce" ) . await ?;
41
43
42
44
let row = client
@@ -49,6 +51,7 @@ pub async fn insert_channel(pool: &DbPool, channel: Channel) -> Result<Channel,
49
51
& channel. guardian ,
50
52
& channel. token ,
51
53
& channel. nonce ,
54
+ & chain_id,
52
55
] ,
53
56
)
54
57
. await ?;
@@ -58,8 +61,8 @@ pub async fn insert_channel(pool: &DbPool, channel: Channel) -> Result<Channel,
58
61
59
62
mod list_channels {
60
63
use primitives:: {
61
- sentry:: { channel_list:: ChannelListResponse , Pagination } ,
62
- Channel , ValidatorId ,
64
+ sentry:: { channel_list:: { ChannelListResponse , ChannelListQuery } , Pagination } ,
65
+ Channel ,
63
66
} ;
64
67
65
68
use crate :: db:: { DbPool , PoolError , TotalCount } ;
@@ -71,24 +74,36 @@ mod list_channels {
71
74
pool : & DbPool ,
72
75
skip : u64 ,
73
76
limit : u32 ,
74
- validator : Option < ValidatorId > ,
77
+ query : & ChannelListQuery ,
75
78
) -> Result < ChannelListResponse , PoolError > {
76
79
let client = pool. get ( ) . await ?;
80
+ let mut where_clauses = vec ! [ ] ;
81
+ if !query. chains . is_empty ( ) {
82
+ where_clauses. push ( format ! (
83
+ "chain_id IN ({})" ,
84
+ query
85
+ . chains
86
+ . iter( )
87
+ . map( |id| id. to_u32( ) . to_string( ) )
88
+ . collect:: <Vec <String >>( )
89
+ . join( "," )
90
+ ) ) ;
91
+ }
77
92
78
93
// To understand why we use Order by, see Postgres Documentation: https://www.postgresql.org/docs/8.1/queries-limit.html
79
- let rows = match validator {
94
+ let rows = match query . validator {
80
95
Some ( validator) => {
81
- let where_clause = "(leader = $1 OR follower = $1)" . to_string ( ) ;
96
+ where_clauses . push ( "(leader = $1 OR follower = $1)" . to_string ( ) ) ;
82
97
83
98
let statement = format ! ( "SELECT leader, follower, guardian, token, nonce, created FROM channels WHERE {} ORDER BY created ASC LIMIT {} OFFSET {}" ,
84
- where_clause , limit, skip) ;
99
+ where_clauses . join ( " AND " ) , limit, skip) ;
85
100
let stmt = client. prepare ( & statement) . await ?;
86
101
87
102
client. query ( & stmt, & [ & validator. to_string ( ) ] ) . await ?
88
103
}
89
104
None => {
90
- let statement = format ! ( "SELECT id, leader, follower, guardian, token, nonce, created FROM channels ORDER BY created ASC LIMIT {} OFFSET {}" ,
91
- limit, skip) ;
105
+ let statement = format ! ( "SELECT id, leader, follower, guardian, token, nonce, created FROM channels WHERE {} ORDER BY created ASC LIMIT {} OFFSET {}" ,
106
+ where_clauses . join ( " AND " ) , limit, skip) ;
92
107
let stmt = client. prepare ( & statement) . await ?;
93
108
94
109
client. query ( & stmt, & [ ] ) . await ?
@@ -97,7 +112,7 @@ mod list_channels {
97
112
98
113
let channels = rows. iter ( ) . map ( Channel :: from) . collect ( ) ;
99
114
100
- let total_count = list_channels_total_count ( pool, validator ) . await ?;
115
+ let total_count = list_channels_total_count ( pool, query ) . await ?;
101
116
102
117
// fast ceil for total_pages
103
118
let total_pages = if total_count == 0 {
@@ -117,25 +132,38 @@ mod list_channels {
117
132
118
133
async fn list_channels_total_count < ' a > (
119
134
pool : & DbPool ,
120
- validator : Option < ValidatorId > ,
135
+ query : & ChannelListQuery ,
121
136
) -> Result < u64 , PoolError > {
122
137
let client = pool. get ( ) . await ?;
123
138
124
- let row = match validator {
139
+ let mut where_clauses = vec ! [ ] ;
140
+ if !query. chains . is_empty ( ) {
141
+ where_clauses. push ( format ! (
142
+ "chain_id IN ({})" ,
143
+ query
144
+ . chains
145
+ . iter( )
146
+ . map( |id| id. to_u32( ) . to_string( ) )
147
+ . collect:: <Vec <String >>( )
148
+ . join( "," )
149
+ ) ) ;
150
+ }
151
+
152
+ let row = match query. validator {
125
153
Some ( validator) => {
126
- let where_clause = "(leader = $1 OR follower = $1)" . to_string ( ) ;
154
+ where_clauses . push ( "(leader = $1 OR follower = $1)" . to_string ( ) ) ;
127
155
128
156
let statement = format ! (
129
157
"SELECT COUNT(id)::varchar FROM channels WHERE {}" ,
130
- where_clause
158
+ where_clauses . join ( " AND " )
131
159
) ;
132
160
let stmt = client. prepare ( & statement) . await ?;
133
161
134
162
client. query_one ( & stmt, & [ & validator. to_string ( ) ] ) . await ?
135
163
}
136
164
None => {
137
- let statement = "SELECT COUNT(id)::varchar FROM channels" ;
138
- let stmt = client. prepare ( statement) . await ?;
165
+ let statement = format ! ( "SELECT COUNT(id)::varchar FROM channels WHERE {}" , where_clauses . join ( " AND " ) ) ;
166
+ let stmt = client. prepare ( & statement) . await ?;
139
167
140
168
client. query_one ( & stmt, & [ ] ) . await ?
141
169
}
@@ -147,29 +175,36 @@ mod list_channels {
147
175
148
176
#[ cfg( test) ]
149
177
mod test {
150
- use primitives:: test_util:: DUMMY_CAMPAIGN ;
178
+ use primitives:: { test_util:: DUMMY_CAMPAIGN , sentry :: channel_list :: ChannelListQuery , ChainId } ;
151
179
152
- use crate :: db:: {
180
+ use crate :: { db:: {
153
181
insert_channel,
154
182
tests_postgres:: { setup_test_migrations, DATABASE_POOL } ,
155
- } ;
183
+ } , test_util :: setup_dummy_app } ;
156
184
157
185
use super :: list_channels:: list_channels;
158
186
159
187
#[ tokio:: test]
160
188
async fn insert_and_list_channels_return_channels ( ) {
189
+ let app = setup_dummy_app ( ) . await ;
161
190
let database = DATABASE_POOL . get ( ) . await . expect ( "Should get database" ) ;
162
191
setup_test_migrations ( database. pool . clone ( ) )
163
192
. await
164
193
. expect ( "Should setup migrations" ) ;
165
194
195
+ let channel_chain = app
196
+ . config
197
+ . find_chain_of ( DUMMY_CAMPAIGN . channel . token )
198
+ . expect ( "Channel token should be whitelisted in config!" ) ;
199
+ let channel_context = channel_chain. with_channel ( DUMMY_CAMPAIGN . channel ) ;
200
+
166
201
let actual_channel = {
167
- let insert = insert_channel ( & database. pool , DUMMY_CAMPAIGN . channel )
202
+ let insert = insert_channel ( & database. pool , & channel_context )
168
203
. await
169
204
. expect ( "Should insert Channel" ) ;
170
205
171
206
// once inserted, the channel should only be returned by the function
172
- let only_select = insert_channel ( & database. pool , DUMMY_CAMPAIGN . channel )
207
+ let only_select = insert_channel ( & database. pool , & channel_context )
173
208
. await
174
209
. expect ( "Should run insert with RETURNING on the Channel" ) ;
175
210
@@ -178,7 +213,13 @@ mod test {
178
213
only_select
179
214
} ;
180
215
181
- let response = list_channels ( & database. pool , 0 , 10 , None )
216
+ let query = ChannelListQuery {
217
+ page : 0 ,
218
+ validator : None ,
219
+ chains : vec ! [ channel_context. chain. chain_id] ,
220
+ } ;
221
+
222
+ let response = list_channels ( & database. pool , 0 , 10 , & query)
182
223
. await
183
224
. expect ( "Should list Channels" ) ;
184
225
0 commit comments