|
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | import numpy as np |
9 | | -import os |
10 | 9 | import sys |
11 | 10 | from unittest.mock import patch, MagicMock, call |
12 | 11 | from basicrta.cluster import ProcessProtein |
| 12 | +from basicrta.tests.utils import work_in |
13 | 13 |
|
14 | 14 |
|
15 | 15 | class TestProcessProtein: |
@@ -175,6 +175,144 @@ def test_plot_protein_calls_util_function(self, mock_plot_protein): |
175 | 175 | assert 'label_cutoff' in kwargs |
176 | 176 | assert kwargs['label_cutoff'] == 2.5 |
177 | 177 |
|
| 178 | + def test_write_data_with_existing_data(self, tmp_path): |
| 179 | + """Test write_data method when taus and bars are already set.""" |
| 180 | + pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0) |
| 181 | + |
| 182 | + # Set up test data as numpy arrays (matching the actual implementation) |
| 183 | + pp.residues = np.array(["R100", "R101", "R102"]) |
| 184 | + pp.taus = np.array([1.0, 2.0, 3.0]) |
| 185 | + pp.bars = np.array([[0.5, 0.6, 0.7], [1.5, 1.6, 1.7]]) |
| 186 | + |
| 187 | + # Create output file in temporary directory |
| 188 | + output_file = tmp_path / "test_taus" |
| 189 | + |
| 190 | + # Call write_data |
| 191 | + pp.write_data(str(output_file)) |
| 192 | + |
| 193 | + # Verify the file was created |
| 194 | + assert output_file.with_suffix('.npy').exists() |
| 195 | + |
| 196 | + # Load and verify the data |
| 197 | + saved_data = np.load(str(output_file) + '.npy') |
| 198 | + |
| 199 | + # Expected data format: [resid, tau, CI_lower, CI_upper] |
| 200 | + expected_data = np.array([ |
| 201 | + [100, 1.0, 0.5, 1.5], # R100 -> 100 |
| 202 | + [101, 2.0, 0.6, 1.6], # R101 -> 101 |
| 203 | + [102, 3.0, 0.7, 1.7] # R102 -> 102 |
| 204 | + ]) |
| 205 | + |
| 206 | + assert np.array_equal(saved_data, expected_data) |
| 207 | + |
| 208 | + @patch('basicrta.cluster.ProcessProtein.get_taus') |
| 209 | + def test_write_data_calls_get_taus_when_needed(self, mock_get_taus, tmp_path): |
| 210 | + """Test write_data method calls get_taus when taus is None.""" |
| 211 | + pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0) |
| 212 | + |
| 213 | + # Define the test data |
| 214 | + test_taus = np.array([1.5, 2.5, 3.5]) |
| 215 | + test_bars = np.array([[0.3, 0.4, 0.5], [1.7, 1.8, 1.9]]) |
| 216 | + test_residues = np.array(["R200", "R201", "R202"]) |
| 217 | + |
| 218 | + # Set up mock to return values AND set instance attributes (like real get_taus) |
| 219 | + def mock_get_taus_side_effect(): |
| 220 | + pp.taus = test_taus |
| 221 | + pp.bars = test_bars |
| 222 | + pp.residues = test_residues |
| 223 | + return test_taus, test_bars |
| 224 | + |
| 225 | + mock_get_taus.side_effect = mock_get_taus_side_effect |
| 226 | + |
| 227 | + # Create output file in temporary directory |
| 228 | + output_file = tmp_path / "test_taus_from_get_taus" |
| 229 | + |
| 230 | + # Ensure taus is None to trigger get_taus call |
| 231 | + pp.taus = None |
| 232 | + |
| 233 | + # Call write_data |
| 234 | + pp.write_data(str(output_file)) |
| 235 | + |
| 236 | + # Verify get_taus was called |
| 237 | + mock_get_taus.assert_called_once() |
| 238 | + |
| 239 | + # Verify the file was created and contains expected data |
| 240 | + assert output_file.with_suffix('.npy').exists() |
| 241 | + saved_data = np.load(str(output_file) + '.npy') |
| 242 | + |
| 243 | + expected_data = np.array([ |
| 244 | + [200, 1.5, 0.3, 1.7], # R200 -> 200 |
| 245 | + [201, 2.5, 0.4, 1.8], # R201 -> 201 |
| 246 | + [202, 3.5, 0.5, 1.9] # R202 -> 202 |
| 247 | + ]) |
| 248 | + |
| 249 | + assert np.array_equal(saved_data, expected_data) |
| 250 | + |
| 251 | + @patch('basicrta.cluster.glob') |
| 252 | + @patch('basicrta.cluster.Pool') |
| 253 | + @patch('basicrta.util.get_bars') |
| 254 | + def test_get_taus_returns_values(self, mock_get_bars, mock_pool, mock_glob): |
| 255 | + """Test that get_taus method returns values as documented in docstring.""" |
| 256 | + pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0) |
| 257 | + |
| 258 | + # Mock the directory structure |
| 259 | + mock_glob.return_value = ["basicrta-7.0/R100", "basicrta-7.0/R101"] |
| 260 | + |
| 261 | + # Mock the multiprocessing pool to return test data |
| 262 | + mock_pool_instance = mock_pool.return_value.__enter__.return_value |
| 263 | + mock_imap_results = [ |
| 264 | + ("R100", [0.1, 1.5, 2.8], "path1"), |
| 265 | + ("R101", [0.2, 2.0, 3.2], "path2") |
| 266 | + ] |
| 267 | + mock_pool_instance.imap.return_value = mock_imap_results |
| 268 | + |
| 269 | + # Mock get_bars to return test confidence intervals |
| 270 | + test_bars = np.array([[0.5, 0.6], [2.5, 2.6]]) |
| 271 | + mock_get_bars.return_value = test_bars |
| 272 | + |
| 273 | + # Call get_taus and verify it returns values |
| 274 | + result = pp.get_taus(nproc=1) |
| 275 | + |
| 276 | + # Verify the method returns a tuple as documented |
| 277 | + assert isinstance(result, tuple) |
| 278 | + assert len(result) == 2 |
| 279 | + |
| 280 | + returned_taus, returned_bars = result |
| 281 | + |
| 282 | + # Verify the returned values match the instance attributes |
| 283 | + assert np.array_equal(returned_taus, pp.taus) |
| 284 | + assert np.array_equal(returned_bars, pp.bars) |
| 285 | + |
| 286 | + # Verify the values are what we expect |
| 287 | + expected_taus = np.array([1.5, 2.0]) # Middle values from tau arrays |
| 288 | + assert np.array_equal(returned_taus, expected_taus) |
| 289 | + assert np.array_equal(returned_bars, test_bars) |
| 290 | + |
| 291 | + def test_write_data_with_default_filename(self, tmp_path): |
| 292 | + """Test write_data method uses default filename when none provided.""" |
| 293 | + pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0) |
| 294 | + |
| 295 | + # Set up test data |
| 296 | + pp.residues = np.array(["R300", "R301"]) |
| 297 | + pp.taus = np.array([4.0, 5.0]) |
| 298 | + pp.bars = np.array([[0.8, 0.9], [2.0, 2.1]]) |
| 299 | + |
| 300 | + with work_in(tmp_path): |
| 301 | + # Call write_data without filename (should use default) |
| 302 | + pp.write_data() |
| 303 | + |
| 304 | + # Verify default file was created |
| 305 | + default_file = tmp_path / "tausout.npy" |
| 306 | + assert default_file.exists() |
| 307 | + |
| 308 | + # Verify data integrity |
| 309 | + saved_data = np.load(default_file) |
| 310 | + expected_data = np.array([ |
| 311 | + [300, 4.0, 0.8, 2.0], |
| 312 | + [301, 5.0, 0.9, 2.1] |
| 313 | + ]) |
| 314 | + assert np.array_equal(saved_data, expected_data) |
| 315 | + |
178 | 316 |
|
179 | 317 | class TestClusterScript: |
180 | 318 | """Tests for the command-line script functionality.""" |
|
0 commit comments