Skip to content

Commit 9a6d72c

Browse files
Krsto ProrokovićKrsto Proroković
authored andcommitted
Add cluster_tree attribute and predict function to BAHC
1 parent 9310cb6 commit 9a6d72c

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from sklearn.base import ClusterMixin
2+
from typing import Self
3+
4+
class ClusterNode:
5+
def __init__(self, label: int):
6+
"""
7+
Initialize a node in the cluster tree
8+
9+
Parameters:
10+
-----------
11+
label : int
12+
The cluster label for this node (required as all nodes start as leaves)
13+
"""
14+
self.label = label
15+
self.clustering_model = None
16+
self.children = []
17+
18+
@property
19+
def is_leaf(self):
20+
return len(self.children) == 0
21+
22+
def split(self, clustering_model: ClusterMixin, children: list[Self]):
23+
"""
24+
Split this node by setting its clustering model and adding children
25+
26+
This converts the node to an internal node and removes its label
27+
28+
Parameters:
29+
-----------
30+
clustering_model : ClusterMixin
31+
The clustering model used to split this node
32+
children : list of ClusterNode
33+
The child nodes resulting from the split
34+
"""
35+
self.label = None
36+
self.clustering_model = clustering_model
37+
self.children = children
38+
39+
def get_leaves(self) -> list[Self]:
40+
"""
41+
Get all leaf nodes in the subtree rooted at this node
42+
43+
Returns:
44+
--------
45+
list of ClusterNode
46+
All leaf nodes in the subtree
47+
"""
48+
if not self.children:
49+
return [self]
50+
51+
leaves = []
52+
for child in self.children:
53+
leaves.extend(child.get_leaves())
54+
return leaves

0 commit comments

Comments
 (0)