Skip to content

Commit eb170b9

Browse files
jdimmermandavidism
authored andcommitted
get_or_404 passes kwargs to session.get
1 parent 02e1857 commit eb170b9

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Version 3.1.0
44
Unreleased
55

66
- Remove previously deprecated code.
7+
- Pass extra keyword arguments from ``get_or_404`` to ``session.get``. :issue:`1149`
78

89

910
Version 3.0.3

src/flask_sqlalchemy/extension.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,18 +597,27 @@ def engine(self) -> sa.engine.Engine:
597597
return self.engines[None]
598598

599599
def get_or_404(
600-
self, entity: t.Type[t.Any], ident: t.Any, *, description: str | None = None
600+
self,
601+
entity: t.Type[t.Any],
602+
ident: t.Any,
603+
*,
604+
description: str | None = None,
605+
**kwargs: t.Any,
601606
) -> t.Any:
602607
"""Like :meth:`session.get() <sqlalchemy.orm.Session.get>` but aborts with a
603608
``404 Not Found`` error instead of returning ``None``.
604609
605610
:param entity: The model class to query.
606611
:param ident: The primary key to query.
607612
:param description: A custom message to show on the error page.
613+
:param kwargs: Extra arguments passed to ``session.get()``.
614+
615+
.. versionchanged:: 3.1
616+
Pass extra keyword arguments to ``session.get()``.
608617
609618
.. versionadded:: 3.0
610619
"""
611-
value = self.session.get(entity, ident)
620+
value = self.session.get(entity, ident, **kwargs)
612621

613622
if value is None:
614623
abort(404, description=description)

tests/test_extension_object.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
import pytest
6+
import sqlalchemy as sa
7+
from flask import Flask
8+
from sqlalchemy.orm import joinedload
9+
from werkzeug.exceptions import NotFound
10+
11+
from flask_sqlalchemy import SQLAlchemy
12+
from flask_sqlalchemy.record_queries import get_recorded_queries
13+
14+
15+
@pytest.mark.usefixtures("app_ctx")
16+
def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None:
17+
item = Todo()
18+
db.session.add(item)
19+
db.session.commit()
20+
assert db.get_or_404(Todo, 1) is item
21+
22+
with pytest.raises(NotFound):
23+
db.get_or_404(Todo, 2)
24+
25+
26+
def test_get_or_404_kwargs(app: Flask) -> None:
27+
app.config["SQLALCHEMY_RECORD_QUERIES"] = True
28+
db = SQLAlchemy(app)
29+
30+
class User(db.Model):
31+
id = sa.Column(db.Integer, primary_key=True)
32+
33+
class Todo(db.Model):
34+
id = sa.Column(sa.Integer, primary_key=True)
35+
user_id = sa.Column(sa.ForeignKey(User.id))
36+
user = db.relationship(User)
37+
38+
with app.app_context():
39+
db.create_all()
40+
db.session.add(Todo(user=User()))
41+
db.session.commit()
42+
43+
with app.app_context():
44+
item = db.get_or_404(Todo, 1, options=[joinedload(Todo.user)])
45+
assert item.user.id == 1
46+
# one query with join, no second query when accessing relationship
47+
assert len(get_recorded_queries()) == 1

0 commit comments

Comments
 (0)