Skip to content

Commit 3b84b0f

Browse files
committed
add more tree tests
1 parent 9b735e0 commit 3b84b0f

File tree

5 files changed

+220
-9
lines changed

5 files changed

+220
-9
lines changed

prody/tests/datafiles/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
package that contains test modules and files as well."""
55

66

7-
from os.path import join, isfile, split, splitext
7+
from os.path import join, isfile, splitext
88
from prody.tests import TestCase
99

1010
from numpy import array
1111
import numpy as np
1212

1313
from prody import parsePDB, parseDCD, parseMMCIF, parseMMTF
14-
from prody import parseSparseMatrix, parseArray, loadModel
15-
from prody.tests import TEMPDIR, TESTDIR
14+
from prody import parseSparseMatrix, parseArray, parseTree, loadModel
15+
from prody.tests import TESTDIR
1616

1717

1818
DATA_FILES = {
@@ -453,6 +453,16 @@
453453
'n_atoms': 4,
454454
'long_resname': 'ACET',
455455
'short_resname': 'ACE'
456+
},
457+
'upgma_tree': {
458+
'file': 'simple_tree_upgma.nwk',
459+
'n_leaves': 4,
460+
'n_top_clades': 2,
461+
},
462+
'nj_tree': {
463+
'file': 'simple_tree_nj.nwk',
464+
'n_leaves': 4,
465+
'n_top_clades': 3,
456466
}
457467
}
458468

@@ -463,7 +473,8 @@
463473
'.coo': parseSparseMatrix, '.dat': parseArray,
464474
'.txt': np.loadtxt,
465475
'.npy': lambda fn, **kwargs: np.load(fn, allow_pickle=True),
466-
'.gz': lambda fn, **kwargs: PARSERS[splitext(fn)[1]](fn, **kwargs)
476+
'.gz': lambda fn, **kwargs: PARSERS[splitext(fn)[1]](fn, **kwargs),
477+
'.nwk': parseTree
467478
}
468479

469480

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
(A:0.25000,(C:1.00000,B:0.50000):0.50000,D:0.75000):0.00000;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
((C:0.75000,B:0.75000):0.12500,(D:0.50000,A:0.50000):0.37500):0.00000;

prody/tests/utilities/test_catchall.py

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""This module contains some unit tests for :mod:`prody.utilities.catchall` module,
2-
starting with calcTree."""
2+
starting with tree-related ones."""
33

4+
import os
5+
import tempfile
46
import numpy as np
57

68
from prody.tests import unittest
7-
from prody.utilities import calcTree
9+
from prody.utilities import calcTree, findSubgroups, writeTree, parseTree
10+
from prody.tests.datafiles import parseDatafile, pathDatafile
811

912
class TestCalcTree(unittest.TestCase):
1013

@@ -49,5 +52,176 @@ def testCalcTreeMismatchSize(self):
4952
with self.assertRaises(ValueError):
5053
calcTree(names, distance_matrix)
5154

55+
56+
class TestFindSubgroups(unittest.TestCase):
57+
58+
def setUp(self):
59+
"""Set up a test tree for findSubgroups tests."""
60+
# Create a simple distance matrix with clear clustering
61+
# Points A,B are close (distance 0.5), C,D are close (distance 0.5)
62+
# But A,B are far from C,D (distance 5)
63+
self.names = ['A', 'B', 'C', 'D']
64+
self.distance_matrix = np.array([[0.0, 0.5, 5.0, 5.0],
65+
[0.5, 0.0, 5.0, 5.0],
66+
[5.0, 5.0, 0.0, 0.5],
67+
[5.0, 5.0, 0.5, 0.0]])
68+
self.tree = calcTree(self.names, self.distance_matrix, method='upgma')
69+
70+
def testFindSubgroupsNaiveMethod(self):
71+
"""Test findSubgroups with naive method."""
72+
# Using cutoff 2.0 should separate into 2 subgroups
73+
subgroups = findSubgroups(self.tree, 2.0, method='naive')
74+
self.assertIsNotNone(subgroups)
75+
self.assertEqual(len(subgroups), 2)
76+
# Check that subgroups contain the expected names
77+
all_names = [name for subgroup in subgroups for name in subgroup]
78+
self.assertEqual(set(all_names), set(self.names))
79+
80+
def testFindSubgroupsNaiveLargeCutoff(self):
81+
"""Test findSubgroups with naive method and large cutoff."""
82+
# Using cutoff 10.0 should keep everything in one subgroup
83+
subgroups = findSubgroups(self.tree, 10.0, method='naive')
84+
self.assertEqual(len(subgroups), 1)
85+
self.assertEqual(set(subgroups[0]), set(self.names))
86+
87+
def testFindSubgroupsNaiveTinyCutoff(self):
88+
"""Test findSubgroups with naive method and tiny cutoff."""
89+
# Using cutoff 0.1 should separate all into individual subgroups
90+
subgroups = findSubgroups(self.tree, 0.1, method='naive')
91+
self.assertEqual(len(subgroups), 4)
92+
# Each subgroup should have one member
93+
for subgroup in subgroups:
94+
self.assertEqual(len(subgroup), 1)
95+
96+
def testFindSubgroupsReturnsListOfLists(self):
97+
"""Test that findSubgroups returns a list of lists."""
98+
subgroups = findSubgroups(self.tree, 2.0, method='naive')
99+
self.assertIsInstance(subgroups, list)
100+
for subgroup in subgroups:
101+
self.assertIsInstance(subgroup, list)
102+
103+
104+
class TestParseTree(unittest.TestCase):
105+
106+
def testParseUPGMATree(self):
107+
"""Test parsing an UPGMA tree from a file."""
108+
tree_fn = pathDatafile('upgma_tree')
109+
tree = parseTree(tree_fn)
110+
self.assertIsNotNone(tree)
111+
# Check that tree has expected number of leaves
112+
leaves = tree.get_terminals()
113+
self.assertEqual(len(leaves), 4)
114+
# Check that tree has expected number of top-level clades
115+
self.assertEqual(len(tree.root.clades), 2)
116+
117+
def testParseNJTree(self):
118+
"""Test parsing a neighbor-joining tree from a file."""
119+
tree_fn = pathDatafile('nj_tree')
120+
tree = parseTree(tree_fn)
121+
self.assertIsNotNone(tree)
122+
# Check that tree has expected number of leaves
123+
leaves = tree.get_terminals()
124+
self.assertEqual(len(leaves), 4)
125+
# Check that tree has expected number of top-level clades
126+
self.assertEqual(len(tree.root.clades), 3)
127+
128+
def testParseTreeTreeType(self):
129+
"""Test that parseTree returns a Biopython Tree object."""
130+
try:
131+
from Bio import Phylo
132+
tree = parseDatafile('upgma_tree')
133+
self.assertIsInstance(tree, Phylo.BaseTree.Tree)
134+
except ImportError:
135+
self.skipTest("Biopython not available")
136+
137+
def testParseTreeWrongFilepath(self):
138+
"""Test parseTree with non-existent file."""
139+
with self.assertRaises((AssertionError, FileNotFoundError)):
140+
parseTree('/nonexistent/path/to/tree.nwk')
141+
142+
def testParseTreeWrongFileType(self):
143+
"""Test parseTree with invalid filename argument."""
144+
with self.assertRaises(TypeError):
145+
parseTree(123)
146+
147+
148+
class TestWriteTree(unittest.TestCase):
149+
150+
def setUp(self):
151+
"""Set up test trees for writing."""
152+
self.upgma_tree = parseDatafile('upgma_tree')
153+
self.nj_tree = parseDatafile('nj_tree')
154+
# Create a temporary directory for test files
155+
self.temp_dir = tempfile.mkdtemp()
156+
157+
def tearDown(self):
158+
"""Clean up temporary test files."""
159+
import shutil
160+
if os.path.exists(self.temp_dir):
161+
shutil.rmtree(self.temp_dir)
162+
163+
def testWriteUPGMATree(self):
164+
"""Test writing an UPGMA tree to a file."""
165+
output_file = os.path.join(self.temp_dir, 'test_upgma.nwk')
166+
try:
167+
writeTree(output_file, self.upgma_tree)
168+
self.assertTrue(os.path.exists(output_file))
169+
# Verify the file is not empty
170+
with open(output_file, 'r') as f:
171+
content = f.read()
172+
self.assertTrue(len(content) > 0)
173+
# Check for Newick format markers
174+
self.assertIn(';', content)
175+
except ImportError:
176+
self.skipTest("Biopython not available")
177+
178+
def testWriteNJTree(self):
179+
"""Test writing a neighbor-joining tree to a file."""
180+
output_file = os.path.join(self.temp_dir, 'test_nj.nwk')
181+
try:
182+
writeTree(output_file, self.nj_tree)
183+
self.assertTrue(os.path.exists(output_file))
184+
# Verify the file is not empty
185+
with open(output_file, 'r') as f:
186+
content = f.read()
187+
self.assertTrue(len(content) > 0)
188+
# Check for Newick format markers
189+
self.assertIn(';', content)
190+
except ImportError:
191+
self.skipTest("Biopython not available")
192+
193+
def testWriteTreeWrongFilename(self):
194+
"""Test writeTree with invalid filename argument."""
195+
with self.assertRaises(TypeError):
196+
writeTree(123, self.upgma_tree)
197+
198+
def testWriteTreeWrongTreeType(self):
199+
"""Test writeTree with invalid tree argument."""
200+
output_file = os.path.join(self.temp_dir, 'test.nwk')
201+
with self.assertRaises(TypeError):
202+
writeTree(output_file, "not a tree")
203+
204+
def testWriteTreeWrongFormat(self):
205+
"""Test writeTree with invalid format argument."""
206+
output_file = os.path.join(self.temp_dir, 'test.nwk')
207+
with self.assertRaises(TypeError):
208+
writeTree(output_file, self.upgma_tree, format_str=123)
209+
210+
def testWriteAndParseRoundtrip(self):
211+
"""Test writing a tree and then parsing it back."""
212+
output_file = os.path.join(self.temp_dir, 'roundtrip.nwk')
213+
try:
214+
# Write the tree
215+
writeTree(output_file, self.upgma_tree)
216+
# Parse it back
217+
parsed_tree = parseTree(output_file)
218+
# Verify the parsed tree is valid
219+
self.assertIsNotNone(parsed_tree)
220+
leaves = parsed_tree.get_terminals()
221+
self.assertEqual(len(leaves), 4)
222+
self.assertEqual(len(parsed_tree.root.clades), 2)
223+
except ImportError:
224+
self.skipTest("Biopython not available")
225+
52226
if __name__ == '__main__':
53227
unittest.main()

prody/utilities/catchall.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from .logger import LOGGER
99

1010

11-
__all__ = ['calcTree', 'clusterMatrix',
11+
__all__ = ['calcTree', 'writeTree', 'parseTree',
12+
'clusterMatrix',
1213
'showLines', 'showMatrix', 'showBars',
1314
'reorderMatrix', 'findSubgroups', 'getCoords',
1415
'getLinkage', 'getTreeFromLinkage', 'clusterSubfamilies',
@@ -239,7 +240,7 @@ def getTreeFromLinkage(names, linkage):
239240
:arg linkage: linkage matrix
240241
:type linkage: :class:`~numpy.ndarray`
241242
"""
242-
try:
243+
try:
243244
from Bio.Phylo.BaseTree import Tree, Clade
244245
except ImportError:
245246
raise ImportError('Phylo module could not be imported. '
@@ -367,7 +368,7 @@ def writeTree(filename, tree, format_str='newick'):
367368
:arg format_str: a string specifying the format for the tree
368369
:type format_str: str
369370
"""
370-
try:
371+
try:
371372
from Bio import Phylo
372373
except ImportError:
373374
raise ImportError('Phylo module could not be imported. '
@@ -385,6 +386,29 @@ def writeTree(filename, tree, format_str='newick'):
385386

386387
Phylo.write(tree, filename, format_str)
387388

389+
def parseTree(filename, format_str='newick'):
390+
""" Parse a tree from a file using Biopython.
391+
392+
:arg filename: name for output file
393+
:type filename: str
394+
395+
:arg format_str: a string specifying the format for the tree
396+
:type format_str: str
397+
"""
398+
try:
399+
from Bio import Phylo
400+
except ImportError:
401+
raise ImportError('Phylo module could not be imported. '
402+
'Reinstall ProDy or install Biopython '
403+
'to solve the problem.')
404+
405+
if not isinstance(filename, str):
406+
raise TypeError('filename should be a string')
407+
408+
if not isinstance(format_str, str):
409+
raise TypeError('format_str should be a string')
410+
411+
return Phylo.read(filename, format_str)
388412

389413
def clusterMatrix(distance_matrix=None, similarity_matrix=None, labels=None, return_linkage=None, **kwargs):
390414
"""

0 commit comments

Comments
 (0)