Skip to content

Commit c8b3d0b

Browse files
authored
Merge pull request #162 from georgian-io/unit-tests
[Unit Tests] Introducing Unit Tests
2 parents 79a6889 + 9b9ec34 commit c8b3d0b

21 files changed

+998
-21
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ ENV/
4747
env.bak/
4848
venv.bak/
4949

50+
# Coverage Report
51+
.coverage
52+
/htmlcov

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
test-coverage:
2+
pytest --cov=llmtune tests/
3+
4+
fix-format:
5+
ruff check --fix
6+
ruff format
7+
18
build-release:
29
rm -rf dist
310
rm -rf build

llmtune/qa/generics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self.ground_truths = ground_truths
4949
self.model_preds = model_preds
5050

51-
self.test_results = {}
51+
self._results = {}
5252

5353
@staticmethod
5454
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
@@ -60,29 +60,29 @@ def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
6060

6161
def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
6262
test_results = {}
63-
for test in zip(self.tests):
63+
for test in self.tests:
6464
metrics = []
6565
for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds):
6666
metrics.append(test.get_metric(prompt, ground_truth, model_pred))
6767
test_results[test.test_name] = metrics
6868

69-
self.test_results = test_results
69+
self._results = test_results
7070
return test_results
7171

7272
@property
7373
def test_results(self):
74-
return self.test_results if self.test_results else self.run_tests()
74+
return self._results if self._results else self.run_tests()
7575

7676
def print_test_results(self):
77-
result_dictionary = self.test_results()
77+
result_dictionary = self.test_results
7878
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
7979
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
8080
median_values = {key: statistics.median(column_data[key]) for key in column_data}
8181
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
8282
# Use the RichUI class to display the table
83-
RichUI.display_table(result_dictionary, mean_values, median_values, stdev_values)
83+
RichUI.qa_display_table(result_dictionary, mean_values, median_values, stdev_values)
8484

8585
def save_test_results(self, path: str):
8686
# TODO: save these!
87-
resultant_dataframe = pd.DataFrame(self.test_results())
87+
resultant_dataframe = pd.DataFrame(self.test_results)
8888
resultant_dataframe.to_csv(path, index=False)

llmtune/qa/qa_tests.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rouge_score import rouge_scorer
1010
from transformers import DistilBertModel, DistilBertTokenizer
1111

12-
from llmtune.qa.generics import LLMQaTest, TestRegistry
12+
from llmtune.qa.generics import LLMQaTest, QaTestRegistry
1313

1414

1515
model_name = "distilbert-base-uncased"
@@ -21,7 +21,7 @@
2121
nltk.download("averaged_perceptron_tagger")
2222

2323

24-
@TestRegistry.register("summary_length")
24+
@QaTestRegistry.register("summary_length")
2525
class LengthTest(LLMQaTest):
2626
@property
2727
def test_name(self) -> str:
@@ -31,7 +31,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
3131
return abs(len(ground_truth) - len(model_prediction))
3232

3333

34-
@TestRegistry.register("jaccard_similarity")
34+
@QaTestRegistry.register("jaccard_similarity")
3535
class JaccardSimilarityTest(LLMQaTest):
3636
@property
3737
def test_name(self) -> str:
@@ -45,10 +45,10 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
4545
union_size = len(set_ground_truth.union(set_model_prediction))
4646

4747
similarity = intersection_size / union_size if union_size != 0 else 0
48-
return similarity
48+
return float(similarity)
4949

5050

51-
@TestRegistry.register("dot_product")
51+
@QaTestRegistry.register("dot_product")
5252
class DotProductSimilarityTest(LLMQaTest):
5353
@property
5454
def test_name(self) -> str:
@@ -64,10 +64,10 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
6464
embedding_ground_truth = self._encode_sentence(ground_truth)
6565
embedding_model_prediction = self._encode_sentence(model_prediction)
6666
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
67-
return dot_product_similarity
67+
return float(dot_product_similarity)
6868

6969

70-
@TestRegistry.register("rouge_score")
70+
@QaTestRegistry.register("rouge_score")
7171
class RougeScoreTest(LLMQaTest):
7272
@property
7373
def test_name(self) -> str:
@@ -79,7 +79,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
7979
return float(scores["rouge1"].precision)
8080

8181

82-
@TestRegistry.register("word_overlap")
82+
@QaTestRegistry.register("word_overlap")
8383
class WordOverlapTest(LLMQaTest):
8484
@property
8585
def test_name(self) -> str:
@@ -100,7 +100,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
100100

101101
common_words = words_model_prediction.intersection(words_ground_truth)
102102
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
103-
return overlap_percentage
103+
return float(overlap_percentage)
104104

105105

106106
class PosCompositionTest(LLMQaTest):
@@ -112,7 +112,7 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
112112
return round(len(pos_words) / total_words, 2)
113113

114114

115-
@TestRegistry.register("verb_percent")
115+
@QaTestRegistry.register("verb_percent")
116116
class VerbPercent(PosCompositionTest):
117117
@property
118118
def test_name(self) -> str:
@@ -122,7 +122,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
122122
return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"])
123123

124124

125-
@TestRegistry.register("adjective_percent")
125+
@QaTestRegistry.register("adjective_percent")
126126
class AdjectivePercent(PosCompositionTest):
127127
@property
128128
def test_name(self) -> str:
@@ -132,7 +132,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
132132
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])
133133

134134

135-
@TestRegistry.register("noun_percent")
135+
@QaTestRegistry.register("noun_percent")
136136
class NounPercent(PosCompositionTest):
137137
@property
138138
def test_name(self) -> str:

llmtune/ui/rich_ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def qa_found():
182182
pass
183183

184184
@staticmethod
185-
def qa_display_table(self, result_dictionary, mean_values, median_values, stdev_values):
185+
def qa_display_table(result_dictionary, mean_values, median_values, stdev_values):
186186
# Create a table
187187
table = Table(show_header=True, header_style="bold", title="Test Results")
188188

poetry.lock

Lines changed: 151 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ shellingham = "^1.5.4"
6767

6868
[tool.poetry.group.dev.dependencies]
6969
ruff = "~0.3.5"
70+
pytest = "^8.1.1"
71+
pytest-cov = "^5.0.0"
72+
pytest-mock = "^3.14.0"
7073

7174
[build-system]
7275
requires = ["poetry-core", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
@@ -92,4 +95,18 @@ indent-style = "space"
9295
skip-magic-trailing-comma = false
9396
line-ending = "auto"
9497

98+
[tool.coverage.run]
99+
omit = [
100+
# Ignore UI for now as this might change quite often
101+
"llmtune/ui/*",
102+
"llmtune/utils/rich_print_utils.py"
103+
]
104+
105+
[tool.coverage.report]
106+
skip_empty = true
107+
exclude_also = [
108+
"pass",
109+
]
95110

111+
[tool.pytest.ini_options]
112+
addopts = "--cov=llmtune --cov-report term-missing"

test_utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)