Skip to content

Commit 5cb4107

Browse files
committed
log jobs and job defs in mlflow
1 parent dad6c5f commit 5cb4107

File tree

3 files changed

+32
-64
lines changed

3 files changed

+32
-64
lines changed

jupyter_scheduler/executors.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,19 @@ def execute(self):
139139
kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
140140
)
141141

142-
try:
143-
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
144-
except CellExecutionError as e:
145-
raise e
146-
finally:
147-
if getattr(job, "mlflow_logging", False):
148-
self.log_to_mlflow(job, nb)
149-
self.add_side_effects_files(staging_dir)
150-
self.create_output_files(job, nb)
142+
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
143+
with mlflow.start_run(run_id=job.mlflow_run_id):
144+
try:
145+
ep.preprocess(nb, {"metadata": {"path": staging_dir}})
146+
if job.parameters:
147+
mlflow.log_params(job.parameters)
148+
except CellExecutionError as e:
149+
raise e
150+
finally:
151+
if getattr(job, "mlflow_logging", False):
152+
self.log_to_mlflow(job, nb)
153+
self.add_side_effects_files(staging_dir)
154+
self.create_output_files(job, nb)
151155

152156
def add_side_effects_files(self, staging_dir: str):
153157
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
@@ -175,8 +179,10 @@ def create_output_files(self, job: DescribeJob, notebook_node):
175179
for output_format in job.output_formats:
176180
cls = nbconvert.get_exporter(output_format)
177181
output, _ = cls().from_notebook_node(notebook_node)
182+
output_path = self.staging_paths[output_format]
178183
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
179184
f.write(output)
185+
mlflow.log_artifact(output_path)
180186

181187
def log_to_mlflow(self, job, nb):
182188
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)

jupyter_scheduler/scheduler.py

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22
import os
33
import random
44
import shutil
5-
import signal
6-
import subprocess
75
from typing import Dict, List, Optional, Type, Union
8-
import sys
9-
from uuid import uuid4
6+
import subprocess
107

118
import fsspec
12-
import mlflow
139
import psutil
1410
from jupyter_core.paths import jupyter_data_dir
1511
from jupyter_server.transutils import _i18n
@@ -50,10 +46,6 @@
5046
create_output_filename,
5147
)
5248

53-
MLFLOW_SERVER_HOST = "127.0.0.1"
54-
MLFLOW_SERVER_PORT = "5000"
55-
MLFLOW_SERVER_URI = f"http://{MLFLOW_SERVER_HOST}:{MLFLOW_SERVER_PORT}"
56-
5749

5850
class BaseScheduler(LoggingConfigurable):
5951
"""Base class for schedulers. A default implementation
@@ -409,31 +401,20 @@ class Scheduler(BaseScheduler):
409401
task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner")
410402

411403
def start_mlflow_server(self):
412-
mlflow_process = subprocess.Popen(
404+
subprocess.Popen(
413405
[
414406
"mlflow",
415407
"server",
408+
"--backend-store-uri",
409+
"./mlruns",
410+
"--default-artifact-root",
411+
"./mlartifacts",
416412
"--host",
417-
MLFLOW_SERVER_HOST,
413+
"0.0.0.0",
418414
"--port",
419-
MLFLOW_SERVER_PORT,
420-
],
421-
preexec_fn=os.setsid,
415+
"5000",
416+
]
422417
)
423-
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
424-
return mlflow_process
425-
426-
def stop_mlflow_server(self):
427-
if self.mlflow_process is not None:
428-
os.killpg(os.getpgid(self.mlflow_process.pid), signal.SIGTERM)
429-
self.mlflow_process.wait()
430-
self.mlflow_process = None
431-
print("MLFlow server stopped")
432-
433-
def mlflow_signal_handler(self, signum, frame):
434-
print("Shutting down MLFlow server")
435-
self.stop_mlflow_server()
436-
sys.exit(0)
437418

438419
def __init__(
439420
self,
@@ -450,9 +431,7 @@ def __init__(
450431
if self.task_runner_class:
451432
self.task_runner = self.task_runner_class(scheduler=self, config=config)
452433

453-
self.mlflow_process = self.start_mlflow_server()
454-
signal.signal(signal.SIGINT, self.mlflow_signal_handler)
455-
signal.signal(signal.SIGTERM, self.mlflow_signal_handler)
434+
self.start_mlflow_server()
456435

457436
@property
458437
def db_session(self):
@@ -502,21 +481,6 @@ def create_job(self, model: CreateJob) -> str:
502481
if not model.output_formats:
503482
model.output_formats = []
504483

505-
mlflow_client = mlflow.MlflowClient()
506-
507-
if model.job_definition_id and model.mlflow_experiment_id:
508-
experiment_id = model.mlflow_experiment_id
509-
else:
510-
experiment_id = mlflow_client.create_experiment(f"{model.input_filename}-{uuid4()}")
511-
model.mlflow_experiment_id = experiment_id
512-
input_file_path = os.path.join(self.root_dir, model.input_uri)
513-
mlflow.log_artifact(input_file_path, "input")
514-
515-
mlflow_run = mlflow_client.create_run(
516-
experiment_id=experiment_id, run_name=f"{model.input_filename}-{uuid4()}"
517-
)
518-
model.mlflow_run_id = mlflow_run.info.run_id
519-
520484
job = Job(**model.dict(exclude_none=True, exclude={"input_uri"}))
521485

522486
session.add(job)
@@ -664,12 +628,6 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
664628
if not self.file_exists(model.input_uri):
665629
raise InputUriError(model.input_uri)
666630

667-
mlflow_client = mlflow.MlflowClient()
668-
experiment_id = mlflow_client.create_experiment(f"{model.input_filename}-{uuid4()}")
669-
model.mlflow_experiment_id = experiment_id
670-
input_file_path = os.path.join(self.root_dir, model.input_uri)
671-
mlflow.log_artifact(input_file_path, "input")
672-
673631
job_definition = JobDefinition(**model.dict(exclude_none=True, exclude={"input_uri"}))
674632
session.add(job_definition)
675633
session.commit()

src/mainviews/create-job.tsx

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ export function CreateJob(props: ICreateJobProps): JSX.Element {
176176

177177
const handleInputChange = (event: ChangeEvent<HTMLInputElement>) => {
178178
const target = event.target;
179-
180179
const parameterNameIdx = parameterNameMatch(target.name);
181180
const parameterValueIdx = parameterValueMatch(target.name);
182181
const newParams = props.model.parameters || [];
@@ -323,7 +322,10 @@ export function CreateJob(props: ICreateJobProps): JSX.Element {
323322
idempotency_token: props.model.idempotencyToken,
324323
tags: props.model.tags,
325324
runtime_environment_parameters: props.model.runtimeEnvironmentParameters,
326-
package_input_folder: props.model.packageInputFolder
325+
package_input_folder: props.model.packageInputFolder,
326+
mlflow_logging: props.model.mlflowLogging,
327+
mlflow_experiment_id: props.model.mlflowExperimentId,
328+
mlflow_run_id: props.model.mlflowRunId
327329
};
328330

329331
if (props.model.parameters !== undefined) {
@@ -372,7 +374,9 @@ export function CreateJob(props: ICreateJobProps): JSX.Element {
372374
runtime_environment_parameters: props.model.runtimeEnvironmentParameters,
373375
schedule: props.model.schedule,
374376
timezone: props.model.timezone,
375-
package_input_folder: props.model.packageInputFolder
377+
package_input_folder: props.model.packageInputFolder,
378+
mlflow_logging: props.model.mlflowLogging,
379+
mlflow_experiment_id: props.model.mlflowExperimentId
376380
};
377381

378382
if (props.model.parameters !== undefined) {

0 commit comments

Comments
 (0)