Skip to content

Commit 74b7a87

Browse files
authored
lint: use type narrowing for orm.object_session() (#17582)
Since `object_session(...)` can return `None`, we need to help mypy have confidence that the session is indeed there. One way to narrow the type is to assert that it's not `None`, and therefore should not alert about `union-attr` problems. Includes a couple of renames, and variable extraction where relevant. Signed-off-by: Mike Fiedler <[email protected]>
1 parent 1a5b4fe commit 74b7a87

File tree

9 files changed

+96
-40
lines changed

9 files changed

+96
-40
lines changed

tests/unit/utils/db/test_orm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import pytest
14+
15+
from sqlalchemy.orm import object_session
16+
17+
from warehouse.db import Model
18+
from warehouse.utils.db.orm import NoSessionError, orm_session_from_obj
19+
20+
21+
def test_orm_session_from_obj_raises_with_no_session():
22+
23+
class FakeObject(Model):
24+
__tablename__ = "fake_object"
25+
26+
obj = FakeObject()
27+
# Confirm that the object does not have a session with the built-in
28+
assert object_session(obj) is None
29+
30+
with pytest.raises(NoSessionError):
31+
orm_session_from_obj(obj)

warehouse/accounts/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from warehouse.observations.models import HasObservations, HasObservers, ObservationKind
4141
from warehouse.sitemap.models import SitemapMixin
4242
from warehouse.utils.attrs import make_repr
43+
from warehouse.utils.db import orm_session_from_obj
4344
from warehouse.utils.db.types import TZDateTime, bool_false, datetime_now
4445

4546
if TYPE_CHECKING:
@@ -236,7 +237,7 @@ def has_primary_verified_email(self):
236237

237238
@property
238239
def recent_events(self):
239-
session = orm.object_session(self)
240+
session = orm_session_from_obj(self)
240241
last_ninety = datetime.datetime.now() - datetime.timedelta(days=90)
241242
return (
242243
session.query(User.Event)

warehouse/cache/origin/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
from itertools import chain
1818

19-
from sqlalchemy.orm.session import Session
20-
2119
from warehouse import db
2220
from warehouse.cache.origin.derivers import html_cache_deriver
2321
from warehouse.cache.origin.interfaces import IOriginCache
22+
from warehouse.utils.db import orm_session_from_obj
2423

2524

2625
@db.listens_for(db.Session, "after_flush")
@@ -139,7 +138,7 @@ def register_origin_cache_keys(config, klass, cache_keys=None, purge_keys=None):
139138

140139
def receive_set(attribute, config, target):
141140
cache_keys = config.registry["cache_keys"]
142-
session = Session.object_session(target)
141+
session = orm_session_from_obj(target)
143142
purges = session.info.setdefault("warehouse.cache.origin.purges", set())
144143
key_maker = cache_keys[attribute]
145144
keys = key_maker(target).purge

warehouse/email/ses/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from sqlalchemy.dialects.postgresql import JSONB, UUID as PG_UUID
2121
from sqlalchemy.ext.mutable import MutableDict
2222
from sqlalchemy.orm import Mapped, mapped_column
23-
from sqlalchemy.orm.session import object_session
2423

2524
from warehouse import db
2625
from warehouse.accounts.models import Email as EmailAddress, UnverifyReasons
26+
from warehouse.utils.db import orm_session_from_obj
2727
from warehouse.utils.db.types import bool_false, datetime_now
2828

2929
MAX_TRANSIENT_BOUNCES = 5
@@ -217,9 +217,9 @@ def _get_email(self):
217217
if self._email_message.missing:
218218
return
219219

220-
db = object_session(self._email_message)
220+
session = orm_session_from_obj(self._email_message)
221221
email = (
222-
db.query(EmailAddress)
222+
session.query(EmailAddress)
223223
.filter(EmailAddress.email == self._email_message.to)
224224
.first()
225225
)

warehouse/legacy/api/xmlrpc/cache/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from pyramid.exceptions import ConfigurationError
1616
from sqlalchemy.orm.base import NO_VALUE
17-
from sqlalchemy.orm.session import Session
1817
from urllib3.util import parse_url
1918

2019
from warehouse import db
@@ -23,6 +22,7 @@
2322
from warehouse.legacy.api.xmlrpc.cache.fncache import RedisLru
2423
from warehouse.legacy.api.xmlrpc.cache.interfaces import IXMLRPCCache
2524
from warehouse.legacy.api.xmlrpc.cache.services import NullXMLRPCCache, RedisXMLRPCCache
25+
from warehouse.utils.db import orm_session_from_obj
2626

2727
__all__ = ["RedisLru"]
2828

@@ -32,7 +32,7 @@
3232

3333
def receive_set(attribute, config, target):
3434
cache_keys = config.registry["cache_keys"]
35-
session = Session.object_session(target)
35+
session = orm_session_from_obj(target)
3636
purges = session.info.setdefault("warehouse.legacy.api.xmlrpc.cache.purges", set())
3737
key_maker = cache_keys[attribute]
3838
keys = key_maker(target).purge

warehouse/organizations/models.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from warehouse.authnz import Permissions
4444
from warehouse.events.models import HasEvents
4545
from warehouse.utils.attrs import make_repr
46+
from warehouse.utils.db import orm_session_from_obj
4647
from warehouse.utils.db.types import TZDateTime, bool_false, datetime_now
4748

4849
if typing.TYPE_CHECKING:
@@ -332,22 +333,17 @@ class Organization(OrganizationMixin, HasEvents, db.Model):
332333
@property
333334
def owners(self):
334335
"""Return all users who are owners of the organization."""
336+
session = orm_session_from_obj(self)
335337
owner_roles = (
336-
orm.object_session(self)
337-
.query(User.id)
338+
session.query(User.id)
338339
.join(OrganizationRole.user)
339340
.filter(
340341
OrganizationRole.role_name == OrganizationRoleType.Owner,
341342
OrganizationRole.organization == self,
342343
)
343344
.subquery()
344345
)
345-
return (
346-
orm.object_session(self)
347-
.query(User)
348-
.join(owner_roles, User.id == owner_roles.c.id)
349-
.all()
350-
)
346+
return session.query(User).join(owner_roles, User.id == owner_roles.c.id).all()
351347

352348
def record_event(self, *, tag, request: Request = None, additional=None):
353349
"""Record organization name in events in case organization is ever deleted."""
@@ -358,7 +354,7 @@ def record_event(self, *, tag, request: Request = None, additional=None):
358354
)
359355

360356
def __acl__(self):
361-
session = orm.object_session(self)
357+
session = orm_session_from_obj(self)
362358

363359
acls = [
364360
(

warehouse/packaging/models.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
from warehouse.sitemap.models import SitemapMixin
8080
from warehouse.utils import dotted_navigator, wheel
8181
from warehouse.utils.attrs import make_repr
82+
from warehouse.utils.db import orm_session_from_obj
8283
from warehouse.utils.db.types import bool_false, bool_true, datetime_now
8384

8485
if typing.TYPE_CHECKING:
@@ -257,7 +258,7 @@ class Project(SitemapMixin, HasEvents, HasObservations, db.Model):
257258
)
258259

259260
def __getitem__(self, version):
260-
session = orm.object_session(self)
261+
session = orm_session_from_obj(self)
261262
canonical_version = packaging.utils.canonicalize_version(version)
262263

263264
try:
@@ -288,7 +289,7 @@ def __getitem__(self, version):
288289
raise KeyError from None
289290

290291
def __acl__(self):
291-
session = orm.object_session(self)
292+
session = orm_session_from_obj(self)
292293
acls = [
293294
# TODO: Similar to `warehouse.accounts.models.User.__acl__`, we express the
294295
# permissions here in terms of the permissions that the user has on
@@ -417,42 +418,36 @@ def documentation_url(self):
417418
@property
418419
def owners(self):
419420
"""Return all users who are owners of the project."""
421+
session = orm_session_from_obj(self)
420422
owner_roles = (
421-
orm.object_session(self)
422-
.query(User.id)
423+
session.query(User.id)
423424
.join(Role.user)
424425
.filter(Role.role_name == "Owner", Role.project == self)
425426
.subquery()
426427
)
427-
return (
428-
orm.object_session(self)
429-
.query(User)
430-
.join(owner_roles, User.id == owner_roles.c.id)
431-
.all()
432-
)
428+
return session.query(User).join(owner_roles, User.id == owner_roles.c.id).all()
433429

434430
@property
435431
def maintainers(self):
436432
"""Return all users who are maintainers of the project."""
433+
session = orm_session_from_obj(self)
437434
maintainer_roles = (
438-
orm.object_session(self)
439-
.query(User.id)
435+
session.query(User.id)
440436
.join(Role.user)
441437
.filter(Role.role_name == "Maintainer", Role.project == self)
442438
.subquery()
443439
)
444440
return (
445-
orm.object_session(self)
446-
.query(User)
441+
session.query(User)
447442
.join(maintainer_roles, User.id == maintainer_roles.c.id)
448443
.all()
449444
)
450445

451446
@property
452447
def all_versions(self):
448+
session = orm_session_from_obj(self)
453449
return (
454-
orm.object_session(self)
455-
.query(
450+
session.query(
456451
Release.version,
457452
Release.created,
458453
Release.is_prerelease,
@@ -466,9 +461,9 @@ def all_versions(self):
466461

467462
@property
468463
def latest_version(self):
464+
session = orm_session_from_obj(self)
469465
return (
470-
orm.object_session(self)
471-
.query(Release.version, Release.created, Release.is_prerelease)
466+
session.query(Release.version, Release.created, Release.is_prerelease)
472467
.filter(Release.project == self, Release.yanked.is_(False))
473468
.order_by(Release.is_prerelease.nullslast(), Release._pypi_ordering.desc())
474469
.first()
@@ -477,7 +472,7 @@ def latest_version(self):
477472
@property
478473
def active_releases(self):
479474
return (
480-
orm.object_session(self)
475+
orm_session_from_obj(self)
481476
.query(Release)
482477
.filter(Release.project == self, Release.yanked.is_(False))
483478
.order_by(Release._pypi_ordering.desc())
@@ -487,7 +482,7 @@ def active_releases(self):
487482
@property
488483
def yanked_releases(self):
489484
return (
490-
orm.object_session(self)
485+
orm_session_from_obj(self)
491486
.query(Release)
492487
.filter(Release.project == self, Release.yanked.is_(True))
493488
.order_by(Release._pypi_ordering.desc())
@@ -747,7 +742,7 @@ def __table_args__(cls): # noqa
747742
uploaded_via: Mapped[str | None]
748743

749744
def __getitem__(self, filename: str) -> File:
750-
session: orm.Session = orm.object_session(self) # type: ignore[assignment]
745+
session = orm_session_from_obj(self)
751746

752747
try:
753748
return (

warehouse/utils/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
from warehouse.utils.db.orm import orm_session_from_obj
1314
from warehouse.utils.db.query_printer import print_query
1415

15-
__all__ = ["print_query"]
16+
__all__ = ["orm_session_from_obj", "print_query"]

warehouse/utils/db/orm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
"""ORM utilities."""
14+
15+
from sqlalchemy.orm import Session, object_session
16+
17+
18+
class NoSessionError(Exception):
19+
"""Raised when there is no active SQLAlchemy session"""
20+
21+
22+
def orm_session_from_obj(obj) -> Session:
23+
"""
24+
Returns the session from the ORM object.
25+
26+
Adds guard, but it should never happen.
27+
The guard helps with type hinting, as the object_session function
28+
returns Optional[Session] type.
29+
"""
30+
session = object_session(obj)
31+
if not session:
32+
raise NoSessionError("Object does not have a session")
33+
return session

0 commit comments

Comments
 (0)