Skip to content

Commit 5b55901

Browse files
Add database schema update and database migration logic (#520)
* add db migration logic and a test for it * make Job and JobDefinition records extendable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make updated_job_model a fixture * add return types to test_orm fixtures * refactor update_db_schema logic into a separate function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make initial_db return a tuple * improve naming clarity * remove a level of intendation in update_db_schema * Ignore nullability and default values during the db migration, document the fact via comments * improve update_db_schema accordingly to comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d690ac8 commit 5b55901

File tree

3 files changed

+113
-4
lines changed

3 files changed

+113
-4
lines changed

conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
from sqlalchemy import create_engine
5-
from sqlalchemy.orm import Session, sessionmaker
5+
from sqlalchemy.orm import sessionmaker
66

77
from jupyter_scheduler.orm import Base
88
from jupyter_scheduler.scheduler import Scheduler

jupyter_scheduler/orm.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import json
2-
import os
32
from sqlite3 import OperationalError
43
from uuid import uuid4
54

65
import sqlalchemy.types as types
7-
from sqlalchemy import Boolean, Column, Integer, String, create_engine
6+
from sqlalchemy import Boolean, Column, Integer, String, create_engine, inspect
87
from sqlalchemy.orm import declarative_base, declarative_mixin, registry, sessionmaker
8+
from sqlalchemy.sql import text
99

1010
from jupyter_scheduler.models import EmailNotifications, Status
1111
from jupyter_scheduler.utils import get_utc_timestamp
@@ -85,12 +85,15 @@ class CommonColumns:
8585
output_filename_template = Column(String(256))
8686
update_time = Column(Integer, default=get_utc_timestamp, onupdate=get_utc_timestamp)
8787
create_time = Column(Integer, default=get_utc_timestamp)
88+
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
89+
# Any default values specified for new columns will be ignored during the migration process.
8890
package_input_folder = Column(Boolean)
8991
packaged_files = Column(JsonType, default=[])
9092

9193

9294
class Job(CommonColumns, Base):
9395
__tablename__ = "jobs"
96+
__table_args__ = {"extend_existing": True}
9497
job_id = Column(String(36), primary_key=True, default=generate_uuid)
9598
job_definition_id = Column(String(36))
9699
status = Column(String(64), default=Status.STOPPED)
@@ -100,20 +103,53 @@ class Job(CommonColumns, Base):
100103
url = Column(String(256), default=generate_jobs_url)
101104
pid = Column(Integer)
102105
idempotency_token = Column(String(256))
106+
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
107+
# Any default values specified for new columns will be ignored during the migration process.
103108

104109

105110
class JobDefinition(CommonColumns, Base):
106111
__tablename__ = "job_definitions"
112+
__table_args__ = {"extend_existing": True}
107113
job_definition_id = Column(String(36), primary_key=True, default=generate_uuid)
108114
schedule = Column(String(256))
109115
timezone = Column(String(36))
110116
url = Column(String(256), default=generate_job_definitions_url)
111117
create_time = Column(Integer, default=get_utc_timestamp)
112118
active = Column(Boolean, default=True)
119+
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
120+
# Any default values specified for new columns will be ignored during the migration process.
113121

114122

115-
def create_tables(db_url, drop_tables=False):
123+
def update_db_schema(engine, Base):
124+
inspector = inspect(engine)
125+
alter_statements = []
126+
127+
for table_name, model in Base.metadata.tables.items():
128+
if not inspector.has_table(table_name):
129+
continue
130+
columns_db = inspector.get_columns(table_name)
131+
columns_db_names = {col["name"] for col in columns_db}
132+
133+
for column_model_name, column_model in model.c.items():
134+
if column_model_name in columns_db_names:
135+
continue
136+
column_type = str(column_model.type.compile(dialect=engine.dialect))
137+
alter_statement = text(
138+
f"ALTER TABLE {table_name} ADD COLUMN {column_model_name} {column_type} NULL"
139+
)
140+
alter_statements.append(alter_statement)
141+
142+
if not alter_statements:
143+
return
144+
with engine.connect() as connection:
145+
for alter_statement in alter_statements:
146+
connection.execute(alter_statement)
147+
148+
149+
def create_tables(db_url, drop_tables=False, Base=Base):
116150
engine = create_engine(db_url)
151+
update_db_schema(engine, Base)
152+
117153
try:
118154
if drop_tables:
119155
Base.metadata.drop_all(engine)

jupyter_scheduler/tests/test_orm.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Type
2+
3+
import pytest
4+
from sqlalchemy import Column, Integer, String, inspect
5+
from sqlalchemy.orm import DeclarativeMeta, sessionmaker
6+
7+
from jupyter_scheduler.orm import (
8+
create_session,
9+
create_tables,
10+
declarative_base,
11+
generate_uuid,
12+
)
13+
14+
15+
@pytest.fixture
16+
def initial_db(jp_scheduler_db_url) -> tuple[Type[DeclarativeMeta], sessionmaker, str]:
17+
TestBase = declarative_base()
18+
19+
class MockInitialJob(TestBase):
20+
__tablename__ = "jobs"
21+
job_id = Column(String(36), primary_key=True, default=generate_uuid)
22+
runtime_environment_name = Column(String(256), nullable=False)
23+
input_filename = Column(String(256), nullable=False)
24+
25+
initial_job = MockInitialJob(runtime_environment_name="abc", input_filename="input.ipynb")
26+
27+
create_tables(db_url=jp_scheduler_db_url, Base=TestBase)
28+
29+
Session = create_session(jp_scheduler_db_url)
30+
session = Session()
31+
32+
session.add(initial_job)
33+
session.commit()
34+
job_id = initial_job.job_id
35+
session.close()
36+
37+
return (TestBase, Session, job_id)
38+
39+
40+
@pytest.fixture
41+
def updated_job_model(initial_db) -> Type[DeclarativeMeta]:
42+
TestBase = initial_db[0]
43+
44+
class MockUpdatedJob(TestBase):
45+
__tablename__ = "jobs"
46+
__table_args__ = {"extend_existing": True}
47+
job_id = Column(String(36), primary_key=True, default=generate_uuid)
48+
runtime_environment_name = Column(String(256), nullable=False)
49+
input_filename = Column(String(256), nullable=False)
50+
new_column = Column("new_column", Integer)
51+
52+
return MockUpdatedJob
53+
54+
55+
def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_job_model):
56+
TestBase, Session, initial_job_id = initial_db
57+
58+
session = Session()
59+
initial_columns = {col["name"] for col in inspect(session.bind).get_columns("jobs")}
60+
assert "new_column" not in initial_columns
61+
session.close()
62+
63+
JobModel = updated_job_model
64+
create_tables(db_url=jp_scheduler_db_url, Base=TestBase)
65+
66+
session = Session()
67+
updated_columns = {col["name"] for col in inspect(session.bind).get_columns("jobs")}
68+
assert "new_column" in updated_columns
69+
70+
updated_job = session.query(JobModel).filter(JobModel.job_id == initial_job_id).one()
71+
assert hasattr(updated_job, "new_column")
72+
assert updated_job.runtime_environment_name == "abc"
73+
assert updated_job.input_filename == "input.ipynb"

0 commit comments

Comments
 (0)