1- import pytest
21import numpy as np
2+ import pytest
33
4- from themap .data .molecule_dataset import MoleculeDataset
54from themap .data .molecule_datapoint import MoleculeDatapoint
5+ from themap .data .molecule_dataset import MoleculeDataset
6+
67
78def test_MoleculeDataset_load_from_file (dataset_CHEMBL2219236 ):
89 """Test loading MoleculeDataset from file."""
@@ -18,18 +19,12 @@ def test_MoleculeDataset_load_from_file(dataset_CHEMBL2219236):
1819 # Test the __repr__ method
1920 assert str (dataset ) == "MoleculeDataset(task_id=CHEMBL2219236, task_size=157)"
2021
22+
2123def test_MoleculeDataset_validation ():
2224 """Test input validation in MoleculeDataset."""
2325 # Test valid initialization
2426 dataset = MoleculeDataset (
25- task_id = "test_task" ,
26- data = [
27- MoleculeDatapoint (
28- task_id = "test_task" ,
29- smiles = "c1ccccc1" ,
30- bool_label = True
31- )
32- ]
27+ task_id = "test_task" , data = [MoleculeDatapoint (task_id = "test_task" , smiles = "c1ccccc1" , bool_label = True )]
3328 )
3429 assert dataset .task_id == "test_task"
3530 assert len (dataset ) == 1
@@ -38,30 +33,31 @@ def test_MoleculeDataset_validation():
3833 with pytest .raises (TypeError ):
3934 MoleculeDataset (
4035 task_id = 123 , # Should be string
41- data = []
36+ data = [],
4237 )
4338
4439 # Test invalid data
4540 with pytest .raises (TypeError ):
4641 MoleculeDataset (
4742 task_id = "test_task" ,
48- data = "not_a_list" # Should be list
43+ data = "not_a_list" , # Should be list
4944 )
5045
5146 # Test invalid data items
5247 with pytest .raises (TypeError ):
5348 MoleculeDataset (
5449 task_id = "test_task" ,
55- data = ["not_a_MoleculeDatapoint" ] # Should be MoleculeDatapoint
50+ data = ["not_a_MoleculeDatapoint" ], # Should be MoleculeDatapoint
5651 )
5752
53+
5854def test_MoleculeDataset_properties ():
5955 """Test MoleculeDataset properties."""
6056 # Create a test dataset
6157 datapoints = [
6258 MoleculeDatapoint ("test_task" , "c1ccccc1" , True ),
6359 MoleculeDatapoint ("test_task" , "c1ccccc1" , False ),
64- MoleculeDatapoint ("test_task" , "c1ccccc1" , True )
60+ MoleculeDatapoint ("test_task" , "c1ccccc1" , True ),
6561 ]
6662 dataset = MoleculeDataset ("test_task" , datapoints )
6763
@@ -83,13 +79,14 @@ def test_MoleculeDataset_properties():
8379 # Test get_ratio property
8480 assert dataset .get_ratio == 0.67 # 2/3 rounded to 2 decimal places
8581
82+
8683def test_MoleculeDataset_filter ():
8784 """Test MoleculeDataset filtering."""
8885 # Create a test dataset
8986 datapoints = [
9087 MoleculeDatapoint ("test_task" , "c1ccccc1" , True ),
9188 MoleculeDatapoint ("test_task" , "c1ccccc1" , False ),
92- MoleculeDatapoint ("test_task" , "c1ccccc1" , True )
89+ MoleculeDatapoint ("test_task" , "c1ccccc1" , True ),
9390 ]
9491 dataset = MoleculeDataset ("test_task" , datapoints )
9592
@@ -98,22 +95,23 @@ def test_MoleculeDataset_filter():
9895 assert len (filtered_dataset ) == 2
9996 assert all (dp .bool_label for dp in filtered_dataset )
10097
98+
10199def test_MoleculeDataset_statistics ():
102100 """Test MoleculeDataset statistics."""
103101 # Create a test dataset
104102 datapoints = [
105103 MoleculeDatapoint ("test_task" , "c1ccccc1" , True ),
106104 MoleculeDatapoint ("test_task" , "c1ccccc1" , False ),
107- MoleculeDatapoint ("test_task" , "c1ccccc1" , True )
105+ MoleculeDatapoint ("test_task" , "c1ccccc1" , True ),
108106 ]
109107 dataset = MoleculeDataset ("test_task" , datapoints )
110108
111109 # Get statistics
112110 stats = dataset .get_statistics ()
113-
111+
114112 # Check statistics
115113 assert stats ["size" ] == 3
116114 assert stats ["positive_ratio" ] == 0.67
117115 assert isinstance (stats ["avg_molecular_weight" ], float )
118116 assert isinstance (stats ["avg_atoms" ], float )
119- assert isinstance (stats ["avg_bonds" ], float )
117+ assert isinstance (stats ["avg_bonds" ], float )
0 commit comments