diff --git a/packages/slack-bot/src/classifier/index.test.ts b/packages/slack-bot/src/classifier/index.test.ts new file mode 100644 index 0000000..b43cbea --- /dev/null +++ b/packages/slack-bot/src/classifier/index.test.ts @@ -0,0 +1,153 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Env, RepoConfig } from "../types"; + +const { + mockMessagesCreate, + mockGetAvailableRepos, + mockBuildRepoDescriptions, + mockGetReposByChannel, +} = vi.hoisted(() => ({ + mockMessagesCreate: vi.fn(), + mockGetAvailableRepos: vi.fn(), + mockBuildRepoDescriptions: vi.fn(), + mockGetReposByChannel: vi.fn(), +})); + +vi.mock("@anthropic-ai/sdk", () => ({ + default: vi.fn().mockImplementation(() => ({ + messages: { + create: mockMessagesCreate, + }, + })), +})); + +vi.mock("./repos", () => ({ + getAvailableRepos: mockGetAvailableRepos, + buildRepoDescriptions: mockBuildRepoDescriptions, + getReposByChannel: mockGetReposByChannel, +})); + +import { RepoClassifier } from "./index"; + +const TEST_REPOS: RepoConfig[] = [ + { + id: "acme/prod", + owner: "acme", + name: "prod", + fullName: "acme/prod", + displayName: "prod", + description: "Production worker", + defaultBranch: "main", + private: true, + aliases: ["production"], + keywords: ["worker", "slack"], + }, + { + id: "acme/web", + owner: "acme", + name: "web", + fullName: "acme/web", + displayName: "web", + description: "Web application", + defaultBranch: "main", + private: true, + aliases: ["frontend"], + keywords: ["react", "ui"], + }, +]; + +const TEST_ENV = { + ANTHROPIC_API_KEY: "test-api-key", + CLASSIFICATION_MODEL: "claude-haiku-4-5", +} as Env; + +describe("RepoClassifier", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGetAvailableRepos.mockResolvedValue(TEST_REPOS); + mockGetReposByChannel.mockResolvedValue([]); + mockBuildRepoDescriptions.mockResolvedValue("- acme/prod\n- acme/web"); + }); + + it("uses tool output when provider returns valid structured classification", async () => { + mockMessagesCreate.mockResolvedValue({ + content: [ + { + type: "tool_use", + id: "toolu_1", + name: "classify_repository", + input: { + repoId: "acme/prod", + confidence: "high", + reasoning: "The message explicitly mentions prod.", + alternatives: [], + }, + }, + ], + }); + + const classifier = new RepoClassifier(TEST_ENV); + const result = await classifier.classify("please fix prod slack alerts", undefined, "trace-1"); + + expect(result.repo?.fullName).toBe("acme/prod"); + expect(result.confidence).toBe("high"); + expect(result.needsClarification).toBe(false); + expect(mockMessagesCreate).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0, + tool_choice: expect.objectContaining({ + type: "tool", + name: "classify_repository", + }), + tools: [expect.objectContaining({ name: "classify_repository" })], + }) + ); + }); + + it("asks for clarification when tool payload is invalid", async () => { + mockMessagesCreate.mockResolvedValue({ + content: [ + { + type: "tool_use", + id: "toolu_2", + name: "classify_repository", + input: { + repoId: "acme/prod", + confidence: "certain", + reasoning: "Totally sure", + alternatives: [], + }, + }, + ], + }); + + const classifier = new RepoClassifier(TEST_ENV); + const result = await classifier.classify("please update prod deployment config"); + + expect(result.repo).toBeNull(); + expect(result.confidence).toBe("low"); + expect(result.needsClarification).toBe(true); + expect(result.reasoning).toContain("structured model output"); + expect(result.alternatives).toHaveLength(2); + }); + + it("asks for clarification when tool output is missing", async () => { + mockMessagesCreate.mockResolvedValue({ + content: [ + { + type: "text", + text: '{"repoId":"acme/web","confidence":"high","reasoning":"Mentions frontend and UI.","alternatives":[]}', + }, + ], + }); + + const classifier = new RepoClassifier(TEST_ENV); + const result = await classifier.classify("frontend UI issue in web app"); + + expect(result.repo).toBeNull(); + expect(result.confidence).toBe("low"); + expect(result.needsClarification).toBe(true); + expect(result.reasoning).toContain("structured model output"); + expect(result.alternatives).toHaveLength(2); + }); +}); diff --git a/packages/slack-bot/src/classifier/index.ts b/packages/slack-bot/src/classifier/index.ts index eba9e86..8f67ca9 100644 --- a/packages/slack-bot/src/classifier/index.ts +++ b/packages/slack-bot/src/classifier/index.ts @@ -11,6 +11,38 @@ import { getAvailableRepos, buildRepoDescriptions, getReposByChannel } from "./r import { createLogger } from "../logger"; const log = createLogger("classifier"); +const CLASSIFY_REPO_TOOL_NAME = "classify_repository"; +const CONFIDENCE_LEVELS: ClassificationResult["confidence"][] = ["high", "medium", "low"]; + +const CLASSIFY_REPO_TOOL: Anthropic.Messages.Tool = { + name: CLASSIFY_REPO_TOOL_NAME, + description: + "Classify which repository a Slack message refers to. Use repoId as null when uncertain.", + input_schema: { + type: "object", + properties: { + repoId: { + type: ["string", "null"], + description: "Repository ID/fullName if confident enough to choose one, otherwise null.", + }, + confidence: { + type: "string", + enum: CONFIDENCE_LEVELS, + }, + reasoning: { + type: "string", + description: "Brief explanation of classification decision.", + }, + alternatives: { + type: "array", + items: { type: "string" }, + description: "Alternative repository IDs/fullNames when confidence is not high.", + }, + }, + required: ["repoId", "confidence", "reasoning", "alternatives"], + additionalProperties: false, + }, +}; /** * Build the classification prompt for the LLM. @@ -63,21 +95,11 @@ Consider: ## Response Format -Respond with a JSON object (no markdown code blocks): -{ - "repoId": "owner/name" or null if unclear, - "confidence": "high" | "medium" | "low", - "reasoning": "Brief explanation of why you chose this repo", - "alternatives": ["owner/name", ...] // Other possible repos if confidence is not high -} - -If no repository matches or the message doesn't seem to be about code: -{ - "repoId": null, - "confidence": "low", - "reasoning": "Explanation of why no repo was identified", - "alternatives": [] -}`; +Return your decision by calling the ${CLASSIFY_REPO_TOOL_NAME} tool with: +- repoId: "owner/name" or null if unclear +- confidence: "high" | "medium" | "low" +- reasoning: brief explanation +- alternatives: other possible repos when confidence is not high`; } /** @@ -87,7 +109,65 @@ interface LLMResponse { repoId: string | null; confidence: "high" | "medium" | "low"; reasoning: string; - alternatives?: string[]; + alternatives: string[]; +} + +function normalizeModelResponse(raw: unknown): LLMResponse { + if (!raw || typeof raw !== "object" || Array.isArray(raw)) { + throw new Error("LLM response was not an object"); + } + + const input = raw as Record; + const rawRepoId = input.repoId; + const repoId = + rawRepoId === null + ? null + : typeof rawRepoId === "string" && rawRepoId.trim().length > 0 + ? rawRepoId.trim() + : null; + + const rawConfidence = typeof input.confidence === "string" ? input.confidence.trim() : ""; + const confidence = rawConfidence.toLowerCase(); + if (!CONFIDENCE_LEVELS.includes(confidence as ClassificationResult["confidence"])) { + throw new Error(`Invalid confidence value: ${rawConfidence || String(input.confidence)}`); + } + + if (typeof input.reasoning !== "string" || input.reasoning.trim().length === 0) { + throw new Error("Missing reasoning in LLM response"); + } + + if (!Array.isArray(input.alternatives)) { + throw new Error("Alternatives must be an array"); + } + + const alternatives = input.alternatives + .filter((value): value is string => typeof value === "string") + .map((value) => value.trim()) + .filter((value) => value.length > 0); + + if (alternatives.length !== input.alternatives.length) { + throw new Error("Invalid alternatives in LLM response"); + } + + return { + repoId, + confidence: confidence as ClassificationResult["confidence"], + reasoning: input.reasoning.trim(), + alternatives: [...new Set(alternatives)], + }; +} + +function extractStructuredResponse(response: Anthropic.Messages.Message): LLMResponse { + const toolUseBlock = response.content.find( + (block): block is Anthropic.Messages.ToolUseBlock => + block.type === "tool_use" && block.name === CLASSIFY_REPO_TOOL_NAME + ); + + if (!toolUseBlock) { + throw new Error("No structured tool_use classification in LLM response"); + } + + return normalizeModelResponse(toolUseBlock.input); } /** @@ -155,6 +235,13 @@ export class RepoClassifier { const response = await this.client.messages.create({ model: this.env.CLASSIFICATION_MODEL || "claude-haiku-4-5", max_tokens: 500, + temperature: 0, + tools: [CLASSIFY_REPO_TOOL], + tool_choice: { + type: "tool", + name: CLASSIFY_REPO_TOOL_NAME, + disable_parallel_tool_use: true, + }, messages: [ { role: "user", @@ -163,14 +250,7 @@ export class RepoClassifier { ], }); - // Extract text from response - const textContent = response.content.find((c) => c.type === "text"); - if (!textContent || textContent.type !== "text") { - throw new Error("No text response from LLM"); - } - - // Parse JSON response - const llmResult = JSON.parse(textContent.text) as LLMResponse; + const llmResult = extractStructuredResponse(response); // Find the matched repo let matchedRepo: RepoConfig | null = null; @@ -185,16 +265,14 @@ export class RepoClassifier { // Find alternative repos const alternatives: RepoConfig[] = []; - if (llmResult.alternatives) { - for (const altId of llmResult.alternatives) { - const altRepo = repos.find( - (r) => - r.id.toLowerCase() === altId.toLowerCase() || - r.fullName.toLowerCase() === altId.toLowerCase() - ); - if (altRepo && altRepo.id !== matchedRepo?.id) { - alternatives.push(altRepo); - } + for (const altId of llmResult.alternatives) { + const altRepo = repos.find( + (r) => + r.id.toLowerCase() === altId.toLowerCase() || + r.fullName.toLowerCase() === altId.toLowerCase() + ); + if (altRepo && altRepo.id !== matchedRepo?.id) { + alternatives.push(altRepo); } } @@ -217,85 +295,15 @@ export class RepoClassifier { channel_id: context?.channelId, }); - // Fallback: try simple keyword matching - return this.fallbackClassification(message, repos, context); - } - } - - /** - * Fallback classification using simple keyword matching. - */ - private fallbackClassification( - message: string, - repos: RepoConfig[], - context?: ThreadContext - ): ClassificationResult { - const messageLower = message.toLowerCase(); - - // Score each repo based on keyword matches - const scored = repos.map((repo) => { - let score = 0; - - // Check repo name - if (messageLower.includes(repo.name.toLowerCase())) { - score += 10; - } - - // Check owner - if (messageLower.includes(repo.owner.toLowerCase())) { - score += 5; - } - - // Check aliases - for (const alias of repo.aliases || []) { - if (messageLower.includes(alias.toLowerCase())) { - score += 8; - } - } - - // Check keywords - for (const keyword of repo.keywords || []) { - if (messageLower.includes(keyword.toLowerCase())) { - score += 3; - } - } - - // Check channel association - if (context?.channelId && repo.channelAssociations?.includes(context.channelId)) { - score += 15; - } - - return { repo, score }; - }); - - // Sort by score - scored.sort((a, b) => b.score - a.score); - - const topMatch = scored[0]; - const hasMatch = topMatch && topMatch.score > 0; - - if (!hasMatch) { return { repo: null, confidence: "low", - reasoning: "Could not determine repository from message content.", - alternatives: repos.slice(0, 3), + reasoning: + "Could not classify repository from structured model output. Please select a repository.", + alternatives: repos.slice(0, 5), needsClarification: true, }; } - - const confidence = topMatch.score >= 10 ? "high" : topMatch.score >= 5 ? "medium" : "low"; - - return { - repo: topMatch.repo, - confidence, - reasoning: `Matched based on keyword analysis (score: ${topMatch.score})`, - alternatives: scored - .slice(1, 4) - .filter((s) => s.score > 0) - .map((s) => s.repo), - needsClarification: confidence !== "high", - }; } }