Skip to content

Commit f9346c1

Browse files
authored
Add tests for SQLite and configure CI (#6)
* Add possibility to run embedding test without model * Move DB client tests to submodule * Refactor annotations table * Refactor test cases table * Tests for annotations table * Fix embeddings test * TestCases table test * Requirements table tests * Turn on foreign keys and CasesToAnnos table tests * CasesToAnnos table - add insertion fact return * AnnosToReqs table tests * Semver implementation * Check SQLite version * Move torch dependencies to production group * Switch logging level to warning * Configure tests CI * Fix coverage script * Ignore errors for coverage * Run tests without coverage * Revert "Run tests without coverage" This reverts commit 1b895f6. * Try avoid torch installation in uv * Check only project files * Do not download production dependencies on coverage report generation * Fix wrong method call * Install ruff * Lint in CI * Configure ruff linter * Check formatting in CI * Format files * Add last empty lines to files * Correct old SQLite error message
1 parent d063c1d commit f9346c1

36 files changed

+892
-161
lines changed

.github/workflows/test.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: test
2+
3+
on: [push, workflow_dispatch]
4+
5+
jobs:
6+
test:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- name: Checkout code
10+
uses: actions/checkout@v2
11+
12+
- name: Set up Python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: '3.9'
16+
17+
- name: Setup uv
18+
uses: astral-sh/setup-uv@v5
19+
20+
- name: Install dependencies
21+
run: uv sync --no-group production
22+
23+
- name: Run tests with coverage
24+
run: |
25+
uv run --no-group production -m coverage run --source=test2text -m unittest discover tests
26+
uv run --no-group production -m coverage report --ignore-errors
27+
28+
- name: Lint
29+
run: uvx ruff check
30+
31+
- name: Check formatting
32+
run: uvx ruff check --select E --ignore "E402,E501" --fix

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.idea
2-
.venv
2+
.venv
3+
.coverage

CONTRIBUTING.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## Utilities commands
2+
3+
Run linter from CI with automatic fixes if possible:
4+
5+
```bash
6+
uvx ruff check --fix
7+
```
8+
9+
Automatically format code:
10+
11+
```bash
12+
uvx ruff format
13+
```

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ erDiagram
9797
`CasesToAnnos` table.
9898
5. Create report about the coverage of the requirements by the test cases:
9999
- By running your SQL queries in `requirements.db` SQLite database.
100-
- Or by running `report.py` script.
100+
- Or by running `report.py` script.

index_annotations.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@
99

1010
BATCH_SIZE = 100
1111

12-
if __name__ == '__main__':
13-
db = DbClient('./private/requirements.db')
14-
annotations_folder = Path('./private/annotations')
12+
if __name__ == "__main__":
13+
db = DbClient("./private/requirements.db")
14+
annotations_folder = Path("./private/annotations")
1515
# Write annotations to the database
1616
for i, file in enumerate(os.listdir(annotations_folder)):
17-
logging.info(f'Processing file {i + 1}: {file}')
18-
with open(annotations_folder / file, newline='', encoding='utf-8', mode='r') as csvfile:
17+
logging.info(f"Processing file {i + 1}: {file}")
18+
with open(
19+
annotations_folder / file, newline="", encoding="utf-8", mode="r"
20+
) as csvfile:
1921
reader = csv.reader(csvfile)
2022
for row in reader:
2123
[summary, _, test_script, test_case, *_] = row
22-
anno_id = db.annotations.insert(summary=summary)
23-
tc_id = db.test_cases.insert(test_script=test_script, test_case=test_case)
24+
anno_id = db.annotations.get_or_insert(summary=summary)
25+
tc_id = db.test_cases.get_or_insert(
26+
test_script=test_script, test_case=test_case
27+
)
2428
db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id)
2529
db.conn.commit()
2630
# Embed annotations
@@ -38,17 +42,20 @@ def write_batch():
3842
embeddings = embed_annotations_batch([annotation for _, annotation in batch])
3943
for i, (anno_id, annotation) in enumerate(batch):
4044
embedding = embeddings[i]
41-
db.conn.execute("""
45+
db.conn.execute(
46+
"""
4247
UPDATE Annotations
4348
SET embedding = ?
4449
WHERE id = ?
45-
""", (embedding, anno_id))
50+
""",
51+
(embedding, anno_id),
52+
)
4653
db.conn.commit()
4754
batch = []
4855

4956
for i, (anno_id, summary) in enumerate(annotations.fetchall()):
5057
if i % 100 == 0:
51-
logging.info(f'Processing annotation {i + 1}/{annotations_count}')
58+
logging.info(f"Processing annotation {i + 1}/{annotations_count}")
5259
batch.append((anno_id, summary))
5360
if len(batch) == BATCH_SIZE:
5461
write_batch()
@@ -57,4 +64,4 @@ def write_batch():
5764
cursor = db.conn.execute("""
5865
SELECT COUNT(*) FROM Annotations
5966
""")
60-
print(cursor.fetchone())
67+
print(cursor.fetchone())

index_requirements.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,38 @@
11
import csv
22
import logging
3+
34
logging.basicConfig(level=logging.DEBUG)
45
from test2text.db import DbClient
56
from test2text.embeddings.embed import embed_requirements_batch
67

78
BATCH_SIZE = 100
89

9-
if __name__ == '__main__':
10-
db = DbClient('./private/requirements.db')
10+
if __name__ == "__main__":
11+
db = DbClient("./private/requirements.db")
1112
# Index requirements
12-
with open('./private/TRACEABILITY MATRIX.csv', newline='', encoding='utf-8', mode='r') as csvfile:
13+
with open(
14+
"./private/TRACEABILITY MATRIX.csv", newline="", encoding="utf-8", mode="r"
15+
) as csvfile:
1316
reader = csv.reader(csvfile)
1417
for _ in range(3):
1518
next(reader)
1619
batch = []
17-
last_requirement = ''
20+
last_requirement = ""
21+
1822
def write_batch():
1923
global batch
20-
embeddings = embed_requirements_batch([requirement for _, requirement in batch])
24+
embeddings = embed_requirements_batch(
25+
[requirement for _, requirement in batch]
26+
)
2127
for i, (external_id, requirement) in enumerate(batch):
2228
embedding = embeddings[i]
2329
db.requirements.insert(requirement, embedding, external_id)
2430
db.conn.commit()
2531
batch = []
32+
2633
for row in reader:
2734
[external_id, requirement, *_] = row
28-
if requirement.startswith('...'):
35+
if requirement.startswith("..."):
2936
requirement = last_requirement + requirement[3:]
3037
last_requirement = requirement
3138
batch.append((external_id, requirement))
@@ -36,4 +43,4 @@ def write_batch():
3643
cursor = db.conn.execute("""
3744
SELECT COUNT(*) FROM Requirements
3845
""")
39-
print(cursor.fetchone())
46+
print(cursor.fetchone())

link_reqs_to_annos.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
2+
23
logging.basicConfig(level=logging.INFO)
34
logger = logging.getLogger()
45
from test2text.db import DbClient
56

6-
if __name__ == '__main__':
7-
db = DbClient('./private/requirements.db')
7+
if __name__ == "__main__":
8+
db = DbClient("./private/requirements.db")
89
db.annos_to_reqs.recreate_table()
910
# Link requirements to annotations
1011
annotations = db.conn.execute("""
@@ -17,7 +18,7 @@
1718
""")
1819
# Visualize distances
1920
distances = []
20-
logger.info('Processing distances')
21+
logger.info("Processing distances")
2122
current_req_id = None
2223
current_req_annos = 0
2324
for i, (anno_id, req_id, distance) in enumerate(annotations.fetchall()):
@@ -26,9 +27,12 @@
2627
current_req_id = req_id
2728
current_req_annos = 0
2829
if current_req_annos < 5 or distance < 0.7:
29-
db.annos_to_reqs.insert(annotation_id=anno_id, requirement_id=req_id, cached_distance=distance)
30+
db.annos_to_reqs.insert(
31+
annotation_id=anno_id, requirement_id=req_id, cached_distance=distance
32+
)
3033
current_req_annos += 1
3134
db.conn.commit()
3235
import matplotlib.pyplot as plt
36+
3337
plt.hist(distances, bins=100)
34-
plt.savefig('./private/distances.png')
38+
plt.savefig("./private/distances.png")

pyproject.toml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,28 @@ authors = [
88
readme = "README.md"
99
requires-python = ">=3.9"
1010
dependencies = [
11-
"einops>=0.8.1",
1211
"matplotlib>=3.9.4",
13-
"sentence-transformers>=4.0.1",
1412
"sqlite-vec>=0.1.6",
15-
"tabbyset>=1.0.0",
16-
"torch",
13+
"tabbyset>=1.0.0"
1714
]
1815

16+
[dependency-groups]
17+
dev = [
18+
"coverage>=7.9.2",
19+
"ruff>=0.12.3",
20+
]
21+
production = [
22+
"einops>=0.8.1",
23+
"sentence-transformers>=4.0.1",
24+
"torch"
25+
]
26+
27+
[tool.uv]
28+
default-groups = "all"
29+
30+
[tool.ruff.lint]
31+
ignore = ["E402"]
32+
1933
[tool.uv.sources]
2034
torch = {index = "pytorch-cpu"}
2135

report.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from sympy import limit
2-
31
from test2text.db import DbClient
42
from tqdm import tqdm
53

4+
65
def add_new_line(summary):
7-
return summary.replace('\n', '<br>')
6+
return summary.replace("\n", "<br>")
7+
88

9-
if __name__ == '__main__':
10-
with open('./private/report.html', 'w', newline='', encoding='utf-8') as f:
9+
if __name__ == "__main__":
10+
with open("./private/report.html", "w", newline="", encoding="utf-8") as f:
1111
f.write("""
1212
<html>
1313
<head>
@@ -18,20 +18,32 @@ def add_new_line(summary):
1818
</head>
1919
<body>
2020
<main style="padding: 1rem;">
21-
<article class="prose prose-sm container" style="max-width: 48rem; margin: 0 auto;">
21+
<article class="prose prose-sm container"
22+
style="max-width: 48rem; margin: 0 auto;">
2223
""")
2324

24-
db = DbClient('./private/requirements.db')
25-
all_reqs_count = db.conn.execute('SELECT COUNT(*) FROM Requirements').fetchone()[0]
25+
db = DbClient("./private/requirements.db")
26+
all_reqs_count = db.conn.execute(
27+
"SELECT COUNT(*) FROM Requirements"
28+
).fetchone()[0]
2629

2730
f.write('<nav style="break-after: page;"><h1>Table of Contents</h1><ul>')
2831

29-
for requirement in tqdm(db.conn.execute('SELECT * FROM Requirements').fetchall(),
30-
desc='Generating table of contents', unit='requirements', total=all_reqs_count):
32+
for requirement in tqdm(
33+
db.conn.execute("SELECT * FROM Requirements").fetchall(),
34+
desc="Generating table of contents",
35+
unit="requirements",
36+
total=all_reqs_count,
37+
):
3138
req_id, req_external_id, req_summary, _ = requirement
32-
f.write(f'<li><a href="#req_{req_id}">Requirement {req_external_id} ({req_id})</a></li>')
39+
f.write(f"""
40+
<li>
41+
<a href="#req_{req_id}">
42+
Requirement {req_external_id} ({req_id})
43+
</a>
44+
</li>""")
3345

34-
f.write('</ul></nav>')
46+
f.write("</ul></nav>")
3547

3648
data = db.conn.execute("""
3749
SELECT
@@ -60,9 +72,12 @@ def add_new_line(summary):
6072
current_req_id = None
6173
current_annotations = {}
6274
current_test_scripts = set()
63-
progress_bar = tqdm(total=all_reqs_count, desc='Generating report', unit='requirements')
75+
progress_bar = tqdm(
76+
total=all_reqs_count, desc="Generating report", unit="requirements"
77+
)
6478

6579
written_count = 0
80+
6681
def write_requirement():
6782
global written_count
6883
# if written_count > 5:
@@ -77,15 +92,26 @@ def write_requirement():
7792
""")
7893
for anno_id, (anno_summary, distance) in current_annotations.items():
7994
f.write(
80-
f'<li>Annotation {anno_id} (distance: {distance:.3f}): <p>{add_new_line(anno_summary)}</p></li>')
81-
f.write('</ul>')
82-
f.write('<h3>Test Scripts</h3><ul>')
95+
f"<li>Annotation {anno_id} (distance: {distance:.3f}): <p>{add_new_line(anno_summary)}</p></li>"
96+
)
97+
f.write("</ul>")
98+
f.write("<h3>Test Scripts</h3><ul>")
8399
for test_script in current_test_scripts:
84100
f.write(f"<li>{test_script}</li>")
85-
f.write('</ul></section>')
101+
f.write("</ul></section>")
86102

87103
for row in data.fetchall():
88-
req_id, req_external_id, req_summary, anno_id, anno_summary, distance, case_id, test_script, test_case = row
104+
(
105+
req_id,
106+
req_external_id,
107+
req_summary,
108+
anno_id,
109+
anno_summary,
110+
distance,
111+
case_id,
112+
test_script,
113+
test_case,
114+
) = row
89115
if req_id != current_req_id:
90116
if current_req_id is not None:
91117
write_requirement()
@@ -102,4 +128,4 @@ def write_requirement():
102128
</main>
103129
</body>
104130
</html>
105-
""")
131+
""")

test2text/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .client import DbClient
1+
__all__ = ["DbClient"]
2+
from .client import DbClient

0 commit comments

Comments
 (0)