Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 690d6e6

Browse files
committed
REFACTOR: Tidy-up embedding endpoints config.
Two changes worth mentioning: `#instance` returns a fully configured embedding endpoint ready to use. All endpoints respond to the same method and have the same signature - `perform!(text)` This makes it easier to reuse them when generating embeddings in bulk.
1 parent 1a10680 commit 690d6e6

17 files changed

+207
-123
lines changed

lib/embeddings/vector_representations/all_mpnet_base_v2.rb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@ def dependant_setting_names
2424
end
2525

2626
def vector_from(text, asymetric: false)
27-
DiscourseAi::Inference::DiscourseClassifier.perform!(
28-
"#{discourse_embeddings_endpoint}/api/v1/classify",
29-
self.class.name,
30-
text,
31-
SiteSetting.ai_embeddings_discourse_service_api_key,
32-
)
27+
inference_client.perform!(text)
3328
end
3429

3530
def dimensions
@@ -59,6 +54,10 @@ def pg_index_type
5954
def tokenizer
6055
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
6156
end
57+
58+
def inference_client
59+
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
60+
end
6261
end
6362
end
6463
end

lib/embeddings/vector_representations/base.rb

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -426,16 +426,8 @@ def save_to_db(target, vector, digest)
426426
end
427427
end
428428

429-
def discourse_embeddings_endpoint
430-
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
431-
service =
432-
DiscourseAi::Utils::DnsSrv.lookup(
433-
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
434-
)
435-
"https://#{service.target}:#{service.port}"
436-
else
437-
SiteSetting.ai_embeddings_discourse_service_api_endpoint
438-
end
429+
def inference_client
430+
raise NotImplementedError
439431
end
440432
end
441433
end

lib/embeddings/vector_representations/bge_large_en.rb

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,12 @@ def dependant_setting_names
3333
def vector_from(text, asymetric: false)
3434
text = "#{asymmetric_query_prefix} #{text}" if asymetric
3535

36-
if SiteSetting.ai_cloudflare_workers_api_token.present?
37-
DiscourseAi::Inference::CloudflareWorkersAi
38-
.perform!(inference_model_name, { text: text })
39-
.dig(:result, :data)
40-
.first
41-
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
42-
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
43-
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
44-
elsif discourse_embeddings_endpoint.present?
45-
DiscourseAi::Inference::DiscourseClassifier.perform!(
46-
"#{discourse_embeddings_endpoint}/api/v1/classify",
47-
inference_model_name.split("/").last,
48-
text,
49-
SiteSetting.ai_embeddings_discourse_service_api_key,
50-
)
51-
else
52-
raise "No inference endpoint configured"
53-
end
36+
client = inference_client
37+
38+
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
39+
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
40+
41+
inference_client.perform!(text)
5442
end
5543

5644
def inference_model_name
@@ -88,6 +76,21 @@ def tokenizer
8876
def asymmetric_query_prefix
8977
"Represent this sentence for searching relevant passages:"
9078
end
79+
80+
def inference_client
81+
if SiteSetting.ai_cloudflare_workers_api_token.present?
82+
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
83+
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
84+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
85+
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
86+
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
87+
DiscourseAi::Inference::DiscourseClassifier.instance(
88+
inference_model_name.split("/").last,
89+
)
90+
else
91+
raise "No inference endpoint configured"
92+
end
93+
end
9194
end
9295
end
9396
end

lib/embeddings/vector_representations/bge_m3.rb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def dependant_setting_names
2020

2121
def vector_from(text, asymetric: false)
2222
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
23-
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
23+
inference_client.perform!(truncated_text)
2424
end
2525

2626
def dimensions
@@ -50,6 +50,10 @@ def pg_index_type
5050
def tokenizer
5151
DiscourseAi::Tokenizer::BgeM3Tokenizer
5252
end
53+
54+
def inference_client
55+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
56+
end
5357
end
5458
end
5559
end

lib/embeddings/vector_representations/gemini.rb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ def pg_index_type
4343
end
4444

4545
def vector_from(text, asymetric: false)
46-
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
47-
response[:embedding][:values]
46+
inference_client.perform!(text).dig(:embedding, :values)
4847
end
4948

5049
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
@@ -53,6 +52,10 @@ def vector_from(text, asymetric: false)
5352
def tokenizer
5453
DiscourseAi::Tokenizer::OpenAiTokenizer
5554
end
55+
56+
def inference_client
57+
DiscourseAi::Inference::GeminiEmbeddings.instance
58+
end
5659
end
5760
end
5861
end

lib/embeddings/vector_representations/multilingual_e5_large.rb

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,16 @@ def dependant_setting_names
2929
end
3030

3131
def vector_from(text, asymetric: false)
32-
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
33-
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
34-
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
35-
elsif discourse_embeddings_endpoint.present?
36-
DiscourseAi::Inference::DiscourseClassifier.perform!(
37-
"#{discourse_embeddings_endpoint}/api/v1/classify",
38-
self.class.name,
39-
"query: #{text}",
40-
SiteSetting.ai_embeddings_discourse_service_api_key,
41-
)
32+
client = inference_client
33+
34+
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
35+
if needs_truncation
36+
text = tokenizer.truncate(text, max_sequence_length - 2)
4237
else
43-
raise "No inference endpoint configured"
38+
text = "query: #{text}"
4439
end
40+
41+
client.perform!(text)
4542
end
4643

4744
def id
@@ -71,6 +68,17 @@ def pg_index_type
7168
def tokenizer
7269
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
7370
end
71+
72+
def inference_client
73+
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
74+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance
75+
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
76+
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
77+
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
78+
else
79+
raise "No inference endpoint configured"
80+
end
81+
end
7482
end
7583
end
7684
end

lib/embeddings/vector_representations/text_embedding_3_large.rb

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,19 @@ def pg_index_type
4545
end
4646

4747
def vector_from(text, asymetric: false)
48-
response =
49-
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
50-
text,
51-
model: self.class.name,
52-
dimensions: dimensions,
53-
)
54-
response[:data].first[:embedding]
48+
inference_client.perform!(text)
5549
end
5650

5751
def tokenizer
5852
DiscourseAi::Tokenizer::OpenAiTokenizer
5953
end
54+
55+
def inference_client
56+
DiscourseAi::Inference::OpenAiEmbeddings.instance(
57+
model: self.class.name,
58+
dimensions: dimensions,
59+
)
60+
end
6061
end
6162
end
6263
end

lib/embeddings/vector_representations/text_embedding_3_small.rb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ def pg_index_type
4343
end
4444

4545
def vector_from(text, asymetric: false)
46-
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
47-
response[:data].first[:embedding]
46+
inference_client.perform!(text)
4847
end
4948

5049
def tokenizer
5150
DiscourseAi::Tokenizer::OpenAiTokenizer
5251
end
52+
53+
def inference_client
54+
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
55+
end
5356
end
5457
end
5558
end

lib/embeddings/vector_representations/text_embedding_ada_002.rb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ def pg_index_type
4343
end
4444

4545
def vector_from(text, asymetric: false)
46-
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
47-
response[:data].first[:embedding]
46+
inference_client.perform!(text)
4847
end
4948

5049
def tokenizer
5150
DiscourseAi::Tokenizer::OpenAiTokenizer
5251
end
52+
53+
def inference_client
54+
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
55+
end
5356
end
5457
end
5558
end

lib/inference/cloudflare_workers_ai.rb

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,38 @@
33
module ::DiscourseAi
44
module Inference
55
class CloudflareWorkersAi
6-
def self.perform!(model, content)
7-
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
6+
def initialize(account_id, api_token, model, referer = Discourse.base_url)
7+
@account_id = account_id
8+
@api_token = api_token
9+
@model = model
10+
@referer = referer
11+
end
12+
13+
def self.instance(model)
14+
new(
15+
SiteSetting.ai_cloudflare_workers_account_id,
16+
SiteSetting.ai_cloudflare_workers_api_token,
17+
model,
18+
)
19+
end
820

9-
account_id = SiteSetting.ai_cloudflare_workers_account_id
10-
token = SiteSetting.ai_cloudflare_workers_api_token
21+
attr_reader :account_id, :api_token, :model, :referer
1122

12-
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
13-
headers["Authorization"] = "Bearer #{token}"
23+
def perform!(content)
24+
headers = {
25+
"Referer" => Discourse.base_url,
26+
"Content-Type" => "application/json",
27+
"Authorization" => "Bearer #{api_token}",
28+
}
1429

15-
endpoint = "#{base_url}#{model}"
30+
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
1631

1732
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
1833
response = conn.post(endpoint, content.to_json, headers)
1934

20-
raise Net::HTTPBadResponse if ![200].include?(response.status)
21-
2235
case response.status
2336
when 200
24-
JSON.parse(response.body, symbolize_names: true)
37+
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
2538
when 429
2639
# TODO add a AdminDashboard Problem?
2740
else

0 commit comments

Comments
 (0)