@@ -24,21 +24,21 @@ def __init__(self, resids):
2424 self .resids = np .array (resids )
2525
2626
27- class TestCombineContacts :
28- """Test class for CombineContacts functionality."""
29-
30- def setup_method ( self ):
31- """Set up test fixtures."""
32- self . temp_dir = tempfile . mkdtemp ( )
33- self . original_dir = os . getcwd ()
34- os .chdir (self . temp_dir )
35-
36- def teardown_method ( self ):
37- """Clean up after tests."""
38- os . chdir ( self . original_dir )
39- shutil . rmtree ( self . temp_dir )
40-
41- def create_mock_contacts ( self , filename , n_contacts = 100 , cutoff = 7.0 ,
27+ @ pytest . fixture
28+ def temp_dir ():
29+ """Create a temporary directory for tests."""
30+ temp_dir = tempfile . mkdtemp ()
31+ original_dir = os . getcwd ()
32+ os . chdir ( temp_dir )
33+ yield temp_dir
34+ os .chdir (original_dir )
35+ shutil . rmtree ( temp_dir )
36+
37+
38+ @ pytest . fixture
39+ def create_mock_contacts ():
40+ """Factory fixture for creating mock contact files."""
41+ def _create_mock_contacts ( filename , n_contacts = 100 , cutoff = 7.0 ,
4242 ts = 0.1 , traj_name = "test.xtc" , top_name = "test.pdb" ):
4343 """Create a mock contact file for testing."""
4444 # Create mock atom groups that can be pickled
@@ -72,12 +72,18 @@ def create_mock_contacts(self, filename, n_contacts=100, cutoff=7.0,
7272 pickle .dump (contacts , f , protocol = 5 )
7373
7474 return contacts , metadata
75+
76+ return _create_mock_contacts
77+
78+
79+ class TestCombineContacts :
80+ """Test class for CombineContacts functionality."""
7581
76- def test_combine_contacts_basic (self ):
82+ def test_combine_contacts_basic (self , temp_dir , create_mock_contacts ):
7783 """Test basic contact combination functionality."""
7884 # Create two mock contact files
79- contacts1 , meta1 = self . create_mock_contacts ("contacts1.pkl" , n_contacts = 50 )
80- contacts2 , meta2 = self . create_mock_contacts ("contacts2.pkl" , n_contacts = 75 ,
85+ contacts1 , meta1 = create_mock_contacts ("contacts1.pkl" , n_contacts = 50 )
86+ contacts2 , meta2 = create_mock_contacts ("contacts2.pkl" , n_contacts = 75 ,
8187 traj_name = "test2.xtc" )
8288
8389 # Combine them
@@ -110,10 +116,10 @@ def test_combine_contacts_basic(self):
110116 assert np .all (traj_sources [:50 ] == 0 ) # First 50 from file 0
111117 assert np .all (traj_sources [50 :] == 1 ) # Next 75 from file 1
112118
113- def test_incompatible_cutoffs (self ):
119+ def test_incompatible_cutoffs (self , temp_dir , create_mock_contacts ):
114120 """Test that incompatible cutoffs raise an error."""
115- self . create_mock_contacts ("contacts1.pkl" , cutoff = 7.0 )
116- self . create_mock_contacts ("contacts2.pkl" , cutoff = 8.0 ) # Different cutoff
121+ create_mock_contacts ("contacts1.pkl" , cutoff = 7.0 )
122+ create_mock_contacts ("contacts2.pkl" , cutoff = 8.0 ) # Different cutoff
117123
118124 combiner = CombineContacts (
119125 contact_files = ["contacts1.pkl" , "contacts2.pkl" ]
@@ -122,10 +128,10 @@ def test_incompatible_cutoffs(self):
122128 with pytest .raises (ValueError , match = "Incompatible cutoffs" ):
123129 combiner .run ()
124130
125- def test_incompatible_atom_groups (self ):
131+ def test_incompatible_atom_groups (self , temp_dir , create_mock_contacts ):
126132 """Test that incompatible atom groups raise an error."""
127133 # Create first file with standard residues
128- contacts1 , _ = self . create_mock_contacts ("contacts1.pkl" )
134+ contacts1 , _ = create_mock_contacts ("contacts1.pkl" )
129135
130136 # Create second file with different protein residues
131137 mock_ag1 = MockAtomGroup ([10 , 20 , 30 ]) # Different resids
@@ -153,10 +159,10 @@ def test_incompatible_atom_groups(self):
153159 with pytest .raises (ValueError , match = "Incompatible ag1 residues" ):
154160 combiner .run ()
155161
156- def test_different_timesteps_warning (self , capsys ):
162+ def test_different_timesteps_warning (self , temp_dir , create_mock_contacts , capsys ):
157163 """Test that different timesteps produce a warning."""
158- self . create_mock_contacts ("contacts1.pkl" , ts = 0.1 )
159- self . create_mock_contacts ("contacts2.pkl" , ts = 0.2 ) # Different timestep
164+ create_mock_contacts ("contacts1.pkl" , ts = 0.1 )
165+ create_mock_contacts ("contacts2.pkl" , ts = 0.2 ) # Different timestep
160166
161167 combiner = CombineContacts (
162168 contact_files = ["contacts1.pkl" , "contacts2.pkl" ]
@@ -168,16 +174,16 @@ def test_different_timesteps_warning(self, capsys):
168174 captured = capsys .readouterr ()
169175 assert "WARNING: Different timesteps detected" in captured .out
170176
171- def test_minimum_files_required (self ):
177+ def test_minimum_files_required (self , temp_dir , create_mock_contacts ):
172178 """Test that at least 2 files are required."""
173- self . create_mock_contacts ("contacts1.pkl" )
179+ create_mock_contacts ("contacts1.pkl" )
174180
175181 with pytest .raises (ValueError , match = "At least 2 contact files are required" ):
176182 CombineContacts (contact_files = ["contacts1.pkl" ])
177183
178- def test_missing_file (self ):
184+ def test_missing_file (self , temp_dir , create_mock_contacts ):
179185 """Test handling of missing contact files."""
180- self . create_mock_contacts ("contacts1.pkl" )
186+ create_mock_contacts ("contacts1.pkl" )
181187
182188 combiner = CombineContacts (
183189 contact_files = ["contacts1.pkl" , "nonexistent.pkl" ]
@@ -186,10 +192,10 @@ def test_missing_file(self):
186192 with pytest .raises (FileNotFoundError , match = "Contact file not found" ):
187193 combiner .run ()
188194
189- def test_skip_validation (self ):
195+ def test_skip_validation (self , temp_dir , create_mock_contacts ):
190196 """Test skipping compatibility validation."""
191- self . create_mock_contacts ("contacts1.pkl" , cutoff = 7.0 )
192- self . create_mock_contacts ("contacts2.pkl" , cutoff = 8.0 ) # Different cutoff
197+ create_mock_contacts ("contacts1.pkl" , cutoff = 7.0 )
198+ create_mock_contacts ("contacts2.pkl" , cutoff = 8.0 ) # Different cutoff
193199
194200 combiner = CombineContacts (
195201 contact_files = ["contacts1.pkl" , "contacts2.pkl" ],
@@ -200,11 +206,11 @@ def test_skip_validation(self):
200206 output_file = combiner .run ()
201207 assert os .path .exists (output_file )
202208
203- def test_combined_contacts_detection (self ):
209+ def test_combined_contacts_detection (self , temp_dir , create_mock_contacts ):
204210 """Test that combined contact files are properly detected."""
205211 # Create and combine contacts
206- self . create_mock_contacts ("contacts1.pkl" , n_contacts = 30 )
207- self . create_mock_contacts ("contacts2.pkl" , n_contacts = 40 , traj_name = "test2.xtc" )
212+ create_mock_contacts ("contacts1.pkl" , n_contacts = 30 )
213+ create_mock_contacts ("contacts2.pkl" , n_contacts = 40 , traj_name = "test2.xtc" )
208214
209215 combiner = CombineContacts (
210216 contact_files = ["contacts1.pkl" , "contacts2.pkl" ],
0 commit comments