Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 04f25bb

Browse files
committed
Allow AI tool to call search directly
1 parent 53c086a commit 04f25bb

File tree

6 files changed

+432
-91
lines changed

6 files changed

+432
-91
lines changed

lib/ai_bot/tool_runner.rb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def framework_script
7373
};
7474
7575
const discourse = {
76+
search: function(params) {
77+
return _discourse_search(params);
78+
},
7679
getPost: _discourse_get_post,
7780
getUser: _discourse_get_user,
7881
getPersona: function(name) {
@@ -341,6 +344,21 @@ def attach_discourse(mini_racer_context)
341344
end
342345
end,
343346
)
347+
348+
mini_racer_context.attach(
349+
"_discourse_search",
350+
->(params) do
351+
in_attached_function do
352+
search_params = params.symbolize_keys
353+
if search_params.delete(:with_private)
354+
search_params[:current_user] = Discourse.system_user
355+
end
356+
search_params[:result_style] = :detailed
357+
results = DiscourseAi::Utils::Search.perform_search(**search_params)
358+
recursive_as_json(results)
359+
end
360+
end,
361+
)
344362
end
345363

346364
def attach_upload(mini_racer_context)

lib/ai_bot/tools/search.rb

Lines changed: 25 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def signature
3434
enum: %w[latest latest_topic oldest views likes],
3535
},
3636
{
37-
name: "limit",
37+
name: "max_results",
3838
description:
3939
"limit number of results returned (generally prefer to just keep to default)",
4040
type: "integer",
@@ -103,102 +103,38 @@ def search_query
103103

104104
def invoke
105105
search_terms = []
106-
107106
search_terms << options[:base_query] if options[:base_query].present?
108-
search_terms << search_query.strip if search_query.present?
107+
search_terms << search_query if search_query.present?
109108
search_args.each { |key, value| search_terms << "#{key}:#{value}" if value.present? }
110109

111-
guardian = nil
112-
if options[:search_private] && context[:user]
113-
guardian = Guardian.new(context[:user])
114-
else
115-
guardian = Guardian.new
116-
search_terms << "status:public"
117-
end
118-
119-
search_string = search_terms.join(" ").to_s
120-
@last_query = search_string
121-
122-
yield(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
110+
@last_query = search_terms.join(" ").to_s
123111

124-
results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian)
112+
yield(I18n.t("discourse_ai.ai_bot.searching", query: @last_query))
125113

126114
max_results = calculate_max_results(llm)
127-
results_limit = parameters[:limit] || max_results
128-
results_limit = max_results if parameters[:limit].to_i > max_results
129-
130-
should_try_semantic_search =
131-
SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present?
132-
133-
max_semantic_results = max_results / 4
134-
results_limit = results_limit - max_semantic_results if should_try_semantic_search
135-
136-
posts = results&.posts || []
137-
posts = posts[0..results_limit.to_i - 1]
138-
139-
if should_try_semantic_search
140-
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
141-
topic_ids = Set.new(posts.map(&:topic_id))
142-
143-
search = ::Search.new(search_string, guardian: guardian)
144-
145-
results = nil
146-
begin
147-
results = semantic_search.search_for_topics(search.term)
148-
rescue => e
149-
Discourse.warn_exception(e, message: "Semantic search failed")
150-
end
151-
152-
if results
153-
results = search.apply_filters(results)
154-
155-
results.each do |post|
156-
next if topic_ids.include?(post.topic_id)
157-
158-
topic_ids << post.topic_id
159-
posts << post
160-
161-
break if posts.length >= max_results
162-
end
163-
end
115+
if parameters[:max_results].to_i > 0
116+
max_results = [parameters[:max_results].to_i, max_results].min
164117
end
165118

166-
@last_num_results = posts.length
167-
# this is the general pattern from core
168-
# if there are millions of hidden tags it may fail
169-
hidden_tags = nil
170-
171-
if posts.blank?
172-
{ args: parameters, rows: [], instruction: "nothing was found, expand your search" }
173-
else
174-
format_results(posts, args: parameters) do |post|
175-
category_names = [
176-
post.topic.category&.parent_category&.name,
177-
post.topic.category&.name,
178-
].compact.join(" > ")
179-
row = {
180-
title: post.topic.title,
181-
url: Discourse.base_path + post.url,
182-
username: post.user&.username,
183-
excerpt: post.excerpt,
184-
created: post.created_at,
185-
category: category_names,
186-
likes: post.like_count,
187-
topic_views: post.topic.views,
188-
topic_likes: post.topic.like_count,
189-
topic_replies: post.topic.posts_count - 1,
190-
}
191-
192-
if SiteSetting.tagging_enabled
193-
hidden_tags ||= DiscourseTagging.hidden_tag_names
194-
# using map over pluck to avoid n+1 (assuming caller preloading)
195-
tags = post.topic.tags.map(&:name) - hidden_tags
196-
row[:tags] = tags.join(", ") if tags.present?
197-
end
198-
199-
row
200-
end
201-
end
119+
search_query_with_base = [options[:base_query], search_query].compact.join(" ").strip
120+
121+
results =
122+
DiscourseAi::Utils::Search.perform_search(
123+
search_query: search_query_with_base,
124+
category: parameters[:category],
125+
user: parameters[:user],
126+
order: parameters[:order],
127+
max_posts: parameters[:max_posts],
128+
tags: parameters[:tags],
129+
before: parameters[:before],
130+
after: parameters[:after],
131+
status: parameters[:status],
132+
max_results: max_results,
133+
current_user: options[:search_private] ? context[:user] : nil,
134+
)
135+
136+
@last_num_results = results[:rows]&.length || 0
137+
results
202138
end
203139

204140
protected

lib/utils/search.rb

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Utils
5+
class Search
6+
def self.perform_search(
7+
search_query: nil,
8+
category: nil,
9+
user: nil,
10+
order: nil,
11+
max_posts: nil,
12+
tags: nil,
13+
before: nil,
14+
after: nil,
15+
status: nil,
16+
hyde: true,
17+
max_results: 20,
18+
current_user: nil,
19+
result_style: :compact
20+
)
21+
search_terms = []
22+
23+
search_terms << search_query.strip if search_query.present?
24+
search_terms << "category:#{category}" if category.present?
25+
search_terms << "user:#{user}" if user.present?
26+
search_terms << "order:#{order}" if order.present?
27+
search_terms << "max_posts:#{max_posts}" if max_posts.present?
28+
search_terms << "tags:#{tags}" if tags.present?
29+
search_terms << "before:#{before}" if before.present?
30+
search_terms << "after:#{after}" if after.present?
31+
search_terms << "status:#{status}" if status.present?
32+
33+
guardian = Guardian.new(current_user)
34+
35+
search_string = search_terms.join(" ").to_s
36+
37+
results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian)
38+
results_limit = max_results
39+
40+
should_try_semantic_search =
41+
SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present?
42+
43+
max_semantic_results = max_results / 4
44+
results_limit = results_limit - max_semantic_results if should_try_semantic_search
45+
46+
posts = results&.posts || []
47+
posts = posts[0..results_limit.to_i - 1]
48+
49+
if should_try_semantic_search
50+
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
51+
topic_ids = Set.new(posts.map(&:topic_id))
52+
53+
search = ::Search.new(search_string, guardian: guardian)
54+
55+
semantic_results = nil
56+
begin
57+
semantic_results = semantic_search.search_for_topics(search.term, hyde: hyde)
58+
rescue => e
59+
Discourse.warn_exception(e, message: "Semantic search failed")
60+
end
61+
62+
if semantic_results
63+
semantic_results = search.apply_filters(semantic_results)
64+
65+
semantic_results.each do |post|
66+
next if topic_ids.include?(post.topic_id)
67+
68+
topic_ids << post.topic_id
69+
posts << post
70+
71+
break if posts.length >= max_results
72+
end
73+
end
74+
end
75+
76+
hidden_tags = nil
77+
78+
# Construct search_args hash for consistent return format
79+
search_args = {
80+
search_query: search_query,
81+
category: category,
82+
user: user,
83+
order: order,
84+
max_posts: max_posts,
85+
tags: tags,
86+
before: before,
87+
after: after,
88+
status: status,
89+
max_results: max_results,
90+
}.compact
91+
92+
if posts.blank?
93+
{ args: search_args, rows: [], instruction: "nothing was found, expand your search" }
94+
else
95+
format_results(posts, args: search_args, result_style: result_style) do |post|
96+
category_names = [
97+
post.topic.category&.parent_category&.name,
98+
post.topic.category&.name,
99+
].compact.join(" > ")
100+
row = {
101+
title: post.topic.title,
102+
url: Discourse.base_path + post.url,
103+
username: post.user&.username,
104+
excerpt: post.excerpt,
105+
created: post.created_at,
106+
category: category_names,
107+
likes: post.like_count,
108+
topic_views: post.topic.views,
109+
topic_likes: post.topic.like_count,
110+
topic_replies: post.topic.posts_count - 1,
111+
}
112+
113+
if SiteSetting.tagging_enabled
114+
hidden_tags ||= DiscourseTagging.hidden_tag_names
115+
tags = post.topic.tags.map(&:name) - hidden_tags
116+
row[:tags] = tags.join(", ") if tags.present?
117+
end
118+
119+
row
120+
end
121+
end
122+
end
123+
124+
private
125+
126+
def self.format_results(rows, args: nil, result_style:)
127+
rows = rows&.map { |row| yield row } if block_given?
128+
129+
if result_style == :compact
130+
index = -1
131+
column_indexes = {}
132+
133+
rows =
134+
rows&.map do |data|
135+
new_row = []
136+
data.each do |key, value|
137+
found_index = column_indexes[key.to_s] ||= (index += 1)
138+
new_row[found_index] = value
139+
end
140+
new_row
141+
end
142+
column_names = column_indexes.keys
143+
end
144+
145+
result = { column_names: column_names, rows: rows }
146+
result[:args] = args if args
147+
result
148+
end
149+
end
150+
end
151+
end

spec/lib/modules/ai_bot/tools/search_spec.rb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898

9999
results = search.invoke(&progress_blk)
100100

101-
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" })
101+
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake", max_results: 60 })
102102
expect(results[:rows]).to eq([])
103103
end
104104

@@ -131,7 +131,9 @@
131131
search.invoke(&progress_blk)
132132
end
133133

134-
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
134+
expect(results[:args]).to eq(
135+
{ max_results: 60, search_query: "hello world, sam", status: "public" },
136+
)
135137
expect(results[:rows].length).to eq(1)
136138

137139
# it also works with no query
@@ -174,6 +176,7 @@
174176
[param[:name], "test"]
175177
end
176178
end
179+
.compact
177180
.to_h
178181
.symbolize_keys
179182

0 commit comments

Comments
 (0)