|
3 | 3 | from os.path import exists, join |
4 | 4 |
|
5 | 5 | import torch |
| 6 | +import pandas as pd |
6 | 7 | import typer |
7 | 8 | import yaml |
8 | 9 | from pydantic import ValidationError |
|
15 | 16 | from llmtune.ui.rich_ui import RichUI |
16 | 17 | from llmtune.utils.ablation_utils import generate_permutations |
17 | 18 | from llmtune.utils.save_utils import DirectoryHelper |
18 | | - |
| 19 | +from llmtune.qa.generics import QaTestRegistry, LLMTestSuite |
19 | 20 |
|
20 | 21 | hf_utils.logging.set_verbosity_error() |
21 | 22 | torch._logging.set_logs(all=logging.CRITICAL) |
@@ -73,15 +74,21 @@ def run_one_experiment(config: Config, config_path: str) -> None: |
73 | 74 | else: |
74 | 75 | RichUI.inference_found(results_path) |
75 | 76 |
|
76 | | - # QA ------------------------------- |
77 | | - # RichUI.before_qa() |
78 | | - # qa_path = dir_helper.save_paths.qa |
79 | | - # if not exists(qa_path) or not listdir(qa_path): |
80 | | - # # TODO: Instantiate unit test classes |
81 | | - # # TODO: Load results.csv |
82 | | - # # TODO: Run Unit Tests |
83 | | - # # TODO: Save Unit Test Results |
84 | | - # pass |
| 77 | + RichUI.before_qa() |
| 78 | + qa_path = dir_helper.save_paths.qa |
| 79 | + if not exists(qa_path) or not listdir(qa_path): |
| 80 | + # TODO: Instantiate unit test classes |
| 81 | + llm_tests = config.get("qa", {}).get("llm_tests", []) |
| 82 | + tests = QaTestRegistry.create_tests_from_list(llm_tests) |
| 83 | + # TODO: Load results.csv |
| 84 | + results_df = pd.read_csv(results_file_path) |
| 85 | + prompts = results_df["prompt"].tolist() |
| 86 | + ground_truths = results_df["ground_truth"].tolist() |
| 87 | + model_preds = results_df["model_prediction"].tolist() |
| 88 | + # TODO: Run Unit Tests |
| 89 | + test_suite = LLMTestSuite(tests, prompts, ground_truths, model_preds) |
| 90 | + # TODO: Save Unit Test Results |
| 91 | + test_suite.save_test_results("unit_test_results.csv") |
85 | 92 |
|
86 | 93 |
|
87 | 94 | @app.command() |
|
0 commit comments