Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit cf3cbb9

Browse files
committed
formatting and spec fixes
1 parent bcb2ee9 commit cf3cbb9

File tree

6 files changed

+57
-41
lines changed

6 files changed

+57
-41
lines changed

lib/completions/dialects/nova.rb

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,11 @@ def to_payload(options = nil)
3333
stop_sequences = options[:stop_sequences]
3434
max_tokens = options[:max_tokens]
3535

36-
inference_config =
37-
options&.slice(:temperature, :top_p, :top_k)
36+
inference_config = options&.slice(:temperature, :top_p, :top_k)
3837

39-
if stop_sequences.present?
40-
inference_config[:stopSequences] = stop_sequences
41-
end
38+
inference_config[:stopSequences] = stop_sequences if stop_sequences.present?
4239

43-
if max_tokens.present?
44-
inference_config[:max_new_tokens] = max_tokens
45-
end
40+
inference_config[:max_new_tokens] = max_tokens if max_tokens.present?
4641

4742
result = { system: system, messages: messages }
4843
result[:inferenceConfig] = inference_config if inference_config.present?

lib/completions/endpoints/aws_bedrock.rb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ def bedrock_decode(chunk)
167167
while decoded
168168
parsed = JSON.parse(decoded.payload.string)
169169
if exception = decoded.headers[":exception-type"]
170-
Rails.logger.error(
171-
"#{self.class.name}: #{exception}: #{parsed}",
172-
)
170+
Rails.logger.error("#{self.class.name}: #{exception}: #{parsed}")
173171
# TODO based on how often this happens, we may want to raise so we
174172
# can retry, this may catch rate limits for example
175173
end
@@ -203,8 +201,10 @@ def final_log_update(log)
203201

204202
def processor
205203
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
206-
@processor ||=
207-
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
204+
@processor ||=
205+
DiscourseAi::Completions::AnthropicMessageProcessor.new(
206+
streaming_mode: @streaming_mode,
207+
)
208208
else
209209
@processor ||=
210210
DiscourseAi::Completions::NovaMessageProcessor.new(streaming_mode: @streaming_mode)

lib/completions/nova_message_processor.rb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,10 @@ def process_streamed_message(parsed)
6969
@current_tool_call.append(tool_progress)
7070
end
7171

72-
if parsed[:contentBlockStop] && @current_tool_call
73-
result = @current_tool_call.to_tool_call
74-
end
72+
result = @current_tool_call.to_tool_call if parsed[:contentBlockStop] && @current_tool_call
7573

7674
if metadata = parsed[:metadata]
77-
@input_tokens = metadata.dig(:usage,:inputTokens)
75+
@input_tokens = metadata.dig(:usage, :inputTokens)
7876
@output_tokens = metadata.dig(:usage, :outputTokens)
7977
end
8078

spec/lib/completions/dialects/claude_spec.rb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# frozen_string_literal: true
22

33
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
4-
54
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
65

76
let :opus_dialect_klass do

spec/lib/completions/dialects/nova_spec.rb

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,9 @@
6262
{
6363
name: "get_weather",
6464
description: "Get the weather in a city",
65-
input_schema: {
66-
type: "object",
67-
properties: {
68-
location: {
69-
type: "string",
70-
description: "the city name",
71-
},
72-
},
73-
required: ["location"],
74-
},
65+
parameters: [
66+
{ name: "location", type: "string", description: "the city name", required: true },
67+
],
7568
},
7669
]
7770

@@ -98,21 +91,16 @@
9891
json: {
9992
type: "object",
10093
properties: {
101-
location: {
94+
"location" => {
10295
type: "string",
103-
description: "the city name",
96+
required: true,
10497
},
10598
},
106-
required: ["location"],
10799
},
108100
},
109101
},
110102
},
111103
],
112-
toolChoice: {
113-
auto: {
114-
},
115-
},
116104
},
117105
)
118106
end
@@ -126,18 +114,18 @@
126114

127115
dialect = nova_dialect_klass.new(prompt, llm_model)
128116

129-
options = { temperature: 0.7, top_p: 0.9, max_new_tokens: 100, stop_sequences: ["STOP"] }
117+
options = { temperature: 0.7, top_p: 0.9, max_tokens: 100, stop_sequences: ["STOP"] }
130118

131119
translated = dialect.translate
132120

133121
expected = {
134-
system: { text: "You are a helpful bot" },
122+
system: [{ text: "You are a helpful bot" }],
135123
messages: [{ role: "user", content: [{ text: "Hello" }] }],
136124
inferenceConfig: {
137125
temperature: 0.7,
138126
top_p: 0.9,
127+
stopSequences: ["STOP"],
139128
max_new_tokens: 100,
140-
stop_sequences: ["STOP"],
141129
},
142130
}
143131

spec/lib/completions/endpoints/nova_spec.rb

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def encode_message(message)
223223
)
224224

225225
# lets continue and ensure all messages are mapped correctly
226-
prompt.push(type: :tool_call, name: "time", content: {timezone: "EST"}.to_json, id: "111")
226+
prompt.push(type: :tool_call, name: "time", content: { timezone: "EST" }.to_json, id: "111")
227227
prompt.push(type: :tool, name: "time", content: "1pm".to_json, id: "111")
228228

229229
# lets just return the tool call again, this is about ensuring we encode the prompt right
@@ -239,9 +239,45 @@ def encode_message(message)
239239
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
240240
end
241241

242-
expected = {:system=>[{:text=>"You are a helpful assistant."}], :messages=>[{:role=>"user", :content=>[{:text=>"what is the time in EST"}]}, {:role=>"assistant", :content=>[{:toolUse=>{:toolUseId=>"111", :name=>"time", :input=>nil}}]}, {:role=>"user", :content=>[{:toolResult=>{:toolUseId=>"111", :content=>[{:json=>"1pm"}]}}]}], :inferenceConfig=>{:max_new_tokens=>200}, :toolConfig=>{:tools=>[{:toolSpec=>{:name=>"time", :description=>"Will look up the current time", :inputSchema=>{:json=>{:type=>"object", :properties=>{:timezone=>{:type=>"string", :required=>true}}}}}}]}}
242+
expected = {
243+
system: [{ text: "You are a helpful assistant." }],
244+
messages: [
245+
{ role: "user", content: [{ text: "what is the time in EST" }] },
246+
{
247+
role: "assistant",
248+
content: [{ toolUse: { toolUseId: "111", name: "time", input: nil } }],
249+
},
250+
{
251+
role: "user",
252+
content: [{ toolResult: { toolUseId: "111", content: [{ json: "1pm" }] } }],
253+
},
254+
],
255+
inferenceConfig: {
256+
max_new_tokens: 200,
257+
},
258+
toolConfig: {
259+
tools: [
260+
{
261+
toolSpec: {
262+
name: "time",
263+
description: "Will look up the current time",
264+
inputSchema: {
265+
json: {
266+
type: "object",
267+
properties: {
268+
timezone: {
269+
type: "string",
270+
required: true,
271+
},
272+
},
273+
},
274+
},
275+
},
276+
},
277+
],
278+
},
279+
}
243280

244281
expect(JSON.parse(request.body, symbolize_names: true)).to eq(expected)
245-
246282
end
247283
end

0 commit comments

Comments
 (0)