Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lib/ai_helper/semantic_categorizer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def categories
return [] unless SiteSetting.ai_embeddings_enabled

candidates = nearest_neighbors(limit: 100)
return [] if candidates.empty?

candidate_ids = candidates.map(&:first)

::Topic
Expand Down Expand Up @@ -52,6 +54,8 @@ def tags
return [] unless SiteSetting.ai_embeddings_enabled

candidates = nearest_neighbors(limit: 100)
return [] if candidates.empty?

candidate_ids = candidates.map(&:first)

count_column = Tag.topic_count_column(@user.guardian) # Determine the count column
Expand Down Expand Up @@ -94,11 +98,21 @@ def nearest_neighbors(limit: 100)

raw_vector = vector_rep.vector_from(@text)

muted_category_ids = nil
if @user.present?
muted_category_ids =
CategoryUser.where(
user: @user,
notification_level: CategoryUser.notification_levels[:muted],
).pluck(:category_id)
end

vector_rep.asymmetric_topics_similarity_search(
raw_vector,
limit: limit,
offset: 0,
return_distance: true,
exclude_category_ids: muted_category_ids,
)
end
end
Expand Down
30 changes: 26 additions & 4 deletions lib/embeddings/vector_representations/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,22 @@ def post_id_from_representation(raw_vector)
SQL
end

def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
def asymmetric_topics_similarity_search(
raw_vector,
limit:,
offset:,
return_distance: false,
exclude_category_ids: nil
)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
Expand All @@ -176,6 +182,22 @@ def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_dist
OFFSET :offset
SQL

builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: id,
strategy_id: @strategy.id,
)

if exclude_category_ids.present?
builder.join("topics t on t.id = topic_id")
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
t.category_id NOT IN (:exclude_category_ids) AND
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
SQL
end

results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)

if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
Expand Down
43 changes: 43 additions & 0 deletions spec/lib/modules/ai_helper/semantic_categorizer_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# frozen_string_literal: true

RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:user)
fab!(:muted_category) { Fabricate(:category) }
fab!(:category_mute) do
CategoryUser.create!(
user: user,
category: muted_category,
notification_level: CategoryUser.notification_levels[:muted],
)
end
fab!(:muted_topic) { Fabricate(:topic, category: muted_category) }
fab!(:category)
fab!(:topic) { Fabricate(:topic, category: category) }

let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"

WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))

vector_rep.generate_representation_from(topic)
vector_rep.generate_representation_from(muted_topic)
end

it "respects user muted categories when making suggestions" do
category_ids = categorizer.categories.map { |c| c[:id] }
expect(category_ids).not_to include(muted_category.id)
expect(category_ids).to include(category.id)
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,38 @@
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
).to contain_exactly(topic.id)
end

it "can exclude categories" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)

expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty

child_category = Fabricate(:category, parent_category_id: topic.category_id)
topic.update!(category_id: child_category.id)

expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
end
end
end