Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ secrets.toml
*.pyc
.vscode/
.coverage*
.cache
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ scratch.
* You can manage branches and commits for
`modelplane-flights` directly from jupyter.

## Caching

Annotator and SUT responses will be cached (locally) unless you pass the
`disable_cache` flag to the appropriate calls.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good choice, I think.


## CLI

You can also interact with modelplane via CLI. Run `poetry run modelplane --help`
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ services:
- "8888:8888"
volumes:
- ./flightpaths:/app/flightpaths
# Used for caching of SUT/annotator results
- ./flightpaths/.cache:/app/flightpaths/.cache
# Volume not needed if not using modelplane-flights for sharing notebooks
- ../modelplane-flights:/app/flightpaths/flights
# Volume not needed if using cloud storage for artifacts
Expand Down
13 changes: 4 additions & 9 deletions flightpaths/Annotator Development Template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@
"\n",
"The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
"\n",
"You can cache prompt responses via `cache_dir`.\n",
"\n",
"Finally, `n_jobs` can adjust the parallelism."
"Finally, `num_workers` can adjust the parallelism."
]
},
{
Expand All @@ -66,8 +64,7 @@
"experiment = \"experiment_\" + datetime.date.today().strftime(\"%Y%m%d\")\n",
"prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n",
"ground_truth = \"data/fakegroundtruth.csv\"\n",
"cache_dir = None\n",
"n_jobs = 4"
"num_workers = 4"
]
},
{
Expand Down Expand Up @@ -168,8 +165,7 @@
" sut_id=sut_id,\n",
" experiment=experiment,\n",
" prompts=prompts,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
},
Expand All @@ -194,8 +190,7 @@
" annotator_ids=[annotator_id],\n",
" experiment=experiment,\n",
" response_run_id=run_id,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
},
Expand Down
21 changes: 4 additions & 17 deletions flightpaths/Ensemble Development Template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@
"\n",
"The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
"\n",
"You can cache prompt responses via `cache_dir`.\n",
"\n",
"Finally, `n_jobs` can adjust the parallelism."
"Finally, `num_workers` can adjust the parallelism."
]
},
{
Expand All @@ -58,8 +56,7 @@
"experiment = \"ensemble_experiment_\" + datetime.date.today().strftime(\"%Y%m%d\")\n",
"prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n",
"ground_truth = \"data/fakegroundtruth.csv\"\n",
"cache_dir = None\n",
"n_jobs = 4"
"num_workers = 4"
]
},
{
Expand Down Expand Up @@ -226,8 +223,7 @@
" sut_id=sut_id,\n",
" experiment=experiment,\n",
" prompts=prompts,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
},
Expand All @@ -253,8 +249,7 @@
" ensemble_strategy=\"anyunsafe\",\n",
" experiment=experiment,\n",
" response_run_id=run_id,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
},
Expand All @@ -281,14 +276,6 @@
" ground_truth=ground_truth,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "476edda5-cfca-489f-90ac-bac11f9530ef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
22 changes: 8 additions & 14 deletions flightpaths/Running the Evaluator with Mods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@
"source": [
"## Settings\n",
"\n",
"* The `sut_id` refers to the model that generates the responses to the prompts. It is currently set to a demo SUT.\n",
"* The `experiment` variable will be used to organize the various runs in mlflow.\n",
"* The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
"* You can cache prompt responses via `cache_dir`.\n",
"\n",
"Finally, `n_jobs` can adjust the parallelism."
"* `sut_id` refers to the model that generates the responses to the prompts. It is currently set to a demo SUT.\n",
"* `experiment` variable will be used to organize the various runs in mlflow.\n",
"* `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
"* `num_workers` can adjust the parallelism."
]
},
{
Expand All @@ -73,8 +71,7 @@
"experiment = \"fp_private_\" + datetime.date.today().strftime(\"%Y%m%d\")\n",
"prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n",
"ground_truth = \"data/fakegroundtruth.csv\"\n",
"cache_dir = None\n",
"n_jobs = 4"
"num_workers = 4"
]
},
{
Expand Down Expand Up @@ -148,8 +145,7 @@
" sut_id=sut_id,\n",
" experiment=experiment,\n",
" prompts=prompts,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
},
Expand All @@ -169,8 +165,7 @@
" ensemble_id=\"official-1.0\",\n",
" experiment=experiment,\n",
" response_run_id=run_id,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")\n",
"```"
]
Expand All @@ -190,8 +185,7 @@
" ensemble_strategy=\"anyunsafe\",\n",
" experiment=experiment,\n",
" response_run_id=run_id,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
}
Expand Down
9 changes: 3 additions & 6 deletions flightpaths/vLLM Annotator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@
"dvc_repo = \"https://github.com/mlcommons/modelplane.git\"\n",
"prompts = \"flightpaths/data/demo_prompts_mini.csv\"\n",
"ground_truth = \"data/fakegroundtruth.csv\"\n",
"cache_dir = None\n",
"n_jobs = 4\n",
"num_workers = 4\n",
"\n",
"vllm_host = \"http://vllm:8001/v1\"\n",
"vllm_model = \"mlc/not-real-model\"\n",
Expand Down Expand Up @@ -90,8 +89,7 @@
" experiment=experiment,\n",
" dvc_repo=dvc_repo,\n",
" prompts=prompts,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
},
Expand Down Expand Up @@ -237,8 +235,7 @@
" annotator_ids=[vllm_annotator_uid],\n",
" experiment=experiment,\n",
" response_run_id=run_id,\n",
" cache_dir=cache_dir,\n",
" n_jobs=n_jobs,\n",
" num_workers=num_workers,\n",
")"
]
}
Expand Down
40 changes: 20 additions & 20 deletions src/modelplane/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,25 @@ def list_suts_cli():
help="URL of the DVC repo to get the prompts from. E.g. https://github.com/my-org/my-repo.git. Can specify the revision using the `#` suffix, e.g. https://github.com/my-org/my-repo.git#main.",
)
@click.option(
"--cache_dir",
type=str,
default=None,
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
"--disable_cache",
is_flag=True,
default=False,
help="Disable caching of LLM responses. If set, the pipeline will not cache SUT/annotator responses. Otherwise, cached responses will be stored locally in `.cache`.",
)
@click.option(
"--n_jobs",
"--num_workers",
type=int,
default=1,
help="The number of jobs to run in parallel. Defaults to 1.",
help="The number of workers to run in parallel. Defaults to 1.",
)
@load_from_dotenv
def get_sut_responses(
sut_id: str,
prompts: str,
experiment: str,
dvc_repo: str | None = None,
cache_dir: str | None = None,
n_jobs: int = 1,
disable_cache: bool = False,
num_workers: int = 1,
):
"""
Run the pipeline to get responses from SUTs.
Expand All @@ -88,8 +88,8 @@ def get_sut_responses(
prompts=prompts,
experiment=experiment,
dvc_repo=dvc_repo,
cache_dir=cache_dir,
n_jobs=n_jobs,
disable_cache=disable_cache,
num_workers=num_workers,
)


Expand Down Expand Up @@ -148,16 +148,16 @@ def get_sut_responses(
help="Use the response_run_id to save annotation artifact. Any existing annotation artifact will be overwritten. If not set, a new run will be created. Only applies if not using response_run_file.",
)
@click.option(
"--cache_dir",
type=str,
default=None,
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
"--disable_cache",
is_flag=True,
default=False,
help="Disable caching of LLM responses. If set, the pipeline will not cache SUT/annotator responses. Otherwise, cached responses will be stored locally in `.cache`.",
)
@click.option(
"--n_jobs",
"--num_workers",
type=int,
default=1,
help="The number of jobs to run in parallel. Defaults to 1.",
help="The number of workers to run in parallel. Defaults to 1.",
)
@load_from_dotenv
def get_annotations(
Expand All @@ -169,8 +169,8 @@ def get_annotations(
ensemble_strategy: str | None = None,
ensemble_id: str | None = None,
overwrite: bool = False,
cache_dir: str | None = None,
n_jobs: int = 1,
disable_cache: bool = False,
num_workers: int = 1,
):
return annotate(
experiment=experiment,
Expand All @@ -181,8 +181,8 @@ def get_annotations(
ensemble_strategy=ensemble_strategy,
ensemble_id=ensemble_id,
overwrite=overwrite,
cache_dir=cache_dir,
n_jobs=n_jobs,
disable_cache=disable_cache,
num_workers=num_workers,
)


Expand Down
15 changes: 7 additions & 8 deletions src/modelplane/runways/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from modelplane.mlflow.loghelpers import log_tags
from modelplane.runways.utils import (
CACHE_DIR,
MODELGAUGE_RUN_TAG_NAME,
PROMPT_RESPONSE_ARTIFACT_NAME,
RUN_TYPE_ANNOTATOR,
Expand Down Expand Up @@ -47,8 +48,8 @@ def annotate(
ensemble_strategy: str | None = None,
ensemble_id: str | None = None,
overwrite: bool = False,
cache_dir: str | None = None,
n_jobs: int = 1,
disable_cache: bool = False,
num_workers: int = 1,
) -> str:
"""
Run annotations and record measurements.
Expand All @@ -57,8 +58,9 @@ def annotate(
pipeline_kwargs = _get_annotator_settings(
annotator_ids, ensemble_strategy, ensemble_id
)
pipeline_kwargs["cache_dir"] = cache_dir
pipeline_kwargs["num_workers"] = n_jobs
if not disable_cache:
pipeline_kwargs["cache_dir"] = CACHE_DIR
pipeline_kwargs["num_workers"] = num_workers

# set the tags
tags = {RUN_TYPE_TAG_NAME: RUN_TYPE_ANNOTATOR}
Expand All @@ -80,10 +82,7 @@ def annotate(
else:
run_id = None

params = {
"cache_dir": cache_dir,
"n_jobs": n_jobs,
}
params = {"num_workers": num_workers}

with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, tags=tags) as run:
mlflow.log_params(params)
Expand Down
14 changes: 6 additions & 8 deletions src/modelplane/runways/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modelgauge.sut_registry import SUTS

from modelplane.runways.utils import (
CACHE_DIR,
MODELGAUGE_RUN_TAG_NAME,
RUN_TYPE_RESPONDER,
RUN_TYPE_TAG_NAME,
Expand All @@ -24,15 +25,12 @@ def respond(
prompts: str,
experiment: str,
dvc_repo: str | None = None,
cache_dir: str | None = None,
n_jobs: int = 1,
disable_cache: bool = False,
num_workers: int = 1,
) -> str:
secrets = setup_sut_credentials(sut_id)
sut = SUTS.make_instance(uid=sut_id, secrets=secrets)
params = {
"cache_dir": cache_dir,
"n_jobs": n_jobs,
}
params = {"num_workers": num_workers}
tags = {"sut_id": sut_id, RUN_TYPE_TAG_NAME: RUN_TYPE_RESPONDER}

experiment_id = get_experiment_id(experiment)
Expand All @@ -44,10 +42,10 @@ def respond(
input_data = build_input(path=prompts, dvc_repo=dvc_repo, dest_dir=tmp)
input_data.log_input()
pipeline_runner = PromptRunner(
num_workers=n_jobs,
num_workers=num_workers,
input_path=input_data.local_path(),
output_dir=pathlib.Path(tmp),
cache_dir=cache_dir,
cache_dir=None if disable_cache else CACHE_DIR,
suts={sut_id: sut},
)

Expand Down
1 change: 1 addition & 0 deletions src/modelplane/runways/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RUN_TYPE_ANNOTATOR = "annotate"
RUN_TYPE_SCORER = "score"
MODELGAUGE_RUN_TAG_NAME = "modelgauge_run_id"
CACHE_DIR = ".cache"


def is_debug_mode() -> bool:
Expand Down
Loading