Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 0ef5289

Browse files
committed
Use AR model for embeddings features
1 parent 890b85b commit 0ef5289

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+461
-842
lines changed

app/jobs/regular/digest_rag_upload.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def execute(args)
1818
target = target_type.constantize.find_by(id: target_id)
1919
return if !target
2020

21-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
21+
vector_rep = DiscourseAi::Embeddings::Vector.instance
2222

2323
tokenizer = vector_rep.tokenizer
2424
chunk_tokens = target.rag_chunk_tokens

app/models/embedding_definition.rb

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# frozen_string_literal: true
2+
3+
class EmbeddingDefinition < ActiveRecord::Base
4+
CLOUDFLARE = "cloudflare"
5+
DISCOURSE = "discourse"
6+
HUGGING_FACE = "hugging_face"
7+
OPEN_AI = "open_ai"
8+
GEMINI = "gemini"
9+
10+
class << self
11+
def provider_names
12+
[CLOUDFLARE, DISCOURSE, HUGGING_FACE, OPEN_AI, GEMINI]
13+
end
14+
15+
def tokenizer_names
16+
[
17+
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer,
18+
DiscourseAi::Tokenizer::BgeLargeEnTokenizer,
19+
DiscourseAi::Tokenizer::BgeM3Tokenizer,
20+
DiscourseAi::Tokenizer::OpenAiTokenizer,
21+
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer,
22+
DiscourseAi::Tokenizer::OpenAiTokenizer,
23+
].map(&:name)
24+
end
25+
26+
def self.provider_params
27+
{ discourse: { model_name: :text }, open_ai: { model_name: :text } }
28+
end
29+
end
30+
31+
validates :provider, presence: true, inclusion: provider_names
32+
validates :display_name, presence: true, length: { maximum: 100 }
33+
validates :tokenizer_class, presence: true, inclusion: tokenizer_names
34+
35+
def tokenizer
36+
tokenizer_class.constantize
37+
end
38+
39+
def inference_client
40+
case provider
41+
when CLOUDFLARE
42+
cloudflare_client
43+
when DISCOURSE
44+
discourse_client
45+
when HUGGING_FACE
46+
hugging_face_client
47+
when OPEN_AI
48+
open_ai_client
49+
when GEMINI
50+
gemini_client
51+
else
52+
raise "Uknown embeddings provider"
53+
end
54+
end
55+
56+
def lookup_custom_param(key)
57+
provider_params&.dig(key)
58+
end
59+
60+
def endpoint_url
61+
return url if !url.starts_with?("srv://")
62+
63+
service = DiscourseAi::Utils::DnsSrv.lookup(url)
64+
"https://#{service.target}:#{service.port}"
65+
end
66+
67+
def prepare_query_text(text, asymetric: false)
68+
strategy.prepare_query_text(text, self, asymetric: asymetric)
69+
end
70+
71+
def prepare_target_text(target)
72+
strategy.prepare_target_text(target, self)
73+
end
74+
75+
def strategy_id
76+
strategy.id
77+
end
78+
79+
def strategy_version
80+
strategy.version
81+
end
82+
83+
private
84+
85+
def strategy
86+
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
87+
end
88+
89+
def cloudflare_client
90+
DiscourseAi::Inference::CloudflareWorkersAi.new(endpoint_url, api_key)
91+
end
92+
93+
def discourse_client
94+
client_url = endpoint_url
95+
client_url = "#{client_url}/api/v1/classify" if url.starts_with?("srv://")
96+
97+
DiscourseAi::Inference::DiscourseClassifier.new(
98+
client_url,
99+
api_key,
100+
lookup_custom_param("model_name"),
101+
)
102+
end
103+
104+
def hugging_face_client
105+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(endpoint_url, api_key)
106+
end
107+
108+
def open_ai_client
109+
DiscourseAi::Inference::OpenAiEmbeddings.new(
110+
endpoint_url,
111+
api_key,
112+
lookup_custom_param("model_name"),
113+
dimensions,
114+
)
115+
end
116+
117+
def gemini_client
118+
DiscourseAi::Inference::GeminiEmbeddings.new(endpoint_url, api_key)
119+
end
120+
end
121+
122+
# == Schema Information
123+
#
124+
# Table name: embedding_definitions
125+
#
126+
# id :bigint not null, primary key
127+
# display_name :string not null
128+
# dimensions :integer not null
129+
# max_sequence_length :integer not null
130+
# version :integer default(1), not null
131+
# pg_function :string not null
132+
# provider :string not null
133+
# tokenizer_class :string not null
134+
# url :string not null
135+
# api_key :string
136+
# provider_params :jsonb
137+
# created_at :datetime not null
138+
# updated_at :datetime not null
139+
#

config/locales/server.en.yml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,11 @@ en:
6868
ai_auto_image_caption_allowed_groups: "Users on these groups can toggle automatic image captioning."
6969

7070
ai_embeddings_enabled: "Enable the embeddings module."
71-
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
72-
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
73-
ai_embeddings_model: "Use all-mpnet-base-v2 for local and fast inference in english, text-embedding-ada-002 to use OpenAI API (need API key) and multilingual-e5-large for local multilingual embeddings"
71+
ai_embeddings_selected_model: "Use the selected model for generating embeddings."
7472
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
7573
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
7674
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
7775
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
78-
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
7976
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
8077
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
8178
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
@@ -439,11 +436,7 @@ en:
439436
embeddings:
440437
configuration:
441438
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
442-
choose_model: "Set 'ai embeddings model' first."
443-
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
444-
hint:
445-
one: "Make sure the `%{settings}` setting was configured."
446-
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"
439+
choose_model: "Set 'ai embeddings selected model' first."
447440

448441
llm_models:
449442
missing_provider_param: "%{param} can't be blank"

config/settings.yml

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,27 +158,12 @@ discourse_ai:
158158
default: false
159159
client: true
160160
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
161-
ai_embeddings_discourse_service_api_endpoint: ""
162-
ai_embeddings_discourse_service_api_endpoint_srv:
163-
default: ""
164-
hidden: true
165-
ai_embeddings_discourse_service_api_key:
166-
default: ""
167-
secret: true
168-
ai_embeddings_model:
161+
ai_embeddings_selected_model:
169162
type: enum
170-
default: "bge-large-en"
163+
default: ""
171164
allow_any: false
172-
choices:
173-
- all-mpnet-base-v2
174-
- text-embedding-ada-002
175-
- text-embedding-3-small
176-
- text-embedding-3-large
177-
- multilingual-e5-large
178-
- bge-large-en
179-
- gemini
180-
- bge-m3
181-
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
165+
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
166+
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
182167
ai_embeddings_per_post_enabled:
183168
default: false
184169
hidden: true
@@ -191,9 +176,6 @@ discourse_ai:
191176
ai_embeddings_backfill_batch_size:
192177
default: 250
193178
hidden: true
194-
ai_embeddings_pg_connection_string:
195-
default: ""
196-
hidden: true
197179
ai_embeddings_semantic_search_enabled:
198180
default: false
199181
client: true
@@ -213,6 +195,35 @@ discourse_ai:
213195
default: false
214196
client: true
215197
hidden: true
198+
199+
ai_embeddings_discourse_service_api_endpoint:
200+
default: ""
201+
hidden: true
202+
ai_embeddings_discourse_service_api_endpoint_srv:
203+
default: ""
204+
hidden: true
205+
ai_embeddings_discourse_service_api_key:
206+
hidden: true
207+
default: ""
208+
secret: true
209+
ai_embeddings_model:
210+
hidden: true
211+
type: enum
212+
default: "bge-large-en"
213+
allow_any: false
214+
choices:
215+
- all-mpnet-base-v2
216+
- text-embedding-ada-002
217+
- text-embedding-3-small
218+
- text-embedding-3-large
219+
- multilingual-e5-large
220+
- bge-large-en
221+
- gemini
222+
- bge-m3
223+
ai_embeddings_pg_connection_string:
224+
default: ""
225+
hidden: true
226+
216227
ai_summarization_enabled:
217228
default: false
218229
client: true
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# frozen_string_literal: true
2+
class CreateEmbeddingDefinitions < ActiveRecord::Migration[7.2]
3+
def change
4+
create_table :embedding_definitions do |t|
5+
t.string :display_name, null: false
6+
t.integer :dimensions, null: false
7+
t.integer :max_sequence_length, null: false
8+
t.integer :version, null: false, default: 1
9+
t.string :pg_function, null: false
10+
t.string :provider, null: false
11+
t.string :tokenizer_class, null: false
12+
t.string :url, null: false
13+
t.string :api_key
14+
t.jsonb :provider_params
15+
t.timestamps
16+
end
17+
end
18+
end

lib/ai_bot/entry_point.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def inject_into(plugin)
196196
)
197197

198198
plugin.on(:site_setting_changed) do |name, old_value, new_value|
199-
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
199+
if name == :ai_embeddings_selected_model && SiteSetting.ai_embeddings_enabled? &&
200200
new_value != old_value
201201
RagDocumentFragment.delete_all
202202
UploadReference

lib/ai_bot/personas/persona.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
327327
rag_conversation_chunks
328328
end
329329

330-
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
330+
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment)
331331

332332
candidate_fragment_ids =
333333
schema

lib/ai_helper/semantic_categorizer.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def tags
9393

9494
def nearest_neighbors(limit: 100)
9595
vector = DiscourseAi::Embeddings::Vector.instance
96-
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
96+
schema = DiscourseAi::Embeddings::Schema.for(Topic)
9797

9898
raw_vector = vector.vector_from(@text)
9999

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# frozen_string_literal: true
2+
3+
require "enum_site_setting"
4+
5+
module DiscourseAi
6+
module Configuration
7+
class EmbeddingDefsEnumerator < ::EnumSiteSetting
8+
def self.valid_value?(val)
9+
true
10+
end
11+
12+
def self.values
13+
DB.query_hash(<<~SQL).map(&:symbolize_keys)
14+
SELECT display_name AS name, id AS value
15+
FROM embedding_definitions
16+
SQL
17+
end
18+
end
19+
end
20+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Configuration
5+
class EmbeddingDefsValidator
6+
def initialize(opts = {})
7+
@opts = opts
8+
end
9+
10+
def valid_value?(val)
11+
val.blank? || EmbeddingDefinition.exists?(id: val)
12+
end
13+
14+
def error_message
15+
""
16+
end
17+
end
18+
end
19+
end

0 commit comments

Comments
 (0)