Skip to content

Commit 9751a1d

Browse files
ModelStorage (#41)
* Implemented ModelStorage class
1 parent 93bd367 commit 9751a1d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1235
-572
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
3131
- name: run tests
3232
run: |
33-
python -m poetry run poe test
33+
python -m poetry run poe test || exit 1

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ __pycache__
1010
.pytest_cache
1111
.mypy_cache
1212
test.ipynb
13-
*.pyc
13+
*.pyc
14+
docs/build

dialogue2graph/cli/commands/generate_data.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,21 @@
11
import json
2-
import os
32
from pathlib import Path
4-
from langchain_openai import ChatOpenAI
5-
from dialogue2graph.datasets.complex_dialogues.generation import LoopedGraphGenerator
3+
from dialogue2graph.pipelines.topic_generation.pipeline import TopicGenerationPipeline
4+
from dialogue2graph.pipelines.model_storage import ModelStorage
5+
6+
ms = ModelStorage()
67

78

89
def generate_data(topic: str, config: dict, output_path: str):
910
"""Generate dialogue data for a given topic"""
1011

11-
if config == {}:
12-
gen_model = ChatOpenAI(
13-
model="gpt-4o",
14-
api_key=os.getenv("OPENAI_API_KEY"),
15-
base_url=os.getenv("OPENAI_BASE_URL"),
16-
)
17-
18-
val_model = ChatOpenAI(
19-
model="gpt-3.5-turbo",
20-
api_key=os.getenv("OPENAI_API_KEY"),
21-
base_url=os.getenv("OPENAI_BASE_URL"),
22-
temperature=0,
23-
)
24-
else:
25-
gen_model = ChatOpenAI(
26-
model=config["models"].get("generation-model", {}).get("name", "gpt-4o"),
27-
temperature=config["models"].get("generation-model", {}).get("temperature", 0.7),
28-
api_key=os.getenv("OPENAI_API_KEY"),
29-
base_url=os.getenv("OPENAI_BASE_URL"),
30-
)
31-
val_model = ChatOpenAI(
32-
model=config["models"].get("validation-model", {}).get("name", "gpt-3.5-turbo"),
33-
temperature=config["models"].get("validation-model", {}).get("temperature", 0.7),
34-
api_key=os.getenv("OPENAI_API_KEY"),
35-
base_url=os.getenv("OPENAI_BASE_URL"),
36-
)
37-
38-
pipeline = LoopedGraphGenerator(
39-
generation_model=gen_model,
40-
validation_model=val_model,
41-
)
12+
if config != {}:
13+
ms.load(config)
14+
15+
pipeline = TopicGenerationPipeline(ms)
16+
17+
result = pipeline.invoke(topic)
18+
print("Result:", result.graph_dict)
4219

4320
result = pipeline.invoke(topic=topic)
4421

dialogue2graph/cli/commands/generate_graph_algo.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,19 @@
11
import json
22
from pathlib import Path
33
from dialogue2graph.pipelines.d2g_algo.pipeline import Pipeline
4-
from dialogue2graph.pipelines.models import ModelsAPI
4+
from dialogue2graph.pipelines.model_storage import ModelStorage
55

6-
models = ModelsAPI()
6+
ms = ModelStorage()
77

88

9-
def generate_algo(dialogues: str, config: dict, output_path: str):
9+
def generate_algo(dialogues: str, config: Path, output_path: str):
1010
"""Generates graph from dialogues via d2g_algo pipeline using parameters from config
1111
and saves graph dictionary to output_path"""
1212

13-
if config == {}:
14-
filler_name = "chatgpt-4o-latest"
15-
formatter_name = "gpt-4o-mini"
16-
filler_temp = 0
17-
formatter_temp = 0
18-
sim_name = "BAAI/bge-m3"
19-
device = "cpu"
20-
else:
21-
filler_name = config["models"].get("filler-model", {}).get("name", "chatgpt-4o-latest")
22-
formatter_name = config["models"].get("formatter-model", {}).get("name", "gpt-4o-mini")
23-
filler_temp = config["models"].get("filler-model", {}).get("temperature", 0)
24-
formatter_temp = config["models"].get("formatter-model", {}).get("temperature", 0)
25-
sim_name = config["models"].get("sim-model", {}).get("name", "BAAI/bge-m3")
26-
device = config["models"].get("sim-model", {}).get("device", "cpu")
27-
28-
filling_llm = models("llm", name=filler_name, temp=filler_temp)
29-
formatting_llm = models("llm", name=formatter_name, temp=formatter_temp)
30-
sim_model = models("similarity", name=sim_name, device=device)
31-
32-
pipeline = Pipeline(filling_llm=filling_llm, formatting_llm=formatting_llm, sim_model=sim_model)
13+
if config != {}:
14+
ms.load(config)
15+
16+
pipeline = Pipeline(ms)
3317

3418
result = pipeline.invoke(dialogues)
3519
print("Result:", result.graph_dict)

dialogue2graph/cli/commands/generate_graph_extender.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,19 @@
11
import json
22
from pathlib import Path
33
from dialogue2graph.pipelines.d2g_extender.pipeline import Pipeline
4-
from dialogue2graph.pipelines.models import ModelsAPI
4+
from dialogue2graph.pipelines.model_storage import ModelStorage
55

6-
models = ModelsAPI()
6+
ms = ModelStorage()
77

88

9-
def generate_extender(dialogues: str, config: dict, output_path: str):
9+
def generate_extender(dialogues: str, config: Path, output_path: str):
1010
"""Generates graph from dialogues via d2g_llm pipeline using parameters from config
1111
and saves graph dictionary to output_path"""
1212

13-
if config == {}:
14-
extender_name = "chatgpt-4o-latest"
15-
extender_temp = 0
16-
filler_name = "chatgpt-4o-latest"
17-
filler_temp = 0
18-
formatter_name = "gpt-4o-mini"
19-
formatter_temp = 0
20-
sim_name = "BAAI/bge-m3"
21-
device = "cpu"
22-
else:
23-
extender_name = config["models"].get("extender-model", {}).get("name", "chatgpt-4o-latest")
24-
extender_temp = config["models"].get("extender-model", {}).get("temperature", 0)
25-
filler_name = config["models"].get("filler-model", {}).get("name", "chatgpt-4o-latest")
26-
filler_temp = config["models"].get("filler-model", {}).get("temperature", 0)
27-
formatter_name = config["models"].get("formatter-model", {}).get("name", "gpt-4o-mini")
28-
formatter_temp = config["models"].get("formatter-model", {}).get("temperature", 0)
29-
sim_name = config["models"].get("sim-model", {}).get("name", "BAAI/bge-m3")
30-
device = config["models"].get("sim-model", {}).get("device", "cpu")
31-
32-
extending_llm = models("llm", name=extender_name, temp=extender_temp)
33-
filling_llm = models("llm", name=filler_name, temp=filler_temp)
34-
formatting_llm = models("llm", name=formatter_name, temp=formatter_temp)
35-
sim_model = models("similarity", name=sim_name, device=device)
36-
37-
pipeline = Pipeline(extending_llm=extending_llm, filling_llm=filling_llm, formatting_llm=formatting_llm, sim_model=sim_model)
13+
if config != {}:
14+
ms.load(config)
15+
16+
pipeline = Pipeline(ms)
3817

3918
result = pipeline.invoke(dialogues)
4019
print("Result:", result.graph_dict)

dialogue2graph/cli/commands/generate_graph_llm.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,18 @@
11
import json
22
from pathlib import Path
33
from dialogue2graph.pipelines.d2g_llm.pipeline import Pipeline
4-
from dialogue2graph.pipelines.models import ModelsAPI
4+
from dialogue2graph.pipelines.model_storage import ModelStorage
55

6-
models = ModelsAPI()
6+
ms = ModelStorage()
77

88

9-
def generate_llm(dialogues: str, config: dict, output_path: str):
9+
def generate_llm(dialogues: str, config: Path, output_path: str):
1010
"""Generates graph from dialogues via d2g_llm pipeline using parameters from config
1111
and saves graph dictionary to output_path"""
1212

13-
if config == {}:
14-
grouper_name = "chatgpt-4o-latest"
15-
grouper_temp = 0
16-
filler_name = "chatgpt-4o-latest"
17-
filler_temp = 0
18-
formatter_name = "gpt-4o-mini"
19-
formatter_temp = 0
20-
sim_name = "BAAI/bge-m3"
21-
device = "cpu"
22-
else:
23-
grouper_name = config["models"].get("grouper-model", {}).get("name", "chatgpt-4o-latest")
24-
grouper_temp = config["models"].get("grouper-model", {}).get("temperature", 0)
25-
filler_name = config["models"].get("filler-model", {}).get("name", "chatgpt-4o-latest")
26-
filler_temp = config["models"].get("filler-model", {}).get("temperature", 0)
27-
formatter_name = config["models"].get("formatter-model", {}).get("name", "gpt-4o-mini")
28-
formatter_temp = config["models"].get("formatter-model", {}).get("temperature", 0)
29-
sim_name = config["models"].get("sim-model", {}).get("name", "BAAI/bge-m3")
30-
device = config["models"].get("sim-model", {}).get("device", "cpu")
31-
32-
grouping_llm = models("llm", name=grouper_name, temp=grouper_temp)
33-
filling_llm = models("llm", name=filler_name, temp=filler_temp)
34-
formatting_llm = models("llm", name=formatter_name, temp=formatter_temp)
35-
sim_model = models("similarity", name=sim_name, device=device)
36-
37-
pipeline = Pipeline(grouping_llm=grouping_llm, filling_llm=filling_llm, formatting_llm=formatting_llm, sim_model=sim_model)
13+
if config != {}:
14+
ms.load(config)
15+
pipeline = Pipeline(ms)
3816

3917
result = pipeline.invoke(dialogues)
4018
print("Result:", result.graph_dict)

dialogue2graph/cli/main.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import click
2-
import yaml
32
from dotenv import load_dotenv
43
from .commands.generate_data import generate_data
54
from .commands.generate_graph_algo import generate_algo
@@ -21,14 +20,7 @@ def cli():
2120
def gen_data(env: str, cfg: str, topic: str, output: str):
2221
"""Generate dialogue data for a given topic"""
2322
load_dotenv(env)
24-
with open(cfg) as stream:
25-
config: dict = {}
26-
try:
27-
config = yaml.safe_load(stream)
28-
except yaml.YAMLError as exc:
29-
print(exc)
30-
31-
generate_data(topic, config, output)
23+
generate_data(topic, cfg, output)
3224

3325

3426
@cli.command()
@@ -39,13 +31,7 @@ def gen_data(env: str, cfg: str, topic: str, output: str):
3931
def gen_graph_algo(env: str, cfg: str, dialogues: str, output: str):
4032
"""Generate graph from dialogues data via d2g_algo pipeline"""
4133
load_dotenv(env)
42-
with open(cfg) as stream:
43-
config: dict = {}
44-
try:
45-
config = yaml.safe_load(stream)
46-
except yaml.YAMLError as exc:
47-
print(exc)
48-
generate_algo(dialogues, config, output)
34+
generate_algo(dialogues, cfg, output)
4935

5036

5137
@cli.command()
@@ -56,13 +42,7 @@ def gen_graph_algo(env: str, cfg: str, dialogues: str, output: str):
5642
def gen_graph_llm(env: str, cfg: str, dialogues: str, output: str):
5743
"""Generate graph from dialogues data via d2g_llm pipeline"""
5844
load_dotenv(env)
59-
with open(cfg) as stream:
60-
config: dict = {}
61-
try:
62-
config = yaml.safe_load(stream)
63-
except yaml.YAMLError as exc:
64-
print(exc)
65-
generate_llm(dialogues, config, output)
45+
generate_llm(dialogues, cfg, output)
6646

6747

6848
@cli.command()
@@ -73,13 +53,7 @@ def gen_graph_llm(env: str, cfg: str, dialogues: str, output: str):
7353
def gen_graph_extender(env: str, cfg: str, dialogues: str, output: str):
7454
"""Generate graph from dialogues data via d2g_llm pipeline"""
7555
load_dotenv(env)
76-
with open(cfg) as stream:
77-
config: dict = {}
78-
try:
79-
config = yaml.safe_load(stream)
80-
except yaml.YAMLError as exc:
81-
print(exc)
82-
generate_extender(dialogues, config, output)
56+
generate_extender(dialogues, cfg, output)
8357

8458

8559
if __name__ == "__main__":
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from dialogue2graph.datasets.core import Dataset

dialogue2graph/datasets/complex_dialogues/generation.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dialogue2graph.pipelines.core.graph import BaseGraph, Graph
1616
from dialogue2graph.pipelines.core.algorithms import TopicGraphGenerator
1717
from dialogue2graph.pipelines.core.schemas import GraphGenerationResult, DialogueGraph
18+
from dialogue2graph.pipelines.model_storage import ModelStorage
1819
from dialogue2graph.utils.prompt_caching import setup_cache, add_uuid_to_prompt
1920

2021
from .prompts import cycle_graph_generation_prompt_informal, cycle_graph_repair_prompt, graph_example
@@ -276,19 +277,30 @@ def __call__(self, topic: str) -> PipelineResult:
276277

277278

278279
class LoopedGraphGenerator(TopicGraphGenerator):
279-
generation_model: BaseChatModel
280-
validation_model: BaseChatModel
280+
"""Graph generator for topic-based dialogue generation with model storage support"""
281+
282+
model_storage: ModelStorage = Field(description="Model storage")
283+
generation_llm: str = Field(description="LLM for graph generation")
284+
validation_llm: str = Field(description="LLM for validation")
285+
theme_validation_llm: str = Field(description="LLM for theme validation")
281286
pipeline: GenerationPipeline
282287

283-
def __init__(self, generation_model: BaseChatModel, validation_model: BaseChatModel, theme_validation_model: BaseChatModel):
288+
def __init__(
289+
self,
290+
model_storage: ModelStorage,
291+
generation_llm: str,
292+
validation_llm: str,
293+
theme_validation_llm: str,
294+
):
284295
super().__init__(
285-
generation_model=generation_model,
286-
validation_model=validation_model,
287-
theme_validation_model=theme_validation_model,
296+
model_storage=model_storage,
297+
generation_llm=generation_llm,
298+
validation_llm=validation_llm,
299+
theme_validation_llm=theme_validation_llm,
288300
pipeline=GenerationPipeline(
289-
generation_model=generation_model,
290-
validation_model=validation_model,
291-
theme_validation_model=theme_validation_model,
301+
generation_model=model_storage.storage[generation_llm].model,
302+
validation_model=model_storage.storage[validation_llm].model,
303+
theme_validation_model=model_storage.storage[theme_validation_llm].model,
292304
generation_prompt=cycle_graph_generation_prompt_informal,
293305
repair_prompt=cycle_graph_repair_prompt,
294306
),
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .dataset import Dataset
2+
3+
__all__ = ["Dataset"]

0 commit comments

Comments
 (0)