From 6cc612231d9df84507d561357042fa1e22288045 Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Fri, 29 Nov 2024 08:54:35 +1100 Subject: [PATCH 1/2] FEATURE: exclude muted categories from category suggester The logic here is that users do not particularly care about topics in the category so we can exclude them from tag and category suggestions --- lib/ai_helper/semantic_categorizer.rb | 14 ++++++ lib/embeddings/vector_representations/base.rb | 30 +++++++++++-- .../ai_helper/semantic_categorizer_spec.rb | 43 +++++++++++++++++++ .../vector_rep_shared_examples.rb | 33 ++++++++++++++ 4 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 spec/lib/modules/ai_helper/semantic_categorizer_spec.rb diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index 7fa46f292..612030c8f 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -12,6 +12,8 @@ def categories return [] unless SiteSetting.ai_embeddings_enabled candidates = nearest_neighbors(limit: 100) + return [] if candidates.empty? + candidate_ids = candidates.map(&:first) ::Topic @@ -52,6 +54,8 @@ def tags return [] unless SiteSetting.ai_embeddings_enabled candidates = nearest_neighbors(limit: 100) + return [] if candidates.empty? + candidate_ids = candidates.map(&:first) count_column = Tag.topic_count_column(@user.guardian) # Determine the count column @@ -94,11 +98,21 @@ def nearest_neighbors(limit: 100) raw_vector = vector_rep.vector_from(@text) + muted_category_ids = nil + if @user.present? + muted_category_ids = + CategoryUser.where( + user: @user, + notification_level: CategoryUser.notification_levels[:muted], + ).pluck(:category_id) + end + vector_rep.asymmetric_topics_similarity_search( raw_vector, limit: limit, offset: 0, return_distance: true, + exclude_category_ids: muted_category_ids, ) end end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 5f1130020..831a98707 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -151,16 +151,22 @@ def post_id_from_representation(raw_vector) SQL end - def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false) - results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset) + def asymmetric_topics_similarity_search( + raw_vector, + limit:, + offset:, + return_distance: false, + exclude_category_ids: nil + ) + builder = DB.build(<<~SQL) WITH candidates AS ( SELECT topic_id, embeddings::halfvec(#{dimensions}) AS embeddings FROM #{topic_table_name} - WHERE - model_id = #{id} AND strategy_id = #{@strategy.id} + /*join*/ + /*where*/ ORDER BY binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions})) LIMIT :limit * 2 @@ -176,6 +182,22 @@ def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_dist OFFSET :offset SQL + builder.where( + "model_id = :model_id AND strategy_id = :strategy_id", + model_id: id, + strategy_id: @strategy.id, + ) + + if exclude_category_ids.present? + builder.join("topics t on t.id = topic_id") + builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i)) + t.category_id NOT IN (:exclude_category_ids) AND + t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids)) + SQL + end + + results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset) + if return_distance results.map { |r| [r.topic_id, r.distance] } else diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb new file mode 100644 index 000000000..2df40837b --- /dev/null +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do + fab!(:user) + fab!(:muted_category) { Fabricate(:category) } + fab!(:category_mute) do + CategoryUser.create!( + user: user, + category: muted_category, + notification_level: CategoryUser.notification_levels[:muted], + ) + end + fab!(:muted_topic) { Fabricate(:topic, category: muted_category) } + fab!(:category) { Fabricate(:category) } + fab!(:topic) { Fabricate(:topic, category: category) } + + let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } + let(:vector_rep) do + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) + end + let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } + let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } + + before do + SiteSetting.ai_embeddings_enabled = true + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + SiteSetting.ai_embeddings_model = "bge-large-en" + + WebMock.stub_request( + :post, + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + ).to_return(status: 200, body: JSON.dump(expected_embedding)) + + vector_rep.generate_representation_from(topic) + vector_rep.generate_representation_from(muted_topic) + end + + it "respects user muted categories when making suggestions" do + category_ids = categorizer.categories.map { |c| c[:id] } + expect(category_ids).not_to include(muted_category.id) + expect(category_ids).to include(category.id) + end +end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb index 075e6930b..a91344039 100644 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -104,5 +104,38 @@ vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0), ).to contain_exactly(topic.id) end + + it "can exclude categories" do + similar_vector = [0.0038494] * vector_rep.dimensions + text = + truncation.prepare_text_from( + topic, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + stub_vector_mapping(text, expected_embedding_1) + vector_rep.generate_representation_from(topic) + + expect( + vector_rep.asymmetric_topics_similarity_search( + similar_vector, + limit: 1, + offset: 0, + exclude_category_ids: [topic.category_id], + ), + ).to be_empty + + child_category = Fabricate(:category, parent_category_id: topic.category_id) + topic.update!(category_id: child_category.id) + + expect( + vector_rep.asymmetric_topics_similarity_search( + similar_vector, + limit: 1, + offset: 0, + exclude_category_ids: [topic.category_id], + ), + ).to be_empty + end end end From e639e39a8aa092a14001e8815c762660c4b4626b Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Fri, 29 Nov 2024 08:57:15 +1100 Subject: [PATCH 2/2] lint --- spec/lib/modules/ai_helper/semantic_categorizer_spec.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb index 2df40837b..8d40b572e 100644 --- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -11,7 +11,7 @@ ) end fab!(:muted_topic) { Fabricate(:topic, category: muted_category) } - fab!(:category) { Fabricate(:category) } + fab!(:category) fab!(:topic) { Fabricate(:topic, category: category) } let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }