Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions lib/inference/hugging_face_text_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ def initialize(endpoint, key, referer = Discourse.base_url)
attr_reader :endpoint, :key, :referer

class << self
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end

def reranker_configured?
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
Expand Down Expand Up @@ -50,32 +45,23 @@ def rerank(content, candidates)

JSON.parse(response.body, symbolize_names: true)
end
end

def classify(content, model_config, base_url = Discourse.base_url)
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = model_config.api_key
headers["Authorization"] = "Bearer #{model_config.api_key}"

body = { inputs: content, truncate: true }.to_json

api_endpoint = model_config.endpoint
if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end
def classify_by_sentiment!(content)
response = do_request!(content)

conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
JSON.parse(response.body, symbolize_names: true)
end

if response.status != 200
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
end
def perform!(content)
response = do_request!(content)

JSON.parse(response.body, symbolize_names: true)
end
JSON.parse(response.body, symbolize_names: true).first
end

def perform!(content)
private

def do_request!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json

Expand All @@ -89,7 +75,7 @@ def perform!(content)

raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)

JSON.parse(response.body, symbolize_names: true).first
response
end
end
end
Expand Down
39 changes: 27 additions & 12 deletions lib/sentiment/post_classification.rb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def bulk_classify!(relation)

available_classifiers = classifiers
return if available_classifiers.blank?
base_url = Discourse.base_url

promised_classifications =
relation
Expand All @@ -70,12 +69,14 @@ def bulk_classify!(relation)
already_classified = w_text[:target].sentiment_classifications.map(&:model_used)

classifiers_for_target =
available_classifiers.reject { |ac| already_classified.include?(ac.model_name) }
available_classifiers.reject do |ac|
already_classified.include?(ac[:model_name])
end

promised_target_results =
classifiers_for_target.map do |c|
classifiers_for_target.map do |cft|
Concurrent::Promises.future_on(pool) do
results[c.model_name] = request_with(w_text[:text], c, base_url)
results[cft[:model_name]] = request_with(cft[:client], w_text[:text])
end
end

Expand All @@ -98,26 +99,40 @@ def bulk_classify!(relation)

def classify!(target)
return if target.blank?
return if classifiers.blank?
available_classifiers = classifiers
return if available_classifiers.blank?

to_classify = prepare_text(target)
return if to_classify.blank?

already_classified = target.sentiment_classifications.map(&:model_used)
classifiers_for_target =
classifiers.reject { |ac| already_classified.include?(ac.model_name) }
available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) }

results =
classifiers_for_target.reduce({}) do |memo, model|
memo[model.model_name] = request_with(to_classify, model)
classifiers_for_target.reduce({}) do |memo, cft|
memo[cft[:model_name]] = request_with(cft[:client], to_classify)
memo
end

store_classification(target, results)
end

def classifiers
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map do |config|
api_endpoint = config.endpoint

if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end

{
model_name: config.model_name,
client:
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(api_endpoint, config.api_key),
}
end
end

def has_classifiers?
Expand All @@ -137,9 +152,9 @@ def prepare_text(target)
Tokenizer::BertTokenizer.truncate(content, 512)
end

def request_with(content, config, base_url = Discourse.base_url)
result =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
def request_with(client, content)
result = client.classify_by_sentiment!(content)

transform_result(result)
end

Expand Down