Skip to content

Commit c2b0512

Browse files
lint
1 parent 5dc7a20 commit c2b0512

File tree

15 files changed

+841
-457
lines changed

15 files changed

+841
-457
lines changed

dialogue2graph/pipelines/model_storage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def load(self, path: Path):
104104
try:
105105
with open(path, "r") as f:
106106
loaded_storage = yaml.safe_load(f)
107-
107+
108108
for key, config in loaded_storage.items():
109109
self.add(
110110
key=key, config=config, model_type=config.pop("model_type")
@@ -169,7 +169,9 @@ def save(self, path: str):
169169
for model_key in self.storage:
170170
storage_dump[model_key] = {}
171171
storage_dump[model_key]["config"] = self.storage[model_key].config
172-
storage_dump[model_key]["model_type"] = self.storage[model_key].model_type
172+
storage_dump[model_key]["model_type"] = self.storage[
173+
model_key
174+
].model_type
173175
yaml.dump(storage_dump, f)
174176
logger.info(f"Saved {len(self.storage)} models to {path}")
175177
except Exception as e:

docs/source/conf.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
sys.path.insert(0, os.path.abspath("../../dialogue2graph"))
1313

14-
project = 'Dialogue2Graph'
15-
copyright = '2024, Denis Kuznetsov, Anastasia Voznyuk, Andrey Chirkin'
16-
author = 'Denis Kuznetsov, Anastasia Voznyuk, Andrey Chirkin'
14+
project = "Dialogue2Graph"
15+
copyright = "2024, Denis Kuznetsov, Anastasia Voznyuk, Andrey Chirkin"
16+
author = "Denis Kuznetsov, Anastasia Voznyuk, Andrey Chirkin"
1717

1818
# Get the deployment environment
1919
on_github = os.environ.get("GITHUB_ACTIONS") == "true"
@@ -71,7 +71,10 @@
7171
html_static_path = ["_static"]
7272

7373
extlinks = {
74-
'github_source_link': ("https://github.com/deeppavlov/dialogue2graph/tree/dev/%s", None),
74+
"github_source_link": (
75+
"https://github.com/deeppavlov/dialogue2graph/tree/dev/%s",
76+
None,
77+
),
7578
}
7679

7780
# Add these configurations

experiments/exp2025_03_12_rec_models_incrementation/exp2025_03_12_rec_models_incrementation/append_chain.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
# compare_graphs
1818
# )
1919

20+
from dialogue2graph.metrics.no_llm_metrics import is_same_structure
21+
from dialogue2graph.metrics.llm_metrics import compare_graphs
2022

2123
env_settings = EnvSettings()
2224

25+
2326
# @AlgorithmRegistry.register(input_type=list[Dialogue], path_to_result=env_settings.GENERATION_SAVE_PATH, output_type=BaseGraph)
2427
class AppendChain(GraphExtender):
2528
"""
@@ -32,41 +35,48 @@ class AppendChain(GraphExtender):
3235
Returns:
3336
graph
3437
"""
38+
3539
prompt: str = ""
40+
3641
def __init__(self):
3742
super().__init__()
3843
self.prompt = PromptTemplate.from_template(prompt_dialogs_and_graph)
3944

40-
def invoke(self, dialogues: list[Dialogue] = None, graph: Graph = None) -> BaseGraph:
41-
print("model: ",env_settings.GENERATION_MODEL_NAME)
42-
base_model = ChatOpenAI(model=env_settings.GENERATION_MODEL_NAME, api_key=env_settings.OPENAI_API_KEY, base_url=env_settings.OPENAI_BASE_URL, temperature=0)
45+
def invoke(
46+
self, dialogues: list[Dialogue] = None, graph: Graph = None
47+
) -> BaseGraph:
48+
print("model: ", env_settings.GENERATION_MODEL_NAME)
49+
base_model = ChatOpenAI(
50+
model=env_settings.GENERATION_MODEL_NAME,
51+
api_key=env_settings.OPENAI_API_KEY,
52+
base_url=env_settings.OPENAI_BASE_URL,
53+
temperature=0,
54+
)
4355
model = base_model | PydanticOutputParser(pydantic_object=DialogueGraph)
4456

4557
final_prompt = self.prompt.format(
46-
orig_dial=dialogues[0],
47-
orig_graph=graph.graph_dict,
48-
new_dial=dialogues[1]
58+
orig_dial=dialogues[0], orig_graph=graph.graph_dict, new_dial=dialogues[1]
4959
)
5060

5161
result = call_llm_api(final_prompt, model, temp=0)
5262
if result is None:
5363
return Graph(graph_dict={})
5464

5565
graph_dict = result.model_dump()
56-
57-
if not all([e['target'] for e in graph_dict['edges']]):
66+
67+
if not all([e["target"] for e in graph_dict["edges"]]):
5868
return Graph(graph_dict={}), []
5969

6070
result_graph = Graph(graph_dict=graph_dict)
6171
return result_graph
6272

6373
async def ainvoke(self, *args, **kwargs):
6474
return self.invoke(*args, **kwargs)
65-
75+
6676
async def evaluate(self, dialogues, graph, target_graph):
6777
result_graph = self.invoke(dialogues, graph)
68-
# report = {
69-
# "is_same_structure": is_same_structure(result_graph, target_graph),
70-
# "graph_match": compare_graphs(result_graph, target_graph),
71-
# }
78+
report = {
79+
"is_same_structure": is_same_structure(result_graph, target_graph),
80+
"graph_match": compare_graphs(result_graph, target_graph),
81+
}
7282
return report

experiments/exp2025_03_12_rec_models_incrementation/exp2025_03_12_rec_models_incrementation/embedder.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88

99
class EnvSettings(BaseSettings, case_sensitive=True):
10-
11-
model_config = SettingsConfigDict(env_file=os.environ["PATH_TO_ENV"], env_file_encoding="utf-8")
10+
model_config = SettingsConfigDict(
11+
env_file=os.environ["PATH_TO_ENV"], env_file_encoding="utf-8"
12+
)
1213

1314
OPENAI_API_KEY: Optional[str]
1415
OPENAI_BASE_URL: Optional[str]
@@ -21,24 +22,39 @@ class EnvSettings(BaseSettings, case_sensitive=True):
2122
preloaded_models = {}
2223

2324

24-
def compare_strings(first: str, second: str, embeddings: HuggingFaceEmbeddings, embedder_th: float = 0.001) -> bool:
25+
def compare_strings(
26+
first: str,
27+
second: str,
28+
embeddings: HuggingFaceEmbeddings,
29+
embedder_th: float = 0.001,
30+
) -> bool:
2531
"""Calculate pairwise_embedding_distance between two strings based on embeddings
2632
and return True when threshold embedder_th not exceeded
2733
Return False othetwise"""
2834

2935
evaluator_2 = load_evaluator("pairwise_embedding_distance", embeddings=embeddings)
30-
score = evaluator_2.evaluate_string_pairs(prediction=first, prediction_b=second)["score"]
36+
score = evaluator_2.evaluate_string_pairs(prediction=first, prediction_b=second)[
37+
"score"
38+
]
3139
# print("SCORE: ", score)
3240
return score <= embedder_th
3341

3442

35-
def get_similarity(generated: list[str], golden: list[str], model_name: str = "BAAI/bge-m3"):
43+
def get_similarity(
44+
generated: list[str], golden: list[str], model_name: str = "BAAI/bge-m3"
45+
):
3646
""" "Calculate similarity matrix between generated and golden using model model_name"""
3747

3848
if model_name not in preloaded_models:
39-
preloaded_models[model_name] = SentenceTransformer(model_name, device=env_settings.DEVICE)
40-
41-
golden_vectors = preloaded_models[model_name].encode(golden, normalize_embeddings=True)
42-
generated_vectors = preloaded_models[model_name].encode(generated, normalize_embeddings=True)
49+
preloaded_models[model_name] = SentenceTransformer(
50+
model_name, device=env_settings.DEVICE
51+
)
52+
53+
golden_vectors = preloaded_models[model_name].encode(
54+
golden, normalize_embeddings=True
55+
)
56+
generated_vectors = preloaded_models[model_name].encode(
57+
generated, normalize_embeddings=True
58+
)
4359
similarities = generated_vectors @ golden_vectors.T
4460
return similarities

0 commit comments

Comments
 (0)