Skip to content

Commit b4f4986

Browse files
committed
refactor and update in line with comments
1 parent 5f3a16f commit b4f4986

File tree

2 files changed

+119
-28
lines changed

2 files changed

+119
-28
lines changed

jupyter_scheduler/executors.py

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import io
23
import os
34
import shutil
@@ -143,6 +144,8 @@ def execute(self):
143144
except CellExecutionError as e:
144145
raise e
145146
finally:
147+
if getattr(job, "mlflow_logging", False):
148+
self.log_to_mlflow(job, nb)
146149
self.add_side_effects_files(staging_dir)
147150
self.create_output_files(job, nb)
148151

@@ -174,33 +177,103 @@ def create_output_files(self, job: DescribeJob, notebook_node):
174177
output, _ = cls().from_notebook_node(notebook_node)
175178
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
176179
f.write(output)
180+
181+
def log_to_mlflow(self, job, nb):
177182
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
178183
with mlflow.start_run(run_id=job.mlflow_run_id):
179-
try:
180-
ep.preprocess(nb)
181-
if job.parameters:
182-
mlflow.log_params(job.parameters)
183-
184-
for idx, cell in enumerate(nb.cells):
185-
if "tags" in cell.metadata and "mlflow_log" in cell.metadata["tags"]:
186-
mlflow.log_text(cell.source, f"source_cell_{idx}.txt")
187-
if cell.cell_type == "code" and cell.outputs:
188-
for output in cell.outputs:
189-
if "text/plain" in output.data:
190-
mlflow.log_text(
191-
output.data["text/plain"], f"output_cell_{idx}.txt"
192-
)
193-
194-
except CellExecutionError as e:
195-
raise e
196-
finally:
197-
for output_format in job.output_formats:
198-
cls = nbconvert.get_exporter(output_format)
199-
output, resources = cls().from_notebook_node(nb)
200-
output_path = self.staging_paths[output_format]
201-
with fsspec.open(output_path, "w", encoding="utf-8") as f:
202-
f.write(output)
203-
mlflow.log_artifact(output_path)
184+
if job.parameters:
185+
mlflow.log_params(job.parameters)
186+
187+
for cell_idx, cell in enumerate(nb.cells):
188+
if "tags" in cell.metadata:
189+
if "mlflow_log" in cell.metadata["tags"]:
190+
self.mlflow_log(cell, cell_idx)
191+
elif "mlflow_log_input" in cell.metadata["tags"]:
192+
self.mlflow_log_input(cell, cell_idx)
193+
elif "mlflow_log_output" in cell.metadata["tags"]:
194+
self.mlflow_log_output(cell, cell_idx)
195+
196+
for output_format in job.output_formats:
197+
output_path = self.staging_paths[output_format]
198+
directory, file_name_with_extension = os.path.split(output_path)
199+
file_name, file_extension = os.path.splitext(file_name_with_extension)
200+
file_name_parts = file_name.split("-")
201+
file_name_without_timestamp = "-".join(file_name_parts[:-7])
202+
file_name_final = f"{file_name_without_timestamp}{file_extension}"
203+
new_output_path = os.path.join(directory, file_name_final)
204+
shutil.copy(output_path, new_output_path)
205+
timestamp = "-".join(file_name_parts[-7:]).split(".")[0]
206+
mlflow.log_param("timestamp", timestamp)
207+
mlflow.log_artifact(new_output_path, "")
208+
os.remove(new_output_path)
209+
210+
def mlflow_log(self, cell, cell_idx):
211+
self.mlflow_log_input(cell, cell_idx)
212+
self.mlflow_log_output(cell, cell_idx)
213+
214+
def mlflow_log_input(self, cell, cell_idx):
215+
mlflow.log_text(cell.source, f"cell_{cell_idx}_input.txt")
216+
217+
def mlflow_log_output(self, cell, cell_idx):
218+
if cell.cell_type == "code" and hasattr(cell, "outputs"):
219+
self._log_code_output(cell_idx, cell.outputs)
220+
elif cell.cell_type == "markdown":
221+
self._log_markdown_output(cell, cell_idx)
222+
223+
def _log_code_output(self, cell_idx, outputs):
224+
for output_idx, output in enumerate(outputs):
225+
if output.output_type == "stream":
226+
self._log_stream_output(cell_idx, output_idx, output)
227+
elif hasattr(output, "data"):
228+
for output_data_idx, output_data in enumerate(output.data):
229+
if output_data == "text/plain":
230+
mlflow.log_text(
231+
output.data[output_data],
232+
f"cell_{cell_idx}_output_{output_data_idx}.txt",
233+
)
234+
elif output_data == "text/html":
235+
self._log_html_output(output, cell_idx, output_data_idx)
236+
elif output_data == "application/pdf":
237+
self._log_pdf_output(output, cell_idx, output_data_idx)
238+
elif output_data.startswith("image"):
239+
self._log_image_output(output, cell_idx, output_data_idx, output_data)
240+
241+
def _log_stream_output(self, cell_idx, output_idx, output):
242+
mlflow.log_text("".join(output.text), f"cell_{cell_idx}_output_{output_idx}.txt")
243+
244+
def _log_html_output(self, output, cell_idx, output_idx):
245+
if "text/html" in output.data:
246+
html_content = output.data["text/html"]
247+
if isinstance(html_content, list):
248+
html_content = "".join(html_content)
249+
mlflow.log_text(html_content, f"cell_{cell_idx}_output_{output_idx}.html")
250+
251+
def _log_pdf_output(self, output, cell_idx, output_idx):
252+
pdf_data = base64.b64decode(output.data["application/pdf"].split(",")[1])
253+
with open(f"cell_{cell_idx}_output_{output_idx}.pdf", "wb") as pdf_file:
254+
pdf_file.write(pdf_data)
255+
mlflow.log_artifact(f"cell_{cell_idx}_output_{output_idx}.pdf")
256+
257+
def _log_image_output(self, output, cell_idx, output_idx, mime_type):
258+
image_data_str = output.data[mime_type]
259+
if "," in image_data_str:
260+
image_data_base64 = image_data_str.split(",")[1]
261+
else:
262+
image_data_base64 = image_data_str
263+
264+
try:
265+
image_data = base64.b64decode(image_data_base64)
266+
image_extension = mime_type.split("/")[1]
267+
filename = f"cell_{cell_idx}_output_{output_idx}.{image_extension}"
268+
with open(filename, "wb") as image_file:
269+
image_file.write(image_data)
270+
mlflow.log_artifact(filename)
271+
os.remove(filename)
272+
except Exception as e:
273+
print(f"Error logging image output in cell {cell_idx}, output {output_idx}: {e}")
274+
275+
def _log_markdown_output(self, cell, cell_idx):
276+
mlflow.log_text(cell.source, f"cell_{cell_idx}_output_0.md")
204277

205278
def supported_features(cls) -> Dict[JobFeature, bool]:
206279
return {

jupyter_scheduler/scheduler.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import os
33
import random
44
import shutil
5+
import signal
56
import subprocess
67
from typing import Dict, List, Optional, Type, Union
8+
import sys
79
from uuid import uuid4
810

911
import fsspec
@@ -407,17 +409,31 @@ class Scheduler(BaseScheduler):
407409
task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner")
408410

409411
def start_mlflow_server(self):
410-
subprocess.Popen(
412+
mlflow_process = subprocess.Popen(
411413
[
412414
"mlflow",
413415
"server",
414416
"--host",
415417
MLFLOW_SERVER_HOST,
416418
"--port",
417419
MLFLOW_SERVER_PORT,
418-
]
420+
],
421+
preexec_fn=os.setsid,
419422
)
420423
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
424+
return mlflow_process
425+
426+
def stop_mlflow_server(self):
427+
if self.mlflow_process is not None:
428+
os.killpg(os.getpgid(self.mlflow_process.pid), signal.SIGTERM)
429+
self.mlflow_process.wait()
430+
self.mlflow_process = None
431+
print("MLFlow server stopped")
432+
433+
def mlflow_signal_handler(self, signum, frame):
434+
print("Shutting down MLFlow server")
435+
self.stop_mlflow_server()
436+
sys.exit(0)
421437

422438
def __init__(
423439
self,
@@ -434,7 +450,9 @@ def __init__(
434450
if self.task_runner_class:
435451
self.task_runner = self.task_runner_class(scheduler=self, config=config)
436452

437-
self.start_mlflow_server()
453+
self.mlflow_process = self.start_mlflow_server()
454+
signal.signal(signal.SIGINT, self.mlflow_signal_handler)
455+
signal.signal(signal.SIGTERM, self.mlflow_signal_handler)
438456

439457
@property
440458
def db_session(self):

0 commit comments

Comments
 (0)