Skip to content

Commit 97a16d7

Browse files
updated the changes
1 parent dd8ef69 commit 97a16d7

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

llmtune/cli/toolkit.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from os import listdir
33
from os.path import exists, join
44

5-
import torch
65
import pandas as pd
6+
import torch
77
import typer
88
import yaml
99
from pydantic import ValidationError
@@ -13,10 +13,11 @@
1313
from llmtune.finetune.lora import LoRAFinetune
1414
from llmtune.inference.lora import LoRAInference
1515
from llmtune.pydantic_models.config_model import Config
16+
from llmtune.qa.generics import LLMTestSuite, QaTestRegistry
1617
from llmtune.ui.rich_ui import RichUI
1718
from llmtune.utils.ablation_utils import generate_permutations
1819
from llmtune.utils.save_utils import DirectoryHelper
19-
from llmtune.qa.generics import QaTestRegistry, LLMTestSuite
20+
2021

2122
hf_utils.logging.set_verbosity_error()
2223
torch._logging.set_logs(all=logging.CRITICAL)
@@ -82,13 +83,8 @@ def run_one_experiment(config: Config, config_path: str) -> None:
8283
tests = QaTestRegistry.create_tests_from_list(llm_tests)
8384
# TODO: Load results.csv
8485
results_df = pd.read_csv(results_file_path)
85-
prompts = results_df["prompt"].tolist()
86-
ground_truths = results_df["ground_truth"].tolist()
87-
model_preds = results_df["model_prediction"].tolist()
88-
# TODO: Run Unit Tests
89-
test_suite = LLMTestSuite(tests, prompts, ground_truths, model_preds)
90-
# TODO: Save Unit Test Results
91-
test_suite.save_test_results("unit_test_results.csv")
86+
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
87+
test_suite.save_test_results(os.path.join(qa_path, "unit_test_results.csv"))
9288

9389

9490
@app.command()

llmtune/qa/generics.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def __init__(
5050

5151
self.test_results = {}
5252

53+
@staticmethod
54+
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
55+
results_df = pd.read_csv(file_path)
56+
prompts = results_df["prompt"].tolist()
57+
ground_truths = results_df["ground_truth"].tolist()
58+
model_preds = results_df["model_prediction"].tolist()
59+
return LLMTestSuite(tests, prompts, ground_truths, model_preds)
60+
5361
def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
5462
test_results = {}
5563
for test in zip(self.tests):

0 commit comments

Comments
 (0)