|
1 | 1 | import pandas as pd |
| 2 | +import tiktoken |
2 | 3 | from dataflow.utils.registry import OPERATOR_REGISTRY |
3 | 4 | from dataflow import get_logger |
| 5 | +from pathlib import Path |
4 | 6 |
|
5 | 7 | from dataflow.utils.storage import DataFlowStorage |
6 | 8 | from dataflow.core import OperatorABC |
@@ -87,3 +89,139 @@ def run(self, storage: DataFlowStorage, input_key: str = "raw_content", output_k |
87 | 89 | # Save the updated dataframe to the output file |
88 | 90 | output_file = storage.write(dataframe) |
89 | 91 | return output_key |
| 92 | +class ChunkedPromptedGenerator(OperatorABC): |
| 93 | + """ |
| 94 | + 基于Prompt的生成算子,支持自动chunk输入。 |
| 95 | + - 使用tiktoken精确计算token数量; |
| 96 | + - 若输入超过max_chunk_len,采用递归二分法切分; |
| 97 | + - 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径; |
| 98 | + - 生成结果是以separator拼接的字符串。 |
| 99 | + """ |
| 100 | + |
| 101 | + def __init__( |
| 102 | + self, |
| 103 | + llm_serving: LLMServingABC, |
| 104 | + system_prompt: str = "You are a helpful agent.", |
| 105 | + json_schema: dict = None, |
| 106 | + max_chunk_len: int = 128000, |
| 107 | + enc = tiktoken.get_encoding("cl100k_base"), |
| 108 | + seperator: str = "\n", |
| 109 | + ): |
| 110 | + self.logger = get_logger() |
| 111 | + self.llm_serving = llm_serving |
| 112 | + self.system_prompt = system_prompt |
| 113 | + self.json_schema = json_schema |
| 114 | + self.max_chunk_len = max_chunk_len |
| 115 | + self.enc = enc |
| 116 | + self.separator = seperator |
| 117 | + |
| 118 | + @staticmethod |
| 119 | + def get_desc(lang: str = "zh"): |
| 120 | + if lang == "zh": |
| 121 | + return ( |
| 122 | + "基于提示词的生成算子,支持长文本自动分chunk。" |
| 123 | + "采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。" |
| 124 | + "从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。" |
| 125 | + "多个生成结果以separator拼接成最终输出字符串。" |
| 126 | + "输入参数:\n" |
| 127 | + "- llm_serving:LLM服务对象,需实现LLMServingABC接口\n" |
| 128 | + "- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n" |
| 129 | + "- max_chunk_len:单个chunk的最大token长度,默认为128000\n" |
| 130 | + "- input_path_key:输入文件路径字段名,默认为'input_path'\n" |
| 131 | + "- output_path_key:输出文件路径字段名,默认为'output_path'\n" |
| 132 | + "- json_schema:可选,生成结果的JSON Schema约束\n" |
| 133 | + "- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n" |
| 134 | + "- separator:chunk结果拼接分隔符,默认为换行符\n" |
| 135 | + ) |
| 136 | + else: |
| 137 | + return ( |
| 138 | + "Prompt-based generator with recursive chunk splitting." |
| 139 | + "Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens." |
| 140 | + "Reads content from specified input file paths and saves generated results to designated output file paths." |
| 141 | + "Multiple generated results are joined as a string using the specified separator." |
| 142 | + "Input Parameters:\n" |
| 143 | + "- llm_serving: LLM serving object implementing LLMServingABC interface\n" |
| 144 | + "- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n" |
| 145 | + "- max_chunk_len: Maximum token length per chunk, default is 128000\n" |
| 146 | + "- input_path_key: Field name for input file path, default is 'input_path'\n" |
| 147 | + "- output_path_key: Field name for output file path, default is 'output_path'\n" |
| 148 | + "- json_schema: Optional JSON Schema constraint for generated results\n" |
| 149 | + "- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n" |
| 150 | + "- separator: Separator for chunk results, default is newline character\n" |
| 151 | + ) |
| 152 | + |
| 153 | + # === token计算 === |
| 154 | + def _count_tokens(self, text: str) -> int: |
| 155 | + return len(self.enc.encode(text)) |
| 156 | + |
| 157 | + # === 递归二分分chunk === |
| 158 | + def _split_recursive(self, text: str) -> list[str]: |
| 159 | + """递归地将文本拆分为不超过max_chunk_len的多个chunk""" |
| 160 | + token_len = self._count_tokens(text) |
| 161 | + if token_len <= self.max_chunk_len: |
| 162 | + return [text] |
| 163 | + else: |
| 164 | + mid = len(text) // 2 |
| 165 | + left, right = text[:mid], text[mid:] |
| 166 | + return self._split_recursive(left) + self._split_recursive(right) |
| 167 | + |
| 168 | + def run( |
| 169 | + self, |
| 170 | + storage: DataFlowStorage, |
| 171 | + input_path_key, |
| 172 | + output_path_key, |
| 173 | + ): |
| 174 | + self.logger.info("Running ChunkedPromptedGenerator...") |
| 175 | + dataframe = storage.read("dataframe") |
| 176 | + self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.") |
| 177 | + |
| 178 | + all_generated_results = [] |
| 179 | + |
| 180 | + all_llm_inputs = [] |
| 181 | + row_chunk_map = [] # 记录每个row对应的chunk数量 |
| 182 | + |
| 183 | + # === 先收集所有chunk === |
| 184 | + for i, row in dataframe.iterrows(): |
| 185 | + raw_content = Path(row[input_path_key]).read_text(encoding='utf-8') |
| 186 | + |
| 187 | + chunks = self._split_recursive(raw_content) |
| 188 | + self.logger.info(f"Row {i}: split into {len(chunks)} chunks") |
| 189 | + |
| 190 | + system_prompt = self.system_prompt + "\n" |
| 191 | + llm_inputs = [system_prompt + chunk for chunk in chunks] |
| 192 | + all_llm_inputs.extend(llm_inputs) |
| 193 | + row_chunk_map.append(len(chunks)) |
| 194 | + |
| 195 | + # === 一次性并发调用 === |
| 196 | + self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate") |
| 197 | + |
| 198 | + try: |
| 199 | + if self.json_schema: |
| 200 | + all_responses = self.llm_serving.generate_from_input( |
| 201 | + all_llm_inputs, json_schema=self.json_schema |
| 202 | + ) |
| 203 | + else: |
| 204 | + all_responses = self.llm_serving.generate_from_input(all_llm_inputs) |
| 205 | + except Exception as e: |
| 206 | + self.logger.error(f"Global generation failed: {e}") |
| 207 | + all_generated_results = [[] for _ in range(len(dataframe))] |
| 208 | + else: |
| 209 | + # === 按row重新划分responses === |
| 210 | + all_generated_results = [] |
| 211 | + idx = 0 |
| 212 | + for num_chunks in row_chunk_map: |
| 213 | + if num_chunks == 0: |
| 214 | + all_generated_results.append([]) |
| 215 | + else: |
| 216 | + all_generated_results.append(all_responses[idx:idx + num_chunks]) |
| 217 | + idx += num_chunks |
| 218 | + |
| 219 | + for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results): |
| 220 | + output_path = row[input_path_key].split('.')[0] + '_llm_output.txt' |
| 221 | + with open(output_path, 'w', encoding='utf-8') as f: |
| 222 | + f.write(self.separator.join(gen_results)) |
| 223 | + dataframe.at[i, output_path_key] = output_path |
| 224 | + |
| 225 | + output_file = storage.write(dataframe) |
| 226 | + self.logger.info(f"Generation complete. Output saved to {output_file}") |
| 227 | + return output_path_key |
0 commit comments