Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 105 additions & 95 deletions app/jobs/scheduled/embeddings_backfill.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,105 +18,115 @@ def execute(args)
)
end

rebaked = 0
production_vector = DiscourseAi::Embeddings::Vector.instance

vector = DiscourseAi::Embeddings::Vector.instance
vector_def = vector.vdef
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE

topics =
Topic
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}",
if SiteSetting.ai_embeddings_backfill_model.present? &&
SiteSetting.ai_embeddings_backfill_model != SiteSetting.ai_embeddings_selected_model
backfill_vector =
DiscourseAi::Embeddings::Vector.new(
EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_backfill_model),
)
.where(archetype: Archetype.default)
.where(deleted_at: nil)
.order("topics.bumped_at DESC")

rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))

return if rebaked >= limit

# Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version
relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL

rebaked += populate_topic_embeddings(vector, relation, force: true)

return if rebaked >= limit
end

# Finally, we'll try to backfill embeddings for topics that have outdated
# embeddings due to edits or new replies. Here we only do 10% of the limit
relation =
topics
.where("#{table_name}.updated_at < ?", 6.hours.ago)
.where("#{table_name}.updated_at < topics.updated_at")
topic_work_list = []
topic_work_list << production_vector
topic_work_list << backfill_vector if backfill_vector

topic_work_list.each do |vector|
rebaked = 0
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
vector_def = vector.vdef

topics =
Topic
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}",
)
.where(archetype: Archetype.default)
.where(deleted_at: nil)
.order("topics.bumped_at DESC")

rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))

next if rebaked >= limit

# Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version
relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL

rebaked += populate_topic_embeddings(vector, relation, force: true)

next if rebaked >= limit

# Finally, we'll try to backfill embeddings for topics that have outdated
# embeddings due to edits or new replies. Here we only do 10% of the limit
relation =
topics
.where("#{table_name}.updated_at < ?", 6.hours.ago)
.where("#{table_name}.updated_at < topics.updated_at")
.limit((limit - rebaked) / 10)

populate_topic_embeddings(vector, relation, force: true)

next unless SiteSetting.ai_embeddings_per_post_enabled

# Now for posts
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
posts_batch_size = 1000

posts =
Post
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}",
)
.where(deleted_at: nil)
.where(post_type: Post.types[:regular])

# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

next if rebaked >= limit

# Then, we'll try to backfill embeddings for posts that have outdated
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

next if rebaked >= limit

# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)

populate_topic_embeddings(vector, relation, force: true)

return if rebaked >= limit

return unless SiteSetting.ai_embeddings_per_post_enabled

# Now for posts
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
posts_batch_size = 1000

posts =
Post
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}",
)
.where(deleted_at: nil)
.where(post_type: Post.types[:regular])

# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

return if rebaked >= limit

# Then, we'll try to backfill embeddings for posts that have outdated
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

return if rebaked >= limit

# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

rebaked
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
end
end

private
Expand Down
12 changes: 9 additions & 3 deletions config/settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,26 @@ discourse_ai:
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
area: "ai-features/embeddings"
ai_embeddings_backfill_model:
type: enum
default: ""
allow_any: false
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
hidden: true
ai_embeddings_per_post_enabled:
default: false
hidden: true
ai_embeddings_generate_for_pms:
ai_embeddings_generate_for_pms:
default: false
area: "ai-features/embeddings"
ai_embeddings_semantic_related_topics_enabled:
default: false
client: true
area: "ai-features/embeddings"
ai_embeddings_semantic_related_topics:
ai_embeddings_semantic_related_topics:
default: 5
area: "ai-features/embeddings"
ai_embeddings_semantic_related_include_closed_topics:
ai_embeddings_semantic_related_include_closed_topics:
default: true
area: "ai-features/embeddings"
ai_embeddings_backfill_batch_size:
Expand Down
7 changes: 5 additions & 2 deletions lib/embeddings/schema.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ class Schema
MissingEmbeddingError = Class.new(StandardError)

class << self
def for(target_klass)
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
def for(target_klass, vector_def: nil)
vector_def =
EmbeddingDefinition.find_by(
id: SiteSetting.ai_embeddings_selected_model,
) if vector_def.nil?
raise "Invalid embeddings selected model" if vector_def.nil?

case target_klass&.name
Expand Down
13 changes: 11 additions & 2 deletions lib/embeddings/semantic_related.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
module DiscourseAi
module Embeddings
class SemanticRelated
CACHE_PREFIX = "semantic-suggested-topic-"

def self.clear_cache_for(topic)
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
Expand Down Expand Up @@ -79,14 +81,21 @@ def self.related_topics_for_crawler(controller)
)
end

def self.clear_cache!
Discourse
.cache
.keys("#{CACHE_PREFIX}*")
.each { |key| Discourse.cache.delete(key.split(":").last) }
end

private

def semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}"
"#{CACHE_PREFIX}#{topic_id}"
end

def build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}"
"build-#{CACHE_PREFIX}#{topic_id}"
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/embeddings/vector.rb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def gen_bulk_reprensentations(relation)
idletime: 30,
)

schema = DiscourseAi::Embeddings::Schema.for(relation.first.class)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: @vdef)

embedding_gen = vdef.inference_client
promised_embeddings =
Expand Down Expand Up @@ -58,7 +58,7 @@ def generate_representation_from(target)
text = vdef.prepare_target_text(target)
return if text.blank?

schema = DiscourseAi::Embeddings::Schema.for(target.class)
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: @vdef)

new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest
Expand Down
25 changes: 20 additions & 5 deletions spec/jobs/scheduled/embeddings_backfill_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@
end

fab!(:vector_def) { Fabricate(:embedding_definition) }
fab!(:vector_def2) { Fabricate(:embedding_definition) }
fab!(:embedding_array) { Array.new(1024) { 1 } }

before do
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_backfill_batch_size = 1
SiteSetting.ai_embeddings_per_post_enabled = true
Jobs.run_immediately!
end

it "backfills topics based on bumped_at date" do
embedding = Array.new(1024) { 1 }

WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
status: 200,
body: JSON.dump(embedding),
body: JSON.dump(embedding_array),
)
end

it "backfills topics based on bumped_at date" do
Jobs::EmbeddingsBackfill.new.execute({})

topic_ids =
Expand Down Expand Up @@ -68,4 +68,19 @@

expect(index_date).to be_within_one_second_of(Time.zone.now)
end

it "backfills embeddings for the ai_embeddings_backfill_model" do
SiteSetting.ai_embeddings_backfill_model = vector_def2.id
SiteSetting.ai_embeddings_backfill_batch_size = 100

Jobs::EmbeddingsBackfill.new.execute({})

topic_ids =
DB.query_single(
"SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE model_id = ?",
vector_def2.id,
)

expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
end
end
Loading