Skip to content

Commit 33ad323

Browse files
Langchain::Assistant when using OpenAI accepts a message with image_url (#799)
* Langchain::Assistant when using OpenAI accept a message with image_url * CHANGELOG entry + fixing linter
1 parent bdafbc1 commit 33ad323

File tree

7 files changed

+112
-35
lines changed

7 files changed

+112
-35
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## [Unreleased]
2+
- Assistant can now process image_urls in the messages (currently only for OpenAI)
23

34
## [0.16.1] - 2024-09-30
45
- Deprecate Langchain::LLM::GooglePalm

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ assistant = Langchain::Assistant.new(
501501
# Add a user message and run the assistant
502502
assistant.add_message_and_run!(content: "What's the latest news about AI?")
503503

504+
# Supply an image to the assistant
505+
assistant.add_message_and_run!(
506+
content: "Show me a picture of a cat",
507+
image: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
508+
)
509+
504510
# Access the conversation thread
505511
messages = assistant.messages
506512

lib/langchain/assistants/assistant.rb

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,14 @@ def initialize(
6363

6464
# Add a user message to the messages array
6565
#
66-
# @param content [String] The content of the message
6766
# @param role [String] The role attribute of the message. Default: "user"
67+
# @param content [String] The content of the message
68+
# @param image_url [String] The URL of the image to include in the message
6869
# @param tool_calls [Array<Hash>] The tool calls to include in the message
6970
# @param tool_call_id [String] The ID of the tool call to include in the message
7071
# @return [Array<Langchain::Message>] The messages
71-
def add_message(content: nil, role: "user", tool_calls: [], tool_call_id: nil)
72-
message = build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
72+
def add_message(role: "user", content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
73+
message = build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
7374

7475
# Call the callback with the message
7576
add_message_callback.call(message) if add_message_callback # rubocop:disable Style/SafeNavigation
@@ -145,17 +146,17 @@ def run!
145146
# @param content [String] The content of the message
146147
# @param auto_tool_execution [Boolean] Whether or not to automatically run tools
147148
# @return [Array<Langchain::Message>] The messages
148-
def add_message_and_run(content:, auto_tool_execution: false)
149-
add_message(content: content, role: "user")
149+
def add_message_and_run(content: nil, image_url: nil, auto_tool_execution: false)
150+
add_message(content: content, image_url: image_url, role: "user")
150151
run(auto_tool_execution: auto_tool_execution)
151152
end
152153

153154
# Add a user message and run the assistant with automatic tool execution
154155
#
155156
# @param content [String] The content of the message
156157
# @return [Array<Langchain::Message>] The messages
157-
def add_message_and_run!(content:)
158-
add_message_and_run(content: content, auto_tool_execution: true)
158+
def add_message_and_run!(content: nil, image_url: nil)
159+
add_message_and_run(content: content, image_url: image_url, auto_tool_execution: true)
159160
end
160161

161162
# Submit tool output
@@ -388,11 +389,12 @@ def run_tools(tool_calls)
388389
#
389390
# @param role [String] The role of the message
390391
# @param content [String] The content of the message
392+
# @param image_url [String] The URL of the image to include in the message
391393
# @param tool_calls [Array<Hash>] The tool calls to include in the message
392394
# @param tool_call_id [String] The ID of the tool call to include in the message
393395
# @return [Langchain::Message] The Message object
394-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
395-
@llm_adapter.build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
396+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
397+
@llm_adapter.build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
396398
end
397399

398400
# Increment the tokens count based on the last interaction with the LLM
@@ -443,7 +445,7 @@ def extract_tool_call_args(tool_call:)
443445
raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
444446
end
445447

446-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
448+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
447449
raise NotImplementedError, "Subclasses must implement build_message"
448450
end
449451
end
@@ -457,7 +459,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
457459
params
458460
end
459461

460-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
462+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
463+
warn "Image URL is not supported by Ollama currently" if image_url
464+
461465
Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
462466
end
463467

@@ -506,8 +510,8 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
506510
params
507511
end
508512

509-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
510-
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
513+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
514+
Langchain::Messages::OpenAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
511515
end
512516

513517
# Extract the tool call information from the OpenAI tool call hash
@@ -564,7 +568,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
564568
params
565569
end
566570

567-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
571+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
572+
warn "Image URL is not supported by MistralAI currently" if image_url
573+
568574
Langchain::Messages::MistralAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
569575
end
570576

@@ -623,7 +629,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
623629
params
624630
end
625631

626-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
632+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
633+
warn "Image URL is not supported by Google Gemini" if image_url
634+
627635
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
628636
end
629637

@@ -676,7 +684,9 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
676684
params
677685
end
678686

679-
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
687+
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
688+
warn "Image URL is not supported by Anthropic currently" if image_url
689+
680690
Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
681691
end
682692

lib/langchain/assistants/messages/base.rb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
module Langchain
44
module Messages
55
class Base
6-
attr_reader :role, :content, :tool_calls, :tool_call_id
6+
attr_reader :role,
7+
:content,
8+
:image_url,
9+
:tool_calls,
10+
:tool_call_id
711

812
# Check if the message came from a user
913
#

lib/langchain/assistants/messages/openai_message.rb

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,25 @@ class OpenAIMessage < Base
1515

1616
# Initialize a new OpenAI message
1717
#
18-
# @param [String] The role of the message
19-
# @param [String] The content of the message
20-
# @param [Array<Hash>] The tool calls made in the message
21-
# @param [String] The ID of the tool call
22-
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
18+
# @param role [String] The role of the message
19+
# @param content [String] The content of the message
20+
# @param image_url [String] The URL of the image
21+
# @param tool_calls [Array<Hash>] The tool calls made in the message
22+
# @param tool_call_id [String] The ID of the tool call
23+
def initialize(
24+
role:,
25+
content: nil,
26+
image_url: nil,
27+
tool_calls: [],
28+
tool_call_id: nil
29+
)
2330
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
2431
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
2532

2633
@role = role
2734
# Some Tools return content as a JSON hence `.to_s`
2835
@content = content.to_s
36+
@image_url = image_url
2937
@tool_calls = tool_calls
3038
@tool_call_id = tool_call_id
3139
end
@@ -43,9 +51,30 @@ def llm?
4351
def to_hash
4452
{}.tap do |h|
4553
h[:role] = role
46-
h[:content] = content if content # Content is nil for tool calls
47-
h[:tool_calls] = tool_calls if tool_calls.any?
48-
h[:tool_call_id] = tool_call_id if tool_call_id
54+
55+
if tool_calls.any?
56+
h[:tool_calls] = tool_calls
57+
else
58+
h[:tool_call_id] = tool_call_id if tool_call_id
59+
60+
h[:content] = []
61+
62+
if content && !content.empty?
63+
h[:content] << {
64+
type: "text",
65+
text: content
66+
}
67+
end
68+
69+
if image_url
70+
h[:content] << {
71+
type: "image_url",
72+
image_url: {
73+
url: image_url
74+
}
75+
}
76+
end
77+
end
4978
end
5079
end
5180

spec/langchain/assistants/assistant_spec.rb

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@
8787
expect(subject.messages.first.content).to eq("hello")
8888
end
8989

90+
it "adds a message with image_url" do
91+
message_with_image = {role: "user", content: "hello", image_url: "https://example.com/image.jpg"}
92+
subject = described_class.new(llm: llm, messages: [])
93+
94+
expect {
95+
subject.add_message(**message_with_image)
96+
}.to change { subject.messages.count }.from(0).to(1)
97+
expect(subject.messages.first).to be_a(Langchain::Messages::OpenAIMessage)
98+
expect(subject.messages.first.role).to eq("user")
99+
expect(subject.messages.first.content).to eq("hello")
100+
expect(subject.messages.first.image_url).to eq("https://example.com/image.jpg")
101+
end
102+
90103
it "calls the add_message_callback with the message" do
91104
callback = double("callback", call: true)
92105
subject = described_class.new(llm: llm, messages: [], add_message_callback: callback)
@@ -211,8 +224,8 @@
211224
allow(subject.llm).to receive(:chat)
212225
.with(
213226
messages: [
214-
{role: "system", content: instructions},
215-
{role: "user", content: "Please calculate 2+2"}
227+
{role: "system", content: [{type: "text", text: instructions}]},
228+
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]}
216229
],
217230
tools: calculator.class.function_schemas.to_openai_format,
218231
tool_choice: "auto"
@@ -255,16 +268,16 @@
255268
allow(subject.llm).to receive(:chat)
256269
.with(
257270
messages: [
258-
{role: "system", content: instructions},
259-
{role: "user", content: "Please calculate 2+2"},
260-
{role: "assistant", content: "", tool_calls: [
271+
{role: "system", content: [{type: "text", text: instructions}]},
272+
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]},
273+
{role: "assistant", tool_calls: [
261274
{
262275
"function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "langchain_tool_calculator__execute"},
263276
"id" => "call_9TewGANaaIjzY31UCpAAGLeV",
264277
"type" => "function"
265278
}
266279
]},
267-
{content: "4.0", role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
280+
{content: [{type: "text", text: "4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
268281
],
269282
tools: calculator.class.function_schemas.to_openai_format,
270283
tool_choice: "auto"

spec/langchain/assistants/messages/openai_message_spec.rb

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) }
1111

1212
it "returns a hash with the role and content key" do
13-
expect(message.to_hash).to eq({role: "user", content: "Hello, world!"})
13+
expect(message.to_hash).to eq({role: "user", content: [{type: "text", text: "Hello, world!"}]})
1414
end
1515
end
1616

1717
context "when tool_call_id is not nil" do
1818
let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") }
1919

2020
it "returns a hash with the tool_call_id key" do
21-
expect(message.to_hash).to eq({role: "tool", content: "Hello, world!", tool_call_id: "123"})
21+
expect(message.to_hash).to eq({role: "tool", content: [{type: "text", text: "Hello, world!"}], tool_call_id: "123"})
2222
end
2323
end
2424

@@ -29,10 +29,24 @@
2929
"function" => {"name" => "weather__execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}}
3030
}
3131

32-
let(:message) { described_class.new(role: "assistant", content: "", tool_calls: [tool_call], tool_call_id: nil) }
32+
let(:message) { described_class.new(role: "assistant", tool_calls: [tool_call], tool_call_id: nil) }
3333

3434
it "returns a hash with the tool_calls key" do
35-
expect(message.to_hash).to eq({role: "assistant", content: "", tool_calls: [tool_call]})
35+
expect(message.to_hash).to eq({role: "assistant", tool_calls: [tool_call]})
36+
end
37+
end
38+
39+
context "when image_url is present" do
40+
let(:message) { described_class.new(role: "user", content: "Please describe this image", image_url: "https://example.com/image.jpg") }
41+
42+
it "returns a hash with the image_url key" do
43+
expect(message.to_hash).to eq({
44+
role: "user",
45+
content: [
46+
{type: "text", text: "Please describe this image"},
47+
{type: "image_url", image_url: {url: "https://example.com/image.jpg"}}
48+
]
49+
})
3650
end
3751
end
3852
end

0 commit comments

Comments
 (0)