diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake index 6d4ffd0f0..1981e6ca0 100644 --- a/lib/tasks/modules/embeddings/database.rake +++ b/lib/tasks/modules/embeddings/database.rake @@ -1,24 +1,21 @@ # frozen_string_literal: true desc "Backfill embeddings for all topics and posts" -task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args| +task "ai:embeddings:backfill", %i[embedding_def_id concurrency] => [:environment] do |_, args| public_categories = Category.where(read_restricted: false).pluck(:id) - if args[:model].present? - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new( - strategy, - ) + if args[:embedding_def_id].present? + vdef = EmbeddingDefinition.find(args[:embedding_def_id]) + vector_rep = DiscourseAi::Embeddings::Vector.new(vdef) else - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector_rep = DiscourseAi::Embeddings::Vector.instance end - table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE + topics_table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE topics = Topic - .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id") - .where("#{table_name}.topic_id IS NULL") + .joins("LEFT JOIN #{topics_table_name} ON #{topics_table_name}.topic_id = topics.id") + .where("#{topics_table_name}.topic_id IS NULL") .where("category_id IN (?)", public_categories) .where(deleted_at: nil) .order("topics.id DESC") @@ -29,11 +26,11 @@ task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, ar end end - table_name = vector_rep.post_table_name + posts_table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE posts = Post - .joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id") - .where("#{table_name}.post_id IS NULL") + .joins("LEFT JOIN #{posts_table_name} ON #{posts_table_name}.post_id = posts.id") + .where("#{posts_table_name}.post_id IS NULL") .where(deleted_at: nil) .order("posts.id DESC")