1+ from ._cluster_node import ClusterNode
2+ from collections import deque
13import heapq
24from numbers import Integral
35import numpy as np
@@ -58,22 +60,25 @@ def fit(self, X, y):
5860 order = "C" ,
5961 )
6062 n_samples , _ = X .shape
61- # We start with all samples in a single cluster
63+ # We start with all samples being in a single cluster
6264 self .n_clusters_ = 1
6365 # We assign all samples a label of zero
6466 labels = np .zeros (n_samples , dtype = np .uint32 )
65- clusters = []
67+ leaves = []
6668 scores = []
6769 label = 0
70+ root = ClusterNode (label )
71+ self .cluster_tree_ = root
6872 # The entire dataset has a discrimination score of zero
6973 score = 0
70- heap = [(None , label , score )]
74+ heap = [(None , root , score )]
7175 for _ in range (self .bahc_max_iter ):
7276 if not heap :
7377 # If the heap is empty we stop iterating
7478 break
7579 # Take the cluster with the highest standard deviation of metric y
76- _ , label , score = heapq .heappop (heap )
80+ _ , node , score = heapq .heappop (heap )
81+ label = node .label
7782 cluster_indices = np .nonzero (labels == label )[0 ]
7883 cluster = X [cluster_indices ]
7984
@@ -97,32 +102,73 @@ def fit(self, X, y):
97102 mask1 [indices1 ] = False
98103 score1 = np .mean (y [mask1 ]) - np .mean (y [indices1 ])
99104 if max (score0 , score1 ) >= score :
105+ std0 = np .std (y [indices0 ])
106+ node0 = ClusterNode (label )
100107 # heapq implements min-heap
101108 # so we have to negate std before pushing
102- std0 = np .std (y [indices0 ])
103- heapq .heappush (heap , (- std0 , label , score0 ))
109+ heapq .heappush (heap , (- std0 , node0 , score0 ))
104110 std1 = np .std (y [indices1 ])
105- heapq .heappush (heap , (- std1 , self .n_clusters_ , score1 ))
111+ node1 = ClusterNode (self .n_clusters_ )
112+ heapq .heappush (heap , (- std1 , node1 , score1 ))
106113 labels [indices1 ] = self .n_clusters_
114+ # TODO: Increase n_clusters_ by clustering_model.n_clusters_ - 1
107115 self .n_clusters_ += 1
116+ children = [node0 , node1 ]
117+ node .split (clustering_model , children )
108118 else :
109- clusters .append (label )
119+ leaves .append (node )
110120 scores .append (score )
111121 else :
112- clusters .append (label )
122+ leaves .append (node )
113123 scores .append (score )
114124 if heap :
115- clusters = np .concatenate ([clusters , [label for _ , label , _ in heap ]])
125+ # TODO: Check if this can be made more efficient
126+ leaves .extend ((node for _ , node , _ in heap ))
116127 scores = np .concatenate ([scores , [score for _ , _ , score in heap ]])
117128 else :
118- clusters = np .array (clusters )
119129 scores = np .array (scores )
120130
121131 # We sort clusters by decreasing scores
122132 indices = np .argsort (- scores )
123- clusters = clusters [indices ]
124133 self .scores_ = scores [indices ]
125- mapping = np .zeros (self .n_clusters_ , dtype = np .uint32 )
126- mapping [clusters ] = np .arange (self .n_clusters_ , dtype = np .uint32 )
127- self .labels_ = mapping [labels ]
134+ leaf_labels = np .array ([leaf .label for leaf in leaves ])
135+ leaf_labels = leaf_labels [indices ]
136+ # TODO: Check this!!!
137+ for i , leaf in enumerate (leaves ):
138+ leaf .label = leaf_labels [i ]
139+ label_mapping = np .zeros (self .n_clusters_ , dtype = np .uint32 )
140+ label_mapping [leaf_labels ] = np .arange (self .n_clusters_ , dtype = np .uint32 )
141+ self .labels_ = label_mapping [labels ]
128142 return self
143+
144+ def predict (self , X ):
145+ """Predict the cluster labels for the given data.
146+
147+ Parameters
148+ ----------
149+ X : array-like of shape (n_samples, n_features)
150+ """
151+ # TODO: Assert that fit has been called
152+ # TODO: Assert that X has the same number of features as the data used to fit
153+ # TODO: Assert that clustering_model has predict method
154+ # TODO: Validate X
155+ n_samples , _ = X .shape
156+ labels = np .zeros (n_samples , dtype = np .uint32 )
157+ queue = deque ([(self .cluster_tree_ , np .arange (n_samples ))])
158+ while queue :
159+ node , indices = queue .popleft ()
160+ if node .is_leaf :
161+ labels [indices ] = node .label
162+ else :
163+ cluster = X [indices ]
164+ clustering_model = node .clustering_model
165+ cluster_labels = clustering_model .predict (cluster )
166+ if hasattr (clustering_model , "n_clusters_" ):
167+ n_clusters = clustering_model .n_clusters_
168+ else :
169+ n_clusters = len (np .unique (cluster_labels ))
170+ for i in range (n_clusters ):
171+ child_indices = indices [np .nonzero (cluster_labels == i )[0 ]]
172+ if child_indices .size > 0 :
173+ queue .append ((node .children [i ], child_indices ))
174+ return labels
0 commit comments