Skip to content

Commit ffbbbf3

Browse files
[RFC][AIC-py] hf summarization model parser
Very similar to text gen Test: ``` ... { "name": "test_hf_sum", "input": "HMS Duncan was a D-class destroyer ...", # [contents of https://en.wikipedia.org/wiki/HMS_Duncan_(D99)] "metadata": { "model": { "name": "stevhliu/my_awesome_billsum_model", "settings": { "min_length": 100, "max_length": 200, "num_beams": 1 } } } }, ... } ``` ``` import asyncio from aiconfig_extension_hugging_face.local_inference.text_summarization import HuggingFaceTextSummarizationTransformer from aiconfig import AIConfigRuntime, InferenceOptions, CallbackManager # Load the aiconfig. mp = HuggingFaceTextSummarizationTransformer() AIConfigRuntime.register_model_parser(mp, "stevhliu/my_awesome_billsum_model") config = AIConfigRuntime.load("/Users/jonathan/Projects/aiconfig/cookbooks/Getting-Started/travel.aiconfig.json") config.callback_manager = CallbackManager([]) def print_stream(data, _accumulated_data, _index): print(data, end="", flush=True) async def run(): print("Stream") options = InferenceOptions(stream=True, stream_callback=print_stream) out = await config.run("test_hf_sum", options=options) print("Output:\n", out) print("no stream") options = InferenceOptions(stream=False) out = await config.run("test_hf_sum", options=options) print("Output:\n", out) asyncio.run(run()) # OUT: Stream Token indices sequence length is longer than the specified maximum sequence length for this model (2778 > 512). Running this sequence through the model will result in indexing errors /opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:430: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`. warnings.warn( /opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:449: UserWarning: `num_beams` is set to 1. However, `length_penalty` is set to `2.0` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `length_penalty`. warnings.warn( <pad> escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the </s>Output: [ExecuteResult(output_type='execute_result', execution_count=0, data="<pad> escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the </s>", mime_type=None, metadata={})] no stream Output: [ExecuteResult(output_type='execute_result', execution_count=0, data="escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war . The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the .", mime_type=None, metadata={})] ```
1 parent 6bbe147 commit ffbbbf3

File tree

2 files changed

+312
-1
lines changed

2 files changed

+312
-1
lines changed
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
22
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
33
from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
4+
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
45

5-
LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer"]
6+
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
7+
8+
LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer", "HuggingFaceTextSummarizationTransformer"]
69
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"]
710
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import copy
2+
import json
3+
import threading
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
5+
from transformers import (
6+
AutoTokenizer,
7+
Pipeline,
8+
pipeline,
9+
TextIteratorStreamer,
10+
)
11+
12+
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
13+
from aiconfig.model_parser import InferenceOptions
14+
from aiconfig.schema import (
15+
ExecuteResult,
16+
Output,
17+
OutputDataWithValue,
18+
Prompt,
19+
PromptMetadata,
20+
)
21+
from aiconfig.util.params import resolve_prompt
22+
23+
# Circuluar Dependency Type Hints
24+
if TYPE_CHECKING:
25+
from aiconfig.Config import AIConfigRuntime
26+
27+
28+
# Step 1: define Helpers
29+
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
30+
"""
31+
Refines the completion params for the HF text summarization api. Removes any unsupported params.
32+
The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()`
33+
"""
34+
35+
supported_keys = {
36+
"max_length",
37+
"max_new_tokens",
38+
"min_length",
39+
"min_new_tokens",
40+
"early_stopping",
41+
"max_time",
42+
"do_sample",
43+
"num_beams",
44+
"num_beam_groups",
45+
"penalty_alpha",
46+
"use_cache",
47+
"temperature",
48+
"top_k",
49+
"top_p",
50+
"typical_p",
51+
"epsilon_cutoff",
52+
"eta_cutoff",
53+
"diversity_penalty",
54+
"repetition_penalty",
55+
"encoder_repetition_penalty",
56+
"length_penalty",
57+
"no_repeat_ngram_size",
58+
"bad_words_ids",
59+
"force_words_ids",
60+
"renormalize_logits",
61+
"constraints",
62+
"forced_bos_token_id",
63+
"forced_eos_token_id",
64+
"remove_invalid_values",
65+
"exponential_decay_length_penalty",
66+
"suppress_tokens",
67+
"begin_suppress_tokens",
68+
"forced_decoder_ids",
69+
"sequence_bias",
70+
"guidance_scale",
71+
"low_memory",
72+
"num_return_sequences",
73+
"output_attentions",
74+
"output_hidden_states",
75+
"output_scores",
76+
"return_dict_in_generate",
77+
"pad_token_id",
78+
"bos_token_id",
79+
"eos_token_id",
80+
"encoder_no_repeat_ngram_size",
81+
"decoder_start_token_id",
82+
"num_assistant_tokens",
83+
"num_assistant_tokens_schedule",
84+
}
85+
86+
completion_data = {}
87+
for key in model_settings:
88+
if key.lower() in supported_keys:
89+
completion_data[key.lower()] = model_settings[key]
90+
91+
return completion_data
92+
93+
94+
def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
95+
"""
96+
Construct regular output per response result, without streaming enabled
97+
"""
98+
output = ExecuteResult(
99+
**{
100+
"output_type": "execute_result",
101+
"data": result["summary_text"],
102+
"execution_count": execution_count,
103+
"metadata": {},
104+
}
105+
)
106+
return output
107+
108+
109+
def construct_stream_output(
110+
streamer: TextIteratorStreamer,
111+
options: InferenceOptions,
112+
) -> Output:
113+
"""
114+
Constructs the output for a stream response.
115+
116+
Args:
117+
streamer (TextIteratorStreamer): Streams the output. See:
118+
https://huggingface.co/docs/transformers/v4.35.2/en/internal/summarization_utils#transformers.TextIteratorStreamer
119+
options (InferenceOptions): The inference options. Used to determine
120+
the stream callback.
121+
122+
"""
123+
output = ExecuteResult(
124+
**{
125+
"output_type": "execute_result",
126+
"data": "", # We update this below
127+
"execution_count": 0, # Multiple outputs are not supported for streaming
128+
"metadata": {},
129+
}
130+
)
131+
accumulated_message = ""
132+
for new_text in streamer:
133+
if isinstance(new_text, str):
134+
accumulated_message += new_text
135+
options.stream_callback(new_text, accumulated_message, 0)
136+
137+
output.data = accumulated_message
138+
else:
139+
print("not str")
140+
print(f"new_text: {new_text}")
141+
return output
142+
143+
144+
class HuggingFaceTextSummarizationTransformer(ParameterizedModelParser):
145+
"""
146+
A model parser for HuggingFace models of type text summarization task using transformers.
147+
"""
148+
149+
def __init__(self):
150+
"""
151+
Returns:
152+
HuggingFaceTextSummarizationTransformer
153+
154+
Usage:
155+
1. Create a new model parser object with the model ID of the model to use.
156+
parser = HuggingFaceTextSummarizationTransformer()
157+
2. Add the model parser to the registry.
158+
config.register_model_parser(parser)
159+
"""
160+
super().__init__()
161+
self.summarizers: dict[str, Pipeline] = {}
162+
163+
def id(self) -> str:
164+
"""
165+
Returns an identifier for the Model Parser
166+
"""
167+
return "HuggingFaceTextSummarizationTransformer"
168+
169+
async def serialize(
170+
self,
171+
prompt_name: str,
172+
data: Any,
173+
ai_config: "AIConfigRuntime",
174+
parameters: Optional[Dict[str, Any]] = None,
175+
**kwargs,
176+
) -> List[Prompt]:
177+
"""
178+
Defines how a prompt and model inference settings get serialized in the .aiconfig.
179+
180+
Args:
181+
prompt_name (str): The prompt to be serialized.
182+
inference_settings (dict): Model-specific inference settings to be serialized.
183+
184+
Returns:
185+
List[Prompt]: Serialized representation of the prompt and inference settings.
186+
"""
187+
data = copy.deepcopy(data)
188+
189+
# assume data is completion params for HF text summarization
190+
prompt_input = data["prompt"]
191+
192+
# Prompt is handled, remove from data
193+
data.pop("prompt", None)
194+
195+
model_metadata = ai_config.get_model_metadata(data, self.id())
196+
prompt = Prompt(
197+
name=prompt_name,
198+
input=prompt_input,
199+
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
200+
)
201+
return [prompt]
202+
203+
async def deserialize(
204+
self,
205+
prompt: Prompt,
206+
aiconfig: "AIConfigRuntime",
207+
_options,
208+
params: Optional[Dict[str, Any]] = {},
209+
) -> Dict[str, Any]:
210+
"""
211+
Defines how to parse a prompt in the .aiconfig for a particular model
212+
and constructs the completion params for that model.
213+
214+
Args:
215+
serialized_data (str): Serialized data from the .aiconfig.
216+
217+
Returns:
218+
dict: Model-specific completion parameters.
219+
"""
220+
# Build Completion data
221+
model_settings = self.get_model_settings(prompt, aiconfig)
222+
completion_data = refine_chat_completion_params(model_settings)
223+
224+
# Add resolved prompt
225+
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
226+
completion_data["prompt"] = resolved_prompt
227+
return completion_data
228+
229+
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
230+
"""
231+
Invoked to run a prompt in the .aiconfig. This method should perform
232+
the actual model inference based on the provided prompt and inference settings.
233+
234+
Args:
235+
prompt (str): The input prompt.
236+
inference_settings (dict): Model-specific inference settings.
237+
238+
Returns:
239+
InferenceResponse: The response from the model.
240+
"""
241+
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
242+
inputs = completion_data.pop("prompt", None)
243+
244+
model_name: str = aiconfig.get_model_name(prompt)
245+
if isinstance(model_name, str) and model_name not in self.summarizers:
246+
self.summarizers[model_name] = pipeline("summarization", model=model_name)
247+
summarizer = self.summarizers[model_name]
248+
249+
# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
250+
streamer = None
251+
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
252+
if should_stream:
253+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
254+
streamer = TextIteratorStreamer(tokenizer)
255+
completion_data["streamer"] = streamer
256+
257+
outputs: List[Output] = []
258+
output = None
259+
260+
def _summarize():
261+
return summarizer(inputs, **completion_data)
262+
263+
if not should_stream:
264+
response: List[Any] = _summarize()
265+
for count, result in enumerate(response):
266+
output = construct_regular_output(result, count)
267+
outputs.append(output)
268+
else:
269+
if completion_data.get("num_return_sequences", 1) > 1:
270+
raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1")
271+
if not streamer:
272+
raise ValueError("Stream option is selected but streamer is not initialized")
273+
274+
# For streaming, cannot call `summarizer` directly otherwise response will be blocking
275+
thread = threading.Thread(target=_summarize)
276+
thread.start()
277+
output = construct_stream_output(streamer, options)
278+
if output is not None:
279+
outputs.append(output)
280+
281+
prompt.outputs = outputs
282+
return prompt.outputs
283+
284+
def get_output_text(
285+
self,
286+
prompt: Prompt,
287+
aiconfig: "AIConfigRuntime",
288+
output: Optional[Output] = None,
289+
) -> str:
290+
if output is None:
291+
output = aiconfig.get_latest_output(prompt)
292+
293+
if output is None:
294+
return ""
295+
296+
# TODO (rossdanlm): Handle multiple outputs in list
297+
# https://github.com/lastmile-ai/aiconfig/issues/467
298+
if output.output_type == "execute_result":
299+
output_data = output.data
300+
if isinstance(output_data, str):
301+
return output_data
302+
if isinstance(output_data, OutputDataWithValue):
303+
if isinstance(output_data.value, str):
304+
return output_data.value
305+
# HuggingFace Text summarization does not support function
306+
# calls so shouldn't get here, but just being safe
307+
return json.dumps(output_data.value, indent=2)
308+
return ""

0 commit comments

Comments
 (0)