44 GenerationPipeline ,
55 LoopedGraphGenerator ,
66 ErrorType ,
7- GenerationError
7+ GenerationError ,
88)
99from dialogue2graph .pipelines .core .dialogue import Dialogue
1010from dialogue2graph .pipelines .core .graph import Graph , BaseGraph
1111from dialogue2graph .pipelines .core .dialogue_sampling import RecursiveDialogueSampler
12- from dialogue2graph .metrics .automatic_metrics import (
13- jaccard_edges ,
14- jaccard_nodes ,
15- triplet_match ,
16- all_utterances_present
17- )
12+ from dialogue2graph .metrics .automatic_metrics import jaccard_edges , jaccard_nodes , triplet_match , all_utterances_present
1813from langchain_core .language_models .chat_models import BaseChatModel
1914
15+
2016class MockChatModel (BaseChatModel ):
2117 """Mock chat model for testing"""
22-
18+
2319 def _generate (self , * args , ** kwargs ):
2420 return {"generations" : [{"text" : "test response" }]}
25-
21+
2622 def _llm_type (self ):
2723 return "mock"
2824
25+
2926def test_cycle_graph_generator_init ():
3027 """Test CycleGraphGenerator initialization"""
3128 generator = CycleGraphGenerator ()
3229 assert isinstance (generator , CycleGraphGenerator )
3330
31+
3432def test_generation_pipeline_init ():
3533 """Test GenerationPipeline initialization"""
3634 model = MockChatModel ()
37- pipeline = GenerationPipeline (
38- generation_model = model ,
39- validation_model = model ,
40- generation_prompt = None ,
41- repair_prompt = None
42- )
35+ pipeline = GenerationPipeline (generation_model = model , validation_model = model , generation_prompt = None , repair_prompt = None )
4336 assert isinstance (pipeline , GenerationPipeline )
4437
38+
4539def test_looped_graph_generator_init ():
4640 """Test LoopedGraphGenerator initialization"""
4741 model = MockChatModel ()
48- generator = LoopedGraphGenerator (
49- generation_model = model ,
50- validation_model = model
51- )
42+ generator = LoopedGraphGenerator (generation_model = model , validation_model = model )
5243 assert isinstance (generator , LoopedGraphGenerator )
5344
45+
5446def test_dialogue_init ():
5547 """Test Dialogue initialization"""
56- messages = [
57- {"participant" : "assistant" , "text" : "Hello" },
58- {"participant" : "user" , "text" : "Hi" }
59- ]
48+ messages = [{"participant" : "assistant" , "text" : "Hello" }, {"participant" : "user" , "text" : "Hi" }]
6049 dialogue = Dialogue .from_list (messages )
6150 assert isinstance (dialogue , Dialogue )
6251 assert len (dialogue .messages ) == 2
6352
53+
6454def test_graph_init ():
6555 """Test Graph initialization"""
66- graph_dict = {
67- "nodes" : [
68- {
69- "id" : 1 ,
70- "label" : "start" ,
71- "is_start" : True ,
72- "utterances" : ["Hello" ]
73- }
74- ],
75- "edges" : []
76- }
56+ graph_dict = {"nodes" : [{"id" : 1 , "label" : "start" , "is_start" : True , "utterances" : ["Hello" ]}], "edges" : []}
7757 graph = Graph (graph_dict = graph_dict )
7858 assert isinstance (graph , BaseGraph )
7959
60+
8061def test_recursive_dialogue_sampler_init ():
8162 """Test RecursiveDialogueSampler initialization"""
8263 sampler = RecursiveDialogueSampler ()
8364 assert isinstance (sampler , RecursiveDialogueSampler )
8465
66+
8567def test_error_type_enum ():
8668 """Test ErrorType enum initialization"""
8769 assert ErrorType .INVALID_GRAPH_STRUCTURE == "invalid_graph_structure"
@@ -90,12 +72,10 @@ def test_error_type_enum():
9072 assert ErrorType .INVALID_THEME == "invalid_theme"
9173 assert ErrorType .GENERATION_FAILED == "generation_failed"
9274
75+
9376def test_generation_error_init ():
9477 """Test GenerationError initialization"""
95- error = GenerationError (
96- error_type = ErrorType .INVALID_GRAPH_STRUCTURE ,
97- message = "Test error"
98- )
78+ error = GenerationError (error_type = ErrorType .INVALID_GRAPH_STRUCTURE , message = "Test error" )
9979 assert isinstance (error , GenerationError )
10080 assert error .error_type == ErrorType .INVALID_GRAPH_STRUCTURE
101- assert error .message == "Test error"
81+ assert error .message == "Test error"
0 commit comments