Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 5b9add0

Browse files
authored
FEATURE: add a SambaNova LLM provider (#797)
Note, at the moment the context window is quite small, it is mainly useful as a helper backend or hyde generator
1 parent 22d1e71 commit 5b9add0

File tree

8 files changed

+162
-11
lines changed

8 files changed

+162
-11
lines changed

app/models/ai_api_audit_log.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ module Provider
1212
Vllm = 5
1313
Cohere = 6
1414
Ollama = 7
15+
SambaNova = 8
1516
end
1617
end
1718

config/locales/client.en.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ en:
243243
confirm_delete: Are you sure you want to delete this model?
244244
delete: Delete
245245
seeded_warning: "This model is pre-configured on your site and cannot be edited."
246-
in_use_warning:
246+
in_use_warning:
247247
one: "This model is currently used by the %{settings} setting. If misconfigured, the feature won't work as expected."
248248
other: "This model is currently used by the following settings: %{settings}. If misconfigured, features won't work as expected. "
249249

@@ -275,6 +275,7 @@ en:
275275
azure: "Azure"
276276
ollama: "Ollama"
277277
CDCK: "CDCK"
278+
samba_nova: "SambaNova"
278279

279280
provider_fields:
280281
access_key_id: "AWS Bedrock Access key ID"

lib/completions/dialects/open_ai_compatible.rb

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,14 @@ def inline_images(content, message)
8585
encoded_uploads = prompt.encoded_uploads(message)
8686
return content if encoded_uploads.blank?
8787

88-
content_w_imgs =
89-
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
90-
memo << {
91-
type: "image_url",
92-
image_url: {
93-
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
94-
},
95-
}
96-
end
88+
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
89+
memo << {
90+
type: "image_url",
91+
image_url: {
92+
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
93+
},
94+
}
95+
end
9796
end
9897
end
9998
end

lib/completions/endpoints/base.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def endpoint_for(provider_name)
1717
DiscourseAi::Completions::Endpoints::Vllm,
1818
DiscourseAi::Completions::Endpoints::Anthropic,
1919
DiscourseAi::Completions::Endpoints::Cohere,
20+
DiscourseAi::Completions::Endpoints::SambaNova,
2021
]
2122

2223
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
module Endpoints
6+
class SambaNova < Base
7+
def self.can_contact?(model_provider)
8+
model_provider == "samba_nova"
9+
end
10+
11+
def normalize_model_params(model_params)
12+
model_params = model_params.dup
13+
14+
# max_tokens, temperature are already supported
15+
if model_params[:stop_sequences]
16+
model_params[:stop] = model_params.delete(:stop_sequences)
17+
end
18+
19+
model_params
20+
end
21+
22+
def default_options
23+
{ model: llm_model.name }
24+
end
25+
26+
def provider_id
27+
AiApiAuditLog::Provider::SambaNova
28+
end
29+
30+
private
31+
32+
def model_uri
33+
URI(llm_model.url)
34+
end
35+
36+
def prepare_payload(prompt, model_params, dialect)
37+
payload = default_options.merge(model_params).merge(messages: prompt)
38+
39+
payload[:stream] = true if @streaming_mode
40+
41+
payload
42+
end
43+
44+
def prepare_request(payload)
45+
headers = { "Content-Type" => "application/json" }
46+
api_key = llm_model.api_key
47+
48+
headers["Authorization"] = "Bearer #{api_key}"
49+
50+
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
51+
end
52+
53+
def final_log_update(log)
54+
log.request_tokens = @prompt_tokens if @prompt_tokens
55+
log.response_tokens = @completion_tokens if @completion_tokens
56+
end
57+
58+
def extract_completion_from(response_raw)
59+
json = JSON.parse(response_raw, symbolize_names: true)
60+
61+
if @streaming_mode
62+
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
63+
@completion_tokens ||= json.dig(:usage, :completion_tokens)
64+
end
65+
66+
parsed = json.dig(:choices, 0)
67+
return if !parsed
68+
69+
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
70+
end
71+
72+
def partials_from(decoded_chunk)
73+
decoded_chunk
74+
.split("\n")
75+
.map do |line|
76+
data = line.split("data: ", 2)[1]
77+
data == "[DONE]" ? nil : data
78+
end
79+
.compact
80+
end
81+
end
82+
end
83+
end
84+
end

lib/completions/llm.rb

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,17 @@ def presets
7676
end
7777

7878
def provider_names
79-
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
79+
providers = %w[
80+
aws_bedrock
81+
anthropic
82+
vllm
83+
hugging_face
84+
cohere
85+
open_ai
86+
google
87+
azure
88+
samba_nova
89+
]
8090
if !Rails.env.production?
8191
providers << "fake"
8292
providers << "ollama"

spec/fabricators/llm_model_fabricator.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,11 @@
7171
api_key "ABC"
7272
url "https://api.cohere.ai/v1/chat"
7373
end
74+
75+
Fabricator(:samba_nova_model, from: :llm_model) do
76+
display_name "Samba Nova"
77+
name "samba-nova"
78+
provider "samba_nova"
79+
api_key "ABC"
80+
url "https://api.sambanova.ai/v1/chat/completions"
81+
end
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# frozen_string_literal: true
2+
3+
RSpec.describe DiscourseAi::Completions::Endpoints::SambaNova do
4+
fab!(:llm_model) { Fabricate(:samba_nova_model) }
5+
let(:llm) { llm_model.to_llm }
6+
7+
it "can stream completions" do
8+
body = <<~PARTS
9+
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [{"index": 0, "delta": {"content": "I am a bot"}, "logprobs": null, "finish_reason": null}]}
10+
11+
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [], "usage": {"is_last_response": true, "total_tokens": 62, "prompt_tokens": 21, "completion_tokens": 41, "time_to_first_token": 0.09152531623840332, "end_time": 1721336361.582011, "start_time": 1721336361.413994, "total_latency": 0.16801691055297852, "total_tokens_per_sec": 369.010475171488, "completion_tokens_per_sec": 244.02305616179046, "completion_tokens_after_first_per_sec": 522.9332759819093, "completion_tokens_after_first_per_sec_first_ten": 1016.0004844667837}}
12+
13+
data: [DONE]
14+
PARTS
15+
16+
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
17+
body:
18+
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}],\"stream\":true}",
19+
headers: {
20+
"Authorization" => "Bearer ABC",
21+
"Content-Type" => "application/json",
22+
},
23+
).to_return(status: 200, body: body, headers: {})
24+
25+
response = +""
26+
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
27+
28+
expect(response).to eq("I am a bot")
29+
end
30+
31+
it "can perform regular completions" do
32+
body = { choices: [message: { content: "I am a bot" }] }.to_json
33+
34+
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
35+
body:
36+
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}]}",
37+
headers: {
38+
"Authorization" => "Bearer ABC",
39+
"Content-Type" => "application/json",
40+
},
41+
).to_return(status: 200, body: body, headers: {})
42+
43+
response = llm.generate("who are you?", user: Discourse.system_user)
44+
45+
expect(response).to eq("I am a bot")
46+
end
47+
end

0 commit comments

Comments
 (0)