From d06f0065d28fa9a4d02beead4306d361f3033469 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Thu, 6 Feb 2025 11:28:34 -0300 Subject: [PATCH] DEV: Build sentiment clients outside of promises --- lib/inference/hugging_face_text_embeddings.rb | 38 ++++++------------ lib/sentiment/post_classification.rb | 39 +++++++++++++------ 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index 954e2a30f..67a964f88 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -12,11 +12,6 @@ def initialize(endpoint, key, referer = Discourse.base_url) attr_reader :endpoint, :key, :referer class << self - def configured? - SiteSetting.ai_hugging_face_tei_endpoint.present? || - SiteSetting.ai_hugging_face_tei_endpoint_srv.present? - end - def reranker_configured? SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? || SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present? @@ -50,32 +45,23 @@ def rerank(content, candidates) JSON.parse(response.body, symbolize_names: true) end + end - def classify(content, model_config, base_url = Discourse.base_url) - headers = { "Referer" => base_url, "Content-Type" => "application/json" } - headers["X-API-KEY"] = model_config.api_key - headers["Authorization"] = "Bearer #{model_config.api_key}" - - body = { inputs: content, truncate: true }.to_json - - api_endpoint = model_config.endpoint - if api_endpoint.present? && api_endpoint.start_with?("srv://") - service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://")) - api_endpoint = "https://#{service.target}:#{service.port}" - end + def classify_by_sentiment!(content) + response = do_request!(content) - conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } - response = conn.post(api_endpoint, body, headers) + JSON.parse(response.body, symbolize_names: true) + end - if response.status != 200 - raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}") - end + def perform!(content) + response = do_request!(content) - JSON.parse(response.body, symbolize_names: true) - end + JSON.parse(response.body, symbolize_names: true).first end - def perform!(content) + private + + def do_request!(content) headers = { "Referer" => referer, "Content-Type" => "application/json" } body = { inputs: content, truncate: true }.to_json @@ -89,7 +75,7 @@ def perform!(content) raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status) - JSON.parse(response.body, symbolize_names: true).first + response end end end diff --git a/lib/sentiment/post_classification.rb b/lib/sentiment/post_classification.rb index c2dfae443..3e0bab880 100644 --- a/lib/sentiment/post_classification.rb +++ b/lib/sentiment/post_classification.rb @@ -55,7 +55,6 @@ def bulk_classify!(relation) available_classifiers = classifiers return if available_classifiers.blank? - base_url = Discourse.base_url promised_classifications = relation @@ -70,12 +69,14 @@ def bulk_classify!(relation) already_classified = w_text[:target].sentiment_classifications.map(&:model_used) classifiers_for_target = - available_classifiers.reject { |ac| already_classified.include?(ac.model_name) } + available_classifiers.reject do |ac| + already_classified.include?(ac[:model_name]) + end promised_target_results = - classifiers_for_target.map do |c| + classifiers_for_target.map do |cft| Concurrent::Promises.future_on(pool) do - results[c.model_name] = request_with(w_text[:text], c, base_url) + results[cft[:model_name]] = request_with(cft[:client], w_text[:text]) end end @@ -98,18 +99,19 @@ def bulk_classify!(relation) def classify!(target) return if target.blank? - return if classifiers.blank? + available_classifiers = classifiers + return if available_classifiers.blank? to_classify = prepare_text(target) return if to_classify.blank? already_classified = target.sentiment_classifications.map(&:model_used) classifiers_for_target = - classifiers.reject { |ac| already_classified.include?(ac.model_name) } + available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) } results = - classifiers_for_target.reduce({}) do |memo, model| - memo[model.model_name] = request_with(to_classify, model) + classifiers_for_target.reduce({}) do |memo, cft| + memo[cft[:model_name]] = request_with(cft[:client], to_classify) memo end @@ -117,7 +119,20 @@ def classify!(target) end def classifiers - DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values + DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map do |config| + api_endpoint = config.endpoint + + if api_endpoint.present? && api_endpoint.start_with?("srv://") + service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://")) + api_endpoint = "https://#{service.target}:#{service.port}" + end + + { + model_name: config.model_name, + client: + DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(api_endpoint, config.api_key), + } + end end def has_classifiers? @@ -137,9 +152,9 @@ def prepare_text(target) Tokenizer::BertTokenizer.truncate(content, 512) end - def request_with(content, config, base_url = Discourse.base_url) - result = - DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url) + def request_with(client, content) + result = client.classify_by_sentiment!(content) + transform_result(result) end