Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions app/jobs/regular/post_sentiment_analysis.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ def execute(args)
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
return if post&.raw.blank?

DiscourseAi::PostClassificator.new(
DiscourseAi::Sentiment::SentimentClassification.new,
).classify!(post)
DiscourseAi::Sentiment::PostClassification.new.classify!(post)
end
end
end
4 changes: 2 additions & 2 deletions lib/inference/hugging_face_text_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def rerank(content, candidates)
JSON.parse(response.body, symbolize_names: true)
end

def classify(content, model_config)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def classify(content, model_config, base_url = Discourse.base_url)
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = model_config.api_key
headers["Authorization"] = "Bearer #{model_config.api_key}"

Expand Down
111 changes: 111 additions & 0 deletions lib/sentiment/post_classification.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# frozen_string_literal: true

module DiscourseAi
module Sentiment
class PostClassification
def bulk_classify!(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)

available_classifiers = classifiers
base_url = Discourse.base_url

promised_classifications =
relation
.map do |record|
text = prepare_text(record)
next if text.blank?

Concurrent::Promises
.fulfilled_future({ target: record, text: text }, pool)
.then_on(pool) do |w_text|
results = Concurrent::Hash.new

promised_target_results =
available_classifiers.map do |c|
Concurrent::Promises.future_on(pool) do
results[c.model_name] = request_with(w_text[:text], c, base_url)
end
end

Concurrent::Promises
.zip(*promised_target_results)
.then_on(pool) { |_| w_text.merge(classification: results) }
end
.flat(1)
end
.compact

Concurrent::Promises
.zip(*promised_classifications)
.value!
.each { |r| store_classification(r[:target], r[:classification]) }

pool.shutdown
pool.wait_for_termination
end

def classify!(target)
return if target.blank?

to_classify = prepare_text(target)
return if to_classify.blank?

results =
classifiers.reduce({}) do |memo, model|
memo[model.model_name] = request_with(to_classify, model)
memo
end

store_classification(target, results)
end

private

def prepare_text(target)
content =
if target.post_number == 1
"#{target.topic.title}\n#{target.raw}"
else
target.raw
end

Tokenizer::BertTokenizer.truncate(content, 512)
end

def classifiers
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
end

def request_with(content, config, base_url = Discourse.base_url)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
end

def store_classification(target, classification)
attrs =
classification.map do |model_name, classifications|
{
model_used: model_name,
target_id: target.id,
target_type: target.class.sti_name,
classification_type: :sentiment,
classification: classifications,
updated_at: DateTime.now,
created_at: DateTime.now,
}
end

ClassificationResult.upsert_all(
attrs,
unique_by: %i[target_id target_type model_used],
update_only: %i[classification],
)
end
end
end
end
7 changes: 2 additions & 5 deletions lib/tasks/modules/sentiment/backfill.rake
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args|
.where("category_id IN (?)", public_categories)
.where(posts: { deleted_at: nil })
.where(topics: { deleted_at: nil })
.order("posts.id ASC")
.find_each do |post|
.find_in_batches do |batch|
print "."
DiscourseAi::PostClassificator.new(
DiscourseAi::Sentiment::SentimentClassification.new,
).classify!(post)
DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch)
end
end
47 changes: 47 additions & 0 deletions spec/lib/modules/sentiment/post_classification_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# frozen_string_literal: true

require_relative "../../../support/sentiment_inference_stubs"

RSpec.describe DiscourseAi::Sentiment::PostClassification do
fab!(:post_1) { Fabricate(:post, post_number: 2) }

before do
SiteSetting.ai_sentiment_enabled = true
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end

describe "#classify!" do
it "does nothing if the post content is blank" do
post_1.update_columns(raw: "")

subject.classify!(post_1)

expect(ClassificationResult.where(target: post_1).count).to be_zero
end

it "successfully classifies the post" do
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
SentimentInferenceStubs.stub_classification(post_1)

subject.classify!(post_1)

expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
end
end

describe "#classify_bulk!" do
fab!(:post_2) { Fabricate(:post, post_number: 2) }

it "classifies all given posts" do
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
SentimentInferenceStubs.stub_classification(post_1)
SentimentInferenceStubs.stub_classification(post_2)

subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))

expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
end
end
end