Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions app/jobs/scheduled/embeddings_backfill.rb
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def execute(args)
# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.find_each do |t|
vector_rep.generate_representation_from(t)
rebaked += 1
.find_in_batches do |batch|
vector_rep.gen_bulk_reprensentations(batch)
rebaked += batch.size
end

return if rebaked >= limit
Expand All @@ -90,24 +90,28 @@ def execute(args)
OR
#{table_name}.strategy_version < #{strategy.version}
SQL
.find_each do |t|
vector_rep.generate_representation_from(t)
rebaked += 1
.find_in_batches do |batch|
vector_rep.gen_bulk_reprensentations(batch)
rebaked += batch.size
end

return if rebaked >= limit

# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)
.each do |id|
vector_rep.generate_representation_from(Post.find_by(id: id))
rebaked += 1
end
posts_batch_size = 1000

outdated_post_ids =
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)

outdated_post_ids.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC"))
rebaked += batch.length
end

rebaked
end
Expand All @@ -120,14 +124,13 @@ def populate_topic_embeddings(vector_rep, topics, force: false)
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force

ids = topics.pluck("topics.id")
batch_size = 1000

ids.each do |id|
topic = Topic.find_by(id: id)
if topic
vector_rep.generate_representation_from(topic)
done += 1
end
ids.each_slice(batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
done += batch.length
end

done
end
end
Expand Down
36 changes: 35 additions & 1 deletion lib/embeddings/vector_representations/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,38 @@ def vector_from(text, asymetric: false)
raise NotImplementedError
end

def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)

embedding_gen = inference_client
promised_embeddings =
relation.map do |record|
materials = { target: record, text: prepare_text(record) }

Concurrent::Promises
.fulfilled_future(materials, pool)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(
embedding: embedding_gen.perform!(w_prepared_text[:text]),
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
)
end
end

Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
end

def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
text = prepare_text(target)
return if text.blank?

target_column =
Expand Down Expand Up @@ -429,6 +459,10 @@ def save_to_db(target, vector, digest)
def inference_client
raise NotImplementedError
end

def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end
end
end
Expand Down
10 changes: 9 additions & 1 deletion lib/embeddings/vector_representations/multilingual_e5_large.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def vector_from(text, asymetric: false)
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2)
else
elsif !text.starts_with?("query:")
text = "query: #{text}"
end

Expand Down Expand Up @@ -79,6 +79,14 @@ def inference_client
raise "No inference endpoint configured"
end
end

def prepare_text(record)
if inference_client.class.name.include?("DiscourseClassifier")
return "query: #{super(record)}"
end

super(record)
end
end
end
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# frozen_string_literal: true

RSpec.shared_examples "generates and store embedding using with vector representation" do
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }

describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)

expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
end
end

Expand All @@ -24,11 +25,11 @@
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)

vector_rep.generate_representation_from(topic)

expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end

it "creates a vector from a post and stores it in the database" do
Expand All @@ -38,11 +39,45 @@
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)

vector_rep.generate_representation_from(post)

expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
end
end

describe "#gen_bulk_reprensentations" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }

fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }

it "creates a vector for each object in the relation" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)

text2 =
truncation.prepare_text_from(
topic_2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)

stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)

vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))

expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end
end

Expand All @@ -58,7 +93,7 @@
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)

expect(
Expand Down