|
2 | 2 | from os import listdir |
3 | 3 | from os.path import exists, join |
4 | 4 |
|
5 | | -import torch |
6 | 5 | import pandas as pd |
| 6 | +import torch |
7 | 7 | import typer |
8 | 8 | import yaml |
9 | 9 | from pydantic import ValidationError |
|
13 | 13 | from llmtune.finetune.lora import LoRAFinetune |
14 | 14 | from llmtune.inference.lora import LoRAInference |
15 | 15 | from llmtune.pydantic_models.config_model import Config |
| 16 | +from llmtune.qa.generics import LLMTestSuite, QaTestRegistry |
16 | 17 | from llmtune.ui.rich_ui import RichUI |
17 | 18 | from llmtune.utils.ablation_utils import generate_permutations |
18 | 19 | from llmtune.utils.save_utils import DirectoryHelper |
19 | | -from llmtune.qa.generics import QaTestRegistry, LLMTestSuite |
| 20 | + |
20 | 21 |
|
21 | 22 | hf_utils.logging.set_verbosity_error() |
22 | 23 | torch._logging.set_logs(all=logging.CRITICAL) |
@@ -82,13 +83,8 @@ def run_one_experiment(config: Config, config_path: str) -> None: |
82 | 83 | tests = QaTestRegistry.create_tests_from_list(llm_tests) |
83 | 84 | # TODO: Load results.csv |
84 | 85 | results_df = pd.read_csv(results_file_path) |
85 | | - prompts = results_df["prompt"].tolist() |
86 | | - ground_truths = results_df["ground_truth"].tolist() |
87 | | - model_preds = results_df["model_prediction"].tolist() |
88 | | - # TODO: Run Unit Tests |
89 | | - test_suite = LLMTestSuite(tests, prompts, ground_truths, model_preds) |
90 | | - # TODO: Save Unit Test Results |
91 | | - test_suite.save_test_results("unit_test_results.csv") |
| 86 | + test_suite = LLMTestSuite.from_csv(results_file_path, tests) |
| 87 | + test_suite.save_test_results(os.path.join(qa_path, "unit_test_results.csv")) |
92 | 88 |
|
93 | 89 |
|
94 | 90 | @app.command() |
|
0 commit comments