Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit ddf2bf7

Browse files
authored
DEV: Backfill embeddings concurrently. (#941)
We are adding a new method for generating and storing embeddings in bulk, which relies on `Concurrent::Promises::Future`. Generating an embedding consists of three steps: Prepare text HTTP call to retrieve the vector Save to DB. Each one is independently executed on whatever thread the pool gives us. We are bringing a custom thread pool instead of the global executor since we want control over how many threads we spawn to limit concurrency. We also avoid firing thousands of HTTP requests when working with large batches.
1 parent 23193ee commit ddf2bf7

File tree

4 files changed

+111
-31
lines changed

4 files changed

+111
-31
lines changed

app/jobs/scheduled/embeddings_backfill.rb

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def execute(args)
7575
# First, we'll try to backfill embeddings for posts that have none
7676
posts
7777
.where("#{table_name}.post_id IS NULL")
78-
.find_each do |t|
79-
vector_rep.generate_representation_from(t)
80-
rebaked += 1
78+
.find_in_batches do |batch|
79+
vector_rep.gen_bulk_reprensentations(batch)
80+
rebaked += batch.size
8181
end
8282

8383
return if rebaked >= limit
@@ -90,24 +90,28 @@ def execute(args)
9090
OR
9191
#{table_name}.strategy_version < #{strategy.version}
9292
SQL
93-
.find_each do |t|
94-
vector_rep.generate_representation_from(t)
95-
rebaked += 1
93+
.find_in_batches do |batch|
94+
vector_rep.gen_bulk_reprensentations(batch)
95+
rebaked += batch.size
9696
end
9797

9898
return if rebaked >= limit
9999

100100
# Finally, we'll try to backfill embeddings for posts that have outdated
101101
# embeddings due to edits. Here we only do 10% of the limit
102-
posts
103-
.where("#{table_name}.updated_at < ?", 7.days.ago)
104-
.order("random()")
105-
.limit((limit - rebaked) / 10)
106-
.pluck(:id)
107-
.each do |id|
108-
vector_rep.generate_representation_from(Post.find_by(id: id))
109-
rebaked += 1
110-
end
102+
posts_batch_size = 1000
103+
104+
outdated_post_ids =
105+
posts
106+
.where("#{table_name}.updated_at < ?", 7.days.ago)
107+
.order("random()")
108+
.limit((limit - rebaked) / 10)
109+
.pluck(:id)
110+
111+
outdated_post_ids.each_slice(posts_batch_size) do |batch|
112+
vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC"))
113+
rebaked += batch.length
114+
end
111115

112116
rebaked
113117
end
@@ -120,14 +124,13 @@ def populate_topic_embeddings(vector_rep, topics, force: false)
120124
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
121125

122126
ids = topics.pluck("topics.id")
127+
batch_size = 1000
123128

124-
ids.each do |id|
125-
topic = Topic.find_by(id: id)
126-
if topic
127-
vector_rep.generate_representation_from(topic)
128-
done += 1
129-
end
129+
ids.each_slice(batch_size) do |batch|
130+
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
131+
done += batch.length
130132
end
133+
131134
done
132135
end
133136
end

lib/embeddings/vector_representations/base.rb

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,38 @@ def vector_from(text, asymetric: false)
5050
raise NotImplementedError
5151
end
5252

53+
def gen_bulk_reprensentations(relation)
54+
http_pool_size = 100
55+
pool =
56+
Concurrent::CachedThreadPool.new(
57+
min_threads: 0,
58+
max_threads: http_pool_size,
59+
idletime: 30,
60+
)
61+
62+
embedding_gen = inference_client
63+
promised_embeddings =
64+
relation.map do |record|
65+
materials = { target: record, text: prepare_text(record) }
66+
67+
Concurrent::Promises
68+
.fulfilled_future(materials, pool)
69+
.then_on(pool) do |w_prepared_text|
70+
w_prepared_text.merge(
71+
embedding: embedding_gen.perform!(w_prepared_text[:text]),
72+
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
73+
)
74+
end
75+
end
76+
77+
Concurrent::Promises
78+
.zip(*promised_embeddings)
79+
.value!
80+
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
81+
end
82+
5383
def generate_representation_from(target, persist: true)
54-
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
84+
text = prepare_text(target)
5585
return if text.blank?
5686

5787
target_column =
@@ -429,6 +459,10 @@ def save_to_db(target, vector, digest)
429459
def inference_client
430460
raise NotImplementedError
431461
end
462+
463+
def prepare_text(record)
464+
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
465+
end
432466
end
433467
end
434468
end

lib/embeddings/vector_representations/multilingual_e5_large.rb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def vector_from(text, asymetric: false)
3434
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
3535
if needs_truncation
3636
text = tokenizer.truncate(text, max_sequence_length - 2)
37-
else
37+
elsif !text.starts_with?("query:")
3838
text = "query: #{text}"
3939
end
4040

@@ -79,6 +79,14 @@ def inference_client
7979
raise "No inference endpoint configured"
8080
end
8181
end
82+
83+
def prepare_text(record)
84+
if inference_client.class.name.include?("DiscourseClassifier")
85+
return "query: #{super(record)}"
86+
end
87+
88+
super(record)
89+
end
8290
end
8391
end
8492
end

spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# frozen_string_literal: true
22

33
RSpec.shared_examples "generates and store embedding using with vector representation" do
4-
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
4+
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
5+
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
56

67
describe "#vector_from" do
78
it "creates a vector from a given string" do
89
text = "This is a piece of text"
9-
stub_vector_mapping(text, @expected_embedding)
10+
stub_vector_mapping(text, expected_embedding_1)
1011

11-
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
12+
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
1213
end
1314
end
1415

@@ -24,11 +25,11 @@
2425
vector_rep.tokenizer,
2526
vector_rep.max_sequence_length - 2,
2627
)
27-
stub_vector_mapping(text, @expected_embedding)
28+
stub_vector_mapping(text, expected_embedding_1)
2829

2930
vector_rep.generate_representation_from(topic)
3031

31-
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
32+
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
3233
end
3334

3435
it "creates a vector from a post and stores it in the database" do
@@ -38,11 +39,45 @@
3839
vector_rep.tokenizer,
3940
vector_rep.max_sequence_length - 2,
4041
)
41-
stub_vector_mapping(text, @expected_embedding)
42+
stub_vector_mapping(text, expected_embedding_1)
4243

4344
vector_rep.generate_representation_from(post)
4445

45-
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
46+
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
47+
end
48+
end
49+
50+
describe "#gen_bulk_reprensentations" do
51+
fab!(:topic) { Fabricate(:topic) }
52+
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
53+
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
54+
55+
fab!(:topic_2) { Fabricate(:topic) }
56+
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
57+
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
58+
59+
it "creates a vector for each object in the relation" do
60+
text =
61+
truncation.prepare_text_from(
62+
topic,
63+
vector_rep.tokenizer,
64+
vector_rep.max_sequence_length - 2,
65+
)
66+
67+
text2 =
68+
truncation.prepare_text_from(
69+
topic_2,
70+
vector_rep.tokenizer,
71+
vector_rep.max_sequence_length - 2,
72+
)
73+
74+
stub_vector_mapping(text, expected_embedding_1)
75+
stub_vector_mapping(text2, expected_embedding_2)
76+
77+
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
78+
79+
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
80+
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
4681
end
4782
end
4883

@@ -58,7 +93,7 @@
5893
vector_rep.tokenizer,
5994
vector_rep.max_sequence_length - 2,
6095
)
61-
stub_vector_mapping(text, @expected_embedding)
96+
stub_vector_mapping(text, expected_embedding_1)
6297
vector_rep.generate_representation_from(topic)
6398

6499
expect(

0 commit comments

Comments
 (0)