diff --git a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb index 1a4b8002e..7e4a2ad7a 100644 --- a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb +++ b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb @@ -24,12 +24,7 @@ def dependant_setting_names end def vector_from(text, asymetric: false) - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{discourse_embeddings_endpoint}/api/v1/classify", - self.class.name, - text, - SiteSetting.ai_embeddings_discourse_service_api_key, - ) + inference_client.perform!(text) end def dimensions @@ -59,6 +54,10 @@ def pg_index_type def tokenizer DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer end + + def inference_client + DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name) + end end end end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index d5c23d210..be6b46b57 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -426,16 +426,8 @@ def save_to_db(target, vector, digest) end end - def discourse_embeddings_endpoint - if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup( - SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, - ) - "https://#{service.target}:#{service.port}" - else - SiteSetting.ai_embeddings_discourse_service_api_endpoint - end + def inference_client + raise NotImplementedError end end end diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb index 601c85a16..923ee19ec 100644 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ b/lib/embeddings/vector_representations/bge_large_en.rb @@ -33,24 +33,12 @@ def dependant_setting_names def vector_from(text, asymetric: false) text = "#{asymmetric_query_prefix} #{text}" if asymetric - if SiteSetting.ai_cloudflare_workers_api_token.present? - DiscourseAi::Inference::CloudflareWorkersAi - .perform!(inference_model_name, { text: text }) - .dig(:result, :data) - .first - elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first - elsif discourse_embeddings_endpoint.present? - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{discourse_embeddings_endpoint}/api/v1/classify", - inference_model_name.split("/").last, - text, - SiteSetting.ai_embeddings_discourse_service_api_key, - ) - else - raise "No inference endpoint configured" - end + client = inference_client + + needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") + text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation + + inference_client.perform!(text) end def inference_model_name @@ -88,6 +76,21 @@ def tokenizer def asymmetric_query_prefix "Represent this sentence for searching relevant passages:" end + + def inference_client + if SiteSetting.ai_cloudflare_workers_api_token.present? + DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name) + elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? + DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance + elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + DiscourseAi::Inference::DiscourseClassifier.instance( + inference_model_name.split("/").last, + ) + else + raise "No inference endpoint configured" + end + end end end end diff --git a/lib/embeddings/vector_representations/bge_m3.rb b/lib/embeddings/vector_representations/bge_m3.rb index c220cf750..d7e963fc5 100644 --- a/lib/embeddings/vector_representations/bge_m3.rb +++ b/lib/embeddings/vector_representations/bge_m3.rb @@ -20,7 +20,7 @@ def dependant_setting_names def vector_from(text, asymetric: false) truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first + inference_client.perform!(truncated_text) end def dimensions @@ -50,6 +50,10 @@ def pg_index_type def tokenizer DiscourseAi::Tokenizer::BgeM3Tokenizer end + + def inference_client + DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance + end end end end diff --git a/lib/embeddings/vector_representations/gemini.rb b/lib/embeddings/vector_representations/gemini.rb index 86b7afaee..a693849d3 100644 --- a/lib/embeddings/vector_representations/gemini.rb +++ b/lib/embeddings/vector_representations/gemini.rb @@ -43,8 +43,7 @@ def pg_index_type end def vector_from(text, asymetric: false) - response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text) - response[:embedding][:values] + inference_client.perform!(text).dig(:embedding, :values) end # There is no public tokenizer for Gemini, and from the ones we already ship in the plugin @@ -53,6 +52,10 @@ def vector_from(text, asymetric: false) def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::GeminiEmbeddings.instance + end end end end diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index 8267f9387..c7ef3c0fe 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -29,19 +29,16 @@ def dependant_setting_names end def vector_from(text, asymetric: false) - if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first - elsif discourse_embeddings_endpoint.present? - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{discourse_embeddings_endpoint}/api/v1/classify", - self.class.name, - "query: #{text}", - SiteSetting.ai_embeddings_discourse_service_api_key, - ) + client = inference_client + + needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") + if needs_truncation + text = tokenizer.truncate(text, max_sequence_length - 2) else - raise "No inference endpoint configured" + text = "query: #{text}" end + + client.perform!(text) end def id @@ -71,6 +68,17 @@ def pg_index_type def tokenizer DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer end + + def inference_client + if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? + DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance + elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name) + else + raise "No inference endpoint configured" + end + end end end end diff --git a/lib/embeddings/vector_representations/text_embedding_3_large.rb b/lib/embeddings/vector_representations/text_embedding_3_large.rb index 626428c95..202d66de6 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_large.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_large.rb @@ -45,18 +45,19 @@ def pg_index_type end def vector_from(text, asymetric: false) - response = - DiscourseAi::Inference::OpenAiEmbeddings.perform!( - text, - model: self.class.name, - dimensions: dimensions, - ) - response[:data].first[:embedding] + inference_client.perform!(text) end def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::OpenAiEmbeddings.instance( + model: self.class.name, + dimensions: dimensions, + ) + end end end end diff --git a/lib/embeddings/vector_representations/text_embedding_3_small.rb b/lib/embeddings/vector_representations/text_embedding_3_small.rb index fbac4bc78..87f311859 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_small.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_small.rb @@ -43,13 +43,16 @@ def pg_index_type end def vector_from(text, asymetric: false) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name) - response[:data].first[:embedding] + inference_client.perform!(text) end def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name) + end end end end diff --git a/lib/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/embeddings/vector_representations/text_embedding_ada_002.rb index 2079e028c..1e570b983 100644 --- a/lib/embeddings/vector_representations/text_embedding_ada_002.rb +++ b/lib/embeddings/vector_representations/text_embedding_ada_002.rb @@ -43,13 +43,16 @@ def pg_index_type end def vector_from(text, asymetric: false) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name) - response[:data].first[:embedding] + inference_client.perform!(text) end def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name) + end end end end diff --git a/lib/inference/cloudflare_workers_ai.rb b/lib/inference/cloudflare_workers_ai.rb index 099ae5be9..b0cd59266 100644 --- a/lib/inference/cloudflare_workers_ai.rb +++ b/lib/inference/cloudflare_workers_ai.rb @@ -3,25 +3,38 @@ module ::DiscourseAi module Inference class CloudflareWorkersAi - def self.perform!(model, content) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def initialize(account_id, api_token, model, referer = Discourse.base_url) + @account_id = account_id + @api_token = api_token + @model = model + @referer = referer + end + + def self.instance(model) + new( + SiteSetting.ai_cloudflare_workers_account_id, + SiteSetting.ai_cloudflare_workers_api_token, + model, + ) + end - account_id = SiteSetting.ai_cloudflare_workers_account_id - token = SiteSetting.ai_cloudflare_workers_api_token + attr_reader :account_id, :api_token, :model, :referer - base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/" - headers["Authorization"] = "Bearer #{token}" + def perform!(content) + headers = { + "Referer" => Discourse.base_url, + "Content-Type" => "application/json", + "Authorization" => "Bearer #{api_token}", + } - endpoint = "#{base_url}#{model}" + endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}" conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } response = conn.post(endpoint, content.to_json, headers) - raise Net::HTTPBadResponse if ![200].include?(response.status) - case response.status when 200 - JSON.parse(response.body, symbolize_names: true) + JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first when 429 # TODO add a AdminDashboard Problem? else diff --git a/lib/inference/discourse_classifier.rb b/lib/inference/discourse_classifier.rb index 3784a190c..46f912ddc 100644 --- a/lib/inference/discourse_classifier.rb +++ b/lib/inference/discourse_classifier.rb @@ -3,9 +3,36 @@ module ::DiscourseAi module Inference class DiscourseClassifier - def self.perform!(endpoint, model, content, api_key) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def initialize(endpoint, api_key, model, referer = Discourse.base_url) + @endpoint = endpoint + @api_key = api_key + @model = model + @referer = referer + end + + def self.instance(model) + endpoint = + if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_embeddings_discourse_service_api_endpoint + end + + new( + "#{endpoint}/api/v1/classify", + SiteSetting.ai_embeddings_discourse_service_api_key, + model, + ) + end + + attr_reader :endpoint, :api_key, :model, :referer + def perform!(content) + headers = { "Referer" => referer, "Content-Type" => "application/json" } headers["X-API-KEY"] = api_key if api_key.present? conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } diff --git a/lib/inference/gemini_embeddings.rb b/lib/inference/gemini_embeddings.rb index cedda24cc..13fb62c5a 100644 --- a/lib/inference/gemini_embeddings.rb +++ b/lib/inference/gemini_embeddings.rb @@ -3,12 +3,17 @@ module ::DiscourseAi module Inference class GeminiEmbeddings - def self.perform!(content) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def initialize(api_key, referer = Discourse.base_url) + @api_key = api_key + @referer = referer + end - url = - "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}" + attr_reader :api_key, :referer + def perform!(content) + headers = { "Referer" => referer, "Content-Type" => "application/json" } + url = + "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}" body = { content: { parts: [{ text: content }] } } conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index 0e904a946..743a2b570 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -3,30 +3,36 @@ module ::DiscourseAi module Inference class HuggingFaceTextEmbeddings - class << self - def perform!(content) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } - body = { inputs: content, truncate: true }.to_json + def initialize(endpoint, key, referer = Discourse.base_url) + @endpoint = endpoint + @key = key + @referer = referer + end - if SiteSetting.ai_hugging_face_tei_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv) - api_endpoint = "https://#{service.target}:#{service.port}" - else - api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint - end + attr_reader :endpoint, :key, :referer - if SiteSetting.ai_hugging_face_tei_api_key.present? - headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key - headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_tei_api_key}" - end - - conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } - response = conn.post(api_endpoint, body, headers) + class << self + def instance + endpoint = + if SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_hugging_face_tei_endpoint + end + + new(endpoint, SiteSetting.ai_hugging_face_tei_api_key) + end - raise Net::HTTPBadResponse if ![200].include?(response.status) + def configured? + SiteSetting.ai_hugging_face_tei_endpoint.present? || + SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + end - JSON.parse(response.body, symbolize_names: true) + def reranker_configured? + SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? || + SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present? end def rerank(content, candidates) @@ -80,16 +86,23 @@ def classify(content, model_config) JSON.parse(response.body, symbolize_names: true) end + end - def reranker_configured? - SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? || - SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present? - end + def perform!(content) + headers = { "Referer" => referer, "Content-Type" => "application/json" } + body = { inputs: content, truncate: true }.to_json - def configured? - SiteSetting.ai_hugging_face_tei_endpoint.present? || - SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + if key.present? + headers["X-API-KEY"] = key + headers["Authorization"] = "Bearer #{key}" end + + conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } + response = conn.post(endpoint, body, headers) + + raise Net::HTTPBadResponse if ![200].include?(response.status) + + JSON.parse(response.body, symbolize_names: true).first end end end diff --git a/lib/inference/open_ai_embeddings.rb b/lib/inference/open_ai_embeddings.rb index 9ffcaa49a..e3e6551c8 100644 --- a/lib/inference/open_ai_embeddings.rb +++ b/lib/inference/open_ai_embeddings.rb @@ -3,13 +3,26 @@ module ::DiscourseAi module Inference class OpenAiEmbeddings - def self.perform!(content, model:, dimensions: nil) + def initialize(endpoint, api_key, model, dimensions) + @endpoint = endpoint + @api_key = api_key + @model = model + @dimensions = dimensions + end + + attr_reader :endpoint, :api_key, :model, :dimensions + + def self.instance(model:, dimensions: nil) + new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions) + end + + def perform!(content) headers = { "Content-Type" => "application/json" } - if SiteSetting.ai_openai_embeddings_url.include?("azure") - headers["api-key"] = SiteSetting.ai_openai_api_key + if endpoint.include?("azure") + headers["api-key"] = api_key else - headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" + headers["Authorization"] = "Bearer #{api_key}" end payload = { model: model, input: content } @@ -20,7 +33,7 @@ def self.perform!(content, model:, dimensions: nil) case response.status when 200 - JSON.parse(response.body, symbolize_names: true) + JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding) when 429 # TODO add a AdminDashboard Problem? else diff --git a/lib/nsfw/classification.rb b/lib/nsfw/classification.rb index c87ba8d17..a6f99439a 100644 --- a/lib/nsfw/classification.rb +++ b/lib/nsfw/classification.rb @@ -54,12 +54,11 @@ def evaluate_with_model(model, upload) upload_url = Discourse.store.cdn_url(upload.url) upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") - DiscourseAi::Inference::DiscourseClassifier.perform!( + DiscourseAi::Inference::DiscourseClassifier.new( "#{endpoint}/api/v1/classify", - model, - upload_url, SiteSetting.ai_nsfw_inference_service_api_key, - ) + model, + ).perform!(upload_url) end def available_models diff --git a/lib/toxicity/toxicity_classification.rb b/lib/toxicity/toxicity_classification.rb index c178d2e18..1756d3b29 100644 --- a/lib/toxicity/toxicity_classification.rb +++ b/lib/toxicity/toxicity_classification.rb @@ -42,12 +42,11 @@ def should_flag_based_on?(verdicts) def request(target_to_classify) data = - ::DiscourseAi::Inference::DiscourseClassifier.perform!( + ::DiscourseAi::Inference::DiscourseClassifier.new( "#{endpoint}/api/v1/classify", - SiteSetting.ai_toxicity_inference_service_api_model, - content_of(target_to_classify), SiteSetting.ai_toxicity_inference_service_api_key, - ) + SiteSetting.ai_toxicity_inference_service_api_model, + ).perform!(content_of(target_to_classify)) { available_model => data } end diff --git a/spec/shared/inference/openai_embeddings_spec.rb b/spec/shared/inference/openai_embeddings_spec.rb index e938b6a44..7db19a7ec 100644 --- a/spec/shared/inference/openai_embeddings_spec.rb +++ b/spec/shared/inference/openai_embeddings_spec.rb @@ -26,10 +26,11 @@ ).to_return(status: 200, body: body_json, headers: {}) result = - DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002") + DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!( + "hello", + ) - expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) - expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) + expect(result).to eq([0.0, 0.1]) end it "supports openai embeddings" do @@ -54,13 +55,11 @@ ).to_return(status: 200, body: body_json, headers: {}) result = - DiscourseAi::Inference::OpenAiEmbeddings.perform!( - "hello", + DiscourseAi::Inference::OpenAiEmbeddings.instance( model: "text-embedding-ada-002", dimensions: 1000, - ) + ).perform!("hello") - expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) - expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) + expect(result).to eq([0.0, 0.1]) end end