Skip to content

Commit 00b7acd

Browse files
[AIC-py] hf machine translation (text to text) model parser (#753)
[AIC-py] hf mt model parser text translation sans streaming. Streaming is not trivial here Test: ``` import asyncio from aiconfig_extension_hugging_face.local_inference.text_translation import HuggingFaceTextTranslationTransformer from aiconfig import AIConfigRuntime, InferenceOptions, CallbackManager # Load the aiconfig. mp = HuggingFaceTextTranslationTransformer() AIConfigRuntime.register_model_parser(mp, "translation_en_to_fr") config = AIConfigRuntime.load("/Users/jonathan/Projects/aiconfig/test_hf_transl.aiconfig.json") config.callback_manager = CallbackManager([]) def print_stream(data, _accumulated_data, _index): print(data, end="", flush=True) async def run(): # print("Stream") # options = InferenceOptions(stream=True, stream_callback=print_stream) # out = await config.run("test_hf_trans", options=options) # print("Output:\n", out) print("no stream") options = InferenceOptions(stream=False) out = await config.run("test_hf_trans", options=options) print("Output:\n", out) asyncio.run(run()) ``` --- Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/753). * __->__ #753 * #740
2 parents 1625d10 + 9fd71be commit 00b7acd

File tree

3 files changed

+614
-1
lines changed

3 files changed

+614
-1
lines changed
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
22
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
33
from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
4+
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
5+
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
46

5-
LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer"]
7+
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
8+
9+
LOCAL_INFERENCE_CLASSES = [
10+
"HuggingFaceText2ImageDiffusor",
11+
"HuggingFaceTextGenerationTransformer",
12+
"HuggingFaceTextSummarizationTransformer",
13+
"HuggingFaceTextTranslationTransformer",
14+
]
615
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"]
716
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
import copy
2+
import json
3+
import threading
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
5+
from transformers import (
6+
AutoTokenizer,
7+
Pipeline,
8+
pipeline,
9+
TextIteratorStreamer,
10+
)
11+
12+
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
13+
from aiconfig.model_parser import InferenceOptions
14+
from aiconfig.schema import (
15+
ExecuteResult,
16+
Output,
17+
OutputDataWithValue,
18+
Prompt,
19+
PromptMetadata,
20+
)
21+
from aiconfig.util.params import resolve_prompt
22+
23+
# Circuluar Dependency Type Hints
24+
if TYPE_CHECKING:
25+
from aiconfig.Config import AIConfigRuntime
26+
27+
28+
# Step 1: define Helpers
29+
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
30+
"""
31+
Refines the completion params for the HF text summarization api. Removes any unsupported params.
32+
The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()`
33+
"""
34+
35+
supported_keys = {
36+
"max_length",
37+
"max_new_tokens",
38+
"min_length",
39+
"min_new_tokens",
40+
"early_stopping",
41+
"max_time",
42+
"do_sample",
43+
"num_beams",
44+
"num_beam_groups",
45+
"penalty_alpha",
46+
"use_cache",
47+
"temperature",
48+
"top_k",
49+
"top_p",
50+
"typical_p",
51+
"epsilon_cutoff",
52+
"eta_cutoff",
53+
"diversity_penalty",
54+
"repetition_penalty",
55+
"encoder_repetition_penalty",
56+
"length_penalty",
57+
"no_repeat_ngram_size",
58+
"bad_words_ids",
59+
"force_words_ids",
60+
"renormalize_logits",
61+
"constraints",
62+
"forced_bos_token_id",
63+
"forced_eos_token_id",
64+
"remove_invalid_values",
65+
"exponential_decay_length_penalty",
66+
"suppress_tokens",
67+
"begin_suppress_tokens",
68+
"forced_decoder_ids",
69+
"sequence_bias",
70+
"guidance_scale",
71+
"low_memory",
72+
"num_return_sequences",
73+
"output_attentions",
74+
"output_hidden_states",
75+
"output_scores",
76+
"return_dict_in_generate",
77+
"pad_token_id",
78+
"bos_token_id",
79+
"eos_token_id",
80+
"encoder_no_repeat_ngram_size",
81+
"decoder_start_token_id",
82+
"num_assistant_tokens",
83+
"num_assistant_tokens_schedule",
84+
}
85+
86+
completion_data = {}
87+
for key in model_settings:
88+
if key.lower() in supported_keys:
89+
completion_data[key.lower()] = model_settings[key]
90+
91+
return completion_data
92+
93+
94+
def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
95+
"""
96+
Construct regular output per response result, without streaming enabled
97+
"""
98+
output = ExecuteResult(
99+
**{
100+
"output_type": "execute_result",
101+
"data": result["summary_text"],
102+
"execution_count": execution_count,
103+
"metadata": {},
104+
}
105+
)
106+
return output
107+
108+
109+
def construct_stream_output(
110+
streamer: TextIteratorStreamer,
111+
options: InferenceOptions,
112+
) -> Output:
113+
"""
114+
Constructs the output for a stream response.
115+
116+
Args:
117+
streamer (TextIteratorStreamer): Streams the output. See:
118+
https://huggingface.co/docs/transformers/v4.35.2/en/internal/summarization_utils#transformers.TextIteratorStreamer
119+
options (InferenceOptions): The inference options. Used to determine
120+
the stream callback.
121+
122+
"""
123+
output = ExecuteResult(
124+
**{
125+
"output_type": "execute_result",
126+
"data": "", # We update this below
127+
"execution_count": 0, # Multiple outputs are not supported for streaming
128+
"metadata": {},
129+
}
130+
)
131+
accumulated_message = ""
132+
for new_text in streamer:
133+
if isinstance(new_text, str):
134+
accumulated_message += new_text
135+
options.stream_callback(new_text, accumulated_message, 0)
136+
137+
output.data = accumulated_message
138+
return output
139+
140+
141+
class HuggingFaceTextSummarizationTransformer(ParameterizedModelParser):
142+
"""
143+
A model parser for HuggingFace models of type text summarization task using transformers.
144+
"""
145+
146+
def __init__(self):
147+
"""
148+
Returns:
149+
HuggingFaceTextSummarizationTransformer
150+
151+
Usage:
152+
1. Create a new model parser object with the model ID of the model to use.
153+
parser = HuggingFaceTextSummarizationTransformer()
154+
2. Add the model parser to the registry.
155+
config.register_model_parser(parser)
156+
"""
157+
super().__init__()
158+
self.summarizers: dict[str, Pipeline] = {}
159+
160+
def id(self) -> str:
161+
"""
162+
Returns an identifier for the Model Parser
163+
"""
164+
return "HuggingFaceTextSummarizationTransformer"
165+
166+
async def serialize(
167+
self,
168+
prompt_name: str,
169+
data: Any,
170+
ai_config: "AIConfigRuntime",
171+
parameters: Optional[Dict[str, Any]] = None,
172+
**kwargs,
173+
) -> List[Prompt]:
174+
"""
175+
Defines how a prompt and model inference settings get serialized in the .aiconfig.
176+
177+
Args:
178+
prompt_name (str): The prompt to be serialized.
179+
inference_settings (dict): Model-specific inference settings to be serialized.
180+
181+
Returns:
182+
List[Prompt]: Serialized representation of the prompt and inference settings.
183+
"""
184+
data = copy.deepcopy(data)
185+
186+
# assume data is completion params for HF text summarization
187+
prompt_input = data["prompt"]
188+
189+
# Prompt is handled, remove from data
190+
data.pop("prompt", None)
191+
192+
model_metadata = ai_config.get_model_metadata(data, self.id())
193+
prompt = Prompt(
194+
name=prompt_name,
195+
input=prompt_input,
196+
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
197+
)
198+
return [prompt]
199+
200+
async def deserialize(
201+
self,
202+
prompt: Prompt,
203+
aiconfig: "AIConfigRuntime",
204+
_options,
205+
params: Optional[Dict[str, Any]] = {},
206+
) -> Dict[str, Any]:
207+
"""
208+
Defines how to parse a prompt in the .aiconfig for a particular model
209+
and constructs the completion params for that model.
210+
211+
Args:
212+
serialized_data (str): Serialized data from the .aiconfig.
213+
214+
Returns:
215+
dict: Model-specific completion parameters.
216+
"""
217+
# Build Completion data
218+
model_settings = self.get_model_settings(prompt, aiconfig)
219+
completion_data = refine_chat_completion_params(model_settings)
220+
221+
# Add resolved prompt
222+
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
223+
completion_data["prompt"] = resolved_prompt
224+
return completion_data
225+
226+
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
227+
"""
228+
Invoked to run a prompt in the .aiconfig. This method should perform
229+
the actual model inference based on the provided prompt and inference settings.
230+
231+
Args:
232+
prompt (str): The input prompt.
233+
inference_settings (dict): Model-specific inference settings.
234+
235+
Returns:
236+
InferenceResponse: The response from the model.
237+
"""
238+
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
239+
inputs = completion_data.pop("prompt", None)
240+
241+
model_name: str = aiconfig.get_model_name(prompt)
242+
if isinstance(model_name, str) and model_name not in self.summarizers:
243+
self.summarizers[model_name] = pipeline("summarization", model=model_name)
244+
summarizer = self.summarizers[model_name]
245+
246+
# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
247+
streamer = None
248+
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
249+
if should_stream:
250+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
251+
streamer = TextIteratorStreamer(tokenizer)
252+
completion_data["streamer"] = streamer
253+
254+
outputs: List[Output] = []
255+
output = None
256+
257+
def _summarize():
258+
return summarizer(inputs, **completion_data)
259+
260+
if not should_stream:
261+
response: List[Any] = _summarize()
262+
for count, result in enumerate(response):
263+
output = construct_regular_output(result, count)
264+
outputs.append(output)
265+
else:
266+
if completion_data.get("num_return_sequences", 1) > 1:
267+
raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1")
268+
if not streamer:
269+
raise ValueError("Stream option is selected but streamer is not initialized")
270+
271+
# For streaming, cannot call `summarizer` directly otherwise response will be blocking
272+
thread = threading.Thread(target=_summarize)
273+
thread.start()
274+
output = construct_stream_output(streamer, options)
275+
if output is not None:
276+
outputs.append(output)
277+
278+
prompt.outputs = outputs
279+
return prompt.outputs
280+
281+
def get_output_text(
282+
self,
283+
prompt: Prompt,
284+
aiconfig: "AIConfigRuntime",
285+
output: Optional[Output] = None,
286+
) -> str:
287+
if output is None:
288+
output = aiconfig.get_latest_output(prompt)
289+
290+
if output is None:
291+
return ""
292+
293+
# TODO (rossdanlm): Handle multiple outputs in list
294+
# https://github.com/lastmile-ai/aiconfig/issues/467
295+
if output.output_type == "execute_result":
296+
output_data = output.data
297+
if isinstance(output_data, str):
298+
return output_data
299+
if isinstance(output_data, OutputDataWithValue):
300+
if isinstance(output_data.value, str):
301+
return output_data.value
302+
# HuggingFace Text summarization does not support function
303+
# calls so shouldn't get here, but just being safe
304+
return json.dumps(output_data.value, indent=2)
305+
return ""

0 commit comments

Comments
 (0)