Skip to content

Commit 111abdb

Browse files
committed
Enhance prompt button for openrouter
1 parent 1581ed1 commit 111abdb

File tree

13 files changed

+703
-102
lines changed

13 files changed

+703
-102
lines changed

src/api/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ import { GeminiHandler } from "./providers/gemini"
1111
import { OpenAiNativeHandler } from "./providers/openai-native"
1212
import { ApiStream } from "./transform/stream"
1313

14+
export interface SingleCompletionHandler {
15+
completePrompt(prompt: string): Promise<string>
16+
}
17+
1418
export interface ApiHandler {
1519
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
1620
getModel(): { id: string; info: ModelInfo }

src/api/providers/__tests__/openrouter.test.ts

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ describe('OpenRouterHandler', () => {
5151
})
5252
})
5353

54+
test('getModel returns default model info when options are not provided', () => {
55+
const handler = new OpenRouterHandler({})
56+
const result = handler.getModel()
57+
58+
expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
59+
expect(result.info.supportsPromptCache).toBe(true)
60+
})
61+
5462
test('createMessage generates correct stream chunks', async () => {
5563
const handler = new OpenRouterHandler(mockOptions)
5664
const mockStream = {
@@ -118,4 +126,158 @@ describe('OpenRouterHandler', () => {
118126
stream: true
119127
}))
120128
})
129+
130+
test('createMessage with middle-out transform enabled', async () => {
131+
const handler = new OpenRouterHandler({
132+
...mockOptions,
133+
openRouterUseMiddleOutTransform: true
134+
})
135+
const mockStream = {
136+
async *[Symbol.asyncIterator]() {
137+
yield {
138+
id: 'test-id',
139+
choices: [{
140+
delta: {
141+
content: 'test response'
142+
}
143+
}]
144+
}
145+
}
146+
}
147+
148+
const mockCreate = jest.fn().mockResolvedValue(mockStream)
149+
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
150+
completions: { create: mockCreate }
151+
} as any
152+
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
153+
154+
await handler.createMessage('test', []).next()
155+
156+
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
157+
transforms: ['middle-out']
158+
}))
159+
})
160+
161+
test('createMessage with Claude model adds cache control', async () => {
162+
const handler = new OpenRouterHandler({
163+
...mockOptions,
164+
openRouterModelId: 'anthropic/claude-3.5-sonnet'
165+
})
166+
const mockStream = {
167+
async *[Symbol.asyncIterator]() {
168+
yield {
169+
id: 'test-id',
170+
choices: [{
171+
delta: {
172+
content: 'test response'
173+
}
174+
}]
175+
}
176+
}
177+
}
178+
179+
const mockCreate = jest.fn().mockResolvedValue(mockStream)
180+
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
181+
completions: { create: mockCreate }
182+
} as any
183+
;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
184+
185+
const messages: Anthropic.Messages.MessageParam[] = [
186+
{ role: 'user', content: 'message 1' },
187+
{ role: 'assistant', content: 'response 1' },
188+
{ role: 'user', content: 'message 2' }
189+
]
190+
191+
await handler.createMessage('test system', messages).next()
192+
193+
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
194+
messages: expect.arrayContaining([
195+
expect.objectContaining({
196+
role: 'system',
197+
content: expect.arrayContaining([
198+
expect.objectContaining({
199+
cache_control: { type: 'ephemeral' }
200+
})
201+
])
202+
})
203+
])
204+
}))
205+
})
206+
207+
test('createMessage handles API errors', async () => {
208+
const handler = new OpenRouterHandler(mockOptions)
209+
const mockStream = {
210+
async *[Symbol.asyncIterator]() {
211+
yield {
212+
error: {
213+
message: 'API Error',
214+
code: 500
215+
}
216+
}
217+
}
218+
}
219+
220+
const mockCreate = jest.fn().mockResolvedValue(mockStream)
221+
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
222+
completions: { create: mockCreate }
223+
} as any
224+
225+
const generator = handler.createMessage('test', [])
226+
await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
227+
})
228+
229+
test('completePrompt returns correct response', async () => {
230+
const handler = new OpenRouterHandler(mockOptions)
231+
const mockResponse = {
232+
choices: [{
233+
message: {
234+
content: 'test completion'
235+
}
236+
}]
237+
}
238+
239+
const mockCreate = jest.fn().mockResolvedValue(mockResponse)
240+
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
241+
completions: { create: mockCreate }
242+
} as any
243+
244+
const result = await handler.completePrompt('test prompt')
245+
246+
expect(result).toBe('test completion')
247+
expect(mockCreate).toHaveBeenCalledWith({
248+
model: mockOptions.openRouterModelId,
249+
messages: [{ role: 'user', content: 'test prompt' }],
250+
temperature: 0,
251+
stream: false
252+
})
253+
})
254+
255+
test('completePrompt handles API errors', async () => {
256+
const handler = new OpenRouterHandler(mockOptions)
257+
const mockError = {
258+
error: {
259+
message: 'API Error',
260+
code: 500
261+
}
262+
}
263+
264+
const mockCreate = jest.fn().mockResolvedValue(mockError)
265+
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
266+
completions: { create: mockCreate }
267+
} as any
268+
269+
await expect(handler.completePrompt('test prompt'))
270+
.rejects.toThrow('OpenRouter API Error 500: API Error')
271+
})
272+
273+
test('completePrompt handles unexpected errors', async () => {
274+
const handler = new OpenRouterHandler(mockOptions)
275+
const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
276+
;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
277+
completions: { create: mockCreate }
278+
} as any
279+
280+
await expect(handler.completePrompt('test prompt'))
281+
.rejects.toThrow('OpenRouter completion error: Unexpected error')
282+
})
121283
})

src/api/providers/openrouter.ts

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ import OpenAI from "openai"
44
import { ApiHandler } from "../"
55
import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
66
import { convertToOpenAiMessages } from "../transform/openai-format"
7-
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
7+
import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
88
import delay from "delay"
99

1010
// Add custom interface for OpenRouter params
11-
interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming {
11+
type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
1212
transforms?: string[];
1313
}
1414

@@ -17,7 +17,12 @@ interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
1717
fullResponseText: string;
1818
}
1919

20-
export class OpenRouterHandler implements ApiHandler {
20+
// Interface for providers that support single completions
21+
export interface SingleCompletionHandler {
22+
completePrompt(prompt: string): Promise<string>
23+
}
24+
25+
export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
2126
private options: ApiHandlerOptions
2227
private client: OpenAI
2328

@@ -184,4 +189,28 @@ export class OpenRouterHandler implements ApiHandler {
184189
}
185190
return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
186191
}
192+
193+
async completePrompt(prompt: string): Promise<string> {
194+
try {
195+
const response = await this.client.chat.completions.create({
196+
model: this.getModel().id,
197+
messages: [{ role: "user", content: prompt }],
198+
temperature: 0,
199+
stream: false
200+
})
201+
202+
if ("error" in response) {
203+
const error = response.error as { message?: string; code?: number }
204+
throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
205+
}
206+
207+
const completion = response as OpenAI.Chat.ChatCompletion
208+
return completion.choices[0]?.message?.content || ""
209+
} catch (error) {
210+
if (error instanceof Error) {
211+
throw new Error(`OpenRouter completion error: ${error.message}`)
212+
}
213+
throw error
214+
}
215+
}
187216
}

src/core/Cline.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for"
88
import * as path from "path"
99
import { serializeError } from "serialize-error"
1010
import * as vscode from "vscode"
11-
import { ApiHandler, buildApiHandler } from "../api"
11+
import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
1212
import { ApiStream } from "../api/transform/stream"
1313
import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
1414
import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
@@ -49,6 +49,7 @@ import { truncateHalfConversation } from "./sliding-window"
4949
import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider"
5050
import { detectCodeOmission } from "../integrations/editor/detect-omission"
5151
import { BrowserSession } from "../services/browser/BrowserSession"
52+
import { OpenRouterHandler } from "../api/providers/openrouter"
5253

5354
const cwd =
5455
vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
@@ -126,6 +127,22 @@ export class Cline {
126127
}
127128
}
128129

130+
async enhancePrompt(promptText: string): Promise<string> {
131+
if (!promptText) {
132+
throw new Error("No prompt text provided")
133+
}
134+
135+
const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no bullet points): ${promptText}`
136+
137+
// Check if the API handler supports completePrompt
138+
if (this.api instanceof OpenRouterHandler) {
139+
return this.api.completePrompt(prompt)
140+
}
141+
142+
// Otherwise just return the prompt as is
143+
return prompt;
144+
}
145+
129146
// Storing task to disk for history
130147

131148
private async ensureTaskDirectoryExists(): Promise<string> {

src/core/webview/ClineProvider.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { openMention } from "../mentions"
2323
import { getNonce } from "./getNonce"
2424
import { getUri } from "./getUri"
2525
import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
26+
import { enhancePrompt } from "../../utils/enhance-prompt"
2627

2728
/*
2829
https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -637,6 +638,26 @@ export class ClineProvider implements vscode.WebviewViewProvider {
637638
await this.updateGlobalState("writeDelayMs", message.value)
638639
await this.postStateToWebview()
639640
break
641+
case "enhancePrompt":
642+
if (message.text) {
643+
try {
644+
const { apiConfiguration } = await this.getState()
645+
const enhanceConfig = {
646+
...apiConfiguration,
647+
apiProvider: "openrouter" as const,
648+
openRouterModelId: "gpt-4o",
649+
}
650+
const enhancedPrompt = await enhancePrompt(enhanceConfig, message.text)
651+
await this.postMessageToWebview({
652+
type: "enhancedPrompt",
653+
text: enhancedPrompt
654+
})
655+
} catch (error) {
656+
console.error("Error enhancing prompt:", error)
657+
vscode.window.showErrorMessage("Failed to enhance prompt")
658+
}
659+
}
660+
break
640661
}
641662
},
642663
null,

src/shared/ExtensionMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export interface ExtensionMessage {
1919
| "openRouterModels"
2020
| "openAiModels"
2121
| "mcpServers"
22+
| "enhancedPrompt"
2223
text?: string
2324
action?:
2425
| "chatButtonClicked"

src/shared/WebviewMessage.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ export interface WebviewMessage {
4343
| "fuzzyMatchThreshold"
4444
| "preferredLanguage"
4545
| "writeDelayMs"
46+
| "enhancePrompt"
47+
| "enhancedPrompt"
48+
| "draggedImages"
4649
text?: string
4750
disabled?: boolean
4851
askResponse?: ClineAskResponse
@@ -52,10 +55,10 @@ export interface WebviewMessage {
5255
value?: number
5356
commands?: string[]
5457
audioType?: AudioType
55-
// For toggleToolAutoApprove
5658
serverName?: string
5759
toolName?: string
5860
alwaysAllow?: boolean
61+
dataUrls?: string[]
5962
}
6063

6164
export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"

0 commit comments

Comments
 (0)