Skip to content

Commit 149cb5e

Browse files
authored
Set up default caching options for modelplane users (#67)
* Use num_workers to be consistent with modelgauge. * Set up default caching options for SUT/annotator responses * Replace `n_jobs` and `cache_dir` with `num_workers` in notebooks. * Fixes. * Explain caching in readme.md.
1 parent f2641f2 commit 149cb5e

File tree

12 files changed

+86
-111
lines changed

12 files changed

+86
-111
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ secrets.toml
66
*.pyc
77
.vscode/
88
.coverage*
9+
.cache

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ scratch.
7373
* You can manage branches and commits for
7474
`modelplane-flights` directly from jupyter.
7575

76+
## Caching
77+
78+
Annotator and SUT responses will be cached (locally) unless you pass the
79+
`disable_cache` flag to the appropriate calls.
80+
7681
## CLI
7782

7883
You can also interact with modelplane via CLI. Run `poetry run modelplane --help`

docker-compose.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ services:
7373
- "8888:8888"
7474
volumes:
7575
- ./flightpaths:/app/flightpaths
76+
# Used for caching of SUT/annotator results
77+
- ./flightpaths/.cache:/app/flightpaths/.cache
7678
# Volume not needed if not using modelplane-flights for sharing notebooks
7779
- ../modelplane-flights:/app/flightpaths/flights
7880
# Volume not needed if using cloud storage for artifacts

flightpaths/Annotator Development Template.ipynb

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@
5050
"\n",
5151
"The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
5252
"\n",
53-
"You can cache prompt responses via `cache_dir`.\n",
54-
"\n",
55-
"Finally, `n_jobs` can adjust the parallelism."
53+
"Finally, `num_workers` can adjust the parallelism."
5654
]
5755
},
5856
{
@@ -66,8 +64,7 @@
6664
"experiment = \"experiment_\" + datetime.date.today().strftime(\"%Y%m%d\")\n",
6765
"prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n",
6866
"ground_truth = \"data/fakegroundtruth.csv\"\n",
69-
"cache_dir = None\n",
70-
"n_jobs = 4"
67+
"num_workers = 4"
7168
]
7269
},
7370
{
@@ -168,8 +165,7 @@
168165
" sut_id=sut_id,\n",
169166
" experiment=experiment,\n",
170167
" prompts=prompts,\n",
171-
" cache_dir=cache_dir,\n",
172-
" n_jobs=n_jobs,\n",
168+
" num_workers=num_workers,\n",
173169
")"
174170
]
175171
},
@@ -194,8 +190,7 @@
194190
" annotator_ids=[annotator_id],\n",
195191
" experiment=experiment,\n",
196192
" response_run_id=run_id,\n",
197-
" cache_dir=cache_dir,\n",
198-
" n_jobs=n_jobs,\n",
193+
" num_workers=num_workers,\n",
199194
")"
200195
]
201196
},

flightpaths/Ensemble Development Template.ipynb

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@
4242
"\n",
4343
"The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
4444
"\n",
45-
"You can cache prompt responses via `cache_dir`.\n",
46-
"\n",
47-
"Finally, `n_jobs` can adjust the parallelism."
45+
"Finally, `num_workers` can adjust the parallelism."
4846
]
4947
},
5048
{
@@ -58,8 +56,7 @@
5856
"experiment = \"ensemble_experiment_\" + datetime.date.today().strftime(\"%Y%m%d\")\n",
5957
"prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n",
6058
"ground_truth = \"data/fakegroundtruth.csv\"\n",
61-
"cache_dir = None\n",
62-
"n_jobs = 4"
59+
"num_workers = 4"
6360
]
6461
},
6562
{
@@ -226,8 +223,7 @@
226223
" sut_id=sut_id,\n",
227224
" experiment=experiment,\n",
228225
" prompts=prompts,\n",
229-
" cache_dir=cache_dir,\n",
230-
" n_jobs=n_jobs,\n",
226+
" num_workers=num_workers,\n",
231227
")"
232228
]
233229
},
@@ -253,8 +249,7 @@
253249
" ensemble_strategy=\"anyunsafe\",\n",
254250
" experiment=experiment,\n",
255251
" response_run_id=run_id,\n",
256-
" cache_dir=cache_dir,\n",
257-
" n_jobs=n_jobs,\n",
252+
" num_workers=num_workers,\n",
258253
")"
259254
]
260255
},
@@ -281,14 +276,6 @@
281276
" ground_truth=ground_truth,\n",
282277
")"
283278
]
284-
},
285-
{
286-
"cell_type": "code",
287-
"execution_count": null,
288-
"id": "476edda5-cfca-489f-90ac-bac11f9530ef",
289-
"metadata": {},
290-
"outputs": [],
291-
"source": []
292279
}
293280
],
294281
"metadata": {

flightpaths/Running the Evaluator with Mods.ipynb

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,10 @@
5454
"source": [
5555
"## Settings\n",
5656
"\n",
57-
"* The `sut_id` refers to the model that generates the responses to the prompts. It is currently set to a demo SUT.\n",
58-
"* The `experiment` variable will be used to organize the various runs in mlflow.\n",
59-
"* The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
60-
"* You can cache prompt responses via `cache_dir`.\n",
61-
"\n",
62-
"Finally, `n_jobs` can adjust the parallelism."
57+
"* `sut_id` refers to the model that generates the responses to the prompts. It is currently set to a demo SUT.\n",
58+
"* `experiment` variable will be used to organize the various runs in mlflow.\n",
59+
"* `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n",
60+
"* `num_workers` can adjust the parallelism."
6361
]
6462
},
6563
{
@@ -73,8 +71,7 @@
7371
"experiment = \"fp_private_\" + datetime.date.today().strftime(\"%Y%m%d\")\n",
7472
"prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n",
7573
"ground_truth = \"data/fakegroundtruth.csv\"\n",
76-
"cache_dir = None\n",
77-
"n_jobs = 4"
74+
"num_workers = 4"
7875
]
7976
},
8077
{
@@ -148,8 +145,7 @@
148145
" sut_id=sut_id,\n",
149146
" experiment=experiment,\n",
150147
" prompts=prompts,\n",
151-
" cache_dir=cache_dir,\n",
152-
" n_jobs=n_jobs,\n",
148+
" num_workers=num_workers,\n",
153149
")"
154150
]
155151
},
@@ -169,8 +165,7 @@
169165
" ensemble_id=\"official-1.0\",\n",
170166
" experiment=experiment,\n",
171167
" response_run_id=run_id,\n",
172-
" cache_dir=cache_dir,\n",
173-
" n_jobs=n_jobs,\n",
168+
" num_workers=num_workers,\n",
174169
")\n",
175170
"```"
176171
]
@@ -190,8 +185,7 @@
190185
" ensemble_strategy=\"anyunsafe\",\n",
191186
" experiment=experiment,\n",
192187
" response_run_id=run_id,\n",
193-
" cache_dir=cache_dir,\n",
194-
" n_jobs=n_jobs,\n",
188+
" num_workers=num_workers,\n",
195189
")"
196190
]
197191
}

flightpaths/vLLM Annotator.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@
5959
"dvc_repo = \"https://github.com/mlcommons/modelplane.git\"\n",
6060
"prompts = \"flightpaths/data/demo_prompts_mini.csv\"\n",
6161
"ground_truth = \"data/fakegroundtruth.csv\"\n",
62-
"cache_dir = None\n",
63-
"n_jobs = 4\n",
62+
"num_workers = 4\n",
6463
"\n",
6564
"vllm_host = \"http://vllm:8001/v1\"\n",
6665
"vllm_model = \"mlc/not-real-model\"\n",
@@ -90,8 +89,7 @@
9089
" experiment=experiment,\n",
9190
" dvc_repo=dvc_repo,\n",
9291
" prompts=prompts,\n",
93-
" cache_dir=cache_dir,\n",
94-
" n_jobs=n_jobs,\n",
92+
" num_workers=num_workers,\n",
9593
")"
9694
]
9795
},
@@ -237,8 +235,7 @@
237235
" annotator_ids=[vllm_annotator_uid],\n",
238236
" experiment=experiment,\n",
239237
" response_run_id=run_id,\n",
240-
" cache_dir=cache_dir,\n",
241-
" n_jobs=n_jobs,\n",
238+
" num_workers=num_workers,\n",
242239
")"
243240
]
244241
}

src/modelplane/cli.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,25 +60,25 @@ def list_suts_cli():
6060
help="URL of the DVC repo to get the prompts from. E.g. https://github.com/my-org/my-repo.git. Can specify the revision using the `#` suffix, e.g. https://github.com/my-org/my-repo.git#main.",
6161
)
6262
@click.option(
63-
"--cache_dir",
64-
type=str,
65-
default=None,
66-
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
63+
"--disable_cache",
64+
is_flag=True,
65+
default=False,
66+
help="Disable caching of LLM responses. If set, the pipeline will not cache SUT/annotator responses. Otherwise, cached responses will be stored locally in `.cache`.",
6767
)
6868
@click.option(
69-
"--n_jobs",
69+
"--num_workers",
7070
type=int,
7171
default=1,
72-
help="The number of jobs to run in parallel. Defaults to 1.",
72+
help="The number of workers to run in parallel. Defaults to 1.",
7373
)
7474
@load_from_dotenv
7575
def get_sut_responses(
7676
sut_id: str,
7777
prompts: str,
7878
experiment: str,
7979
dvc_repo: str | None = None,
80-
cache_dir: str | None = None,
81-
n_jobs: int = 1,
80+
disable_cache: bool = False,
81+
num_workers: int = 1,
8282
):
8383
"""
8484
Run the pipeline to get responses from SUTs.
@@ -88,8 +88,8 @@ def get_sut_responses(
8888
prompts=prompts,
8989
experiment=experiment,
9090
dvc_repo=dvc_repo,
91-
cache_dir=cache_dir,
92-
n_jobs=n_jobs,
91+
disable_cache=disable_cache,
92+
num_workers=num_workers,
9393
)
9494

9595

@@ -148,16 +148,16 @@ def get_sut_responses(
148148
help="Use the response_run_id to save annotation artifact. Any existing annotation artifact will be overwritten. If not set, a new run will be created. Only applies if not using response_run_file.",
149149
)
150150
@click.option(
151-
"--cache_dir",
152-
type=str,
153-
default=None,
154-
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
151+
"--disable_cache",
152+
is_flag=True,
153+
default=False,
154+
help="Disable caching of LLM responses. If set, the pipeline will not cache SUT/annotator responses. Otherwise, cached responses will be stored locally in `.cache`.",
155155
)
156156
@click.option(
157-
"--n_jobs",
157+
"--num_workers",
158158
type=int,
159159
default=1,
160-
help="The number of jobs to run in parallel. Defaults to 1.",
160+
help="The number of workers to run in parallel. Defaults to 1.",
161161
)
162162
@load_from_dotenv
163163
def get_annotations(
@@ -169,8 +169,8 @@ def get_annotations(
169169
ensemble_strategy: str | None = None,
170170
ensemble_id: str | None = None,
171171
overwrite: bool = False,
172-
cache_dir: str | None = None,
173-
n_jobs: int = 1,
172+
disable_cache: bool = False,
173+
num_workers: int = 1,
174174
):
175175
return annotate(
176176
experiment=experiment,
@@ -181,8 +181,8 @@ def get_annotations(
181181
ensemble_strategy=ensemble_strategy,
182182
ensemble_id=ensemble_id,
183183
overwrite=overwrite,
184-
cache_dir=cache_dir,
185-
n_jobs=n_jobs,
184+
disable_cache=disable_cache,
185+
num_workers=num_workers,
186186
)
187187

188188

src/modelplane/runways/annotator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from modelplane.mlflow.loghelpers import log_tags
2020
from modelplane.runways.utils import (
21+
CACHE_DIR,
2122
MODELGAUGE_RUN_TAG_NAME,
2223
PROMPT_RESPONSE_ARTIFACT_NAME,
2324
RUN_TYPE_ANNOTATOR,
@@ -47,8 +48,8 @@ def annotate(
4748
ensemble_strategy: str | None = None,
4849
ensemble_id: str | None = None,
4950
overwrite: bool = False,
50-
cache_dir: str | None = None,
51-
n_jobs: int = 1,
51+
disable_cache: bool = False,
52+
num_workers: int = 1,
5253
) -> str:
5354
"""
5455
Run annotations and record measurements.
@@ -57,8 +58,9 @@ def annotate(
5758
pipeline_kwargs = _get_annotator_settings(
5859
annotator_ids, ensemble_strategy, ensemble_id
5960
)
60-
pipeline_kwargs["cache_dir"] = cache_dir
61-
pipeline_kwargs["num_workers"] = n_jobs
61+
if not disable_cache:
62+
pipeline_kwargs["cache_dir"] = CACHE_DIR
63+
pipeline_kwargs["num_workers"] = num_workers
6264

6365
# set the tags
6466
tags = {RUN_TYPE_TAG_NAME: RUN_TYPE_ANNOTATOR}
@@ -80,10 +82,7 @@ def annotate(
8082
else:
8183
run_id = None
8284

83-
params = {
84-
"cache_dir": cache_dir,
85-
"n_jobs": n_jobs,
86-
}
85+
params = {"num_workers": num_workers}
8786

8887
with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, tags=tags) as run:
8988
mlflow.log_params(params)

src/modelplane/runways/responder.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from modelgauge.sut_registry import SUTS
1010

1111
from modelplane.runways.utils import (
12+
CACHE_DIR,
1213
MODELGAUGE_RUN_TAG_NAME,
1314
RUN_TYPE_RESPONDER,
1415
RUN_TYPE_TAG_NAME,
@@ -24,15 +25,12 @@ def respond(
2425
prompts: str,
2526
experiment: str,
2627
dvc_repo: str | None = None,
27-
cache_dir: str | None = None,
28-
n_jobs: int = 1,
28+
disable_cache: bool = False,
29+
num_workers: int = 1,
2930
) -> str:
3031
secrets = setup_sut_credentials(sut_id)
3132
sut = SUTS.make_instance(uid=sut_id, secrets=secrets)
32-
params = {
33-
"cache_dir": cache_dir,
34-
"n_jobs": n_jobs,
35-
}
33+
params = {"num_workers": num_workers}
3634
tags = {"sut_id": sut_id, RUN_TYPE_TAG_NAME: RUN_TYPE_RESPONDER}
3735

3836
experiment_id = get_experiment_id(experiment)
@@ -44,10 +42,10 @@ def respond(
4442
input_data = build_input(path=prompts, dvc_repo=dvc_repo, dest_dir=tmp)
4543
input_data.log_input()
4644
pipeline_runner = PromptRunner(
47-
num_workers=n_jobs,
45+
num_workers=num_workers,
4846
input_path=input_data.local_path(),
4947
output_dir=pathlib.Path(tmp),
50-
cache_dir=cache_dir,
48+
cache_dir=None if disable_cache else CACHE_DIR,
5149
suts={sut_id: sut},
5250
)
5351

0 commit comments

Comments
 (0)