Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit fd7ccfd

Browse files
committed
FEATURE: support tool progress callbacks
This is anthropic only for now, but we can get a callback as tool is completing, this gives us the ability to show progress to user as the function is populating.
1 parent 0191b41 commit fd7ccfd

File tree

11 files changed

+815
-13
lines changed

11 files changed

+815
-13
lines changed

lib/ai_bot/bot.rb

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,15 @@ def reply(context, &update_blk)
105105
tool_found = false
106106
force_tool_if_needed(prompt, context)
107107

108+
tool_progress = proc { |progress| p progress }
109+
108110
result =
109-
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
111+
llm.generate(
112+
prompt,
113+
feature_name: "bot",
114+
tool_progress: tool_progress,
115+
**llm_kwargs,
116+
) do |partial, cancel|
110117
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
111118

112119
if (tools.present?)

lib/completions/anthropic_message_processor.rb

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,83 @@
11
# frozen_string_literal: true
22

33
class DiscourseAi::Completions::AnthropicMessageProcessor
4+
class ToolCallProgressTracker
5+
attr_reader :current_key, :current_value, :tool_call
6+
7+
def initialize(tool_call)
8+
@tool_call = tool_call
9+
@current_key = nil
10+
@current_value = nil
11+
@parser = DiscourseAi::Utils::JsonStreamingParser.new
12+
13+
@parser.key do |k|
14+
@current_key = k
15+
@current_value = nil
16+
end
17+
@parser.value do |v|
18+
@current_value = v
19+
20+
if @current_key
21+
tool_call.tool_progress.call(
22+
{ name: tool_call.name, id: tool_call.id, key: @current_key, value: @current_value },
23+
)
24+
end
25+
end
26+
end
27+
28+
def <<(json)
29+
# llm could send broken json
30+
# in that case just deal with it later
31+
# don't stream
32+
return if @broken
33+
34+
begin
35+
@parser << json
36+
rescue DiscourseAi::Utils::ParserError
37+
@broken = true
38+
return
39+
end
40+
41+
if @parser.state == :start_string && @current_key
42+
# this is is worth notifying
43+
tool_call.tool_progress.call(
44+
{
45+
name: tool_call.name,
46+
id: tool_call.id,
47+
key: @current_key,
48+
value: @parser.buf,
49+
done: false,
50+
},
51+
)
52+
end
53+
end
54+
end
55+
456
class AnthropicToolCall
5-
attr_reader :name, :raw_json, :id
57+
attr_reader :name, :raw_json, :id, :tool_progress
658

7-
def initialize(name, id)
59+
def initialize(name, id, tool_progress)
860
@name = name
961
@id = id
1062
@raw_json = +""
63+
if tool_progress
64+
@tool_progress = tool_progress
65+
@tool_call_progress_tracker = ToolCallProgressTracker.new(self)
66+
end
1167
end
1268

1369
def append(json)
1470
@raw_json << json
71+
@tool_call_progress_tracker << json if @tool_progress
1572
end
1673
end
1774

1875
attr_reader :tool_calls, :input_tokens, :output_tokens
1976

20-
def initialize(streaming_mode:)
77+
def initialize(streaming_mode:, tool_progress:)
2178
@streaming_mode = streaming_mode
2279
@tool_calls = []
80+
@tool_progress = tool_progress
2381
end
2482

2583
def to_xml_tool_calls(function_buffer)
@@ -58,7 +116,7 @@ def process_message(payload)
58116
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
59117
tool_name = parsed.dig(:content_block, :name)
60118
tool_id = parsed.dig(:content_block, :id)
61-
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
119+
@tool_calls << AnthropicToolCall.new(tool_name, tool_id, @tool_progress) if tool_name
62120
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
63121
if @tool_calls.present?
64122
result = parsed.dig(:delta, :partial_json).to_s
@@ -83,7 +141,7 @@ def process_message(payload)
83141
if content.is_a?(Array)
84142
tool_call = content.find { |c| c[:type] == "tool_use" }
85143
if tool_call
86-
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
144+
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id], @tool_progress)
87145
@tool_calls.last.append(tool_call[:input].to_json)
88146
else
89147
result = parsed.dig(:content, 0, :text).to_s

lib/completions/endpoints/anthropic.rb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ def prepare_request(payload)
9292

9393
def processor
9494
@processor ||=
95-
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
95+
DiscourseAi::Completions::AnthropicMessageProcessor.new(
96+
streaming_mode: @streaming_mode,
97+
tool_progress: @tool_progress,
98+
)
9699
end
97100

98101
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)

lib/completions/endpoints/aws_bedrock.rb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def final_log_update(log)
157157

158158
def processor
159159
@processor ||=
160-
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
160+
DiscourseAi::Completions::AnthropicMessageProcessor.new(
161+
streaming_mode: @streaming_mode,
162+
tool_progress: @tool_progress,
163+
)
161164
end
162165

163166
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)

lib/completions/endpoints/base.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ def perform_completion!(
6262
model_params = {},
6363
feature_name: nil,
6464
feature_context: nil,
65+
tool_progress: nil,
6566
&blk
6667
)
6768
allow_tools = dialect.prompt.has_tools?
6869
model_params = normalize_model_params(model_params)
6970
orig_blk = blk
7071

72+
@tool_progress = tool_progress
7173
@streaming_mode = block_given?
7274
to_strip = xml_tags_to_strip(dialect)
7375
@xml_stripper =

lib/completions/endpoints/canned_response.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def perform_completion!(
2828
_user,
2929
_model_params,
3030
feature_name: nil,
31-
feature_context: nil
31+
feature_context: nil,
32+
tool_progress: nil
3233
)
3334
@dialect = dialect
3435
response = responses[completions]

lib/completions/endpoints/fake.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def perform_completion!(
116116
user,
117117
model_params = {},
118118
feature_name: nil,
119-
feature_context: nil
119+
feature_context: nil,
120+
tool_progress: nil
120121
)
121122
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
122123

lib/completions/endpoints/open_ai.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def perform_completion!(
3333
model_params = {},
3434
feature_name: nil,
3535
feature_context: nil,
36+
tool_progress: nil,
3637
&blk
3738
)
3839
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?

lib/completions/llm.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def generate(
192192
user:,
193193
feature_name: nil,
194194
feature_context: nil,
195+
tool_progress: nil,
195196
&partial_read_blk
196197
)
197198
self.class.record_prompt(prompt)
@@ -226,6 +227,7 @@ def generate(
226227
model_params,
227228
feature_name: feature_name,
228229
feature_context: feature_context,
230+
tool_progress: tool_progress,
229231
&partial_read_blk
230232
)
231233
end

0 commit comments

Comments
 (0)