Skip to content

Commit 9b735e0

Browse files
committed
add calcTree tests to confirm not broken
1 parent e60e937 commit 9b735e0

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""This module contains some unit tests for :mod:`prody.utilities.catchall` module,
2+
starting with calcTree."""
3+
4+
import numpy as np
5+
6+
from prody.tests import unittest
7+
from prody.utilities import calcTree
8+
9+
class TestCalcTree(unittest.TestCase):
10+
11+
def testCalcTreeUPGMA(self):
12+
"""Test calcTree with UPGMA method."""
13+
names = ['A', 'B', 'C', 'D']
14+
distance_matrix = np.array([[0, 1, 2, 1],
15+
[1, 0, 1.5, 2],
16+
[2, 1.5, 0, 2],
17+
[1, 2, 2, 0]])
18+
tree = calcTree(names, distance_matrix, method='upgma')
19+
self.assertIsNotNone(tree)
20+
# Check that tree has 4 leaves and they include the names
21+
leaves = tree.get_terminals()
22+
self.assertEqual(len(leaves), 4)
23+
self.assertEqual(set([leaf.name for leaf in leaves]), set(names))
24+
# Check that the tree has split evenly as expected for UPGMA
25+
self.assertEqual(len(tree.root.clades), 2)
26+
27+
def testCalcTreeNJ(self):
28+
"""Test calcTree with NJ method."""
29+
names = ['A', 'B', 'C', 'D']
30+
distance_matrix = np.array([[0, 1, 2, 1],
31+
[1, 0, 1.5, 2],
32+
[2, 1.5, 0, 2],
33+
[1, 2, 2, 0]])
34+
tree = calcTree(names, distance_matrix, method='nj')
35+
self.assertIsNotNone(tree)
36+
leaves = tree.get_terminals()
37+
# Check that tree has 4 leaves and they include the names
38+
self.assertEqual(len(leaves), 4)
39+
self.assertEqual(set([leaf.name for leaf in leaves]), set(names))
40+
# Check that the tree has split unevenly as expected for NJ
41+
self.assertEqual(len(tree.root.clades), 3)
42+
43+
def testCalcTreeMismatchSize(self):
44+
"""Test calcTree with mismatched names and matrix sizes."""
45+
names = ['A', 'B']
46+
distance_matrix = np.array([[0, 1, 2],
47+
[1, 0, 1.5],
48+
[2, 1.5, 0]])
49+
with self.assertRaises(ValueError):
50+
calcTree(names, distance_matrix)
51+
52+
if __name__ == '__main__':
53+
unittest.main()

0 commit comments

Comments
 (0)