diff --git a/backend/beets_flask/database/models/base.py b/backend/beets_flask/database/models/base.py index fc488ed2..0e961019 100644 --- a/backend/beets_flask/database/models/base.py +++ b/backend/beets_flask/database/models/base.py @@ -1,16 +1,17 @@ from __future__ import annotations from datetime import datetime -from typing import Any, List, Mapping, Self, Sequence, TypedDict +from typing import Mapping, Self, Sequence from uuid import uuid4 -from beets.importer import ImportTask, library -from sqlalchemy import Index, LargeBinary, select +import pytz +from sqlalchemy import LargeBinary, select from sqlalchemy.orm import ( DeclarativeBase, Mapped, Session, mapped_column, + reconstructor, registry, ) from sqlalchemy.sql import func @@ -24,6 +25,7 @@ class Base(DeclarativeBase): registry = registry(type_annotation_map={bytes: LargeBinary}) id: Mapped[str] = mapped_column(primary_key=True) + created_at: Mapped[datetime] = mapped_column(default=func.now(), index=True) updated_at: Mapped[datetime] = mapped_column( default=func.now(), onupdate=func.now() @@ -96,3 +98,13 @@ def exist_all_ids(cls, ids: list[str], session: Session) -> bool: def to_dict(self) -> Mapping: return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + @reconstructor + def _sqlalchemy_reconstructor(self): + # Set timezone info for created_at and updated_at + # Seems a bit hacky but is the only way to ensure that + # datetime objects are timezone-aware after deserialization + if self.created_at and self.created_at.tzinfo is None: + self.created_at = self.created_at.replace(tzinfo=pytz.UTC) + if self.updated_at and self.updated_at.tzinfo is None: + self.updated_at = self.updated_at.replace(tzinfo=pytz.UTC) diff --git a/backend/tests/unit/test_database/test_dates.py b/backend/tests/unit/test_database/test_dates.py new file mode 100644 index 00000000..c1931ca4 --- /dev/null +++ b/backend/tests/unit/test_database/test_dates.py @@ -0,0 +1,48 @@ +import datetime +from pathlib import Path + +import pytz + +from beets_flask.database.models import SessionStateInDb +from beets_flask.importer.session import SessionState +from tests.mixins.database import IsolatedDBMixin + + +class TestDates(IsolatedDBMixin): + """Test that dates are set correctly in the database for SessionStateInDb objects. + + This test checks that the created_at and updated_at fields are set to the current + time in UTC when a SessionStateInDb object is created, and that they are + deserialized correctly from the database. + """ + + def test_dates(self, db_session_factory, tmpdir_factory): + # Create a new session state in db + state = SessionState(Path(tmpdir_factory.mktemp("dates_test"))) + + state_in_db = SessionStateInDb.from_live_state(state) + with db_session_factory() as s: + s.add(state_in_db) + s.commit() + + # Check that the dates are set + with db_session_factory() as s: + state_in_db = s.query(SessionStateInDb).filter_by(id=state_in_db.id).one() + assert state_in_db.created_at is not None + assert state_in_db.updated_at is not None + + # Check that the dates are deserialized correctly + assert isinstance(state_in_db.created_at, datetime.datetime) + assert isinstance(state_in_db.updated_at, datetime.datetime) + + # Check that the timezone is UTC + assert state_in_db.created_at.tzinfo is not None + assert state_in_db.updated_at.tzinfo is not None + assert state_in_db.created_at.tzinfo == pytz.UTC + assert state_in_db.updated_at.tzinfo == pytz.UTC + + # Should be approximately equal to current local time + now = datetime.datetime.now().astimezone() + + assert abs((state_in_db.created_at.astimezone() - now).total_seconds()) < 5 + assert abs((state_in_db.updated_at.astimezone() - now).total_seconds()) < 5