Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions db/migrate/20241129190708_fix_classification_data.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# frozen_string_literal: true

class FixClassificationData < ActiveRecord::Migration[7.2]
def up
classifications = DB.query(<<~SQL)
SELECT id, classification
FROM classification_results
WHERE classification_type = 'sentiment'
AND SUBSTRING(LTRIM(classification::text), 1, 1) = '['
SQL

transformed =
classifications.reduce([]) do |memo, c|
hash_result = {}
c.classification.each { |r| hash_result[r["label"]] = r["score"] }

memo << { id: c.id, fixed_classification: hash_result }
end

transformed_json = transformed.to_json

DB.exec(<<~SQL, values: transformed_json)
UPDATE classification_results
SET classification = N.fixed_classification
FROM (
SELECT (value::jsonb->'id')::integer AS id, (value::jsonb->'fixed_classification')::jsonb AS fixed_classification
FROM jsonb_array_elements(:values::jsonb)
) N
WHERE classification_results.id = N.id
AND classification_type = 'sentiment'
SQL
end

def down
raise ActiveRecord::IrreversibleMigration
end
end
10 changes: 9 additions & 1 deletion lib/sentiment/post_classification.rb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,15 @@ def classifiers
end

def request_with(content, config, base_url = Discourse.base_url)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
result =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
transform_result(result)
end

def transform_result(result)
hash_result = {}
result.each { |r| hash_result[r[:label]] = r[:score] }
hash_result
end

def store_classification(target, classification)
Expand Down
27 changes: 27 additions & 0 deletions spec/lib/modules/sentiment/post_classification_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end

def check_classification_for(post)
result =
ClassificationResult.find_by(
model_used: "cardiffnlp/twitter-roberta-base-sentiment-latest",
target: post,
)

expect(result.classification.keys).to contain_exactly("negative", "neutral", "positive")
end

describe "#classify!" do
it "does nothing if the post content is blank" do
post_1.update_columns(raw: "")
Expand All @@ -28,6 +38,13 @@

expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
end

it "classification results must be { emotion => score }" do
SentimentInferenceStubs.stub_classification(post_1)

subject.classify!(post_1)
check_classification_for(post_1)
end
end

describe "#classify_bulk!" do
Expand All @@ -43,5 +60,15 @@
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
end

it "classification results must be { emotion => score }" do
SentimentInferenceStubs.stub_classification(post_1)
SentimentInferenceStubs.stub_classification(post_2)

subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))

check_classification_for(post_1)
check_classification_for(post_2)
end
end
end