Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions backend/beets_flask/database/models/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, List, Mapping, Self, Sequence, TypedDict
from typing import Mapping, Self, Sequence
from uuid import uuid4

from beets.importer import ImportTask, library
from sqlalchemy import Index, LargeBinary, select
import pytz
from sqlalchemy import LargeBinary, select
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Session,
mapped_column,
reconstructor,
registry,
)
from sqlalchemy.sql import func
Expand All @@ -24,6 +25,7 @@ class Base(DeclarativeBase):
registry = registry(type_annotation_map={bytes: LargeBinary})

id: Mapped[str] = mapped_column(primary_key=True)

created_at: Mapped[datetime] = mapped_column(default=func.now(), index=True)
updated_at: Mapped[datetime] = mapped_column(
default=func.now(), onupdate=func.now()
Expand Down Expand Up @@ -96,3 +98,13 @@ def exist_all_ids(cls, ids: list[str], session: Session) -> bool:

def to_dict(self) -> Mapping:
return {c.name: getattr(self, c.name) for c in self.__table__.columns}

@reconstructor
def _sqlalchemy_reconstructor(self):
# Set timezone info for created_at and updated_at
# Seems a bit hacky but is the only way to ensure that
# datetime objects are timezone-aware after deserialization
if self.created_at and self.created_at.tzinfo is None:
self.created_at = self.created_at.replace(tzinfo=pytz.UTC)
if self.updated_at and self.updated_at.tzinfo is None:
self.updated_at = self.updated_at.replace(tzinfo=pytz.UTC)
48 changes: 48 additions & 0 deletions backend/tests/unit/test_database/test_dates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import datetime
from pathlib import Path

import pytz

from beets_flask.database.models import SessionStateInDb
from beets_flask.importer.session import SessionState
from tests.mixins.database import IsolatedDBMixin


class TestDates(IsolatedDBMixin):
"""Test that dates are set correctly in the database for SessionStateInDb objects.

This test checks that the created_at and updated_at fields are set to the current
time in UTC when a SessionStateInDb object is created, and that they are
deserialized correctly from the database.
"""

def test_dates(self, db_session_factory, tmpdir_factory):
# Create a new session state in db
state = SessionState(Path(tmpdir_factory.mktemp("dates_test")))

state_in_db = SessionStateInDb.from_live_state(state)
with db_session_factory() as s:
s.add(state_in_db)
s.commit()

# Check that the dates are set
with db_session_factory() as s:
state_in_db = s.query(SessionStateInDb).filter_by(id=state_in_db.id).one()
assert state_in_db.created_at is not None
assert state_in_db.updated_at is not None

# Check that the dates are deserialized correctly
assert isinstance(state_in_db.created_at, datetime.datetime)
assert isinstance(state_in_db.updated_at, datetime.datetime)

# Check that the timezone is UTC
assert state_in_db.created_at.tzinfo is not None
assert state_in_db.updated_at.tzinfo is not None
assert state_in_db.created_at.tzinfo == pytz.UTC
assert state_in_db.updated_at.tzinfo == pytz.UTC

# Should be approximately equal to current local time
now = datetime.datetime.now().astimezone()

assert abs((state_in_db.created_at.astimezone() - now).total_seconds()) < 5
assert abs((state_in_db.updated_at.astimezone() - now).total_seconds()) < 5
Loading