Skip to content

Commit 703e4fb

Browse files
authored
Fix the CLI. (#16)
* Fix the CLI. * Documentation for CLI. * CLI tests. * Centralize CLI commands. * Also centralize plugin loading.
1 parent 9219dd5 commit 703e4fb

File tree

7 files changed

+185
-118
lines changed

7 files changed

+185
-118
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ given docker-compose.yaml file will start up:
3838
1. The runs can be monitored in MLFlow wherever you have that set up. If
3939
local with the default setup, http://localhost:8080.
4040
41+
## CLI
42+
43+
You can also interact with modelplane via CLI. Run `poetry run modelplane --help`
44+
for more details.
45+
46+
*Important:* You must set the `MLFLOW_TRACKING_URI` environmental variable.
47+
For example, if you've brought up MLFlow using the docker compose process above,
48+
you could run:
49+
```
50+
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane get-sut-responses --sut_id {sut_id} --prompts tests/data/prompts.csv --experiment expname
51+
```
52+
After running the command, you'd see the `run_id` in the output from mlflow,
53+
or you can get the `run_id` via the MLFlow UI.
54+
55+
Then you can run annotations with:
56+
```
57+
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane annotate --annotator_id {annotator_id} --experiment expname --response_run_id {run_id}
58+
```
59+
4160
## TODO
4261
4362
- [ ] Scoring against ground truth (measurement runner functionality)

src/modelplane/runways/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from modelgauge.load_plugins import load_plugins
2+
3+
4+
load_plugins(disable_progress_bar=True)

src/modelplane/runways/annotator.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import tempfile
1111
from collections import defaultdict
1212

13-
import click
1413
import jsonlines
1514
import mlflow
1615
import mlflow.artifacts
@@ -19,7 +18,6 @@
1918

2019
from modelgauge.annotation_pipeline import ANNOTATOR_CSV_INPUT_COLUMNS
2120
from modelgauge.annotator_registry import ANNOTATORS
22-
from modelgauge.load_plugins import load_plugins
2321
from modelgauge.pipeline_runner import AnnotatorRunner
2422

2523
from modelplane.runways.utils import (
@@ -28,65 +26,6 @@
2826
is_debug_mode,
2927
setup_annotator_credentials,
3028
)
31-
from modelplane.utils.env import load_from_dotenv
32-
33-
load_plugins(disable_progress_bar=True)
34-
35-
36-
@click.command(name="annotate")
37-
@click.option(
38-
"--annotator_id",
39-
type=str,
40-
required=True,
41-
help="The SUT UID to use.",
42-
)
43-
@click.option(
44-
"--experiment",
45-
type=str,
46-
required=True,
47-
help="The experiment name to use. If the experiment does not exist, it will be created.",
48-
)
49-
@click.option(
50-
"--response_run_id",
51-
type=str,
52-
required=True,
53-
help="The run ID corresponding to the responses to annotate.",
54-
)
55-
@click.option(
56-
"--overwrite",
57-
is_flag=True,
58-
default=False,
59-
help="Use the response_run_id to save annotation artifact. Any existing annotation artifact will be overwritten. If not set, a new run will be created.",
60-
)
61-
@click.option(
62-
"--cache_dir",
63-
type=str,
64-
default=None,
65-
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
66-
)
67-
@click.option(
68-
"--n_jobs",
69-
type=int,
70-
default=1,
71-
help="The number of jobs to run in parallel. Defaults to 1.",
72-
)
73-
@load_from_dotenv
74-
def get_annotations(
75-
annotator_id: str,
76-
experiment: str,
77-
response_run_id: str,
78-
overwrite: bool = False,
79-
cache_dir: str | None = None,
80-
n_jobs: int = 1,
81-
):
82-
return annotate(
83-
annotator_id=annotator_id,
84-
experiment=experiment,
85-
response_run_id=response_run_id,
86-
overwrite=overwrite,
87-
cache_dir=cache_dir,
88-
n_jobs=n_jobs,
89-
)
9029

9130

9231
def annotate(

src/modelplane/runways/responder.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import pathlib
44
import tempfile
55

6-
import click
76
import mlflow
87

9-
from modelgauge.load_plugins import load_plugins
108
from modelgauge.pipeline_runner import PromptRunner
119
from modelgauge.sut_registry import SUTS
1210

@@ -16,60 +14,6 @@
1614
is_debug_mode,
1715
setup_sut_credentials,
1816
)
19-
from modelplane.utils.env import load_from_dotenv
20-
21-
load_plugins(disable_progress_bar=True)
22-
23-
24-
@click.command(name="get-responses")
25-
@click.option(
26-
"--sut_id",
27-
type=str,
28-
required=True,
29-
help="The SUT UID to use.",
30-
)
31-
@click.option(
32-
"--prompts",
33-
type=str,
34-
required=True,
35-
help="The path to the input prompts file.",
36-
)
37-
@click.option(
38-
"--experiment",
39-
type=str,
40-
required=True,
41-
help="The experiment name to use. If the experiment does not exist, it will be created.",
42-
)
43-
@click.option(
44-
"--cache_dir",
45-
type=str,
46-
default=None,
47-
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
48-
)
49-
@click.option(
50-
"--n_jobs",
51-
type=int,
52-
default=1,
53-
help="The number of jobs to run in parallel. Defaults to 1.",
54-
)
55-
@load_from_dotenv
56-
def get_sut_responses(
57-
sut_id: str,
58-
prompts: str,
59-
experiment: str,
60-
cache_dir: str | None = None,
61-
n_jobs: int = 1,
62-
):
63-
"""
64-
Run the pipeline to get responses from SUTs.
65-
"""
66-
return respond(
67-
sut_id=sut_id,
68-
prompts=prompts,
69-
experiment=experiment,
70-
cache_dir=cache_dir,
71-
n_jobs=n_jobs,
72-
)
7317

7418

7519
def respond(

src/modelplane/runways/run.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import click
2+
3+
4+
from modelplane.runways.annotator import annotate
5+
from modelplane.runways.responder import respond
6+
from modelplane.utils.env import load_from_dotenv
7+
8+
9+
@click.group(name="modelplane")
10+
def cli():
11+
pass
12+
13+
14+
@cli.command(name="get-sut-responses")
15+
@click.option(
16+
"--sut_id",
17+
type=str,
18+
required=True,
19+
help="The SUT UID to use.",
20+
)
21+
@click.option(
22+
"--prompts",
23+
type=str,
24+
required=True,
25+
help="The path to the input prompts file.",
26+
)
27+
@click.option(
28+
"--experiment",
29+
type=str,
30+
required=True,
31+
help="The experiment name to use. If the experiment does not exist, it will be created.",
32+
)
33+
@click.option(
34+
"--cache_dir",
35+
type=str,
36+
default=None,
37+
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
38+
)
39+
@click.option(
40+
"--n_jobs",
41+
type=int,
42+
default=1,
43+
help="The number of jobs to run in parallel. Defaults to 1.",
44+
)
45+
@load_from_dotenv
46+
def get_sut_responses(
47+
sut_id: str,
48+
prompts: str,
49+
experiment: str,
50+
cache_dir: str | None = None,
51+
n_jobs: int = 1,
52+
):
53+
"""
54+
Run the pipeline to get responses from SUTs.
55+
"""
56+
return respond(
57+
sut_id=sut_id,
58+
prompts=prompts,
59+
experiment=experiment,
60+
cache_dir=cache_dir,
61+
n_jobs=n_jobs,
62+
)
63+
64+
65+
@cli.command(name="annotate")
66+
@click.option(
67+
"--annotator_id",
68+
type=str,
69+
required=True,
70+
help="The SUT UID to use.",
71+
)
72+
@click.option(
73+
"--experiment",
74+
type=str,
75+
required=True,
76+
help="The experiment name to use. If the experiment does not exist, it will be created.",
77+
)
78+
@click.option(
79+
"--response_run_id",
80+
type=str,
81+
required=True,
82+
help="The run ID corresponding to the responses to annotate.",
83+
)
84+
@click.option(
85+
"--overwrite",
86+
is_flag=True,
87+
default=False,
88+
help="Use the response_run_id to save annotation artifact. Any existing annotation artifact will be overwritten. If not set, a new run will be created.",
89+
)
90+
@click.option(
91+
"--cache_dir",
92+
type=str,
93+
default=None,
94+
help="The cache directory. Defaults to None. Local directory used to cache LLM responses.",
95+
)
96+
@click.option(
97+
"--n_jobs",
98+
type=int,
99+
default=1,
100+
help="The number of jobs to run in parallel. Defaults to 1.",
101+
)
102+
@load_from_dotenv
103+
def get_annotations(
104+
annotator_id: str,
105+
experiment: str,
106+
response_run_id: str,
107+
overwrite: bool = False,
108+
cache_dir: str | None = None,
109+
n_jobs: int = 1,
110+
):
111+
return annotate(
112+
annotator_id=annotator_id,
113+
experiment=experiment,
114+
response_run_id=response_run_id,
115+
overwrite=overwrite,
116+
cache_dir=cache_dir,
117+
n_jobs=n_jobs,
118+
)
119+
120+
121+
if __name__ == "__main__":
122+
cli()

tests/it/test_cli.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from click.testing import CliRunner
2+
3+
from modelplane.runways.run import cli
4+
5+
6+
def test_main_help():
7+
runner = CliRunner()
8+
result = runner.invoke(
9+
cli,
10+
[
11+
"--help",
12+
],
13+
)
14+
assert result.exit_code == 0
15+
assert "get-sut-responses" in result.output
16+
assert "annotate" in result.output
17+
18+
19+
def test_get_sut_responses_help():
20+
runner = CliRunner()
21+
result = runner.invoke(
22+
cli,
23+
[
24+
"get-sut-responses",
25+
"--help",
26+
],
27+
)
28+
assert result.exit_code == 0
29+
30+
31+
def test_annotate_help():
32+
runner = CliRunner()
33+
result = runner.invoke(
34+
cli,
35+
[
36+
"annotate",
37+
"--help",
38+
],
39+
)
40+
assert result.exit_code == 0

tests/it/test_health.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Ensures the mlflow tracking server is live.
22

3-
import mlflow
43
from modelplane.mlflow.health import tracking_server_is_live
54

65

0 commit comments

Comments
 (0)