Skip to content

Commit f8a97c3

Browse files
committed
refactor and update in line with comments
1 parent c0fd724 commit f8a97c3

File tree

2 files changed

+129
-29
lines changed

2 files changed

+129
-29
lines changed

jupyter_scheduler/executors.py

Lines changed: 108 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import io
23
import os
34
import shutil
@@ -138,28 +139,15 @@ def execute(self):
138139
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
139140
)
140141

141-
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
142-
with mlflow.start_run(run_id=job.mlflow_run_id):
143-
try:
144-
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
145-
if job.parameters:
146-
mlflow.log_params(job.parameters)
147-
148-
for idx, cell in enumerate(nb.cells):
149-
if "tags" in cell.metadata and "mlflow_log" in cell.metadata["tags"]:
150-
mlflow.log_text(cell.source, f"source_cell_{idx}.txt")
151-
if cell.cell_type == "code" and cell.outputs:
152-
for output in cell.outputs:
153-
if "text/plain" in output.data:
154-
mlflow.log_text(
155-
output.data["text/plain"], f"output_cell_{idx}.txt"
156-
)
157-
158-
except CellExecutionError as e:
159-
raise e
160-
finally:
161-
self.add_side_effects_files(staging_dir)
162-
self.create_output_files(job, nb)
142+
try:
143+
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
144+
except CellExecutionError as e:
145+
raise e
146+
finally:
147+
self.add_side_effects_files(staging_dir)
148+
self.create_output_files(job, nb)
149+
if getattr(job, "mlflow_logging", False):
150+
self.log_to_mlflow(job, nb)
163151

164152
def add_side_effects_files(self, staging_dir: str):
165153
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
@@ -187,10 +175,105 @@ def create_output_files(self, job: DescribeJob, notebook_node):
187175
for output_format in job.output_formats:
188176
cls = nbconvert.get_exporter(output_format)
189177
output, _ = cls().from_notebook_node(notebook_node)
190-
output_path = self.staging_paths[output_format]
191-
with fsspec.open(output_path, "w", encoding="utf-8") as f:
178+
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
192179
f.write(output)
193-
mlflow.log_artifact(output_path)
180+
181+
def log_to_mlflow(self, job, nb):
182+
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
183+
with mlflow.start_run(run_id=job.mlflow_run_id):
184+
if job.parameters:
185+
mlflow.log_params(job.parameters)
186+
187+
for cell_idx, cell in enumerate(nb.cells):
188+
if "tags" in cell.metadata:
189+
if "mlflow_log" in cell.metadata["tags"]:
190+
self.mlflow_log(cell, cell_idx)
191+
elif "mlflow_log_input" in cell.metadata["tags"]:
192+
self.mlflow_log_input(cell, cell_idx)
193+
elif "mlflow_log_output" in cell.metadata["tags"]:
194+
self.mlflow_log_output(cell, cell_idx)
195+
196+
for output_format in job.output_formats:
197+
output_path = self.staging_paths[output_format]
198+
directory, file_name_with_extension = os.path.split(output_path)
199+
file_name, file_extension = os.path.splitext(file_name_with_extension)
200+
file_name_parts = file_name.split("-")
201+
file_name_without_timestamp = "-".join(file_name_parts[:-7])
202+
file_name_final = f"{file_name_without_timestamp}{file_extension}"
203+
new_output_path = os.path.join(directory, file_name_final)
204+
shutil.copy(output_path, new_output_path)
205+
timestamp = "-".join(file_name_parts[-7:]).split(".")[0]
206+
mlflow.log_param("job_created", timestamp)
207+
mlflow.log_artifact(new_output_path, "")
208+
os.remove(new_output_path)
209+
210+
def mlflow_log(self, cell, cell_idx):
211+
self.mlflow_log_input(cell, cell_idx)
212+
self.mlflow_log_output(cell, cell_idx)
213+
214+
def mlflow_log_input(self, cell, cell_idx):
215+
mlflow.log_text(cell.source, f"cell_{cell_idx}_input.txt")
216+
217+
def mlflow_log_output(self, cell, cell_idx):
218+
if cell.cell_type == "code" and hasattr(cell, "outputs"):
219+
self._log_code_output(cell_idx, cell.outputs)
220+
elif cell.cell_type == "markdown":
221+
self._log_markdown_output(cell, cell_idx)
222+
223+
def _log_code_output(self, cell_idx, outputs):
224+
for output_idx, output in enumerate(outputs):
225+
if output.output_type == "stream":
226+
self._log_stream_output(cell_idx, output_idx, output)
227+
elif hasattr(output, "data"):
228+
for output_data_idx, output_data in enumerate(output.data):
229+
if output_data == "text/plain":
230+
mlflow.log_text(
231+
output.data[output_data],
232+
f"cell_{cell_idx}_output_{output_data_idx}.txt",
233+
)
234+
elif output_data == "text/html":
235+
self._log_html_output(output, cell_idx, output_data_idx)
236+
elif output_data == "application/pdf":
237+
self._log_pdf_output(output, cell_idx, output_data_idx)
238+
elif output_data.startswith("image"):
239+
self._log_image_output(output, cell_idx, output_data_idx, output_data)
240+
241+
def _log_stream_output(self, cell_idx, output_idx, output):
242+
mlflow.log_text("".join(output.text), f"cell_{cell_idx}_output_{output_idx}.txt")
243+
244+
def _log_html_output(self, output, cell_idx, output_idx):
245+
if "text/html" in output.data:
246+
html_content = output.data["text/html"]
247+
if isinstance(html_content, list):
248+
html_content = "".join(html_content)
249+
mlflow.log_text(html_content, f"cell_{cell_idx}_output_{output_idx}.html")
250+
251+
def _log_pdf_output(self, output, cell_idx, output_idx):
252+
pdf_data = base64.b64decode(output.data["application/pdf"].split(",")[1])
253+
with open(f"cell_{cell_idx}_output_{output_idx}.pdf", "wb") as pdf_file:
254+
pdf_file.write(pdf_data)
255+
mlflow.log_artifact(f"cell_{cell_idx}_output_{output_idx}.pdf")
256+
257+
def _log_image_output(self, output, cell_idx, output_idx, mime_type):
258+
image_data_str = output.data[mime_type]
259+
if "," in image_data_str:
260+
image_data_base64 = image_data_str.split(",")[1]
261+
else:
262+
image_data_base64 = image_data_str
263+
264+
try:
265+
image_data = base64.b64decode(image_data_base64)
266+
image_extension = mime_type.split("/")[1]
267+
filename = f"cell_{cell_idx}_output_{output_idx}.{image_extension}"
268+
with open(filename, "wb") as image_file:
269+
image_file.write(image_data)
270+
mlflow.log_artifact(filename)
271+
os.remove(filename)
272+
except Exception as e:
273+
print(f"Error logging image output in cell {cell_idx}, output {output_idx}: {e}")
274+
275+
def _log_markdown_output(self, cell, cell_idx):
276+
mlflow.log_text(cell.source, f"cell_{cell_idx}_output_0.md")
194277

195278
def supported_features(cls) -> Dict[JobFeature, bool]:
196279
return {

jupyter_scheduler/scheduler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import random
44
import shutil
55
from typing import Dict, List, Optional, Type, Union
6+
import signal
67
import subprocess
7-
from typing import Dict, Optional, Type, Union
8+
import sys
89
from uuid import uuid4
910

1011
import fsspec
@@ -408,17 +409,31 @@ class Scheduler(BaseScheduler):
408409
task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner")
409410

410411
def start_mlflow_server(self):
411-
subprocess.Popen(
412+
mlflow_process = subprocess.Popen(
412413
[
413414
"mlflow",
414415
"server",
415416
"--host",
416417
MLFLOW_SERVER_HOST,
417418
"--port",
418419
MLFLOW_SERVER_PORT,
419-
]
420+
],
421+
preexec_fn=os.setsid,
420422
)
421423
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
424+
return mlflow_process
425+
426+
def stop_mlflow_server(self):
427+
if self.mlflow_process is not None:
428+
os.killpg(os.getpgid(self.mlflow_process.pid), signal.SIGTERM)
429+
self.mlflow_process.wait()
430+
self.mlflow_process = None
431+
print("MLFlow server stopped")
432+
433+
def mlflow_signal_handler(self, signum, frame):
434+
print("Shutting down MLFlow server")
435+
self.stop_mlflow_server()
436+
sys.exit(0)
422437

423438
def __init__(
424439
self,
@@ -435,7 +450,9 @@ def __init__(
435450
if self.task_runner_class:
436451
self.task_runner = self.task_runner_class(scheduler=self, config=config)
437452

438-
self.start_mlflow_server()
453+
self.mlflow_process = self.start_mlflow_server()
454+
signal.signal(signal.SIGINT, self.mlflow_signal_handler)
455+
signal.signal(signal.SIGTERM, self.mlflow_signal_handler)
439456

440457
@property
441458
def db_session(self):

0 commit comments

Comments
 (0)