diff --git a/lib/messages/prune.ts b/lib/messages/prune.ts index 7361b74..f556a9e 100644 --- a/lib/messages/prune.ts +++ b/lib/messages/prune.ts @@ -1,8 +1,10 @@ import type { SessionState, WithParts } from "../state" import type { Logger } from "../logger" import type { PluginConfig } from "../config" -import { getLastUserMessage, extractParameterKey, buildToolIdList } from "./utils" import { loadPrompt } from "../prompt" +import { extractParameterKey, buildToolIdList } from "./utils" +import { getLastUserMessage } from "../shared-utils" +import { UserMessage } from "@opencode-ai/sdk" const PRUNED_TOOL_OUTPUT_REPLACEMENT = '[Output removed to save context - information superseded or no longer needed]' const NUDGE_STRING = loadPrompt("nudge") @@ -51,7 +53,7 @@ export const insertPruneToolContext = ( } const lastUserMessage = getLastUserMessage(messages) - if (!lastUserMessage || lastUserMessage.info.role !== 'user') { + if (!lastUserMessage) { return } @@ -72,10 +74,10 @@ export const insertPruneToolContext = ( sessionID: lastUserMessage.info.sessionID, role: "user", time: { created: Date.now() }, - agent: lastUserMessage.info.agent || "build", + agent: (lastUserMessage.info as UserMessage).agent || "build", model: { - providerID: lastUserMessage.info.model.providerID, - modelID: lastUserMessage.info.model.modelID + providerID: (lastUserMessage.info as UserMessage).model.providerID, + modelID: (lastUserMessage.info as UserMessage).model.modelID } }, parts: [ @@ -118,9 +120,6 @@ const pruneToolOutputs = ( if (part.state.status === 'completed') { part.state.output = PRUNED_TOOL_OUTPUT_REPLACEMENT } - // if (part.state.status === 'error') { - // part.state.error = PRUNED_TOOL_OUTPUT_REPLACEMENT - // } } } } diff --git a/lib/messages/utils.ts b/lib/messages/utils.ts index 9638e74..1ae2c7d 100644 --- a/lib/messages/utils.ts +++ b/lib/messages/utils.ts @@ -1,5 +1,3 @@ -import { UserMessage } from "@opencode-ai/sdk" -import { Logger } from "../logger" import type { WithParts } from "../state" /** @@ -73,38 +71,6 @@ export const extractParameterKey = (tool: string, parameters: any): string => { return paramStr.substring(0, 50) } -export const getLastUserMessage = ( - messages: WithParts[] -): WithParts | null => { - for (let i = messages.length - 1; i >= 0; i--) { - const msg = messages[i] - if (msg.info.role === 'user') { - return msg - } - } - return null -} - -export function getCurrentParams( - messages: WithParts[], - logger: Logger -): { - providerId: string | undefined, - modelId: string | undefined, - agent: string | undefined -} { - const userMsg = getLastUserMessage(messages) - if (!userMsg) { - logger.debug("No user message found when determining current params") - return { providerId: undefined, modelId: undefined, agent: undefined } - } - const agent: string = (userMsg.info as UserMessage).agent - const providerId: string | undefined = (userMsg.info as UserMessage).model.providerID - const modelId: string | undefined = (userMsg.info as UserMessage).model.modelID - - return { providerId, modelId, agent } -} - export function buildToolIdList(messages: WithParts[]): string[] { const toolIds: string[] = [] for (const msg of messages) { diff --git a/lib/shared-utils.ts b/lib/shared-utils.ts new file mode 100644 index 0000000..cdcfb5b --- /dev/null +++ b/lib/shared-utils.ts @@ -0,0 +1,13 @@ +import { WithParts } from "./state" + +export const getLastUserMessage = ( + messages: WithParts[] +): WithParts | null => { + for (let i = messages.length - 1; i >= 0; i--) { + const msg = messages[i] + if (msg.info.role === 'user') { + return msg + } + } + return null +} diff --git a/lib/state/state.ts b/lib/state/state.ts index 91e3f92..19dc854 100644 --- a/lib/state/state.ts +++ b/lib/state/state.ts @@ -1,8 +1,8 @@ import type { SessionState, ToolParameterEntry, WithParts } from "./types" import type { Logger } from "../logger" import { loadSessionState } from "./persistence" -import { getLastUserMessage } from "../messages/utils" -import { isSubAgentSession } from "../utils" +import { isSubAgentSession } from "./utils" +import { getLastUserMessage } from "../shared-utils" export const checkSession = async ( client: any, diff --git a/lib/state/utils.ts b/lib/state/utils.ts new file mode 100644 index 0000000..4cc10ce --- /dev/null +++ b/lib/state/utils.ts @@ -0,0 +1,8 @@ +export async function isSubAgentSession(client: any, sessionID: string): Promise { + try { + const result = await client.session.get({ path: { id: sessionID } }) + return !!result.data?.parentID + } catch (error: any) { + return false + } +} diff --git a/lib/strategies/deduplication.ts b/lib/strategies/deduplication.ts index 61cc484..eaa9798 100644 --- a/lib/strategies/deduplication.ts +++ b/lib/strategies/deduplication.ts @@ -1,8 +1,8 @@ import { PluginConfig } from "../config" import { Logger } from "../logger" import type { SessionState, WithParts } from "../state" -import { calculateTokensSaved } from "../utils" import { buildToolIdList } from "../messages/utils" +import { calculateTokensSaved } from "./utils" /** * Deduplication strategy - prunes older tool calls that have identical diff --git a/lib/strategies/on-idle.ts b/lib/strategies/on-idle.ts index 16698b3..50dbb2f 100644 --- a/lib/strategies/on-idle.ts +++ b/lib/strategies/on-idle.ts @@ -4,10 +4,9 @@ import type { Logger } from "../logger" import type { PluginConfig } from "../config" import { buildAnalysisPrompt } from "../prompt" import { selectModel, ModelInfo } from "../model-selector" -import { calculateTokensSaved } from "../utils" -import { getCurrentParams } from "../messages/utils" import { saveSessionState } from "../state/persistence" import { sendUnifiedNotification } from "../ui/notification" +import { calculateTokensSaved, getCurrentParams } from "./utils" export interface OnIdleResult { prunedCount: number diff --git a/lib/strategies/prune-tool.ts b/lib/strategies/prune-tool.ts index 6807070..a83694f 100644 --- a/lib/strategies/prune-tool.ts +++ b/lib/strategies/prune-tool.ts @@ -1,14 +1,14 @@ import { tool } from "@opencode-ai/plugin" import type { SessionState, ToolParameterEntry, WithParts } from "../state" import type { PluginConfig } from "../config" -import { getCurrentParams, buildToolIdList } from "../messages/utils" -import { calculateTokensSaved } from "../utils" +import { buildToolIdList } from "../messages/utils" import { PruneReason, sendUnifiedNotification } from "../ui/notification" -import { formatPruningResultForTool } from "../ui/display-utils" +import { formatPruningResultForTool } from "../ui/utils" import { ensureSessionInitialized } from "../state" import { saveSessionState } from "../state/persistence" import type { Logger } from "../logger" import { loadPrompt } from "../prompt" +import { calculateTokensSaved, getCurrentParams } from "./utils" /** Tool description loaded from prompts/tool.txt */ const TOOL_DESCRIPTION = loadPrompt("tool") diff --git a/lib/utils.ts b/lib/strategies/utils.ts similarity index 66% rename from lib/utils.ts rename to lib/strategies/utils.ts index 842b964..af18963 100644 --- a/lib/utils.ts +++ b/lib/strategies/utils.ts @@ -1,5 +1,28 @@ -import { WithParts } from "./state" +import { WithParts } from "../state" +import { UserMessage } from "@opencode-ai/sdk" +import { Logger } from "../logger" import { encode } from 'gpt-tokenizer' +import { getLastUserMessage } from "../shared-utils" + +export function getCurrentParams( + messages: WithParts[], + logger: Logger +): { + providerId: string | undefined, + modelId: string | undefined, + agent: string | undefined +} { + const userMsg = getLastUserMessage(messages) + if (!userMsg) { + logger.debug("No user message found when determining current params") + return { providerId: undefined, modelId: undefined, agent: undefined } + } + const agent: string = (userMsg.info as UserMessage).agent + const providerId: string | undefined = (userMsg.info as UserMessage).model.providerID + const modelId: string | undefined = (userMsg.info as UserMessage).model.modelID + + return { providerId, modelId, agent } +} /** * Estimates token counts for a batch of texts using gpt-tokenizer. @@ -47,19 +70,3 @@ export const calculateTokensSaved = ( return 0 } } - -export function formatTokenCount(tokens: number): string { - if (tokens >= 1000) { - return `${(tokens / 1000).toFixed(1)}K`.replace('.0K', 'K') + ' tokens' - } - return tokens.toString() + ' tokens' -} - -export async function isSubAgentSession(client: any, sessionID: string): Promise { - try { - const result = await client.session.get({ path: { id: sessionID } }) - return !!result.data?.parentID - } catch (error: any) { - return false - } -} diff --git a/lib/ui/notification.ts b/lib/ui/notification.ts index 00ad378..ead50ac 100644 --- a/lib/ui/notification.ts +++ b/lib/ui/notification.ts @@ -1,7 +1,6 @@ import type { Logger } from "../logger" import type { SessionState } from "../state" -import { formatTokenCount } from "../utils" -import { formatPrunedItemsList } from "./display-utils" +import { formatPrunedItemsList, formatTokenCount } from "./utils" import { ToolParameterEntry } from "../state" import { PluginConfig } from "../config" diff --git a/lib/ui/display-utils.ts b/lib/ui/utils.ts similarity index 92% rename from lib/ui/display-utils.ts rename to lib/ui/utils.ts index deb23a3..11335fa 100644 --- a/lib/ui/display-utils.ts +++ b/lib/ui/utils.ts @@ -1,6 +1,13 @@ import { ToolParameterEntry } from "../state" import { extractParameterKey } from "../messages/utils" +export function formatTokenCount(tokens: number): string { + if (tokens >= 1000) { + return `${(tokens / 1000).toFixed(1)}K`.replace('.0K', 'K') + ' tokens' + } + return tokens.toString() + ' tokens' +} + export function truncate(str: string, maxLen: number = 60): string { if (str.length <= maxLen) return str return str.slice(0, maxLen - 3) + '...'