Skip to content

Commit ff0b76c

Browse files
Working on lint
1 parent 1d9c3b2 commit ff0b76c

File tree

11 files changed

+59
-95
lines changed

11 files changed

+59
-95
lines changed

dialogue2graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from dialogue2graph.pipelines.core.dialogue import Dialogue
2-
from dialogue2graph.pipelines.core.graph import Graph
2+
from dialogue2graph.pipelines.core.graph import Graph

dialogue2graph/datasets/complex_dialogues/generation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
repair_prompt: Optional[PromptTemplate],
8585
min_cycles: int = 0,
8686
max_fix_attempts: int = 2,
87-
use_cache: bool = True
87+
use_cache: bool = True,
8888
):
8989
super().__init__(
9090
generation_model=generation_model,
@@ -93,7 +93,7 @@ def __init__(
9393
repair_prompt=repair_prompt,
9494
min_cycles=min_cycles,
9595
max_fix_attempts=max_fix_attempts,
96-
use_cache= True
96+
use_cache=True,
9797
)
9898

9999
def validate_graph_cycle_requirement(self, graph: BaseGraph, min_cycles: int = 2) -> Dict[str, Any]:
@@ -137,7 +137,7 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
137137
prompt=self.repair_prompt,
138138
invalid_transitions=initial_validation["invalid_transitions"],
139139
graph_json=current_graph.graph_dict,
140-
use_cache=self.use_cache
140+
use_cache=self.use_cache,
141141
)
142142

143143
validation = are_triplets_valid(current_graph, self.validation_model, return_type="detailed")
@@ -199,7 +199,8 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
199199
invalid_transitions = transition_validation["validation_details"]["invalid_transitions"]
200200
return GenerationError(
201201
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
202-
message=f"Found {len(invalid_transitions)} invalid transitions after {transition_validation['validation_details']['attempts_made']} fix attempts",
202+
message=f"Found {len(invalid_transitions)} invalid transitions"
203+
f"after {transition_validation['validation_details']['attempts_made']} fix attempts",
203204
)
204205

205206
return GraphGenerationResult(graph=transition_validation["graph"].graph_dict, topic=topic, dialogues=sampled_dialogues)

dialogue2graph/datasets/complex_dialogues/prompts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# flake8: noqa
12
from langchain.prompts import PromptTemplate
23

34
cycle_graph_generation_prompt = PromptTemplate.from_template(

dialogue2graph/pipelines/core/dialogue_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dialogue2graph.pipelines.core.graph import BaseGraph
55
from dialogue2graph.pipelines.core.dialogue import Dialogue
66
from dialogue2graph.pipelines.core.algorithms import DialogueGenerator
7-
from dialogue2graph.metrics.automatic_metrics import all_utterances_present, all_roles_correct
7+
from dialogue2graph.metrics.automatic_metrics import all_utterances_present
88

99

1010
# @AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue)

dialogue2graph/pipelines/cycled_graphs/prompts/prompts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# flake8: noqa
12
from langchain.prompts import PromptTemplate
23

34

dialogue2graph/utils/prompt_caching.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
from dotenv import load_dotenv
44
from typing import Optional
55
from langchain_core.globals import set_llm_cache
6-
from langchain_openai import ChatOpenAI
76
from langchain_community.cache import SQLAlchemyCache, Base
87
from langchain_core.load.load import loads
98
from langchain_core.load.dump import dumps
10-
from langchain_core.outputs import ChatGeneration, Generation
9+
from langchain_core.outputs import Generation
1110

1211
from sqlalchemy import Column, Integer, String, create_engine, select, DateTime
1312
from sqlalchemy.sql import func

scripts/codestyle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _run_flake():
2020
# black formats binary operators after line breaks
2121
"--ignore=W503",
2222
"--ignore=E501",
23+
"--ignore=W293",
2324
"--per-file-ignores="
2425
# allow imports in init files without use
2526
"**/__init__.py:F401 ",

scripts/metrics.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,27 @@ def _run_check():
1515
if algorithm in previous_metrics:
1616
differences[algorithm] = {
1717
"all_paths_sampled_diff": metrics.get("all_paths_sampled_avg", 0) - previous_metrics[algorithm].get("all_paths_sampled_avg", 0),
18-
"all_utterances_present_diff": metrics.get("all_utterances_present_avg", 0)
19-
- previous_metrics[algorithm].get("all_utterances_present_avg", 0),
18+
"all_utterances_present_diff": metrics.get("all_utterances_present_avg", 0) -
19+
previous_metrics[algorithm].get("all_utterances_present_avg", 0),
2020
"all_roles_correct_diff": metrics.get("all_roles_correct_avg", 0) - previous_metrics[algorithm].get("all_roles_correct_avg", 0),
2121
"is_correct_length_diff": metrics.get("is_correct_lenght_avg", 0) - previous_metrics[algorithm].get("is_correct_lenght_avg", 0),
2222
"are_triplets_valid_diff": metrics.get("are_triplets_valid", 0) - previous_metrics[algorithm].get("are_triplets_valid", 0),
2323
"is_theme_valid_diff": metrics.get("is_theme_valid_avg", 0) - previous_metrics[algorithm].get("is_theme_valid_avg", 0),
2424
"total_diff": (
25-
metrics.get("all_paths_sampled_avg", 0)
26-
+ metrics.get("all_utterances_present_avg", 0)
27-
+ metrics.get("all_roles_correct_avg", 0)
28-
+ metrics.get("is_correct_lenght_avg", 0)
29-
+ metrics.get("are_triplets_valid", 0)
30-
+ metrics.get("is_theme_valid_avg", 0)
31-
)
32-
- (
33-
previous_metrics[algorithm].get("all_paths_sampled_avg", 0)
34-
+ previous_metrics[algorithm].get("all_utterances_present_avg", 0)
35-
+ previous_metrics[algorithm].get("all_roles_correct_avg", 0)
36-
+ previous_metrics[algorithm].get("is_correct_lenght_avg", 0)
37-
+ previous_metrics[algorithm].get("are_triplets_valid", 0)
38-
+ metrics.get("is_theme_valid_avg", 0)
25+
metrics.get("all_paths_sampled_avg", 0) +
26+
metrics.get("all_utterances_present_avg", 0) +
27+
metrics.get("all_roles_correct_avg", 0) +
28+
metrics.get("is_correct_lenght_avg", 0) +
29+
metrics.get("are_triplets_valid", 0) +
30+
metrics.get("is_theme_valid_avg", 0)
31+
) -
32+
(
33+
previous_metrics[algorithm].get("all_paths_sampled_avg", 0) +
34+
previous_metrics[algorithm].get("all_utterances_present_avg", 0) +
35+
previous_metrics[algorithm].get("all_roles_correct_avg", 0) +
36+
previous_metrics[algorithm].get("is_correct_lenght_avg", 0) +
37+
previous_metrics[algorithm].get("are_triplets_valid", 0) +
38+
metrics.get("is_theme_valid_avg", 0)
3939
),
4040
}
4141

tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# Test package initialization
1+
# Test package initialization

tests/test_initialization.py

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,84 +4,66 @@
44
GenerationPipeline,
55
LoopedGraphGenerator,
66
ErrorType,
7-
GenerationError
7+
GenerationError,
88
)
99
from dialogue2graph.pipelines.core.dialogue import Dialogue
1010
from dialogue2graph.pipelines.core.graph import Graph, BaseGraph
1111
from dialogue2graph.pipelines.core.dialogue_sampling import RecursiveDialogueSampler
12-
from dialogue2graph.metrics.automatic_metrics import (
13-
jaccard_edges,
14-
jaccard_nodes,
15-
triplet_match,
16-
all_utterances_present
17-
)
12+
from dialogue2graph.metrics.automatic_metrics import jaccard_edges, jaccard_nodes, triplet_match, all_utterances_present
1813
from langchain_core.language_models.chat_models import BaseChatModel
1914

15+
2016
class MockChatModel(BaseChatModel):
2117
"""Mock chat model for testing"""
22-
18+
2319
def _generate(self, *args, **kwargs):
2420
return {"generations": [{"text": "test response"}]}
25-
21+
2622
def _llm_type(self):
2723
return "mock"
2824

25+
2926
def test_cycle_graph_generator_init():
3027
"""Test CycleGraphGenerator initialization"""
3128
generator = CycleGraphGenerator()
3229
assert isinstance(generator, CycleGraphGenerator)
3330

31+
3432
def test_generation_pipeline_init():
3533
"""Test GenerationPipeline initialization"""
3634
model = MockChatModel()
37-
pipeline = GenerationPipeline(
38-
generation_model=model,
39-
validation_model=model,
40-
generation_prompt=None,
41-
repair_prompt=None
42-
)
35+
pipeline = GenerationPipeline(generation_model=model, validation_model=model, generation_prompt=None, repair_prompt=None)
4336
assert isinstance(pipeline, GenerationPipeline)
4437

38+
4539
def test_looped_graph_generator_init():
4640
"""Test LoopedGraphGenerator initialization"""
4741
model = MockChatModel()
48-
generator = LoopedGraphGenerator(
49-
generation_model=model,
50-
validation_model=model
51-
)
42+
generator = LoopedGraphGenerator(generation_model=model, validation_model=model)
5243
assert isinstance(generator, LoopedGraphGenerator)
5344

45+
5446
def test_dialogue_init():
5547
"""Test Dialogue initialization"""
56-
messages = [
57-
{"participant": "assistant", "text": "Hello"},
58-
{"participant": "user", "text": "Hi"}
59-
]
48+
messages = [{"participant": "assistant", "text": "Hello"}, {"participant": "user", "text": "Hi"}]
6049
dialogue = Dialogue.from_list(messages)
6150
assert isinstance(dialogue, Dialogue)
6251
assert len(dialogue.messages) == 2
6352

53+
6454
def test_graph_init():
6555
"""Test Graph initialization"""
66-
graph_dict = {
67-
"nodes": [
68-
{
69-
"id": 1,
70-
"label": "start",
71-
"is_start": True,
72-
"utterances": ["Hello"]
73-
}
74-
],
75-
"edges": []
76-
}
56+
graph_dict = {"nodes": [{"id": 1, "label": "start", "is_start": True, "utterances": ["Hello"]}], "edges": []}
7757
graph = Graph(graph_dict=graph_dict)
7858
assert isinstance(graph, BaseGraph)
7959

60+
8061
def test_recursive_dialogue_sampler_init():
8162
"""Test RecursiveDialogueSampler initialization"""
8263
sampler = RecursiveDialogueSampler()
8364
assert isinstance(sampler, RecursiveDialogueSampler)
8465

66+
8567
def test_error_type_enum():
8668
"""Test ErrorType enum initialization"""
8769
assert ErrorType.INVALID_GRAPH_STRUCTURE == "invalid_graph_structure"
@@ -90,12 +72,10 @@ def test_error_type_enum():
9072
assert ErrorType.INVALID_THEME == "invalid_theme"
9173
assert ErrorType.GENERATION_FAILED == "generation_failed"
9274

75+
9376
def test_generation_error_init():
9477
"""Test GenerationError initialization"""
95-
error = GenerationError(
96-
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
97-
message="Test error"
98-
)
78+
error = GenerationError(error_type=ErrorType.INVALID_GRAPH_STRUCTURE, message="Test error")
9979
assert isinstance(error, GenerationError)
10080
assert error.error_type == ErrorType.INVALID_GRAPH_STRUCTURE
101-
assert error.message == "Test error"
81+
assert error.message == "Test error"

0 commit comments

Comments
 (0)