Skip to content

Commit d7d2047

Browse files
Upgrade demo to SQLAlchemy 2 ORM (#639)
1 parent 2033f41 commit d7d2047

File tree

11 files changed

+173
-315
lines changed

11 files changed

+173
-315
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Setup Python
2323
uses: actions/setup-python@v4
2424
with:
25-
python-version: 3.9
25+
python-version: '3.10'
2626
cache: 'pip'
2727
cache-dependency-path: '**/requirements*.txt'
2828
- name: Install dependencies

demo/database_auth/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from aiohttp.web import run_app
2+
3+
from .main import init_app
4+
5+
if __name__ == "__main__":
6+
run_app(init_app())

demo/database_auth/db.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,37 @@
11
import sqlalchemy as sa
2+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
23

34

4-
metadata = sa.MetaData()
5+
class Base(DeclarativeBase):
6+
metadata = sa.MetaData(naming_convention={
7+
"ix": "ix_%(column_0_label)s",
8+
"uq": "uq_%(table_name)s_%(column_0_name)s",
9+
"ck": "ck_%(table_name)s_%(column_0_name)s",
10+
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
11+
"pk": "pk_%(table_name)s"
12+
})
513

614

7-
users = sa.Table(
8-
'users', metadata,
9-
sa.Column('id', sa.Integer, nullable=False),
10-
sa.Column('login', sa.String(256), nullable=False),
11-
sa.Column('passwd', sa.String(256), nullable=False),
12-
sa.Column('is_superuser', sa.Boolean, nullable=False,
13-
server_default='FALSE'),
14-
sa.Column('disabled', sa.Boolean, nullable=False,
15-
server_default='FALSE'),
15+
class User(Base):
16+
"""A user and their credentials."""
1617

17-
# indices
18-
sa.PrimaryKeyConstraint('id', name='user_pkey'),
19-
sa.UniqueConstraint('login', name='user_login_key'),
20-
)
18+
__tablename__ = "users"
2119

20+
id: Mapped[int] = mapped_column(primary_key=True)
21+
username: Mapped[str] = mapped_column(sa.String(256), unique=True, index=True)
22+
password: Mapped[str] = mapped_column(sa.String(256))
23+
is_superuser: Mapped[bool] = mapped_column(
24+
default=False, server_default=sa.sql.expression.false())
25+
disabled: Mapped[bool] = mapped_column(
26+
default=False, server_default=sa.sql.expression.false())
27+
permissions = relationship("Permission", cascade="all, delete")
2228

23-
permissions = sa.Table(
24-
'permissions', metadata,
25-
sa.Column('id', sa.Integer, nullable=False),
26-
sa.Column('user_id', sa.Integer, nullable=False),
27-
sa.Column('perm_name', sa.String(64), nullable=False),
2829

29-
# indices
30-
sa.PrimaryKeyConstraint('id', name='permission_pkey'),
31-
sa.ForeignKeyConstraint(['user_id'], [users.c.id],
32-
name='user_permission_fkey',
33-
ondelete='CASCADE'),
34-
)
30+
class Permission(Base):
31+
"""A permission that grants a user access to something."""
32+
33+
__tablename__ = "permissions"
34+
35+
user_id: Mapped[int] = mapped_column(
36+
sa.ForeignKey(User.id, ondelete="CASCADE"), primary_key=True)
37+
name: Mapped[str] = mapped_column(sa.String(64), primary_key=True)

demo/database_auth/db_auth.py

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,52 @@
11
from enum import Enum
2-
from typing import Any, Optional, Union
32

43
import sqlalchemy as sa
54
from passlib.hash import sha256_crypt
5+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
6+
from sqlalchemy.orm import selectinload
67

78
from aiohttp_security.abc import AbstractAuthorizationPolicy
8-
from . import db
9+
from .db import User
10+
11+
12+
def _where_authorized(identity: str) -> tuple[sa.sql.ColumnElement[bool], ...]:
13+
return (User.username == identity, ~User.disabled)
914

1015

1116
class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
12-
def __init__(self, dbengine: Any):
13-
self.dbengine = dbengine
14-
15-
async def authorized_userid(self, identity: str) -> Optional[str]:
16-
async with self.dbengine.acquire() as conn:
17-
where = sa.and_(db.users.c.login == identity,
18-
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
19-
query = db.users.count().where(where)
20-
ret = await conn.scalar(query)
21-
if ret:
22-
return identity
23-
else:
24-
return None
25-
26-
async def permits(self, identity: Optional[str], permission: Union[str, Enum],
27-
context: None = None) -> bool:
28-
async with self.dbengine.acquire() as conn:
29-
where = sa.and_(db.users.c.login == identity,
30-
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
31-
query = db.users.select().where(where)
32-
ret = await conn.execute(query)
33-
user = await ret.fetchone()
34-
if user is not None:
35-
user_id = user[0]
36-
is_superuser = user[3]
37-
if is_superuser:
38-
return True
39-
40-
where = db.permissions.c.user_id == user_id
41-
query = db.permissions.select().where(where)
42-
ret = await conn.execute(query)
43-
result = await ret.fetchall()
44-
if ret is not None:
45-
for record in result:
46-
if record.perm_name == permission:
47-
return True
17+
def __init__(self, dbsession: async_sessionmaker[AsyncSession]):
18+
self.dbsession = dbsession
19+
20+
async def authorized_userid(self, identity: str) -> str | None:
21+
where = _where_authorized(identity)
22+
async with self.dbsession() as sess:
23+
user_id = await sess.scalar(sa.select(User.id).where(*where))
24+
return str(user_id) if user_id else None
25+
26+
async def permits(self, identity: str | None, permission: str | Enum,
27+
context: dict[str, object] | None = None) -> bool:
28+
if identity is None:
29+
return False
4830

31+
where = _where_authorized(identity)
32+
stmt = sa.select(User).options(selectinload(User.permissions)).where(*where)
33+
async with self.dbsession() as sess:
34+
user = await sess.scalar(stmt)
35+
36+
if user is None:
4937
return False
38+
if user.is_superuser:
39+
return True
40+
return any(p.name == permission for p in user.permissions)
41+
42+
43+
async def check_credentials(db_session: async_sessionmaker[AsyncSession],
44+
username: str, password: str) -> bool:
45+
where = _where_authorized(username)
46+
async with db_session() as sess:
47+
hashed_pw = await sess.scalar(sa.select(User.password).where(*where))
5048

49+
if hashed_pw is None:
50+
return False
5151

52-
async def check_credentials(db_engine: Any, username: str, password: str) -> bool:
53-
async with db_engine.acquire() as conn:
54-
where = sa.and_(db.users.c.login == username,
55-
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
56-
query = db.users.select().where(where)
57-
ret = await conn.execute(query)
58-
user = await ret.fetchone()
59-
if user is not None:
60-
hashed = user[2]
61-
return sha256_crypt.verify(password, hashed)
62-
return False
52+
return sha256_crypt.verify(password, hashed_pw)

demo/database_auth/handlers.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .db_auth import check_credentials
99

1010

11-
class Web(object):
11+
class Web:
1212
index_template = dedent("""
1313
<!doctype html>
1414
<head></head>
@@ -32,20 +32,18 @@ async def index(self, request: web.Request) -> web.Response:
3232
message='Hello, {username}!'.format(username=username))
3333
else:
3434
template = self.index_template.format(message='You need to login')
35-
response = web.Response(body=template.encode())
36-
return response
35+
return web.Response(text=template, content_type="text/html")
3736

3837
async def login(self, request: web.Request) -> NoReturn:
3938
invalid_resp = web.HTTPUnauthorized(body=b"Invalid username/password combination")
4039
form = await request.post()
4140
login = form.get('login')
4241
password = form.get('password')
43-
db_engine = request.app["db_engine"]
4442

4543
if not (isinstance(login, str) and isinstance(password, str)):
4644
raise invalid_resp
4745

48-
if await check_credentials(db_engine, login, password):
46+
if await check_credentials(request.app["db_session"], login, password):
4947
response = web.HTTPFound("/")
5048
await remember(request, response, login)
5149
raise response
@@ -54,26 +52,22 @@ async def login(self, request: web.Request) -> NoReturn:
5452

5553
async def logout(self, request: web.Request) -> web.Response:
5654
await check_authorized(request)
57-
response = web.Response(body=b'You have been logged out')
55+
response = web.Response(text="You have been logged out")
5856
await forget(request, response)
5957
return response
6058

6159
async def internal_page(self, request: web.Request) -> web.Response:
6260
await check_permission(request, 'public')
63-
response = web.Response(
64-
body=b'This page is visible for all registered users')
65-
return response
61+
return web.Response(text="This page is visible for all registered users")
6662

6763
async def protected_page(self, request: web.Request) -> web.Response:
6864
await check_permission(request, 'protected')
69-
response = web.Response(body=b'You are on protected page')
70-
return response
65+
return web.Response(text="You are on protected page")
7166

7267
def configure(self, app: web.Application) -> None:
7368
router = app.router
7469
router.add_route('GET', '/', self.index, name='index')
7570
router.add_route('POST', '/login', self.login, name='login')
7671
router.add_route('GET', '/logout', self.logout, name='logout')
7772
router.add_route('GET', '/public', self.internal_page, name='public')
78-
router.add_route('GET', '/protected', self.protected_page,
79-
name='protected')
73+
router.add_route('GET', '/protected', self.protected_page, name='protected')

demo/database_auth/main.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,48 @@
1-
import asyncio
2-
from typing import Tuple
3-
41
from aiohttp import web
5-
from aiohttp_session import setup as setup_session
6-
from aiohttp_session.redis_storage import RedisStorage
7-
from aiopg.sa import create_engine
8-
from aioredis import create_pool # type: ignore[attr-defined]
2+
from aiohttp_session import SimpleCookieStorage, setup as setup_session
3+
from sqlalchemy.ext.asyncio import (AsyncEngine, AsyncSession, async_sessionmaker,
4+
create_async_engine)
95

106
from aiohttp_security import SessionIdentityPolicy
117
from aiohttp_security import setup as setup_security
8+
from .db import Base, Permission, User
129
from .db_auth import DBAuthorizationPolicy
1310
from .handlers import Web
1411

1512

16-
async def init(loop: asyncio.AbstractEventLoop) -> Tuple[asyncio.Server, web.Application,
17-
web.Server]:
18-
redis_pool = await create_pool(('localhost', 6379))
19-
db_engine = await create_engine( # type: ignore[no-untyped-call] # noqa: S106
20-
user="aiohttp_security", password="aiohttp_security",
21-
database="aiohttp_security", host="127.0.0.1")
13+
async def init_db(db_engine: AsyncEngine, db_session: async_sessionmaker[AsyncSession]) -> None:
14+
"""Initialise DB with sample data."""
15+
async with db_engine.begin() as conn:
16+
await conn.run_sync(Base.metadata.create_all)
17+
async with db_session.begin() as sess:
18+
pw = "$5$rounds=535000$2kqN9fxCY6Xt5/pi$tVnh0xX87g/IsnOSuorZG608CZDFbWIWBr58ay6S4pD"
19+
sess.add(User(username="admin", password=pw, is_superuser=True))
20+
moderator = User(username="moderator", password=pw)
21+
user = User(username="user", password=pw)
22+
sess.add(moderator)
23+
sess.add(user)
24+
async with db_session.begin() as sess:
25+
sess.add(Permission(user_id=moderator.id, name="protected"))
26+
sess.add(Permission(user_id=moderator.id, name="public"))
27+
sess.add(Permission(user_id=user.id, name="public"))
28+
29+
30+
async def init_app() -> web.Application:
2231
app = web.Application()
23-
app["db_engine"] = db_engine
24-
setup_session(app, RedisStorage(redis_pool))
25-
setup_security(app,
26-
SessionIdentityPolicy(),
27-
DBAuthorizationPolicy(db_engine))
28-
29-
web_handlers = Web()
30-
web_handlers.configure(app)
3132

32-
handler = app.make_handler()
33-
srv = await loop.create_server(handler, '127.0.0.1', 8080)
34-
print('Server started at http://127.0.0.1:8080')
35-
return srv, app, handler
33+
db_engine = create_async_engine("sqlite+aiosqlite:///:memory:")
34+
app["db_session"] = async_sessionmaker(db_engine, expire_on_commit=False)
3635

36+
await init_db(db_engine, app["db_session"])
3737

38-
async def finalize(srv: asyncio.Server, app: web.Application, handler: web.Server) -> None:
39-
sock = srv.sockets[0]
40-
app.loop.remove_reader(sock.fileno())
41-
sock.close()
42-
43-
await handler.shutdown(1.0)
44-
srv.close()
45-
await srv.wait_closed()
46-
await app.cleanup()
38+
setup_session(app, SimpleCookieStorage())
39+
setup_security(app, SessionIdentityPolicy(), DBAuthorizationPolicy(app["db_session"]))
4740

41+
web_handlers = Web()
42+
web_handlers.configure(app)
4843

49-
def main() -> None:
50-
loop = asyncio.get_event_loop()
51-
srv, app, handler = loop.run_until_complete(init(loop))
52-
try:
53-
loop.run_forever()
54-
except KeyboardInterrupt:
55-
loop.run_until_complete((finalize(srv, app, handler)))
44+
return app
5645

5746

58-
if __name__ == '__main__':
59-
main()
47+
if __name__ == "__main__":
48+
web.run_app(init_app())

demo/database_auth/sql/init_db.sql

Lines changed: 0 additions & 5 deletions
This file was deleted.

demo/database_auth/sql/sample_data.sql

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)