Skip to content

Commit 8d19231

Browse files
authored
[serve] allow array content inputs for LLMs (#39829)
fix bug; add tests
1 parent 34a1fc6 commit 8d19231

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

src/transformers/commands/serving.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -829,13 +829,22 @@ def get_processor_inputs_from_inbound_messages(messages, modality: Modality):
829829
parsed_message = {"role": message["role"], "content": []}
830830

831831
if modality == Modality.LLM:
832-
# If we're working with LLMs, then "content" is a single string.
833-
content = message["content"] if isinstance(message["content"], str) else message["content"]["text"]
834-
parsed_message["content"] = content
832+
# Input: `content` is a string or a list of dictionaries with a "text" key.
833+
# Output: `content` is a string.
834+
if isinstance(message["content"], str):
835+
parsed_content = message["content"]
836+
elif isinstance(message["content"], list):
837+
parsed_content = []
838+
for content in message["content"]:
839+
if content["type"] == "text":
840+
parsed_content.append(content["text"])
841+
parsed_content = " ".join(parsed_content)
842+
parsed_message["content"] = parsed_content
835843

836844
elif modality == Modality.VLM:
837-
# If we're working with VLMs, then "content" is a dictionary, containing a "type" key indicating
838-
# which other key will be present and the type of the value of said key.
845+
# Input: `content` is a string or a list of dictionaries with a "type" key (possible types: "text",
846+
# "image_url").
847+
# Output: `content` is a list of dictionaries with a "type" key
839848
if isinstance(message["content"], str):
840849
parsed_message["content"].append({"type": "text", "text": message["content"]})
841850
else:

tests/commands/test_serving.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,37 @@ def test_processor_inputs_from_inbound_messages_llm(self):
282282
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
283283
self.assertListEqual(expected_outputs, outputs)
284284

285+
messages_with_type = [
286+
{"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
287+
{
288+
"role": "assistant",
289+
"content": [
290+
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
291+
],
292+
},
293+
{"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
294+
]
295+
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages_with_type, modality)
296+
self.assertListEqual(expected_outputs, outputs)
297+
298+
messages_multiple_text = [
299+
{
300+
"role": "user",
301+
"content": [
302+
{"type": "text", "text": "How are you doing?"},
303+
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"},
304+
],
305+
},
306+
]
307+
expected_outputs_multiple_text = [
308+
{
309+
"role": "user",
310+
"content": "How are you doing? I'm doing great, thank you for asking! How can I assist you today?",
311+
},
312+
]
313+
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages_multiple_text, modality)
314+
self.assertListEqual(expected_outputs_multiple_text, outputs)
315+
285316
def test_processor_inputs_from_inbound_messages_vlm_text_only(self):
286317
modality = Modality.VLM
287318
messages = [

0 commit comments

Comments
 (0)