Skip to content

Commit 0dcaff3

Browse files
committed
fix get_bind for polymorphic models
1 parent f93f339 commit 0dcaff3

File tree

3 files changed

+61
-10
lines changed

3 files changed

+61
-10
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Unreleased
55

66
- Show helpful errors when mistakenly using multiple ``SQLAlchemy`` instances for the
77
same app, or without calling ``init_app``. :pr:`1151`
8+
- Fix issue with getting the engine associated with a model that uses polymorphic
9+
table inheritance. :issue:`1155`
810

911

1012
Version 3.0.2

src/flask_sqlalchemy/session.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def get_bind( # type: ignore[override]
3838
"""Select an engine based on the ``bind_key`` of the metadata associated with
3939
the model or table being queried. If no bind key is set, uses the default bind.
4040
41+
.. versionchanged:: 3.0.3
42+
Fix finding the bind for a joined inheritance model.
43+
4144
.. versionchanged:: 3.0
4245
The implementation more closely matches the base SQLAlchemy implementation.
4346
@@ -47,6 +50,8 @@ def get_bind( # type: ignore[override]
4750
if bind is not None:
4851
return bind
4952

53+
engines = self._db.engines
54+
5055
if mapper is not None:
5156
try:
5257
mapper = sa.inspect(mapper)
@@ -56,26 +61,42 @@ def get_bind( # type: ignore[override]
5661

5762
raise
5863

59-
clause = mapper.persist_selectable
64+
engine = _clause_to_engine(mapper.local_table, engines)
6065

61-
engines = self._db.engines
62-
63-
if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info:
64-
key = clause.metadata.info["bind_key"]
66+
if engine is not None:
67+
return engine
6568

66-
if key not in engines:
67-
raise sa.exc.UnboundExecutionError(
68-
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
69-
)
69+
if clause is not None:
70+
engine = _clause_to_engine(clause, engines)
7071

71-
return engines[key]
72+
if engine is not None:
73+
return engine
7274

7375
if None in engines:
7476
return engines[None]
7577

7678
return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs)
7779

7880

81+
def _clause_to_engine(
82+
clause: t.Any | None, engines: t.Mapping[str | None, sa.engine.Engine]
83+
) -> sa.engine.Engine | None:
84+
"""If the clause is a table, return the engine associated with the table's
85+
metadata's bind key.
86+
"""
87+
if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info:
88+
key = clause.metadata.info["bind_key"]
89+
90+
if key not in engines:
91+
raise sa.exc.UnboundExecutionError(
92+
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
93+
)
94+
95+
return engines[key]
96+
97+
return None
98+
99+
79100
def _app_ctx_id() -> int:
80101
"""Get the id of the current Flask application context for the session scope."""
81102
return id(app_ctx._get_current_object()) # type: ignore[attr-defined]

tests/test_session.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,31 @@ class Post(db.Model):
6464

6565
assert db.session.get_bind(mapper=User) is db.engine
6666
assert db.session.get_bind(mapper=Post) is db.engines["a"]
67+
68+
69+
@pytest.mark.usefixtures("app_ctx")
70+
def test_get_bind_inheritance(app: Flask) -> None:
71+
app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"}
72+
db = SQLAlchemy(app)
73+
74+
class User(db.Model):
75+
__bind_key__ = "a"
76+
id = sa.Column(sa.Integer, primary_key=True)
77+
type = sa.Column(sa.String, nullable=False)
78+
79+
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}
80+
81+
class Admin(User):
82+
id = sa.Column( # type: ignore[assignment]
83+
sa.ForeignKey(User.id), primary_key=True
84+
)
85+
org = sa.Column(sa.String, nullable=False)
86+
87+
__mapper_args__ = {"polymorphic_identity": "admin"}
88+
89+
db.create_all()
90+
db.session.add(Admin(org="pallets"))
91+
db.session.commit()
92+
admin = db.session.execute(db.select(Admin)).scalar_one()
93+
db.session.expire(admin)
94+
assert admin.org == "pallets"

0 commit comments

Comments
 (0)