Skip to content

Commit e58da41

Browse files
HyeockJinKimclaudelablup-octodog
authored
feat(BA-3802): Migrate to SQLAlchemy 2.0 with comprehensive type safety improvements (#7880)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: octodog <mu001@lablup.com>
1 parent 5ffa9bd commit e58da41

File tree

220 files changed

+4501
-3180
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

220 files changed

+4501
-3180
lines changed

changes/7880.deps.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Migrate to SQLAlchemy 2.0 with comprehensive type safety improvements

docs/manager/graphql-reference/supergraph.graphql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3892,7 +3892,7 @@ type ModelRevision implements Node
38923892
modelRuntimeConfig: ModelRuntimeConfig!
38933893

38943894
"""Model data mount configuration."""
3895-
modelMountConfig: ModelMountConfig!
3895+
modelMountConfig: ModelMountConfig
38963896

38973897
"""Additional volume folder mounts."""
38983898
extraMounts: ExtraVFolderMountConnection!
@@ -6379,7 +6379,7 @@ type Route implements Node
63796379
trafficRatio: Float!
63806380

63816381
"""The timestamp when the route was created."""
6382-
createdAt: DateTime!
6382+
createdAt: DateTime
63836383

63846384
"""Error data if the route is in a failed state."""
63856385
errorData: JSONString

docs/manager/graphql-reference/v2-schema.graphql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,7 +1464,7 @@ type ModelRevision implements Node {
14641464
modelRuntimeConfig: ModelRuntimeConfig!
14651465

14661466
"""Model data mount configuration."""
1467-
modelMountConfig: ModelMountConfig!
1467+
modelMountConfig: ModelMountConfig
14681468

14691469
"""Additional volume folder mounts."""
14701470
extraMounts: ExtraVFolderMountConnection!
@@ -2410,7 +2410,7 @@ type Route implements Node {
24102410
trafficRatio: Float!
24112411

24122412
"""The timestamp when the route was created."""
2413-
createdAt: DateTime!
2413+
createdAt: DateTime
24142414

24152415
"""Error data if the route is in a failed state."""
24162416
errorData: JSONString

pants.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ pants_ignore = [
2626
"/tools/pants-plugins",
2727
"/wheelhouse/*/",
2828
"/wheelhouse/*.whl",
29+
# Alembic migration files - legacy SQLAlchemy 1.x style, excluded from type checking
30+
"/src/ai/backend/manager/models/alembic/versions/",
31+
"/src/ai/backend/account_manager/models/alembic/versions/",
32+
"/src/ai/backend/appproxy/coordinator/models/alembic/versions/",
2933
]
3034
build_file_prelude_globs = ["tools/build-macros.py"]
3135

@@ -125,6 +129,7 @@ report = ["xml", "console"]
125129

126130
[mypy]
127131
install_from_resolve = "mypy"
132+
config = "pyproject.toml"
128133

129134
[setuptools]
130135
install_from_resolve = "setuptools"

python.lock

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// "Jinja2~=3.1.6",
1313
// "PyJWT~=2.10.1",
1414
// "PyYAML~=6.0",
15-
// "SQLAlchemy[postgresql_asyncpg]~=1.4.54",
15+
// "SQLAlchemy[postgresql_asyncpg]~=2.0.45",
1616
// "aioboto3~=15.0.0",
1717
// "aiodataloader~=0.4.2",
1818
// "aiodns==3.2",
@@ -3282,13 +3282,13 @@
32823282
"artifacts": [
32833283
{
32843284
"algorithm": "sha256",
3285-
"hash": "3671a94fd25e62f5f2f554f5e95389c2294d89822378a5f2dd24353e1494a9e0",
3286-
"url": "https://files.pythonhosted.org/packages/bb/f5/fddaec430367be9d62a7ed125530e133bfd4a1c0350fe221149ee0f2b526/jupyter_client-8.7.0-py3-none-any.whl"
3285+
"hash": "f93a5b99c5e23a507b773d3a1136bd6e16c67883ccdbd9a829b0bbdb98cd7d7a",
3286+
"url": "https://files.pythonhosted.org/packages/2d/0b/ceb7694d864abc0a047649aec263878acb9f792e1fec3e676f22dc9015e3/jupyter_client-8.8.0-py3-none-any.whl"
32873287
},
32883288
{
32893289
"algorithm": "sha256",
3290-
"hash": "3357212d9cbe01209e59190f67a3a7e1f387a4f4e88d1e0433ad84d7b262531d",
3291-
"url": "https://files.pythonhosted.org/packages/a6/27/d10de45e8ad4ce872372c4a3a37b7b35b6b064f6f023a5c14ffcced4d59d/jupyter_client-8.7.0.tar.gz"
3290+
"hash": "d556811419a4f2d96c869af34e854e3f059b7cc2d6d01a9cd9c85c267691be3e",
3291+
"url": "https://files.pythonhosted.org/packages/05/e4/ba649102a3bc3fbca54e7239fb924fd434c766f855693d86de0b1f2bec81/jupyter_client-8.8.0.tar.gz"
32923292
}
32933293
],
32943294
"project_name": "jupyter-client",
@@ -3298,8 +3298,10 @@
32983298
"ipykernel; extra == \"docs\"",
32993299
"ipykernel>=6.14; extra == \"test\"",
33003300
"jupyter-core>=5.1",
3301-
"mypy; extra == \"test\"",
3301+
"msgpack; extra == \"test\"",
3302+
"mypy; platform_python_implementation != \"PyPy\" and extra == \"test\"",
33023303
"myst-parser; extra == \"docs\"",
3304+
"orjson; extra == \"orjson\"",
33033305
"paramiko; sys_platform == \"win32\" and extra == \"test\"",
33043306
"pre-commit; extra == \"test\"",
33053307
"pydata-sphinx-theme; extra == \"docs\"",
@@ -3317,7 +3319,7 @@
33173319
"traitlets>=5.3"
33183320
],
33193321
"requires_python": ">=3.10",
3320-
"version": "8.7.0"
3322+
"version": "8.8.0"
33213323
},
33223324
{
33233325
"artifacts": [
@@ -6203,52 +6205,82 @@
62036205
"artifacts": [
62046206
{
62056207
"algorithm": "sha256",
6206-
"hash": "4470fbed088c35dc20b78a39aaf4ae54fe81790c783b3264872a0224f437c31a",
6207-
"url": "https://files.pythonhosted.org/packages/ce/af/20290b55d469e873cba9d41c0206ab5461ff49d759989b3fe65010f9d265/sqlalchemy-1.4.54.tar.gz"
6208+
"hash": "5225a288e4c8cc2308dbdd874edad6e7d0fd38eac1e9e5f23503425c8eee20d0",
6209+
"url": "https://files.pythonhosted.org/packages/bf/e1/3ccb13c643399d22289c6a9786c1a91e3dcbb68bce4beb44926ac2c557bf/sqlalchemy-2.0.45-py3-none-any.whl"
6210+
},
6211+
{
6212+
"algorithm": "sha256",
6213+
"hash": "672c45cae53ba88e0dad74b9027dddd09ef6f441e927786b05bec75d949fbb2e",
6214+
"url": "https://files.pythonhosted.org/packages/0e/50/80a8d080ac7d3d321e5e5d420c9a522b0aa770ec7013ea91f9a8b7d36e4a/sqlalchemy-2.0.45-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl"
6215+
},
6216+
{
6217+
"algorithm": "sha256",
6218+
"hash": "83d7009f40ce619d483d26ac1b757dfe3167b39921379a8bd1b596cf02dab4a6",
6219+
"url": "https://files.pythonhosted.org/packages/3d/8d/bb40a5d10e7a5f2195f235c0b2f2c79b0bf6e8f00c0c223130a4fbd2db09/sqlalchemy-2.0.45-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl"
6220+
},
6221+
{
6222+
"algorithm": "sha256",
6223+
"hash": "fe187fc31a54d7fd90352f34e8c008cf3ad5d064d08fedd3de2e8df83eb4a1cf",
6224+
"url": "https://files.pythonhosted.org/packages/6a/c8/7cc5221b47a54edc72a0140a1efa56e0a2730eefa4058d7ed0b4c4357ff8/sqlalchemy-2.0.45-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl"
6225+
},
6226+
{
6227+
"algorithm": "sha256",
6228+
"hash": "9c6378449e0940476577047150fd09e242529b761dc887c9808a9a937fe990c8",
6229+
"url": "https://files.pythonhosted.org/packages/74/04/891b5c2e9f83589de202e7abaf24cd4e4fa59e1837d64d528829ad6cc107/sqlalchemy-2.0.45-cp313-cp313-musllinux_1_2_x86_64.whl"
6230+
},
6231+
{
6232+
"algorithm": "sha256",
6233+
"hash": "d8a2ca754e5415cde2b656c27900b19d50ba076aa05ce66e2207623d3fe41f5a",
6234+
"url": "https://files.pythonhosted.org/packages/75/a5/346128b0464886f036c039ea287b7332a410aa2d3fb0bb5d404cb8861635/sqlalchemy-2.0.45-cp313-cp313t-musllinux_1_2_x86_64.whl"
6235+
},
6236+
{
6237+
"algorithm": "sha256",
6238+
"hash": "1632a4bda8d2d25703fdad6363058d882541bdaaee0e5e3ddfa0cd3229efce88",
6239+
"url": "https://files.pythonhosted.org/packages/be/f9/5e4491e5ccf42f5d9cfc663741d261b3e6e1683ae7812114e7636409fcc6/sqlalchemy-2.0.45.tar.gz"
6240+
},
6241+
{
6242+
"algorithm": "sha256",
6243+
"hash": "470daea2c1ce73910f08caf10575676a37159a6d16c4da33d0033546bddebc9b",
6244+
"url": "https://files.pythonhosted.org/packages/da/4c/13dab31266fc9904f7609a5dc308a2432a066141d65b857760c3bef97e69/sqlalchemy-2.0.45-cp313-cp313-musllinux_1_2_aarch64.whl"
62086245
}
62096246
],
62106247
"project_name": "sqlalchemy",
62116248
"requires_dists": [
6212-
"aiomysql>=0.2.0; python_version >= \"3\" and extra == \"aiomysql\"",
6213-
"aiosqlite; python_version >= \"3\" and extra == \"aiosqlite\"",
6214-
"asyncmy!=0.2.4,>=0.2.3; python_version >= \"3\" and extra == \"asyncmy\"",
6215-
"asyncpg; python_version >= \"3\" and extra == \"postgresql-asyncpg\"",
6216-
"asyncpg; python_version >= \"3\" and extra == \"postgresql-asyncpg\"",
6217-
"cx_oracle<8,>=7; python_version < \"3\" and extra == \"oracle\"",
6218-
"cx_oracle>=7; python_version >= \"3\" and extra == \"oracle\"",
6219-
"greenlet!=0.4.17; python_version >= \"3\" and (platform_machine == \"aarch64\" or (platform_machine == \"ppc64le\" or (platform_machine == \"x86_64\" or (platform_machine == \"amd64\" or (platform_machine == \"AMD64\" or (platform_machine == \"win32\" or platform_machine == \"WIN32\"))))))",
6220-
"greenlet!=0.4.17; python_version >= \"3\" and extra == \"aiomysql\"",
6221-
"greenlet!=0.4.17; python_version >= \"3\" and extra == \"aiosqlite\"",
6222-
"greenlet!=0.4.17; python_version >= \"3\" and extra == \"asyncio\"",
6223-
"greenlet!=0.4.17; python_version >= \"3\" and extra == \"asyncmy\"",
6224-
"greenlet!=0.4.17; python_version >= \"3\" and extra == \"postgresql-asyncpg\"",
6225-
"greenlet!=0.4.17; python_version >= \"3\" and extra == \"postgresql-asyncpg\"",
6249+
"aiomysql>=0.2.0; extra == \"aiomysql\"",
6250+
"aioodbc; extra == \"aioodbc\"",
6251+
"aiosqlite; extra == \"aiosqlite\"",
6252+
"asyncmy!=0.2.4,!=0.2.6,>=0.2.3; extra == \"asyncmy\"",
6253+
"asyncpg; extra == \"postgresql-asyncpg\"",
6254+
"cx_oracle>=8; extra == \"oracle\"",
6255+
"greenlet>=1; extra == \"aiomysql\"",
6256+
"greenlet>=1; extra == \"aioodbc\"",
6257+
"greenlet>=1; extra == \"aiosqlite\"",
6258+
"greenlet>=1; extra == \"asyncio\"",
6259+
"greenlet>=1; extra == \"asyncmy\"",
6260+
"greenlet>=1; extra == \"postgresql-asyncpg\"",
6261+
"greenlet>=1; platform_machine == \"aarch64\" or (platform_machine == \"ppc64le\" or (platform_machine == \"x86_64\" or (platform_machine == \"amd64\" or (platform_machine == \"AMD64\" or (platform_machine == \"win32\" or platform_machine == \"WIN32\")))))",
62266262
"importlib-metadata; python_version < \"3.8\"",
6227-
"mariadb!=1.1.2,>=1.0.1; python_version >= \"3\" and extra == \"mariadb-connector\"",
6228-
"mariadb!=1.1.2,>=1.0.1; python_version >= \"3\" and extra == \"mariadb-connector\"",
6229-
"mypy>=0.910; python_version >= \"3\" and extra == \"mypy\"",
6230-
"mysql-connector-python; extra == \"mysql-connector\"",
6263+
"mariadb!=1.1.10,!=1.1.2,!=1.1.5,>=1.0.1; extra == \"mariadb-connector\"",
6264+
"mypy>=0.910; extra == \"mypy\"",
62316265
"mysql-connector-python; extra == \"mysql-connector\"",
6232-
"mysqlclient<2,>=1.4.0; python_version < \"3\" and extra == \"mysql\"",
6233-
"mysqlclient>=1.4.0; python_version >= \"3\" and extra == \"mysql\"",
6234-
"pg8000!=1.29.0,>=1.16.6; python_version >= \"3\" and extra == \"postgresql-pg8000\"",
6235-
"pg8000!=1.29.0,>=1.16.6; python_version >= \"3\" and extra == \"postgresql-pg8000\"",
6266+
"mysqlclient>=1.4.0; extra == \"mysql\"",
6267+
"oracledb>=1.0.1; extra == \"oracle-oracledb\"",
6268+
"pg8000>=1.29.1; extra == \"postgresql-pg8000\"",
62366269
"psycopg2-binary; extra == \"postgresql-psycopg2binary\"",
62376270
"psycopg2>=2.7; extra == \"postgresql\"",
62386271
"psycopg2cffi; extra == \"postgresql-psycopg2cffi\"",
6272+
"psycopg>=3.0.7; extra == \"postgresql-psycopg\"",
6273+
"psycopg[binary]>=3.0.7; extra == \"postgresql-psycopgbinary\"",
62396274
"pymssql; extra == \"mssql-pymssql\"",
6240-
"pymssql; extra == \"mssql-pymssql\"",
6241-
"pymysql; python_version >= \"3\" and extra == \"pymysql\"",
6242-
"pymysql<1; python_version < \"3\" and extra == \"pymysql\"",
6275+
"pymysql; extra == \"pymysql\"",
62436276
"pyodbc; extra == \"mssql\"",
62446277
"pyodbc; extra == \"mssql-pyodbc\"",
6245-
"pyodbc; extra == \"mssql-pyodbc\"",
6246-
"sqlalchemy2-stubs; extra == \"mypy\"",
6247-
"sqlcipher3_binary; python_version >= \"3\" and extra == \"sqlcipher\"",
6278+
"sqlcipher3_binary; extra == \"sqlcipher\"",
6279+
"typing-extensions>=4.6.0",
62486280
"typing_extensions!=3.10.0.1; extra == \"aiosqlite\""
62496281
],
6250-
"requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7",
6251-
"version": "1.4.54"
6282+
"requires_python": ">=3.7",
6283+
"version": "2.0.45"
62526284
},
62536285
{
62546286
"artifacts": [
@@ -7651,7 +7683,7 @@
76517683
"Jinja2~=3.1.6",
76527684
"PyJWT~=2.10.1",
76537685
"PyYAML~=6.0",
7654-
"SQLAlchemy[postgresql_asyncpg]~=1.4.54",
7686+
"SQLAlchemy[postgresql_asyncpg]~=2.0.45",
76557687
"aioboto3~=15.0.0",
76567688
"aiodataloader~=0.4.2",
76577689
"aiodns==3.2",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ huggingface-hub~=0.34.3
7878
redis[hiredis]~=7.1.0
7979
rich~=13.6
8080
ruamel.yaml~=0.18.10
81-
SQLAlchemy[postgresql_asyncpg]~=1.4.54
81+
SQLAlchemy[postgresql_asyncpg]~=2.0.45
8282
setproctitle~=1.3.5
8383
setuptools~=80.0.0
8484
strawberry-graphql~=0.278.0

src/ai/backend/account_manager/models/alembic/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def do_run_migrations(connection: Connection) -> None:
6262

6363
async def run_async_migrations() -> None:
6464
connectable = async_engine_from_config(
65-
config.get_section(config.config_ini_section),
65+
config.get_section(config.config_ini_section) or {},
6666
prefix="sqlalchemy.",
6767
poolclass=pool.NullPool,
6868
)

src/ai/backend/account_manager/models/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,17 @@ def process_bind_param(
117117

118118
def process_result_value(
119119
self,
120-
value: str,
120+
value: Any | None,
121121
dialect: Dialect,
122122
) -> T_StrEnum | None:
123123
return self._enum_cls(value) if value is not None else None
124124

125-
def copy(self, **kw) -> type[Self]:
126-
return StrEnumType(self._enum_cls, **self._opts)
125+
def copy(self, **kw) -> Self:
126+
return StrEnumType(self._enum_cls, **self._opts) # type: ignore[return-value]
127127

128128
@property
129-
def python_type(self) -> T_StrEnum:
130-
return self._enum_class
129+
def python_type(self) -> type[T_StrEnum]:
130+
return self._enum_cls
131131

132132

133133
class PasswordColumn(TypeDecorator):

src/ai/backend/account_manager/models/userprofile.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,50 @@
1+
from datetime import datetime
2+
from uuid import UUID
3+
14
import sqlalchemy as sa
2-
from sqlalchemy.orm import relationship
5+
from sqlalchemy.orm import Mapped, mapped_column, relationship
36

47
from ai.backend.account_manager.types import UserRole, UserStatus
58
from ai.backend.account_manager.utils import verify_password
69

7-
from .base import GUID, Base, IDColumn, PasswordColumn, StrEnumType
10+
from .base import GUID, Base, PasswordColumn, StrEnumType
811

912
__all__: tuple[str, ...] = ("UserProfileRow",)
1013

1114

1215
class UserProfileRow(Base):
1316
__tablename__ = "user_profiles"
14-
id = IDColumn()
15-
user_id = sa.Column("user_id", GUID, nullable=False)
16-
username = sa.Column("username", sa.String(length=64), index=True, nullable=False, unique=True)
17-
email = sa.Column("email", sa.String(length=64), index=True, nullable=False)
18-
password = sa.Column("password", PasswordColumn(), nullable=False)
19-
need_password_change = sa.Column("need_password_change", sa.Boolean, server_default=sa.false())
20-
password_changed_at = sa.Column(
17+
id: Mapped[UUID] = mapped_column(
18+
GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")
19+
)
20+
user_id: Mapped[UUID] = mapped_column("user_id", GUID, nullable=False)
21+
username: Mapped[str] = mapped_column(
22+
"username", sa.String(length=64), index=True, nullable=False, unique=True
23+
)
24+
email: Mapped[str] = mapped_column("email", sa.String(length=64), index=True, nullable=False)
25+
password: Mapped[str] = mapped_column("password", PasswordColumn(), nullable=False)
26+
need_password_change: Mapped[bool | None] = mapped_column(
27+
"need_password_change", sa.Boolean, server_default=sa.false()
28+
)
29+
password_changed_at: Mapped[datetime | None] = mapped_column(
2130
"password_changed_at",
2231
sa.DateTime(timezone=True),
2332
server_default=sa.func.now(),
2433
)
25-
full_name = sa.Column("full_name", sa.String(length=64))
26-
description = sa.Column("description", sa.String(length=500))
34+
full_name: Mapped[str | None] = mapped_column("full_name", sa.String(length=64))
35+
description: Mapped[str | None] = mapped_column("description", sa.String(length=500))
2736

28-
role = sa.Column(
37+
role: Mapped[UserRole] = mapped_column(
2938
"role",
3039
StrEnumType(UserRole),
3140
default=UserRole.USER,
3241
server_default=UserRole.USER.value,
3342
nullable=False,
3443
)
35-
status = sa.Column(
44+
status: Mapped[UserStatus] = mapped_column(
3645
"status", StrEnumType(UserStatus), server_default=UserStatus.ACTIVE.value, nullable=False
3746
)
38-
status_info = sa.Column("status_info", sa.Unicode(), nullable=True)
47+
status_info: Mapped[str | None] = mapped_column("status_info", sa.Unicode(), nullable=True)
3948

4049
created_at = sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now())
4150
modified_at = sa.Column(

src/ai/backend/account_manager/models/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
1818
from sqlalchemy.ext.asyncio import AsyncEngine as SAEngine
1919
from sqlalchemy.ext.asyncio import AsyncSession as SASession
20-
from sqlalchemy.orm import sessionmaker
20+
from sqlalchemy.ext.asyncio import async_sessionmaker
2121
from tenacity import (
2222
AsyncRetrying,
2323
RetryError,
@@ -53,8 +53,8 @@ def __init__(self, *args, **kwargs) -> None:
5353
super().__init__(*args, **kwargs)
5454
self._readonly_txn_count = 0
5555
self._generic_txn_count = 0
56-
self._sess_factory = sessionmaker(self, expire_on_commit=False, class_=SASession)
57-
self._readonly_sess_factory = sessionmaker(self, class_=SASession)
56+
self._sess_factory = async_sessionmaker(self, expire_on_commit=False)
57+
self._readonly_sess_factory = async_sessionmaker(self)
5858

5959
def _check_generic_txn_cnt(self) -> None:
6060
if (
@@ -229,6 +229,8 @@ async def connect_database(
229229
async with version_check_db.begin() as conn:
230230
result = await conn.execute(sa.text("show server_version"))
231231
version_str = result.scalar()
232+
if version_str is None:
233+
raise RuntimeError("Failed to get PostgreSQL version")
232234
major, minor, *_ = map(int, version_str.partition(" ")[0].split("."))
233235
if (major, minor) < (11, 0):
234236
pgsql_connect_opts["server_settings"].pop("jit")

0 commit comments

Comments
 (0)