Skip to content

Commit 593d9ed

Browse files
committed
feat: add comprehensive prompt caching support for Groq provider
- Enable supportsPromptCache flag for all Groq models with 80% discount pricing - Add groqUsePromptCache setting to enable/disable caching - Implement GroqCacheStrategy for optimal message formatting - Override createMessage to handle multiple cache token field names - Add conversation cache state management - Add comprehensive test coverage for caching functionality Similar to Cline PR #5697 but adapted for Groq automatic prefix caching
1 parent 34abaf0 commit 593d9ed

File tree

5 files changed

+705
-2
lines changed

5 files changed

+705
-2
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ const xaiSchema = apiModelIdProviderModelSchema.extend({
266266

267267
const groqSchema = apiModelIdProviderModelSchema.extend({
268268
groqApiKey: z.string().optional(),
269+
groqUsePromptCache: z.boolean().optional(),
269270
})
270271

271272
const huggingFaceSchema = baseProviderSettingsSchema.extend({

src/api/providers/__tests__/groq.spec.ts

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,5 +320,284 @@ describe("GroqHandler", () => {
320320
cacheWriteTokens: 0,
321321
cacheReadTokens: 0, // Default to 0 when not provided
322322
})
323+
324+
describe("Prompt Caching", () => {
325+
it("should use caching strategy when groqUsePromptCache is enabled", async () => {
326+
const handlerWithCache = new GroqHandler({
327+
groqApiKey: "test-groq-api-key",
328+
groqUsePromptCache: true,
329+
})
330+
331+
mockCreate.mockImplementationOnce(() => {
332+
return {
333+
[Symbol.asyncIterator]: () => ({
334+
async next() {
335+
return { done: true }
336+
},
337+
}),
338+
}
339+
})
340+
341+
const systemPrompt = "Test system prompt for caching"
342+
const messages: Anthropic.Messages.MessageParam[] = [
343+
{ role: "user", content: "First message" },
344+
{ role: "assistant", content: "First response" },
345+
{ role: "user", content: "Second message" },
346+
]
347+
348+
const messageGenerator = handlerWithCache.createMessage(systemPrompt, messages)
349+
await messageGenerator.next()
350+
351+
// Verify that the messages were formatted with the system prompt
352+
expect(mockCreate).toHaveBeenCalledWith(
353+
expect.objectContaining({
354+
messages: expect.arrayContaining([
355+
{ role: "system", content: systemPrompt },
356+
{ role: "user", content: "First message" },
357+
{ role: "assistant", content: "First response" },
358+
{ role: "user", content: "Second message" },
359+
]),
360+
}),
361+
undefined,
362+
)
363+
})
364+
365+
it("should not use caching strategy when groqUsePromptCache is disabled", async () => {
366+
const handlerWithoutCache = new GroqHandler({
367+
groqApiKey: "test-groq-api-key",
368+
groqUsePromptCache: false,
369+
})
370+
371+
mockCreate.mockImplementationOnce(() => {
372+
return {
373+
[Symbol.asyncIterator]: () => ({
374+
async next() {
375+
return { done: true }
376+
},
377+
}),
378+
}
379+
})
380+
381+
const systemPrompt = "Test system prompt without caching"
382+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
383+
384+
const messageGenerator = handlerWithoutCache.createMessage(systemPrompt, messages)
385+
await messageGenerator.next()
386+
387+
// Verify standard formatting is used
388+
expect(mockCreate).toHaveBeenCalledWith(
389+
expect.objectContaining({
390+
messages: expect.arrayContaining([
391+
{ role: "system", content: systemPrompt },
392+
{ role: "user", content: "Test message" },
393+
]),
394+
}),
395+
undefined,
396+
)
397+
})
398+
399+
it("should handle multiple cache read token field names", async () => {
400+
const testContent = "Test content"
401+
402+
// Test different field names that Groq might use for cached tokens
403+
const cacheFieldVariations = [
404+
{ cached_tokens: 30 },
405+
{ cache_read_input_tokens: 40 },
406+
{ cache_tokens: 50 },
407+
]
408+
409+
for (const cacheFields of cacheFieldVariations) {
410+
vitest.clearAllMocks()
411+
412+
mockCreate.mockImplementationOnce(() => {
413+
return {
414+
[Symbol.asyncIterator]: () => ({
415+
next: vitest
416+
.fn()
417+
.mockResolvedValueOnce({
418+
done: false,
419+
value: { choices: [{ delta: { content: testContent } }] },
420+
})
421+
.mockResolvedValueOnce({
422+
done: false,
423+
value: {
424+
choices: [{ delta: {} }],
425+
usage: {
426+
prompt_tokens: 100,
427+
completion_tokens: 20,
428+
prompt_tokens_details: cacheFields,
429+
},
430+
},
431+
})
432+
.mockResolvedValueOnce({ done: true }),
433+
}),
434+
}
435+
})
436+
437+
const stream = handler.createMessage("system prompt", [])
438+
const chunks = []
439+
for await (const chunk of stream) {
440+
chunks.push(chunk)
441+
}
442+
443+
// Get the expected cached tokens value
444+
const expectedCachedTokens = Object.values(cacheFields)[0]
445+
446+
// Should properly extract cached tokens from any of the field names
447+
expect(chunks[1]).toEqual({
448+
type: "usage",
449+
inputTokens: 100 - expectedCachedTokens,
450+
outputTokens: 20,
451+
cacheWriteTokens: 0,
452+
cacheReadTokens: expectedCachedTokens,
453+
})
454+
}
455+
})
456+
457+
it("should maintain conversation cache state across multiple messages", async () => {
458+
const handlerWithCache = new GroqHandler({
459+
groqApiKey: "test-groq-api-key",
460+
groqUsePromptCache: true,
461+
})
462+
463+
mockCreate.mockImplementation(() => {
464+
return {
465+
[Symbol.asyncIterator]: () => ({
466+
async next() {
467+
return { done: true }
468+
},
469+
}),
470+
}
471+
})
472+
473+
const systemPrompt = "System prompt for conversation"
474+
const firstMessages: Anthropic.Messages.MessageParam[] = [
475+
{ role: "user", content: "First user message" },
476+
]
477+
478+
// First call
479+
const firstGenerator = handlerWithCache.createMessage(systemPrompt, firstMessages)
480+
await firstGenerator.next()
481+
482+
// Add more messages for second call
483+
const secondMessages: Anthropic.Messages.MessageParam[] = [
484+
...firstMessages,
485+
{ role: "assistant", content: "First assistant response" },
486+
{ role: "user", content: "Second user message" },
487+
]
488+
489+
// Second call with extended conversation
490+
const secondGenerator = handlerWithCache.createMessage(systemPrompt, secondMessages)
491+
await secondGenerator.next()
492+
493+
// Both calls should have been made
494+
expect(mockCreate).toHaveBeenCalledTimes(2)
495+
496+
// Verify the second call has all messages
497+
const secondCallArgs = mockCreate.mock.calls[1][0]
498+
expect(secondCallArgs.messages).toHaveLength(4) // system + 3 messages
499+
})
500+
501+
it("should handle complex message content with caching", async () => {
502+
const handlerWithCache = new GroqHandler({
503+
groqApiKey: "test-groq-api-key",
504+
groqUsePromptCache: true,
505+
})
506+
507+
mockCreate.mockImplementationOnce(() => {
508+
return {
509+
[Symbol.asyncIterator]: () => ({
510+
async next() {
511+
return { done: true }
512+
},
513+
}),
514+
}
515+
})
516+
517+
const systemPrompt = "System prompt"
518+
const messages: Anthropic.Messages.MessageParam[] = [
519+
{
520+
role: "user",
521+
content: [
522+
{ type: "text", text: "Part 1" },
523+
{ type: "text", text: "Part 2" },
524+
],
525+
},
526+
{
527+
role: "assistant",
528+
content: [
529+
{ type: "text", text: "Response part 1" },
530+
{ type: "text", text: "Response part 2" },
531+
],
532+
},
533+
]
534+
535+
const messageGenerator = handlerWithCache.createMessage(systemPrompt, messages)
536+
await messageGenerator.next()
537+
538+
// Verify that complex content is properly converted
539+
expect(mockCreate).toHaveBeenCalledWith(
540+
expect.objectContaining({
541+
messages: expect.arrayContaining([
542+
{ role: "system", content: systemPrompt },
543+
{ role: "user", content: "Part 1\nPart 2" },
544+
{ role: "assistant", content: "Response part 1\nResponse part 2" },
545+
]),
546+
}),
547+
undefined,
548+
)
549+
})
550+
551+
it("should respect model's supportsPromptCache flag", async () => {
552+
// Mock the getModel method to return a model without cache support
553+
const modelId: GroqModelId = "llama-3.1-8b-instant"
554+
555+
const handlerWithCache = new GroqHandler({
556+
apiModelId: modelId,
557+
groqApiKey: "test-groq-api-key",
558+
groqUsePromptCache: true, // Enabled but we'll mock the model to not support it
559+
})
560+
561+
// Override getModel to return a model without cache support
562+
const originalGetModel = handlerWithCache.getModel.bind(handlerWithCache)
563+
handlerWithCache.getModel = () => {
564+
const model = originalGetModel()
565+
return {
566+
...model,
567+
info: {
568+
...model.info,
569+
supportsPromptCache: false, // Override to false for this test
570+
},
571+
}
572+
}
573+
574+
mockCreate.mockImplementationOnce(() => {
575+
return {
576+
[Symbol.asyncIterator]: () => ({
577+
async next() {
578+
return { done: true }
579+
},
580+
}),
581+
}
582+
})
583+
584+
const systemPrompt = "Test system prompt"
585+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
586+
587+
const messageGenerator = handlerWithCache.createMessage(systemPrompt, messages)
588+
await messageGenerator.next()
589+
590+
// Should use standard formatting when model doesn't support caching
591+
expect(mockCreate).toHaveBeenCalledWith(
592+
expect.objectContaining({
593+
messages: expect.arrayContaining([
594+
{ role: "system", content: systemPrompt },
595+
{ role: "user", content: "Test message" },
596+
]),
597+
}),
598+
undefined,
599+
)
600+
})
601+
})
323602
})
324603
})

0 commit comments

Comments
 (0)