@@ -44,29 +44,18 @@ impl AuthInfo {
4444 }
4545}
4646
47- #[ derive( Deserialize , Debug , PartialEq , Eq ) ]
47+ #[ derive( Deserialize , Debug ) ]
48+ #[ serde( rename_all = "lowercase" ) ]
4849pub ( crate ) enum FlowType {
4950 Enrollment ,
5051 Mfa ,
5152}
5253
53- impl std:: str:: FromStr for FlowType {
54- type Err = ( ) ;
55-
56- fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
57- match s. to_lowercase ( ) . as_str ( ) {
58- "enrollment" => Ok ( FlowType :: Enrollment ) ,
59- "mfa" => Ok ( FlowType :: Mfa ) ,
60- _ => Err ( ( ) ) ,
61- }
62- }
63- }
64-
6554#[ derive( Deserialize , Debug ) ]
66- struct RequestData {
55+ pub ( crate ) struct RequestData {
6756 state : Option < String > ,
6857 #[ serde( rename = "type" ) ]
69- flow_type : String ,
58+ flow_type : FlowType ,
7059}
7160
7261/// Request external OAuth2/OpenID provider details from Defguard Core.
@@ -79,13 +68,8 @@ async fn auth_info(
7968) -> Result < ( PrivateCookieJar , Json < AuthInfo > ) , ApiError > {
8069 debug ! ( "Getting auth info for OAuth2/OpenID login" ) ;
8170
82- let flow_type = request_data
83- . flow_type
84- . parse :: < FlowType > ( )
85- . map_err ( |( ) | ApiError :: BadRequest ( "Invalid flow type" . into ( ) ) ) ?;
86-
8771 let request = AuthInfoRequest {
88- redirect_url : state. callback_url ( & flow_type) . to_string ( ) ,
72+ redirect_url : state. callback_url ( & request_data . flow_type ) . to_string ( ) ,
8973 state : request_data. state ,
9074 } ;
9175
@@ -127,7 +111,7 @@ pub(super) struct AuthenticationResponse {
127111 pub ( super ) code : String ,
128112 pub ( super ) state : String ,
129113 #[ serde( rename = "type" ) ]
130- pub ( super ) flow_type : String ,
114+ pub ( super ) flow_type : FlowType ,
131115}
132116
133117#[ derive( Serialize ) ]
@@ -143,15 +127,13 @@ async fn auth_callback(
143127 mut private_cookies : PrivateCookieJar ,
144128 Json ( payload) : Json < AuthenticationResponse > ,
145129) -> Result < ( PrivateCookieJar , Json < CallbackResponseData > ) , ApiError > {
146- let flow_type = payload
147- . flow_type
148- . parse :: < FlowType > ( )
149- . map_err ( |( ) | ApiError :: BadRequest ( "Invalid flow type" . into ( ) ) ) ?;
150-
151- if flow_type != FlowType :: Enrollment {
152- return Err ( ApiError :: BadRequest (
153- "Invalid flow type for OpenID enrollment callback" . into ( ) ,
154- ) ) ;
130+ match payload. flow_type {
131+ FlowType :: Enrollment => ( ) ,
132+ FlowType :: Mfa => {
133+ return Err ( ApiError :: BadRequest (
134+ "Invalid flow type for OpenID enrollment callback" . into ( ) ,
135+ ) ) ;
136+ }
155137 }
156138
157139 let nonce = private_cookies
@@ -176,7 +158,7 @@ async fn auth_callback(
176158 let request = AuthCallbackRequest {
177159 code : payload. code ,
178160 nonce,
179- callback_url : state. callback_url ( & flow_type) . to_string ( ) ,
161+ callback_url : state. callback_url ( & payload . flow_type ) . to_string ( ) ,
180162 } ;
181163
182164 let rx = state
0 commit comments