Skip to content

Commit e7f8d38

Browse files
committed
add DatabaseManager class
1 parent 576fecf commit e7f8d38

File tree

8 files changed

+115
-25
lines changed

8 files changed

+115
-25
lines changed

conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sqlalchemy import create_engine
55
from sqlalchemy.orm import sessionmaker
66

7+
from jupyter_scheduler.managers import SQLAlchemyDatabaseManager
78
from jupyter_scheduler.orm import Base
89
from jupyter_scheduler.scheduler import Scheduler
910
from jupyter_scheduler.tests.mocks import MockEnvironmentManager
@@ -59,6 +60,8 @@ def jp_scheduler(jp_scheduler_db_url, jp_scheduler_root_dir, jp_scheduler_db):
5960
db_url=jp_scheduler_db_url,
6061
root_dir=str(jp_scheduler_root_dir),
6162
environments_manager=MockEnvironmentManager(),
63+
database_manager=SQLAlchemyDatabaseManager(),
64+
database_manager_class="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
6265
)
6366

6467

jupyter_scheduler/executors.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import io
23
import os
34
import shutil
@@ -29,12 +30,24 @@ class ExecutionManager(ABC):
2930
_model = None
3031
_db_session = None
3132

32-
def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
33+
def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str], database_manager_class):
3334
self.job_id = job_id
3435
self.staging_paths = staging_paths
3536
self.root_dir = root_dir
3637
self.db_url = db_url
38+
39+
self.database_manager = self._create_database_manager(database_manager_class)
3740

41+
42+
def _create_database_manager(self, database_manager_class):
43+
try:
44+
module_name, class_name = database_manager_class.rsplit('.', 1)
45+
module = importlib.import_module(module_name)
46+
DatabaseManagerClass = getattr(module, class_name)
47+
return DatabaseManagerClass()
48+
except (ValueError, ImportError, AttributeError) as e:
49+
raise ValueError(f"Invalid database_manager_class '{database_manager_class}': {e}")
50+
3851
@property
3952
def model(self):
4053
if self._model is None:
@@ -46,7 +59,7 @@ def model(self):
4659
@property
4760
def db_session(self):
4861
if self._db_session is None:
49-
self._db_session = create_session(self.db_url)
62+
self._db_session = create_session(self.db_url, self.database_manager)
5063

5164
return self._db_session
5265

jupyter_scheduler/extension.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ class SchedulerApp(ExtensionApp):
4545
def _db_url_default(self):
4646
return f"sqlite:///{jupyter_data_dir()}/scheduler.sqlite"
4747

48+
database_manager_class = Type(
49+
default_value="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
50+
klass="jupyter_scheduler.managers.DatabaseManager",
51+
config=True,
52+
help=_i18n("Database manager class for custom database backends."),
53+
)
54+
4855
environment_manager_class = Type(
4956
default_value="jupyter_scheduler.environments.CondaEnvironmentManager",
5057
klass="jupyter_scheduler.environments.EnvironmentManager",
@@ -69,7 +76,8 @@ def _db_url_default(self):
6976
def initialize_settings(self):
7077
super().initialize_settings()
7178

72-
create_tables(self.db_url, self.drop_tables)
79+
database_manager = self.database_manager_class()
80+
create_tables(self.db_url, self.drop_tables, database_manager=database_manager)
7381

7482
environments_manager = self.environment_manager_class()
7583

@@ -78,6 +86,8 @@ def initialize_settings(self):
7886
environments_manager=environments_manager,
7987
db_url=self.db_url,
8088
config=self.config,
89+
database_manager=database_manager,
90+
database_manager_class=self.database_manager_class,
8191
)
8292

8393
job_files_manager = self.job_files_manager_class(scheduler=scheduler)

jupyter_scheduler/managers.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from abc import ABC, abstractmethod
2+
from sqlite3 import OperationalError
3+
from sqlalchemy import create_engine
4+
from sqlalchemy.orm import sessionmaker
5+
6+
from jupyter_scheduler.orm import Base as DefaultBase, update_db_schema
7+
8+
9+
class DatabaseManager(ABC):
10+
"""Base class for database managers.
11+
12+
Database managers handle database operations for jupyter-scheduler.
13+
Subclasses can implement custom storage backends (K8s, Redis, etc.)
14+
while maintaining compatibility with the scheduler's session interface.
15+
"""
16+
17+
@abstractmethod
18+
def create_session(self, db_url: str):
19+
"""Create a database session.
20+
21+
Args:
22+
db_url: Database URL (e.g., "k8s://namespace", "redis://localhost")
23+
24+
Returns:
25+
Session object compatible with SQLAlchemy session interface
26+
"""
27+
pass
28+
29+
@abstractmethod
30+
def create_tables(self, db_url: str, drop_tables: bool = False, Base=None):
31+
"""Create database tables/schema.
32+
33+
Args:
34+
db_url: Database URL
35+
drop_tables: Whether to drop existing tables first
36+
Base: SQLAlchemy Base for custom schemas (tests)
37+
"""
38+
pass
39+
40+
41+
class SQLAlchemyDatabaseManager(DatabaseManager):
42+
"""Default database manager using SQLAlchemy."""
43+
44+
def create_session(self, db_url: str):
45+
"""Create SQLAlchemy session factory."""
46+
engine = create_engine(db_url, echo=False)
47+
Session = sessionmaker(bind=engine)
48+
return Session
49+
50+
def create_tables(self, db_url: str, drop_tables: bool = False, Base=None):
51+
"""Create database tables using SQLAlchemy."""
52+
if Base is None:
53+
Base = DefaultBase
54+
55+
engine = create_engine(db_url)
56+
update_db_schema(engine, Base)
57+
58+
try:
59+
if drop_tables:
60+
Base.metadata.drop_all(engine)
61+
except OperationalError:
62+
pass
63+
finally:
64+
Base.metadata.create_all(engine)

jupyter_scheduler/orm.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,9 @@ def update_db_schema(engine, Base):
146146
connection.execute(alter_statement)
147147

148148

149-
def create_tables(db_url, drop_tables=False, Base=Base):
150-
engine = create_engine(db_url)
151-
update_db_schema(engine, Base)
149+
def create_tables(db_url, drop_tables=False, Base=Base, *, database_manager):
150+
database_manager.create_tables(db_url, drop_tables, Base)
152151

153-
try:
154-
if drop_tables:
155-
Base.metadata.drop_all(engine)
156-
except OperationalError:
157-
pass
158-
finally:
159-
Base.metadata.create_all(engine)
160152

161-
162-
def create_session(db_url):
163-
engine = create_engine(db_url, echo=False)
164-
Session = sessionmaker(bind=engine)
165-
166-
return Session
153+
def create_session(db_url, database_manager):
154+
return database_manager.create_session(db_url)

jupyter_scheduler/scheduler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,19 +405,23 @@ def __init__(
405405
environments_manager: Type[EnvironmentManager],
406406
db_url: str,
407407
config=None,
408+
database_manager=None,
409+
database_manager_class=None,
408410
**kwargs,
409411
):
410412
super().__init__(
411413
root_dir=root_dir, environments_manager=environments_manager, config=config, **kwargs
412414
)
413415
self.db_url = db_url
416+
self.database_manager = database_manager
417+
self.database_manager_class = database_manager_class
414418
if self.task_runner_class:
415419
self.task_runner = self.task_runner_class(scheduler=self, config=config)
416420

417421
@property
418422
def db_session(self):
419423
if not self._db_session:
420-
self._db_session = create_session(self.db_url)
424+
self._db_session = create_session(self.db_url, self.database_manager)
421425

422426
return self._db_session
423427

@@ -492,6 +496,7 @@ def create_job(self, model: CreateJob) -> str:
492496
staging_paths=staging_paths,
493497
root_dir=self.root_dir,
494498
db_url=self.db_url,
499+
database_manager_class=self.database_manager_class,
495500
).process
496501
)
497502
p.start()

jupyter_scheduler/tests/test_execution_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_add_side_effects_files(
5353
root_dir=jp_scheduler_root_dir,
5454
db_url=jp_scheduler_db_url,
5555
staging_paths={"input": staged_notebook_file_path},
56+
database_manager_class="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
5657
)
5758
manager.add_side_effects_files(staged_notebook_dir)
5859

jupyter_scheduler/tests/test_orm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414

1515
@pytest.fixture
16-
def initial_db(jp_scheduler_db_url) -> tuple[Type[DeclarativeMeta], sessionmaker, str]:
16+
def database_manager():
17+
from jupyter_scheduler.managers import SQLAlchemyDatabaseManager
18+
return SQLAlchemyDatabaseManager()
19+
20+
21+
@pytest.fixture
22+
def initial_db(jp_scheduler_db_url, database_manager) -> tuple[Type[DeclarativeMeta], sessionmaker, str]:
1723
TestBase = declarative_base()
1824

1925
class MockInitialJob(TestBase):
@@ -24,9 +30,9 @@ class MockInitialJob(TestBase):
2430

2531
initial_job = MockInitialJob(runtime_environment_name="abc", input_filename="input.ipynb")
2632

27-
create_tables(db_url=jp_scheduler_db_url, Base=TestBase)
33+
create_tables(db_url=jp_scheduler_db_url, Base=TestBase, database_manager=database_manager)
2834

29-
Session = create_session(jp_scheduler_db_url)
35+
Session = create_session(jp_scheduler_db_url, database_manager)
3036
session = Session()
3137

3238
session.add(initial_job)
@@ -52,7 +58,7 @@ class MockUpdatedJob(TestBase):
5258
return MockUpdatedJob
5359

5460

55-
def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_job_model):
61+
def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_job_model, database_manager):
5662
TestBase, Session, initial_job_id = initial_db
5763

5864
session = Session()
@@ -61,7 +67,7 @@ def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_
6167
session.close()
6268

6369
JobModel = updated_job_model
64-
create_tables(db_url=jp_scheduler_db_url, Base=TestBase)
70+
create_tables(db_url=jp_scheduler_db_url, Base=TestBase, database_manager=database_manager)
6571

6672
session = Session()
6773
updated_columns = {col["name"] for col in inspect(session.bind).get_columns("jobs")}

0 commit comments

Comments
 (0)