Skip to content

Commit 1031855

Browse files
V0.1.0 mvp (#50)
* Refactoring and documentation improvements for the 0.1.0 release
1 parent bdd0776 commit 1031855

File tree

66 files changed

+977
-1025
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+977
-1025
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: build_and_publish_release
2+
3+
on:
4+
workflow_dispatch:
5+
6+
jobs:
7+
deploy:
8+
runs-on: ubuntu-latest
9+
if: github.ref == 'refs/heads/main'
10+
11+
steps:
12+
- uses: actions/checkout@v4
13+
14+
- name: Set up Python
15+
uses: actions/setup-python@v5
16+
with:
17+
python-version: "3.11"
18+
19+
- name: Install Poetry
20+
uses: snok/install-poetry@v1
21+
with:
22+
virtualenvs-create: false
23+
24+
- name: Configure Poetry
25+
run: |
26+
poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }}
27+
poetry config http-basic.pypi __token__ ${{ secrets.PYPI_API_TOKEN }}
28+
29+
- name: Build and publish
30+
run: |
31+
poetry build
32+
poetry publish

.github/workflows/test_release.yml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: test_release
2+
3+
on:
4+
push:
5+
branches: '**'
6+
pull_request:
7+
branches:
8+
- main
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ github.ref != 'refs/heads/dev' && github.ref != 'refs/heads/main' }}
14+
15+
jobs:
16+
test_full:
17+
strategy:
18+
fail-fast: false
19+
matrix:
20+
python-version: ["3.10", "3.11", "3.12"]
21+
os: [macOS-latest, windows-latest, ubuntu-latest]
22+
runs-on: ${{ matrix.os }}
23+
steps:
24+
- uses: actions/checkout@v4
25+
26+
- name: set up python ${{ matrix.python-version }}
27+
uses: actions/setup-python@v5
28+
with:
29+
python-version: ${{ matrix.python-version }}
30+
31+
- name: setup poetry and install dependencies
32+
run: |
33+
python -m pip install --upgrade pip poetry==1.8.4
34+
35+
- name: build release
36+
run: |
37+
python -m poetry build
38+
39+
- name: install and test installed package
40+
shell: bash
41+
run: |
42+
python -m venv test_env
43+
. ${GITHUB_WORKSPACE}/test_env/bin/activate || . ${GITHUB_WORKSPACE}/test_env/Scripts/activate
44+
pip install ./dist/*.whl
45+
pip install pytest
46+
# Debug information
47+
echo "Current directory: $(pwd)"
48+
echo "Directory contents:"
49+
ls -la
50+
# Actually run the tests with explicit path
51+
python -m pytest tests/ -v

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Choose LLMs for generating and validating dialogue graph and invoke graph genera
6565

6666
```python
6767
from dialogue2graph.datasets.complex_dialogues.generation import LoopedGraphGenerator
68-
from langchain_openai import ChatOpenAI
68+
from langchain_community.chat_models import ChatOpenAI
6969

7070

7171
gen_model = ChatOpenAI(
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from dialogue2graph.datasets.complex_dialogues.generation import CycleGraphGenerator
2+
from dialogue2graph.datasets.augment_dialogues.augmentation import DialogueAugmenter
23

3-
__all__ = ["CycleGraphGenerator"]
4+
__all__ = ["CycleGraphGenerator", "DialogueAugmenter"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from dialogue2graph.datasets.augment_dialogues.augmentation import DialogueAugmenter
2+
3+
__all__ = [
4+
"DialogueAugmenter",
5+
]

dialogue2graph/datasets/augment_dialogues/augmentation.py

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,31 @@
88
from dialogue2graph.pipelines.core.algorithms import DialogAugmentation
99
from dialogue2graph.pipelines.core.dialogue import Dialogue
1010
from dialogue2graph.pipelines.model_storage import ModelStorage
11-
from dialogue2graph.metrics.no_llm_metrics.metrics import (
12-
is_correct_length, match_roles
13-
)
11+
from dialogue2graph.metrics.no_llm_metrics.metrics import is_correct_length, match_roles
1412

1513
logging.getLogger("langchain_core.vectorstores.base").setLevel(logging.ERROR)
1614

15+
1716
class AugmentedTurn(BaseModel):
1817
"""Dialogue turn to augment"""
18+
1919
participant: str
20-
text: list[str] = Field(..., description="List of utterance variations for this turn")
20+
text: list[str] = Field(
21+
..., description="List of utterance variations for this turn"
22+
)
23+
2124

2225
class DialogueSequence(BaseModel):
2326
"""Result as dialogue sequence"""
27+
2428
result: list[AugmentedTurn] = Field(..., description="Sequence of augmented turns")
2529

2630

2731
class DialogueAugmenter(DialogAugmentation):
2832
"""Class for dialogue augmentation.
29-
33+
3034
Augments dialogues while preserving structure and conversation flow by rephrasing original dialogue lines."""
31-
35+
3236
model_storage: ModelStorage = Field(..., description="Model storage instance")
3337
generation_llm: str = Field(..., description="Key for generation LLM in storage")
3438
formatting_llm: str = Field(..., description="Key for formatting LLM in storage")
@@ -40,33 +44,34 @@ def invoke(
4044
topic: str = "",
4145
) -> Union[list[Dialogue], str]:
4246
"""Augment dialogue while preserving conversation structure.
43-
47+
4448
Args:
4549
dialogue: Input Dialogue object to augment
4650
prompt: Required augmentation prompt template
4751
topic: Contextual topic for augmentation (default: empty)
48-
52+
4953
Returns:
5054
List of augmented Dialogue objects or error message
5155
"""
52-
if prompt == '':
53-
return 'Preprocessing failed: prompt should be a valid instruction for LLM'
54-
56+
if prompt == "":
57+
return "Preprocessing failed: prompt should be a valid instruction for LLM"
58+
5559
try:
5660
message_dicts = [msg.model_dump() for msg in dialogue.messages]
5761
if message_dicts == []:
58-
return 'Preprocessing failed: no messages found in the dialogue'
59-
62+
return "Preprocessing failed: no messages found in the dialogue"
63+
6064
augmentation_prompt = PromptTemplate.from_template(prompt)
6165
parser = JsonOutputParser(pydantic_object=DialogueSequence)
62-
66+
6367
fixed_parser = OutputFixingParser.from_llm(
64-
parser=parser,
65-
llm=self._get_llm(self.formatting_llm)
68+
parser=parser, llm=self._get_llm(self.formatting_llm)
69+
)
70+
71+
chain = (
72+
augmentation_prompt | self._get_llm(self.generation_llm) | fixed_parser
6673
)
6774

68-
chain = augmentation_prompt | self._get_llm(self.generation_llm) | fixed_parser
69-
7075
for attempt in range(3):
7176
try:
7277
result = chain.invoke({"topic": topic, "dialogue": message_dicts})
@@ -76,58 +81,55 @@ def invoke(
7681
except Exception as e:
7782
logging.error(f"Error creating dialogues: {str(e)}")
7883
return f"Post-processing failed: {str(e)}"
79-
84+
8085
except ValidationError as ve:
81-
logging.warning(f"Validation error attempt {attempt+1}: {ve}")
86+
logging.warning(f"Validation error attempt {attempt + 1}: {ve}")
8287

8388
except Exception as e:
8489
logging.error(f"Unexpected error: {str(e)}")
8590
if attempt == 2:
8691
return f"Augmentation failed: {str(e)}"
87-
92+
8893
return "Augmentation failed after 3 attempts"
89-
94+
9095
except Exception as e:
9196
logging.exception("Critical error in augmentation pipeline")
9297
return f"Critical error: {str(e)}"
9398

9499
async def ainvoke(self, *args, **kwargs):
95100
"""Async version of invoke"""
96101
return self.invoke(*args, **kwargs)
97-
98-
async def evaluate(
99-
self,
100-
dialogue: Dialogue,
101-
prompt: str,
102-
topic: str = ""
103-
) -> dict:
102+
103+
async def evaluate(self, dialogue: Dialogue, prompt: str, topic: str = "") -> dict:
104104
"""Evaluate augmentation quality with dictionary report format."""
105105
result = self.invoke(dialogue, prompt, topic)
106-
106+
107107
if isinstance(result, str):
108108
return {"error": result}
109-
110-
report = {}
109+
110+
report = {}
111111
for i, augmented_dialogue in enumerate(result):
112-
try:
113-
report[f'augmented_dialogue_{i}'] = {
112+
try:
113+
report[f"augmented_dialogue_{i}"] = {
114114
"match_roles": match_roles(dialogue, augmented_dialogue),
115-
"correct_length": is_correct_length(dialogue, augmented_dialogue)
115+
"correct_length": is_correct_length(dialogue, augmented_dialogue),
116116
}
117117
except Exception as e:
118-
logging.error(f"Error while calculating metrics: {str(e)}")
118+
logging.error(f"Error while calculating metrics: {str(e)}")
119119
return report
120120

121121
def _get_llm(self, llm_key: str):
122122
"""Get model from model storage safely"""
123123
if llm_key not in self.model_storage.storage:
124124
raise ValueError(f"LLM key '{llm_key}' not found in model storage")
125125
return self.model_storage.storage[llm_key].model
126-
127-
def _combine_one_dialogue(self, augmentation_result: DialogueSequence, i: int) -> dict:
126+
127+
def _combine_one_dialogue(
128+
self, augmentation_result: DialogueSequence, i: int
129+
) -> dict:
128130
"""Combine new augmented dialogues from utterance variations"""
129131
new_augmented_dialogue = {}
130-
new_augmented_dialogue['messages'] = []
132+
new_augmented_dialogue["messages"] = []
131133
roles_to_add = [turn.participant for turn in augmentation_result.result]
132134
utterances_to_add = [turn.text[i] for turn in augmentation_result.result]
133135

@@ -139,13 +141,13 @@ def _combine_one_dialogue(self, augmentation_result: DialogueSequence, i: int) -
139141

140142
return new_augmented_dialogue
141143

142-
def _create_dialogues(self, result: dict) -> list[Dialogue]:
144+
def _create_dialogues(self, result: dict) -> list[Dialogue]:
143145
"""Create a list of Dialogue objects"""
144146
try:
145147
augmentation_result = DialogueSequence(result=result)
146148
except Exception as e:
147149
logging.error(f"Wrong type of augmentation result: {str(e)}")
148-
return f"Creating a list of Dialogue objects failed: {str(e)}"
150+
return f"Creating a list of Dialogue objects failed: {str(e)}"
149151

150152
utterances_lists = [turn.text for turn in augmentation_result.result]
151153
lens = [len(uttr_list) for uttr_list in utterances_lists]
@@ -154,5 +156,8 @@ def _create_dialogues(self, result: dict) -> list[Dialogue]:
154156
for i in range(min(lens)):
155157
new_augmented_dialogue = self._combine_one_dialogue(augmentation_result, i)
156158
augmented_dialogues.append(new_augmented_dialogue)
157-
158-
return [Dialogue.from_list(new_augmented_dialogue['messages']) for new_augmented_dialogue in augmented_dialogues]
159+
160+
return [
161+
Dialogue.from_list(new_augmented_dialogue["messages"])
162+
for new_augmented_dialogue in augmented_dialogues
163+
]

dialogue2graph/datasets/complex_dialogues/generation.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import logging
9+
import os
910
from enum import Enum
1011
from typing import Optional, Dict, Any, Union
1112

@@ -62,6 +63,7 @@ class GenerationError(BaseModel):
6263

6364
class CycleGraphGenerator(BaseModel):
6465
"""Class for generating graph with cycles"""
66+
6567
cache: Optional[Any] = Field(default=None, exclude=True)
6668

6769
class Config:
@@ -99,6 +101,7 @@ def evaluate(self, *args, report_type="dict", **kwargs):
99101

100102
class GenerationPipeline(BaseModel):
101103
"""Class for generation pipeline"""
104+
102105
cache: Optional[Any] = Field(default=None, exclude=True)
103106
generation_model: BaseChatModel
104107
theme_validation_model: BaseChatModel
@@ -392,10 +395,20 @@ class LoopedGraphGenerator(TopicGraphGenerator):
392395
"""Graph generator for topic-based dialogue generation with model storage support"""
393396

394397
model_storage: ModelStorage = Field(description="Model storage")
395-
generation_llm: str = Field(description="LLM for graph generation")
396-
validation_llm: str = Field(description="LLM for validation")
397-
cycle_ends_llm: str = Field(description="LLM for dialog sampler to find cycle ends")
398-
theme_validation_llm: str = Field(description="LLM for theme validation")
398+
generation_llm: str = Field(
399+
description="LLM for graph generation", default="looped_graph_generation_llm:v1"
400+
)
401+
validation_llm: str = Field(
402+
description="LLM for validation", default="looped_graph_validation_llm:v1"
403+
)
404+
cycle_ends_llm: str = Field(
405+
description="LLM for dialog sampler to find cycle ends",
406+
default="looped_graph_cycle_ends_llm:v1",
407+
)
408+
theme_validation_llm: str = Field(
409+
description="LLM for theme validation",
410+
default="looped_graph_theme_validation_llm:v1",
411+
)
399412
pipeline: GenerationPipeline
400413

401414
def __init__(
@@ -406,6 +419,42 @@ def __init__(
406419
cycle_ends_llm: str,
407420
theme_validation_llm: str,
408421
):
422+
# check if models are in model storage
423+
# if model is not in model storage put the default model there
424+
if generation_llm not in model_storage.storage:
425+
model_storage.add(
426+
key=generation_llm,
427+
config={
428+
"name": "gpt-4o-latest",
429+
"api_key": os.getenv("OPENAI_API_KEY"),
430+
"base_url": os.getenv("OPENAI_BASE_URL"),
431+
},
432+
model_type="llm",
433+
)
434+
435+
if validation_llm not in model_storage.storage:
436+
model_storage.add(
437+
key=validation_llm,
438+
config={
439+
"name": "gpt-3.5-turbo",
440+
"api_key": os.getenv("OPENAI_API_KEY"),
441+
"base_url": os.getenv("OPENAI_BASE_URL"),
442+
"temperature": 0,
443+
},
444+
model_type="llm",
445+
)
446+
447+
if theme_validation_llm not in model_storage.storage:
448+
model_storage.add(
449+
key=theme_validation_llm,
450+
config={
451+
"name": "gpt-3.5-turbo",
452+
"api_key": os.getenv("OPENAI_API_KEY"),
453+
"base_url": os.getenv("OPENAI_BASE_URL"),
454+
"temperature": 0,
455+
},
456+
model_type="llm",
457+
)
409458
super().__init__(
410459
model_storage=model_storage,
411460
generation_llm=generation_llm,

dialogue2graph/pipelines/core/pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525

2626
class BasePipeline(BaseModel):
27-
# TODO: add docs
28-
"""Abstract class for base pipeline"""
27+
"""Base class for pipelines"""
2928

3029
name: str = Field(description="Name of the pipeline")
3130
steps: list[

0 commit comments

Comments
 (0)