77import numpy as np
88
99from ragas .prompt import PydanticPrompt
10- from ragas .testset .graph import KnowledgeGraph
10+ from ragas .testset .graph import KnowledgeGraph , Node
1111from ragas .testset .persona import Persona , PersonaList
1212from ragas .testset .synthesizers .multi_hop .base import (
1313 MultiHopQuerySynthesizer ,
@@ -38,9 +38,26 @@ class MultiHopSpecificQuerySynthesizer(MultiHopQuerySynthesizer):
3838 """
3939
4040 name : str = "multi_hop_specific_query_synthesizer"
41+ relation_type : str = "entities_overlap"
42+ property_name : str = "entities"
4143 theme_persona_matching_prompt : PydanticPrompt = ThemesPersonasMatchingPrompt ()
4244 generate_query_reference_prompt : PydanticPrompt = QueryAnswerGenerationPrompt ()
4345
46+ def get_node_clusters (self , knowledge_graph : KnowledgeGraph ) -> t .List [t .Set [Node ]]:
47+
48+ cluster_dict = knowledge_graph .find_direct_clusters (
49+ relationship_condition = lambda rel : (
50+ True if rel .type == self .relation_type else False
51+ )
52+ )
53+ logger .info ("found %d clusters" , len (cluster_dict ))
54+ node_clusters = []
55+ for key_node , list_of_nodes in cluster_dict .items ():
56+ for node in list_of_nodes :
57+ node_clusters .append ((key_node , node ))
58+
59+ return node_clusters
60+
4461 async def _generate_scenarios (
4562 self ,
4663 n : int ,
@@ -61,26 +78,21 @@ async def _generate_scenarios(
6178 4. Return the list of scenarios of length n
6279 """
6380
64- cluster_dict = knowledge_graph .find_direct_clusters (
65- relationship_condition = lambda rel : (
66- True if rel .type == "entities_overlap" else False
81+ node_clusters = self .get_node_clusters (knowledge_graph )
82+
83+ if len (node_clusters ) == 0 :
84+ raise ValueError (
85+ "No clusters found in the knowledge graph. Try changing the relationship condition."
6786 )
68- )
87+
88+ num_sample_per_cluster = int (np .ceil (n / len (node_clusters )))
6989
7090 valid_relationships = [
7191 rel
7292 for rel in knowledge_graph .relationships
73- if rel .type == "entities_overlap"
93+ if rel .type == self . relation_type
7494 ]
75-
76- node_clusters = []
77- for key_node , list_of_nodes in cluster_dict .items ():
78- for node in list_of_nodes :
79- node_clusters .append ((key_node , node ))
80-
81- logger .info ("found %d clusters" , len (cluster_dict ))
8295 scenarios = []
83- num_sample_per_cluster = int (np .ceil (n / len (node_clusters )))
8496
8597 for cluster in node_clusters :
8698 if len (scenarios ) < n :
@@ -106,7 +118,7 @@ async def _generate_scenarios(
106118 overlapped_items ,
107119 PersonaList (personas = persona_list ),
108120 persona_concepts ,
109- property_name = "entities" ,
121+ property_name = self . property_name ,
110122 )
111123 base_scenarios = self .sample_diverse_combinations (
112124 base_scenarios , num_sample_per_cluster
0 commit comments