Skip to content

Commit 2c98249

Browse files
authored
more fixes for clustering (#40)
* fix #37 * more fixes for clustering - ProcessProtein.write_data() should work again - ProcessProteins.get_taus() now returns values * add tests * update CHANGELOG
1 parent 0e1cfe2 commit 2c98249

File tree

3 files changed

+150
-6
lines changed

3 files changed

+150
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ The rules for this file:
2828
parameters, resolving TypeError in script execution. Added --gskip and
2929
--burnin arguments to cluster.py with default values from the research
3030
paper (gskip=1000, burnin=10000) (Issue #37)
31+
* Fixed ProcessProtein.write_data() method to handle residues as numpy array
32+
instead of dictionary, resolving AttributeError when calling the method
33+
after reprocess() or get_taus(). Also fixed get_taus() method to return
34+
values as documented. Added comprehensive test coverage for write_data()
35+
functionality (Issue #37)
3136

3237

3338
## [1.1.1] - 2025-07-18

basicrta/cluster.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ProcessProtein(object):
3939
def __init__(self, niter, prot, cutoff,
4040
gskip=1000, burnin=10000,
4141
taus=None, bars=None):
42-
self.residues = Results()
42+
self.residues = Results() # TODO: double-check that we need to use this, it gets set in reprocess/get_taus
4343
self.niter = niter
4444
self.prot = prot
4545
self.cutoff = cutoff
@@ -150,7 +150,7 @@ def get_taus(self, nproc=1):
150150
setattr(self, 'bars', bars)
151151
setattr(self, 'residues', np.array(residues))
152152
setattr(self, 'files', np.array(results))
153-
#return taus[:, 1], bars
153+
return taus[:, 1], bars
154154

155155
def write_data(self, fname='tausout'):
156156
r"""Write :math:`\tau` values with 95\% confidence interval to a numpy
@@ -163,9 +163,10 @@ def write_data(self, fname='tausout'):
163163
if self.taus is None:
164164
taus, bars = self.get_taus()
165165

166-
keys = self.residues.keys()
167-
residues = np.array([int(res[1:]) for res in keys])
168-
data = np.stack((residues, taus, bars[0], bars[1]))
166+
# Handle residues as numpy array (from reprocess/get_taus methods)
167+
# TODO: double-check that we need to use res[1:] and can't get this easier
168+
residues = np.array([int(res[1:]) for res in self.residues])
169+
data = np.stack((residues, self.taus, self.bars[0], self.bars[1]))
169170
np.save(fname, data.T)
170171

171172
def plot_protein(self, **kwargs):

basicrta/tests/test_cluster.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
import pytest
88
import numpy as np
9-
import os
109
import sys
1110
from unittest.mock import patch, MagicMock, call
1211
from basicrta.cluster import ProcessProtein
12+
from basicrta.tests.utils import work_in
1313

1414

1515
class TestProcessProtein:
@@ -175,6 +175,144 @@ def test_plot_protein_calls_util_function(self, mock_plot_protein):
175175
assert 'label_cutoff' in kwargs
176176
assert kwargs['label_cutoff'] == 2.5
177177

178+
def test_write_data_with_existing_data(self, tmp_path):
179+
"""Test write_data method when taus and bars are already set."""
180+
pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0)
181+
182+
# Set up test data as numpy arrays (matching the actual implementation)
183+
pp.residues = np.array(["R100", "R101", "R102"])
184+
pp.taus = np.array([1.0, 2.0, 3.0])
185+
pp.bars = np.array([[0.5, 0.6, 0.7], [1.5, 1.6, 1.7]])
186+
187+
# Create output file in temporary directory
188+
output_file = tmp_path / "test_taus"
189+
190+
# Call write_data
191+
pp.write_data(str(output_file))
192+
193+
# Verify the file was created
194+
assert output_file.with_suffix('.npy').exists()
195+
196+
# Load and verify the data
197+
saved_data = np.load(str(output_file) + '.npy')
198+
199+
# Expected data format: [resid, tau, CI_lower, CI_upper]
200+
expected_data = np.array([
201+
[100, 1.0, 0.5, 1.5], # R100 -> 100
202+
[101, 2.0, 0.6, 1.6], # R101 -> 101
203+
[102, 3.0, 0.7, 1.7] # R102 -> 102
204+
])
205+
206+
assert np.array_equal(saved_data, expected_data)
207+
208+
@patch('basicrta.cluster.ProcessProtein.get_taus')
209+
def test_write_data_calls_get_taus_when_needed(self, mock_get_taus, tmp_path):
210+
"""Test write_data method calls get_taus when taus is None."""
211+
pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0)
212+
213+
# Define the test data
214+
test_taus = np.array([1.5, 2.5, 3.5])
215+
test_bars = np.array([[0.3, 0.4, 0.5], [1.7, 1.8, 1.9]])
216+
test_residues = np.array(["R200", "R201", "R202"])
217+
218+
# Set up mock to return values AND set instance attributes (like real get_taus)
219+
def mock_get_taus_side_effect():
220+
pp.taus = test_taus
221+
pp.bars = test_bars
222+
pp.residues = test_residues
223+
return test_taus, test_bars
224+
225+
mock_get_taus.side_effect = mock_get_taus_side_effect
226+
227+
# Create output file in temporary directory
228+
output_file = tmp_path / "test_taus_from_get_taus"
229+
230+
# Ensure taus is None to trigger get_taus call
231+
pp.taus = None
232+
233+
# Call write_data
234+
pp.write_data(str(output_file))
235+
236+
# Verify get_taus was called
237+
mock_get_taus.assert_called_once()
238+
239+
# Verify the file was created and contains expected data
240+
assert output_file.with_suffix('.npy').exists()
241+
saved_data = np.load(str(output_file) + '.npy')
242+
243+
expected_data = np.array([
244+
[200, 1.5, 0.3, 1.7], # R200 -> 200
245+
[201, 2.5, 0.4, 1.8], # R201 -> 201
246+
[202, 3.5, 0.5, 1.9] # R202 -> 202
247+
])
248+
249+
assert np.array_equal(saved_data, expected_data)
250+
251+
@patch('basicrta.cluster.glob')
252+
@patch('basicrta.cluster.Pool')
253+
@patch('basicrta.util.get_bars')
254+
def test_get_taus_returns_values(self, mock_get_bars, mock_pool, mock_glob):
255+
"""Test that get_taus method returns values as documented in docstring."""
256+
pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0)
257+
258+
# Mock the directory structure
259+
mock_glob.return_value = ["basicrta-7.0/R100", "basicrta-7.0/R101"]
260+
261+
# Mock the multiprocessing pool to return test data
262+
mock_pool_instance = mock_pool.return_value.__enter__.return_value
263+
mock_imap_results = [
264+
("R100", [0.1, 1.5, 2.8], "path1"),
265+
("R101", [0.2, 2.0, 3.2], "path2")
266+
]
267+
mock_pool_instance.imap.return_value = mock_imap_results
268+
269+
# Mock get_bars to return test confidence intervals
270+
test_bars = np.array([[0.5, 0.6], [2.5, 2.6]])
271+
mock_get_bars.return_value = test_bars
272+
273+
# Call get_taus and verify it returns values
274+
result = pp.get_taus(nproc=1)
275+
276+
# Verify the method returns a tuple as documented
277+
assert isinstance(result, tuple)
278+
assert len(result) == 2
279+
280+
returned_taus, returned_bars = result
281+
282+
# Verify the returned values match the instance attributes
283+
assert np.array_equal(returned_taus, pp.taus)
284+
assert np.array_equal(returned_bars, pp.bars)
285+
286+
# Verify the values are what we expect
287+
expected_taus = np.array([1.5, 2.0]) # Middle values from tau arrays
288+
assert np.array_equal(returned_taus, expected_taus)
289+
assert np.array_equal(returned_bars, test_bars)
290+
291+
def test_write_data_with_default_filename(self, tmp_path):
292+
"""Test write_data method uses default filename when none provided."""
293+
pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0)
294+
295+
# Set up test data
296+
pp.residues = np.array(["R300", "R301"])
297+
pp.taus = np.array([4.0, 5.0])
298+
pp.bars = np.array([[0.8, 0.9], [2.0, 2.1]])
299+
300+
with work_in(tmp_path):
301+
# Call write_data without filename (should use default)
302+
pp.write_data()
303+
304+
# Verify default file was created
305+
default_file = tmp_path / "tausout.npy"
306+
assert default_file.exists()
307+
308+
# Verify data integrity
309+
saved_data = np.load(default_file)
310+
expected_data = np.array([
311+
[300, 4.0, 0.8, 2.0],
312+
[301, 5.0, 0.9, 2.1]
313+
])
314+
assert np.array_equal(saved_data, expected_data)
315+
178316

179317
class TestClusterScript:
180318
"""Tests for the command-line script functionality."""

0 commit comments

Comments
 (0)