Skip to content

Commit 2b7ff87

Browse files
authored
Upgrade to SQLA v2 (#1250)
1 parent 22c4cb1 commit 2b7ff87

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1017
-1093
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
make install
7676
- name: Run tests
7777
run: |
78-
make test
78+
PYTHONASYNCIODEBUG=1 make test
7979
8080
doc:
8181
name: Documentation

demos/blog/aiohttpdemo_blog/db.py

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,83 @@
1-
from datetime import datetime as dt
2-
3-
import asyncpgsa
4-
from sqlalchemy import (
5-
MetaData, Table, Column, ForeignKey,
6-
Integer, String, DateTime
1+
from datetime import datetime
2+
3+
from sqlalchemy import ForeignKey, String
4+
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
5+
from sqlalchemy.orm import (
6+
DeclarativeBase,
7+
Mapped,
8+
mapped_column,
9+
relationship,
10+
selectinload,
711
)
812
from sqlalchemy.sql import select
913

10-
metadata = MetaData()
1114

15+
class Base(DeclarativeBase):
16+
pass
1217

13-
users = Table(
14-
'users', metadata,
1518

16-
Column('id', Integer, primary_key=True),
17-
Column('username', String(64), nullable=False, unique=True),
18-
Column('email', String(120)),
19-
Column('password_hash', String(128), nullable=False)
20-
)
19+
class Users(Base):
20+
__tablename__ = "users"
2121

22+
id: Mapped[int] = mapped_column(primary_key=True)
23+
username: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
24+
email: Mapped[str] = mapped_column(String(120))
25+
password_hash: Mapped[str] = mapped_column(String(128), nullable=False)
2226

23-
posts = Table(
24-
'posts', metadata,
27+
posts: Mapped[list["Posts"]] = relationship(
28+
back_populates="user", lazy="raise_on_sql"
29+
)
2530

26-
Column('id', Integer, primary_key=True),
27-
Column('body', String(140)),
28-
Column('timestamp', DateTime, index=True, default=dt.utcnow),
2931

30-
Column('user_id',
31-
Integer,
32-
ForeignKey('users.id'))
33-
)
32+
class Posts(Base):
33+
__tablename__ = "posts"
34+
35+
id: Mapped[int] = mapped_column(primary_key=True)
36+
body: Mapped[str] = mapped_column(String(140))
37+
timestamp: Mapped[datetime] = mapped_column(index=True, default=datetime.utcnow)
38+
39+
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
40+
user: Mapped[Users] = relationship(back_populates="posts", lazy="raise_on_sql")
3441

3542

3643
async def init_db(app):
37-
dsn = construct_db_url(app['config']['database'])
38-
pool = await asyncpgsa.create_pool(dsn=dsn)
39-
app['db_pool'] = pool
40-
return pool
44+
dsn = construct_db_url(app["config"]["database"])
45+
engine = create_async_engine(dsn)
46+
app["db_pool"] = async_sessionmaker(engine)
47+
48+
yield
49+
50+
await engine.dispose()
4151

4252

4353
def construct_db_url(config):
44-
DSN = "postgresql://{user}:{password}@{host}:{port}/{database}"
54+
DSN = "postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
4555
return DSN.format(
46-
user=config['DB_USER'],
47-
password=config['DB_PASS'],
48-
database=config['DB_NAME'],
49-
host=config['DB_HOST'],
50-
port=config['DB_PORT'],
56+
user=config["DB_USER"],
57+
password=config["DB_PASS"],
58+
database=config["DB_NAME"],
59+
host=config["DB_HOST"],
60+
port=config["DB_PORT"],
5161
)
5262

5363

54-
async def get_user_by_name(conn, username):
55-
result = await conn.fetchrow(
56-
users
57-
.select()
58-
.where(users.c.username == username)
59-
)
64+
async def get_user_by_name(sess, username):
65+
result = await sess.scalar(select(Users).where(Users.username == username))
6066
return result
6167

6268

63-
async def get_users(conn):
64-
records = await conn.fetch(
65-
users.select().order_by(users.c.id)
66-
)
67-
return records
69+
async def get_users(sess):
70+
records = await sess.scalars(select(Users).order_by(Users.id))
71+
return records.all()
6872

6973

70-
async def get_posts(conn):
71-
records = await conn.fetch(
72-
posts.select().order_by(posts.c.id)
73-
)
74-
return records
74+
async def get_posts(sess):
75+
records = await sess.scalars(select(Posts).order_by(Posts.id))
76+
return records.all()
7577

7678

77-
async def get_posts_with_joined_users(conn):
78-
j = posts.join(users, posts.c.user_id == users.c.id)
79-
stmt = select(
80-
[posts, users.c.username]).select_from(j).order_by(posts.c.timestamp)
81-
records = await conn.fetch(stmt)
82-
return records
83-
84-
85-
async def create_post(conn, post_body, user_id):
86-
stmt = posts.insert().values(body=post_body, user_id=user_id)
87-
await conn.execute(stmt)
79+
async def get_posts_with_joined_users(sess):
80+
records = await sess.scalars(
81+
select(Posts).options(selectinload(Posts.user)).order_by(Posts.timestamp)
82+
)
83+
return records.all()

demos/blog/aiohttpdemo_blog/db_auth.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44

55

66
class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
7-
8-
def __init__(self, db_pool):
9-
self.db_pool = db_pool
7+
def __init__(self, app):
8+
self.app = app
109

1110
async def authorized_userid(self, identity):
12-
async with self.db_pool.acquire() as conn:
13-
user = await db.get_user_by_name(conn, identity)
11+
async with self.app["db_pool"]() as sess:
12+
user = await db.get_user_by_name(sess, identity)
1413
if user:
1514
return identity
16-
1715
return None
1816

1917
async def permits(self, identity, permission, context=None):

demos/blog/aiohttpdemo_blog/forms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ async def validate_login_form(conn, form):
1616

1717
if not user:
1818
return 'Invalid username'
19-
if not check_password_hash(password, user['password_hash']):
19+
if not check_password_hash(password, user.password_hash):
2020
return 'Invalid password'
2121

2222
return None

demos/blog/aiohttpdemo_blog/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ async def init_app(config):
4242

4343
setup_routes(app)
4444

45-
db_pool = await init_db(app)
45+
app.cleanup_ctx.append(init_db)
4646

4747
redis = await setup_redis(app)
48+
app.on_shutdown.append(lambda _: redis.aclose())
4849
setup_session(app, RedisStorage(redis))
4950

5051
# needs to be after session setup because of `current_user_ctx_processor`
@@ -57,7 +58,7 @@ async def init_app(config):
5758
setup_security(
5859
app,
5960
SessionIdentityPolicy(),
60-
DBAuthorizationPolicy(db_pool)
61+
DBAuthorizationPolicy(app)
6162
)
6263

6364
log.debug(app['config'])

demos/blog/aiohttpdemo_blog/templates/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
{% block content %}
44
<h1>Hi, {{ user.username }}!</h1>
55
{% for post in posts %}
6-
<div><p>[{{ post.timestamp.strftime('%Y-%m-%d %H:%M:%S') }}] {{ post.username }} posted: <b>{{ post.body }}</b></p></div>
6+
<div><p>[{{ post.timestamp.strftime('%Y-%m-%d %H:%M:%S') }}] {{ post.user.username }} posted: <b>{{ post.body }}</b></p></div>
77
{% endfor %}
88
{% endblock %}

demos/blog/aiohttpdemo_blog/views.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ async def index(request):
1717
if not username:
1818
raise redirect(request.app.router, 'login')
1919

20-
async with request.app['db_pool'].acquire() as conn:
21-
current_user = await db.get_user_by_name(conn, username)
22-
posts = await db.get_posts_with_joined_users(conn)
20+
async with request.app['db_pool']() as sess:
21+
current_user = await db.get_user_by_name(sess, username)
22+
posts = await db.get_posts_with_joined_users(sess)
2323

2424
return {'user': current_user, 'posts': posts}
2525

@@ -33,16 +33,16 @@ async def login(request):
3333
if request.method == 'POST':
3434
form = await request.post()
3535

36-
async with request.app['db_pool'].acquire() as conn:
37-
error = await validate_login_form(conn, form)
36+
async with request.app['db_pool']() as sess:
37+
error = await validate_login_form(sess, form)
3838

3939
if error:
4040
return {'error': error}
4141
else:
4242
response = redirect(request.app.router, 'index')
4343

44-
user = await db.get_user_by_name(conn, form['username'])
45-
await remember(request, response, user['username'])
44+
user = await db.get_user_by_name(sess, form['username'])
45+
await remember(request, response, user.username)
4646

4747
raise response
4848

@@ -64,9 +64,9 @@ async def create_post(request):
6464
if request.method == 'POST':
6565
form = await request.post()
6666

67-
async with request.app['db_pool'].acquire() as conn:
68-
current_user = await db.get_user_by_name(conn, username)
69-
await db.create_post(conn, form['body'], current_user['id'])
67+
async with request.app['db_pool'].begin() as sess:
68+
current_user = await db.get_user_by_name(sess, username)
69+
sess.add(db.Posts(body=form["body"], user_id=current_user.id))
7070
raise redirect(request.app.router, 'index')
7171

7272
return {}

0 commit comments

Comments
 (0)