Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 0abd4b1

Browse files
authored
FIX: Sentiment classification results needs to be transformed before saving (#983)
1 parent 120a20c commit 0abd4b1

File tree

3 files changed

+73
-1
lines changed

3 files changed

+73
-1
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# frozen_string_literal: true
2+
3+
class FixClassificationData < ActiveRecord::Migration[7.2]
4+
def up
5+
classifications = DB.query(<<~SQL)
6+
SELECT id, classification
7+
FROM classification_results
8+
WHERE classification_type = 'sentiment'
9+
AND SUBSTRING(LTRIM(classification::text), 1, 1) = '['
10+
SQL
11+
12+
transformed =
13+
classifications.reduce([]) do |memo, c|
14+
hash_result = {}
15+
c.classification.each { |r| hash_result[r["label"]] = r["score"] }
16+
17+
memo << { id: c.id, fixed_classification: hash_result }
18+
end
19+
20+
transformed_json = transformed.to_json
21+
22+
DB.exec(<<~SQL, values: transformed_json)
23+
UPDATE classification_results
24+
SET classification = N.fixed_classification
25+
FROM (
26+
SELECT (value::jsonb->'id')::integer AS id, (value::jsonb->'fixed_classification')::jsonb AS fixed_classification
27+
FROM jsonb_array_elements(:values::jsonb)
28+
) N
29+
WHERE classification_results.id = N.id
30+
AND classification_type = 'sentiment'
31+
SQL
32+
end
33+
34+
def down
35+
raise ActiveRecord::IrreversibleMigration
36+
end
37+
end

lib/sentiment/post_classification.rb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,15 @@ def classifiers
8383
end
8484

8585
def request_with(content, config, base_url = Discourse.base_url)
86-
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
86+
result =
87+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
88+
transform_result(result)
89+
end
90+
91+
def transform_result(result)
92+
hash_result = {}
93+
result.each { |r| hash_result[r[:label]] = r[:score] }
94+
hash_result
8795
end
8896

8997
def store_classification(target, classification)

spec/lib/modules/sentiment/post_classification_spec.rb

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
1212
end
1313

14+
def check_classification_for(post)
15+
result =
16+
ClassificationResult.find_by(
17+
model_used: "cardiffnlp/twitter-roberta-base-sentiment-latest",
18+
target: post,
19+
)
20+
21+
expect(result.classification.keys).to contain_exactly("negative", "neutral", "positive")
22+
end
23+
1424
describe "#classify!" do
1525
it "does nothing if the post content is blank" do
1626
post_1.update_columns(raw: "")
@@ -28,6 +38,13 @@
2838

2939
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
3040
end
41+
42+
it "classification results must be { emotion => score }" do
43+
SentimentInferenceStubs.stub_classification(post_1)
44+
45+
subject.classify!(post_1)
46+
check_classification_for(post_1)
47+
end
3148
end
3249

3350
describe "#classify_bulk!" do
@@ -43,5 +60,15 @@
4360
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
4461
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
4562
end
63+
64+
it "classification results must be { emotion => score }" do
65+
SentimentInferenceStubs.stub_classification(post_1)
66+
SentimentInferenceStubs.stub_classification(post_2)
67+
68+
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
69+
70+
check_classification_for(post_1)
71+
check_classification_for(post_2)
72+
end
4673
end
4774
end

0 commit comments

Comments
 (0)