diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index db0828f2f..3300b4797 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -75,9 +75,9 @@ def execute(args) # First, we'll try to backfill embeddings for posts that have none posts .where("#{table_name}.post_id IS NULL") - .find_in_batches do |batch| - vector_rep.gen_bulk_reprensentations(batch) - rebaked += batch.size + .find_each do |t| + vector_rep.generate_representation_from(t) + rebaked += 1 end return if rebaked >= limit @@ -90,28 +90,24 @@ def execute(args) OR #{table_name}.strategy_version < #{strategy.version} SQL - .find_in_batches do |batch| - vector_rep.gen_bulk_reprensentations(batch) - rebaked += batch.size + .find_each do |t| + vector_rep.generate_representation_from(t) + rebaked += 1 end return if rebaked >= limit # Finally, we'll try to backfill embeddings for posts that have outdated # embeddings due to edits. Here we only do 10% of the limit - posts_batch_size = 1000 - - outdated_post_ids = - posts - .where("#{table_name}.updated_at < ?", 7.days.ago) - .order("random()") - .limit((limit - rebaked) / 10) - .pluck(:id) - - outdated_post_ids.each_slice(posts_batch_size) do |batch| - vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC")) - rebaked += batch.length - end + posts + .where("#{table_name}.updated_at < ?", 7.days.ago) + .order("random()") + .limit((limit - rebaked) / 10) + .pluck(:id) + .each do |id| + vector_rep.generate_representation_from(Post.find_by(id: id)) + rebaked += 1 + end rebaked end @@ -124,13 +120,14 @@ def populate_topic_embeddings(vector_rep, topics, force: false) topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force ids = topics.pluck("topics.id") - batch_size = 1000 - ids.each_slice(batch_size) do |batch| - vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC")) - done += batch.length + ids.each do |id| + topic = Topic.find_by(id: id) + if topic + vector_rep.generate_representation_from(topic) + done += 1 + end end - done end end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index e1f3ff497..be6b46b57 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -50,38 +50,8 @@ def vector_from(text, asymetric: false) raise NotImplementedError end - def gen_bulk_reprensentations(relation) - http_pool_size = 100 - pool = - Concurrent::CachedThreadPool.new( - min_threads: 0, - max_threads: http_pool_size, - idletime: 30, - ) - - embedding_gen = inference_client - promised_embeddings = - relation.map do |record| - materials = { target: record, text: prepare_text(record) } - - Concurrent::Promises - .fulfilled_future(materials, pool) - .then_on(pool) do |w_prepared_text| - w_prepared_text.merge( - embedding: embedding_gen.perform!(w_prepared_text[:text]), - digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]), - ) - end - end - - Concurrent::Promises - .zip(*promised_embeddings) - .value! - .each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) } - end - def generate_representation_from(target, persist: true) - text = prepare_text(target) + text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) return if text.blank? target_column = @@ -459,10 +429,6 @@ def save_to_db(target, vector, digest) def inference_client raise NotImplementedError end - - def prepare_text(record) - @strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2) - end end end end diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index 605ec8b55..c7ef3c0fe 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -34,7 +34,7 @@ def vector_from(text, asymetric: false) needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") if needs_truncation text = tokenizer.truncate(text, max_sequence_length - 2) - elsif !text.starts_with?("query:") + else text = "query: #{text}" end @@ -79,14 +79,6 @@ def inference_client raise "No inference endpoint configured" end end - - def prepare_text(record) - if inference_client.class.name.include?("DiscourseClassifier") - return "query: #{super(record)}" - end - - super(record) - end end end end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb index fce9f6123..9689a3c61 100644 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -1,15 +1,14 @@ # frozen_string_literal: true RSpec.shared_examples "generates and store embedding using with vector representation" do - let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions } - let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions } + before { @expected_embedding = [0.0038493] * vector_rep.dimensions } describe "#vector_from" do it "creates a vector from a given string" do text = "This is a piece of text" - stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text, @expected_embedding) - expect(vector_rep.vector_from(text)).to eq(expected_embedding_1) + expect(vector_rep.vector_from(text)).to eq(@expected_embedding) end end @@ -25,11 +24,11 @@ vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text, @expected_embedding) vector_rep.generate_representation_from(topic) - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) + expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id) end it "creates a vector from a post and stores it in the database" do @@ -39,45 +38,11 @@ vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text, @expected_embedding) vector_rep.generate_representation_from(post) - expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id) - end - end - - describe "#gen_bulk_reprensentations" do - fab!(:topic) { Fabricate(:topic) } - fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } - fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) } - - fab!(:topic_2) { Fabricate(:topic) } - fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) } - fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) } - - it "creates a vector for each object in the relation" do - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - - text2 = - truncation.prepare_text_from( - topic_2, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - - stub_vector_mapping(text, expected_embedding_1) - stub_vector_mapping(text2, expected_embedding_2) - - vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) - - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) + expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id) end end @@ -93,7 +58,7 @@ vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text, @expected_embedding) vector_rep.generate_representation_from(topic) expect(