diff --git a/lib/messages/inject.ts b/lib/messages/inject.ts index debfd83..bec4a0c 100644 --- a/lib/messages/inject.ts +++ b/lib/messages/inject.ts @@ -3,7 +3,7 @@ import type { Logger } from "../logger" import type { PluginConfig } from "../config" import { loadPrompt } from "../prompts" import { extractParameterKey, buildToolIdList, createSyntheticUserMessage } from "./utils" -import { getLastUserMessage } from "../shared-utils" +import { getLastUserMessage, findUserVariant } from "../shared-utils" const getNudgeString = (config: PluginConfig): string => { const discardEnabled = config.tools.discard.enabled @@ -125,5 +125,6 @@ export const insertPruneToolContext = ( if (!lastUserMessage) { return } - messages.push(createSyntheticUserMessage(lastUserMessage, prunableToolsContent)) + const variant = findUserVariant(messages) + messages.push(createSyntheticUserMessage(lastUserMessage, prunableToolsContent, variant)) } diff --git a/lib/messages/utils.ts b/lib/messages/utils.ts index 756c1de..499d677 100644 --- a/lib/messages/utils.ts +++ b/lib/messages/utils.ts @@ -6,7 +6,7 @@ import type { UserMessage } from "@opencode-ai/sdk" const SYNTHETIC_MESSAGE_ID = "msg_01234567890123456789012345" const SYNTHETIC_PART_ID = "prt_01234567890123456789012345" -export const createSyntheticUserMessage = (baseMessage: WithParts, content: string): WithParts => { +export const createSyntheticUserMessage = (baseMessage: WithParts, content: string, variant?: string): WithParts => { const userInfo = baseMessage.info as UserMessage return { info: { @@ -19,6 +19,8 @@ export const createSyntheticUserMessage = (baseMessage: WithParts, content: stri providerID: userInfo.model.providerID, modelID: userInfo.model.modelID, }, + // @opencode-ai/sdk doesn't yet ship a variant type + ...(variant !== undefined && { variant }), }, parts: [ { diff --git a/lib/shared-utils.ts b/lib/shared-utils.ts index ce3be56..7b016dc 100644 --- a/lib/shared-utils.ts +++ b/lib/shared-utils.ts @@ -13,3 +13,16 @@ export const getLastUserMessage = (messages: WithParts[]): WithParts | null => { } return null } + +export const findUserVariant = (messages: WithParts[]): string | undefined => { + for (let i = messages.length - 1; i >= 0; i--) { + const msg = messages[i] + if (msg.info.role === "user") { + const variant = (msg.info as any).variant + if (variant !== undefined && variant !== null) { + return variant + } + } + } + return undefined +}