Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit eb6d86a

Browse files
committed
FEATURE: RAG search within tools
This implements the back end for RAG search
1 parent 03eccbe commit eb6d86a

File tree

5 files changed

+168
-13
lines changed

5 files changed

+168
-13
lines changed

app/jobs/regular/digest_rag_upload.rb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def chunk_document(file:, tokenizer:, chunk_tokens:, overlap_tokens:)
126126

127127
while overlap_token_ids.present?
128128
begin
129-
overlap = tokenizer.decode(overlap_token_ids) + split_char
129+
padding = split_char
130+
padding = " " if padding.empty?
131+
overlap = tokenizer.decode(overlap_token_ids) + padding
130132
break if overlap.encoding == Encoding::UTF_8
131133
rescue StandardError
132134
# it is possible that we truncated mid char
@@ -135,7 +137,7 @@ def chunk_document(file:, tokenizer:, chunk_tokens:, overlap_tokens:)
135137
end
136138

137139
# remove first word it is probably truncated
138-
overlap = overlap.split(" ", 2).last
140+
overlap = overlap.split(/\s/, 2).last.to_s.lstrip
139141
end
140142
end
141143

app/models/ai_tool.rb

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class AiTool < ActiveRecord::Base
77
validates :script, presence: true, length: { maximum: 100_000 }
88
validates :created_by_id, presence: true
99
belongs_to :created_by, class_name: "User"
10+
has_many :rag_document_fragments, dependent: :destroy, as: :target
1011

1112
def signature
1213
{ name: name, description: description, parameters: parameters.map(&:symbolize_keys) }
@@ -173,14 +174,16 @@ def self.presets
173174
#
174175
# Table name: ai_tools
175176
#
176-
# id :bigint not null, primary key
177-
# name :string not null
178-
# description :string not null
179-
# summary :string not null
180-
# parameters :jsonb not null
181-
# script :text not null
182-
# created_by_id :integer not null
183-
# enabled :boolean default(TRUE), not null
184-
# created_at :datetime not null
185-
# updated_at :datetime not null
177+
# id :bigint not null, primary key
178+
# name :string not null
179+
# description :string not null
180+
# summary :string not null
181+
# parameters :jsonb not null
182+
# script :text not null
183+
# created_by_id :integer not null
184+
# enabled :boolean default(TRUE), not null
185+
# created_at :datetime not null
186+
# updated_at :datetime not null
187+
# rag_chunk_tokens :integer default(374), not null
188+
# rag_chunk_overlap_tokens :integer default(10), not null
186189
#
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# frozen_string_literal: true
2+
3+
class AddRagColumnsToAiTools < ActiveRecord::Migration[7.1]
4+
def change
5+
add_column :ai_tools, :rag_chunk_tokens, :integer, null: false, default: 374
6+
add_column :ai_tools, :rag_chunk_overlap_tokens, :integer, null: false, default: 10
7+
end
8+
end

lib/ai_bot/tool_runner.rb

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def mini_racer_context
3535
)
3636
attach_truncate(ctx)
3737
attach_http(ctx)
38+
attach_index(ctx)
3839
ctx.eval(framework_script)
3940
ctx
4041
end
@@ -50,6 +51,10 @@ def framework_script
5051
const llm = {
5152
truncate: _llm_truncate,
5253
};
54+
55+
const index = {
56+
search: _index_search,
57+
}
5358
function details() { return ""; };
5459
JS
5560
end
@@ -105,13 +110,73 @@ def invoke
105110

106111
private
107112

113+
MAX_FRAGMENTS = 200
114+
115+
def rag_search(query, filenames: nil, limit: 10)
116+
limit = limit.to_i
117+
return [] if limit < 1
118+
limit = [MAX_FRAGMENTS, limit].min
119+
120+
upload_refs = UploadReference.where(target_id: tool.id, target_type: "AiTool").pluck(:upload_id)
121+
122+
if filenames
123+
upload_refs = Upload.where(id: upload_refs).where(original_filename: filenames).pluck(:id)
124+
end
125+
126+
if upload_refs.empty?
127+
return []
128+
end
129+
130+
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
131+
vector_rep =
132+
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
133+
query_vector = vector_rep.vector_from(query)
134+
fragment_ids =
135+
vector_rep.asymmetric_rag_fragment_similarity_search(
136+
query_vector,
137+
target_type: "AiTool",
138+
target_id: tool.id,
139+
limit: limit,
140+
offset: 0
141+
)
142+
fragments =
143+
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(
144+
:id,
145+
:fragment,
146+
:metadata,
147+
)
148+
149+
mapped = {}
150+
fragments.each do |id, fragment, metadata|
151+
mapped[id] = { fragment: fragment, metadata: metadata }
152+
end
153+
154+
fragment_ids.take(limit).map { |fragment_id| mapped[fragment_id] }
155+
end
156+
108157
def attach_truncate(mini_racer_context)
109158
mini_racer_context.attach(
110159
"_llm_truncate",
111160
->(text, length) { @llm.tokenizer.truncate(text, length) },
112161
)
113162
end
114163

164+
def attach_index(mini_racer_context)
165+
mini_racer_context.attach(
166+
"_index_search",
167+
->(query, options) do
168+
begin
169+
self.running_attached_function = true
170+
options ||= {}
171+
options = options.symbolize_keys
172+
self.rag_search(query, **options)
173+
ensure
174+
self.running_attached_function = false
175+
end
176+
end,
177+
)
178+
end
179+
115180
def attach_http(mini_racer_context)
116181
mini_racer_context.attach(
117182
"_http_get",

spec/models/ai_tool_spec.rb

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
55
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
66

7-
def create_tool(parameters: nil, script: nil)
7+
def create_tool(
8+
parameters: nil,
9+
script: nil,
10+
rag_chunk_tokens: nil,
11+
rag_chunk_overlap_tokens: nil
12+
)
813
AiTool.create!(
914
name: "test",
1015
description: "test",
1116
parameters: parameters || [{ name: "query", type: "string", desciption: "perform a search" }],
1217
script: script || "function invoke(params) { return params; }",
1318
created_by_id: 1,
1419
summary: "Test tool summary",
20+
rag_chunk_tokens: rag_chunk_tokens || 374,
21+
rag_chunk_overlap_tokens: rag_chunk_overlap_tokens || 10,
1522
)
1623
end
1724

@@ -193,4 +200,74 @@ def create_tool(parameters: nil, script: nil)
193200
result = runner.invoke
194201
expect(result[:error]).to eq("Script terminated due to timeout")
195202
end
203+
204+
context "when defining RAG fragments" do
205+
before do
206+
SiteSetting.authorized_extensions = "txt"
207+
SiteSetting.ai_embeddings_enabled = true
208+
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
209+
SiteSetting.ai_embeddings_model = "bge-large-en"
210+
211+
Jobs.run_immediately!
212+
end
213+
214+
def create_upload(content, filename)
215+
upload = nil
216+
Tempfile.create(filename) do |file|
217+
file.write(content)
218+
file.rewind
219+
220+
upload = UploadCreator.new(file, filename).create_for(Discourse.system_user.id)
221+
end
222+
upload
223+
end
224+
225+
def stub_embeddings
226+
# this is a trick, we get ever increasing embeddings, this gives us in turn
227+
# 100% consistent search results
228+
@counter = 0
229+
stub_request(:post, "http://test.com/api/v1/classify").to_return(
230+
status: 200,
231+
body: lambda { |req| ([@counter += 1] * 1024).to_json },
232+
headers: {
233+
},
234+
)
235+
end
236+
237+
it "allows search within uploads" do
238+
stub_embeddings
239+
240+
upload1 = create_upload(<<~TXT, "test.txt")
241+
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
242+
TXT
243+
244+
upload2 = create_upload(<<~TXT, "test.txt")
245+
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
246+
TXT
247+
248+
tool = create_tool(rag_chunk_tokens: 10, rag_chunk_overlap_tokens: 4, script: <<~JS)
249+
function invoke(params) {
250+
let result1 = index.search("testing a search", { limit: 1 });
251+
let result2 = index.search("testing another search", { limit: 3, filenames: ["test.txt"] });
252+
253+
return [result1, result2];
254+
}
255+
JS
256+
257+
RagDocumentFragment.link_target_and_uploads(tool, [upload1.id, upload2.id])
258+
259+
result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke
260+
261+
expected = [
262+
[{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil }],
263+
[
264+
{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil },
265+
{ "fragment" => "36 37 38 39 40 41 42 43 44 45", "metadata" => nil },
266+
{ "fragment" => "30 31 32 33 34 35 36 37", "metadata" => nil },
267+
],
268+
]
269+
270+
expect(result).to eq(expected)
271+
end
272+
end
196273
end

0 commit comments

Comments
 (0)