Skip to content

Commit 0352c77

Browse files
authored
Merge pull request #131 from viveksingh-ctrl/integrate-main-qa
Integrate main qa
2 parents 999f7b2 + dc42fb0 commit 0352c77

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

llmtune/cli/toolkit.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from llmtune.finetune.lora import LoRAFinetune
1616
from llmtune.inference.lora import LoRAInference
1717
from llmtune.pydantic_models.config_model import Config
18+
from llmtune.qa.generics import LLMTestSuite, QaTestRegistry
1819
from llmtune.ui.rich_ui import RichUI
1920
from llmtune.utils.ablation_utils import generate_permutations
2021
from llmtune.utils.save_utils import DirectoryHelper
@@ -84,15 +85,13 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
8485
else:
8586
RichUI.results_found(results_path)
8687

87-
# QA -------------------------------
88-
# RichUI.before_qa()
89-
# qa_path = dir_helper.save_paths.qa
90-
# if not exists(qa_path) or not listdir(qa_path):
91-
# # TODO: Instantiate unit test classes
92-
# # TODO: Load results.csv
93-
# # TODO: Run Unit Tests
94-
# # TODO: Save Unit Test Results
95-
# pass
88+
RichUI.before_qa()
89+
qa_file_path = dir_helper.save_paths.qa_file
90+
if not qa_file_path.exists():
91+
llm_tests = config.qa.llm_tests
92+
tests = QaTestRegistry.create_tests_from_list(llm_tests)
93+
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
94+
test_suite.save_test_results(qa_file_path)
9695

9796

9897
@app.command("run")

llmtune/qa/generics.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def __init__(
5050

5151
self.test_results = {}
5252

53+
@staticmethod
54+
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
55+
results_df = pd.read_csv(file_path)
56+
prompts = results_df["prompt"].tolist()
57+
ground_truths = results_df["ground_truth"].tolist()
58+
model_preds = results_df["model_prediction"].tolist()
59+
return LLMTestSuite(tests, prompts, ground_truths, model_preds)
60+
5361
def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
5462
test_results = {}
5563
for test in zip(self.tests):

0 commit comments

Comments
 (0)