@@ -206,11 +206,15 @@ def __repr__(self) -> str:
206206 def __str__ (self ) -> str :
207207 return self .__repr__ ()
208208
209- def find_clusters (
210- self , relationship_condition : t .Callable [[Relationship ], bool ] = lambda _ : True
209+ def find_indirect_clusters (
210+ self ,
211+ relationship_condition : t .Callable [[Relationship ], bool ] = lambda _ : True ,
212+ depth_limit : int = 3 ,
211213 ) -> t .List [t .Set [Node ]]:
212214 """
213- Finds clusters of nodes in the knowledge graph based on a relationship condition.
215+ Finds indirect clusters of nodes in the knowledge graph based on a relationship condition.
216+ Here if A -> B -> C -> D, then A, B, C, and D form a cluster. If there's also a path A -> B -> C -> E,
217+ it will form a separate cluster.
214218
215219 Parameters
216220 ----------
@@ -223,31 +227,95 @@ def find_clusters(
223227 A list of sets, where each set contains nodes that form a cluster.
224228 """
225229 clusters = []
226- visited = set ()
230+ visited_paths = set ()
227231
228232 relationships = [
229233 rel for rel in self .relationships if relationship_condition (rel )
230234 ]
231235
232- def dfs (node : Node , cluster : t .Set [Node ]):
233- visited .add (node )
236+ def dfs (node : Node , cluster : t .Set [Node ], depth : int , path : t .Tuple [Node , ...]):
237+ if depth >= depth_limit or path in visited_paths :
238+ return
239+ visited_paths .add (path )
234240 cluster .add (node )
241+
235242 for rel in relationships :
236- if rel . source == node and rel . target not in visited :
237- dfs ( rel .target , cluster )
238- # if the relationship is bidirectional, we need to check the reverse
243+ neighbor = None
244+ if rel . source == node and rel .target not in cluster :
245+ neighbor = rel . target
239246 elif (
240247 rel .bidirectional
241248 and rel .target == node
242- and rel .source not in visited
249+ and rel .source not in cluster
243250 ):
244- dfs (rel .source , cluster )
251+ neighbor = rel .source
252+
253+ if neighbor is not None :
254+ dfs (neighbor , cluster .copy (), depth + 1 , path + (neighbor ,))
255+
256+ # Add completed path-based cluster
257+ if len (cluster ) > 1 :
258+ clusters .append (cluster )
245259
246260 for node in self .nodes :
247- if node not in visited :
248- cluster = set ()
249- dfs (node , cluster )
250- if len (cluster ) > 1 :
261+ initial_cluster = set ()
262+ dfs (node , initial_cluster , 0 , (node ,))
263+
264+ # Remove duplicates by converting clusters to frozensets
265+ unique_clusters = [
266+ set (cluster ) for cluster in set (frozenset (c ) for c in clusters )
267+ ]
268+
269+ return unique_clusters
270+
271+ def find_direct_clusters (
272+ self , relationship_condition : t .Callable [[Relationship ], bool ] = lambda _ : True
273+ ) -> t .Dict [Node , t .List [t .Set [Node ]]]:
274+ """
275+ Finds direct clusters of nodes in the knowledge graph based on a relationship condition.
276+ Here if A->B, and A->C, then A, B, and C form a cluster.
277+
278+ Parameters
279+ ----------
280+ relationship_condition : Callable[[Relationship], bool], optional
281+ A function that takes a Relationship and returns a boolean, by default lambda _: True
282+
283+ Returns
284+ -------
285+ List[Set[Node]]
286+ A list of sets, where each set contains nodes that form a cluster.
287+ """
288+
289+ clusters = []
290+ relationships = [
291+ rel for rel in self .relationships if relationship_condition (rel )
292+ ]
293+ for node in self .nodes :
294+ cluster = set ()
295+ cluster .add (node )
296+ for rel in relationships :
297+ if rel .bidirectional :
298+ if rel .source == node :
299+ cluster .add (rel .target )
300+ elif rel .target == node :
301+ cluster .add (rel .source )
302+ else :
303+ if rel .source == node :
304+ cluster .add (rel .target )
305+
306+ if len (cluster ) > 1 :
307+ if cluster not in clusters :
251308 clusters .append (cluster )
252309
253- return clusters
310+ # Remove subsets from clusters
311+ unique_clusters = []
312+ for cluster in clusters :
313+ if not any (cluster < other for other in clusters ):
314+ unique_clusters .append (cluster )
315+ clusters = unique_clusters
316+
317+ cluster_dict = {}
318+ for cluster in clusters :
319+ cluster_dict .update ({cluster .pop (): cluster })
320+
321+ return cluster_dict
0 commit comments