From 1a46b8dae91886fcb9cab0db15defb3b6a2acda6 Mon Sep 17 00:00:00 2001 From: Arunav Gupta Date: Wed, 16 Jul 2025 14:54:37 -0400 Subject: [PATCH 1/4] Add completed_cells tracking during notebook execution --- jupyter_scheduler/executors.py | 36 ++++- jupyter_scheduler/models.py | 2 + jupyter_scheduler/orm.py | 1 + .../tests/test_execution_manager.py | 129 +++++++++++++++++- jupyter_scheduler/tests/test_handlers.py | 94 +++++++++++++ jupyter_scheduler/tests/test_orm.py | 91 ++++++++++++ 6 files changed, 349 insertions(+), 4 deletions(-) diff --git a/jupyter_scheduler/executors.py b/jupyter_scheduler/executors.py index 7e1a9974e..1f4c0036f 100644 --- a/jupyter_scheduler/executors.py +++ b/jupyter_scheduler/executors.py @@ -11,12 +11,38 @@ import nbformat from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor -from jupyter_scheduler.models import DescribeJob, JobFeature, Status +from jupyter_scheduler.models import DescribeJob, JobFeature, Status, UpdateJob from jupyter_scheduler.orm import Job, create_session from jupyter_scheduler.parameterize import add_parameters from jupyter_scheduler.utils import get_utc_timestamp +class TrackingExecutePreprocessor(ExecutePreprocessor): + """Custom ExecutePreprocessor that tracks completed cells and updates the database""" + + def __init__(self, db_session, job_id, **kwargs): + super().__init__(**kwargs) + self.db_session = db_session + self.job_id = job_id + + def preprocess_cell(self, cell, resources, index): + """ + Override to track completed cells in the database. + Calls the superclass implementation and then updates the database. + """ + # Call the superclass implementation + cell, resources = super().preprocess_cell(cell, resources, index) + + # Update the database with the current count of completed cells + with self.db_session() as session: + session.query(Job).filter(Job.job_id == self.job_id).update( + {"completed_cells": self.code_cells_executed} + ) + session.commit() + + return cell, resources + + class ExecutionManager(ABC): """Base execution manager. Clients are expected to override this class @@ -132,8 +158,12 @@ def execute(self): nb = add_parameters(nb, job.parameters) staging_dir = os.path.dirname(self.staging_paths["input"]) - ep = ExecutePreprocessor( - kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir + ep = TrackingExecutePreprocessor( + db_session=self.db_session, + job_id=self.job_id, + kernel_name=nb.metadata.kernelspec["name"], + store_widget_state=True, + cwd=staging_dir ) try: diff --git a/jupyter_scheduler/models.py b/jupyter_scheduler/models.py index 38e240e0e..5dc28e1b2 100644 --- a/jupyter_scheduler/models.py +++ b/jupyter_scheduler/models.py @@ -148,6 +148,7 @@ class DescribeJob(BaseModel): downloaded: bool = False package_input_folder: Optional[bool] = None packaged_files: Optional[List[str]] = [] + completed_cells: Optional[int] = None class Config: orm_mode = True @@ -193,6 +194,7 @@ class UpdateJob(BaseModel): status: Optional[Status] = None name: Optional[str] = None compute_type: Optional[str] = None + completed_cells: Optional[int] = None class DeleteJob(BaseModel): diff --git a/jupyter_scheduler/orm.py b/jupyter_scheduler/orm.py index dbbbfad8e..0b380d815 100644 --- a/jupyter_scheduler/orm.py +++ b/jupyter_scheduler/orm.py @@ -103,6 +103,7 @@ class Job(CommonColumns, Base): url = Column(String(256), default=generate_jobs_url) pid = Column(Integer) idempotency_token = Column(String(256)) + completed_cells = Column(Integer) # All new columns added to this table must be nullable to ensure compatibility during database migrations. # Any default values specified for new columns will be ignored during the migration process. diff --git a/jupyter_scheduler/tests/test_execution_manager.py b/jupyter_scheduler/tests/test_execution_manager.py index 66546be38..abc60d1aa 100644 --- a/jupyter_scheduler/tests/test_execution_manager.py +++ b/jupyter_scheduler/tests/test_execution_manager.py @@ -1,10 +1,12 @@ import shutil from pathlib import Path from typing import Tuple +from unittest.mock import MagicMock, patch import pytest +import nbformat -from jupyter_scheduler.executors import DefaultExecutionManager +from jupyter_scheduler.executors import DefaultExecutionManager, TrackingExecutePreprocessor from jupyter_scheduler.orm import Job @@ -58,3 +60,128 @@ def test_add_side_effects_files( job = jp_scheduler_db.query(Job).filter(Job.job_id == job_id).one() assert side_effect_file_name in job.packaged_files + + +@pytest.fixture +def mock_cell(): + """Create a mock notebook cell for testing""" + cell = nbformat.v4.new_code_cell(source="print('test')") + return cell + + +@pytest.fixture +def mock_resources(): + """Create mock resources for testing""" + return {"metadata": {"path": "/test/path"}} + + +def test_tracking_execute_preprocessor_initialization(): + """Test TrackingExecutePreprocessor initialization""" + mock_db_session = MagicMock() + job_id = "test-job-id" + + preprocessor = TrackingExecutePreprocessor( + db_session=mock_db_session, + job_id=job_id, + kernel_name="python3" + ) + + assert preprocessor.db_session == mock_db_session + assert preprocessor.job_id == job_id + assert preprocessor.kernel_name == "python3" + + +def test_tracking_execute_preprocessor_updates_database(mock_cell, mock_resources): + """Test that TrackingExecutePreprocessor updates the database after cell execution""" + mock_db_session = MagicMock() + mock_session_context = MagicMock() + mock_db_session.return_value.__enter__.return_value = mock_session_context + + job_id = "test-job-id" + + with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute: + with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'): + preprocessor = TrackingExecutePreprocessor( + db_session=mock_db_session, + job_id=job_id, + kernel_name="python3" + ) + + # Mock the code_cells_executed attribute + preprocessor.code_cells_executed = 3 + preprocessor.resources = mock_resources + + # Mock the execute_cell method to return the cell + mock_execute.return_value = mock_cell + + # Call preprocess_cell + result_cell, result_resources = preprocessor.preprocess_cell(mock_cell, mock_resources, 0) + + # Verify the superclass method was called + mock_execute.assert_called_once_with(mock_cell, 0, store_history=True) + + # Verify database update was called + mock_session_context.query.assert_called_once_with(Job) + mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with( + {"completed_cells": 3} + ) + mock_session_context.commit.assert_called_once() + + # Verify return values + assert result_cell == mock_cell + assert result_resources == mock_resources + + +def test_tracking_execute_preprocessor_handles_database_errors(mock_cell, mock_resources): + """Test that TrackingExecutePreprocessor handles database errors gracefully""" + mock_db_session = MagicMock() + mock_session_context = MagicMock() + mock_db_session.return_value.__enter__.return_value = mock_session_context + + # Make the database update raise an exception + mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception("DB Error") + + job_id = "test-job-id" + + with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute: + with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'): + preprocessor = TrackingExecutePreprocessor( + db_session=mock_db_session, + job_id=job_id, + kernel_name="python3" + ) + + preprocessor.code_cells_executed = 1 + preprocessor.resources = mock_resources + mock_execute.return_value = mock_cell + + # The database error should propagate + with pytest.raises(Exception, match="DB Error"): + preprocessor.preprocess_cell(mock_cell, mock_resources, 0) + + +def test_tracking_execute_preprocessor_uses_correct_job_id(mock_cell, mock_resources): + """Test that TrackingExecutePreprocessor uses the correct job_id in database queries""" + mock_db_session = MagicMock() + mock_session_context = MagicMock() + mock_db_session.return_value.__enter__.return_value = mock_session_context + + job_id = "specific-job-id-123" + + with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute: + with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'): + preprocessor = TrackingExecutePreprocessor( + db_session=mock_db_session, + job_id=job_id, + kernel_name="python3" + ) + + preprocessor.code_cells_executed = 2 + preprocessor.resources = mock_resources + mock_execute.return_value = mock_cell + + preprocessor.preprocess_cell(mock_cell, mock_resources, 0) + + # Verify the correct job_id is used in the filter + filter_call = mock_session_context.query.return_value.filter.call_args[0][0] + assert str(filter_call).find(job_id) != -1 or filter_call.right.value == job_id diff --git a/jupyter_scheduler/tests/test_handlers.py b/jupyter_scheduler/tests/test_handlers.py index 9e2e4b7ba..f2ab90201 100644 --- a/jupyter_scheduler/tests/test_handlers.py +++ b/jupyter_scheduler/tests/test_handlers.py @@ -131,6 +131,7 @@ async def test_get_jobs_for_single_job(jp_fetch): url="url_a", create_time=1664305872620, update_time=1664305872620, + completed_cells=5, ) response = await jp_fetch("scheduler", "jobs", job_id, method="GET") @@ -140,6 +141,7 @@ async def test_get_jobs_for_single_job(jp_fetch): assert body["job_id"] == job_id assert body["input_filename"] assert body["job_files"] + assert body["completed_cells"] == 5 @pytest.mark.parametrize( @@ -320,6 +322,28 @@ async def test_patch_jobs(jp_fetch): mock_update_job.assert_called_once_with(job_id, UpdateJob(**body)) +async def test_patch_jobs_with_completed_cells(jp_fetch): + with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job: + job_id = "542e0fac-1274-4a78-8340-a850bdb559c8" + body = {"name": "updated job", "completed_cells": 10} + response = await jp_fetch( + "scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body) + ) + assert response.code == 204 + mock_update_job.assert_called_once_with(job_id, UpdateJob(**body)) + + +async def test_patch_jobs_completed_cells_only(jp_fetch): + with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job: + job_id = "542e0fac-1274-4a78-8340-a850bdb559c8" + body = {"completed_cells": 15} + response = await jp_fetch( + "scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body) + ) + assert response.code == 204 + mock_update_job.assert_called_once_with(job_id, UpdateJob(**body)) + + async def test_patch_jobs_for_stop_job(jp_fetch): with patch("jupyter_scheduler.scheduler.Scheduler.stop_job") as mock_stop_job: job_id = "542e0fac-1274-4a78-8340-a850bdb559c8" @@ -677,3 +701,73 @@ async def test_delete_job_definition_for_unexpected_error(jp_fetch): assert expected_http_error( e, 500, "Unexpected error occurred while deleting the job definition." ) + + +# Model validation tests for completed_cells field +def test_describe_job_completed_cells_validation(): + """Test DescribeJob model validation for completed_cells field""" + # Test valid integer values + job_data = { + "name": "test_job", + "input_filename": "test.ipynb", + "runtime_environment_name": "test_env", + "job_id": "test-job-id", + "url": "http://test.com/jobs/test-job-id", + "create_time": 1234567890, + "update_time": 1234567890, + "completed_cells": 5 + } + job = DescribeJob(**job_data) + assert job.completed_cells == 5 + + # Test None value + job_data["completed_cells"] = None + job = DescribeJob(**job_data) + assert job.completed_cells is None + + # Test zero value + job_data["completed_cells"] = 0 + job = DescribeJob(**job_data) + assert job.completed_cells == 0 + + # Test invalid type + job_data["completed_cells"] = "invalid" + with pytest.raises(ValidationError): + DescribeJob(**job_data) + + +def test_update_job_completed_cells_validation(): + """Test UpdateJob model validation for completed_cells field""" + # Test valid integer values + update_data = {"completed_cells": 10} + update_job = UpdateJob(**update_data) + assert update_job.completed_cells == 10 + + # Test None value + update_data = {"completed_cells": None} + update_job = UpdateJob(**update_data) + assert update_job.completed_cells is None + + # Test zero value + update_data = {"completed_cells": 0} + update_job = UpdateJob(**update_data) + assert update_job.completed_cells == 0 + + # Test invalid type + update_data = {"completed_cells": "invalid"} + with pytest.raises(ValidationError): + UpdateJob(**update_data) + + # Test exclude_none behavior + update_data = {"name": "test", "completed_cells": None} + update_job = UpdateJob(**update_data) + job_dict = update_job.dict(exclude_none=True) + assert "completed_cells" not in job_dict + assert job_dict["name"] == "test" + + # Test include completed_cells when not None + update_data = {"name": "test", "completed_cells": 5} + update_job = UpdateJob(**update_data) + job_dict = update_job.dict(exclude_none=True) + assert job_dict["completed_cells"] == 5 + assert job_dict["name"] == "test" diff --git a/jupyter_scheduler/tests/test_orm.py b/jupyter_scheduler/tests/test_orm.py index e2aab07e6..65c04da01 100644 --- a/jupyter_scheduler/tests/test_orm.py +++ b/jupyter_scheduler/tests/test_orm.py @@ -71,3 +71,94 @@ def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_ assert hasattr(updated_job, "new_column") assert updated_job.runtime_environment_name == "abc" assert updated_job.input_filename == "input.ipynb" + + +def test_completed_cells_column_migration(jp_scheduler_db_url): + """Test that the completed_cells column is properly added during migration""" + from jupyter_scheduler.orm import Base, Job, create_tables + from sqlalchemy import create_engine, inspect + from sqlalchemy.orm import sessionmaker + + # Create initial database without completed_cells + engine = create_engine(jp_scheduler_db_url) + + # Create tables with the current schema (which includes completed_cells) + create_tables(db_url=jp_scheduler_db_url, Base=Base) + + # Verify the completed_cells column exists + inspector = inspect(engine) + columns = {col["name"] for col in inspector.get_columns("jobs")} + assert "completed_cells" in columns + + # Verify the column is of correct type (Integer) + completed_cells_column = next(col for col in inspector.get_columns("jobs") if col["name"] == "completed_cells") + assert str(completed_cells_column["type"]).upper() in ["INTEGER", "INT"] + + # Test that we can insert and retrieve completed_cells values + Session = sessionmaker(bind=engine) + session = Session() + + job = Job( + runtime_environment_name="test_env", + input_filename="test.ipynb", + completed_cells=5 + ) + session.add(job) + session.commit() + + # Retrieve and verify + retrieved_job = session.query(Job).filter(Job.job_id == job.job_id).one() + assert retrieved_job.completed_cells == 5 + + # Test null values are handled properly + job_null = Job( + runtime_environment_name="test_env_null", + input_filename="test_null.ipynb", + completed_cells=None + ) + session.add(job_null) + session.commit() + + retrieved_job_null = session.query(Job).filter(Job.job_id == job_null.job_id).one() + assert retrieved_job_null.completed_cells is None + + session.close() + + +def test_completed_cells_column_nullable(jp_scheduler_db_url): + """Test that completed_cells column is nullable for backward compatibility""" + from jupyter_scheduler.orm import Base, Job, create_tables + from sqlalchemy import create_engine, inspect + from sqlalchemy.orm import sessionmaker + + create_tables(db_url=jp_scheduler_db_url, Base=Base) + + engine = create_engine(jp_scheduler_db_url) + inspector = inspect(engine) + + # Find the completed_cells column + completed_cells_column = next( + col for col in inspector.get_columns("jobs") + if col["name"] == "completed_cells" + ) + + # Verify it's nullable + assert completed_cells_column["nullable"] is True + + # Test creating a job without completed_cells + Session = sessionmaker(bind=engine) + session = Session() + + job = Job( + runtime_environment_name="test_env", + input_filename="test.ipynb" + # Note: not setting completed_cells + ) + session.add(job) + session.commit() + + # Verify it defaults to None + retrieved_job = session.query(Job).filter(Job.job_id == job.job_id).one() + assert retrieved_job.completed_cells is None + + session.close() From 70a70a6824187ca728aa29bec6fd68c9bc22d8ee Mon Sep 17 00:00:00 2001 From: Arunav Gupta Date: Wed, 16 Jul 2025 15:40:31 -0400 Subject: [PATCH 2/4] Make completed_cells column nullable --- jupyter_scheduler/orm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jupyter_scheduler/orm.py b/jupyter_scheduler/orm.py index 0b380d815..c3131cd92 100644 --- a/jupyter_scheduler/orm.py +++ b/jupyter_scheduler/orm.py @@ -103,7 +103,7 @@ class Job(CommonColumns, Base): url = Column(String(256), default=generate_jobs_url) pid = Column(Integer) idempotency_token = Column(String(256)) - completed_cells = Column(Integer) + completed_cells = Column(Integer, nullable=True) # All new columns added to this table must be nullable to ensure compatibility during database migrations. # Any default values specified for new columns will be ignored during the migration process. From 48f9d3f296f9ea1433355ce3df6510eeafca1f1a Mon Sep 17 00:00:00 2001 From: Arunav Gupta Date: Tue, 22 Jul 2025 21:01:32 -0400 Subject: [PATCH 3/4] Refactor cell tracking using hook from ExecutePreprocessor --- jupyter_scheduler/executors.py | 53 ++- jupyter_scheduler/models.py | 1 + jupyter_scheduler/scheduler.py | 2 +- .../tests/test_execution_manager.py | 318 ++++++++++++------ 4 files changed, 233 insertions(+), 141 deletions(-) diff --git a/jupyter_scheduler/executors.py b/jupyter_scheduler/executors.py index 1f4c0036f..26a3ab7f4 100644 --- a/jupyter_scheduler/executors.py +++ b/jupyter_scheduler/executors.py @@ -11,38 +11,12 @@ import nbformat from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor -from jupyter_scheduler.models import DescribeJob, JobFeature, Status, UpdateJob +from jupyter_scheduler.models import DescribeJob, JobFeature, Status from jupyter_scheduler.orm import Job, create_session from jupyter_scheduler.parameterize import add_parameters from jupyter_scheduler.utils import get_utc_timestamp -class TrackingExecutePreprocessor(ExecutePreprocessor): - """Custom ExecutePreprocessor that tracks completed cells and updates the database""" - - def __init__(self, db_session, job_id, **kwargs): - super().__init__(**kwargs) - self.db_session = db_session - self.job_id = job_id - - def preprocess_cell(self, cell, resources, index): - """ - Override to track completed cells in the database. - Calls the superclass implementation and then updates the database. - """ - # Call the superclass implementation - cell, resources = super().preprocess_cell(cell, resources, index) - - # Update the database with the current count of completed cells - with self.db_session() as session: - session.query(Job).filter(Job.job_id == self.job_id).update( - {"completed_cells": self.code_cells_executed} - ) - session.commit() - - return cell, resources - - class ExecutionManager(ABC): """Base execution manager. Clients are expected to override this class @@ -158,14 +132,14 @@ def execute(self): nb = add_parameters(nb, job.parameters) staging_dir = os.path.dirname(self.staging_paths["input"]) - ep = TrackingExecutePreprocessor( - db_session=self.db_session, - job_id=self.job_id, - kernel_name=nb.metadata.kernelspec["name"], - store_widget_state=True, - cwd=staging_dir + + ep = ExecutePreprocessor( + kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir ) + if self.supported_features().get(JobFeature.track_cell_execution, False): + ep.on_cell_executed = self.__update_completed_cells_hook(ep) + try: ep.preprocess(nb, {"metadata": {"path": staging_dir}}) except CellExecutionError as e: @@ -174,6 +148,16 @@ def execute(self): self.add_side_effects_files(staging_dir) self.create_output_files(job, nb) + def __update_completed_cells_hook(self, ep: ExecutePreprocessor): + """Returns a hook that runs on every cell execution, regardless of success or failure. Updates the completed_cells for the job.""" + def update_completed_cells(cell, cell_index, execute_reply): + with self.db_session() as session: + session.query(Job).filter(Job.job_id == self.job_id).update( + {"completed_cells": ep.code_cells_executed} + ) + session.commit() + return update_completed_cells + def add_side_effects_files(self, staging_dir: str): """Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files""" input_notebook = os.path.relpath(self.staging_paths["input"]) @@ -203,6 +187,7 @@ def create_output_files(self, job: DescribeJob, notebook_node): with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f: f.write(output) + @classmethod def supported_features(cls) -> Dict[JobFeature, bool]: return { JobFeature.job_name: True, @@ -218,8 +203,10 @@ def supported_features(cls) -> Dict[JobFeature, bool]: JobFeature.output_filename_template: False, JobFeature.stop_job: True, JobFeature.delete_job: True, + JobFeature.track_cell_execution: True, } + @classmethod def validate(cls, input_path: str) -> bool: with open(input_path, encoding="utf-8") as f: nb = nbformat.read(f, as_version=4) diff --git a/jupyter_scheduler/models.py b/jupyter_scheduler/models.py index 5dc28e1b2..85f9d6844 100644 --- a/jupyter_scheduler/models.py +++ b/jupyter_scheduler/models.py @@ -297,3 +297,4 @@ class JobFeature(str, Enum): output_filename_template = "output_filename_template" stop_job = "stop_job" delete_job = "delete_job" + track_cell_execution = "track_cell_execution" diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 867034c60..382fba3f5 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -442,7 +442,7 @@ def create_job(self, model: CreateJob) -> str: raise InputUriError(model.input_uri) input_path = os.path.join(self.root_dir, model.input_uri) - if not self.execution_manager_class.validate(self.execution_manager_class, input_path): + if not self.execution_manager_class.validate(input_path): raise SchedulerError( """There is no kernel associated with the notebook. Please open the notebook, select a kernel, and re-submit the job to execute. diff --git a/jupyter_scheduler/tests/test_execution_manager.py b/jupyter_scheduler/tests/test_execution_manager.py index abc60d1aa..99702f506 100644 --- a/jupyter_scheduler/tests/test_execution_manager.py +++ b/jupyter_scheduler/tests/test_execution_manager.py @@ -4,9 +4,8 @@ from unittest.mock import MagicMock, patch import pytest -import nbformat -from jupyter_scheduler.executors import DefaultExecutionManager, TrackingExecutePreprocessor +from jupyter_scheduler.executors import DefaultExecutionManager from jupyter_scheduler.orm import Job @@ -62,126 +61,231 @@ def test_add_side_effects_files( assert side_effect_file_name in job.packaged_files -@pytest.fixture -def mock_cell(): - """Create a mock notebook cell for testing""" - cell = nbformat.v4.new_code_cell(source="print('test')") - return cell +def test_default_execution_manager_cell_tracking_hook(): + """Test that DefaultExecutionManager sets up on_cell_executed hook when track_cell_execution is supported""" + job_id = "test-job-id" + with patch.object(DefaultExecutionManager, 'model') as mock_model: + with patch('jupyter_scheduler.executors.open', mock=MagicMock()): + with patch('jupyter_scheduler.executors.nbformat.read') as mock_nb_read: + with patch.object(DefaultExecutionManager, 'add_side_effects_files'): + with patch.object(DefaultExecutionManager, 'create_output_files'): + # Mock notebook + mock_nb = MagicMock() + mock_nb.metadata.kernelspec = {"name": "python3"} + mock_nb_read.return_value = mock_nb -@pytest.fixture -def mock_resources(): - """Create mock resources for testing""" - return {"metadata": {"path": "/test/path"}} + # Mock model + mock_model.parameters = None + mock_model.output_formats = [] + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"} + ) -def test_tracking_execute_preprocessor_initialization(): - """Test TrackingExecutePreprocessor initialization""" - mock_db_session = MagicMock() + # Patch ExecutePreprocessor + with patch('jupyter_scheduler.executors.ExecutePreprocessor') as mock_ep_class: + mock_ep = MagicMock() + mock_ep_class.return_value = mock_ep + + # Execute + manager.execute() + + # Verify ExecutePreprocessor was created + mock_ep_class.assert_called_once() + + # Verify on_cell_executed hook was set + assert hasattr(mock_ep, 'on_cell_executed') + assert mock_ep.on_cell_executed is not None + + +def test_update_completed_cells_hook(): + """Test the __update_completed_cells_hook method""" job_id = "test-job-id" - - preprocessor = TrackingExecutePreprocessor( - db_session=mock_db_session, + + # Create manager + manager = DefaultExecutionManager( job_id=job_id, - kernel_name="python3" + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"} ) - - assert preprocessor.db_session == mock_db_session - assert preprocessor.job_id == job_id - assert preprocessor.kernel_name == "python3" - -def test_tracking_execute_preprocessor_updates_database(mock_cell, mock_resources): - """Test that TrackingExecutePreprocessor updates the database after cell execution""" + # Mock db_session mock_db_session = MagicMock() mock_session_context = MagicMock() mock_db_session.return_value.__enter__.return_value = mock_session_context - + manager._db_session = mock_db_session + + # Mock ExecutePreprocessor + mock_ep = MagicMock() + mock_ep.code_cells_executed = 5 + + # Get the hook function + hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep) + + # Call the hook + mock_cell = MagicMock() + mock_execute_reply = MagicMock() + hook_func(mock_cell, 2, mock_execute_reply) + + # Verify database update was called + mock_session_context.query.assert_called_once_with(Job) + mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with( + {"completed_cells": 5} + ) + mock_session_context.commit.assert_called_once() + + +def test_update_completed_cells_hook_database_error(): + """Test that database errors in the hook are handled""" job_id = "test-job-id" - - with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute: - with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'): - preprocessor = TrackingExecutePreprocessor( - db_session=mock_db_session, - job_id=job_id, - kernel_name="python3" - ) - - # Mock the code_cells_executed attribute - preprocessor.code_cells_executed = 3 - preprocessor.resources = mock_resources - - # Mock the execute_cell method to return the cell - mock_execute.return_value = mock_cell - - # Call preprocess_cell - result_cell, result_resources = preprocessor.preprocess_cell(mock_cell, mock_resources, 0) - - # Verify the superclass method was called - mock_execute.assert_called_once_with(mock_cell, 0, store_history=True) - - # Verify database update was called - mock_session_context.query.assert_called_once_with(Job) - mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with( - {"completed_cells": 3} - ) - mock_session_context.commit.assert_called_once() - - # Verify return values - assert result_cell == mock_cell - assert result_resources == mock_resources - - -def test_tracking_execute_preprocessor_handles_database_errors(mock_cell, mock_resources): - """Test that TrackingExecutePreprocessor handles database errors gracefully""" + + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"} + ) + + # Mock db_session with error mock_db_session = MagicMock() mock_session_context = MagicMock() - mock_db_session.return_value.__enter__.return_value = mock_session_context - - # Make the database update raise an exception mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception("DB Error") - - job_id = "test-job-id" - - with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute: - with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'): - preprocessor = TrackingExecutePreprocessor( - db_session=mock_db_session, - job_id=job_id, - kernel_name="python3" - ) - - preprocessor.code_cells_executed = 1 - preprocessor.resources = mock_resources - mock_execute.return_value = mock_cell - - # The database error should propagate - with pytest.raises(Exception, match="DB Error"): - preprocessor.preprocess_cell(mock_cell, mock_resources, 0) - - -def test_tracking_execute_preprocessor_uses_correct_job_id(mock_cell, mock_resources): - """Test that TrackingExecutePreprocessor uses the correct job_id in database queries""" + mock_db_session.return_value.__enter__.return_value = mock_session_context + manager._db_session = mock_db_session + + # Mock ExecutePreprocessor + mock_ep = MagicMock() + mock_ep.code_cells_executed = 3 + + # Get the hook function + hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep) + + # Call the hook - should raise exception + mock_cell = MagicMock() + mock_execute_reply = MagicMock() + + with pytest.raises(Exception, match="DB Error"): + hook_func(mock_cell, 1, mock_execute_reply) + + +def test_supported_features_includes_track_cell_execution(): + """Test that DefaultExecutionManager supports track_cell_execution feature""" + features = DefaultExecutionManager.supported_features() + + from jupyter_scheduler.models import JobFeature + assert JobFeature.track_cell_execution in features + assert features[JobFeature.track_cell_execution] is True + + +def test_hook_uses_correct_job_id(): + """Test that the hook uses the correct job_id in database queries""" + job_id = "specific-job-id-456" + + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"} + ) + + # Mock db_session mock_db_session = MagicMock() mock_session_context = MagicMock() mock_db_session.return_value.__enter__.return_value = mock_session_context - - job_id = "specific-job-id-123" - - with patch.object(TrackingExecutePreprocessor, 'execute_cell') as mock_execute: - with patch.object(TrackingExecutePreprocessor, '_check_assign_resources'): - preprocessor = TrackingExecutePreprocessor( - db_session=mock_db_session, - job_id=job_id, - kernel_name="python3" - ) - - preprocessor.code_cells_executed = 2 - preprocessor.resources = mock_resources - mock_execute.return_value = mock_cell - - preprocessor.preprocess_cell(mock_cell, mock_resources, 0) - - # Verify the correct job_id is used in the filter - filter_call = mock_session_context.query.return_value.filter.call_args[0][0] - assert str(filter_call).find(job_id) != -1 or filter_call.right.value == job_id + manager._db_session = mock_db_session + + # Mock ExecutePreprocessor + mock_ep = MagicMock() + mock_ep.code_cells_executed = 7 + + # Get the hook function + hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep) + + # Call the hook + mock_cell = MagicMock() + mock_execute_reply = MagicMock() + hook_func(mock_cell, 3, mock_execute_reply) + + # Verify the correct job_id is used in the filter + # The filter call should contain a condition that matches Job.job_id == job_id + filter_call = mock_session_context.query.return_value.filter.call_args[0][0] + # This is a SQLAlchemy comparison object, so we need to check its properties + assert hasattr(filter_call, 'right') + assert filter_call.right.value == job_id + + +def test_cell_tracking_disabled_when_feature_false(): + """Test that cell tracking hook is not set when track_cell_execution feature is False""" + job_id = "test-job-id" + + # Create a custom execution manager class with track_cell_execution = False + class DisabledTrackingExecutionManager(DefaultExecutionManager): + @classmethod + def supported_features(cls): + features = super().supported_features() + from jupyter_scheduler.models import JobFeature + features[JobFeature.track_cell_execution] = False + return features + + # Create manager with disabled tracking + manager = DisabledTrackingExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"} + ) + + # Mock ExecutePreprocessor and track calls to __update_completed_cells_hook + with patch.object(manager, '_DefaultExecutionManager__update_completed_cells_hook') as mock_hook_method: + with patch.object(DisabledTrackingExecutionManager, 'model') as mock_model: + with patch('jupyter_scheduler.executors.open', mock=MagicMock()): + with patch('jupyter_scheduler.executors.nbformat.read') as mock_nb_read: + with patch.object(DisabledTrackingExecutionManager, 'add_side_effects_files'): + with patch.object(DisabledTrackingExecutionManager, 'create_output_files'): + with patch('jupyter_scheduler.executors.ExecutePreprocessor') as mock_ep_class: + # Mock notebook + mock_nb = MagicMock() + mock_nb.metadata.kernelspec = {"name": "python3"} + mock_nb_read.return_value = mock_nb + + # Mock model + mock_model.parameters = None + mock_model.output_formats = [] + + mock_ep = MagicMock() + mock_ep_class.return_value = mock_ep + + # Execute + manager.execute() + + # Verify ExecutePreprocessor was created + mock_ep_class.assert_called_once() + + # Verify the hook method was NOT called when feature is disabled + mock_hook_method.assert_not_called() + + +def test_disabled_tracking_feature_support(): + """Test that custom execution manager can disable track_cell_execution feature""" + # Create a custom execution manager class with track_cell_execution = False + class DisabledTrackingExecutionManager(DefaultExecutionManager): + @classmethod + def supported_features(cls): + features = super().supported_features() + from jupyter_scheduler.models import JobFeature + features[JobFeature.track_cell_execution] = False + return features + + features = DisabledTrackingExecutionManager.supported_features() + + from jupyter_scheduler.models import JobFeature + assert JobFeature.track_cell_execution in features + assert features[JobFeature.track_cell_execution] is False From 8c32ef2fe59ae4903026fa9f45461f508c694338 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Jul 2025 01:12:09 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jupyter_scheduler/executors.py | 2 + .../tests/test_execution_manager.py | 55 ++++++++++------- jupyter_scheduler/tests/test_handlers.py | 18 +++--- jupyter_scheduler/tests/test_orm.py | 59 +++++++++---------- 4 files changed, 74 insertions(+), 60 deletions(-) diff --git a/jupyter_scheduler/executors.py b/jupyter_scheduler/executors.py index 26a3ab7f4..402cbe000 100644 --- a/jupyter_scheduler/executors.py +++ b/jupyter_scheduler/executors.py @@ -150,12 +150,14 @@ def execute(self): def __update_completed_cells_hook(self, ep: ExecutePreprocessor): """Returns a hook that runs on every cell execution, regardless of success or failure. Updates the completed_cells for the job.""" + def update_completed_cells(cell, cell_index, execute_reply): with self.db_session() as session: session.query(Job).filter(Job.job_id == self.job_id).update( {"completed_cells": ep.code_cells_executed} ) session.commit() + return update_completed_cells def add_side_effects_files(self, staging_dir: str): diff --git a/jupyter_scheduler/tests/test_execution_manager.py b/jupyter_scheduler/tests/test_execution_manager.py index 99702f506..8e73455aa 100644 --- a/jupyter_scheduler/tests/test_execution_manager.py +++ b/jupyter_scheduler/tests/test_execution_manager.py @@ -65,11 +65,11 @@ def test_default_execution_manager_cell_tracking_hook(): """Test that DefaultExecutionManager sets up on_cell_executed hook when track_cell_execution is supported""" job_id = "test-job-id" - with patch.object(DefaultExecutionManager, 'model') as mock_model: - with patch('jupyter_scheduler.executors.open', mock=MagicMock()): - with patch('jupyter_scheduler.executors.nbformat.read') as mock_nb_read: - with patch.object(DefaultExecutionManager, 'add_side_effects_files'): - with patch.object(DefaultExecutionManager, 'create_output_files'): + with patch.object(DefaultExecutionManager, "model") as mock_model: + with patch("jupyter_scheduler.executors.open", mock=MagicMock()): + with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read: + with patch.object(DefaultExecutionManager, "add_side_effects_files"): + with patch.object(DefaultExecutionManager, "create_output_files"): # Mock notebook mock_nb = MagicMock() mock_nb.metadata.kernelspec = {"name": "python3"} @@ -84,11 +84,13 @@ def test_default_execution_manager_cell_tracking_hook(): job_id=job_id, root_dir="/test", db_url="sqlite:///:memory:", - staging_paths={"input": "/test/input.ipynb"} + staging_paths={"input": "/test/input.ipynb"}, ) # Patch ExecutePreprocessor - with patch('jupyter_scheduler.executors.ExecutePreprocessor') as mock_ep_class: + with patch( + "jupyter_scheduler.executors.ExecutePreprocessor" + ) as mock_ep_class: mock_ep = MagicMock() mock_ep_class.return_value = mock_ep @@ -99,7 +101,7 @@ def test_default_execution_manager_cell_tracking_hook(): mock_ep_class.assert_called_once() # Verify on_cell_executed hook was set - assert hasattr(mock_ep, 'on_cell_executed') + assert hasattr(mock_ep, "on_cell_executed") assert mock_ep.on_cell_executed is not None @@ -112,7 +114,7 @@ def test_update_completed_cells_hook(): job_id=job_id, root_dir="/test", db_url="sqlite:///:memory:", - staging_paths={"input": "/test/input.ipynb"} + staging_paths={"input": "/test/input.ipynb"}, ) # Mock db_session @@ -150,13 +152,15 @@ def test_update_completed_cells_hook_database_error(): job_id=job_id, root_dir="/test", db_url="sqlite:///:memory:", - staging_paths={"input": "/test/input.ipynb"} + staging_paths={"input": "/test/input.ipynb"}, ) # Mock db_session with error mock_db_session = MagicMock() mock_session_context = MagicMock() - mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception("DB Error") + mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception( + "DB Error" + ) mock_db_session.return_value.__enter__.return_value = mock_session_context manager._db_session = mock_db_session @@ -180,6 +184,7 @@ def test_supported_features_includes_track_cell_execution(): features = DefaultExecutionManager.supported_features() from jupyter_scheduler.models import JobFeature + assert JobFeature.track_cell_execution in features assert features[JobFeature.track_cell_execution] is True @@ -193,7 +198,7 @@ def test_hook_uses_correct_job_id(): job_id=job_id, root_dir="/test", db_url="sqlite:///:memory:", - staging_paths={"input": "/test/input.ipynb"} + staging_paths={"input": "/test/input.ipynb"}, ) # Mock db_session @@ -218,7 +223,7 @@ def test_hook_uses_correct_job_id(): # The filter call should contain a condition that matches Job.job_id == job_id filter_call = mock_session_context.query.return_value.filter.call_args[0][0] # This is a SQLAlchemy comparison object, so we need to check its properties - assert hasattr(filter_call, 'right') + assert hasattr(filter_call, "right") assert filter_call.right.value == job_id @@ -232,6 +237,7 @@ class DisabledTrackingExecutionManager(DefaultExecutionManager): def supported_features(cls): features = super().supported_features() from jupyter_scheduler.models import JobFeature + features[JobFeature.track_cell_execution] = False return features @@ -240,17 +246,21 @@ def supported_features(cls): job_id=job_id, root_dir="/test", db_url="sqlite:///:memory:", - staging_paths={"input": "/test/input.ipynb"} + staging_paths={"input": "/test/input.ipynb"}, ) # Mock ExecutePreprocessor and track calls to __update_completed_cells_hook - with patch.object(manager, '_DefaultExecutionManager__update_completed_cells_hook') as mock_hook_method: - with patch.object(DisabledTrackingExecutionManager, 'model') as mock_model: - with patch('jupyter_scheduler.executors.open', mock=MagicMock()): - with patch('jupyter_scheduler.executors.nbformat.read') as mock_nb_read: - with patch.object(DisabledTrackingExecutionManager, 'add_side_effects_files'): - with patch.object(DisabledTrackingExecutionManager, 'create_output_files'): - with patch('jupyter_scheduler.executors.ExecutePreprocessor') as mock_ep_class: + with patch.object( + manager, "_DefaultExecutionManager__update_completed_cells_hook" + ) as mock_hook_method: + with patch.object(DisabledTrackingExecutionManager, "model") as mock_model: + with patch("jupyter_scheduler.executors.open", mock=MagicMock()): + with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read: + with patch.object(DisabledTrackingExecutionManager, "add_side_effects_files"): + with patch.object(DisabledTrackingExecutionManager, "create_output_files"): + with patch( + "jupyter_scheduler.executors.ExecutePreprocessor" + ) as mock_ep_class: # Mock notebook mock_nb = MagicMock() mock_nb.metadata.kernelspec = {"name": "python3"} @@ -275,17 +285,20 @@ def supported_features(cls): def test_disabled_tracking_feature_support(): """Test that custom execution manager can disable track_cell_execution feature""" + # Create a custom execution manager class with track_cell_execution = False class DisabledTrackingExecutionManager(DefaultExecutionManager): @classmethod def supported_features(cls): features = super().supported_features() from jupyter_scheduler.models import JobFeature + features[JobFeature.track_cell_execution] = False return features features = DisabledTrackingExecutionManager.supported_features() from jupyter_scheduler.models import JobFeature + assert JobFeature.track_cell_execution in features assert features[JobFeature.track_cell_execution] is False diff --git a/jupyter_scheduler/tests/test_handlers.py b/jupyter_scheduler/tests/test_handlers.py index f2ab90201..3314e8c6a 100644 --- a/jupyter_scheduler/tests/test_handlers.py +++ b/jupyter_scheduler/tests/test_handlers.py @@ -715,21 +715,21 @@ def test_describe_job_completed_cells_validation(): "url": "http://test.com/jobs/test-job-id", "create_time": 1234567890, "update_time": 1234567890, - "completed_cells": 5 + "completed_cells": 5, } job = DescribeJob(**job_data) assert job.completed_cells == 5 - + # Test None value job_data["completed_cells"] = None job = DescribeJob(**job_data) assert job.completed_cells is None - + # Test zero value job_data["completed_cells"] = 0 job = DescribeJob(**job_data) assert job.completed_cells == 0 - + # Test invalid type job_data["completed_cells"] = "invalid" with pytest.raises(ValidationError): @@ -742,29 +742,29 @@ def test_update_job_completed_cells_validation(): update_data = {"completed_cells": 10} update_job = UpdateJob(**update_data) assert update_job.completed_cells == 10 - + # Test None value update_data = {"completed_cells": None} update_job = UpdateJob(**update_data) assert update_job.completed_cells is None - + # Test zero value update_data = {"completed_cells": 0} update_job = UpdateJob(**update_data) assert update_job.completed_cells == 0 - + # Test invalid type update_data = {"completed_cells": "invalid"} with pytest.raises(ValidationError): UpdateJob(**update_data) - + # Test exclude_none behavior update_data = {"name": "test", "completed_cells": None} update_job = UpdateJob(**update_data) job_dict = update_job.dict(exclude_none=True) assert "completed_cells" not in job_dict assert job_dict["name"] == "test" - + # Test include completed_cells when not None update_data = {"name": "test", "completed_cells": 5} update_job = UpdateJob(**update_data) diff --git a/jupyter_scheduler/tests/test_orm.py b/jupyter_scheduler/tests/test_orm.py index 65c04da01..4d0e3c5ec 100644 --- a/jupyter_scheduler/tests/test_orm.py +++ b/jupyter_scheduler/tests/test_orm.py @@ -75,90 +75,89 @@ def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_ def test_completed_cells_column_migration(jp_scheduler_db_url): """Test that the completed_cells column is properly added during migration""" - from jupyter_scheduler.orm import Base, Job, create_tables from sqlalchemy import create_engine, inspect from sqlalchemy.orm import sessionmaker - + + from jupyter_scheduler.orm import Base, Job, create_tables + # Create initial database without completed_cells engine = create_engine(jp_scheduler_db_url) - + # Create tables with the current schema (which includes completed_cells) create_tables(db_url=jp_scheduler_db_url, Base=Base) - + # Verify the completed_cells column exists inspector = inspect(engine) columns = {col["name"] for col in inspector.get_columns("jobs")} assert "completed_cells" in columns - + # Verify the column is of correct type (Integer) - completed_cells_column = next(col for col in inspector.get_columns("jobs") if col["name"] == "completed_cells") + completed_cells_column = next( + col for col in inspector.get_columns("jobs") if col["name"] == "completed_cells" + ) assert str(completed_cells_column["type"]).upper() in ["INTEGER", "INT"] - + # Test that we can insert and retrieve completed_cells values Session = sessionmaker(bind=engine) session = Session() - - job = Job( - runtime_environment_name="test_env", - input_filename="test.ipynb", - completed_cells=5 - ) + + job = Job(runtime_environment_name="test_env", input_filename="test.ipynb", completed_cells=5) session.add(job) session.commit() - + # Retrieve and verify retrieved_job = session.query(Job).filter(Job.job_id == job.job_id).one() assert retrieved_job.completed_cells == 5 - + # Test null values are handled properly job_null = Job( runtime_environment_name="test_env_null", input_filename="test_null.ipynb", - completed_cells=None + completed_cells=None, ) session.add(job_null) session.commit() - + retrieved_job_null = session.query(Job).filter(Job.job_id == job_null.job_id).one() assert retrieved_job_null.completed_cells is None - + session.close() def test_completed_cells_column_nullable(jp_scheduler_db_url): """Test that completed_cells column is nullable for backward compatibility""" - from jupyter_scheduler.orm import Base, Job, create_tables from sqlalchemy import create_engine, inspect from sqlalchemy.orm import sessionmaker - + + from jupyter_scheduler.orm import Base, Job, create_tables + create_tables(db_url=jp_scheduler_db_url, Base=Base) - + engine = create_engine(jp_scheduler_db_url) inspector = inspect(engine) - + # Find the completed_cells column completed_cells_column = next( - col for col in inspector.get_columns("jobs") - if col["name"] == "completed_cells" + col for col in inspector.get_columns("jobs") if col["name"] == "completed_cells" ) - + # Verify it's nullable assert completed_cells_column["nullable"] is True - + # Test creating a job without completed_cells Session = sessionmaker(bind=engine) session = Session() - + job = Job( runtime_environment_name="test_env", - input_filename="test.ipynb" + input_filename="test.ipynb", # Note: not setting completed_cells ) session.add(job) session.commit() - + # Verify it defaults to None retrieved_job = session.query(Job).filter(Job.job_id == job.job_id).one() assert retrieved_job.completed_cells is None - + session.close()