Skip to content

Commit 8519a8a

Browse files
yupeshYuriy PeshkichevNotBioWaste905
authored
Feat/model validation (#48)
* Initializing models outside of the ModelStorage * ModelStorage.add() now demand type instead of string in model_type field * update userguides --------- Co-authored-by: Yuriy Peshkichev <[email protected]> Co-authored-by: NotBioWaste905 <[email protected]>
1 parent b36babb commit 8519a8a

File tree

25 files changed

+384
-719
lines changed

25 files changed

+384
-719
lines changed

dialogue2graph/datasets/complex_dialogues/generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
The module provides graph generator capable of creating complex validated graphs.
66
"""
77

8-
import logging
98
import os
109
from enum import Enum
1110
from typing import Optional, Dict, Any, Union
@@ -37,8 +36,9 @@
3736
)
3837

3938
# Configure logging
40-
logging.basicConfig(level=logging.INFO)
41-
logger = logging.getLogger(__name__)
39+
from dialogue2graph.utils.logger import Logger
40+
41+
logger = Logger(__file__)
4242

4343

4444
class ErrorType(str, Enum):

dialogue2graph/metrics/llm_metrics/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
The module contains functions that checks Graphs and Dialogues for various metrics using LLM calls.
66
"""
77

8-
import logging
98
import json
109
from typing import List, TypedDict, Union
1110
from pydantic import BaseModel, Field
@@ -22,8 +21,9 @@
2221
from langchain.chat_models import ChatOpenAI
2322
from langchain.schema import HumanMessage
2423

25-
# Set up logging
26-
logging.basicConfig(level=logging.INFO)
24+
from dialogue2graph.utils.logger import Logger
25+
26+
logger = Logger(__file__)
2727

2828

2929
class InvalidTransition(TypedDict):

dialogue2graph/metrics/llm_validators/validators.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from dialogue2graph.pipelines.model_storage import ModelStorage
1414
from dialogue2graph.metrics.similarity import compare_strings
1515

16-
from langchain_core.language_models.chat_models import BaseChatModel
16+
from langchain_openai import ChatOpenAI
17+
from langchain_core.language_models import BaseChatModel
18+
from langchain_huggingface import HuggingFaceEmbeddings
1719
from langchain.prompts import PromptTemplate
1820
from langchain.output_parsers import PydanticOutputParser
1921

@@ -125,7 +127,10 @@ def is_greeting_repeated_emb_llm(
125127
starts = START_TURNS
126128

127129
if model_storage.storage.get(embedder_name):
128-
if not model_storage.storage.get(embedder_name).model_type == "emb":
130+
if (
131+
not model_storage.storage.get(embedder_name).model_type
132+
== HuggingFaceEmbeddings
133+
):
129134
raise TypeError(f"The {embedder_name} model is not an embedder")
130135
embedder_model = model_storage.storage[embedder_name].model
131136
else:
@@ -134,7 +139,7 @@ def is_greeting_repeated_emb_llm(
134139
)
135140

136141
if model_storage.storage.get(llm_name):
137-
if not model_storage.storage.get(llm_name).model_type == "llm":
142+
if not model_storage.storage.get(llm_name).model_type == ChatOpenAI:
138143
raise TypeError(f"The {llm_name} model is not an LLM")
139144
llm_model = model_storage.storage[llm_name].model
140145
else:
@@ -183,7 +188,10 @@ def is_dialog_closed_too_early_emb_llm(
183188
ends = END_TURNS
184189

185190
if model_storage.storage.get(embedder_name):
186-
if not model_storage.storage.get(embedder_name).model_type == "emb":
191+
if (
192+
not model_storage.storage.get(embedder_name).model_type
193+
== HuggingFaceEmbeddings
194+
):
187195
raise TypeError(f"The {embedder_name} model is not an embedder")
188196
embedder_model = model_storage.storage[embedder_name].model
189197
else:
@@ -192,7 +200,7 @@ def is_dialog_closed_too_early_emb_llm(
192200
)
193201

194202
if model_storage.storage.get(llm_name):
195-
if not model_storage.storage.get(llm_name).model_type == "llm":
203+
if not model_storage.storage.get(llm_name).model_type == ChatOpenAI:
196204
raise TypeError(f"The {llm_name} model is not an LLM")
197205
llm_model = model_storage.storage[llm_name].model
198206
else:

dialogue2graph/metrics/no_llm_metrics/metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from dialogue2graph.pipelines.core.graph import BaseGraph
1515
from dialogue2graph.pipelines.core.dialogue import Dialogue
16+
from dialogue2graph.utils.logger import Logger
17+
18+
logger = Logger(__file__)
1619

1720

1821
logging.basicConfig(level=logging.INFO)

dialogue2graph/pipelines/core/dialogue_sampling.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import itertools
9-
import logging
109
from typing import Literal
1110
import pandas as pd
1211
from dialogue2graph.pipelines.core.graph import BaseGraph
@@ -19,8 +18,9 @@
1918
from dialogue2graph.pipelines.helpers.find_cycle_ends import find_cycle_ends
2019
from langchain_core.language_models.chat_models import BaseChatModel
2120

22-
logging.basicConfig(level=logging.INFO)
23-
logger = logging.getLogger(__name__)
21+
from dialogue2graph.utils.logger import Logger
22+
23+
logger = Logger(__file__)
2424

2525

2626
class _DialogPathsCounter:
@@ -189,23 +189,21 @@ def remove_duplicated_paths(node_paths: list[list[int]]) -> list[list[int]]:
189189
return res
190190

191191

192-
def get_dialogue_doublets(seq: list[list[dict]]) -> set[tuple[str]]:
193-
"""Find all dialogue doublets with (edge, target) utterances
194-
195-
Args:
196-
seq: sequence of dialogs
197-
198-
Returns:
199-
Set of (user_utterance, assistant_utterance)
200-
"""
201-
doublets = set()
202-
for dialogue in seq:
203-
user_texts = [d["text"] for d in dialogue if d["participant"] == "user"]
204-
assist_texts = [d["text"] for d in dialogue if d["participant"] == "assistant"]
205-
if len(assist_texts) > len(user_texts):
206-
user_texts += [""]
207-
doublets.update(zip(user_texts, assist_texts))
208-
return doublets
192+
# def get_dialogue_doublets(seq: list[list[dict]]) -> set[tuple[str]]:
193+
# """Find all dialogue doublets with (edge, target) utterances
194+
# Args:
195+
# seq: sequence of dialogs
196+
# Returns:
197+
# Set of (user_utterance, assistant_utterance)
198+
# """
199+
# doublets = set()
200+
# for dialogue in seq:
201+
# user_texts = [d["text"] for d in dialogue if d["participant"] == "user"]
202+
# assist_texts = [d["text"] for d in dialogue if d["participant"] == "assistant"]
203+
# if len(assist_texts) > len(user_texts):
204+
# user_texts += [""]
205+
# doublets.update(zip(user_texts, assist_texts))
206+
# return doublets
209207

210208

211209
def get_dialogue_triplets(seq: list[list[dict]]) -> set[tuple[str]]:
@@ -239,9 +237,7 @@ def remove_duplicated_dialogues(seq: list[list[dict]]) -> list[list[dict]]:
239237
return []
240238
uniq_seq = [non_empty_seq[0]]
241239
for s in non_empty_seq[1:]:
242-
if not get_dialogue_doublets([s]).issubset(
243-
get_dialogue_doublets(uniq_seq)
244-
) or not get_dialogue_triplets([s]).issubset(get_dialogue_triplets(uniq_seq)):
240+
if not get_dialogue_triplets([s]).issubset(get_dialogue_triplets(uniq_seq)):
245241
uniq_seq.append(s)
246242
return uniq_seq
247243

dialogue2graph/pipelines/core/graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from typing import Optional, Any
1111
import matplotlib.pyplot as plt
1212
import abc
13-
import logging
1413

15-
logger = logging.getLogger(__name__)
14+
from dialogue2graph.utils.logger import Logger
15+
16+
logger = Logger(__file__)
1617

1718

1819
class BaseGraph(BaseModel, abc.ABC):
@@ -140,13 +141,13 @@ def load_graph(self):
140141
"""
141142
self.graph = nx.DiGraph()
142143
nodes = sorted([v["id"] for v in self.graph_dict["nodes"]])
143-
logging.debug(f"Nodes: {nodes}")
144+
logger.debug(f"Nodes: {nodes}")
144145

145146
self.node_mapping = {}
146147
renumber_flg = nodes != list(range(1, len(nodes) + 1))
147148
if renumber_flg:
148149
self.node_mapping = {node_id: idx + 1 for idx, node_id in enumerate(nodes)}
149-
logging.debug(f"Renumber flag: {renumber_flg}")
150+
logger.debug(f"Renumber flag: {renumber_flg}")
150151

151152
for node in self.graph_dict["nodes"]:
152153
cur_node_id = node["id"]

dialogue2graph/pipelines/d2g_extender/pipeline.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
from typing import Callable
99
from dotenv import load_dotenv
10+
1011
from dialogue2graph.pipelines.core.pipeline import BasePipeline
1112
from dialogue2graph.pipelines.model_storage import ModelStorage
1213
from dialogue2graph.pipelines.d2g_extender.three_stages_extender import LLMGraphExtender
14+
from langchain_openai import ChatOpenAI
15+
from langchain_huggingface import HuggingFaceEmbeddings
1316

1417
load_dotenv()
1518

@@ -32,6 +35,37 @@ def __init__(
3235
end_evals: list[Callable] = None,
3336
step: int = 2,
3437
):
38+
# if model is not in model storage put the default model there
39+
model_storage.add(
40+
key=extending_llm,
41+
config={"model_name": "chatgpt-4o-latest", "temperature": 0},
42+
model_type="llm",
43+
)
44+
45+
model_storage.add(
46+
key=filling_llm,
47+
config={"mode_name": "o3-mini", "temperature": 1},
48+
model_type=ChatOpenAI,
49+
)
50+
51+
model_storage.add(
52+
key=formatting_llm,
53+
config={"model_name": "gpt-4o-mini", "temperature": 0},
54+
model_type=ChatOpenAI,
55+
)
56+
57+
model_storage.add(
58+
key=dialog_llm,
59+
config={"model_name": "o3-mini", "temperature": 1},
60+
model_type=ChatOpenAI,
61+
)
62+
63+
model_storage.add(
64+
key=sim_model,
65+
config={"model_name": "BAAI/bge-m3", "device": "cpu"},
66+
model_type=HuggingFaceEmbeddings,
67+
)
68+
3569
super().__init__(
3670
name=name,
3771
steps=[

dialogue2graph/pipelines/d2g_extender/three_stages_extender.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from langchain.schema import HumanMessage
1414
from langchain.prompts import PromptTemplate
1515

16+
from dialogue2graph.utils.logger import Logger
1617
from dialogue2graph import metrics
1718
from dialogue2graph.pipelines.core.dialogue_sampling import RecursiveDialogueSampler
1819
from dialogue2graph.pipelines.d2g_light.three_stages_light import LightGraphGenerator
@@ -44,6 +45,8 @@ class DialogueNodes(BaseModel):
4445
logging.basicConfig(level=logging.INFO)
4546
logger = logging.getLogger(__name__)
4647
logging.getLogger("langchain_core.vectorstores.base").setLevel(logging.ERROR)
48+
logger = Logger(__file__)
49+
4750
dialogue_sampler = RecursiveDialogueSampler()
4851

4952

dialogue2graph/pipelines/d2g_light/pipeline.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from dialogue2graph.pipelines.core.pipeline import BasePipeline
1111
from dialogue2graph.pipelines.d2g_light.three_stages_light import LightGraphGenerator
1212
from dialogue2graph.pipelines.model_storage import ModelStorage
13+
from langchain_openai import ChatOpenAI
14+
from langchain_huggingface import HuggingFaceEmbeddings
1315

1416
load_dotenv()
1517

@@ -27,6 +29,22 @@ def __init__(
2729
step2_evals: list[Callable] = None,
2830
end_evals: list[Callable] = None,
2931
):
32+
# if model is not in model storage put the default model there
33+
model_storage.add(
34+
key=filling_llm,
35+
config={"model_name": "chatgpt-4o-latest", "temperature": 0},
36+
model_type=ChatOpenAI,
37+
)
38+
model_storage.add(
39+
key=formatting_llm,
40+
config={"model_name": "gpt-4o-mini", "temperature": 0},
41+
model_type=ChatOpenAI,
42+
)
43+
model_storage.add(
44+
key=sim_model,
45+
config={"model_name": "BAAI/bge-m3", "model_kwargs": {"device": "cpu"}},
46+
model_type=HuggingFaceEmbeddings,
47+
)
3048
super().__init__(
3149
name=name,
3250
steps=[

dialogue2graph/pipelines/d2g_light/three_stages_light.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(
6969
filling_llm: str = "three_stages_light_filling_llm:v1",
7070
formatting_llm: str = "three_stages_light_formatting_llm:v1",
7171
sim_model: str = "three_stages_light_sim_model:v1",
72-
7372
step2_evals: list[Callable] | None = [],
7473
end_evals: list[Callable] | None = [],
7574
):

0 commit comments

Comments
 (0)