Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions app/jobs/regular/generate_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def execute(args)
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.raw.blank?

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation

vector_rep.generate_representation_from(target)
DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
end
end
end
4 changes: 2 additions & 2 deletions app/jobs/regular/generate_rag_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class GenerateRagEmbeddings < ::Jobs::Base
def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance

# generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes.
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
fragments.map { |fragment| vector.generate_representation_from(fragment) }

last_fragment = fragments.last
target = last_fragment.target
Expand Down
27 changes: 14 additions & 13 deletions app/jobs/scheduled/embeddings_backfill.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def execute(args)

rebaked = 0

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
vector_def = vector.vdef
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE

topics =
Expand All @@ -30,19 +31,19 @@ def execute(args)
.where(deleted_at: nil)
.order("topics.bumped_at DESC")

rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked))
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))

return if rebaked >= limit

# Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version
relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_rep.version}
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_rep.strategy_version}
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL

rebaked += populate_topic_embeddings(vector_rep, relation)
rebaked += populate_topic_embeddings(vector, relation)

return if rebaked >= limit

Expand All @@ -54,7 +55,7 @@ def execute(args)
.where("#{table_name}.updated_at < topics.updated_at")
.limit((limit - rebaked) / 10)

populate_topic_embeddings(vector_rep, relation, force: true)
populate_topic_embeddings(vector, relation, force: true)

return if rebaked >= limit

Expand All @@ -76,7 +77,7 @@ def execute(args)
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

Expand All @@ -86,14 +87,14 @@ def execute(args)
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_rep.version}
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_rep.strategy_version}
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

Expand All @@ -107,7 +108,7 @@ def execute(args)
.limit((limit - rebaked) / 10)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

Expand All @@ -116,7 +117,7 @@ def execute(args)

private

def populate_topic_embeddings(vector_rep, topics, force: false)
def populate_topic_embeddings(vector, topics, force: false)
done = 0

topics =
Expand All @@ -126,7 +127,7 @@ def populate_topic_embeddings(vector_rep, topics, force: false)
batch_size = 1000

ids.each_slice(batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
done += batch.length
end

Expand Down
6 changes: 3 additions & 3 deletions lib/ai_bot/personas/persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def rag_fragments_prompt(conversation_context, llm:, user:)

return nil if !consolidated_question

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings

interactions_vector = vector_rep.vector_from(consolidated_question)
interactions_vector = vector.vector_from(consolidated_question)

rag_conversation_chunks = self.class.rag_conversation_chunks
search_limit =
Expand All @@ -327,7 +327,7 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
rag_conversation_chunks
end

schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)

candidate_fragment_ids =
schema
Expand Down
5 changes: 2 additions & 3 deletions lib/ai_bot/tool_runner.rb
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,10 @@ def rag_search(query, filenames: nil, limit: 10)

return [] if upload_refs.empty?

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
query_vector = vector_rep.vector_from(query)
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
fragment_ids =
DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment, vector: vector_rep)
.for(RagDocumentFragment)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON
Expand Down
6 changes: 3 additions & 3 deletions lib/ai_helper/semantic_categorizer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def tags
private

def nearest_neighbors(limit: 100)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
vector = DiscourseAi::Embeddings::Vector.instance
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)

raw_vector = vector_rep.vector_from(@text)
raw_vector = vector.vector_from(@text)

muted_category_ids = nil
if @user.present?
Expand Down
44 changes: 27 additions & 17 deletions lib/embeddings/schema.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,31 @@ class Schema

def self.for(
target_klass,
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
)
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector)
new(TOPICS_TABLE, "topic_id", vector_def)
when "Post"
new(POSTS_TABLE, "post_id", vector)
new(POSTS_TABLE, "post_id", vector_def)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
else
raise ArgumentError, "Invalid target type for embeddings"
end
end

def initialize(table, target_column, vector)
def initialize(table, target_column, vector_def)
@table = table
@target_column = target_column
@vector = vector
@vector_def = vector_def
end

attr_reader :table, :target_column, :vector
attr_reader :table, :target_column, :vector_def

def find_by_embedding(embedding)
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
DB.query(
<<~SQL,
SELECT *
FROM #{table}
WHERE
Expand All @@ -46,10 +47,15 @@ def find_by_embedding(embedding)
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
query_embedding: embedding,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end

def find_by_target(target)
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
DB.query(
<<~SQL,
SELECT *
FROM #{table}
WHERE
Expand All @@ -58,6 +64,10 @@ def find_by_target(target)
#{target_column} = :target_id
LIMIT 1
SQL
target_id: target.id,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end

def asymmetric_similarity_search(embedding, limit:, offset:)
Expand Down Expand Up @@ -87,8 +97,8 @@ def asymmetric_similarity_search(embedding, limit:, offset:)

builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: vector.id,
strategy_id: vector.strategy_id,
model_id: vector_def.id,
strategy_id: vector_def.strategy_id,
)

yield(builder) if block_given?
Expand Down Expand Up @@ -156,7 +166,7 @@ def symmetric_similarity_search(record)

yield(builder) if block_given?

builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
Expand All @@ -176,10 +186,10 @@ def store(record, embedding, digest)
updated_at = :now
SQL
target_id: record.id,
model_id: vector.id,
model_version: vector.version,
strategy_id: vector.strategy_id,
strategy_version: vector.strategy_version,
model_id: vector_def.id,
model_version: vector_def.version,
strategy_id: vector_def.strategy_id,
strategy_version: vector_def.strategy_version,
digest: digest,
embeddings: embedding,
now: Time.zone.now,
Expand All @@ -188,7 +198,7 @@ def store(record, embedding, digest)

private

delegate :dimensions, :pg_function, to: :vector
delegate :dimensions, :pg_function, to: :vector_def
end
end
end
3 changes: 1 addition & 2 deletions lib/embeddings/semantic_related.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ def self.clear_cache_for(topic)
def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
cache_for = results_ttl(topic)

Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
DiscourseAi::Embeddings::Schema
.for(Topic, vector: vector_rep)
.for(Topic)
.symmetric_similarity_search(topic)
.map(&:topic_id)
.tap do |candidate_ids|
Expand Down
18 changes: 8 additions & 10 deletions lib/embeddings/semantic_search.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def cached_query?(query)
Discourse.cache.read(embedding_key).present?
end

def vector_rep
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
def vector
@vector ||= DiscourseAi::Embeddings::Vector.instance
end

def hyde_embedding(search_term)
Expand All @@ -52,16 +52,14 @@ def hyde_embedding(search_term)

Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
end

def embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)

Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
end

# this ensures the candidate topics are over selected
Expand All @@ -84,7 +82,7 @@ def search_for_topics(query, page = 1, hyde: true)

over_selection_limit = limit * OVER_SELECTION_FACTOR

schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)

candidate_topic_ids =
schema.asymmetric_similarity_search(
Expand Down Expand Up @@ -114,7 +112,7 @@ def quick_search(query)

return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance

digest = OpenSSL::Digest::SHA1.hexdigest(search_term)

Expand All @@ -129,12 +127,12 @@ def quick_search(query)
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) do
vector_rep.vector_from(search_term, asymetric: true)
vector.vector_from(search_term, asymetric: true)
end

candidate_post_ids =
DiscourseAi::Embeddings::Schema
.for(Post, vector: vector_rep)
.for(Post, vector_def: vector.vdef)
.asymmetric_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,
Expand Down
17 changes: 13 additions & 4 deletions lib/embeddings/strategies/truncation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,28 @@ def version
1
end

def prepare_text_from(target, tokenizer, max_length)
def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2

case target
when Topic
topic_truncation(target, tokenizer, max_length)
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, tokenizer, max_length)
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
tokenizer.truncate(target.fragment, max_length)
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
end

def prepare_query_text(text, vdef, asymetric: false)
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
max_length = vdef.max_sequence_length - 2

vdef.tokenizer.truncate(text, max_length)
end

private

def topic_information(topic)
Expand Down
Loading
Loading