Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1d24291
Add possibility to run embedding test without model
d0rich Jul 14, 2025
0fba3d4
Move DB client tests to submodule
d0rich Jul 14, 2025
3111193
Refactor annotations table
d0rich Jul 14, 2025
0390572
Refactor test cases table
d0rich Jul 14, 2025
b8f0f2c
Tests for annotations table
d0rich Jul 14, 2025
dcf7a07
Fix embeddings test
d0rich Jul 14, 2025
27147f7
TestCases table test
d0rich Jul 14, 2025
54c0377
Requirements table tests
d0rich Jul 14, 2025
41d10d3
Turn on foreign keys and CasesToAnnos table tests
d0rich Jul 14, 2025
53af2cc
CasesToAnnos table - add insertion fact return
d0rich Jul 14, 2025
2f16167
AnnosToReqs table tests
d0rich Jul 14, 2025
303f1d5
Semver implementation
d0rich Jul 14, 2025
9d922ad
Check SQLite version
d0rich Jul 14, 2025
91198a4
Move torch dependencies to production group
d0rich Jul 14, 2025
01e5cd3
Switch logging level to warning
d0rich Jul 14, 2025
00e89dc
Configure tests CI
d0rich Jul 14, 2025
add8dbe
Fix coverage script
d0rich Jul 14, 2025
e7ce46c
Ignore errors for coverage
d0rich Jul 14, 2025
1b895f6
Run tests without coverage
d0rich Jul 14, 2025
3722da7
Revert "Run tests without coverage"
d0rich Jul 14, 2025
da6e666
Try avoid torch installation in uv
d0rich Jul 14, 2025
ac28546
Check only project files
d0rich Jul 14, 2025
f85136b
Do not download production dependencies on coverage report generation
d0rich Jul 14, 2025
8dbf74d
Fix wrong method call
d0rich Jul 14, 2025
0533c6d
Install ruff
d0rich Jul 14, 2025
32b8c84
Lint in CI
d0rich Jul 14, 2025
30d496a
Configure ruff linter
d0rich Jul 14, 2025
914be61
Check formatting in CI
d0rich Jul 14, 2025
87f1cc3
Format files
d0rich Jul 14, 2025
c8dc7a0
Add last empty lines to files
d0rich Jul 14, 2025
86bc984
Correct old SQLite error message
d0rich Jul 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: test

on: [push, workflow_dispatch]

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Setup uv
uses: astral-sh/setup-uv@v5

- name: Install dependencies
run: uv sync --no-group production

- name: Run tests with coverage
run: uv run --no-group production -m coverage run --source=test2text -m unittest discover tests

- name: Display coverage report
run: uv run --no-group production -m coverage report --ignore-errors
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.idea
.venv
.venv
.coverage
4 changes: 2 additions & 2 deletions index_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
reader = csv.reader(csvfile)
for row in reader:
[summary, _, test_script, test_case, *_] = row
anno_id = db.annotations.insert(summary=summary)
tc_id = db.test_cases.insert(test_script=test_script, test_case=test_case)
anno_id = db.annotations.get_or_insert(summary=summary)
tc_id = db.test_cases.get_or_insert(test_script=test_script, test_case=test_case)
db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id)
db.conn.commit()
# Embed annotations
Expand Down
2 changes: 1 addition & 1 deletion index_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def write_batch():
embeddings = embed_requirements_batch([requirement for _, requirement in batch])
for i, (external_id, requirement) in enumerate(batch):
embedding = embeddings[i]
db.requirements.insert(requirement, embedding, external_id)
db.requirements.get_or_insert(requirement, embedding, external_id)
db.conn.commit()
batch = []
for row in reader:
Expand Down
18 changes: 14 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@ authors = [
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"einops>=0.8.1",
"matplotlib>=3.9.4",
"sentence-transformers>=4.0.1",
"sqlite-vec>=0.1.6",
"tabbyset>=1.0.0",
"torch",
"tabbyset>=1.0.0"
]

[dependency-groups]
dev = [
"coverage>=7.9.2",
]
production = [
"einops>=0.8.1",
"sentence-transformers>=4.0.1",
"torch"
]

[tool.uv]
default-groups = "all"

[tool.uv.sources]
torch = {index = "pytorch-cpu"}
Expand Down
17 changes: 17 additions & 0 deletions test2text/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sqlite_vec
import logging

from test2text.utils.semver import Semver
from .tables import RequirementsTable, AnnotationsTable, AnnotationsToRequirementsTable, TestCasesTable, TestCasesToAnnotationsTable
from ..utils.path import PathParam

Expand All @@ -10,10 +11,22 @@
class DbClient:
conn: sqlite3.Connection

@staticmethod
def _check_sqlite_version():
# Version when RETURNED is available
REQUIRED_SQLITE_VERSION = Semver('3.35.0')
sqlite_version = Semver(sqlite3.sqlite_version)
if sqlite_version < REQUIRED_SQLITE_VERSION:
raise RuntimeError(f'SQLite version {sqlite_version} is too old. '
f'Required version is {REQUIRED_SQLITE_VERSION}. '
'Please upgrade SQLite in your system to use this feature.')

def __init__(self, file_path: PathParam, embedding_dim: int = 768):
self._check_sqlite_version()
logger.info('Connecting to database at %s', file_path)
self.conn = sqlite3.connect(file_path)
self.embedding_dim = embedding_dim
self._turn_on_foreign_keys()
self._install_extension()
self._init_tables()
logger.info('Connected to database at %s', file_path)
Expand All @@ -24,6 +37,10 @@ def _install_extension(self):
sqlite_vec.load(self.conn)
self.conn.enable_load_extension(False)

def _turn_on_foreign_keys(self):
self.conn.execute("PRAGMA foreign_keys = ON")
logger.debug('Foreign keys enabled')

def _init_tables(self):
self.requirements = RequirementsTable(self.conn, self.embedding_dim)
self.annotations = AnnotationsTable(self.conn, self.embedding_dim)
Expand Down
31 changes: 23 additions & 8 deletions test2text/db/tables/annos_to_reqs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sqlite3
from .abstract_table import AbstractTable

class AnnotationsToRequirementsTable(AbstractTable):
Expand All @@ -19,11 +20,25 @@ def recreate_table(self):
""")
self.init_table()

def insert(self, annotation_id: int, requirement_id: int, cached_distance: float):
self.connection.execute(
"""
INSERT OR IGNORE INTO AnnotationsToRequirements (annotation_id, requirement_id, cached_distance)
VALUES (?, ?, ?)
""",
(annotation_id, requirement_id, cached_distance)
)
def insert(self, annotation_id: int, requirement_id: int, cached_distance: float) -> bool:
try:
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO AnnotationsToRequirements (annotation_id, requirement_id, cached_distance)
VALUES (?, ?, ?)
RETURNING true
""",
(annotation_id, requirement_id, cached_distance)
)
result = cursor.fetchone()
cursor.close()
if result:
return result[0]
except sqlite3.IntegrityError:
# If the insert fails due to a duplicate, we simply ignore it
pass
return False

def count(self) -> int:
cursor = self.connection.execute("SELECT COUNT(*) FROM AnnotationsToRequirements")
return cursor.fetchone()[0]
9 changes: 8 additions & 1 deletion test2text/db/tables/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def init_table(self):
""").substitute(embedding_size=self.embedding_size)
)

def insert(self, summary: str, embedding: list[float] = None) -> int:
def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]:
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO Annotations (summary, embedding)
Expand All @@ -41,6 +41,13 @@ def insert(self, summary: str, embedding: list[float] = None) -> int:
cursor.close()
if result:
return result[0]
else:
return None

def get_or_insert(self, summary: str, embedding: list[float] = None) -> int:
inserted_id = self.insert(summary, embedding)
if inserted_id is not None:
return inserted_id
else:
cursor = self.connection.execute(
"""
Expand Down
38 changes: 30 additions & 8 deletions test2text/db/tables/cases_to_annos.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sqlite3

from .abstract_table import AbstractTable

class TestCasesToAnnotationsTable(AbstractTable):
Expand All @@ -12,11 +14,31 @@ def init_table(self):
)
""")

def insert(self, case_id: int, annotation_id: int):
self.connection.execute(
"""
INSERT OR IGNORE INTO CasesToAnnos (case_id, annotation_id)
VALUES (?, ?)
""",
(case_id, annotation_id)
)
def recreate_table(self):
self.connection.execute("""
DROP TABLE IF EXISTS CasesToAnnos
""")
self.init_table()

def insert(self, case_id: int, annotation_id: int) -> bool:
try:
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO CasesToAnnos (case_id, annotation_id)
VALUES (?, ?)
RETURNING true
""",
(case_id, annotation_id)
)
result = cursor.fetchone()
cursor.close()
if result:
return result[0]
except sqlite3.IntegrityError:
# If the insert fails due to a duplicate, we simply ignore it
pass
return False

def count(self) -> int:
cursor = self.connection.execute("SELECT COUNT(*) FROM CasesToAnnos")
return cursor.fetchone()[0]
11 changes: 10 additions & 1 deletion test2text/db/tables/test_case.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from .abstract_table import AbstractTable

class TestCasesTable(AbstractTable):
Expand All @@ -11,7 +13,7 @@ def init_table(self):
)
""")

def insert(self, test_script: str, test_case: str):
def insert(self, test_script: str, test_case: str) -> Optional[int]:
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO TestCases (test_script, test_case)
Expand All @@ -24,6 +26,13 @@ def insert(self, test_script: str, test_case: str):
cursor.close()
if result:
return result[0]
else:
return None

def get_or_insert(self, test_script: str, test_case: str) -> int:
inserted_id = self.insert(test_script, test_case)
if inserted_id is not None:
return inserted_id
else:
cursor = self.connection.execute(
"""
Expand Down
52 changes: 52 additions & 0 deletions test2text/utils/semver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
class Semver:
"""
A class to represent a semantic version.
"""

def __init__(self, version: str):
self.major, self.minor, self.patch = (int(v) for v in version.split('.'))

def __str__(self):
return f"{self.major}.{self.minor}.{self.patch}"

def __eq__(self, other):
if isinstance(other, Semver):
return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
if isinstance(other, str):
return str(self) == other
return False

def __lt__(self, other):
if isinstance(other, Semver):
return (self.major, self.minor, self.patch) < (other.major, other.minor, other.patch)
if isinstance(other, str):
return self < Semver(other)
return NotImplemented

def __le__(self, other):
if isinstance(other, Semver):
return (self.major, self.minor, self.patch) <= (other.major, other.minor, other.patch)
if isinstance(other, str):
return self <= Semver(other)
return NotImplemented

def __gt__(self, other):
if isinstance(other, Semver):
return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch)
if isinstance(other, str):
return self > Semver(other)
return NotImplemented

def __ge__(self, other):
if isinstance(other, Semver):
return (self.major, self.minor, self.patch) >= (other.major, other.minor, other.patch)
if isinstance(other, str):
return self >= Semver(other)
return NotImplemented

def __ne__(self, other):
if isinstance(other, Semver):
return (self.major, self.minor, self.patch) != (other.major, other.minor, other.patch)
if isinstance(other, str):
return str(self) != other
return NotImplemented
Empty file added tests/test_db/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions tests/test_db/test_db_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from unittest import TestCase
from test2text.db.client import DbClient

class TestDBClient(TestCase):
def test_db_client(self):
db = DbClient(':memory:')
with self.subTest('extensions'):
vec_version, = db.conn.execute("select vec_version()").fetchone()
self.assertIsNotNone(vec_version)
Empty file.
51 changes: 51 additions & 0 deletions tests/test_db/test_tables/test_annos_to_reqs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from unittest import TestCase
from test2text.db.client import DbClient

class TestAnnosToReqsTable(TestCase):
def setUp(self):
self.db = DbClient(':memory:')
self.anno1 = self.db.annotations.insert('Test Annotation 1')
self.anno2 = self.db.annotations.insert('Test Annotation 2')
self.req1 = self.db.requirements.insert('Test Requirement 1')
self.req2 = self.db.requirements.insert('Test Requirement 2')
self.wrong_anno = 9999
self.wrong_req = 8888

def test_insert_single(self):
count_before = self.db.annos_to_reqs.count()
inserted = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1)
count_after = self.db.annos_to_reqs.count()
self.assertTrue(inserted)
self.assertEqual(count_after, count_before + 1)

def test_insert_multiple(self):
count_before = self.db.annos_to_reqs.count()
inserted1 = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1)
inserted2 = self.db.annos_to_reqs.insert(self.anno2, self.req2, 1)
count_after = self.db.annos_to_reqs.count()
self.assertTrue(inserted1)
self.assertTrue(inserted2)
self.assertEqual(count_after, count_before + 2)

def test_insert_duplicate(self):
count_before = self.db.annos_to_reqs.count()
inserted1 = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1)
inserted2 = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1)
count_after = self.db.annos_to_reqs.count()
self.assertTrue(inserted1)
self.assertFalse(inserted2) # Second insertion should fail as it's a duplicate
self.assertEqual(count_after, count_before + 1)

def test_insert_wrong_annotation(self):
count_before = self.db.annos_to_reqs.count()
inserted = self.db.annos_to_reqs.insert(self.wrong_anno, self.req1, 1)
count_after = self.db.annos_to_reqs.count()
self.assertFalse(inserted) # Should fail due to foreign key constraint
self.assertEqual(count_after, count_before)

def test_insert_wrong_requirement(self):
count_before = self.db.annos_to_reqs.count()
inserted = self.db.annos_to_reqs.insert(self.anno1, self.wrong_req, 1)
count_after = self.db.annos_to_reqs.count()
self.assertFalse(inserted) # Should fail due to foreign key constraint
self.assertEqual(count_before, count_after)
Loading