Skip to content

Commit 207e33b

Browse files
committed
feat: add the 'get_current_user' dependency
1 parent 9705484 commit 207e33b

File tree

8 files changed

+127
-154
lines changed

8 files changed

+127
-154
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def vip_roles_and_article_update(request: Request):
145145

146146
```
147147

148-
### 依赖项
148+
### 依赖项(推荐)
149149

150-
- 推荐场景: 路由集合,FastAPI应用
150+
- 推荐场景: 单个路由,路由集合,FastAPI应用.
151151

152152
```python
153153
from fastapi import Depends
@@ -156,11 +156,10 @@ from fastapi_user_auth.auth import Auth
156156
from fastapi_user_auth.auth.models import User
157157

158158

159-
# 路由参数依赖项
160-
@app.get("/auth/admin_roles_depend_1")
161-
def admin_roles(request: Request,
162-
auth_result: Tuple[Auth, User] = Depends(auth.requires('admin')())):
163-
return request.user
159+
# 路由参数依赖项, 推荐使用此方式
160+
@app.get("/auth/admin_roles_depend_1")
161+
def admin_roles(user: User = Depends(auth.get_current_user)):
162+
return user # or request.user
164163

165164

166165
# 路径操作装饰器依赖项
@@ -200,6 +199,7 @@ from fastapi_user_auth.auth.models import User
200199

201200

202201
async def get_request_user(request: Request) -> Optional[User]:
202+
# user= await auth.get_current_user(request)
203203
if await auth.requires('admin', response=False)(request):
204204
return request.user
205205
else:

fastapi_user_auth/admin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastapi_user_auth.auth.models import BaseUser, User, Group, Permission, Role
1212
from fastapi_user_auth.auth.schemas import UserLoginOut
1313
from pydantic import BaseModel
14-
from sqlalchemy import insert, update
14+
from sqlalchemy import insert, update, select
1515
from starlette import status
1616
from starlette.requests import Request
1717
from starlette.responses import Response
@@ -126,10 +126,10 @@ async def handle(
126126
**kwargs
127127
) -> BaseApiOut[BaseModel]: # self.schema_submit_out
128128
auth: Auth = request.auth
129-
user = await auth.get_user_by_username(data.username)
129+
user = await auth.db.scalar(select(self.user_model).where(self.user_model.username == data.username))
130130
if user:
131131
return BaseApiOut(status = -1, msg = _('Username has been registered!'), data = None)
132-
user = await auth.get_user_by_whereclause(self.user_model.email == data.email)
132+
user = await auth.db.scalar(select(self.user_model).where(self.user_model.email == data.email))
133133
if user:
134134
return BaseApiOut(status = -2, msg = _('Email has been registered!'), data = None)
135135
user = self.user_model.parse_obj(data)

fastapi_user_auth/auth/auth.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Coroutine
66
from typing import Type, Any, TypeVar, Optional, Sequence, Tuple, Union, Callable, Generic
77

8-
from fastapi import FastAPI, HTTPException, Depends, Form
8+
from fastapi import FastAPI, HTTPException, Depends, Form, params
99
from fastapi.security import OAuth2PasswordBearer
1010
from fastapi.security.utils import get_authorization_scheme_param
1111
from fastapi_amis_admin.crud.base import RouterMixin
@@ -15,6 +15,7 @@
1515
from fastapi_amis_admin.utils.translation import i18n as _
1616
from passlib.context import CryptContext
1717
from pydantic import BaseModel, SecretStr
18+
from sqlalchemy.ext.asyncio import AsyncSession
1819
from sqlalchemy.orm import Session
1920
from sqlalchemy_database import AsyncDatabase, Database
2021
from sqlmodel import select
@@ -46,16 +47,7 @@ def get_user_token(request: Request) -> Optional[str]:
4647
return token
4748

4849
async def authenticate(self, request: Request) -> Tuple["Auth", Optional[_UserModelT]]:
49-
if request.scope.get('auth'): # 防止重复授权
50-
return request.scope.get('auth'), request.scope.get('user')
51-
request.scope["auth"], request.scope["user"] = self.auth, None
52-
token = self.get_user_token(request)
53-
if not token:
54-
return self.auth, None
55-
token_data = await self.token_store.read_token(token)
56-
if token_data is not None:
57-
request.scope["user"]: _UserModelT = await self.auth.get_user_by_username(token_data.username)
58-
return request.auth, request.user
50+
return self.auth, await self.auth.get_current_user(request)
5951

6052
def attach_middleware(self, app: FastAPI):
6153
app.add_middleware(AuthenticationMiddleware, backend = self) # 添加auth中间件
@@ -78,21 +70,34 @@ def __init__(
7870
self.backend = self.backend or AuthBackend(self, token_store or DbTokenStore(self.db))
7971
self.pwd_context = pwd_context
8072

81-
async def get_user_by_username(self, username: str) -> Optional[_UserModelT]:
82-
return await self.get_user_by_whereclause(self.user_model.username == username)
83-
84-
async def get_user_by_whereclause(self, *whereclause: Any) -> Optional[_UserModelT]:
85-
return await self.db.async_scalar(select(self.user_model).where(*whereclause))
86-
8773
async def authenticate_user(self, username: str, password: Union[str, SecretStr]) -> Optional[_UserModelT]:
88-
user = await self.get_user_by_username(username)
74+
user = await self.db.async_scalar(select(self.user_model).where(self.user_model.username == username))
8975
if user:
9076
pwd = password.get_secret_value() if isinstance(password, SecretStr) else password
9177
pwd2 = user.password.get_secret_value() if isinstance(user.password, SecretStr) else user.password
9278
if self.pwd_context.verify(pwd, pwd2): # 用户存在 且 密码验证通过
9379
return user
9480
return None
9581

82+
@cached_property
83+
def get_current_user(self):
84+
async def _get_current_user(
85+
request: Request,
86+
session: Union[Session, AsyncSession, None] = Depends(self.db.session_generator)
87+
) -> Optional[_UserModelT]:
88+
if request.scope.get('auth'): # 防止重复授权
89+
return request.scope.get('user')
90+
request.scope["auth"], request.scope["user"] = self, None
91+
token = self.backend.get_user_token(request)
92+
if not token:
93+
return None
94+
token_data = await self.backend.token_store.read_token(token)
95+
if token_data is not None:
96+
request.scope["user"]: _UserModelT = await self.db.async_get(self.user_model, token_data.id, session = session)
97+
return request.user
98+
99+
return _get_current_user
100+
96101
def requires(
97102
self,
98103
roles: Union[str, Sequence[str]] = None,
@@ -103,21 +108,22 @@ def requires(
103108
response: Union[bool, Response] = None,
104109
) -> Callable: # sourcery no-metrics
105110

106-
async def has_requires(conn: HTTPConnection) -> bool:
107-
# todo websocket support
108-
await self.backend.authenticate(conn) # type:ignore
109-
if not conn.user:
110-
return False
111-
return await self.db.async_run_sync(
112-
conn.user.has_requires,
111+
async def has_requires(user: _UserModelT) -> bool:
112+
return user and await self.db.async_run_sync(
113+
user.has_requires,
113114
roles = roles,
114115
groups = groups,
115116
permissions = permissions,
116117
is_session = True
117118
)
118119

119-
async def depend(request: Request) -> Union[bool, Response]:
120-
if not await has_requires(request):
120+
async def depend(
121+
request: Request,
122+
user: _UserModelT = Depends(self.get_current_user),
123+
) -> Union[bool, Response]:
124+
if isinstance(user, params.Depends):
125+
user = await self.get_current_user(request)
126+
if not await has_requires(user):
121127
if response is not None:
122128
return response
123129
code, headers = status_code, {}
@@ -150,7 +156,8 @@ async def websocket_wrapper(
150156
) -> None:
151157
websocket = kwargs.get("websocket", args[idx] if args else None)
152158
assert isinstance(websocket, WebSocket)
153-
if not await has_requires(websocket):
159+
user = await self.get_current_user(websocket) # type: ignore
160+
if not await has_requires(user):
154161
await websocket.close()
155162
else:
156163
await func(*args, **kwargs)

tests/test_auth/conftest.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,41 @@
1515
from fastapi_user_auth.auth.models import User
1616
from tests.conftest import async_db, sync_db
1717

18-
19-
@pytest.fixture(params=[async_db, sync_db])
18+
@pytest.fixture(params = [async_db, sync_db])
2019
async def db(request) -> Union[Database, AsyncDatabase]:
2120
database = request.param
22-
await database.async_run_sync(SQLModel.metadata.create_all, is_session=False)
21+
await database.async_run_sync(SQLModel.metadata.create_all, is_session = False)
2322
yield database
24-
await database.async_run_sync(SQLModel.metadata.drop_all, is_session=False)
25-
23+
await database.async_run_sync(SQLModel.metadata.drop_all, is_session = False)
2624

2725
app = FastAPI()
2826
# 创建auth实例
29-
auth = Auth(db=async_db)
27+
auth = Auth(db = async_db)
3028
# 注册auth基础路由
31-
auth_router = AuthRouter(auth=auth)
29+
auth_router = AuthRouter(auth = auth)
3230
app.include_router(auth_router.router)
3331

34-
3532
class UserClient:
33+
3634
def __init__(self, client: TestClient = None, user: User = None) -> None:
3735
self.client: TestClient = client or TestClient(app)
3836
self.user: User = user
3937

40-
4138
def get_login_client(username: str = None, password: str = None) -> UserClient:
4239
client = TestClient(app)
4340
if not username or not password:
4441
return UserClient()
45-
response = client.post('/auth/gettoken',
46-
data={'username': username, 'password': password},
47-
headers={"Content-Type": "application/x-www-form-urlencoded"})
42+
response = client.post(
43+
'/auth/gettoken',
44+
data = {'username': username, 'password': password},
45+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
46+
)
4847
data = response.json()
4948
assert data['data']['access_token']
5049
user = User.parse_obj(data['data'])
5150
assert user.is_active
5251
assert user.username == username
53-
return UserClient(client=client, user=user)
54-
52+
return UserClient(client = client, user = user)
5553

5654
@pytest.fixture
5755
def logins(request) -> UserClient:
@@ -64,52 +62,54 @@ def logins(request) -> UserClient:
6462
user = user_data.get(request.param) or {}
6563
return get_login_client(**user)
6664

67-
68-
@pytest.fixture(scope="session")
65+
@pytest.fixture(scope = "session")
6966
async def prepare_database() -> AsyncGenerator[None, None]:
70-
await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session=False)
67+
await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session = False)
7168
yield
72-
await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session=False)
73-
69+
await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session = False)
7470

75-
@pytest.fixture(scope="session")
71+
@pytest.fixture(scope = "session")
7672
def event_loop():
7773
loop = asyncio.get_event_loop_policy().new_event_loop()
7874
yield loop
7975
loop.close()
8076

81-
82-
@pytest_asyncio.fixture(scope="session", autouse=True)
77+
@pytest_asyncio.fixture(scope = "session", autouse = True)
8378
async def fake_users(prepare_database):
8479
await auth.db.async_run_sync(create_fake_users)
8580

86-
8781
# noinspection PyTypeChecker
8882
def create_fake_users(session: Session):
8983
# init permission
90-
admin_perm = Permission(key='admin', name='admin permission')
91-
vip_perm = Permission(key='vip', name='vip permission')
92-
test_perm = Permission(key='test', name='test permission')
84+
admin_perm = Permission(key = 'admin', name = 'admin permission')
85+
vip_perm = Permission(key = 'vip', name = 'vip permission')
86+
test_perm = Permission(key = 'test', name = 'test permission')
9387
session.add_all([admin_perm, vip_perm, test_perm])
9488
session.flush([admin_perm, vip_perm, test_perm])
9589
# init role
96-
admin_role = Role(key='admin', name='admin role', permissions=[admin_perm])
97-
vip_role = Role(key='vip', name='vip role', permissions=[vip_perm])
98-
test_role = Role(key='test', name='test role', permissions=[test_perm])
90+
admin_role = Role(key = 'admin', name = 'admin role', permissions = [admin_perm])
91+
vip_role = Role(key = 'vip', name = 'vip role', permissions = [vip_perm])
92+
test_role = Role(key = 'test', name = 'test role', permissions = [test_perm])
9993
session.add_all([admin_role, vip_role, test_role])
10094
session.flush([admin_role, vip_role, test_role])
10195
# init group
102-
admin_group = Group(key='admin', name='admin group', roles=[admin_role])
103-
vip_group = Group(key='vip', name='vip group', roles=[vip_role])
104-
test_group = Group(key='test', name='test group', roles=[test_role])
96+
admin_group = Group(key = 'admin', name = 'admin group', roles = [admin_role])
97+
vip_group = Group(key = 'vip', name = 'vip group', roles = [vip_role])
98+
test_group = Group(key = 'test', name = 'test group', roles = [test_role])
10599
session.add_all([admin_group, vip_group, test_group])
106100
session.flush([admin_group, vip_group, test_group])
107101
# init user
108-
admin_user = User(username='admin', password=auth.pwd_context.hash('admin'), email='[email protected]',
109-
roles=[admin_role], groups=[admin_group])
110-
vip_user = User(username='vip', password=auth.pwd_context.hash('vip'), email='[email protected]', roles=[vip_role],
111-
groups=[vip_group])
112-
test_user = User(username='test', password=auth.pwd_context.hash('test'), email='[email protected]', roles=[test_role],
113-
groups=[test_group])
102+
admin_user = User(
103+
username = 'admin', password = auth.pwd_context.hash('admin'), email = '[email protected]',
104+
roles = [admin_role], groups = [admin_group]
105+
)
106+
vip_user = User(
107+
username = 'vip', password = auth.pwd_context.hash('vip'), email = '[email protected]', roles = [vip_role],
108+
groups = [vip_group]
109+
)
110+
test_user = User(
111+
username = 'test', password = auth.pwd_context.hash('test'), email = '[email protected]', roles = [test_role],
112+
groups = [test_group]
113+
)
114114
session.add_all([admin_user, vip_user, test_user])
115115
session.flush([admin_user, vip_user, test_user])

tests/test_auth/test_auth.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from fastapi_user_auth.auth.models import User
55
from tests.test_auth.conftest import auth
66

7-
87
async def test_create_role_user():
98
user = await auth.create_role_user('admin2')
109
assert user.username == 'admin2'
@@ -14,7 +13,6 @@ async def test_create_role_user():
1413
role = result.roles[0]
1514
assert role.key == 'admin2'
1615

17-
1816
async def test_authenticate_user():
1917
# error
2018
user = await auth.authenticate_user('admin', 'admin1')
@@ -23,16 +21,3 @@ async def test_authenticate_user():
2321
# admin
2422
user = await auth.authenticate_user('admin', 'admin')
2523
assert user.username == 'admin'
26-
27-
28-
async def test_get_user_by_whereclause():
29-
user = await auth.get_user_by_whereclause(User.id == 1)
30-
assert user.username == 'admin'
31-
32-
user = await auth.get_user_by_whereclause(User.username == 'admin')
33-
assert user.username == 'admin'
34-
35-
36-
async def test_get_user_by_username():
37-
user = await auth.get_user_by_username('admin')
38-
assert user.username == 'admin'

tests/test_auth/test_auth_mount.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,34 @@
1111
app.mount('/subapp2', subapp2)
1212
auth.backend.attach_middleware(subapp2)
1313

14-
subapp3 = FastAPI(dependencies=[Depends(auth.requires('admin')())])
14+
subapp3 = FastAPI(dependencies = [Depends(auth.requires('admin')())])
1515
app.mount('/subapp3', subapp3)
1616

17-
1817
# auth decorator
1918
@subapp1.get("/auth/user")
2019
@auth.requires()
2120
def user(request: Request):
2221
return request.user
2322

24-
2523
@subapp2.get("/auth/user")
2624
def user_2(request: Request):
2725
if request.user:
2826
return request.user
2927
else:
30-
raise HTTPException(status_code=403)
31-
28+
raise HTTPException(status_code = 403)
3229

3330
@subapp3.get("/auth/user")
3431
@auth.requires()
3532
def user_3(request: Request):
3633
return request.user
3734

38-
3935
path_admin_auth = {
4036
"/subapp1/auth/user",
4137
"/subapp2/auth/user",
4238
"/subapp3/auth/user",
4339
}
4440

45-
46-
@pytest.mark.parametrize("logins", ['admin'], indirect=True)
41+
@pytest.mark.parametrize("logins", ['admin'], indirect = True)
4742
@pytest.mark.parametrize("path", list(path_admin_auth))
4843
def test_admin_auth(logins: UserClient, path):
4944
response = logins.client.get(path)

0 commit comments

Comments
 (0)