Skip to content

Commit 9e19238

Browse files
Fixed clause_to_engine bug not finding the right db bind key (#1211)
Co-authored-by: Pamela Fox <[email protected]>
1 parent e5397f5 commit 9e19238

File tree

3 files changed

+78
-3
lines changed

3 files changed

+78
-3
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Unreleased
88
- Bump minimum version of SQLAlchemy to 2.0.16.
99
- Remove previously deprecated code.
1010
- Pass extra keyword arguments from ``get_or_404`` to ``session.get``. :issue:`1149`
11+
- Fix bug with finding right bind key for clause statements. :issue:`1211`
1112

1213

1314
Version 3.0.5

src/flask_sqlalchemy/session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,22 @@ def get_bind(
7979

8080

8181
def _clause_to_engine(
82-
clause: t.Any | None, engines: t.Mapping[str | None, sa.engine.Engine]
82+
clause: sa.ClauseElement | None,
83+
engines: t.Mapping[str | None, sa.engine.Engine],
8384
) -> sa.engine.Engine | None:
8485
"""If the clause is a table, return the engine associated with the table's
8586
metadata's bind key.
8687
"""
87-
if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info:
88-
key = clause.metadata.info["bind_key"]
88+
table = None
89+
90+
if clause is not None:
91+
if isinstance(clause, sa.Table):
92+
table = clause
93+
elif isinstance(clause, sa.UpdateBase) and isinstance(clause.table, sa.Table):
94+
table = clause.table
95+
96+
if table is not None and "bind_key" in table.metadata.info:
97+
key = table.metadata.info["bind_key"]
8998

9099
if key not in engines:
91100
raise sa_exc.UnboundExecutionError(

tests/test_session.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,68 @@ class Admin(User): # type: ignore[no-redef]
140140
admin = db.session.execute(db.select(Admin)).scalar_one()
141141
db.session.expire(admin)
142142
assert admin.org == "pallets"
143+
144+
145+
@pytest.mark.usefixtures("app_ctx")
146+
def test_session_multiple_dbs(app: Flask, model_class: t.Any) -> None:
147+
app.config["SQLALCHEMY_BINDS"] = {"db1": "sqlite:///"}
148+
db = SQLAlchemy(app, model_class=model_class)
149+
150+
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):
151+
152+
class User(db.Model):
153+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
154+
sa.Integer, primary_key=True, init=False
155+
)
156+
name: sa_orm.Mapped[str] = sa_orm.mapped_column(
157+
sa.String(50), nullable=False, init=False
158+
)
159+
160+
class Product(db.Model):
161+
__bind_key__ = "db1"
162+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
163+
sa.Integer, primary_key=True, init=False
164+
)
165+
name: sa_orm.Mapped[str] = sa_orm.mapped_column(
166+
sa.String(50), nullable=False, init=False
167+
)
168+
169+
elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
170+
171+
class User(db.Model): # type: ignore[no-redef]
172+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
173+
name: sa_orm.Mapped[str] = sa_orm.mapped_column(
174+
sa.String(50), nullable=False
175+
)
176+
177+
class Product(db.Model): # type: ignore[no-redef]
178+
__bind_key__ = "db1"
179+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True)
180+
name: sa_orm.Mapped[str] = sa_orm.mapped_column(
181+
sa.String(50), nullable=False
182+
)
183+
184+
else:
185+
186+
class User(db.Model): # type: ignore[no-redef]
187+
id = sa.Column(sa.Integer, primary_key=True)
188+
name = sa.Column(sa.String(50), nullable=False)
189+
190+
class Product(db.Model): # type: ignore[no-redef]
191+
__bind_key__ = "db1"
192+
id = sa.Column(sa.Integer, primary_key=True)
193+
name = sa.Column(sa.String(50), nullable=False)
194+
195+
db.create_all()
196+
197+
db.session.execute(User.__table__.insert(), [{"name": "User1"}, {"name": "User2"}])
198+
db.session.commit()
199+
users = db.session.execute(db.select(User)).scalars().all()
200+
assert len(users) == 2
201+
202+
db.session.execute(
203+
Product.__table__.insert(), [{"name": "Product1"}, {"name": "Product2"}]
204+
)
205+
db.session.commit()
206+
products = db.session.execute(db.select(Product)).scalars().all()
207+
assert len(products) == 2

0 commit comments

Comments
 (0)