diff --git a/dialog2graph/pipelines/d2g_llm/three_stages_llm.py b/dialog2graph/pipelines/d2g_llm/three_stages_llm.py index 401f4b11..85f27684 100644 --- a/dialog2graph/pipelines/d2g_llm/three_stages_llm.py +++ b/dialog2graph/pipelines/d2g_llm/three_stages_llm.py @@ -14,7 +14,7 @@ from langchain.output_parsers import PydanticOutputParser, OutputFixingParser from langchain.schema import HumanMessage from langchain_openai import ChatOpenAI -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_huggingface.embeddings import HuggingFaceEmbeddings from dialog2graph import metrics diff --git a/dialog2graph/pipelines/helpers/parse_data.py b/dialog2graph/pipelines/helpers/parse_data.py index c59afbe8..e944cac2 100644 --- a/dialog2graph/pipelines/helpers/parse_data.py +++ b/dialog2graph/pipelines/helpers/parse_data.py @@ -19,7 +19,7 @@ logger = Logger(__name__) -RawDialogsType = dict | list[list] | list[dict] | Dialog | list[Dialog] | PosixPath +RawDialogsType = dict | list[list] | list[dict] | list[list[dict]] | Dialog | list[Dialog] | PosixPath ValidatedDialogType = ( List[DialogMessage] | List[List[DialogMessage]] | Dialog | List[Dialog] ) diff --git a/dialog2graph/pipelines/model_storage.py b/dialog2graph/pipelines/model_storage.py index 9bc5a343..7a9d5533 100644 --- a/dialog2graph/pipelines/model_storage.py +++ b/dialog2graph/pipelines/model_storage.py @@ -5,14 +5,13 @@ import yaml import re import dotenv -from pydantic._internal._model_construction import ModelMetaclass -from typing import Union, Dict +from typing import Union, Dict, Type from pathlib import Path from pydantic import BaseModel, Field, model_validator -from langchain_community.chat_models import ChatOpenAI +from langchain_openai.chat_models import ChatOpenAI from langchain_core.language_models import BaseChatModel -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_huggingface.embeddings import HuggingFaceEmbeddings from dialog2graph.utils.logger import Logger @@ -62,7 +61,7 @@ class StoredData(BaseModel): Attributes: key (str): Key for the stored model. config (dict): Configuration for the stored model. - model_type (Union[Literal["llm"], Literal["emb"]]): Type of the stored model, either "llm" (language model) or "emb" (embedding model). + model_type (Type[BaseChatModel | HuggingFaceEmbeddings]): Type of the stored model, for example ChatOpenAI, HuggingFaceEmbeddings. model (Union[HuggingFaceEmbeddings, BaseChatModel]): The actual model object, which can either be a HuggingFaceEmbeddings instance or a BaseChatModel instance. Methods: @@ -76,7 +75,7 @@ class StoredData(BaseModel): key: str = Field(description="Key for the stored model") config: dict = Field(description="Configuration for the stored model") - model_type: ModelMetaclass = Field(description="Type of the stored model") + model_type: Type[BaseChatModel | HuggingFaceEmbeddings] = Field(description="Type of the stored model") model: Union[ HuggingFaceEmbeddings, BaseChatModel, @@ -140,7 +139,7 @@ def add( self, key: str, config: dict, - model_type: ModelMetaclass, + model_type: Type[BaseChatModel | HuggingFaceEmbeddings], overwright: bool = False, ): """ @@ -155,7 +154,7 @@ def add( Args: key (str): The unique identifier for the model configuration. config (dict): The configuration dictionary for initializing the model. - model_type (ModelMetaclass): The type name of the model to be added. + model_type (Type[BaseChatModel | HuggingFaceEmbeddings]): The type of the model to be added. overwright (bool): Whether to overwright model existing under same key . Raises: diff --git a/examples/pipeline_usage/models_config.yml b/examples/pipeline_usage/models_config.yml index 4a6c6c31..21181317 100644 --- a/examples/pipeline_usage/models_config.yml +++ b/examples/pipeline_usage/models_config.yml @@ -1,20 +1,20 @@ -d2g_llm_filling_llm:v1: +my_filling_model: config: - name: o3-mini - temperature: 1 - model_type: llm -d2g_llm_grouping_llm:v1: + model_name: gpt-3.5-turbo + model_type: ChatOpenAI +three_stages_llm_formatting_llm:v1: config: - name: gpt-4o-latest + model_name: gpt-4o-mini temperature: 0 - model_type: llm -d2g_llm_sim_model:v1: + model_type: ChatOpenAI +three_stages_llm_grouping_llm:v1: + config: + model_name: chatgpt-4o-latest + temperature: 0 + model_type: ChatOpenAI +three_stages_llm_sim_model:v1: config: model_kwargs: device: cpu - model_name: cointegrated/LaBSE-en-ru - model_type: emb -my_model: - config: - name: gpt-3.5-turbo - model_type: llm + model_name: BAAI/bge-m3 + model_type: HuggingFaceEmbeddings diff --git a/examples/pipeline_usage/multiple_dialogs.json b/examples/pipeline_usage/multiple_dialogs.json new file mode 100644 index 00000000..148759ef --- /dev/null +++ b/examples/pipeline_usage/multiple_dialogs.json @@ -0,0 +1,240 @@ +[ + [ + { + "text": "Hello! How can I assist you with your calendar today?", + "participant": "assistant" + }, + { + "text": "I'm having trouble with my calendar not syncing properly.", + "participant": "user" + }, + { + "text": "I'm sorry to hear you're experiencing a sync delay. Can you tell me more about the issue?", + "participant": "assistant" + }, + { + "text": "Yes, I checked but it still doesn't sync.", + "participant": "user" + }, + { + "text": "Could you verify that the time zone settings are properly configured on your devices?", + "participant": "assistant" + }, + { + "text": "Alright, please guide me through the steps.", + "participant": "user" + }, + { + "text": "Sure, let me guide you through adjusting your time zone settings.", + "participant": "assistant" + }, + { + "text": "I want to change the time zone instead.", + "participant": "user" + }, + { + "text": "Of course! Which device would you like to adjust the time zone settings on?", + "participant": "assistant" + }, + { + "text": "Sure, let's adjust the time zone settings.", + "participant": "user" + }, + { + "text": "Sure, let me guide you through adjusting your time zone settings.", + "participant": "assistant" + }, + { + "text": "I've followed the steps and it seems to be working now.", + "participant": "user" + }, + { + "text": "Great! Is your calendar syncing correctly now?", + "participant": "assistant" + }, + { + "text": "Actually, it's still not syncing. I want to try something else.", + "participant": "user" + }, + { + "text": "I'm sorry we couldn't resolve the issue. If you need further assistance, feel free to reach out anytime. Have a great day!", + "participant": "assistant" + } + ], + [ + { + "text": "Hello! How can I assist you with your calendar today?", + "participant": "assistant" + }, + { + "text": "I'm having trouble with my calendar not syncing properly.", + "participant": "user" + }, + { + "text": "I'm sorry to hear you're experiencing a sync delay. Can you tell me more about the issue?", + "participant": "assistant" + }, + { + "text": "Yes, I checked but it still doesn't sync.", + "participant": "user" + }, + { + "text": "Have you checked if your time zone settings are correct on all your devices?", + "participant": "assistant" + }, + { + "text": "Alright, please guide me through the steps.", + "participant": "user" + }, + { + "text": "Sure, let me guide you through adjusting your time zone settings.", + "participant": "assistant" + }, + { + "text": "I want to change the time zone instead.", + "participant": "user" + }, + { + "text": "Of course! Which device would you like to adjust the time zone settings on?", + "participant": "assistant" + }, + { + "text": "Sure, let's adjust the time zone settings.", + "participant": "user" + }, + { + "text": "Sure, let me guide you through adjusting your time zone settings.", + "participant": "assistant" + }, + { + "text": "I've followed the steps and it seems to be working now.", + "participant": "user" + }, + { + "text": "Great! Is your calendar syncing correctly now?", + "participant": "assistant" + }, + { + "text": "Actually, it's still not syncing. I want to try something else.", + "participant": "user" + }, + { + "text": "I'm sorry we couldn't resolve the issue. If you need further assistance, feel free to reach out anytime. Have a great day!", + "participant": "assistant" + } + ], + [ + { + "text": "Hello! How can I assist you with your calendar today?", + "participant": "assistant" + }, + { + "text": "I'm having trouble with my calendar not syncing properly.", + "participant": "user" + }, + { + "text": "I'm sorry to hear you're experiencing a sync delay. Can you tell me more about the issue?", + "participant": "assistant" + }, + { + "text": "Yes, I checked but it still doesn't sync.", + "participant": "user" + }, + { + "text": "Could you verify that the time zone settings are properly configured on your devices?", + "participant": "assistant" + }, + { + "text": "Time zone settings are correct, but the issue persists.", + "participant": "user" + }, + { + "text": "If time zone settings are correct, we can try refreshing the calendar app or reinstalling it.", + "participant": "assistant" + }, + { + "text": "After adjusting, it still doesn't sync.", + "participant": "user" + }, + { + "text": "Great! Is your calendar syncing correctly now?", + "participant": "assistant" + }, + { + "text": "Actually, it's still not syncing. I want to try something else.", + "participant": "user" + }, + { + "text": "I'm sorry we couldn't resolve the issue. If you need further assistance, feel free to reach out anytime. Have a great day!", + "participant": "assistant" + } + ], + [ + { + "text": "Hello! How can I assist you with your calendar today?", + "participant": "assistant" + }, + { + "text": "I'm having trouble with my calendar not syncing properly.", + "participant": "user" + }, + { + "text": "I'm sorry to hear you're experiencing a sync delay. Can you tell me more about the issue?", + "participant": "assistant" + }, + { + "text": "Yes, I checked but it still doesn't sync.", + "participant": "user" + }, + { + "text": "Have you checked if your time zone settings are correct on all your devices?", + "participant": "assistant" + }, + { + "text": "Time zone settings are correct, but the issue persists.", + "participant": "user" + }, + { + "text": "If time zone settings are correct, we can try refreshing the calendar app or reinstalling it.", + "participant": "assistant" + }, + { + "text": "After adjusting, it still doesn't sync.", + "participant": "user" + }, + { + "text": "Great! Is your calendar syncing correctly now?", + "participant": "assistant" + }, + { + "text": "Actually, it's still not syncing. I want to try something else.", + "participant": "user" + }, + { + "text": "I'm sorry we couldn't resolve the issue. If you need further assistance, feel free to reach out anytime. Have a great day!", + "participant": "assistant" + } + ], + [ + { + "text": "Hello! How can I assist you with your calendar today?", + "participant": "assistant" + }, + { + "text": "I'm having trouble with my calendar not syncing properly.", + "participant": "user" + }, + { + "text": "I'm sorry to hear you're experiencing a sync delay. Can you tell me more about the issue?", + "participant": "assistant" + }, + { + "text": "I think I'll handle it later. Thanks anyway.", + "participant": "user" + }, + { + "text": "I'm sorry we couldn't resolve the issue. If you need further assistance, feel free to reach out anytime. Have a great day!", + "participant": "assistant" + } + ] +] \ No newline at end of file diff --git a/examples/pipeline_usage/pipeline_example.ipynb b/examples/pipeline_usage/pipeline_example.ipynb index 554536d8..55615da4 100644 --- a/examples/pipeline_usage/pipeline_example.ipynb +++ b/examples/pipeline_usage/pipeline_example.ipynb @@ -11,24 +11,48 @@ "First of all we need to import the `ModelStorage` and `Pipeline` we will be using." ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import getpass\n", + "import logging\n", + "import warnings\n", + "\n", + "logging.disable(level=logging.CRITICAL)\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/askatasuna/Документы/DeepPavlov/chatsky-llm-autoconfig/.venv/lib/python3.12/site-packages/pydantic/_internal/_fields.py:198: UserWarning: Field name \"validate\" in \"Dialog\" shadows an attribute in parent \"BaseModel\"\n", - " warnings.warn(\n", - "2025-04-01 21:17:00,978 - datasets - INFO - PyTorch version 2.6.0 available.\n" - ] - } - ], + "outputs": [], + "source": [ + "# Required to run D2GLLMPipeline\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "\n", + "# Replace if you use different base_url\n", + "base_url = \"\"\n", + "if base_url:\n", + " os.environ[\"OPENAI_BASE_URL\"] = base_url" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], "source": [ + "from langchain_openai.chat_models import ChatOpenAI\n", + "\n", "from dialog2graph.pipelines.model_storage import ModelStorage\n", - "from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline" + "from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline\n", + "from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType" ] }, { @@ -42,59 +66,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-04-01 21:17:01,652 - dialog2graph.pipelines.model_storage - INFO - Added llm model 'my_model' to storage\n", - "2025-04-01 21:17:01,707 - dialog2graph.pipelines.model_storage - INFO - Added llm model 'd2g_llm_grouping_llm:v1' to storage\n", - "2025-04-01 21:17:01,760 - dialog2graph.pipelines.model_storage - INFO - Added llm model 'd2g_llm_filling_llm:v1' to storage\n", - "2025-04-01 21:17:01,762 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: cointegrated/LaBSE-en-ru\n", - "2025-04-01 21:17:05,250 - dialog2graph.pipelines.model_storage - INFO - Added emb model 'd2g_llm_sim_model:v1' to storage\n" - ] - } - ], + "outputs": [], "source": [ "ms = ModelStorage()\n", "ms.add(\n", - " \"my_model\",\n", - " config={\"name\": \"gpt-3.5-turbo\"},\n", - " model_type=\"llm\",\n", + " \"my_filling_model\",\n", + " config={\"model_name\": \"gpt-3.5-turbo\"},\n", + " model_type=ChatOpenAI,\n", ")\n", - "\n", - "pipe = D2GLLMPipeline(ms, formatting_llm=\"my_model\")" + "pipe = D2GLLMPipeline(\"d2g_pipeline\", model_storage=ms, filling_llm=\"my_filling_model\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example of building graph based on single dialogue" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 50, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-04-01 21:17:10,679 - httpx - INFO - HTTP Request: POST https://api.vsegpt.ru/v1/chat/completions \"HTTP/1.1 200 OK\"\n", - "2025-04-01 21:17:10,721 - chromadb.telemetry.product.posthog - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.\n", - "/home/askatasuna/Документы/DeepPavlov/chatsky-llm-autoconfig/dialog2graph/utils/vector_stores.py:42: UserWarning: Relevance scores must be between 0 and 1, got [(Document(id='0', metadata={'id': 0}, page_content='hey there! how can i help you today?'), 0.9999999999998331), (Document(id='1', metadata={'id': 1}, page_content='sure! i can help with that. when is your flight, and where are you departing from?'), 0.3425369380032218), (Document(id='3', metadata={'id': 3}, page_content='no worries! feel free to reach out anytime.'), 0.09999855712188555), (Document(id='4', metadata={'id': 4}, page_content=\"you're welcome! have a fantastic trip!\"), 0.012217798441700123), (Document(id='2', metadata={'id': 2}, page_content=\"if you'd prefer, i can send you options for ride-share services instead. would you like that?\"), -0.15093265214397067)]\n", - " docs = self.assistant_store.similarity_search_with_relevance_scores(\n", - "/home/askatasuna/Документы/DeepPavlov/chatsky-llm-autoconfig/dialog2graph/utils/vector_stores.py:42: UserWarning: Relevance scores must be between 0 and 1, got [(Document(id='2', metadata={'id': 2}, page_content=\"if you'd prefer, i can send you options for ride-share services instead. would you like that?\"), 0.9999999999997823), (Document(id='1', metadata={'id': 1}, page_content='sure! i can help with that. when is your flight, and where are you departing from?'), 0.07645962652224825), (Document(id='4', metadata={'id': 4}, page_content=\"you're welcome! have a fantastic trip!\"), -0.1350976596899427), (Document(id='0', metadata={'id': 0}, page_content='hey there! how can i help you today?'), -0.15093282073136471), (Document(id='3', metadata={'id': 3}, page_content='no worries! feel free to reach out anytime.'), -0.1615920958945578)]\n", - " docs = self.assistant_store.similarity_search_with_relevance_scores(\n", - "/home/askatasuna/Документы/DeepPavlov/chatsky-llm-autoconfig/dialog2graph/utils/vector_stores.py:42: UserWarning: Relevance scores must be between 0 and 1, got [(Document(id='3', metadata={'id': 3}, page_content='no worries! feel free to reach out anytime.'), 0.9999999999998337), (Document(id='4', metadata={'id': 4}, page_content=\"you're welcome! have a fantastic trip!\"), 0.1991303893367219), (Document(id='0', metadata={'id': 0}, page_content='hey there! how can i help you today?'), 0.09999855712188555), (Document(id='1', metadata={'id': 1}, page_content='sure! i can help with that. when is your flight, and where are you departing from?'), 0.09619067365262324), (Document(id='2', metadata={'id': 2}, page_content=\"if you'd prefer, i can send you options for ride-share services instead. would you like that?\"), -0.1615924330693459)]\n", - " docs = self.assistant_store.similarity_search_with_relevance_scores(\n", - "/home/askatasuna/Документы/DeepPavlov/chatsky-llm-autoconfig/dialog2graph/utils/vector_stores.py:42: UserWarning: Relevance scores must be between 0 and 1, got [(Document(id='4', metadata={'id': 4}, page_content=\"you're welcome! have a fantastic trip!\"), 0.9999999999998329), (Document(id='3', metadata={'id': 3}, page_content='no worries! feel free to reach out anytime.'), 0.1991303893367219), (Document(id='1', metadata={'id': 1}, page_content='sure! i can help with that. when is your flight, and where are you departing from?'), 0.04436421132532331), (Document(id='0', metadata={'id': 0}, page_content='hey there! how can i help you today?'), 0.012217629854306078), (Document(id='2', metadata={'id': 2}, page_content=\"if you'd prefer, i can send you options for ride-share services instead. would you like that?\"), -0.1350977439836396)]\n", - " docs = self.assistant_store.similarity_search_with_relevance_scores(\n", - "2025-04-01 21:17:11,456 - dialog2graph.pipelines.core.graph - WARNING - requires pygraphviz http://pygraphviz.github.io/.\n", - "Install pygraphviz from http://pygraphviz.github.io/ .\n", - "Falling back to default layout.\n" - ] - }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -121,55 +120,139 @@ " {\"text\": \"Alright, thanks anyway.\", \"participant\": \"user\"},\n", " {\"text\": \"You're welcome! Have a fantastic trip!\", \"participant\": \"assistant\"},\n", "]\n", - "graph = pipe.invoke(data)\n", + "data = PipelineRawDataType(dialogs=data)\n", + "graph, report = pipe.invoke(data)\n", "graph.visualise_short(\"transfer\")" ] }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{1: 'greeting',\n", + " 2: 'ask_flight_info',\n", + " 3: 'offer_rideshare',\n", + " 4: 'closing_no_help',\n", + " 5: 'farewell'}" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "{node['id']: node['label'] for node in graph.graph_dict['nodes']}" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Also you can load and dump the models with their configs to the `.yml` files." + "Use multiple dialogues to build complex graphs" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-04-01 21:17:11,641 - dialog2graph.pipelines.model_storage - INFO - Saved 4 models to models_config.yml\n" - ] + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with open(\"multiple_dialogs.json\") as f:\n", + " multiple_dialogs = json.load(f)\n", + "\n", + "data = PipelineRawDataType(dialogs=multiple_dialogs)\n", + "graph, report = pipe.invoke(data)\n", + "graph.visualise_short(\"transfer\")" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{1: 'start',\n", + " 2: 'ask_issue_details',\n", + " 3: 'ask_timezone_verification',\n", + " 4: 'guide_timezone_adjustment',\n", + " 5: 'ask_device_for_timezone',\n", + " 6: 'suggest_refresh_or_reinstall',\n", + " 7: 'ask_sync_confirmation',\n", + " 8: 'closing_unresolved'}" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" } ], + "source": [ + "{node['id']: node['label'] for node in graph.graph_dict['nodes']}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also you can load and dump the models with their configs to the `.yml` files." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], "source": [ "ms.save(\"models_config.yml\")" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "d2g_llm_filling_llm:v1:\n", - " name: o3-mini\n", - " temperature: 1\n", - "d2g_llm_grouping_llm:v1:\n", - " name: gpt-4o-latest\n", - " temperature: 0\n", - "d2g_llm_sim_model:v1:\n", - " model_kwargs:\n", - " device: cpu\n", - " model_name: cointegrated/LaBSE-en-ru\n", - "my_model:\n", - " name: gpt-3.5-turbo\n" + "my_filling_model:\n", + " config:\n", + " model_name: gpt-3.5-turbo\n", + " model_type: ChatOpenAI\n", + "three_stages_llm_formatting_llm:v1:\n", + " config:\n", + " model_name: gpt-4o-mini\n", + " temperature: 0\n", + " model_type: ChatOpenAI\n", + "three_stages_llm_grouping_llm:v1:\n", + " config:\n", + " model_name: chatgpt-4o-latest\n", + " temperature: 0\n", + " model_type: ChatOpenAI\n", + "three_stages_llm_sim_model:v1:\n", + " config:\n", + " model_kwargs:\n", + " device: cpu\n", + " model_name: BAAI/bge-m3\n", + " model_type: HuggingFaceEmbeddings\n" ] } ], @@ -180,7 +263,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "dialog2graph-py3.12 (3.12.3)", "language": "python", "name": "python3" },