55from sqlalchemy .orm import Session
66
77from fastapi_2fa .api .deps .db import get_db
8- from fastapi_2fa .api .deps .users import get_authenticated_user
8+ from fastapi_2fa .api .deps .users import (get_authenticated_user ,
9+ get_authenticated_user_pre_tfa )
10+ from fastapi_2fa .core import security
911from fastapi_2fa .core .enums import DeviceTypeEnum
10- from fastapi_2fa .core .utils import send_backup_tokens
12+ from fastapi_2fa .core .two_factor_auth import (qr_code_from_key ,
13+ verify_backup_token ,
14+ verify_token )
15+ from fastapi_2fa .core .utils import send_mail_backup_tokens
16+ from fastapi_2fa .crud .backup_token import backup_token_crud
1117from fastapi_2fa .crud .device import device_crud
1218from fastapi_2fa .crud .users import user_crud
1319from fastapi_2fa .models .users import User
1420from fastapi_2fa .schemas .device_schema import DeviceCreate
21+ from fastapi_2fa .schemas .jwt_token_schema import JwtTokenSchema
1522from fastapi_2fa .schemas .user_schema import UserUpdate
1623
1724tfa_router = APIRouter ()
1825
1926
27+ @tfa_router .post (
28+ "/login_tfa" ,
29+ summary = "Verify two factor authentication token" ,
30+ response_model = JwtTokenSchema ,
31+ )
32+ async def login_tfa (
33+ tfa_token : str ,
34+ db : Session = Depends (get_db ),
35+ user : User = Depends (get_authenticated_user_pre_tfa ),
36+ ) -> Any :
37+ if verify_token (user = user , token = tfa_token ):
38+ return JwtTokenSchema (
39+ access_token = security .create_jwt_access_token (user .id ),
40+ refresh_token = security .create_jwt_refresh_token (user .id ),
41+ )
42+
43+ raise HTTPException (
44+ status_code = status .HTTP_403_FORBIDDEN ,
45+ detail = "TOTP token mismatch"
46+ )
47+
48+
49+ @tfa_router .post (
50+ "/recover_tfa" ,
51+ summary = "Checks and consumes one of the user's backup tokens re initializing" ,
52+ response_model = JwtTokenSchema ,
53+ )
54+ async def recover_tfa (
55+ tfa_backup_token : str ,
56+ db : Session = Depends (get_db ),
57+ user : User = Depends (get_authenticated_user_pre_tfa ),
58+ ) -> Any :
59+ if backup_tokens := await backup_token_crud .get_user_backup_tokens (
60+ db = db ,
61+ user = user
62+ ):
63+ matched_bkp_token = verify_backup_token (
64+ backup_tokens = backup_tokens ,
65+ tfa_backup_token = tfa_backup_token
66+ )
67+
68+ if matched_bkp_token :
69+ print ('..consuming backup token' )
70+ await backup_token_crud .remove (
71+ db = db , id = matched_bkp_token .id
72+ )
73+ return JwtTokenSchema (
74+ access_token = security .create_jwt_access_token (user .id ),
75+ refresh_token = security .create_jwt_refresh_token (user .id ),
76+ )
77+
78+ raise HTTPException (
79+ status_code = status .HTTP_403_FORBIDDEN ,
80+ detail = "TOTP backup token not found"
81+ )
82+
83+ # user has elapsed all backup tokens
84+ raise HTTPException (
85+ status_code = status .HTTP_404_NOT_FOUND ,
86+ detail = f"User { user .email } has elapsed his backup tokens, "
87+ "please contact the system administrator"
88+ )
89+
90+
91+ @tfa_router .get (
92+ "/get_my_qrcode" ,
93+ summary = "Returns authenticated user's qr_code "
94+ "if user's device is of type 'code_generator'" ,
95+ responses = {
96+ 200 : {
97+ "content" : {"image/png" : {}},
98+ "description" : "Returns no content or a qr code "
99+ "if tfas is enabled and device_type "
100+ "is 'code_generator'" ,
101+ }
102+ },
103+ )
104+ async def get_my_qrcode (
105+ user : User = Depends (get_authenticated_user ),
106+ ) -> Any :
107+ if (
108+ user .tfa_enabled and
109+ user_crud .device .device_type == DeviceTypeEnum .CODE_GENERATOR
110+ ):
111+ qr_code = qr_code_from_key (
112+ encoded_key = user .device .key ,
113+ user_email = user .email
114+ )
115+ return StreamingResponse (content = qr_code , media_type = "image/png" )
116+
117+ # user has elapsed all backup tokens
118+ raise HTTPException (
119+ status_code = status .HTTP_400_BAD_REQUEST ,
120+ detail = f"User { user .email } has not tfa enabled or "
121+ "has not a 'code_generator' device"
122+ )
123+
124+
20125@tfa_router .put (
21126 "/enable_tfa" ,
22127 summary = "Enable two factor authentication for registered user" ,
@@ -33,8 +138,8 @@ async def enable_tfa(
33138 db : Session = Depends (get_db ),
34139 user : User = Depends (get_authenticated_user ),
35140) -> Any :
36- if not user_crud ( transaction = True ). is_tfa_enabled ( user ) :
37- user = await user_crud .update (
141+ if not user . tfa_enabled :
142+ user = await user_crud ( transaction = True ) .update (
38143 db = db ,
39144 db_obj = user ,
40145 obj_in = UserUpdate (tfa_enabled = True )
@@ -45,7 +150,7 @@ async def enable_tfa(
45150 user = user
46151 )
47152
48- send_backup_tokens (user = user , device = device )
153+ send_mail_backup_tokens (user = user , device = device )
49154
50155 if device .device_type == DeviceTypeEnum .CODE_GENERATOR :
51156 return StreamingResponse (content = qr_code , media_type = "image/png" )
0 commit comments