Skip to content

Commit 13a4c6e

Browse files
author
Rossdan Craig rossdan@lastmileai.dev
committed
[HF][streaming][2/n] Text Translation
TSIA Adding streaming output support for text translation model parser. I also fixed a bug where we didn't pass in `"translation"` key into the pipeline ## Test Plan Rebase onto and test it: 5b74344. Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command ```bash aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'" aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path ``` With Streaming https://github.com/lastmile-ai/aiconfig/assets/151060367/d7bc9df2-2993-4709-bf9b-c5b7979fb00f Without Streaming https://github.com/lastmile-ai/aiconfig/assets/151060367/71eb6ab3-5d6f-4c5d-8b82-f3daf4c5e610
1 parent 074b768 commit 13a4c6e

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

extensions/HuggingFace/python/requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ huggingface_hub
1010

1111
#Hugging Face Libraries - Local Inference Tranformers & Diffusors
1212
accelerate # Used to help speed up image generation
13-
diffusers # Used for image + audio generation
13+
diffusers # Used for image generation
14+
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
15+
sentencepiece # Used for text translation
1416
torch
1517
torchvision
1618
torchaudio
17-
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
1819
transformers # Used for text generation
1920

2021
#Other

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,19 @@ def construct_stream_output(
129129
"metadata": {},
130130
}
131131
)
132+
132133
accumulated_message = ""
133134
for new_text in streamer:
134135
if isinstance(new_text, str):
136+
# For some reason these symbols aren't filtered out by the streamer
137+
new_text = new_text.replace("</s>", "")
138+
new_text = new_text.replace("<s>", "")
139+
new_text = new_text.replace("<pad>", "")
140+
135141
accumulated_message += new_text
136142
options.stream_callback(new_text, accumulated_message, 0)
137143
output.data = accumulated_message
144+
138145
return output
139146

140147

@@ -240,19 +247,26 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
240247

241248
model_name: str = aiconfig.get_model_name(prompt)
242249
if isinstance(model_name, str) and model_name not in self.translators:
243-
self.translators[model_name] = pipeline(model_name)
250+
self.translators[model_name] = pipeline("translation", model_name)
244251
translator = self.translators[model_name]
245252

246253
# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
247254
streamer = None
248-
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
255+
should_stream = (options.stream if options else False) and (
256+
not "stream" in completion_data or completion_data.get("stream") != False
257+
)
249258
if should_stream:
250-
raise NotImplementedError("Streaming is not supported for HuggingFace Text Translation")
259+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
260+
streamer = TextIteratorStreamer(tokenizer)
261+
completion_data["streamer"] = streamer
262+
263+
def _translate():
264+
return translator(inputs, **completion_data)
251265

252266
outputs: List[Output] = []
253267
output = None
254268
if not should_stream:
255-
response: List[Any] = translator(inputs, **completion_data)
269+
response: List[Any] = _translate()
256270
for count, result in enumerate(response):
257271
output = construct_regular_output(result, count)
258272
outputs.append(output)
@@ -263,7 +277,7 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
263277
raise ValueError("Stream option is selected but streamer is not initialized")
264278

265279
# For streaming, cannot call `translator` directly otherwise response will be blocking
266-
thread = threading.Thread(target=translator, kwargs=completion_data)
280+
thread = threading.Thread(target=_translate)
267281
thread.start()
268282
output = construct_stream_output(streamer, options)
269283
if output is not None:

0 commit comments

Comments
 (0)