Skip to content

Commit 67976f1

Browse files
committed
fixing llm QA methods + integration
1 parent def3450 commit 67976f1

File tree

4 files changed

+39
-23
lines changed

4 files changed

+39
-23
lines changed

llmtune/cli/toolkit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from llmtune.finetune.lora import LoRAFinetune
1616
from llmtune.inference.lora import LoRAInference
1717
from llmtune.pydantic_models.config_model import Config
18-
from llmtune.qa.generics import LLMTestSuite, QaTestRegistry
18+
from llmtune.qa.generics import LLMTestSuite
19+
from llmtune.qa.qa_tests import QaTestRegistry
1920
from llmtune.ui.rich_ui import RichUI
2021
from llmtune.utils.ablation_utils import generate_permutations
2122
from llmtune.utils.save_utils import DirectoryHelper
@@ -92,6 +93,7 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
9293
tests = QaTestRegistry.create_tests_from_list(llm_tests)
9394
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
9495
test_suite.save_test_results(qa_file_path)
96+
test_suite.print_test_results()
9597

9698

9799
@app.command("run")

llmtune/pydantic_models/config_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,4 @@ class Config(BaseModel):
244244
lora: LoraConfig
245245
training: TrainingConfig
246246
inference: InferenceConfig
247+
qa: QaConfig

llmtune/qa/generics.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
import statistics
23
from abc import ABC, abstractmethod
34
from typing import Dict, List, Union
@@ -18,23 +19,6 @@ def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[f
1819
pass
1920

2021

21-
class QaTestRegistry:
22-
registry = {}
23-
24-
@classmethod
25-
def register(cls, *names):
26-
def inner_wrapper(wrapped_class):
27-
for name in names:
28-
cls.registry[name] = wrapped_class
29-
return wrapped_class
30-
31-
return inner_wrapper
32-
33-
@classmethod
34-
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
35-
return [cls.create_test(test) for test in test_names]
36-
37-
3822
class LLMTestSuite:
3923
def __init__(
4024
self,
@@ -51,11 +35,17 @@ def __init__(
5135
self._results = {}
5236

5337
@staticmethod
54-
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
38+
def from_csv(
39+
file_path: str,
40+
tests: List[LLMQaTest],
41+
prompt_col: str = "Prompt",
42+
gold_col: str = "Ground Truth",
43+
pred_col="Predicted",
44+
) -> "LLMTestSuite":
5545
results_df = pd.read_csv(file_path)
56-
prompts = results_df["prompt"].tolist()
57-
ground_truths = results_df["ground_truth"].tolist()
58-
model_preds = results_df["model_prediction"].tolist()
46+
prompts = results_df[prompt_col].tolist()
47+
ground_truths = results_df[gold_col].tolist()
48+
model_preds = results_df[pred_col].tolist()
5949
return LLMTestSuite(tests, prompts, ground_truths, model_preds)
6050

6151
def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
@@ -84,5 +74,11 @@ def print_test_results(self):
8474

8575
def save_test_results(self, path: str):
8676
# TODO: save these!
77+
path = Path(path)
78+
dir = path.parent
79+
80+
if not dir.exists():
81+
dir.mkdir(parents=True, exist_ok=True)
82+
8783
resultant_dataframe = pd.DataFrame(self.test_results)
8884
resultant_dataframe.to_csv(path, index=False)

llmtune/qa/qa_tests.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rouge_score import rouge_scorer
1010
from transformers import DistilBertModel, DistilBertTokenizer
1111

12-
from llmtune.qa.generics import LLMQaTest, QaTestRegistry
12+
from llmtune.qa.generics import LLMQaTest
1313

1414

1515
model_name = "distilbert-base-uncased"
@@ -21,6 +21,23 @@
2121
nltk.download("averaged_perceptron_tagger")
2222

2323

24+
class QaTestRegistry:
25+
registry = {}
26+
27+
@classmethod
28+
def register(cls, *names):
29+
def inner_wrapper(wrapped_class):
30+
for name in names:
31+
cls.registry[name] = wrapped_class
32+
return wrapped_class
33+
34+
return inner_wrapper
35+
36+
@classmethod
37+
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
38+
return [cls.registry[test]() for test in test_names]
39+
40+
2441
@QaTestRegistry.register("summary_length")
2542
class LengthTest(LLMQaTest):
2643
@property

0 commit comments

Comments
 (0)