Skip to content

Add cell execution tracking during notebook execution #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,14 @@ def execute(self):
nb = add_parameters(nb, job.parameters)

staging_dir = os.path.dirname(self.staging_paths["input"])

ep = ExecutePreprocessor(
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
)

if self.supported_features().get(JobFeature.track_cell_execution, False):
ep.on_cell_executed = self.__update_completed_cells_hook(ep)

try:
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
except CellExecutionError as e:
Expand All @@ -144,6 +148,18 @@ def execute(self):
self.add_side_effects_files(staging_dir)
self.create_output_files(job, nb)

def __update_completed_cells_hook(self, ep: ExecutePreprocessor):
"""Returns a hook that runs on every cell execution, regardless of success or failure. Updates the completed_cells for the job."""

def update_completed_cells(cell, cell_index, execute_reply):
with self.db_session() as session:
session.query(Job).filter(Job.job_id == self.job_id).update(
{"completed_cells": ep.code_cells_executed}
)
session.commit()

return update_completed_cells

def add_side_effects_files(self, staging_dir: str):
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
input_notebook = os.path.relpath(self.staging_paths["input"])
Expand Down Expand Up @@ -173,6 +189,7 @@ def create_output_files(self, job: DescribeJob, notebook_node):
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
f.write(output)

@classmethod
def supported_features(cls) -> Dict[JobFeature, bool]:
return {
JobFeature.job_name: True,
Expand All @@ -188,8 +205,10 @@ def supported_features(cls) -> Dict[JobFeature, bool]:
JobFeature.output_filename_template: False,
JobFeature.stop_job: True,
JobFeature.delete_job: True,
JobFeature.track_cell_execution: True,
}

@classmethod
def validate(cls, input_path: str) -> bool:
with open(input_path, encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)
Expand Down
3 changes: 3 additions & 0 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class DescribeJob(BaseModel):
downloaded: bool = False
package_input_folder: Optional[bool] = None
packaged_files: Optional[List[str]] = []
completed_cells: Optional[int] = None

class Config:
orm_mode = True
Expand Down Expand Up @@ -193,6 +194,7 @@ class UpdateJob(BaseModel):
status: Optional[Status] = None
name: Optional[str] = None
compute_type: Optional[str] = None
completed_cells: Optional[int] = None


class DeleteJob(BaseModel):
Expand Down Expand Up @@ -295,3 +297,4 @@ class JobFeature(str, Enum):
output_filename_template = "output_filename_template"
stop_job = "stop_job"
delete_job = "delete_job"
track_cell_execution = "track_cell_execution"
1 change: 1 addition & 0 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Job(CommonColumns, Base):
url = Column(String(256), default=generate_jobs_url)
pid = Column(Integer)
idempotency_token = Column(String(256))
completed_cells = Column(Integer, nullable=True)
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
# Any default values specified for new columns will be ignored during the migration process.

Expand Down
2 changes: 1 addition & 1 deletion jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def create_job(self, model: CreateJob) -> str:
raise InputUriError(model.input_uri)

input_path = os.path.join(self.root_dir, model.input_uri)
if not self.execution_manager_class.validate(self.execution_manager_class, input_path):
if not self.execution_manager_class.validate(input_path):
raise SchedulerError(
"""There is no kernel associated with the notebook. Please open
the notebook, select a kernel, and re-submit the job to execute.
Expand Down
244 changes: 244 additions & 0 deletions jupyter_scheduler/tests/test_execution_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil
from pathlib import Path
from typing import Tuple
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -58,3 +59,246 @@ def test_add_side_effects_files(

job = jp_scheduler_db.query(Job).filter(Job.job_id == job_id).one()
assert side_effect_file_name in job.packaged_files


def test_default_execution_manager_cell_tracking_hook():
"""Test that DefaultExecutionManager sets up on_cell_executed hook when track_cell_execution is supported"""
job_id = "test-job-id"

with patch.object(DefaultExecutionManager, "model") as mock_model:
with patch("jupyter_scheduler.executors.open", mock=MagicMock()):
with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read:
with patch.object(DefaultExecutionManager, "add_side_effects_files"):
with patch.object(DefaultExecutionManager, "create_output_files"):
# Mock notebook
mock_nb = MagicMock()
mock_nb.metadata.kernelspec = {"name": "python3"}
mock_nb_read.return_value = mock_nb

# Mock model
mock_model.parameters = None
mock_model.output_formats = []

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Patch ExecutePreprocessor
with patch(
"jupyter_scheduler.executors.ExecutePreprocessor"
) as mock_ep_class:
mock_ep = MagicMock()
mock_ep_class.return_value = mock_ep

# Execute
manager.execute()

# Verify ExecutePreprocessor was created
mock_ep_class.assert_called_once()

# Verify on_cell_executed hook was set
assert hasattr(mock_ep, "on_cell_executed")
assert mock_ep.on_cell_executed is not None


def test_update_completed_cells_hook():
"""Test the __update_completed_cells_hook method"""
job_id = "test-job-id"

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock db_session
mock_db_session = MagicMock()
mock_session_context = MagicMock()
mock_db_session.return_value.__enter__.return_value = mock_session_context
manager._db_session = mock_db_session

# Mock ExecutePreprocessor
mock_ep = MagicMock()
mock_ep.code_cells_executed = 5

# Get the hook function
hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep)

# Call the hook
mock_cell = MagicMock()
mock_execute_reply = MagicMock()
hook_func(mock_cell, 2, mock_execute_reply)

# Verify database update was called
mock_session_context.query.assert_called_once_with(Job)
mock_session_context.query.return_value.filter.return_value.update.assert_called_once_with(
{"completed_cells": 5}
)
mock_session_context.commit.assert_called_once()


def test_update_completed_cells_hook_database_error():
"""Test that database errors in the hook are handled"""
job_id = "test-job-id"

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock db_session with error
mock_db_session = MagicMock()
mock_session_context = MagicMock()
mock_session_context.query.return_value.filter.return_value.update.side_effect = Exception(
"DB Error"
)
mock_db_session.return_value.__enter__.return_value = mock_session_context
manager._db_session = mock_db_session

# Mock ExecutePreprocessor
mock_ep = MagicMock()
mock_ep.code_cells_executed = 3

# Get the hook function
hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep)

# Call the hook - should raise exception
mock_cell = MagicMock()
mock_execute_reply = MagicMock()

with pytest.raises(Exception, match="DB Error"):
hook_func(mock_cell, 1, mock_execute_reply)


def test_supported_features_includes_track_cell_execution():
"""Test that DefaultExecutionManager supports track_cell_execution feature"""
features = DefaultExecutionManager.supported_features()

from jupyter_scheduler.models import JobFeature

assert JobFeature.track_cell_execution in features
assert features[JobFeature.track_cell_execution] is True


def test_hook_uses_correct_job_id():
"""Test that the hook uses the correct job_id in database queries"""
job_id = "specific-job-id-456"

# Create manager
manager = DefaultExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock db_session
mock_db_session = MagicMock()
mock_session_context = MagicMock()
mock_db_session.return_value.__enter__.return_value = mock_session_context
manager._db_session = mock_db_session

# Mock ExecutePreprocessor
mock_ep = MagicMock()
mock_ep.code_cells_executed = 7

# Get the hook function
hook_func = manager._DefaultExecutionManager__update_completed_cells_hook(mock_ep)

# Call the hook
mock_cell = MagicMock()
mock_execute_reply = MagicMock()
hook_func(mock_cell, 3, mock_execute_reply)

# Verify the correct job_id is used in the filter
# The filter call should contain a condition that matches Job.job_id == job_id
filter_call = mock_session_context.query.return_value.filter.call_args[0][0]
# This is a SQLAlchemy comparison object, so we need to check its properties
assert hasattr(filter_call, "right")
assert filter_call.right.value == job_id


def test_cell_tracking_disabled_when_feature_false():
"""Test that cell tracking hook is not set when track_cell_execution feature is False"""
job_id = "test-job-id"

# Create a custom execution manager class with track_cell_execution = False
class DisabledTrackingExecutionManager(DefaultExecutionManager):
@classmethod
def supported_features(cls):
features = super().supported_features()
from jupyter_scheduler.models import JobFeature

features[JobFeature.track_cell_execution] = False
return features

# Create manager with disabled tracking
manager = DisabledTrackingExecutionManager(
job_id=job_id,
root_dir="/test",
db_url="sqlite:///:memory:",
staging_paths={"input": "/test/input.ipynb"},
)

# Mock ExecutePreprocessor and track calls to __update_completed_cells_hook
with patch.object(
manager, "_DefaultExecutionManager__update_completed_cells_hook"
) as mock_hook_method:
with patch.object(DisabledTrackingExecutionManager, "model") as mock_model:
with patch("jupyter_scheduler.executors.open", mock=MagicMock()):
with patch("jupyter_scheduler.executors.nbformat.read") as mock_nb_read:
with patch.object(DisabledTrackingExecutionManager, "add_side_effects_files"):
with patch.object(DisabledTrackingExecutionManager, "create_output_files"):
with patch(
"jupyter_scheduler.executors.ExecutePreprocessor"
) as mock_ep_class:
# Mock notebook
mock_nb = MagicMock()
mock_nb.metadata.kernelspec = {"name": "python3"}
mock_nb_read.return_value = mock_nb

# Mock model
mock_model.parameters = None
mock_model.output_formats = []

mock_ep = MagicMock()
mock_ep_class.return_value = mock_ep

# Execute
manager.execute()

# Verify ExecutePreprocessor was created
mock_ep_class.assert_called_once()

# Verify the hook method was NOT called when feature is disabled
mock_hook_method.assert_not_called()


def test_disabled_tracking_feature_support():
"""Test that custom execution manager can disable track_cell_execution feature"""

# Create a custom execution manager class with track_cell_execution = False
class DisabledTrackingExecutionManager(DefaultExecutionManager):
@classmethod
def supported_features(cls):
features = super().supported_features()
from jupyter_scheduler.models import JobFeature

features[JobFeature.track_cell_execution] = False
return features

features = DisabledTrackingExecutionManager.supported_features()

from jupyter_scheduler.models import JobFeature

assert JobFeature.track_cell_execution in features
assert features[JobFeature.track_cell_execution] is False
Loading
Loading