Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1d24291
Add possibility to run embedding test without model
d0rich Jul 14, 2025
0fba3d4
Move DB client tests to submodule
d0rich Jul 14, 2025
3111193
Refactor annotations table
d0rich Jul 14, 2025
0390572
Refactor test cases table
d0rich Jul 14, 2025
b8f0f2c
Tests for annotations table
d0rich Jul 14, 2025
dcf7a07
Fix embeddings test
d0rich Jul 14, 2025
27147f7
TestCases table test
d0rich Jul 14, 2025
54c0377
Requirements table tests
d0rich Jul 14, 2025
41d10d3
Turn on foreign keys and CasesToAnnos table tests
d0rich Jul 14, 2025
53af2cc
CasesToAnnos table - add insertion fact return
d0rich Jul 14, 2025
2f16167
AnnosToReqs table tests
d0rich Jul 14, 2025
303f1d5
Semver implementation
d0rich Jul 14, 2025
9d922ad
Check SQLite version
d0rich Jul 14, 2025
91198a4
Move torch dependencies to production group
d0rich Jul 14, 2025
01e5cd3
Switch logging level to warning
d0rich Jul 14, 2025
00e89dc
Configure tests CI
d0rich Jul 14, 2025
add8dbe
Fix coverage script
d0rich Jul 14, 2025
e7ce46c
Ignore errors for coverage
d0rich Jul 14, 2025
1b895f6
Run tests without coverage
d0rich Jul 14, 2025
3722da7
Revert "Run tests without coverage"
d0rich Jul 14, 2025
da6e666
Try avoid torch installation in uv
d0rich Jul 14, 2025
ac28546
Check only project files
d0rich Jul 14, 2025
f85136b
Do not download production dependencies on coverage report generation
d0rich Jul 14, 2025
8dbf74d
Fix wrong method call
d0rich Jul 14, 2025
0533c6d
Install ruff
d0rich Jul 14, 2025
32b8c84
Lint in CI
d0rich Jul 14, 2025
30d496a
Configure ruff linter
d0rich Jul 14, 2025
914be61
Check formatting in CI
d0rich Jul 14, 2025
87f1cc3
Format files
d0rich Jul 14, 2025
c8dc7a0
Add last empty lines to files
d0rich Jul 14, 2025
86bc984
Correct old SQLite error message
d0rich Jul 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: test

on: [push, workflow_dispatch]

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Setup uv
uses: astral-sh/setup-uv@v5

- name: Install dependencies
run: uv sync --no-group production

- name: Run tests with coverage
run: |
uv run --no-group production -m coverage run --source=test2text -m unittest discover tests
uv run --no-group production -m coverage report --ignore-errors

- name: Lint
run: uvx ruff check

- name: Check formatting
run: uvx ruff check --select E --ignore "E402,E501" --fix
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.idea
.venv
.venv
.coverage
13 changes: 13 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Utilities commands

Run linter from CI with automatic fixes if possible:

```bash
uvx ruff check --fix
```

Automatically format code:

```bash
uvx ruff format
```
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ erDiagram
`CasesToAnnos` table.
5. Create report about the coverage of the requirements by the test cases:
- By running your SQL queries in `requirements.db` SQLite database.
- Or by running `report.py` script.
- Or by running `report.py` script.
29 changes: 18 additions & 11 deletions index_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@

BATCH_SIZE = 100

if __name__ == '__main__':
db = DbClient('./private/requirements.db')
annotations_folder = Path('./private/annotations')
if __name__ == "__main__":
db = DbClient("./private/requirements.db")
annotations_folder = Path("./private/annotations")
# Write annotations to the database
for i, file in enumerate(os.listdir(annotations_folder)):
logging.info(f'Processing file {i + 1}: {file}')
with open(annotations_folder / file, newline='', encoding='utf-8', mode='r') as csvfile:
logging.info(f"Processing file {i + 1}: {file}")
with open(
annotations_folder / file, newline="", encoding="utf-8", mode="r"
) as csvfile:
reader = csv.reader(csvfile)
for row in reader:
[summary, _, test_script, test_case, *_] = row
anno_id = db.annotations.insert(summary=summary)
tc_id = db.test_cases.insert(test_script=test_script, test_case=test_case)
anno_id = db.annotations.get_or_insert(summary=summary)
tc_id = db.test_cases.get_or_insert(
test_script=test_script, test_case=test_case
)
db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id)
db.conn.commit()
# Embed annotations
Expand All @@ -38,17 +42,20 @@ def write_batch():
embeddings = embed_annotations_batch([annotation for _, annotation in batch])
for i, (anno_id, annotation) in enumerate(batch):
embedding = embeddings[i]
db.conn.execute("""
db.conn.execute(
"""
UPDATE Annotations
SET embedding = ?
WHERE id = ?
""", (embedding, anno_id))
""",
(embedding, anno_id),
)
db.conn.commit()
batch = []

for i, (anno_id, summary) in enumerate(annotations.fetchall()):
if i % 100 == 0:
logging.info(f'Processing annotation {i + 1}/{annotations_count}')
logging.info(f"Processing annotation {i + 1}/{annotations_count}")
batch.append((anno_id, summary))
if len(batch) == BATCH_SIZE:
write_batch()
Expand All @@ -57,4 +64,4 @@ def write_batch():
cursor = db.conn.execute("""
SELECT COUNT(*) FROM Annotations
""")
print(cursor.fetchone())
print(cursor.fetchone())
21 changes: 14 additions & 7 deletions index_requirements.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
import csv
import logging

logging.basicConfig(level=logging.DEBUG)
from test2text.db import DbClient
from test2text.embeddings.embed import embed_requirements_batch

BATCH_SIZE = 100

if __name__ == '__main__':
db = DbClient('./private/requirements.db')
if __name__ == "__main__":
db = DbClient("./private/requirements.db")
# Index requirements
with open('./private/TRACEABILITY MATRIX.csv', newline='', encoding='utf-8', mode='r') as csvfile:
with open(
"./private/TRACEABILITY MATRIX.csv", newline="", encoding="utf-8", mode="r"
) as csvfile:
reader = csv.reader(csvfile)
for _ in range(3):
next(reader)
batch = []
last_requirement = ''
last_requirement = ""

def write_batch():
global batch
embeddings = embed_requirements_batch([requirement for _, requirement in batch])
embeddings = embed_requirements_batch(
[requirement for _, requirement in batch]
)
for i, (external_id, requirement) in enumerate(batch):
embedding = embeddings[i]
db.requirements.insert(requirement, embedding, external_id)
db.conn.commit()
batch = []

for row in reader:
[external_id, requirement, *_] = row
if requirement.startswith('...'):
if requirement.startswith("..."):
requirement = last_requirement + requirement[3:]
last_requirement = requirement
batch.append((external_id, requirement))
Expand All @@ -36,4 +43,4 @@ def write_batch():
cursor = db.conn.execute("""
SELECT COUNT(*) FROM Requirements
""")
print(cursor.fetchone())
print(cursor.fetchone())
14 changes: 9 additions & 5 deletions link_reqs_to_annos.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
from test2text.db import DbClient

if __name__ == '__main__':
db = DbClient('./private/requirements.db')
if __name__ == "__main__":
db = DbClient("./private/requirements.db")
db.annos_to_reqs.recreate_table()
# Link requirements to annotations
annotations = db.conn.execute("""
Expand All @@ -17,7 +18,7 @@
""")
# Visualize distances
distances = []
logger.info('Processing distances')
logger.info("Processing distances")
current_req_id = None
current_req_annos = 0
for i, (anno_id, req_id, distance) in enumerate(annotations.fetchall()):
Expand All @@ -26,9 +27,12 @@
current_req_id = req_id
current_req_annos = 0
if current_req_annos < 5 or distance < 0.7:
db.annos_to_reqs.insert(annotation_id=anno_id, requirement_id=req_id, cached_distance=distance)
db.annos_to_reqs.insert(
annotation_id=anno_id, requirement_id=req_id, cached_distance=distance
)
current_req_annos += 1
db.conn.commit()
import matplotlib.pyplot as plt

plt.hist(distances, bins=100)
plt.savefig('./private/distances.png')
plt.savefig("./private/distances.png")
22 changes: 18 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,28 @@ authors = [
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"einops>=0.8.1",
"matplotlib>=3.9.4",
"sentence-transformers>=4.0.1",
"sqlite-vec>=0.1.6",
"tabbyset>=1.0.0",
"torch",
"tabbyset>=1.0.0"
]

[dependency-groups]
dev = [
"coverage>=7.9.2",
"ruff>=0.12.3",
]
production = [
"einops>=0.8.1",
"sentence-transformers>=4.0.1",
"torch"
]

[tool.uv]
default-groups = "all"

[tool.ruff.lint]
ignore = ["E402"]

[tool.uv.sources]
torch = {index = "pytorch-cpu"}

Expand Down
64 changes: 45 additions & 19 deletions report.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from sympy import limit

from test2text.db import DbClient
from tqdm import tqdm


def add_new_line(summary):
return summary.replace('\n', '<br>')
return summary.replace("\n", "<br>")


if __name__ == '__main__':
with open('./private/report.html', 'w', newline='', encoding='utf-8') as f:
if __name__ == "__main__":
with open("./private/report.html", "w", newline="", encoding="utf-8") as f:
f.write("""
<html>
<head>
Expand All @@ -18,20 +18,32 @@ def add_new_line(summary):
</head>
<body>
<main style="padding: 1rem;">
<article class="prose prose-sm container" style="max-width: 48rem; margin: 0 auto;">
<article class="prose prose-sm container"
style="max-width: 48rem; margin: 0 auto;">
""")

db = DbClient('./private/requirements.db')
all_reqs_count = db.conn.execute('SELECT COUNT(*) FROM Requirements').fetchone()[0]
db = DbClient("./private/requirements.db")
all_reqs_count = db.conn.execute(
"SELECT COUNT(*) FROM Requirements"
).fetchone()[0]

f.write('<nav style="break-after: page;"><h1>Table of Contents</h1><ul>')

for requirement in tqdm(db.conn.execute('SELECT * FROM Requirements').fetchall(),
desc='Generating table of contents', unit='requirements', total=all_reqs_count):
for requirement in tqdm(
db.conn.execute("SELECT * FROM Requirements").fetchall(),
desc="Generating table of contents",
unit="requirements",
total=all_reqs_count,
):
req_id, req_external_id, req_summary, _ = requirement
f.write(f'<li><a href="#req_{req_id}">Requirement {req_external_id} ({req_id})</a></li>')
f.write(f"""
<li>
<a href="#req_{req_id}">
Requirement {req_external_id} ({req_id})
</a>
</li>""")

f.write('</ul></nav>')
f.write("</ul></nav>")

data = db.conn.execute("""
SELECT
Expand Down Expand Up @@ -60,9 +72,12 @@ def add_new_line(summary):
current_req_id = None
current_annotations = {}
current_test_scripts = set()
progress_bar = tqdm(total=all_reqs_count, desc='Generating report', unit='requirements')
progress_bar = tqdm(
total=all_reqs_count, desc="Generating report", unit="requirements"
)

written_count = 0

def write_requirement():
global written_count
# if written_count > 5:
Expand All @@ -77,15 +92,26 @@ def write_requirement():
""")
for anno_id, (anno_summary, distance) in current_annotations.items():
f.write(
f'<li>Annotation {anno_id} (distance: {distance:.3f}): <p>{add_new_line(anno_summary)}</p></li>')
f.write('</ul>')
f.write('<h3>Test Scripts</h3><ul>')
f"<li>Annotation {anno_id} (distance: {distance:.3f}): <p>{add_new_line(anno_summary)}</p></li>"
)
f.write("</ul>")
f.write("<h3>Test Scripts</h3><ul>")
for test_script in current_test_scripts:
f.write(f"<li>{test_script}</li>")
f.write('</ul></section>')
f.write("</ul></section>")

for row in data.fetchall():
req_id, req_external_id, req_summary, anno_id, anno_summary, distance, case_id, test_script, test_case = row
(
req_id,
req_external_id,
req_summary,
anno_id,
anno_summary,
distance,
case_id,
test_script,
test_case,
) = row
if req_id != current_req_id:
if current_req_id is not None:
write_requirement()
Expand All @@ -102,4 +128,4 @@ def write_requirement():
</main>
</body>
</html>
""")
""")
3 changes: 2 additions & 1 deletion test2text/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import DbClient
__all__ = ["DbClient"]
from .client import DbClient
Loading