Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit af71ede

Browse files
committed
Seed embedding definition from old settings
1 parent 1c08bea commit af71ede

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

assets/javascripts/discourse/components/ai-embedding-editor.gjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ export default class AiEmbeddingEditor extends Component {
124124
await this.editingModel.save();
125125

126126
if (isNew) {
127-
this.args.embeddings.addObject(this.args.model);
127+
this.args.embeddings.addObject(this.editingModel);
128128
this.router.transitionTo(
129129
"adminPlugins.show.discourse-ai-embeddings.index"
130130
);
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# frozen_string_literal: true
2+
3+
class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0]
4+
def up
5+
current_model = fetch_setting("ai_embeddings_model") || "bge-large-en"
6+
provider = provider_for(current_model)
7+
8+
if provider.present?
9+
attrs = creds_for(provider)
10+
11+
if attrs.present?
12+
attrs = attrs.merge(model_attrs(current_model))
13+
attrs[:display_name] = current_model
14+
attrs[:provider] = provider
15+
persist_config(attrs)
16+
end
17+
end
18+
end
19+
20+
def down
21+
end
22+
23+
# Utils
24+
25+
def fetch_setting(name)
26+
DB.query_single(
27+
"SELECT value FROM site_settings WHERE name = :setting_name",
28+
setting_name: name,
29+
).first || ENV["DISCOURSE_#{name&.upcase}"]
30+
end
31+
32+
def provider_for(model)
33+
cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token")
34+
35+
return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present?
36+
37+
tei_models = %w[bge-large-en bge-m3 multilingual-e5-large]
38+
return "hugging_face" if tei_models.include?(model)
39+
40+
return "google" if model == "gemini"
41+
42+
if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model)
43+
return "open_ai"
44+
end
45+
end
46+
47+
def creds_for(provider)
48+
# CF
49+
if provider == "cloudflare"
50+
api_key = fetch_setting("ai_cloudflare_workers_api_token")
51+
account_id = fetch_setting("ai_cloudflare_workers_account_id")
52+
53+
return if api_key.blank? || account_id.blank?
54+
55+
{
56+
url:
57+
"https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5",
58+
api_key: api_key,
59+
}
60+
# TEI
61+
elsif provider == "hugging_face"
62+
endpoint = fetch_setting("ai_hugging_face_tei_endpoint")
63+
64+
if endpoint.blank?
65+
endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv")
66+
endpoint = "srv://#{endpoint}" if endpoint.present?
67+
end
68+
69+
api_key = fetch_setting("ai_hugging_face_tei_api_key")
70+
71+
return if endpoint.blank? || api_key.blank?
72+
73+
{ url: endpoint, api_key: api_key }
74+
# Gemini
75+
elsif provider == "google"
76+
api_key = fetch_setting("ai_gemini_api_key")
77+
78+
return if api_key.blank?
79+
80+
{
81+
url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
82+
api_key: api_key,
83+
}
84+
85+
# Open AI
86+
elsif provider == "open_ai"
87+
endpoint = fetch_setting("ai_openai_embeddings_url")
88+
api_key = fetch_setting("ai_openai_api_key")
89+
90+
return if endpoint.blank? || api_key.blank?
91+
92+
{ url: endpoint, api_key: api_key }
93+
else
94+
nil
95+
end
96+
end
97+
98+
def model_attrs(model_name)
99+
if model_name == "bge-large-en"
100+
{
101+
dimensions: 1024,
102+
max_sequence_length: 512,
103+
id: 4,
104+
pg_function: "<#>",
105+
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
106+
}
107+
elsif model_name == "bge-m3"
108+
{
109+
dimensions: 1024,
110+
max_sequence_length: 8192,
111+
id: 8,
112+
pg_function: "<#>",
113+
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
114+
}
115+
elsif model_name == "gemini"
116+
{
117+
dimensions: 768,
118+
max_sequence_length: 1536,
119+
id: 5,
120+
pg_function: "<=>",
121+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
122+
}
123+
elsif model_name == "multilingual-e5-large"
124+
{
125+
dimensions: 1024,
126+
max_sequence_length: 512,
127+
id: 3,
128+
pg_function: "<=>",
129+
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
130+
}
131+
elsif model_name == "text-embedding-3-large"
132+
{
133+
dimensions: 2000,
134+
max_sequence_length: 8191,
135+
id: 7,
136+
pg_function: "<=>",
137+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
138+
provider_params: {
139+
model_name: "text-embedding-3-large",
140+
},
141+
}
142+
elsif model_name == "text-embedding-3-small"
143+
{
144+
dimensions: 1536,
145+
max_sequence_length: 8191,
146+
id: 6,
147+
pg_function: "<=>",
148+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
149+
provider_params: {
150+
model_name: "text-embedding-3-small",
151+
},
152+
}
153+
else
154+
{
155+
dimensions: 1536,
156+
max_sequence_length: 8191,
157+
id: 2,
158+
pg_function: "<=>",
159+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
160+
provider_params: {
161+
model_name: "text-embedding-ada-002",
162+
},
163+
}
164+
end
165+
end
166+
167+
def persist_config(attrs)
168+
DB.exec(
169+
<<~SQL,
170+
INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, created_at, updated_at)
171+
VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :now, :now)
172+
SQL
173+
id: attrs[:id],
174+
display_name: attrs[:display_name],
175+
dimensions: attrs[:dimensions],
176+
max_sequence_length: attrs[:max_sequence_length],
177+
pg_function: attrs[:pg_function],
178+
provider: attrs[:provider],
179+
tokenizer_class: attrs[:tokenizer_class],
180+
url: attrs[:url],
181+
api_key: attrs[:api_key],
182+
provider_params: attrs[:provider_params],
183+
now: Time.zone.now,
184+
)
185+
186+
# We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts.
187+
DB.exec(
188+
"ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq",
189+
new_seq: attrs[:id].to_i + 1,
190+
)
191+
192+
DB.exec(
193+
"UPDATE site_settings SET value=:id WHERE name = 'ai_embeddings_selected_model'",
194+
id: attrs[:id],
195+
)
196+
end
197+
end

0 commit comments

Comments
 (0)