|
4 | 4 | // SPDX-License-Identifier: AGPL-3.0-only |
5 | 5 | // Please see LICENSE in the repository root for full details. |
6 | 6 |
|
| 7 | +use std::sync::Arc; |
| 8 | + |
7 | 9 | use axum::{ |
8 | 10 | extract::{Path, Query, State}, |
9 | 11 | response::{IntoResponse, Redirect}, |
10 | 12 | }; |
11 | 13 | use hyper::StatusCode; |
12 | 14 | use mas_axum_utils::{cookies::CookieJar, record_error}; |
13 | | -use mas_data_model::UpstreamOAuthProvider; |
| 15 | +use mas_data_model::{UpstreamOAuthProvider, oauth2::LoginHint}; |
| 16 | +use mas_matrix::HomeserverConnection; |
14 | 17 | use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; |
15 | 18 | use mas_router::{PostAuthAction, UrlBuilder}; |
16 | 19 | use mas_storage::{ |
@@ -66,6 +69,7 @@ pub(crate) async fn get( |
66 | 69 | cookie_jar: CookieJar, |
67 | 70 | Path(provider_id): Path<Ulid>, |
68 | 71 | Query(query): Query<OptionalPostAuthAction>, |
| 72 | + State(homeserver): State<Arc<dyn HomeserverConnection>>, |
69 | 73 | ) -> Result<impl IntoResponse, RouteError> { |
70 | 74 | let provider = repo |
71 | 75 | .upstream_oauth_provider() |
@@ -96,13 +100,11 @@ pub(crate) async fn get( |
96 | 100 | // sees fit |
97 | 101 | if provider.forward_login_hint { |
98 | 102 | if let Some(PostAuthAction::ContinueAuthorizationGrant { id }) = &query.post_auth_action { |
99 | | - if let Some(login_hint) = repo |
100 | | - .oauth2_authorization_grant() |
101 | | - .lookup(*id) |
102 | | - .await? |
103 | | - .and_then(|grant| grant.login_hint) |
104 | | - { |
105 | | - data = data.with_login_hint(login_hint); |
| 103 | + if let Some(grant) = repo.oauth2_authorization_grant().lookup(*id).await? { |
| 104 | + match grant.parse_login_hint(homeserver.homeserver()) { |
| 105 | + LoginHint::MXID(mxid) => data = data.with_login_hint(mxid.to_string()), |
| 106 | + LoginHint::None => (), |
| 107 | + } |
106 | 108 | } |
107 | 109 | } |
108 | 110 | } |
|
0 commit comments