3
3
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4
4
// Please see LICENSE files in the repository root for full details.
5
5
6
- use aide:: { NoApi , OperationIo , transform:: TransformOperation } ;
7
- use axum:: { Json , response:: IntoResponse } ;
6
+ use std:: sync:: Arc ;
7
+
8
+ use aide:: { OperationIo , transform:: TransformOperation } ;
9
+ use axum:: { Json , extract:: State , response:: IntoResponse } ;
8
10
use hyper:: StatusCode ;
9
11
use mas_axum_utils:: record_error;
10
- use mas_storage:: {
11
- BoxRng ,
12
- queue:: { QueueJobRepositoryExt as _, ReactivateUserJob } ,
13
- } ;
14
- use tracing:: info;
12
+ use mas_matrix:: HomeserverConnection ;
15
13
use ulid:: Ulid ;
16
14
17
15
use crate :: {
@@ -30,6 +28,9 @@ pub enum RouteError {
30
28
#[ error( transparent) ]
31
29
Internal ( Box < dyn std:: error:: Error + Send + Sync + ' static > ) ,
32
30
31
+ #[ error( transparent) ]
32
+ Homeserver ( anyhow:: Error ) ,
33
+
33
34
#[ error( "User ID {0} not found" ) ]
34
35
NotFound ( Ulid ) ,
35
36
}
@@ -39,9 +40,9 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
39
40
impl IntoResponse for RouteError {
40
41
fn into_response ( self ) -> axum:: response:: Response {
41
42
let error = ErrorResponse :: from_error ( & self ) ;
42
- let sentry_event_id = record_error ! ( self , Self :: Internal ( _) ) ;
43
+ let sentry_event_id = record_error ! ( self , Self :: Internal ( _) | Self :: Homeserver ( _ ) ) ;
43
44
let status = match self {
44
- Self :: Internal ( _) => StatusCode :: INTERNAL_SERVER_ERROR ,
45
+ Self :: Internal ( _) | Self :: Homeserver ( _ ) => StatusCode :: INTERNAL_SERVER_ERROR ,
45
46
Self :: NotFound ( _) => StatusCode :: NOT_FOUND ,
46
47
} ;
47
48
( status, sentry_event_id, Json ( error) ) . into_response ( )
@@ -69,10 +70,8 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
69
70
70
71
#[ tracing:: instrument( name = "handler.admin.v1.users.reactivate" , skip_all) ]
71
72
pub async fn handler (
72
- CallContext {
73
- mut repo, clock, ..
74
- } : CallContext ,
75
- NoApi ( mut rng) : NoApi < BoxRng > ,
73
+ CallContext { mut repo, .. } : CallContext ,
74
+ State ( homeserver) : State < Arc < dyn HomeserverConnection > > ,
76
75
id : UlidPathParam ,
77
76
) -> Result < Json < SingleResponse < User > > , RouteError > {
78
77
let id = * id;
@@ -82,10 +81,15 @@ pub async fn handler(
82
81
. await ?
83
82
. ok_or ( RouteError :: NotFound ( id) ) ?;
84
83
85
- info ! ( %user. id, "Scheduling reactivation of user" ) ;
86
- repo. queue_job ( )
87
- . schedule_job ( & mut rng, & clock, ReactivateUserJob :: new ( & user, false ) )
88
- . await ?;
84
+ // Call the homeserver synchronously to reactivate the user
85
+ let mxid = homeserver. mxid ( & user. username ) ;
86
+ homeserver
87
+ . reactivate_user ( & mxid)
88
+ . await
89
+ . map_err ( RouteError :: Homeserver ) ?;
90
+
91
+ // Now reactivate the user in our database
92
+ let user = repo. user ( ) . reactivate ( user) . await ?;
89
93
90
94
repo. save ( ) . await ?;
91
95
@@ -100,7 +104,7 @@ mod tests {
100
104
use hyper:: { Request , StatusCode } ;
101
105
use mas_matrix:: { HomeserverConnection , ProvisionRequest } ;
102
106
use mas_storage:: { Clock , RepositoryAccess , user:: UserRepository } ;
103
- use sqlx:: { PgPool , types :: Json } ;
107
+ use sqlx:: PgPool ;
104
108
105
109
use crate :: test_utils:: { RequestBuilderExt , ResponseExt , TestState , setup} ;
106
110
@@ -150,18 +154,10 @@ mod tests {
150
154
body[ "data" ] [ "attributes" ] [ "locked_at" ] ,
151
155
serde_json:: json!( state. clock. now( ) )
152
156
) ;
153
- // TODO: have test coverage on deactivated_at timestamp
154
-
155
- // It should have scheduled a reactivation job for the user
156
- // XXX: we don't have a good way to look for the reactivation job
157
- let job: Json < serde_json:: Value > = sqlx:: query_scalar (
158
- "SELECT payload FROM queue_jobs WHERE queue_name = 'reactivate-user'" ,
159
- )
160
- . fetch_one ( & pool)
161
- . await
162
- . expect ( "Reactivation job to be scheduled" ) ;
163
- assert_eq ! ( job[ "user_id" ] , serde_json:: json!( user. id) ) ;
164
- assert_eq ! ( job[ "unlock" ] , serde_json:: Value :: Bool ( false ) ) ;
157
+ assert_eq ! (
158
+ body[ "data" ] [ "attributes" ] [ "deactivated_at" ] ,
159
+ serde_json:: Value :: Null ,
160
+ ) ;
165
161
}
166
162
167
163
#[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
@@ -178,6 +174,14 @@ mod tests {
178
174
. unwrap ( ) ;
179
175
repo. save ( ) . await . unwrap ( ) ;
180
176
177
+ // Provision the user on the homeserver
178
+ let mxid = state. homeserver_connection . mxid ( & user. username ) ;
179
+ state
180
+ . homeserver_connection
181
+ . provision_user ( & ProvisionRequest :: new ( & mxid, & user. sub ) )
182
+ . await
183
+ . unwrap ( ) ;
184
+
181
185
let request = Request :: post ( format ! ( "/api/admin/v1/users/{}/reactivate" , user. id) )
182
186
. bearer ( & token)
183
187
. empty ( ) ;
@@ -189,18 +193,10 @@ mod tests {
189
193
body[ "data" ] [ "attributes" ] [ "locked_at" ] ,
190
194
serde_json:: Value :: Null
191
195
) ;
192
- // TODO: have test coverage on deactivated_at timestamp
193
-
194
- // It should have scheduled a reactivation job for the user
195
- // XXX: we don't have a good way to look for the reactivation job
196
- let job: Json < serde_json:: Value > = sqlx:: query_scalar (
197
- "SELECT payload FROM queue_jobs WHERE queue_name = 'reactivate-user'" ,
198
- )
199
- . fetch_one ( & pool)
200
- . await
201
- . expect ( "Reactivation job to be scheduled" ) ;
202
- assert_eq ! ( job[ "user_id" ] , serde_json:: json!( user. id) ) ;
203
- assert_eq ! ( job[ "unlock" ] , serde_json:: Value :: Bool ( false ) ) ;
196
+ assert_eq ! (
197
+ body[ "data" ] [ "attributes" ] [ "deactivated_at" ] ,
198
+ serde_json:: Value :: Null
199
+ ) ;
204
200
}
205
201
206
202
#[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
0 commit comments