Skip to content

Commit c6e76b8

Browse files
committed
update typing for sqlalchemy 2
1 parent 3b9965a commit c6e76b8

File tree

11 files changed

+73
-43
lines changed

11 files changed

+73
-43
lines changed

pdm.lock

Lines changed: 1 addition & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ coverage = [
4747
mypy = [
4848
"mypy",
4949
"pytest",
50-
"sqlalchemy[mypy]",
50+
"sqlalchemy",
5151
]
5252
docs = [
5353
"sphinx",

src/flask_sqlalchemy/extension.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import sqlalchemy.event
99
import sqlalchemy.exc
1010
import sqlalchemy.orm
11-
import sqlalchemy.pool
1211
from flask import abort
1312
from flask import current_app
1413
from flask import Flask
@@ -22,6 +21,7 @@
2221
from .query import Query
2322
from .session import _app_ctx_id
2423
from .session import Session
24+
from .table import _Table
2525

2626

2727
class SQLAlchemy:
@@ -126,8 +126,8 @@ def __init__(
126126
*,
127127
metadata: sa.MetaData | None = None,
128128
session_options: dict[str, t.Any] | None = None,
129-
query_class: t.Type[Query] = Query,
130-
model_class: t.Type[Model] | sa.orm.DeclarativeMeta = Model,
129+
query_class: type[Query] = Query,
130+
model_class: type[Model] | sa.orm.DeclarativeMeta = Model,
131131
engine_options: dict[str, t.Any] | None = None,
132132
add_models_to_shell: bool = True,
133133
):
@@ -336,7 +336,9 @@ def init_app(self, app: Flask) -> None:
336336

337337
track_modifications._listen(self.session)
338338

339-
def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_session:
339+
def _make_scoped_session(
340+
self, options: dict[str, t.Any]
341+
) -> sa.orm.scoped_session[Session]:
340342
"""Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory
341343
from :meth:`_make_session_factory`. The result is available as :attr:`session`.
342344
@@ -363,7 +365,7 @@ def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_sessi
363365

364366
def _make_session_factory(
365367
self, options: dict[str, t.Any]
366-
) -> sa.orm.sessionmaker[Session]: # type: ignore[type-var]
368+
) -> sa.orm.sessionmaker[Session]:
367369
"""Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by
368370
:meth:`_make_scoped_session`.
369371
@@ -438,7 +440,7 @@ def _make_metadata(self, bind_key: str | None) -> sa.MetaData:
438440
self.metadatas[bind_key] = metadata
439441
return metadata
440442

441-
def _make_table_class(self) -> t.Type[sa.Table]:
443+
def _make_table_class(self) -> type[_Table]:
442444
"""Create a SQLAlchemy :class:`sqlalchemy.schema.Table` class that chooses a
443445
metadata automatically based on the ``bind_key``. The result is available as
444446
:attr:`Table`.
@@ -450,7 +452,7 @@ def _make_table_class(self) -> t.Type[sa.Table]:
450452
.. versionadded:: 3.0
451453
"""
452454

453-
class Table(sa.Table):
455+
class Table(_Table):
454456
def __new__(
455457
cls, *args: t.Any, bind_key: str | None = None, **kwargs: t.Any
456458
) -> Table:
@@ -475,13 +477,13 @@ def __new__(
475477
bind_key = kwargs["info"].get("bind_key")
476478

477479
metadata = self._make_metadata(bind_key)
478-
return super().__new__(cls, args[0], metadata, *args[1:], **kwargs)
480+
return super().__new__(cls, *[args[0], metadata, *args[1:]], **kwargs)
479481

480482
return Table
481483

482484
def _make_declarative_base(
483-
self, model: t.Type[Model] | sa.orm.DeclarativeMeta
484-
) -> t.Type[t.Any]:
485+
self, model: type[Model] | sa.orm.DeclarativeMeta
486+
) -> type[t.Any]:
485487
"""Create a SQLAlchemy declarative model class. The result is available as
486488
:attr:`Model`.
487489
@@ -728,7 +730,7 @@ def get_binds(self) -> dict[sa.Table, sa.engine.Engine]:
728730
}
729731

730732
def get_or_404(
731-
self, entity: t.Type[t.Any], ident: t.Any, *, description: str | None = None
733+
self, entity: type[t.Any], ident: t.Any, *, description: str | None = None
732734
) -> t.Any:
733735
"""Like :meth:`session.get() <sqlalchemy.orm.Session.get>` but aborts with a
734736
``404 Not Found`` error instead of returning ``None``.
@@ -747,7 +749,7 @@ def get_or_404(
747749
return value
748750

749751
def first_or_404(
750-
self, statement: sa.sql.Select, *, description: str | None = None
752+
self, statement: sa.sql.Select[t.Any], *, description: str | None = None
751753
) -> t.Any:
752754
"""Like :meth:`Result.scalar() <sqlalchemy.engine.Result.scalar>`, but aborts
753755
with a ``404 Not Found`` error instead of returning ``None``.
@@ -765,7 +767,7 @@ def first_or_404(
765767
return value
766768

767769
def one_or_404(
768-
self, statement: sa.sql.Select, *, description: str | None = None
770+
self, statement: sa.sql.Select[t.Any], *, description: str | None = None
769771
) -> t.Any:
770772
"""Like :meth:`Result.scalar_one() <sqlalchemy.engine.Result.scalar_one>`,
771773
but aborts with a ``404 Not Found`` error instead of raising ``NoResultFound``
@@ -783,7 +785,7 @@ def one_or_404(
783785

784786
def paginate(
785787
self,
786-
select: sa.sql.Select,
788+
select: sa.sql.Select[t.Any],
787789
*,
788790
page: int | None = None,
789791
per_page: int | None = None,
@@ -971,7 +973,8 @@ def _relation(
971973
"""
972974
# Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``.
973975
self._set_rel_query(kwargs)
974-
return sa.orm.relation(*args, **kwargs)
976+
f = sa.orm.relation # type: ignore[attr-defined]
977+
return f(*args, **kwargs) # type: ignore[no-any-return]
975978

976979
def __getattr__(self, name: str) -> t.Any:
977980
if name == "db":

src/flask_sqlalchemy/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ class _QueryProperty:
1919
"""
2020

2121
@t.overload
22-
def __get__(self, obj: None, cls: t.Type[Model]) -> Query:
22+
def __get__(self, obj: None, cls: type[Model]) -> Query:
2323
...
2424

2525
@t.overload
26-
def __get__(self, obj: Model, cls: t.Type[Model]) -> Query:
26+
def __get__(self, obj: Model, cls: type[Model]) -> Query:
2727
...
2828

29-
def __get__(self, obj: Model | None, cls: t.Type[Model]) -> Query:
29+
def __get__(self, obj: Model | None, cls: type[Model]) -> Query:
3030
return cls.query_class(
3131
cls, session=cls.__fsa__.session() # type: ignore[arg-type]
3232
)
@@ -47,7 +47,7 @@ class Model:
4747
:meta private:
4848
"""
4949

50-
query_class: t.ClassVar[t.Type[Query]] = Query
50+
query_class: t.ClassVar[type[Query]] = Query
5151
"""Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which
5252
defaults to :class:`.Query`.
5353
"""
@@ -63,6 +63,7 @@ class Model:
6363

6464
def __repr__(self) -> str:
6565
state = sa.inspect(self)
66+
assert state is not None
6667

6768
if state.transient:
6869
pk = f"(transient {id(self)})"

src/flask_sqlalchemy/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None:
2828
self._db = db
2929
self._model_changes: dict[object, tuple[t.Any, str]] = {}
3030

31-
def get_bind( # type: ignore[override]
31+
def get_bind(
3232
self,
3333
mapper: t.Any | None = None,
3434
clause: t.Any | None = None,

src/flask_sqlalchemy/table.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
import sqlalchemy as sa
6+
import sqlalchemy.sql.schema as sa_sql_schema
7+
8+
9+
class _Table(sa.Table):
10+
@t.overload
11+
def __init__(
12+
self,
13+
name: str,
14+
*args: sa_sql_schema.SchemaItem,
15+
bind_key: str | None = None,
16+
**kwargs: t.Any,
17+
) -> None:
18+
...
19+
20+
@t.overload
21+
def __init__(
22+
self,
23+
name: str,
24+
metadata: sa.MetaData,
25+
*args: sa_sql_schema.SchemaItem,
26+
**kwargs: t.Any,
27+
) -> None:
28+
...
29+
30+
@t.overload
31+
def __init__(
32+
self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any
33+
) -> None:
34+
...
35+
36+
def __init__(
37+
self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any
38+
) -> None:
39+
super().__init__(name, *args, **kwargs) # type: ignore[arg-type]

src/flask_sqlalchemy/track_modifications.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"""
3030

3131

32-
def _listen(session: sa.orm.scoped_session) -> None:
32+
def _listen(session: sa.orm.scoped_session[Session]) -> None:
3333
sa.event.listen(session, "before_flush", _record_ops, named=True)
3434
sa.event.listen(session, "before_commit", _record_ops, named=True)
3535
sa.event.listen(session, "before_commit", _before_commit)

tests/test_legacy_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class Parent(db.Model):
8282

8383
class Child(db.Model):
8484
id = sa.Column(sa.Integer, primary_key=True)
85-
parent_id = sa.Column(sa.ForeignKey(Parent.id))
85+
parent_id = sa.Column(sa.ForeignKey(Parent.id)) # type: ignore[var-annotated]
8686
parent2 = db.relationship(
8787
Parent,
8888
backref=db.backref("children2", lazy="dynamic", viewonly=True),
@@ -109,7 +109,7 @@ class Parent(db.Model):
109109

110110
class Child(db.Model):
111111
id = sa.Column(sa.Integer, primary_key=True)
112-
parent_id = sa.Column(sa.ForeignKey(Parent.id))
112+
parent_id = sa.Column(sa.ForeignKey(Parent.id)) # type: ignore[var-annotated]
113113
parent2 = db.relationship(
114114
Parent,
115115
backref=db.backref("children2", lazy="dynamic", viewonly=True),

tests/test_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_custom_metadata() -> None:
2727

2828
def test_metadata_from_custom_model() -> None:
2929
base = sa.orm.declarative_base(cls=Model, metaclass=DefaultMeta)
30-
metadata = base.metadata # type: ignore[attr-defined]
30+
metadata = base.metadata
3131
db = SQLAlchemy(model_class=base)
3232
assert db.Model.metadata is metadata
3333
assert db.Model.metadata is db.metadata

tests/test_model_name.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_mixin_attr(db: SQLAlchemy) -> None:
110110
"""
111111

112112
class Mixin:
113-
@sa.orm.declared_attr
113+
@sa.orm.declared_attr # type: ignore[arg-type]
114114
def __tablename__(cls) -> str: # noqa: B902
115115
return cls.__name__.upper() # type: ignore[attr-defined,no-any-return]
116116

@@ -210,7 +210,7 @@ class class_property:
210210
def __init__(self, f: t.Callable[..., t.Any]) -> None:
211211
self.f = f
212212

213-
def __get__(self, instance: t.Any, owner: t.Type[t.Any]) -> t.Any:
213+
def __get__(self, instance: t.Any, owner: type[t.Any]) -> t.Any:
214214
return self.f(owner)
215215

216216
class Duck(db.Model):
@@ -221,7 +221,7 @@ class ns:
221221
floats = False
222222

223223
class Witch(Duck):
224-
@sa.orm.declared_attr
224+
@sa.orm.declared_attr # type: ignore[arg-type]
225225
def is_duck(self) -> None:
226226
# declared attrs will be accessed during mapper configuration,
227227
# but make sure they're not accessed before that

0 commit comments

Comments
 (0)