Skip to content

Commit af18c3d

Browse files
committed
tests: add tests
1 parent 24522e3 commit af18c3d

17 files changed

+1496
-815
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
repos:
22
# Ruff - replaces flake8, isort, black, and more
33
- repo: https://github.com/astral-sh/ruff-pre-commit
4-
rev: v0.12.4 # Use latest stable version
4+
rev: v0.12.4
55
hooks:
66
- id: ruff
7-
args: [--fix] # Automatically fix what can be fixed
7+
args: [--fix] # Add back --fix for pre-commit
88
- id: ruff-format
99

1010
# Type checking
@@ -13,8 +13,7 @@ repos:
1313
hooks:
1414
- id: mypy
1515
additional_dependencies: [types-requests]
16-
args: [--ignore-missing-imports]
17-
exclude: ^(scripts/|third_party/|themap/models/otdd/)
16+
# Remove --config-file arg, mypy reads pyproject.toml automatically
1817

1918
# Basic file checks
2019
- repo: https://github.com/pre-commit/pre-commit-hooks
@@ -27,7 +26,7 @@ repos:
2726
- id: check-merge-conflict
2827
- id: check-case-conflict
2928
- id: check-added-large-files
30-
args: ['--maxkb=1000'] # Prevent large files
29+
args: ['--maxkb=1000']
3130

3231
# Optional: Run tests (can be slow, consider making it optional)
3332
# Uncomment if you want tests to run on every commit

pyproject.toml

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"numpy",
2020
"pandas",
2121
"matplotlib",
22-
"seaborn",
22+
"seaborn",
2323
"scikit-learn",
2424
"dpu-utils>=0.2.13",
2525
"rdkit",
@@ -60,7 +60,7 @@ protein = [
6060
"esm",
6161
]
6262

63-
# Optimal transport distances
63+
# Optimal transport distances
6464
otdd = [
6565
"torch==2.4.0",
6666
"torchvision==0.19.0",
@@ -78,7 +78,7 @@ all = [
7878
"torchvision==0.19.0",
7979
"torchaudio==2.4.0",
8080
"molfeat==0.11.0",
81-
"dgl<=2.0",
81+
"dgl<=2.0",
8282
"dgllife>=0.3.2",
8383
"pytorch_geometric",
8484
"fcd_torch",
@@ -152,15 +152,23 @@ exclude = [
152152
]
153153

154154
[tool.ruff]
155-
# Select rule codes to enforce. "E" and "F" are defaults. Add "I" for import sorting.
156-
# You can add more codes like "B" (flake8-bugbear) or "C4" (flake8-comprehensions) etc.
157-
lint.select = ["E", "F", "W", "I"] # W = warnings, I = isort
158-
lint.ignore = [
159-
"E501", # Line length handled by formatter
160-
]
155+
# Enable auto-fixing
156+
fix = true
157+
158+
# Set line length
161159
line-length = 110
160+
161+
# Set target version
162162
target-version = "py310"
163163

164+
# Select rule codes to enforce. "E" and "F" are defaults. Add "I" for import sorting.
165+
# You can add more codes like "B" (flake8-bugbear) or "C4" (flake8-comprehensions) etc.
166+
[tool.ruff.lint]
167+
168+
select = ["E", "F", "W", "I"] # W = warnings, I = isort
169+
ignore = ["E501"] # Line length handled by formatter
170+
171+
164172
[tool.ruff.format]
165173
# Use double quotes for strings
166174
quote-style = "double"
@@ -199,19 +207,6 @@ output = "coverage.xml"
199207

200208
[tool.mypy]
201209
python_version = "3.10"
202-
warn_return_any = true
203-
warn_unused_configs = true
204-
disallow_untyped_defs = true
205-
disallow_incomplete_defs = true
206-
check_untyped_defs = true
207-
disallow_untyped_decorators = true
208-
no_implicit_optional = true
209-
warn_redundant_casts = true
210-
warn_unused_ignores = true
211-
warn_no_return = true
212-
warn_unreachable = true
213-
strict_equality = true
214-
show_error_codes = true
215210

216211
[[tool.mypy.overrides]]
217212
module = [
@@ -220,5 +215,6 @@ module = [
220215
"molfeat.*",
221216
"dpu_utils.*",
222217
"themap.models.otdd.*",
218+
"tests.*",
223219
]
224220
ignore_missing_imports = true

tests/conftest.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import numpy as np
21
import pandas as pd
32
import pytest
4-
53
from dpu_utils.utils import RichPath
6-
from themap.data import MoleculeDatapoint, ProteinDataset, MoleculeDataset
4+
5+
from themap.data.molecule_datapoint import MoleculeDatapoint
6+
from themap.data.protein_datasets import ProteinDataset
7+
78

89
@pytest.fixture
910
def manual_smiles():
@@ -44,25 +45,29 @@ def dataset_CHEMBL2219236():
4445
def dataset_CHEMBL1963831():
4546
return RichPath.create("datasets/test/CHEMBL1963831.jsonl.gz")
4647

48+
4749
@pytest.fixture
4850
def dataset_CHEMBL1023359():
4951
return RichPath.create("datasets/test/CHEMBL1023359.jsonl.gz")
5052

53+
5154
@pytest.fixture
5255
def dataset_CHEMBL2219358():
5356
return RichPath.create("datasets/test/CHEMBL2219358.jsonl.gz")
5457

58+
5559
@pytest.fixture
5660
def dataset_CHEMBL1963831_csv():
5761
return pd.read_csv("tests/conftest/CHEMBL1963831.csv")
5862

63+
5964
@pytest.fixture
6065
def manual_protein_dataset():
6166
return ProteinDataset(
6267
task_id=["CHEMBL2219236", "CHEMBL2219358"],
63-
protein={"Q13177" : "MSDNGELEDKPPAPPVRMSSTI",
64-
"P50750" : "MAKQYDSVECPFCDEVSKYEK"}
65-
)
68+
protein={"Q13177": "MSDNGELEDKPPAPPVRMSSTI", "P50750": "MAKQYDSVECPFCDEVSKYEK"},
69+
)
70+
6671

6772
@pytest.fixture
6873
def protein_dataset_train():

tests/data/test_molecule_datapoint.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from themap.data.molecule_datapoint import MoleculeDatapoint
44

5+
56
def test_MoleculeDatapoint(datapoint_molecule):
67
"""Test the MoleculeDatapoint class functionality."""
78
# Test the __repr__ method
@@ -19,15 +20,11 @@ def test_MoleculeDatapoint(datapoint_molecule):
1920
# Test the molecular_weight method
2021
assert round(datapoint_molecule.molecular_weight) == 78
2122

23+
2224
def test_MoleculeDatapoint_validation():
2325
"""Test input validation in MoleculeDatapoint."""
2426
# Test valid initialization
25-
datapoint = MoleculeDatapoint(
26-
task_id="test_task",
27-
smiles="c1ccccc1",
28-
bool_label=True,
29-
numeric_label=1.0
30-
)
27+
datapoint = MoleculeDatapoint(task_id="test_task", smiles="c1ccccc1", bool_label=True, numeric_label=1.0)
3128
assert datapoint.task_id == "test_task"
3229
assert datapoint.smiles == "c1ccccc1"
3330
assert datapoint.bool_label is True
@@ -38,23 +35,23 @@ def test_MoleculeDatapoint_validation():
3835
MoleculeDatapoint(
3936
task_id=123, # Should be string
4037
smiles="c1ccccc1",
41-
bool_label=True
38+
bool_label=True,
4239
)
4340

4441
# Test invalid smiles
4542
with pytest.raises(TypeError):
4643
MoleculeDatapoint(
4744
task_id="test_task",
4845
smiles=123, # Should be string
49-
bool_label=True
46+
bool_label=True,
5047
)
5148

5249
# Test invalid bool_label
5350
with pytest.raises(TypeError):
5451
MoleculeDatapoint(
5552
task_id="test_task",
5653
smiles="c1ccccc1",
57-
bool_label=1 # Should be bool
54+
bool_label=1, # Should be bool
5855
)
5956

6057
# Test invalid numeric_label
@@ -63,5 +60,5 @@ def test_MoleculeDatapoint_validation():
6360
task_id="test_task",
6461
smiles="c1ccccc1",
6562
bool_label=True,
66-
numeric_label="invalid" # Should be number or None
67-
)
63+
numeric_label="invalid", # Should be number or None
64+
)

tests/data/test_molecule_dataset.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import pytest
21
import numpy as np
2+
import pytest
33

4-
from themap.data.molecule_dataset import MoleculeDataset
54
from themap.data.molecule_datapoint import MoleculeDatapoint
5+
from themap.data.molecule_dataset import MoleculeDataset
6+
67

78
def test_MoleculeDataset_load_from_file(dataset_CHEMBL2219236):
89
"""Test loading MoleculeDataset from file."""
@@ -18,18 +19,12 @@ def test_MoleculeDataset_load_from_file(dataset_CHEMBL2219236):
1819
# Test the __repr__ method
1920
assert str(dataset) == "MoleculeDataset(task_id=CHEMBL2219236, task_size=157)"
2021

22+
2123
def test_MoleculeDataset_validation():
2224
"""Test input validation in MoleculeDataset."""
2325
# Test valid initialization
2426
dataset = MoleculeDataset(
25-
task_id="test_task",
26-
data=[
27-
MoleculeDatapoint(
28-
task_id="test_task",
29-
smiles="c1ccccc1",
30-
bool_label=True
31-
)
32-
]
27+
task_id="test_task", data=[MoleculeDatapoint(task_id="test_task", smiles="c1ccccc1", bool_label=True)]
3328
)
3429
assert dataset.task_id == "test_task"
3530
assert len(dataset) == 1
@@ -38,30 +33,31 @@ def test_MoleculeDataset_validation():
3833
with pytest.raises(TypeError):
3934
MoleculeDataset(
4035
task_id=123, # Should be string
41-
data=[]
36+
data=[],
4237
)
4338

4439
# Test invalid data
4540
with pytest.raises(TypeError):
4641
MoleculeDataset(
4742
task_id="test_task",
48-
data="not_a_list" # Should be list
43+
data="not_a_list", # Should be list
4944
)
5045

5146
# Test invalid data items
5247
with pytest.raises(TypeError):
5348
MoleculeDataset(
5449
task_id="test_task",
55-
data=["not_a_MoleculeDatapoint"] # Should be MoleculeDatapoint
50+
data=["not_a_MoleculeDatapoint"], # Should be MoleculeDatapoint
5651
)
5752

53+
5854
def test_MoleculeDataset_properties():
5955
"""Test MoleculeDataset properties."""
6056
# Create a test dataset
6157
datapoints = [
6258
MoleculeDatapoint("test_task", "c1ccccc1", True),
6359
MoleculeDatapoint("test_task", "c1ccccc1", False),
64-
MoleculeDatapoint("test_task", "c1ccccc1", True)
60+
MoleculeDatapoint("test_task", "c1ccccc1", True),
6561
]
6662
dataset = MoleculeDataset("test_task", datapoints)
6763

@@ -83,13 +79,14 @@ def test_MoleculeDataset_properties():
8379
# Test get_ratio property
8480
assert dataset.get_ratio == 0.67 # 2/3 rounded to 2 decimal places
8581

82+
8683
def test_MoleculeDataset_filter():
8784
"""Test MoleculeDataset filtering."""
8885
# Create a test dataset
8986
datapoints = [
9087
MoleculeDatapoint("test_task", "c1ccccc1", True),
9188
MoleculeDatapoint("test_task", "c1ccccc1", False),
92-
MoleculeDatapoint("test_task", "c1ccccc1", True)
89+
MoleculeDatapoint("test_task", "c1ccccc1", True),
9390
]
9491
dataset = MoleculeDataset("test_task", datapoints)
9592

@@ -98,22 +95,23 @@ def test_MoleculeDataset_filter():
9895
assert len(filtered_dataset) == 2
9996
assert all(dp.bool_label for dp in filtered_dataset)
10097

98+
10199
def test_MoleculeDataset_statistics():
102100
"""Test MoleculeDataset statistics."""
103101
# Create a test dataset
104102
datapoints = [
105103
MoleculeDatapoint("test_task", "c1ccccc1", True),
106104
MoleculeDatapoint("test_task", "c1ccccc1", False),
107-
MoleculeDatapoint("test_task", "c1ccccc1", True)
105+
MoleculeDatapoint("test_task", "c1ccccc1", True),
108106
]
109107
dataset = MoleculeDataset("test_task", datapoints)
110108

111109
# Get statistics
112110
stats = dataset.get_statistics()
113-
111+
114112
# Check statistics
115113
assert stats["size"] == 3
116114
assert stats["positive_ratio"] == 0.67
117115
assert isinstance(stats["avg_molecular_weight"], float)
118116
assert isinstance(stats["avg_atoms"], float)
119-
assert isinstance(stats["avg_bonds"], float)
117+
assert isinstance(stats["avg_bonds"], float)

0 commit comments

Comments
 (0)