|
1 | 1 | from typing import Any |
2 | 2 |
|
3 | | -from fastapi import APIRouter, Depends |
| 3 | +from fastapi import APIRouter, Depends, HTTPException, Response, status |
| 4 | +from fastapi.responses import StreamingResponse |
4 | 5 | from sqlalchemy.orm import Session |
5 | 6 |
|
6 | 7 | from fastapi_2fa.api.deps.db import get_db |
7 | 8 | from fastapi_2fa.api.deps.users import get_authenticated_user |
| 9 | +from fastapi_2fa.core.enums import DeviceTypeEnum |
| 10 | +from fastapi_2fa.core.utils import send_backup_tokens |
8 | 11 | from fastapi_2fa.crud.device import device_crud |
9 | 12 | from fastapi_2fa.crud.users import user_crud |
10 | 13 | from fastapi_2fa.models.users import User |
11 | 14 | from fastapi_2fa.schemas.device_schema import DeviceCreate |
12 | | -from fastapi_2fa.schemas.user_schema import UserOut, UserUpdate |
| 15 | +from fastapi_2fa.schemas.user_schema import UserUpdate |
13 | 16 |
|
14 | 17 | tfa_router = APIRouter() |
15 | 18 |
|
16 | 19 |
|
17 | 20 | @tfa_router.put( |
18 | 21 | "/enable_tfa", |
19 | 22 | summary="Enable two factor authentication for registered user", |
20 | | - response_model=UserOut, |
| 23 | + responses={ |
| 24 | + 200: { |
| 25 | + "content": {"image/png": {}}, |
| 26 | + "description": "Returns no content or a qr code " |
| 27 | + "if device_type is 'code_generator'", |
| 28 | + } |
| 29 | + }, |
21 | 30 | ) |
22 | 31 | async def enable_tfa( |
23 | 32 | device: DeviceCreate, |
24 | 33 | db: Session = Depends(get_db), |
25 | 34 | user: User = Depends(get_authenticated_user), |
26 | 35 | ) -> Any: |
27 | | - if not user_crud.is_tfa_enabled(user): |
28 | | - async with db.begin_nested(): |
29 | | - user = await user_crud.update( |
30 | | - db=db, |
31 | | - db_obj=user, |
32 | | - obj_in=UserUpdate(tfa_enabled=True) |
33 | | - ) |
34 | | - await device_crud.create( |
35 | | - db=db, |
36 | | - device=device, |
37 | | - user=user |
38 | | - ) |
39 | | - await db.refresh(user) |
40 | | - return user |
| 36 | + if not user_crud(transaction=True).is_tfa_enabled(user): |
| 37 | + user = await user_crud.update( |
| 38 | + db=db, |
| 39 | + db_obj=user, |
| 40 | + obj_in=UserUpdate(tfa_enabled=True) |
| 41 | + ) |
| 42 | + device, qr_code = await device_crud(transaction=True).create( |
| 43 | + db=db, |
| 44 | + device=device, |
| 45 | + user=user |
| 46 | + ) |
| 47 | + |
| 48 | + send_backup_tokens(user=user, device=device) |
| 49 | + |
| 50 | + if device.device_type == DeviceTypeEnum.CODE_GENERATOR: |
| 51 | + return StreamingResponse(content=qr_code, media_type="image/png") |
| 52 | + |
| 53 | + return Response(status_code=status.HTTP_200_OK) |
| 54 | + |
| 55 | + raise HTTPException( |
| 56 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 57 | + detail='Two factor authentication already ' |
| 58 | + f'active for user {user.email}' |
| 59 | + ) |
0 commit comments