Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 269f169

Browse files
authored
DEV: add API to get topic info from a custom tool (#1197)
Previously we could only get post info
1 parent 7d7c169 commit 269f169

File tree

2 files changed

+77
-15
lines changed

2 files changed

+77
-15
lines changed

lib/ai_bot/tool_runner.rb

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def framework_script
7777
return _discourse_search(params);
7878
},
7979
getPost: _discourse_get_post,
80+
getTopic: _discourse_get_topic,
8081
getUser: _discourse_get_user,
8182
getPersona: function(name) {
8283
return {
@@ -276,7 +277,28 @@ def attach_discourse(mini_racer_context)
276277
post = Post.find_by(id: post_id)
277278
return nil if post.nil?
278279
guardian = Guardian.new(Discourse.system_user)
279-
recursive_as_json(PostSerializer.new(post, scope: guardian, root: false))
280+
obj =
281+
recursive_as_json(
282+
PostSerializer.new(post, scope: guardian, root: false, add_raw: true),
283+
)
284+
topic_obj =
285+
recursive_as_json(
286+
ListableTopicSerializer.new(post.topic, scope: guardian, root: false),
287+
)
288+
obj["topic"] = topic_obj
289+
obj
290+
end
291+
end,
292+
)
293+
294+
mini_racer_context.attach(
295+
"_discourse_get_topic",
296+
->(topic_id) do
297+
in_attached_function do
298+
topic = Topic.find_by(id: topic_id)
299+
return nil if topic.nil?
300+
guardian = Guardian.new(Discourse.system_user)
301+
recursive_as_json(ListableTopicSerializer.new(topic, scope: guardian, root: false))
280302
end
281303
end,
282304
)

spec/models/ai_tool_spec.rb

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
RSpec.describe AiTool do
44
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
55
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
6+
fab!(:topic)
7+
fab!(:post) { Fabricate(:post, topic: topic, raw: "bananas are a tasty fruit") }
68

79
def create_tool(
810
parameters: nil,
@@ -329,36 +331,74 @@ def stub_embeddings
329331
end
330332
end
331333

334+
context "when using the topic API" do
335+
it "can fetch topic details" do
336+
script = <<~JS
337+
function invoke(params) {
338+
return discourse.getTopic(params.topic_id);
339+
}
340+
JS
341+
342+
tool = create_tool(script: script)
343+
runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil, context: {})
344+
345+
result = runner.invoke
346+
347+
expect(result["id"]).to eq(topic.id)
348+
expect(result["title"]).to eq(topic.title)
349+
expect(result["archetype"]).to eq("regular")
350+
expect(result["posts_count"]).to eq(1)
351+
end
352+
end
353+
354+
context "when using the post API" do
355+
it "can fetch post details" do
356+
script = <<~JS
357+
function invoke(params) {
358+
const post = discourse.getPost(params.post_id);
359+
return {
360+
post: post,
361+
topic: post.topic
362+
}
363+
}
364+
JS
365+
366+
tool = create_tool(script: script)
367+
runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil, context: {})
368+
369+
result = runner.invoke
370+
post_hash = result["post"]
371+
topic_hash = result["topic"]
372+
373+
expect(post_hash["id"]).to eq(post.id)
374+
expect(post_hash["topic_id"]).to eq(topic.id)
375+
expect(post_hash["raw"]).to eq(post.raw)
376+
377+
expect(topic_hash["id"]).to eq(topic.id)
378+
end
379+
end
380+
332381
context "when using the search API" do
333382
before { SearchIndexer.enable }
334383
after { SearchIndexer.disable }
335384

336385
it "can perform a discourse search" do
337-
# Create a new topic
338-
topic = Fabricate(:topic, title: "Test Search Topic", category: Fabricate(:category))
339-
post = Fabricate(:post, topic: topic, raw: "This is a test post content, banana")
340-
341-
# Ensure the topic is indexed
342386
SearchIndexer.index(topic, force: true)
343387
SearchIndexer.index(post, force: true)
344388

345-
# Define the tool script
346389
script = <<~JS
347-
function invoke(params) {
348-
return discourse.search({ search_query: params.query });
349-
}
350-
JS
390+
function invoke(params) {
391+
return discourse.search({ search_query: params.query });
392+
}
393+
JS
351394

352-
# Create the tool and runner
353395
tool = create_tool(script: script)
354396
runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {})
355397

356-
# Invoke the tool and get the results
357398
result = runner.invoke
358399

359-
# Verify the topic is found
360400
expect(result["rows"].length).to be > 0
361-
expect(result["rows"].first["title"]).to eq("Test Search Topic")
401+
expect(result["rows"].first["title"]).to eq(topic.title)
362402
end
363403
end
364404
end

0 commit comments

Comments
 (0)