Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions app/controllers/discourse_ai/ai_helper/assistant_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,26 @@ def suggest_title
end

def suggest_category
input = get_text_param!
input_hash = { text: input }
if params[:topic_id]
opts = { topic_id: params[:topic_id] }
else
input = get_text_param!
opts = { text: input }
end

render json:
DiscourseAi::AiHelper::SemanticCategorizer.new(
input_hash,
current_user,
).categories,
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).categories,
status: 200
end

def suggest_tags
input = get_text_param!
input_hash = { text: input }
if params[:topic_id]
opts = { topic_id: params[:topic_id] }
else
input = get_text_param!
opts = { text: input }
end

render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags,
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(current_user, opts).tags,
status: 200
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,13 @@ export default class AiCategorySuggester extends Component {
@tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles";
@tracked content = null;
@tracked topicContent = null;

constructor() {
super(...arguments);
if (!this.topicContent && this.args.composer?.reply === undefined) {
this.fetchTopicContent();
}
}

async fetchTopicContent() {
await ajax(`/t/${this.args.buffered.content.id}.json`).then(
({ post_stream }) => {
this.topicContent = post_stream.posts[0].cooked;
}
);
}

get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields");
this.content = this.args.composer?.reply || this.topicContent;
const showTrigger = this.content?.length > MIN_CHARACTER_COUNT;
this.content = this.args.composer?.reply;
const showTrigger =
this.content?.length > MIN_CHARACTER_COUNT ||
this.args.topicState === "edit";

if (composerFields) {
if (showTrigger) {
Expand All @@ -62,12 +48,20 @@ export default class AiCategorySuggester extends Component {
this.loading = true;
this.triggerIcon = "spinner";

const data = {};

if (this.content) {
data.text = this.content;
} else {
data.topic_id = this.args.buffered.content.id;
}

try {
const { assistant } = await ajax(
"/discourse-ai/ai-helper/suggest_category",
{
method: "POST",
data: { text: this.content },
data,
}
);
this.suggestions = assistant;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,13 @@ export default class AiTagSuggester extends Component {
@tracked untriggers = [];
@tracked triggerIcon = "discourse-sparkles";
@tracked content = null;
@tracked topicContent = null;

constructor() {
super(...arguments);
if (!this.topicContent && this.args.composer?.reply === undefined) {
this.fetchTopicContent();
}
}

async fetchTopicContent() {
await ajax(`/t/${this.args.buffered.content.id}.json`).then(
({ post_stream }) => {
this.topicContent = post_stream.posts[0].cooked;
}
);
}

get showSuggestionButton() {
const composerFields = document.querySelector(".composer-fields");
this.content = this.args.composer?.reply || this.topicContent;
const showTrigger = this.content?.length > MIN_CHARACTER_COUNT;
this.content = this.args.composer?.reply;
const showTrigger =
this.content?.length > MIN_CHARACTER_COUNT ||
this.args.topicState === "edit";

if (composerFields) {
if (showTrigger) {
Expand Down Expand Up @@ -74,15 +60,25 @@ export default class AiTagSuggester extends Component {
this.loading = true;
this.triggerIcon = "spinner";

const data = {};

if (this.content) {
data.text = this.content;
} else {
data.topic_id = this.args.buffered.content.id;
}

try {
const { assistant } = await ajax("/discourse-ai/ai-helper/suggest_tags", {
method: "POST",
data: { text: this.content },
data,
});
this.suggestions = assistant;

const model = this.args.composer
? this.args.composer
: this.args.buffered;

if (this.#tagSelectorHasValues()) {
this.suggestions = this.suggestions.filter(
(s) => !model.get("tags").includes(s.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
}

<template>
<AiCategorySuggester @composer={{@outletArgs.composer}} />
<AiCategorySuggester @composer={{@outletArgs.composer}} @topicState="new" />
</template>
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export default class AiTagSuggestion extends Component {
}

<template>
<AiTagSuggester @composer={{@outletArgs.composer}} />
<AiTagSuggester @composer={{@outletArgs.composer}} @topicState="new" />
</template>
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ export default class AiCategorySuggestion extends Component {
}

<template>
<AiCategorySuggester @buffered={{@outletArgs.buffered}} />
<AiCategorySuggester
@buffered={{@outletArgs.buffered}}
@topicState="edit"
/>
</template>
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export default class AiCategorySuggestion extends Component {
}

<template>
<AiTagSuggester @buffered={{@outletArgs.buffered}} />
<AiTagSuggester @buffered={{@outletArgs.buffered}} @topicState="edit" />
</template>
}
27 changes: 22 additions & 5 deletions lib/ai_helper/semantic_categorizer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
module DiscourseAi
module AiHelper
class SemanticCategorizer
def initialize(input, user)
def initialize(user, opts)
@user = user
@text = input[:text]
@text = opts[:text]
@vector = DiscourseAi::Embeddings::Vector.instance
@schema = DiscourseAi::Embeddings::Schema.for(Topic)
@topic_id = opts[:topic_id]
end

def categories
return [] if @text.blank?
return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled?

candidates = nearest_neighbors
Expand Down Expand Up @@ -55,7 +56,7 @@ def categories
end

def tags
return [] if @text.blank?
return [] if @text.blank? && @topic_id.nil?
return [] if !DiscourseAi::Embeddings.enabled?

candidates = nearest_neighbors(limit: 100)
Expand Down Expand Up @@ -100,7 +101,23 @@ def tags
private

def nearest_neighbors(limit: 50)
raw_vector = @vector.vector_from(@text)
if @topic_id
target = Topic.find_by(id: @topic_id)
embeddings = @schema.find_by_target(target)&.embeddings

if embeddings.blank?
@text =
DiscourseAi::Summarization::Strategies::TopicSummary
.new(target)
.targets_data
.pluck(:text)
raw_vector = @vector.vector_from(@text)
else
raw_vector = JSON.parse(embeddings)
end
else
raw_vector = @vector.vector_from(@text)
end

muted_category_ids = nil
if @user.present?
Expand Down
2 changes: 1 addition & 1 deletion spec/lib/modules/ai_helper/semantic_categorizer_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
fab!(:topic) { Fabricate(:topic, category: category) }

let(:vector) { DiscourseAi::Embeddings::Vector.instance }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new(user, { text: "hello" }) }
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }

before do
Expand Down