Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from jupyter_scheduler.managers import SQLAlchemyDatabaseManager
from jupyter_scheduler.orm import Base
from jupyter_scheduler.scheduler import Scheduler
from jupyter_scheduler.tests.mocks import MockEnvironmentManager
Expand Down Expand Up @@ -59,6 +60,8 @@ def jp_scheduler(jp_scheduler_db_url, jp_scheduler_root_dir, jp_scheduler_db):
db_url=jp_scheduler_db_url,
root_dir=str(jp_scheduler_root_dir),
environments_manager=MockEnvironmentManager(),
database_manager=SQLAlchemyDatabaseManager(),
database_manager_class="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
)


Expand Down
23 changes: 21 additions & 2 deletions jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import io
import os
import shutil
Expand Down Expand Up @@ -29,12 +30,30 @@ class ExecutionManager(ABC):
_model = None
_db_session = None

def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
def __init__(
self,
job_id: str,
root_dir: str,
db_url: str,
staging_paths: Dict[str, str],
database_manager_class,
):
self.job_id = job_id
self.staging_paths = staging_paths
self.root_dir = root_dir
self.db_url = db_url

self.database_manager = self._create_database_manager(database_manager_class)

def _create_database_manager(self, database_manager_class):
try:
module_name, class_name = database_manager_class.rsplit(".", 1)
module = importlib.import_module(module_name)
DatabaseManagerClass = getattr(module, class_name)
return DatabaseManagerClass()
except (ValueError, ImportError, AttributeError) as e:
raise ValueError(f"Invalid database_manager_class '{database_manager_class}': {e}")

@property
def model(self):
if self._model is None:
Expand All @@ -46,7 +65,7 @@ def model(self):
@property
def db_session(self):
if self._db_session is None:
self._db_session = create_session(self.db_url)
self._db_session = create_session(self.db_url, self.database_manager)

return self._db_session

Expand Down
12 changes: 11 additions & 1 deletion jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class SchedulerApp(ExtensionApp):
def _db_url_default(self):
return f"sqlite:///{jupyter_data_dir()}/scheduler.sqlite"

database_manager_class = Type(
default_value="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
klass="jupyter_scheduler.managers.DatabaseManager",
config=True,
help=_i18n("Database manager class for custom database backends."),
)

environment_manager_class = Type(
default_value="jupyter_scheduler.environments.CondaEnvironmentManager",
klass="jupyter_scheduler.environments.EnvironmentManager",
Expand All @@ -69,7 +76,8 @@ def _db_url_default(self):
def initialize_settings(self):
super().initialize_settings()

create_tables(self.db_url, self.drop_tables)
database_manager = self.database_manager_class()
create_tables(self.db_url, self.drop_tables, database_manager=database_manager)

environments_manager = self.environment_manager_class()

Expand All @@ -78,6 +86,8 @@ def initialize_settings(self):
environments_manager=environments_manager,
db_url=self.db_url,
config=self.config,
database_manager=database_manager,
database_manager_class=self.database_manager_class,
)

job_files_manager = self.job_files_manager_class(scheduler=scheduler)
Expand Down
17 changes: 17 additions & 0 deletions jupyter_scheduler/job_files_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def generate_filepaths(self):
output_filepath = os.path.join(self.output_dir, self.output_filenames[output_format])
if not os.path.exists(output_filepath) or self.redownload:
yield input_filepath, output_filepath

if self.staging_paths:
staging_dir = os.path.dirname(next(iter(self.staging_paths.values())))
if os.path.exists(staging_dir):
explicit_files = set()
for output_format in output_formats:
if output_format in self.staging_paths:
explicit_files.add(os.path.basename(self.staging_paths[output_format]))

for file_name in os.listdir(staging_dir):
file_path = os.path.join(staging_dir, file_name)
if os.path.isfile(file_path) and file_name not in explicit_files:
input_filepath = file_path
output_filepath = os.path.join(self.output_dir, file_name)
if not os.path.exists(output_filepath) or self.redownload:
yield input_filepath, output_filepath

if self.include_staging_files:
staging_dir = os.path.dirname(self.staging_paths["input"])
for file_relative_path in self.output_filenames["files"]:
Expand Down
66 changes: 66 additions & 0 deletions jupyter_scheduler/managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
from sqlite3 import OperationalError

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from jupyter_scheduler.orm import Base as DefaultBase
from jupyter_scheduler.orm import update_db_schema


class DatabaseManager(ABC):
"""Base class for database managers.

Database managers handle database operations for jupyter-scheduler.
Subclasses can implement custom storage backends (K8s, Redis, etc.)
while maintaining compatibility with the scheduler's session interface.
"""

@abstractmethod
def create_session(self, db_url: str):
"""Create a database session.

Args:
db_url: Database URL (e.g., "k8s://namespace", "redis://localhost")

Returns:
Session object compatible with SQLAlchemy session interface
"""
pass

@abstractmethod
def create_tables(self, db_url: str, drop_tables: bool = False, Base=None):
"""Create database tables/schema.

Args:
db_url: Database URL
drop_tables: Whether to drop existing tables first
Base: SQLAlchemy Base for custom schemas (tests)
"""
pass


class SQLAlchemyDatabaseManager(DatabaseManager):
"""Default database manager using SQLAlchemy."""

def create_session(self, db_url: str):
"""Create SQLAlchemy session factory."""
engine = create_engine(db_url, echo=False)
Session = sessionmaker(bind=engine)
return Session

def create_tables(self, db_url: str, drop_tables: bool = False, Base=None):
"""Create database tables using SQLAlchemy."""
if Base is None:
Base = DefaultBase

engine = create_engine(db_url)
update_db_schema(engine, Base)

try:
if drop_tables:
Base.metadata.drop_all(engine)
except OperationalError:
pass
finally:
Base.metadata.create_all(engine)
20 changes: 4 additions & 16 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,9 @@ def update_db_schema(engine, Base):
connection.execute(alter_statement)


def create_tables(db_url, drop_tables=False, Base=Base):
engine = create_engine(db_url)
update_db_schema(engine, Base)
def create_tables(db_url, drop_tables=False, Base=Base, *, database_manager):
database_manager.create_tables(db_url, drop_tables, Base)

try:
if drop_tables:
Base.metadata.drop_all(engine)
except OperationalError:
pass
finally:
Base.metadata.create_all(engine)


def create_session(db_url):
engine = create_engine(db_url, echo=False)
Session = sessionmaker(bind=engine)

return Session
def create_session(db_url, database_manager):
return database_manager.create_session(db_url)
7 changes: 6 additions & 1 deletion jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,19 +405,23 @@ def __init__(
environments_manager: Type[EnvironmentManager],
db_url: str,
config=None,
database_manager=None,
database_manager_class=None,
**kwargs,
):
super().__init__(
root_dir=root_dir, environments_manager=environments_manager, config=config, **kwargs
)
self.db_url = db_url
self.database_manager = database_manager
self.database_manager_class = database_manager_class
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)

@property
def db_session(self):
if not self._db_session:
self._db_session = create_session(self.db_url)
self._db_session = create_session(self.db_url, self.database_manager)

return self._db_session

Expand Down Expand Up @@ -492,6 +496,7 @@ def create_job(self, model: CreateJob) -> str:
staging_paths=staging_paths,
root_dir=self.root_dir,
db_url=self.db_url,
database_manager_class=self.database_manager_class,
).process
)
p.start()
Expand Down
1 change: 1 addition & 0 deletions jupyter_scheduler/tests/test_execution_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_add_side_effects_files(
root_dir=jp_scheduler_root_dir,
db_url=jp_scheduler_db_url,
staging_paths={"input": staged_notebook_file_path},
database_manager_class="jupyter_scheduler.managers.SQLAlchemyDatabaseManager",
)
manager.add_side_effects_files(staged_notebook_dir)

Expand Down
21 changes: 16 additions & 5 deletions jupyter_scheduler/tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@


@pytest.fixture
def initial_db(jp_scheduler_db_url) -> tuple[Type[DeclarativeMeta], sessionmaker, str]:
def database_manager():
from jupyter_scheduler.managers import SQLAlchemyDatabaseManager

return SQLAlchemyDatabaseManager()


@pytest.fixture
def initial_db(
jp_scheduler_db_url, database_manager
) -> tuple[Type[DeclarativeMeta], sessionmaker, str]:
TestBase = declarative_base()

class MockInitialJob(TestBase):
Expand All @@ -24,9 +33,9 @@ class MockInitialJob(TestBase):

initial_job = MockInitialJob(runtime_environment_name="abc", input_filename="input.ipynb")

create_tables(db_url=jp_scheduler_db_url, Base=TestBase)
create_tables(db_url=jp_scheduler_db_url, Base=TestBase, database_manager=database_manager)

Session = create_session(jp_scheduler_db_url)
Session = create_session(jp_scheduler_db_url, database_manager)
session = Session()

session.add(initial_job)
Expand All @@ -52,7 +61,9 @@ class MockUpdatedJob(TestBase):
return MockUpdatedJob


def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_job_model):
def test_create_tables_with_new_column(
jp_scheduler_db_url, initial_db, updated_job_model, database_manager
):
TestBase, Session, initial_job_id = initial_db

session = Session()
Expand All @@ -61,7 +72,7 @@ def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_
session.close()

JobModel = updated_job_model
create_tables(db_url=jp_scheduler_db_url, Base=TestBase)
create_tables(db_url=jp_scheduler_db_url, Base=TestBase, database_manager=database_manager)

session = Session()
updated_columns = {col["name"] for col in inspect(session.bind).get_columns("jobs")}
Expand Down
Loading