|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import itertools |
9 | | -import logging |
10 | 9 | from typing import Literal |
11 | 10 | import pandas as pd |
12 | 11 | from dialogue2graph.pipelines.core.graph import BaseGraph |
|
19 | 18 | from dialogue2graph.pipelines.helpers.find_cycle_ends import find_cycle_ends |
20 | 19 | from langchain_core.language_models.chat_models import BaseChatModel |
21 | 20 |
|
22 | | -logging.basicConfig(level=logging.INFO) |
23 | | -logger = logging.getLogger(__name__) |
| 21 | +from dialogue2graph.utils.logger import Logger |
| 22 | + |
| 23 | +logger = Logger(__file__) |
24 | 24 |
|
25 | 25 |
|
26 | 26 | class _DialogPathsCounter: |
@@ -189,23 +189,21 @@ def remove_duplicated_paths(node_paths: list[list[int]]) -> list[list[int]]: |
189 | 189 | return res |
190 | 190 |
|
191 | 191 |
|
192 | | -def get_dialogue_doublets(seq: list[list[dict]]) -> set[tuple[str]]: |
193 | | - """Find all dialogue doublets with (edge, target) utterances |
194 | | -
|
195 | | - Args: |
196 | | - seq: sequence of dialogs |
197 | | -
|
198 | | - Returns: |
199 | | - Set of (user_utterance, assistant_utterance) |
200 | | - """ |
201 | | - doublets = set() |
202 | | - for dialogue in seq: |
203 | | - user_texts = [d["text"] for d in dialogue if d["participant"] == "user"] |
204 | | - assist_texts = [d["text"] for d in dialogue if d["participant"] == "assistant"] |
205 | | - if len(assist_texts) > len(user_texts): |
206 | | - user_texts += [""] |
207 | | - doublets.update(zip(user_texts, assist_texts)) |
208 | | - return doublets |
| 192 | +# def get_dialogue_doublets(seq: list[list[dict]]) -> set[tuple[str]]: |
| 193 | +# """Find all dialogue doublets with (edge, target) utterances |
| 194 | +# Args: |
| 195 | +# seq: sequence of dialogs |
| 196 | +# Returns: |
| 197 | +# Set of (user_utterance, assistant_utterance) |
| 198 | +# """ |
| 199 | +# doublets = set() |
| 200 | +# for dialogue in seq: |
| 201 | +# user_texts = [d["text"] for d in dialogue if d["participant"] == "user"] |
| 202 | +# assist_texts = [d["text"] for d in dialogue if d["participant"] == "assistant"] |
| 203 | +# if len(assist_texts) > len(user_texts): |
| 204 | +# user_texts += [""] |
| 205 | +# doublets.update(zip(user_texts, assist_texts)) |
| 206 | +# return doublets |
209 | 207 |
|
210 | 208 |
|
211 | 209 | def get_dialogue_triplets(seq: list[list[dict]]) -> set[tuple[str]]: |
@@ -239,9 +237,7 @@ def remove_duplicated_dialogues(seq: list[list[dict]]) -> list[list[dict]]: |
239 | 237 | return [] |
240 | 238 | uniq_seq = [non_empty_seq[0]] |
241 | 239 | for s in non_empty_seq[1:]: |
242 | | - if not get_dialogue_doublets([s]).issubset( |
243 | | - get_dialogue_doublets(uniq_seq) |
244 | | - ) or not get_dialogue_triplets([s]).issubset(get_dialogue_triplets(uniq_seq)): |
| 240 | + if not get_dialogue_triplets([s]).issubset(get_dialogue_triplets(uniq_seq)): |
245 | 241 | uniq_seq.append(s) |
246 | 242 | return uniq_seq |
247 | 243 |
|
|
0 commit comments