-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
48 lines (39 loc) · 1.45 KB
/
main.py
File metadata and controls
48 lines (39 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import rich_click as click
from dotenv import load_dotenv
from src.llm_client import LLM_CLIENT_REGISTRY
# from src.llm_client import LLM_CLIENT_REGISTRY
from src.models.models import EvaluationConfig
from src.pipeline import EvaluationPipeline
load_dotenv()
@click.command()
@click.option("--csv", required=True, type=click.Path(exists=True), help="Path to input CSV file")
@click.option(
"--model-provider",
default="mistral",
type=click.Choice(list(LLM_CLIENT_REGISTRY.keys())),
show_default=True,
help="Provider of the Large Language Model judge.",
)
@click.option("--model-name", required=False, help="Name of the LLM model")
@click.option("--temperature", default=0.0, type=float, show_default=True, help="LLM temperature")
@click.option("--seed", default=42, type=int, show_default=True, help="Random seed")
@click.option("--api-key", required=False, help="API key for the LLM provider")
def main(csv, model_provider, model_name, temperature, seed, api_key):
config = EvaluationConfig(
csv=csv,
model_provider=model_provider,
model_name=model_name,
temperature=temperature,
seed=seed,
api_key=api_key,
)
pipeline = EvaluationPipeline(
model_provider=config.model_provider,
model_name=config.model_name,
temperature=config.temperature,
seed=config.seed,
api_key=config.api_key,
)
pipeline.run(config.csv)
if __name__ == "__main__":
main()