diff --git a/.gitignore b/.gitignore index ec3aa5a..aab6656 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ secrets.toml *.pyc .vscode/ .coverage* +.cache diff --git a/README.md b/README.md index da6fc32..3b787ee 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,11 @@ scratch. * You can manage branches and commits for `modelplane-flights` directly from jupyter. +## Caching + +Annotator and SUT responses will be cached (locally) unless you pass the +`disable_cache` flag to the appropriate calls. + ## CLI You can also interact with modelplane via CLI. Run `poetry run modelplane --help` diff --git a/docker-compose.yaml b/docker-compose.yaml index fa35443..80c9760 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -73,6 +73,8 @@ services: - "8888:8888" volumes: - ./flightpaths:/app/flightpaths + # Used for caching of SUT/annotator results + - ./flightpaths/.cache:/app/flightpaths/.cache # Volume not needed if not using modelplane-flights for sharing notebooks - ../modelplane-flights:/app/flightpaths/flights # Volume not needed if using cloud storage for artifacts diff --git a/flightpaths/Annotator Development Template.ipynb b/flightpaths/Annotator Development Template.ipynb index 55982b2..7587259 100644 --- a/flightpaths/Annotator Development Template.ipynb +++ b/flightpaths/Annotator Development Template.ipynb @@ -50,9 +50,7 @@ "\n", "The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n", "\n", - "You can cache prompt responses via `cache_dir`.\n", - "\n", - "Finally, `n_jobs` can adjust the parallelism." + "Finally, `num_workers` can adjust the parallelism." ] }, { @@ -66,8 +64,7 @@ "experiment = \"experiment_\" + datetime.date.today().strftime(\"%Y%m%d\")\n", "prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n", "ground_truth = \"data/fakegroundtruth.csv\"\n", - "cache_dir = None\n", - "n_jobs = 4" + "num_workers = 4" ] }, { @@ -168,8 +165,7 @@ " sut_id=sut_id,\n", " experiment=experiment,\n", " prompts=prompts,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] }, @@ -194,8 +190,7 @@ " annotator_ids=[annotator_id],\n", " experiment=experiment,\n", " response_run_id=run_id,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] }, diff --git a/flightpaths/Ensemble Development Template.ipynb b/flightpaths/Ensemble Development Template.ipynb index e16d4a3..472cee6 100644 --- a/flightpaths/Ensemble Development Template.ipynb +++ b/flightpaths/Ensemble Development Template.ipynb @@ -42,9 +42,7 @@ "\n", "The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n", "\n", - "You can cache prompt responses via `cache_dir`.\n", - "\n", - "Finally, `n_jobs` can adjust the parallelism." + "Finally, `num_workers` can adjust the parallelism." ] }, { @@ -58,8 +56,7 @@ "experiment = \"ensemble_experiment_\" + datetime.date.today().strftime(\"%Y%m%d\")\n", "prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n", "ground_truth = \"data/fakegroundtruth.csv\"\n", - "cache_dir = None\n", - "n_jobs = 4" + "num_workers = 4" ] }, { @@ -226,8 +223,7 @@ " sut_id=sut_id,\n", " experiment=experiment,\n", " prompts=prompts,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] }, @@ -253,8 +249,7 @@ " ensemble_strategy=\"anyunsafe\",\n", " experiment=experiment,\n", " response_run_id=run_id,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] }, @@ -281,14 +276,6 @@ " ground_truth=ground_truth,\n", ")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "476edda5-cfca-489f-90ac-bac11f9530ef", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/flightpaths/Running the Evaluator with Mods.ipynb b/flightpaths/Running the Evaluator with Mods.ipynb index 29c3bab..a424a23 100644 --- a/flightpaths/Running the Evaluator with Mods.ipynb +++ b/flightpaths/Running the Evaluator with Mods.ipynb @@ -54,12 +54,10 @@ "source": [ "## Settings\n", "\n", - "* The `sut_id` refers to the model that generates the responses to the prompts. It is currently set to a demo SUT.\n", - "* The `experiment` variable will be used to organize the various runs in mlflow.\n", - "* The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n", - "* You can cache prompt responses via `cache_dir`.\n", - "\n", - "Finally, `n_jobs` can adjust the parallelism." + "* `sut_id` refers to the model that generates the responses to the prompts. It is currently set to a demo SUT.\n", + "* `experiment` variable will be used to organize the various runs in mlflow.\n", + "* `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.\n", + "* `num_workers` can adjust the parallelism." ] }, { @@ -73,8 +71,7 @@ "experiment = \"fp_private_\" + datetime.date.today().strftime(\"%Y%m%d\")\n", "prompts = \"data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv\"\n", "ground_truth = \"data/fakegroundtruth.csv\"\n", - "cache_dir = None\n", - "n_jobs = 4" + "num_workers = 4" ] }, { @@ -148,8 +145,7 @@ " sut_id=sut_id,\n", " experiment=experiment,\n", " prompts=prompts,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] }, @@ -169,8 +165,7 @@ " ensemble_id=\"official-1.0\",\n", " experiment=experiment,\n", " response_run_id=run_id,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")\n", "```" ] @@ -190,8 +185,7 @@ " ensemble_strategy=\"anyunsafe\",\n", " experiment=experiment,\n", " response_run_id=run_id,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] } diff --git a/flightpaths/vLLM Annotator.ipynb b/flightpaths/vLLM Annotator.ipynb index 3799b2d..df2fd99 100644 --- a/flightpaths/vLLM Annotator.ipynb +++ b/flightpaths/vLLM Annotator.ipynb @@ -59,8 +59,7 @@ "dvc_repo = \"https://github.com/mlcommons/modelplane.git\"\n", "prompts = \"flightpaths/data/demo_prompts_mini.csv\"\n", "ground_truth = \"data/fakegroundtruth.csv\"\n", - "cache_dir = None\n", - "n_jobs = 4\n", + "num_workers = 4\n", "\n", "vllm_host = \"http://vllm:8001/v1\"\n", "vllm_model = \"mlc/not-real-model\"\n", @@ -90,8 +89,7 @@ " experiment=experiment,\n", " dvc_repo=dvc_repo,\n", " prompts=prompts,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] }, @@ -237,8 +235,7 @@ " annotator_ids=[vllm_annotator_uid],\n", " experiment=experiment,\n", " response_run_id=run_id,\n", - " cache_dir=cache_dir,\n", - " n_jobs=n_jobs,\n", + " num_workers=num_workers,\n", ")" ] } diff --git a/src/modelplane/cli.py b/src/modelplane/cli.py index 19db47d..52521ef 100644 --- a/src/modelplane/cli.py +++ b/src/modelplane/cli.py @@ -60,16 +60,16 @@ def list_suts_cli(): help="URL of the DVC repo to get the prompts from. E.g. https://github.com/my-org/my-repo.git. Can specify the revision using the `#` suffix, e.g. https://github.com/my-org/my-repo.git#main.", ) @click.option( - "--cache_dir", - type=str, - default=None, - help="The cache directory. Defaults to None. Local directory used to cache LLM responses.", + "--disable_cache", + is_flag=True, + default=False, + help="Disable caching of LLM responses. If set, the pipeline will not cache SUT/annotator responses. Otherwise, cached responses will be stored locally in `.cache`.", ) @click.option( - "--n_jobs", + "--num_workers", type=int, default=1, - help="The number of jobs to run in parallel. Defaults to 1.", + help="The number of workers to run in parallel. Defaults to 1.", ) @load_from_dotenv def get_sut_responses( @@ -77,8 +77,8 @@ def get_sut_responses( prompts: str, experiment: str, dvc_repo: str | None = None, - cache_dir: str | None = None, - n_jobs: int = 1, + disable_cache: bool = False, + num_workers: int = 1, ): """ Run the pipeline to get responses from SUTs. @@ -88,8 +88,8 @@ def get_sut_responses( prompts=prompts, experiment=experiment, dvc_repo=dvc_repo, - cache_dir=cache_dir, - n_jobs=n_jobs, + disable_cache=disable_cache, + num_workers=num_workers, ) @@ -148,16 +148,16 @@ def get_sut_responses( help="Use the response_run_id to save annotation artifact. Any existing annotation artifact will be overwritten. If not set, a new run will be created. Only applies if not using response_run_file.", ) @click.option( - "--cache_dir", - type=str, - default=None, - help="The cache directory. Defaults to None. Local directory used to cache LLM responses.", + "--disable_cache", + is_flag=True, + default=False, + help="Disable caching of LLM responses. If set, the pipeline will not cache SUT/annotator responses. Otherwise, cached responses will be stored locally in `.cache`.", ) @click.option( - "--n_jobs", + "--num_workers", type=int, default=1, - help="The number of jobs to run in parallel. Defaults to 1.", + help="The number of workers to run in parallel. Defaults to 1.", ) @load_from_dotenv def get_annotations( @@ -169,8 +169,8 @@ def get_annotations( ensemble_strategy: str | None = None, ensemble_id: str | None = None, overwrite: bool = False, - cache_dir: str | None = None, - n_jobs: int = 1, + disable_cache: bool = False, + num_workers: int = 1, ): return annotate( experiment=experiment, @@ -181,8 +181,8 @@ def get_annotations( ensemble_strategy=ensemble_strategy, ensemble_id=ensemble_id, overwrite=overwrite, - cache_dir=cache_dir, - n_jobs=n_jobs, + disable_cache=disable_cache, + num_workers=num_workers, ) diff --git a/src/modelplane/runways/annotator.py b/src/modelplane/runways/annotator.py index cb41a8d..d078ca4 100644 --- a/src/modelplane/runways/annotator.py +++ b/src/modelplane/runways/annotator.py @@ -18,6 +18,7 @@ from modelplane.mlflow.loghelpers import log_tags from modelplane.runways.utils import ( + CACHE_DIR, MODELGAUGE_RUN_TAG_NAME, PROMPT_RESPONSE_ARTIFACT_NAME, RUN_TYPE_ANNOTATOR, @@ -47,8 +48,8 @@ def annotate( ensemble_strategy: str | None = None, ensemble_id: str | None = None, overwrite: bool = False, - cache_dir: str | None = None, - n_jobs: int = 1, + disable_cache: bool = False, + num_workers: int = 1, ) -> str: """ Run annotations and record measurements. @@ -57,8 +58,9 @@ def annotate( pipeline_kwargs = _get_annotator_settings( annotator_ids, ensemble_strategy, ensemble_id ) - pipeline_kwargs["cache_dir"] = cache_dir - pipeline_kwargs["num_workers"] = n_jobs + if not disable_cache: + pipeline_kwargs["cache_dir"] = CACHE_DIR + pipeline_kwargs["num_workers"] = num_workers # set the tags tags = {RUN_TYPE_TAG_NAME: RUN_TYPE_ANNOTATOR} @@ -80,10 +82,7 @@ def annotate( else: run_id = None - params = { - "cache_dir": cache_dir, - "n_jobs": n_jobs, - } + params = {"num_workers": num_workers} with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, tags=tags) as run: mlflow.log_params(params) diff --git a/src/modelplane/runways/responder.py b/src/modelplane/runways/responder.py index 5f5b7c6..9c81c41 100644 --- a/src/modelplane/runways/responder.py +++ b/src/modelplane/runways/responder.py @@ -9,6 +9,7 @@ from modelgauge.sut_registry import SUTS from modelplane.runways.utils import ( + CACHE_DIR, MODELGAUGE_RUN_TAG_NAME, RUN_TYPE_RESPONDER, RUN_TYPE_TAG_NAME, @@ -24,15 +25,12 @@ def respond( prompts: str, experiment: str, dvc_repo: str | None = None, - cache_dir: str | None = None, - n_jobs: int = 1, + disable_cache: bool = False, + num_workers: int = 1, ) -> str: secrets = setup_sut_credentials(sut_id) sut = SUTS.make_instance(uid=sut_id, secrets=secrets) - params = { - "cache_dir": cache_dir, - "n_jobs": n_jobs, - } + params = {"num_workers": num_workers} tags = {"sut_id": sut_id, RUN_TYPE_TAG_NAME: RUN_TYPE_RESPONDER} experiment_id = get_experiment_id(experiment) @@ -44,10 +42,10 @@ def respond( input_data = build_input(path=prompts, dvc_repo=dvc_repo, dest_dir=tmp) input_data.log_input() pipeline_runner = PromptRunner( - num_workers=n_jobs, + num_workers=num_workers, input_path=input_data.local_path(), output_dir=pathlib.Path(tmp), - cache_dir=cache_dir, + cache_dir=None if disable_cache else CACHE_DIR, suts={sut_id: sut}, ) diff --git a/src/modelplane/runways/utils.py b/src/modelplane/runways/utils.py index d7c324a..1f9cb70 100644 --- a/src/modelplane/runways/utils.py +++ b/src/modelplane/runways/utils.py @@ -22,6 +22,7 @@ RUN_TYPE_ANNOTATOR = "annotate" RUN_TYPE_SCORER = "score" MODELGAUGE_RUN_TAG_NAME = "modelgauge_run_id" +CACHE_DIR = ".cache" def is_debug_mode() -> bool: diff --git a/tests/it/runways/test_e2e.py b/tests/it/runways/test_e2e.py index 6561dea..ac84a05 100644 --- a/tests/it/runways/test_e2e.py +++ b/tests/it/runways/test_e2e.py @@ -20,21 +20,21 @@ def test_e2e(): prompts = "tests/data/prompts.csv" ground_truth = "tests/data/ground_truth.csv" experiment = "test_experiment_" + time.strftime("%Y%m%d%H%M%S", time.localtime()) - n_jobs = 1 + num_workers = 1 run_id = check_responder( sut_id=sut_id, prompts=prompts, experiment=experiment, - cache_dir=None, - n_jobs=n_jobs, + disable_cache=True, + num_workers=num_workers, ) run_id = check_annotator( response_run_id=run_id, annotator_ids=[TEST_ANNOTATOR_ID], experiment=experiment, - cache_dir=None, - n_jobs=n_jobs, + disable_cache=True, + num_workers=num_workers, ) check_scorer( annotation_run_id=run_id, @@ -48,17 +48,16 @@ def check_responder( sut_id: str, prompts: str, experiment: str, - cache_dir: str | None, - n_jobs: int, + disable_cache: bool, + num_workers: int, ): - with tempfile.TemporaryDirectory() as cache_dir: - run_id = respond( - sut_id=sut_id, - prompts=prompts, - experiment=experiment, - cache_dir=cache_dir, - n_jobs=n_jobs, - ) + run_id = respond( + sut_id=sut_id, + prompts=prompts, + experiment=experiment, + disable_cache=disable_cache, + num_workers=num_workers, + ) # confirm experiment exists exp = mlflow.get_experiment_by_name(experiment) @@ -69,8 +68,7 @@ def check_responder( run = mlflow.get_run(run_id) params = run.data.params tags = run.data.tags - assert params.get("cache_dir") == cache_dir - assert params.get("n_jobs") == str(n_jobs) + assert params.get("num_workers") == str(num_workers) assert tags.get("sut_id") == sut_id # validate responses @@ -99,18 +97,17 @@ def check_annotator( response_run_id: str, annotator_ids: List[str], experiment: str, - cache_dir: str | None, - n_jobs: int, + disable_cache: bool, + num_workers: int, ): # run the annotator - with tempfile.TemporaryDirectory() as cache_dir: - run_id = annotate( - response_run_id=response_run_id, - annotator_ids=annotator_ids, - experiment=experiment, - cache_dir=cache_dir, - n_jobs=n_jobs, - ) + run_id = annotate( + response_run_id=response_run_id, + annotator_ids=annotator_ids, + experiment=experiment, + disable_cache=disable_cache, + num_workers=num_workers, + ) # confirm experiment exists exp = mlflow.get_experiment_by_name(experiment) assert exp is not None @@ -120,8 +117,7 @@ def check_annotator( params = run.data.params tags = run.data.tags metrics = run.data.metrics - assert params.get("cache_dir") == cache_dir - assert params.get("n_jobs") == str(n_jobs) + assert params.get("num_workers") == str(num_workers) assert tags.get(f"annotator_{TEST_ANNOTATOR_ID}") == "true" # expect 5 safe (every other item)