Skip to content

Commit 3a89561

Browse files
committed
rename some ref for corrent type hint
1 parent 41dccc7 commit 3a89561

File tree

12 files changed

+60
-53
lines changed

12 files changed

+60
-53
lines changed

src/flask_sqlalchemy/cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
from flask import current_app
66

77

8+
if t.TYPE_CHECKING:
9+
from .extension import SQLAlchemy as SQLAlchemy
10+
else:
11+
SQLAlchemy = t.Any
12+
13+
814
def add_models_to_shell() -> dict[str, t.Any]:
915
"""Registered with :meth:`~flask.Flask.shell_context_processor` if
1016
``add_models_to_shell`` is enabled. Adds the ``db`` instance and all model classes
1117
to ``flask shell``.
1218
"""
13-
db = current_app.extensions["sqlalchemy"]
19+
db: SQLAlchemy = current_app.extensions["sqlalchemy"]
1420
out = {m.class_.__name__: m.class_ for m in db.Model._sa_registry.mappers}
1521
out["db"] = db
1622
return out

src/flask_sqlalchemy/extension.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
metadata: sa.MetaData | None = None,
128128
session_options: dict[str, t.Any] | None = None,
129129
query_class: type[Query] = Query,
130-
model_class: type[Model] | sa.orm.DeclarativeMeta = Model,
130+
model_class: type[Model] | sqlalchemy.orm.DeclarativeMeta = Model,
131131
engine_options: dict[str, t.Any] | None = None,
132132
add_models_to_shell: bool = True,
133133
):
@@ -338,7 +338,7 @@ def init_app(self, app: Flask) -> None:
338338

339339
def _make_scoped_session(
340340
self, options: dict[str, t.Any]
341-
) -> sa.orm.scoped_session[Session]:
341+
) -> sqlalchemy.orm.scoped_session[Session]:
342342
"""Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory
343343
from :meth:`_make_session_factory`. The result is available as :attr:`session`.
344344
@@ -361,11 +361,11 @@ def _make_scoped_session(
361361
"""
362362
scope = options.pop("scopefunc", _app_ctx_id)
363363
factory = self._make_session_factory(options)
364-
return sa.orm.scoped_session(factory, scope)
364+
return sqlalchemy.orm.scoped_session(factory, scope)
365365

366366
def _make_session_factory(
367367
self, options: dict[str, t.Any]
368-
) -> sa.orm.sessionmaker[Session]:
368+
) -> sqlalchemy.orm.sessionmaker[Session]:
369369
"""Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by
370370
:meth:`_make_scoped_session`.
371371
@@ -388,7 +388,7 @@ def _make_session_factory(
388388
"""
389389
options.setdefault("class_", Session)
390390
options.setdefault("query_cls", self.Query)
391-
return sa.orm.sessionmaker(db=self, **options)
391+
return sqlalchemy.orm.sessionmaker(db=self, **options)
392392

393393
def _teardown_commit(self, exc: BaseException | None) -> None:
394394
"""Commit the session at the end of the request if there was not an unhandled
@@ -482,7 +482,7 @@ def __new__(
482482
return Table
483483

484484
def _make_declarative_base(
485-
self, model: type[Model] | sa.orm.DeclarativeMeta
485+
self, model: type[Model] | sqlalchemy.orm.DeclarativeMeta
486486
) -> type[t.Any]:
487487
"""Create a SQLAlchemy declarative model class. The result is available as
488488
:attr:`Model`.
@@ -503,9 +503,9 @@ def _make_declarative_base(
503503
.. versionchanged:: 2.3
504504
``model`` can be an already created declarative model class.
505505
"""
506-
if not isinstance(model, sa.orm.DeclarativeMeta):
506+
if not isinstance(model, sqlalchemy.orm.DeclarativeMeta):
507507
metadata = self._make_metadata(None)
508-
model = sa.orm.declarative_base(
508+
model = sqlalchemy.orm.declarative_base(
509509
metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta
510510
)
511511

@@ -780,7 +780,7 @@ def one_or_404(
780780
"""
781781
try:
782782
return self.session.execute(statement).scalar_one()
783-
except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound):
783+
except (sqlalchemy.exc.NoResultFound, sqlalchemy.exc.MultipleResultsFound):
784784
abort(404, description=description)
785785

786786
def paginate(
@@ -859,7 +859,7 @@ def _call_for_binds(
859859
if key is None:
860860
message = f"'SQLALCHEMY_DATABASE_URI' config is not set. {message}"
861861

862-
raise sa.exc.UnboundExecutionError(message) from None
862+
raise sqlalchemy.exc.UnboundExecutionError(message) from None
863863

864864
metadata = self.metadatas[key]
865865
getattr(metadata, op_name)(bind=engine)
@@ -936,31 +936,31 @@ def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None:
936936

937937
def relationship(
938938
self, *args: t.Any, **kwargs: t.Any
939-
) -> sa.orm.RelationshipProperty[t.Any]:
939+
) -> sqlalchemy.orm.RelationshipProperty[t.Any]:
940940
"""A :func:`sqlalchemy.orm.relationship` that applies this extension's
941941
:attr:`Query` class for dynamic relationships and backrefs.
942942
943943
.. versionchanged:: 3.0
944944
The :attr:`Query` class is set on ``backref``.
945945
"""
946946
self._set_rel_query(kwargs)
947-
return sa.orm.relationship(*args, **kwargs)
947+
return sqlalchemy.orm.relationship(*args, **kwargs)
948948

949949
def dynamic_loader(
950950
self, argument: t.Any, **kwargs: t.Any
951-
) -> sa.orm.RelationshipProperty[t.Any]:
951+
) -> sqlalchemy.orm.RelationshipProperty[t.Any]:
952952
"""A :func:`sqlalchemy.orm.dynamic_loader` that applies this extension's
953953
:attr:`Query` class for relationships and backrefs.
954954
955955
.. versionchanged:: 3.0
956956
The :attr:`Query` class is set on ``backref``.
957957
"""
958958
self._set_rel_query(kwargs)
959-
return sa.orm.dynamic_loader(argument, **kwargs)
959+
return sqlalchemy.orm.dynamic_loader(argument, **kwargs)
960960

961961
def _relation(
962962
self, *args: t.Any, **kwargs: t.Any
963-
) -> sa.orm.RelationshipProperty[t.Any]:
963+
) -> sqlalchemy.orm.RelationshipProperty[t.Any]:
964964
"""A :func:`sqlalchemy.orm.relationship` that applies this extension's
965965
:attr:`Query` class for dynamic relationships and backrefs.
966966
@@ -973,7 +973,7 @@ def _relation(
973973
"""
974974
# Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``.
975975
self._set_rel_query(kwargs)
976-
f = sa.orm.relation # type: ignore[attr-defined]
976+
f = sqlalchemy.orm.relation # type: ignore[attr-defined]
977977
return f(*args, **kwargs) # type: ignore[no-any-return]
978978

979979
def __getattr__(self, name: str) -> t.Any:
@@ -993,12 +993,12 @@ def __getattr__(self, name: str) -> t.Any:
993993
return self._relation
994994

995995
if name == "event":
996-
return sa.event
996+
return sqlalchemy.event
997997

998998
if name.startswith("_"):
999999
raise AttributeError(name)
10001000

1001-
for mod in (sa, sa.orm):
1001+
for mod in (sa, sqlalchemy.orm):
10021002
if hasattr(mod, name):
10031003
return getattr(mod, name)
10041004

src/flask_sqlalchemy/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,21 @@ def should_set_tablename(cls: type) -> bool:
182182
joined-table inheritance. If no primary key is found, the name will be unset.
183183
"""
184184
if cls.__dict__.get("__abstract__", False) or not any(
185-
isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]
185+
isinstance(b, sqlalchemy.orm.DeclarativeMeta) for b in cls.__mro__[1:]
186186
):
187187
return False
188188

189189
for base in cls.__mro__:
190190
if "__tablename__" not in base.__dict__:
191191
continue
192192

193-
if isinstance(base.__dict__["__tablename__"], sa.orm.declared_attr):
193+
if isinstance(base.__dict__["__tablename__"], sqlalchemy.orm.declared_attr):
194194
return False
195195

196196
return not (
197197
base is cls
198198
or base.__dict__.get("__abstract__", False)
199-
or not isinstance(base, sa.orm.DeclarativeMeta)
199+
or not isinstance(base, sqlalchemy.orm.DeclarativeMeta)
200200
)
201201

202202
return True
@@ -208,7 +208,7 @@ def camel_to_snake_case(name: str) -> str:
208208
return name.lower().lstrip("_")
209209

210210

211-
class DefaultMeta(BindMetaMixin, NameMetaMixin, sa.orm.DeclarativeMeta):
211+
class DefaultMeta(BindMetaMixin, NameMetaMixin, sqlalchemy.orm.DeclarativeMeta):
212212
"""SQLAlchemy declarative metaclass that provides ``__bind_key__`` and
213213
``__tablename__`` support.
214214
"""

src/flask_sqlalchemy/pagination.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _query_items(self) -> list[t.Any]:
336336

337337
def _query_count(self) -> int:
338338
select = self._query_args["select"]
339-
sub = select.options(sa.orm.lazyload("*")).order_by(None).subquery()
339+
sub = select.options(sqlalchemy.orm.lazyload("*")).order_by(None).subquery()
340340
session = self._query_args["session"]
341341
out = session.execute(sa.select(sa.func.count()).select_from(sub)).scalar()
342342
return out # type: ignore[no-any-return]

src/flask_sqlalchemy/query.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import typing as t
44

5-
import sqlalchemy as sa
65
import sqlalchemy.exc
76
import sqlalchemy.orm
87
from flask import abort
@@ -11,7 +10,7 @@
1110
from .pagination import QueryPagination
1211

1312

14-
class Query(sa.orm.Query): # type: ignore[type-arg]
13+
class Query(sqlalchemy.orm.Query[t.Any]):
1514
"""SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods
1615
useful for querying in a web application.
1716
@@ -58,7 +57,7 @@ def one_or_404(self, description: str | None = None) -> t.Any:
5857
"""
5958
try:
6059
return self.one()
61-
except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound):
60+
except (sqlalchemy.exc.NoResultFound, sqlalchemy.exc.MultipleResultsFound):
6261
abort(404, description=description)
6362

6463
def paginate(

src/flask_sqlalchemy/record_queries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def __getitem__(self, key: int) -> object:
9696

9797

9898
def _listen(engine: sa.engine.Engine) -> None:
99-
sa.event.listen(engine, "before_cursor_execute", _record_start, named=True)
100-
sa.event.listen(engine, "after_cursor_execute", _record_end, named=True)
99+
sqlalchemy.event.listen(engine, "before_cursor_execute", _record_start, named=True)
100+
sqlalchemy.event.listen(engine, "after_cursor_execute", _record_end, named=True)
101101

102102

103103
def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None:

src/flask_sqlalchemy/session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .extension import SQLAlchemy
1212

1313

14-
class Session(sa.orm.Session):
14+
class Session(sqlalchemy.orm.Session):
1515
"""A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to
1616
use based on the bind key associated with the metadata associated with the thing
1717
being queried.
@@ -55,9 +55,9 @@ def get_bind(
5555
if mapper is not None:
5656
try:
5757
mapper = sa.inspect(mapper)
58-
except sa.exc.NoInspectionAvailable as e:
58+
except sqlalchemy.exc.NoInspectionAvailable as e:
5959
if isinstance(mapper, type):
60-
raise sa.orm.exc.UnmappedClassError(mapper) from e
60+
raise sqlalchemy.orm.exc.UnmappedClassError(mapper) from e
6161

6262
raise
6363

@@ -88,7 +88,7 @@ def _clause_to_engine(
8888
key = clause.metadata.info["bind_key"]
8989

9090
if key not in engines:
91-
raise sa.exc.UnboundExecutionError(
91+
raise sqlalchemy.exc.UnboundExecutionError(
9292
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
9393
)
9494

src/flask_sqlalchemy/track_modifications.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
"""
3030

3131

32-
def _listen(session: sa.orm.scoped_session[Session]) -> None:
33-
sa.event.listen(session, "before_flush", _record_ops, named=True)
34-
sa.event.listen(session, "before_commit", _record_ops, named=True)
35-
sa.event.listen(session, "before_commit", _before_commit)
36-
sa.event.listen(session, "after_commit", _after_commit)
37-
sa.event.listen(session, "after_rollback", _after_rollback)
32+
def _listen(session: sqlalchemy.orm.scoped_session[Session]) -> None:
33+
sqlalchemy.event.listen(session, "before_flush", _record_ops, named=True)
34+
sqlalchemy.event.listen(session, "before_commit", _record_ops, named=True)
35+
sqlalchemy.event.listen(session, "before_commit", _before_commit)
36+
sqlalchemy.event.listen(session, "after_commit", _after_commit)
37+
sqlalchemy.event.listen(session, "after_rollback", _after_rollback)
3838

3939

4040
def _record_ops(session: Session, **kwargs: t.Any) -> None:

tests/test_legacy_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
@pytest.fixture(autouse=True)
1717
def ignore_query_warning() -> t.Generator[None, None, None]:
18-
if hasattr(sa.exc, "LegacyAPIWarning"):
18+
if hasattr(sqlalchemy.exc, "LegacyAPIWarning"):
1919
with warnings.catch_warnings():
20-
exc = sa.exc.LegacyAPIWarning
20+
exc = sqlalchemy.exc.LegacyAPIWarning
2121
warnings.simplefilter("ignore", exc)
2222
yield
2323
else:

tests/test_metadata.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def test_custom_metadata() -> None:
2626

2727

2828
def test_metadata_from_custom_model() -> None:
29-
base = sa.orm.declarative_base(cls=Model, metaclass=DefaultMeta)
29+
base = sqlalchemy.orm.declarative_base(cls=Model, metaclass=DefaultMeta)
3030
metadata = base.metadata
3131
db = SQLAlchemy(model_class=base)
3232
assert db.Model.metadata is metadata
3333
assert db.Model.metadata is db.metadata
3434

3535

3636
def test_custom_metadata_overrides_custom_model() -> None:
37-
base = sa.orm.declarative_base(cls=Model, metaclass=DefaultMeta)
37+
base = sqlalchemy.orm.declarative_base(cls=Model, metaclass=DefaultMeta)
3838
metadata = sa.MetaData()
3939
db = SQLAlchemy(model_class=base, metadata=metadata)
4040
assert db.Model.metadata is metadata
@@ -69,21 +69,21 @@ class Post(db.Model):
6969
__bind_key__ = "a"
7070
id = sa.Column(sa.Integer, primary_key=True)
7171

72-
with pytest.raises(sa.exc.OperationalError):
72+
with pytest.raises(sqlalchemy.exc.OperationalError):
7373
db.session.execute(sa.select(User)).scalars()
7474

75-
with pytest.raises(sa.exc.OperationalError):
75+
with pytest.raises(sqlalchemy.exc.OperationalError):
7676
db.session.execute(sa.select(Post)).scalars()
7777

7878
db.create_all()
7979
db.session.execute(sa.select(User)).scalars()
8080
db.session.execute(sa.select(Post)).scalars()
8181
db.drop_all()
8282

83-
with pytest.raises(sa.exc.OperationalError):
83+
with pytest.raises(sqlalchemy.exc.OperationalError):
8484
db.session.execute(sa.select(User)).scalars()
8585

86-
with pytest.raises(sa.exc.OperationalError):
86+
with pytest.raises(sqlalchemy.exc.OperationalError):
8787
db.session.execute(sa.select(Post)).scalars()
8888

8989

@@ -103,7 +103,7 @@ class Post(db.Model):
103103
db.create_all(bind_key=bind_key)
104104
db.session.execute(sa.select(Post)).scalars()
105105

106-
with pytest.raises(sa.exc.OperationalError):
106+
with pytest.raises(sqlalchemy.exc.OperationalError):
107107
db.session.execute(sa.select(User)).scalars()
108108

109109

0 commit comments

Comments
 (0)