diff --git a/jupyter_scheduler/executors.py b/jupyter_scheduler/executors.py index 7e1a9974..402cbe00 100644 --- a/jupyter_scheduler/executors.py +++ b/jupyter_scheduler/executors.py @@ -132,10 +132,14 @@ def execute(self): nb = add_parameters(nb, job.parameters) staging_dir = os.path.dirname(self.staging_paths["input"]) + ep = ExecutePreprocessor( kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir ) + if self.supported_features().get(JobFeature.track_cell_execution, False): + ep.on_cell_executed = self.__update_completed_cells_hook(ep) + try: ep.preprocess(nb, {"metadata": {"path": staging_dir}}) except CellExecutionError as e: @@ -144,6 +148,18 @@ def execute(self): self.add_side_effects_files(staging_dir) self.create_output_files(job, nb) + def __update_completed_cells_hook(self, ep: ExecutePreprocessor): + """Returns a hook that runs on every cell execution, regardless of success or failure. Updates the completed_cells for the job.""" + + def update_completed_cells(cell, cell_index, execute_reply): + with self.db_session() as session: + session.query(Job).filter(Job.job_id == self.job_id).update( + {"completed_cells": ep.code_cells_executed} + ) + session.commit() + + return update_completed_cells + def add_side_effects_files(self, staging_dir: str): """Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files""" input_notebook = os.path.relpath(self.staging_paths["input"]) @@ -173,6 +189,7 @@ def create_output_files(self, job: DescribeJob, notebook_node): with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f: f.write(output) + @classmethod def supported_features(cls) -> Dict[JobFeature, bool]: return { JobFeature.job_name: True, @@ -188,8 +205,10 @@ def supported_features(cls) -> Dict[JobFeature, bool]: JobFeature.output_filename_template: False, JobFeature.stop_job: True, JobFeature.delete_job: True, + JobFeature.track_cell_execution: True, } + @classmethod def validate(cls, input_path: str) -> bool: with open(input_path, encoding="utf-8") as f: nb = nbformat.read(f, as_version=4) diff --git a/jupyter_scheduler/models.py b/jupyter_scheduler/models.py index 38e240e0..85f9d684 100644 --- a/jupyter_scheduler/models.py +++ b/jupyter_scheduler/models.py @@ -148,6 +148,7 @@ class DescribeJob(BaseModel): downloaded: bool = False package_input_folder: Optional[bool] = None packaged_files: Optional[List[str]] = [] + completed_cells: Optional[int] = None class Config: orm_mode = True @@ -193,6 +194,7 @@ class UpdateJob(BaseModel): status: Optional[Status] = None name: Optional[str] = None compute_type: Optional[str] = None + completed_cells: Optional[int] = None class DeleteJob(BaseModel): @@ -295,3 +297,4 @@ class JobFeature(str, Enum): output_filename_template = "output_filename_template" stop_job = "stop_job" delete_job = "delete_job" + track_cell_execution = "track_cell_execution" diff --git a/jupyter_scheduler/orm.py b/jupyter_scheduler/orm.py index dbbbfad8..c3131cd9 100644 --- a/jupyter_scheduler/orm.py +++ b/jupyter_scheduler/orm.py @@ -103,6 +103,7 @@ class Job(CommonColumns, Base): url = Column(String(256), default=generate_jobs_url) pid = Column(Integer) idempotency_token = Column(String(256)) + completed_cells = Column(Integer, nullable=True) # All new columns added to this table must be nullable to ensure compatibility during database migrations. # Any default values specified for new columns will be ignored during the migration process. diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 867034c6..382fba3f 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -442,7 +442,7 @@ def create_job(self, model: CreateJob) -> str: raise InputUriError(model.input_uri) input_path = os.path.join(self.root_dir, model.input_uri) - if not self.execution_manager_class.validate(self.execution_manager_class, input_path): + if not self.execution_manager_class.validate(input_path): raise SchedulerError( """There is no kernel associated with the notebook. Please open the notebook, select a kernel, and re-submit the job to execute. diff --git a/jupyter_scheduler/tests/test_execution_manager.py b/jupyter_scheduler/tests/test_execution_manager.py index 66546be3..8e73455a 100644 --- a/jupyter_scheduler/tests/test_execution_manager.py +++ b/jupyter_scheduler/tests/test_execution_manager.py @@ -1,6 +1,7 @@ import shutil from pathlib import Path from typing import Tuple +from unittest.mock import MagicMock, patch import pytest @@ -58,3 +59,246 @@ def test_add_side_effects_files( job = jp_scheduler_db.query(Job).filter(Job.job_id == job_id).one() assert side_effect_file_name in job.packaged_files + + +def test_default_execution_manager_cell_tracking_hook(): + """Test that DefaultExecutionManager sets up on_cell_executed hook when track_cell_execution is supported""" + job_id = "test-job-id" + + with patch.object(DefaultExecutionManager, "model") as mock_model: + with patch("jupyter_scheduler.executors.open", mock=MagicMock()): + with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read: + with patch.object(DefaultExecutionManager, "add_side_effects_files"): + with patch.object(DefaultExecutionManager, "create_output_files"): + # Mock notebook + mock_nb = MagicMock() + mock_nb.metadata.kernelspec = {"name": "python3"} + mock_nb_read.return_value = mock_nb + + # Mock model + mock_model.parameters = None + mock_model.output_formats = [] + + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"}, + ) + + # Patch ExecutePreprocessor + with patch( + "jupyter_scheduler.executors.ExecutePreprocessor" + ) as mock_ep_class: + mock_ep = MagicMock() + mock_ep_class.return_value = mock_ep + + # Execute + manager.execute() + + # Verify ExecutePreprocessor was created + mock_ep_class.assert_called_once() + + # Verify on_cell_executed hook was set + assert hasattr(mock_ep, "on_cell_executed") + assert mock_ep.on_cell_executed is not None + + +def test_update_completed_cells_hook(): + """Test the __update_completed_cells_hook method""" + job_id = "test-job-id" + + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"}, + ) + + # Mock db_session + mock_db_session = MagicMock() + mock_session_context = MagicMock() + mock_db_session.return_value.__enter__.return_value = mock_session_context + manager._db_session = mock_db_session + + # Mock ExecutePreprocessor + mock_ep = MagicMock() + mock_ep.code_cells_executed = 5 + + # Get the hook function + hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep) + + # Call the hook + mock_cell = MagicMock() + mock_execute_reply = MagicMock() + hook_func(mock_cell, 2, mock_execute_reply) + + # Verify database update was called + mock_session_context.query.assert_called_once_with(Job) + mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with( + {"completed_cells": 5} + ) + mock_session_context.commit.assert_called_once() + + +def test_update_completed_cells_hook_database_error(): + """Test that database errors in the hook are handled""" + job_id = "test-job-id" + + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"}, + ) + + # Mock db_session with error + mock_db_session = MagicMock() + mock_session_context = MagicMock() + mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception( + "DB Error" + ) + mock_db_session.return_value.__enter__.return_value = mock_session_context + manager._db_session = mock_db_session + + # Mock ExecutePreprocessor + mock_ep = MagicMock() + mock_ep.code_cells_executed = 3 + + # Get the hook function + hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep) + + # Call the hook - should raise exception + mock_cell = MagicMock() + mock_execute_reply = MagicMock() + + with pytest.raises(Exception, match="DB Error"): + hook_func(mock_cell, 1, mock_execute_reply) + + +def test_supported_features_includes_track_cell_execution(): + """Test that DefaultExecutionManager supports track_cell_execution feature""" + features = DefaultExecutionManager.supported_features() + + from jupyter_scheduler.models import JobFeature + + assert JobFeature.track_cell_execution in features + assert features[JobFeature.track_cell_execution] is True + + +def test_hook_uses_correct_job_id(): + """Test that the hook uses the correct job_id in database queries""" + job_id = "specific-job-id-456" + + # Create manager + manager = DefaultExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"}, + ) + + # Mock db_session + mock_db_session = MagicMock() + mock_session_context = MagicMock() + mock_db_session.return_value.__enter__.return_value = mock_session_context + manager._db_session = mock_db_session + + # Mock ExecutePreprocessor + mock_ep = MagicMock() + mock_ep.code_cells_executed = 7 + + # Get the hook function + hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep) + + # Call the hook + mock_cell = MagicMock() + mock_execute_reply = MagicMock() + hook_func(mock_cell, 3, mock_execute_reply) + + # Verify the correct job_id is used in the filter + # The filter call should contain a condition that matches Job.job_id == job_id + filter_call = mock_session_context.query.return_value.filter.call_args[0][0] + # This is a SQLAlchemy comparison object, so we need to check its properties + assert hasattr(filter_call, "right") + assert filter_call.right.value == job_id + + +def test_cell_tracking_disabled_when_feature_false(): + """Test that cell tracking hook is not set when track_cell_execution feature is False""" + job_id = "test-job-id" + + # Create a custom execution manager class with track_cell_execution = False + class DisabledTrackingExecutionManager(DefaultExecutionManager): + @classmethod + def supported_features(cls): + features = super().supported_features() + from jupyter_scheduler.models import JobFeature + + features[JobFeature.track_cell_execution] = False + return features + + # Create manager with disabled tracking + manager = DisabledTrackingExecutionManager( + job_id=job_id, + root_dir="/test", + db_url="sqlite:///:memory:", + staging_paths={"input": "/test/input.ipynb"}, + ) + + # Mock ExecutePreprocessor and track calls to __update_completed_cells_hook + with patch.object( + manager, "_DefaultExecutionManager__update_completed_cells_hook" + ) as mock_hook_method: + with patch.object(DisabledTrackingExecutionManager, "model") as mock_model: + with patch("jupyter_scheduler.executors.open", mock=MagicMock()): + with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read: + with patch.object(DisabledTrackingExecutionManager, "add_side_effects_files"): + with patch.object(DisabledTrackingExecutionManager, "create_output_files"): + with patch( + "jupyter_scheduler.executors.ExecutePreprocessor" + ) as mock_ep_class: + # Mock notebook + mock_nb = MagicMock() + mock_nb.metadata.kernelspec = {"name": "python3"} + mock_nb_read.return_value = mock_nb + + # Mock model + mock_model.parameters = None + mock_model.output_formats = [] + + mock_ep = MagicMock() + mock_ep_class.return_value = mock_ep + + # Execute + manager.execute() + + # Verify ExecutePreprocessor was created + mock_ep_class.assert_called_once() + + # Verify the hook method was NOT called when feature is disabled + mock_hook_method.assert_not_called() + + +def test_disabled_tracking_feature_support(): + """Test that custom execution manager can disable track_cell_execution feature""" + + # Create a custom execution manager class with track_cell_execution = False + class DisabledTrackingExecutionManager(DefaultExecutionManager): + @classmethod + def supported_features(cls): + features = super().supported_features() + from jupyter_scheduler.models import JobFeature + + features[JobFeature.track_cell_execution] = False + return features + + features = DisabledTrackingExecutionManager.supported_features() + + from jupyter_scheduler.models import JobFeature + + assert JobFeature.track_cell_execution in features + assert features[JobFeature.track_cell_execution] is False diff --git a/jupyter_scheduler/tests/test_handlers.py b/jupyter_scheduler/tests/test_handlers.py index 9e2e4b7b..3314e8c6 100644 --- a/jupyter_scheduler/tests/test_handlers.py +++ b/jupyter_scheduler/tests/test_handlers.py @@ -131,6 +131,7 @@ async def test_get_jobs_for_single_job(jp_fetch): url="url_a", create_time=1664305872620, update_time=1664305872620, + completed_cells=5, ) response = await jp_fetch("scheduler", "jobs", job_id, method="GET") @@ -140,6 +141,7 @@ async def test_get_jobs_for_single_job(jp_fetch): assert body["job_id"] == job_id assert body["input_filename"] assert body["job_files"] + assert body["completed_cells"] == 5 @pytest.mark.parametrize( @@ -320,6 +322,28 @@ async def test_patch_jobs(jp_fetch): mock_update_job.assert_called_once_with(job_id, UpdateJob(**body)) +async def test_patch_jobs_with_completed_cells(jp_fetch): + with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job: + job_id = "542e0fac-1274-4a78-8340-a850bdb559c8" + body = {"name": "updated job", "completed_cells": 10} + response = await jp_fetch( + "scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body) + ) + assert response.code == 204 + mock_update_job.assert_called_once_with(job_id, UpdateJob(**body)) + + +async def test_patch_jobs_completed_cells_only(jp_fetch): + with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job: + job_id = "542e0fac-1274-4a78-8340-a850bdb559c8" + body = {"completed_cells": 15} + response = await jp_fetch( + "scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body) + ) + assert response.code == 204 + mock_update_job.assert_called_once_with(job_id, UpdateJob(**body)) + + async def test_patch_jobs_for_stop_job(jp_fetch): with patch("jupyter_scheduler.scheduler.Scheduler.stop_job") as mock_stop_job: job_id = "542e0fac-1274-4a78-8340-a850bdb559c8" @@ -677,3 +701,73 @@ async def test_delete_job_definition_for_unexpected_error(jp_fetch): assert expected_http_error( e, 500, "Unexpected error occurred while deleting the job definition." ) + + +# Model validation tests for completed_cells field +def test_describe_job_completed_cells_validation(): + """Test DescribeJob model validation for completed_cells field""" + # Test valid integer values + job_data = { + "name": "test_job", + "input_filename": "test.ipynb", + "runtime_environment_name": "test_env", + "job_id": "test-job-id", + "url": "http://test.com/jobs/test-job-id", + "create_time": 1234567890, + "update_time": 1234567890, + "completed_cells": 5, + } + job = DescribeJob(**job_data) + assert job.completed_cells == 5 + + # Test None value + job_data["completed_cells"] = None + job = DescribeJob(**job_data) + assert job.completed_cells is None + + # Test zero value + job_data["completed_cells"] = 0 + job = DescribeJob(**job_data) + assert job.completed_cells == 0 + + # Test invalid type + job_data["completed_cells"] = "invalid" + with pytest.raises(ValidationError): + DescribeJob(**job_data) + + +def test_update_job_completed_cells_validation(): + """Test UpdateJob model validation for completed_cells field""" + # Test valid integer values + update_data = {"completed_cells": 10} + update_job = UpdateJob(**update_data) + assert update_job.completed_cells == 10 + + # Test None value + update_data = {"completed_cells": None} + update_job = UpdateJob(**update_data) + assert update_job.completed_cells is None + + # Test zero value + update_data = {"completed_cells": 0} + update_job = UpdateJob(**update_data) + assert update_job.completed_cells == 0 + + # Test invalid type + update_data = {"completed_cells": "invalid"} + with pytest.raises(ValidationError): + UpdateJob(**update_data) + + # Test exclude_none behavior + update_data = {"name": "test", "completed_cells": None} + update_job = UpdateJob(**update_data) + job_dict = update_job.dict(exclude_none=True) + assert "completed_cells" not in job_dict + assert job_dict["name"] == "test" + + # Test include completed_cells when not None + update_data = {"name": "test", "completed_cells": 5} + update_job = UpdateJob(**update_data) + job_dict = update_job.dict(exclude_none=True) + assert job_dict["completed_cells"] == 5 + assert job_dict["name"] == "test" diff --git a/jupyter_scheduler/tests/test_orm.py b/jupyter_scheduler/tests/test_orm.py index e2aab07e..4d0e3c5e 100644 --- a/jupyter_scheduler/tests/test_orm.py +++ b/jupyter_scheduler/tests/test_orm.py @@ -71,3 +71,93 @@ def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_ assert hasattr(updated_job, "new_column") assert updated_job.runtime_environment_name == "abc" assert updated_job.input_filename == "input.ipynb" + + +def test_completed_cells_column_migration(jp_scheduler_db_url): + """Test that the completed_cells column is properly added during migration""" + from sqlalchemy import create_engine, inspect + from sqlalchemy.orm import sessionmaker + + from jupyter_scheduler.orm import Base, Job, create_tables + + # Create initial database without completed_cells + engine = create_engine(jp_scheduler_db_url) + + # Create tables with the current schema (which includes completed_cells) + create_tables(db_url=jp_scheduler_db_url, Base=Base) + + # Verify the completed_cells column exists + inspector = inspect(engine) + columns = {col["name"] for col in inspector.get_columns("jobs")} + assert "completed_cells" in columns + + # Verify the column is of correct type (Integer) + completed_cells_column = next( + col for col in inspector.get_columns("jobs") if col["name"] == "completed_cells" + ) + assert str(completed_cells_column["type"]).upper() in ["INTEGER", "INT"] + + # Test that we can insert and retrieve completed_cells values + Session = sessionmaker(bind=engine) + session = Session() + + job = Job(runtime_environment_name="test_env", input_filename="test.ipynb", completed_cells=5) + session.add(job) + session.commit() + + # Retrieve and verify + retrieved_job = session.query(Job).filter(Job.job_id == job.job_id).one() + assert retrieved_job.completed_cells == 5 + + # Test null values are handled properly + job_null = Job( + runtime_environment_name="test_env_null", + input_filename="test_null.ipynb", + completed_cells=None, + ) + session.add(job_null) + session.commit() + + retrieved_job_null = session.query(Job).filter(Job.job_id == job_null.job_id).one() + assert retrieved_job_null.completed_cells is None + + session.close() + + +def test_completed_cells_column_nullable(jp_scheduler_db_url): + """Test that completed_cells column is nullable for backward compatibility""" + from sqlalchemy import create_engine, inspect + from sqlalchemy.orm import sessionmaker + + from jupyter_scheduler.orm import Base, Job, create_tables + + create_tables(db_url=jp_scheduler_db_url, Base=Base) + + engine = create_engine(jp_scheduler_db_url) + inspector = inspect(engine) + + # Find the completed_cells column + completed_cells_column = next( + col for col in inspector.get_columns("jobs") if col["name"] == "completed_cells" + ) + + # Verify it's nullable + assert completed_cells_column["nullable"] is True + + # Test creating a job without completed_cells + Session = sessionmaker(bind=engine) + session = Session() + + job = Job( + runtime_environment_name="test_env", + input_filename="test.ipynb", + # Note: not setting completed_cells + ) + session.add(job) + session.commit() + + # Verify it defaults to None + retrieved_job = session.query(Job).filter(Job.job_id == job.job_id).one() + assert retrieved_job.completed_cells is None + + session.close()