Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,24 @@
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
from dataflow.core import LLMServingABC
from dataflow.prompts.general_text import ConsistentQueryPrompt, ConsistentResponsePrompt
from dataflow.core.prompt import prompt_restrict
from dataflow.prompts.general_text import ConsistentChatPrompt
from dataflow.core.prompt import DIYPromptABC, prompt_restrict
from typing import Union

@prompt_restrict(
ConsistentQueryPrompt,
ConsistentResponsePrompt
ConsistentChatPrompt
)

@OPERATOR_REGISTRY.register()
class ConsistentChatGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None, num_dialogs_per_intent = 20, num_turns_per_dialog = 6, temperature = 0.9):
def __init__(self, llm_serving: LLMServingABC = None, num_dialogs_per_intent = 20, num_turns_per_dialog = 6, temperature = 0.9, prompt_template : Union[ConsistentChatPrompt, DIYPromptABC] = None):
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.llm_serving = llm_serving
self.num_dialogs_per_intent = num_dialogs_per_intent # Based on the topic_dict in the existing prompt, it is recommended to set the value to below 1000 (which can generate 9000 conversation data). Otherwise, it is recommended to add more topic_dict in dataflow.prompts.general_text.ConsistentChatPrompt to increase data richness
self.num_turns_per_dialog = num_turns_per_dialog
self.temperature = temperature
self.query_prompt = ConsistentQueryPrompt()
self.response_prompt = ConsistentResponsePrompt()
self.prompt_template = prompt_template
self.logger.info(f'{self.__class__.__name__} initialized.')

@staticmethod
Expand All @@ -37,6 +36,7 @@ def get_desc(lang: str = "zh"):
"- num_dialogs_per_intent:每个意图生成的对话数量,默认20\n"
"- num_turns_per_dialog:每个对话的轮次数量,默认6\n"
"- temperature:生成温度,控制输出随机性,默认0.9\n"
"- prompt_template:提示词模板对象,用于定义提示结构\n"
"输出参数:\n"
"- 包含category和conversation字段的DataFrame,其中conversation为多轮对话列表"
)
Expand All @@ -48,6 +48,7 @@ def get_desc(lang: str = "zh"):
"- num_dialogs_per_intent: Number of dialogs generated per intent, default 20\n"
"- num_turns_per_dialog: Number of turns per dialog, default 6\n"
"- temperature: Sampling temperature for generation, default 0.9\n"
"- prompt_template: Prompt template object, for defining the prompt structure\n"
"Output Parameters:\n"
"- DataFrame containing 'category' and 'conversation' fields, where conversation is a list of multi-turn dialogues"
)
Expand All @@ -57,7 +58,7 @@ def get_desc(lang: str = "zh"):
def run(self, storage: DataFlowStorage):

# Step 1: Generate all queries using LLM
all_query_prompts = self.query_prompt.build_prompt(num_dialogs_per_intent=self.num_dialogs_per_intent)
all_query_prompts = self.prompt_template.build_prompt(mode="query", num_dialogs_per_intent=self.num_dialogs_per_intent)
# Step 2: Generate queries by calling llm_serving once
self.logger.info("Generating queries...")
queries_list = self.llm_serving.generate_from_input(user_inputs=all_query_prompts)
Expand All @@ -78,7 +79,7 @@ def run(self, storage: DataFlowStorage):
for queries in valid_queries:
category = queries.get("category")
turns = queries.get("turns")
all_response_prompts.append(self.response_prompt.build_prompt(topic=category, queries=turns))
all_response_prompts.append(self.prompt_template.build_prompt(mode="response", topic=category, queries=turns))
self.logger.info("Generating responses...")
responses_list = self.llm_serving.generate_from_input(user_inputs=all_response_prompts)

Expand Down
9 changes: 6 additions & 3 deletions dataflow/operators/text_sft/generate/condor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@
import pandas as pd
from dataflow.core import LLMServingABC
from dataflow.prompts.general_text import CondorQuestionPrompt
from dataflow.core.prompt import prompt_restrict
from dataflow.core.prompt import DIYPromptABC, prompt_restrict
from typing import Union

@prompt_restrict(
CondorQuestionPrompt
)

@OPERATOR_REGISTRY.register()
class CondorGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None, num_samples=15, use_task_diversity=True):
def __init__(self, llm_serving: LLMServingABC = None, num_samples=15, use_task_diversity=True, prompt_template: Union[CondorQuestionPrompt, DIYPromptABC] = None):
# Based on the existing topics, it is recommended to set num_samples below 5000. Otherwise, it is recommended to add topics in dataflow.prompts.general_text.CondorPrompt on your own to increase data richness
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.llm_serving = llm_serving
self.num_questions = num_samples // 3 # 每个prompt生成3个难度的问题
self.prompt = CondorQuestionPrompt()
self.prompt = prompt_template
self.use_task_diversity = use_task_diversity # 是否使用任务场景增强多样性
self.logger.info(f'{self.__class__.__name__} initialized.')

Expand All @@ -33,6 +34,7 @@ def get_desc(lang: str = "zh"):
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- num_samples:生成样本总数,建议小于5000,默认值为15\n"
"- prompt_template:提示词模板对象,用于定义提示结构\n"
"输出参数:\n"
"- 包含'difficulty'、'instruction'和'output'字段的DataFrame\n"
"- 返回生成的DataFrame用于后续处理"
Expand All @@ -44,6 +46,7 @@ def get_desc(lang: str = "zh"):
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- num_samples: Total number of samples to generate, recommended to be less than 5000, default is 15\n\n"
"- prompt_template: Prompt template object, for defining the prompt structure\n"
"Output Parameters:\n"
"- DataFrame containing 'difficulty', 'instruction', and 'output' fields\n"
"- Returns generated DataFrame for subsequent processing"
Expand Down
19 changes: 10 additions & 9 deletions dataflow/operators/text_sft/refine/condor_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
from dataflow.core import LLMServingABC
from dataflow.prompts.general_text import CondorCritiquePrompt, CondorRefinePrompt
from dataflow.core.prompt import prompt_restrict
from dataflow.prompts.general_text import CondorRefinePrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from typing import Union

@prompt_restrict(
CondorCritiquePrompt,
CondorRefinePrompt
)

@OPERATOR_REGISTRY.register()
class CondorRefiner(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None):
def __init__(self, llm_serving: LLMServingABC = None, prompt_template: Union[CondorRefinePrompt, DIYPromptABC] = None):
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.llm_serving = llm_serving
self.critique_prompt = CondorCritiquePrompt() # 创建 CondorPrompt 类的实例
self.refine_prompt = CondorRefinePrompt()
self.prompt_template = prompt_template
self.logger.info(f'{self.__class__.__name__} initialized.')

@staticmethod
Expand All @@ -33,6 +32,7 @@ def get_desc(lang: str = "zh"):
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_instruction_key:输入指令字段名,默认为'instruction'\n"
"- input_output_key:输入回复字段名,默认为'output'\n"
"- prompt_template:提示词模板对象,用于定义提示结构\n"
"输出参数:\n"
"- 包含优化后回复的DataFrame\n"
"- 返回包含优化后回复字段名的列表,用于后续算子引用"
Expand All @@ -44,7 +44,8 @@ def get_desc(lang: str = "zh"):
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_instruction_key: Field name for input instructions, default is 'instruction'\n"
"- input_output_key: Field name for input responses, default is 'output'\n\n"
"- input_output_key: Field name for input responses, default is 'output'\n"
"- prompt_template: Prompt template object, for defining the prompt structure\n"
"Output Parameters:\n"
"- DataFrame containing refined responses\n"
"- List containing refined response field name for subsequent operator reference"
Expand All @@ -56,13 +57,13 @@ def get_desc(lang: str = "zh"):

def generate_critique(self, question, answer):
# 批量生成 Critique
critique_prompts = [self.critique_prompt.build_prompt(q, a) for q, a in zip(question, answer)]
critique_prompts = [self.prompt_template.build_prompt(mode="critique", question=q, answer=a) for q, a in zip(question, answer)]
critique_responses = self.llm_serving.generate_from_input(critique_prompts)
return critique_responses

def generate_refined_answer(self, question, answer, critique):
# 批量生成修改后的答案
refine_prompts = [self.refine_prompt.build_prompt(q, a, c) for q, a, c in zip(question, answer, critique)]
refine_prompts = [self.prompt_template.build_prompt(mode="refine", question=q, answer=a, critique=c) for q, a, c in zip(question, answer, critique)]
refined_answers = self.llm_serving.generate_from_input(refine_prompts)
refined_answers = [answer.replace('[Improved Answer Start]', '').replace('[Improved Answer End]', '').strip() for answer in refined_answers]
return refined_answers
Expand Down
Loading