Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit f2238ed

Browse files
committed
Seed embedding definition from old settings
1 parent 1c08bea commit f2238ed

File tree

2 files changed

+200
-1
lines changed

2 files changed

+200
-1
lines changed

assets/javascripts/discourse/components/ai-embedding-editor.gjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ export default class AiEmbeddingEditor extends Component {
124124
await this.editingModel.save();
125125

126126
if (isNew) {
127-
this.args.embeddings.addObject(this.args.model);
127+
this.args.embeddings.addObject(this.editingModel);
128128
this.router.transitionTo(
129129
"adminPlugins.show.discourse-ai-embeddings.index"
130130
);
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# frozen_string_literal: true
2+
3+
class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0]
4+
def up
5+
current_model = fetch_setting("ai_embeddings_model") || "bge-large-en"
6+
provider = provider_for(current_model)
7+
8+
if provider.present?
9+
attrs = creds_for(provider)
10+
11+
if attrs.present?
12+
attrs = attrs.merge(model_attrs(current_model))
13+
attrs[:display_name] = current_model
14+
attrs[:provider] = provider
15+
persist_config(attrs)
16+
end
17+
end
18+
end
19+
20+
def down
21+
end
22+
23+
# Utils
24+
25+
def fetch_setting(name)
26+
DB.query_single(
27+
"SELECT value FROM site_settings WHERE name = :setting_name",
28+
setting_name: name,
29+
).first || ENV["DISCOURSE_#{name&.upcase}"]
30+
end
31+
32+
def provider_for(model)
33+
cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token")
34+
35+
return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present?
36+
37+
tei_models = %w[bge-large-en bge-m3 multilingual-e5-large]
38+
return "hugging_face" if tei_models.include?(model)
39+
40+
return "google" if model == "gemini"
41+
42+
if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model)
43+
return "open_ai"
44+
end
45+
46+
nil
47+
end
48+
49+
def creds_for(provider)
50+
# CF
51+
if provider == "cloudflare"
52+
api_key = fetch_setting("ai_cloudflare_workers_api_token")
53+
account_id = fetch_setting("ai_cloudflare_workers_account_id")
54+
55+
return if api_key.blank? || account_id.blank?
56+
57+
{
58+
url:
59+
"https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5",
60+
api_key: api_key,
61+
}
62+
# TEI
63+
elsif provider == "hugging_face"
64+
endpoint = fetch_setting("ai_hugging_face_tei_endpoint")
65+
66+
if endpoint.blank?
67+
endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv")
68+
endpoint = "srv://#{endpoint}" if endpoint.present?
69+
end
70+
71+
api_key = fetch_setting("ai_hugging_face_tei_api_key")
72+
73+
return if endpoint.blank? || api_key.blank?
74+
75+
{ url: endpoint, api_key: api_key }
76+
# Gemini
77+
elsif provider == "google"
78+
api_key = fetch_setting("ai_gemini_api_key")
79+
80+
return if api_key.blank?
81+
82+
{
83+
url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
84+
api_key: api_key,
85+
}
86+
87+
# Open AI
88+
elsif provider == "open_ai"
89+
endpoint = fetch_setting("ai_openai_embeddings_url")
90+
api_key = fetch_setting("ai_openai_api_key")
91+
92+
return if endpoint.blank? || api_key.blank?
93+
94+
{ url: endpoint, api_key: api_key }
95+
else
96+
nil
97+
end
98+
end
99+
100+
def model_attrs(model_name)
101+
if model_name == "bge-large-en"
102+
{
103+
dimensions: 1024,
104+
max_sequence_length: 512,
105+
id: 4,
106+
pg_function: "<#>",
107+
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
108+
}
109+
elsif model_name == "bge-m3"
110+
{
111+
dimensions: 1024,
112+
max_sequence_length: 8192,
113+
id: 8,
114+
pg_function: "<#>",
115+
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
116+
}
117+
elsif model_name == "gemini"
118+
{
119+
dimensions: 768,
120+
max_sequence_length: 1536,
121+
id: 5,
122+
pg_function: "<=>",
123+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
124+
}
125+
elsif model_name == "multilingual-e5-large"
126+
{
127+
dimensions: 1024,
128+
max_sequence_length: 512,
129+
id: 3,
130+
pg_function: "<=>",
131+
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
132+
}
133+
elsif model_name == "text-embedding-3-large"
134+
{
135+
dimensions: 2000,
136+
max_sequence_length: 8191,
137+
id: 7,
138+
pg_function: "<=>",
139+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
140+
provider_params: {
141+
model_name: "text-embedding-3-large",
142+
},
143+
}
144+
elsif model_name == "text-embedding-3-small"
145+
{
146+
dimensions: 1536,
147+
max_sequence_length: 8191,
148+
id: 6,
149+
pg_function: "<=>",
150+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
151+
provider_params: {
152+
model_name: "text-embedding-3-small",
153+
},
154+
}
155+
else
156+
{
157+
dimensions: 1536,
158+
max_sequence_length: 8191,
159+
id: 2,
160+
pg_function: "<=>",
161+
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
162+
provider_params: {
163+
model_name: "text-embedding-ada-002",
164+
},
165+
}
166+
end
167+
end
168+
169+
def persist_config(attrs)
170+
DB.exec(
171+
<<~SQL,
172+
INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, created_at, updated_at)
173+
VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :now, :now)
174+
SQL
175+
id: attrs[:id],
176+
display_name: attrs[:display_name],
177+
dimensions: attrs[:dimensions],
178+
max_sequence_length: attrs[:max_sequence_length],
179+
pg_function: attrs[:pg_function],
180+
provider: attrs[:provider],
181+
tokenizer_class: attrs[:tokenizer_class],
182+
url: attrs[:url],
183+
api_key: attrs[:api_key],
184+
provider_params: attrs[:provider_params],
185+
now: Time.zone.now,
186+
)
187+
188+
# We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts.
189+
DB.exec(
190+
"ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq",
191+
new_seq: attrs[:id].to_i + 1,
192+
)
193+
194+
DB.exec(
195+
"UPDATE site_settings SET value=:id WHERE name = 'ai_embeddings_selected_model'",
196+
id: attrs[:id],
197+
)
198+
end
199+
end

0 commit comments

Comments
 (0)