Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 62d3e2b

Browse files
committed
Anthropic implementation of partial streaming
1 parent 9551b1a commit 62d3e2b

File tree

10 files changed

+781
-25
lines changed

10 files changed

+781
-25
lines changed

lib/completions/anthropic_message_processor.rb

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,92 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
44
class AnthropicToolCall
55
attr_reader :name, :raw_json, :id
66

7-
def initialize(name, id)
7+
def initialize(name, id, partial_tool_calls: false)
88
@name = name
99
@id = id
1010
@raw_json = +""
11+
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
12+
@streaming_parser = ToolCallProgressTracker.new(self) if partial_tool_calls
1113
end
1214

1315
def append(json)
1416
@raw_json << json
17+
@streaming_parser << json if @streaming_parser
18+
end
19+
20+
def notify_progress(key, value)
21+
@tool_call.partial = true
22+
@tool_call.parameters[key.to_sym] = value
23+
@has_new_data = true
24+
end
25+
26+
def has_partial?
27+
@has_new_data
28+
end
29+
30+
def partial_tool_call
31+
@has_new_data = false
32+
@tool_call
1533
end
1634

1735
def to_tool_call
1836
parameters = JSON.parse(raw_json, symbolize_names: true)
19-
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
37+
@tool_call.partial = false
38+
@tool_call.parameters = parameters
39+
@tool_call
40+
end
41+
end
42+
43+
class ToolCallProgressTracker
44+
attr_reader :current_key, :current_value, :tool_call
45+
46+
def initialize(tool_call)
47+
@tool_call = tool_call
48+
@current_key = nil
49+
@current_value = nil
50+
@parser = DiscourseAi::Completions::JsonStreamingParser.new
51+
52+
@parser.key do |k|
53+
@current_key = k
54+
@current_value = nil
55+
end
56+
57+
@parser.value do |v|
58+
tool_call.notify_progress(@current_key, v) if @current_key
59+
end
60+
end
61+
62+
def <<(json)
63+
# llm could send broken json
64+
# in that case just deal with it later
65+
# don't stream
66+
return if @broken
67+
68+
begin
69+
@parser << json
70+
rescue DiscourseAi::Utils::ParserError
71+
@broken = true
72+
return
73+
end
74+
75+
if @parser.state == :start_string && @current_key
76+
# this is is worth notifying
77+
tool_call.notify_progress(@current_key, @parser.buf)
78+
end
79+
80+
if @parser.state == :end_value
81+
@current_key = nil
82+
end
2083
end
2184
end
2285

2386
attr_reader :tool_calls, :input_tokens, :output_tokens
2487

25-
def initialize(streaming_mode:)
88+
def initialize(streaming_mode:, partial_tool_calls: false)
2689
@streaming_mode = streaming_mode
2790
@tool_calls = []
2891
@current_tool_call = nil
92+
@partial_tool_calls = partial_tool_calls
2993
end
3094

3195
def to_tool_calls
@@ -38,11 +102,19 @@ def process_streamed_message(parsed)
38102
tool_name = parsed.dig(:content_block, :name)
39103
tool_id = parsed.dig(:content_block, :id)
40104
result = @current_tool_call.to_tool_call if @current_tool_call
41-
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
105+
@current_tool_call =
106+
AnthropicToolCall.new(
107+
tool_name,
108+
tool_id,
109+
partial_tool_calls: @partial_tool_calls,
110+
) if tool_name
42111
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
43112
if @current_tool_call
44113
tool_delta = parsed.dig(:delta, :partial_json).to_s
45114
@current_tool_call.append(tool_delta)
115+
if @current_tool_call.has_partial?
116+
result = @current_tool_call.partial_tool_call
117+
end
46118
else
47119
result = parsed.dig(:delta, :text).to_s
48120
end

lib/completions/endpoints/anthropic.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def decode(response_data)
107107

108108
def processor
109109
@processor ||=
110-
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
110+
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode, partial_tool_calls: partial_tool_calls)
111111
end
112112

113113
def has_tool?(_response_data)

lib/completions/endpoints/base.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ module DiscourseAi
44
module Completions
55
module Endpoints
66
class Base
7+
attr_reader :partial_tool_calls
8+
79
CompletionFailed = Class.new(StandardError)
810
TIMEOUT = 60
911

@@ -58,8 +60,10 @@ def perform_completion!(
5860
model_params = {},
5961
feature_name: nil,
6062
feature_context: nil,
63+
partial_tool_calls: false,
6164
&blk
6265
)
66+
@partial_tool_calls = partial_tool_calls
6367
model_params = normalize_model_params(model_params)
6468
orig_blk = blk
6569

lib/completions/endpoints/canned_response.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def perform_completion!(
2828
_user,
2929
_model_params,
3030
feature_name: nil,
31-
feature_context: nil
31+
feature_context: nil,
32+
partial_tool_calls: false
3233
)
3334
@dialect = dialect
3435
response = responses[completions]

lib/completions/endpoints/fake.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def perform_completion!(
120120
user,
121121
model_params = {},
122122
feature_name: nil,
123-
feature_context: nil
123+
feature_context: nil,
124+
partial_tool_calls: false
124125
)
125126
last_call = { dialect: dialect, user: user, model_params: model_params }
126127
self.class.last_call = last_call

lib/completions/endpoints/open_ai.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def perform_completion!(
3333
model_params = {},
3434
feature_name: nil,
3535
feature_context: nil,
36+
partial_tool_calls: false,
3637
&blk
3738
)
3839
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?

0 commit comments

Comments
 (0)