Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/modelplane/runways/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
is_debug_mode,
setup_annotator_credentials,
)
from modelplane.utils.input import build_input
from modelplane.utils.input import build_and_log_input

KNOWN_ENSEMBLES: Dict[str, AnnotatorSet] = {}
# try to load the private ensemble
Expand Down Expand Up @@ -95,14 +95,13 @@ def annotate(

with tempfile.TemporaryDirectory() as tmp:
# load/transform the prompt responses from the specified run
input_data = build_input(
input_data = build_and_log_input(
path=response_file,
run_id=response_run_id,
artifact_path=PROMPT_RESPONSE_ARTIFACT_NAME,
dvc_repo=dvc_repo,
dest_dir=tmp,
)
input_data.log_input()
input_path = input_data.local_path() # type: ignore
pipeline_kwargs["input_path"] = pathlib.Path(input_path)
pipeline_kwargs["output_dir"] = pathlib.Path(tmp)
Expand Down
9 changes: 6 additions & 3 deletions src/modelplane/runways/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
is_debug_mode,
setup_sut_credentials,
)
from modelplane.utils.input import build_input
from modelplane.utils.input import build_and_log_input


def respond(
Expand All @@ -41,8 +41,11 @@ def respond(
mlflow.log_params(params)
# Use temporary file as mlflow will log this into the artifact store
with tempfile.TemporaryDirectory() as tmp:
input_data = build_input(path=prompts, dvc_repo=dvc_repo, dest_dir=tmp)
input_data.log_input()
input_data = build_and_log_input(
path=prompts,
dvc_repo=dvc_repo,
dest_dir=tmp,
)
pipeline_runner = build_runner(
num_workers=num_workers,
input_path=input_data.local_path(),
Expand Down
13 changes: 6 additions & 7 deletions src/modelplane/runways/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import json
import math
import os
import tempfile
from pathlib import Path

Expand All @@ -19,7 +18,7 @@
RUN_TYPE_TAG_NAME,
get_experiment_id,
)
from modelplane.utils.input import build_input
from modelplane.utils.input import build_and_log_input


def score(
Expand Down Expand Up @@ -50,19 +49,19 @@ def score(

with tempfile.TemporaryDirectory() as tmp:
# Load annotations
annotation_input = build_input(
annotation_input = build_and_log_input(
run_id=annotation_run_id,
artifact_path=ANNOTATION_RESPONSE_ARTIFACT_NAME,
dest_dir=tmp,
)
annotation_input.log_input()
annotation_data = AnnotationData(annotation_input.local_path(), is_json_annotation=True, sample_uid_col=sample_uid_col, annotator_uid_col=annotator_uid_col, annotation_col=annotation_col)

# Load ground truth
ground_truth_input = build_input(
path=ground_truth, dvc_repo=dvc_repo, dest_dir=tmp
ground_truth_input = build_and_log_input(
path=ground_truth,
dvc_repo=dvc_repo,
dest_dir=tmp,
)
ground_truth_input.log_input()
ground_truth_data = AnnotationData(ground_truth_input.local_path(), is_json_annotation=False, annotation_col="is_safe", annotator_uid_col=None, sample_uid_col=sample_uid_col)
mlflow.log_metric("num_ground_truth_samples", len(ground_truth_data.df))

Expand Down
148 changes: 96 additions & 52 deletions src/modelplane/utils/input.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import shutil
from abc import ABC, abstractmethod
Expand All @@ -7,52 +6,103 @@

import dvc.api
import mlflow
import mlflow.artifacts
import pandas as pd

from modelplane.mlflow.datasets import LocalDatasetSource
_MLFLOW_REQUIRED_ERROR_MESSAGE = (
"An active MLflow run is required to log input artifacts."
)


class BaseInput(ABC):
"""Base class for input datasets."""

input_type: str

def __init_subclass__(cls):
super().__init_subclass__()
if not hasattr(cls, "input_type"):
raise TypeError(f"{cls.__name__} must define class attribute 'input_type'")

def log_artifact(self):
"""Log the dataset to MLflow as an artifact to the current run."""
mlflow.log_artifact(str(self.local_path()))
mlflow.set_tags(self.input_tags())

@abstractmethod
def log_input(self):
"""Log the dataset to MLflow as input. This method should only be called inside an active MLflow run."""
def local_path(self) -> Path:
pass

def input_tags(self) -> dict:
tags = {"input_type": self.input_type}
tags.update(self.tags_for_input_type)
return tags

@property
@abstractmethod
def local_path(self) -> Path:
def tags_for_input_type(self) -> dict:
pass


class LocalInput(BaseInput):
"""A dataset that is stored locally."""

input_type = "local"

def __init__(self, path: str):
self.path = path

def log_input(self):
mlf_dataset = mlflow.data.meta_dataset.MetaDataset(
source=LocalDatasetSource(uri=self.path),
name=self.path,
)
mlflow.log_input(mlf_dataset)

def local_path(self) -> Path:
return Path(self.path)

@property
def tags_for_input_type(self) -> dict:
return {"input_path": self.path}


class DataframeInput(BaseInput):
"""A dataset that is represented as a Pandas DataFrame."""

input_type = "dataframe"
_INPUT_FILE_NAME = "input.csv"

def __init__(self, df: pd.DataFrame, dest_dir: str):
self._local_path = Path(dest_dir) / self._INPUT_FILE_NAME
self.df = df

@property
def df(self) -> pd.DataFrame:
return self._df

@df.setter
def df(self, df: pd.DataFrame):
self._df = df
self._update_local_file()

def _update_local_file(self):
self.df.to_csv(self._local_path, index=False)

def local_path(self) -> Path:
return self._local_path

@property
def tags_for_input_type(self) -> dict:
return {}


class DVCInput(BaseInput):
"""A dataset from a DVC remote."""

input_type = "dvc"

def __init__(self, path: str, repo: str, dest_dir: str):
repo_path = repo.split("#")
if len(repo_path) == 2:
repo, self.rev = repo_path
else:
self.rev = "main"
self.path = path
self.url = dvc.api.get_url(path, repo=repo, rev=self.rev) # For logging.
self._local_path = self._download_dvc_file(path, repo, dest_dir)
self._tags = {"input_repo": repo, "input_rev": self.rev, "input_path": path}

def _download_dvc_file(self, path: str, repo: str, dest_dir: str) -> str:
local_path = os.path.join(dest_dir, path)
Expand All @@ -64,32 +114,23 @@ def _download_dvc_file(self, path: str, repo: str, dest_dir: str) -> str:

return local_path

def digest(self) -> str:
"""Return the md5 hash of the dvc file."""
# TODO: Check if this works with other storage options (besides google cloud)
segments = self.url.split("/")
i = segments.index("md5")
digest = "".join(segments[i + 1 :])
return digest

def log_input(self):
dataset = mlflow.data.meta_dataset.MetaDataset(
source=mlflow.data.http_dataset_source.HTTPDatasetSource(self.url),
name=self.path,
digest=self.digest(),
)
mlflow.log_input(dataset)

def local_path(self) -> Path:
return Path(self._local_path)

@property
def tags_for_input_type(self) -> dict:
return self._tags


class MLFlowArtifactInput(BaseInput):
"""A dataset artifact from a previous MLFlow run."""

input_type = "artifact"

def __init__(self, run_id: str, artifact_path: str, dest_dir: str):
self.run_id = run_id
self._local_path = self._download_artifacts(run_id, artifact_path, dest_dir)
self._tags = {"input_run_id": run_id, "input_artifact_path": artifact_path}

def _download_artifacts(
self, run_id: str, artifact_path: str, dest_dir: str
Expand All @@ -101,48 +142,51 @@ def _download_artifacts(
)
return os.path.join(dest_dir, artifact_path)

def log_input(self):
run = mlflow.get_run(self.run_id)
for input in run.inputs.dataset_inputs:
ds = input.dataset
source_dict = json.loads(ds.source)
if ds.source_type == "http":
source = mlflow.data.http_dataset_source.HTTPDatasetSource(
source_dict["url"]
)
else:
source = LocalDatasetSource.from_dict(source_dict)
dataset = mlflow.data.dataset.Dataset(
source=source, name=ds.name, digest=ds.digest
)
mlflow.log_input(dataset)

def local_path(self) -> Path:
return Path(self._local_path)

@property
def tags_for_input_type(self) -> dict:
return self._tags


def build_input(
def build_and_log_input(
input_obj: Optional[BaseInput] = None,
path: Optional[str] = None,
run_id: Optional[str] = None,
artifact_path: Optional[str] = None,
dvc_repo: Optional[str] = None,
dest_dir: Optional[str] = None,
dest_dir: str = "",
df: Optional[pd.DataFrame] = None,
) -> BaseInput:
if dvc_repo is not None:
if mlflow.active_run() is None:
raise RuntimeError(_MLFLOW_REQUIRED_ERROR_MESSAGE)
# Direct input
if input_obj is not None:
inp = input_obj
# DF case
elif df is not None:
inp = DataframeInput(df, dest_dir=dest_dir)
# DVC case
elif dvc_repo is not None:
if path is None:
raise ValueError("Path must be provided when dvc_repo is provided.")
if run_id is not None:
raise ValueError(
"Cannot provide both run_id and dvc_repo to build an input."
)
return DVCInput(path=path, repo=dvc_repo, dest_dir=dest_dir)
inp = DVCInput(path=path, repo=dvc_repo, dest_dir=dest_dir)
# Local case
elif path is not None:
if run_id is not None:
raise ValueError("Cannot provide both path and run_id.")
return LocalInput(path)
inp = LocalInput(path)
# MLFlow artifact case
elif run_id is not None:
if artifact_path is None:
raise ValueError("Artifact path must be provided when run_id is provided.")
return MLFlowArtifactInput(run_id, artifact_path, dest_dir)
inp = MLFlowArtifactInput(run_id, artifact_path, dest_dir)
else:
raise ValueError("Either path or run_id must be provided to build an input.")
inp.log_artifact()
return inp
Loading