1+ import pandas as pd
2+ import tiktoken
3+ from dataflow .utils .registry import OPERATOR_REGISTRY
4+ from dataflow import get_logger
5+ from pathlib import Path
6+
7+ from dataflow .utils .storage import DataFlowStorage
8+ from dataflow .core import OperatorABC
9+ from dataflow .core import LLMServingABC
10+
11+ @OPERATOR_REGISTRY .register ()
12+ class ChunkedPromptedGenerator (OperatorABC ):
13+ """
14+ 基于Prompt的生成算子,支持自动chunk输入。
15+ - 使用tiktoken或HuggingFace的AutoTokenizer计算token数量;
16+ - 若输入超过max_chunk_len,采用递归二分法切分;
17+ - 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
18+ - 生成结果是以separator拼接的字符串。
19+ """
20+
21+ def __init__ (
22+ self ,
23+ llm_serving : LLMServingABC ,
24+ system_prompt : str = "You are a helpful agent." ,
25+ json_schema : dict = None ,
26+ max_chunk_len : int = 128000 ,
27+ enc = tiktoken .get_encoding ("cl100k_base" ), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer
28+ seperator : str = "\n " ,
29+ ):
30+ self .logger = get_logger ()
31+ self .llm_serving = llm_serving
32+ self .system_prompt = system_prompt
33+ self .json_schema = json_schema
34+ self .max_chunk_len = max_chunk_len
35+ self .enc = enc
36+ self .separator = seperator
37+
38+ @staticmethod
39+ def get_desc (lang : str = "zh" ):
40+ if lang == "zh" :
41+ return (
42+ "基于提示词的生成算子,支持长文本自动分chunk。"
43+ "采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
44+ "从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
45+ "多个生成结果以separator拼接成最终输出字符串。"
46+ "输入参数:\n "
47+ "- llm_serving:LLM服务对象,需实现LLMServingABC接口\n "
48+ "- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n "
49+ "- max_chunk_len:单个chunk的最大token长度,默认为128000\n "
50+ "- input_path_key:输入文件路径字段名,默认为'input_path'\n "
51+ "- output_path_key:输出文件路径字段名,默认为'output_path'\n "
52+ "- json_schema:可选,生成结果的JSON Schema约束\n "
53+ "- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n "
54+ "- separator:chunk结果拼接分隔符,默认为换行符\n "
55+ )
56+ else :
57+ return (
58+ "Prompt-based generator with recursive chunk splitting."
59+ "Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
60+ "Reads content from specified input file paths and saves generated results to designated output file paths."
61+ "Multiple generated results are joined as a string using the specified separator."
62+ "Input Parameters:\n "
63+ "- llm_serving: LLM serving object implementing LLMServingABC interface\n "
64+ "- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n "
65+ "- max_chunk_len: Maximum token length per chunk, default is 128000\n "
66+ "- input_path_key: Field name for input file path, default is 'input_path'\n "
67+ "- output_path_key: Field name for output file path, default is 'output_path'\n "
68+ "- json_schema: Optional JSON Schema constraint for generated results\n "
69+ "- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n "
70+ "- separator: Separator for chunk results, default is newline character\n "
71+ )
72+
73+ # === token计算 ===
74+ def _count_tokens (self , text : str ) -> int :
75+ return len (self .enc .encode (text ))
76+
77+ # === 递归二分分chunk ===
78+ def _split_recursive (self , text : str ) -> list [str ]:
79+ """递归地将文本拆分为不超过max_chunk_len的多个chunk"""
80+ token_len = self ._count_tokens (text )
81+ if token_len <= self .max_chunk_len :
82+ return [text ]
83+ else :
84+ mid = len (text ) // 2
85+ left , right = text [:mid ], text [mid :]
86+ return self ._split_recursive (left ) + self ._split_recursive (right )
87+
88+ def run (
89+ self ,
90+ storage : DataFlowStorage ,
91+ input_path_key ,
92+ output_path_key ,
93+ ):
94+ self .logger .info ("Running ChunkedPromptedGenerator..." )
95+ dataframe = storage .read ("dataframe" )
96+ self .logger .info (f"Loaded DataFrame with { len (dataframe )} rows." )
97+
98+ all_generated_results = []
99+
100+ all_llm_inputs = []
101+ row_chunk_map = [] # 记录每个row对应的chunk数量
102+
103+ # === 先收集所有chunk ===
104+ for i , row in dataframe .iterrows ():
105+ raw_content = Path (row [input_path_key ]).read_text (encoding = 'utf-8' )
106+
107+ chunks = self ._split_recursive (raw_content )
108+ self .logger .info (f"Row { i } : split into { len (chunks )} chunks" )
109+
110+ system_prompt = self .system_prompt + "\n "
111+ llm_inputs = [system_prompt + chunk for chunk in chunks ]
112+ all_llm_inputs .extend (llm_inputs )
113+ row_chunk_map .append (len (chunks ))
114+
115+ # === 一次性并发调用 ===
116+ self .logger .info (f"Total { len (all_llm_inputs )} chunks to generate" )
117+
118+ try :
119+ if self .json_schema :
120+ all_responses = self .llm_serving .generate_from_input (
121+ all_llm_inputs , json_schema = self .json_schema
122+ )
123+ else :
124+ all_responses = self .llm_serving .generate_from_input (all_llm_inputs )
125+ except Exception as e :
126+ self .logger .error (f"Global generation failed: { e } " )
127+ all_generated_results = [[] for _ in range (len (dataframe ))]
128+ else :
129+ # === 按row重新划分responses ===
130+ all_generated_results = []
131+ idx = 0
132+ for num_chunks in row_chunk_map :
133+ if num_chunks == 0 :
134+ all_generated_results .append ([])
135+ else :
136+ all_generated_results .append (all_responses [idx :idx + num_chunks ])
137+ idx += num_chunks
138+
139+ for (i , row ), gen_results in zip (dataframe .iterrows (), all_generated_results ):
140+ output_path = row [input_path_key ].split ('.' )[0 ] + '_llm_output.txt'
141+ with open (output_path , 'w' , encoding = 'utf-8' ) as f :
142+ f .write (self .separator .join (gen_results ))
143+ dataframe .at [i , output_path_key ] = output_path
144+
145+ output_file = storage .write (dataframe )
146+ self .logger .info (f"Generation complete. Output saved to { output_file } " )
147+ return output_path_key
0 commit comments