Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 9f8c15f

Browse files
committed
Open AI starting to work.
1 parent 6cac18a commit 9f8c15f

File tree

3 files changed

+97
-113
lines changed

3 files changed

+97
-113
lines changed

lib/completions/endpoints/open_ai.rb

Lines changed: 61 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -93,99 +93,85 @@ def prepare_request(payload)
9393
end
9494

9595
def final_log_update(log)
96-
log.request_tokens = @prompt_tokens if @prompt_tokens
97-
log.response_tokens = @completion_tokens if @completion_tokens
96+
log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
97+
log.response_tokens = processor.completion_tokens if processor.completion_tokens
9898
end
9999

100-
def extract_completion_from(response_raw)
101-
json = JSON.parse(response_raw, symbolize_names: true)
100+
class OpenAiMessageProcessor
101+
attr_reader :prompt_tokens, :completion_tokens
102102

103-
if @streaming_mode
104-
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
105-
@completion_tokens ||= json.dig(:usage, :completion_tokens)
103+
def initialize
104+
@tool = nil
105+
@tool_arguments = +""
106+
@prompt_tokens = nil
107+
@completion_tokens = nil
106108
end
107109

108-
parsed = json.dig(:choices, 0)
109-
return if !parsed
110+
def process_streamed_message(json)
111+
rval = nil
110112

111-
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
112-
@has_function_call ||= response_h.dig(:tool_calls).present?
113-
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
114-
end
113+
tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
114+
content = json.dig(:choices, 0, :delta, :content)
115+
116+
finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
115117

116-
def partials_from(decoded_chunk)
117-
decoded_chunk
118-
.split("\n")
119-
.map do |line|
120-
data = line.split("data: ", 2)[1]
121-
data == "[DONE]" ? nil : data
118+
if tool_calls.present?
119+
id = tool_calls.dig(0, :id)
120+
name = tool_calls.dig(0, :function, :name)
121+
arguments = tool_calls.dig(0, :function, :arguments)
122+
#index = tool_calls[0].dig(:index)
123+
if id.present? && @tool
124+
if @tool_arguments.present?
125+
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
126+
@tool.parameters = parsed_args
127+
end
128+
rval = @tool
129+
@tool = nil
130+
end
131+
132+
if id.present?
133+
@tool_arguments = +""
134+
@tool = ToolCall.new(id: id, name: name)
135+
end
136+
137+
@tool_arguments << arguments.to_s
138+
elsif finished_tools && @tool
139+
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
140+
@tool.parameters = parsed_args
141+
rval = @tool
142+
@tool = nil
143+
elsif content.present?
144+
rval = content
122145
end
123-
.compact
146+
147+
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
148+
@completion_tokens ||= json.dig(:usage, :completion_tokens)
149+
150+
rval
151+
end
152+
end
153+
154+
def decode_chunk(chunk)
155+
@decoder ||= JsonStreamDecoder.new
156+
(@decoder << chunk).map do |parsed_json|
157+
processor.process_streamed_message(parsed_json)
158+
end.flatten.compact
124159
end
125160

126161
def has_tool?(_response_data)
127162
@has_function_call
128163
end
129164

130-
def native_tool_support?
131-
true
165+
def xml_tools_enabled?
166+
false
132167
end
133168

134-
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
135-
if @streaming_mode
136-
return function_buffer if !partial
137-
else
138-
partial = payload
139-
end
140-
141-
@args_buffer ||= +""
142-
143-
f_name = partial.dig(:function, :name)
144-
145-
@current_function ||= function_buffer.at("invoke")
146-
147-
if f_name
148-
current_name = function_buffer.at("tool_name").content
149-
150-
if current_name.blank?
151-
# first call
152-
else
153-
# we have a previous function, so we need to add a noop
154-
@args_buffer = +""
155-
@current_function =
156-
function_buffer.at("function_calls").add_child(
157-
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
158-
)
159-
end
160-
end
161-
162-
@current_function.at("tool_name").content = f_name if f_name
163-
@current_function.at("tool_id").content = partial[:id] if partial[:id]
164-
165-
args = partial.dig(:function, :arguments)
166-
167-
# allow for SPACE within arguments
168-
if args && args != ""
169-
@args_buffer << args
170-
171-
begin
172-
json_args = JSON.parse(@args_buffer, symbolize_names: true)
173-
174-
argument_fragments =
175-
json_args.reduce(+"") do |memo, (arg_name, value)|
176-
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
177-
end
178-
argument_fragments << "\n"
179-
180-
@current_function.at("parameters").children =
181-
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
182-
rescue JSON::ParserError
183-
return function_buffer
184-
end
185-
end
169+
private
186170

187-
function_buffer
171+
def processor
172+
@processor ||= OpenAiMessageProcessor.new
188173
end
174+
189175
end
190176
end
191177
end

lib/completions/tool_call.rb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ def initialize(id:, name:, parameters: nil)
1111
@parameters = parameters || {}
1212
end
1313

14+
def parameters=(parameters)
15+
raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash)
16+
@parameters = parameters
17+
end
18+
1419
def ==(other)
1520
id == other.id && name == other.name && parameters == other.parameters
1621
end

spec/lib/completions/endpoints/open_ai_spec.rb

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -487,40 +487,38 @@ def request_body(prompt, stream: false, tool_call: false)
487487
488488
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
489489
490-
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
490+
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]}
491491
492492
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
493493
494494
data: [DONE]
495495
TEXT
496496

497497
open_ai_mock.stub_raw(raw_data)
498-
content = +""
498+
response = []
499499

500500
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
501501

502-
endpoint.perform_completion!(dialect, user) { |partial| content << partial }
503-
504-
expected = <<~TEXT
505-
<function_calls>
506-
<invoke>
507-
<tool_name>search</tool_name>
508-
<parameters>
509-
<search_query>Discourse AI bot</search_query>
510-
</parameters>
511-
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
512-
</invoke>
513-
<invoke>
514-
<tool_name>search</tool_name>
515-
<parameters>
516-
<query>Discourse AI bot</query>
517-
</parameters>
518-
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
519-
</invoke>
520-
</function_calls>
521-
TEXT
502+
endpoint.perform_completion!(dialect, user) { |partial| response << partial }
503+
504+
tool_calls = [
505+
DiscourseAi::Completions::ToolCall.new(
506+
name: "search",
507+
id: "call_3Gyr3HylFJwfrtKrL6NaIit1",
508+
parameters: {
509+
search_query: "Discourse AI bot",
510+
},
511+
),
512+
DiscourseAi::Completions::ToolCall.new(
513+
name: "search",
514+
id: "call_H7YkbgYurHpyJqzwUN4bghwN",
515+
parameters: {
516+
query: "Discourse AI bot2",
517+
},
518+
),
519+
]
522520

523-
expect(content).to eq(expected)
521+
expect(response).to eq(tool_calls)
524522
end
525523

526524
it "uses proper token accounting" do
@@ -593,21 +591,16 @@ def request_body(prompt, stream: false, tool_call: false)
593591
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
594592
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
595593

596-
expect(partials.length).to eq(1)
597-
598-
function_call = (<<~TXT).strip
599-
<function_calls>
600-
<invoke>
601-
<tool_name>google</tool_name>
602-
<parameters>
603-
<query>Adabas 9.1</query>
604-
</parameters>
605-
<tool_id>func_id</tool_id>
606-
</invoke>
607-
</function_calls>
608-
TXT
609-
610-
expect(partials[0].strip).to eq(function_call)
594+
tool_call =
595+
DiscourseAi::Completions::ToolCall.new(
596+
id: "func_id",
597+
name: "google",
598+
parameters: {
599+
query: "Adabas 9.1",
600+
},
601+
)
602+
603+
expect(partials).to eq([tool_call])
611604
end
612605
end
613606
end

0 commit comments

Comments
 (0)