|
| 1 | +import copy |
| 2 | +import json |
| 3 | +import threading |
| 4 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional |
| 5 | +from transformers import ( |
| 6 | + AutoTokenizer, |
| 7 | + Pipeline, |
| 8 | + pipeline, |
| 9 | + TextIteratorStreamer, |
| 10 | +) |
| 11 | + |
| 12 | +from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser |
| 13 | +from aiconfig.model_parser import InferenceOptions |
| 14 | +from aiconfig.schema import ( |
| 15 | + ExecuteResult, |
| 16 | + Output, |
| 17 | + OutputDataWithValue, |
| 18 | + Prompt, |
| 19 | + PromptMetadata, |
| 20 | +) |
| 21 | +from aiconfig.util.params import resolve_prompt |
| 22 | + |
| 23 | +# Circuluar Dependency Type Hints |
| 24 | +if TYPE_CHECKING: |
| 25 | + from aiconfig.Config import AIConfigRuntime |
| 26 | + |
| 27 | + |
| 28 | +# Step 1: define Helpers |
| 29 | +def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: |
| 30 | + """ |
| 31 | + Refines the completion params for the HF text summarization api. Removes any unsupported params. |
| 32 | + The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()` |
| 33 | + """ |
| 34 | + |
| 35 | + supported_keys = { |
| 36 | + "max_length", |
| 37 | + "max_new_tokens", |
| 38 | + "min_length", |
| 39 | + "min_new_tokens", |
| 40 | + "early_stopping", |
| 41 | + "max_time", |
| 42 | + "do_sample", |
| 43 | + "num_beams", |
| 44 | + "num_beam_groups", |
| 45 | + "penalty_alpha", |
| 46 | + "use_cache", |
| 47 | + "temperature", |
| 48 | + "top_k", |
| 49 | + "top_p", |
| 50 | + "typical_p", |
| 51 | + "epsilon_cutoff", |
| 52 | + "eta_cutoff", |
| 53 | + "diversity_penalty", |
| 54 | + "repetition_penalty", |
| 55 | + "encoder_repetition_penalty", |
| 56 | + "length_penalty", |
| 57 | + "no_repeat_ngram_size", |
| 58 | + "bad_words_ids", |
| 59 | + "force_words_ids", |
| 60 | + "renormalize_logits", |
| 61 | + "constraints", |
| 62 | + "forced_bos_token_id", |
| 63 | + "forced_eos_token_id", |
| 64 | + "remove_invalid_values", |
| 65 | + "exponential_decay_length_penalty", |
| 66 | + "suppress_tokens", |
| 67 | + "begin_suppress_tokens", |
| 68 | + "forced_decoder_ids", |
| 69 | + "sequence_bias", |
| 70 | + "guidance_scale", |
| 71 | + "low_memory", |
| 72 | + "num_return_sequences", |
| 73 | + "output_attentions", |
| 74 | + "output_hidden_states", |
| 75 | + "output_scores", |
| 76 | + "return_dict_in_generate", |
| 77 | + "pad_token_id", |
| 78 | + "bos_token_id", |
| 79 | + "eos_token_id", |
| 80 | + "encoder_no_repeat_ngram_size", |
| 81 | + "decoder_start_token_id", |
| 82 | + "num_assistant_tokens", |
| 83 | + "num_assistant_tokens_schedule", |
| 84 | + } |
| 85 | + |
| 86 | + completion_data = {} |
| 87 | + for key in model_settings: |
| 88 | + if key.lower() in supported_keys: |
| 89 | + completion_data[key.lower()] = model_settings[key] |
| 90 | + |
| 91 | + return completion_data |
| 92 | + |
| 93 | + |
| 94 | +def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: |
| 95 | + """ |
| 96 | + Construct regular output per response result, without streaming enabled |
| 97 | + """ |
| 98 | + output = ExecuteResult( |
| 99 | + **{ |
| 100 | + "output_type": "execute_result", |
| 101 | + "data": result["summary_text"], |
| 102 | + "execution_count": execution_count, |
| 103 | + "metadata": {}, |
| 104 | + } |
| 105 | + ) |
| 106 | + return output |
| 107 | + |
| 108 | + |
| 109 | +def construct_stream_output( |
| 110 | + streamer: TextIteratorStreamer, |
| 111 | + options: InferenceOptions, |
| 112 | +) -> Output: |
| 113 | + """ |
| 114 | + Constructs the output for a stream response. |
| 115 | +
|
| 116 | + Args: |
| 117 | + streamer (TextIteratorStreamer): Streams the output. See: |
| 118 | + https://huggingface.co/docs/transformers/v4.35.2/en/internal/summarization_utils#transformers.TextIteratorStreamer |
| 119 | + options (InferenceOptions): The inference options. Used to determine |
| 120 | + the stream callback. |
| 121 | +
|
| 122 | + """ |
| 123 | + output = ExecuteResult( |
| 124 | + **{ |
| 125 | + "output_type": "execute_result", |
| 126 | + "data": "", # We update this below |
| 127 | + "execution_count": 0, # Multiple outputs are not supported for streaming |
| 128 | + "metadata": {}, |
| 129 | + } |
| 130 | + ) |
| 131 | + accumulated_message = "" |
| 132 | + for new_text in streamer: |
| 133 | + if isinstance(new_text, str): |
| 134 | + accumulated_message += new_text |
| 135 | + options.stream_callback(new_text, accumulated_message, 0) |
| 136 | + |
| 137 | + output.data = accumulated_message |
| 138 | + return output |
| 139 | + |
| 140 | + |
| 141 | +class HuggingFaceTextSummarizationTransformer(ParameterizedModelParser): |
| 142 | + """ |
| 143 | + A model parser for HuggingFace models of type text summarization task using transformers. |
| 144 | + """ |
| 145 | + |
| 146 | + def __init__(self): |
| 147 | + """ |
| 148 | + Returns: |
| 149 | + HuggingFaceTextSummarizationTransformer |
| 150 | +
|
| 151 | + Usage: |
| 152 | + 1. Create a new model parser object with the model ID of the model to use. |
| 153 | + parser = HuggingFaceTextSummarizationTransformer() |
| 154 | + 2. Add the model parser to the registry. |
| 155 | + config.register_model_parser(parser) |
| 156 | + """ |
| 157 | + super().__init__() |
| 158 | + self.summarizers: dict[str, Pipeline] = {} |
| 159 | + |
| 160 | + def id(self) -> str: |
| 161 | + """ |
| 162 | + Returns an identifier for the Model Parser |
| 163 | + """ |
| 164 | + return "HuggingFaceTextSummarizationTransformer" |
| 165 | + |
| 166 | + async def serialize( |
| 167 | + self, |
| 168 | + prompt_name: str, |
| 169 | + data: Any, |
| 170 | + ai_config: "AIConfigRuntime", |
| 171 | + parameters: Optional[Dict[str, Any]] = None, |
| 172 | + **kwargs, |
| 173 | + ) -> List[Prompt]: |
| 174 | + """ |
| 175 | + Defines how a prompt and model inference settings get serialized in the .aiconfig. |
| 176 | +
|
| 177 | + Args: |
| 178 | + prompt_name (str): The prompt to be serialized. |
| 179 | + inference_settings (dict): Model-specific inference settings to be serialized. |
| 180 | +
|
| 181 | + Returns: |
| 182 | + List[Prompt]: Serialized representation of the prompt and inference settings. |
| 183 | + """ |
| 184 | + data = copy.deepcopy(data) |
| 185 | + |
| 186 | + # assume data is completion params for HF text summarization |
| 187 | + prompt_input = data["prompt"] |
| 188 | + |
| 189 | + # Prompt is handled, remove from data |
| 190 | + data.pop("prompt", None) |
| 191 | + |
| 192 | + model_metadata = ai_config.get_model_metadata(data, self.id()) |
| 193 | + prompt = Prompt( |
| 194 | + name=prompt_name, |
| 195 | + input=prompt_input, |
| 196 | + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), |
| 197 | + ) |
| 198 | + return [prompt] |
| 199 | + |
| 200 | + async def deserialize( |
| 201 | + self, |
| 202 | + prompt: Prompt, |
| 203 | + aiconfig: "AIConfigRuntime", |
| 204 | + _options, |
| 205 | + params: Optional[Dict[str, Any]] = {}, |
| 206 | + ) -> Dict[str, Any]: |
| 207 | + """ |
| 208 | + Defines how to parse a prompt in the .aiconfig for a particular model |
| 209 | + and constructs the completion params for that model. |
| 210 | +
|
| 211 | + Args: |
| 212 | + serialized_data (str): Serialized data from the .aiconfig. |
| 213 | +
|
| 214 | + Returns: |
| 215 | + dict: Model-specific completion parameters. |
| 216 | + """ |
| 217 | + # Build Completion data |
| 218 | + model_settings = self.get_model_settings(prompt, aiconfig) |
| 219 | + completion_data = refine_chat_completion_params(model_settings) |
| 220 | + |
| 221 | + # Add resolved prompt |
| 222 | + resolved_prompt = resolve_prompt(prompt, params, aiconfig) |
| 223 | + completion_data["prompt"] = resolved_prompt |
| 224 | + return completion_data |
| 225 | + |
| 226 | + async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: |
| 227 | + """ |
| 228 | + Invoked to run a prompt in the .aiconfig. This method should perform |
| 229 | + the actual model inference based on the provided prompt and inference settings. |
| 230 | +
|
| 231 | + Args: |
| 232 | + prompt (str): The input prompt. |
| 233 | + inference_settings (dict): Model-specific inference settings. |
| 234 | +
|
| 235 | + Returns: |
| 236 | + InferenceResponse: The response from the model. |
| 237 | + """ |
| 238 | + completion_data = await self.deserialize(prompt, aiconfig, options, parameters) |
| 239 | + inputs = completion_data.pop("prompt", None) |
| 240 | + |
| 241 | + model_name: str = aiconfig.get_model_name(prompt) |
| 242 | + if isinstance(model_name, str) and model_name not in self.summarizers: |
| 243 | + self.summarizers[model_name] = pipeline("summarization", model=model_name) |
| 244 | + summarizer = self.summarizers[model_name] |
| 245 | + |
| 246 | + # if stream enabled in runtime options and config, then stream. Otherwise don't stream. |
| 247 | + streamer = None |
| 248 | + should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) |
| 249 | + if should_stream: |
| 250 | + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) |
| 251 | + streamer = TextIteratorStreamer(tokenizer) |
| 252 | + completion_data["streamer"] = streamer |
| 253 | + |
| 254 | + outputs: List[Output] = [] |
| 255 | + output = None |
| 256 | + |
| 257 | + def _summarize(): |
| 258 | + return summarizer(inputs, **completion_data) |
| 259 | + |
| 260 | + if not should_stream: |
| 261 | + response: List[Any] = _summarize() |
| 262 | + for count, result in enumerate(response): |
| 263 | + output = construct_regular_output(result, count) |
| 264 | + outputs.append(output) |
| 265 | + else: |
| 266 | + if completion_data.get("num_return_sequences", 1) > 1: |
| 267 | + raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1") |
| 268 | + if not streamer: |
| 269 | + raise ValueError("Stream option is selected but streamer is not initialized") |
| 270 | + |
| 271 | + # For streaming, cannot call `summarizer` directly otherwise response will be blocking |
| 272 | + thread = threading.Thread(target=_summarize) |
| 273 | + thread.start() |
| 274 | + output = construct_stream_output(streamer, options) |
| 275 | + if output is not None: |
| 276 | + outputs.append(output) |
| 277 | + |
| 278 | + prompt.outputs = outputs |
| 279 | + return prompt.outputs |
| 280 | + |
| 281 | + def get_output_text( |
| 282 | + self, |
| 283 | + prompt: Prompt, |
| 284 | + aiconfig: "AIConfigRuntime", |
| 285 | + output: Optional[Output] = None, |
| 286 | + ) -> str: |
| 287 | + if output is None: |
| 288 | + output = aiconfig.get_latest_output(prompt) |
| 289 | + |
| 290 | + if output is None: |
| 291 | + return "" |
| 292 | + |
| 293 | + # TODO (rossdanlm): Handle multiple outputs in list |
| 294 | + # https://github.com/lastmile-ai/aiconfig/issues/467 |
| 295 | + if output.output_type == "execute_result": |
| 296 | + output_data = output.data |
| 297 | + if isinstance(output_data, str): |
| 298 | + return output_data |
| 299 | + if isinstance(output_data, OutputDataWithValue): |
| 300 | + if isinstance(output_data.value, str): |
| 301 | + return output_data.value |
| 302 | + # HuggingFace Text summarization does not support function |
| 303 | + # calls so shouldn't get here, but just being safe |
| 304 | + return json.dumps(output_data.value, indent=2) |
| 305 | + return "" |
0 commit comments