Skip to content

Commit fe861c1

Browse files
authored
Added the timezone to dates when deserializing from the database. (#108)
1 parent fb7af6a commit fe861c1

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

backend/beets_flask/database/models/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from __future__ import annotations
22

33
from datetime import datetime
4-
from typing import Any, List, Mapping, Self, Sequence, TypedDict
4+
from typing import Mapping, Self, Sequence
55
from uuid import uuid4
66

7-
from beets.importer import ImportTask, library
8-
from sqlalchemy import Index, LargeBinary, select
7+
import pytz
8+
from sqlalchemy import LargeBinary, select
99
from sqlalchemy.orm import (
1010
DeclarativeBase,
1111
Mapped,
1212
Session,
1313
mapped_column,
14+
reconstructor,
1415
registry,
1516
)
1617
from sqlalchemy.sql import func
@@ -24,6 +25,7 @@ class Base(DeclarativeBase):
2425
registry = registry(type_annotation_map={bytes: LargeBinary})
2526

2627
id: Mapped[str] = mapped_column(primary_key=True)
28+
2729
created_at: Mapped[datetime] = mapped_column(default=func.now(), index=True)
2830
updated_at: Mapped[datetime] = mapped_column(
2931
default=func.now(), onupdate=func.now()
@@ -96,3 +98,13 @@ def exist_all_ids(cls, ids: list[str], session: Session) -> bool:
9698

9799
def to_dict(self) -> Mapping:
98100
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
101+
102+
@reconstructor
103+
def _sqlalchemy_reconstructor(self):
104+
# Set timezone info for created_at and updated_at
105+
# Seems a bit hacky but is the only way to ensure that
106+
# datetime objects are timezone-aware after deserialization
107+
if self.created_at and self.created_at.tzinfo is None:
108+
self.created_at = self.created_at.replace(tzinfo=pytz.UTC)
109+
if self.updated_at and self.updated_at.tzinfo is None:
110+
self.updated_at = self.updated_at.replace(tzinfo=pytz.UTC)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import datetime
2+
from pathlib import Path
3+
4+
import pytz
5+
6+
from beets_flask.database.models import SessionStateInDb
7+
from beets_flask.importer.session import SessionState
8+
from tests.mixins.database import IsolatedDBMixin
9+
10+
11+
class TestDates(IsolatedDBMixin):
12+
"""Test that dates are set correctly in the database for SessionStateInDb objects.
13+
14+
This test checks that the created_at and updated_at fields are set to the current
15+
time in UTC when a SessionStateInDb object is created, and that they are
16+
deserialized correctly from the database.
17+
"""
18+
19+
def test_dates(self, db_session_factory, tmpdir_factory):
20+
# Create a new session state in db
21+
state = SessionState(Path(tmpdir_factory.mktemp("dates_test")))
22+
23+
state_in_db = SessionStateInDb.from_live_state(state)
24+
with db_session_factory() as s:
25+
s.add(state_in_db)
26+
s.commit()
27+
28+
# Check that the dates are set
29+
with db_session_factory() as s:
30+
state_in_db = s.query(SessionStateInDb).filter_by(id=state_in_db.id).one()
31+
assert state_in_db.created_at is not None
32+
assert state_in_db.updated_at is not None
33+
34+
# Check that the dates are deserialized correctly
35+
assert isinstance(state_in_db.created_at, datetime.datetime)
36+
assert isinstance(state_in_db.updated_at, datetime.datetime)
37+
38+
# Check that the timezone is UTC
39+
assert state_in_db.created_at.tzinfo is not None
40+
assert state_in_db.updated_at.tzinfo is not None
41+
assert state_in_db.created_at.tzinfo == pytz.UTC
42+
assert state_in_db.updated_at.tzinfo == pytz.UTC
43+
44+
# Should be approximately equal to current local time
45+
now = datetime.datetime.now().astimezone()
46+
47+
assert abs((state_in_db.created_at.astimezone() - now).total_seconds()) < 5
48+
assert abs((state_in_db.updated_at.astimezone() - now).total_seconds()) < 5

0 commit comments

Comments
 (0)