Skip to content

Commit 85405c2

Browse files
authored
feat: support "CachedInputToken" type in "llmRequestCosts" (#1315)
**Description** Many AI providers have recently supported prompt caching on the provider side. Cached token prices are significantly cheaper than normal token processing. For example, in Open AI, cached token price is 10x cheaper than normal token [1]. Thus, Envoy AI Gateway would like to take cached token count into account in calculating `llmRequestCosts` in `AIGatewayRequestCosts`. Moreover, for self-hosted LLMs, cached tokens can drastically reduce GPU usage. Thus, in such a case, users would like to care about cached token usage in `llmRequestCosts`. 1: https://openai.com/api/pricing/ --------- Signed-off-by: Shingo Omura <[email protected]>
1 parent 90ba3e9 commit 85405c2

33 files changed

+223
-106
lines changed

api/v1alpha1/ai_gateway_route.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ type AIGatewayRouteSpec struct {
103103
// type: OutputToken
104104
// - metadataKey: llm_total_token
105105
// type: TotalToken
106+
// - metadataKey: llm_cached_input_token
107+
// type: CachedInputToken
106108
// ```
107109
// Then, with the following BackendTrafficPolicy of Envoy Gateway, you can have three
108110
// rate limit buckets for each unique x-user-id header value. One bucket is for the input token,

api/v1alpha1/shared_types.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ type LLMRequestCost struct {
8080
// and it uses "output token" as the cost. The other types are "InputToken", "TotalToken",
8181
// and "CEL".
8282
//
83-
// +kubebuilder:validation:Enum=OutputToken;InputToken;TotalToken;CEL
83+
// +kubebuilder:validation:Enum=OutputToken;InputToken;CachedInputToken;TotalToken;CEL
8484
Type LLMRequestCostType `json:"type"`
8585
// CEL is the CEL expression to calculate the cost of the request.
8686
// The CEL expression must return a signed or unsigned integer. If the
@@ -91,13 +91,15 @@ type LLMRequestCost struct {
9191
// * model: the model name extracted from the request content. Type: string.
9292
// * backend: the backend name in the form of "name.namespace". Type: string.
9393
// * input_tokens: the number of input tokens. Type: unsigned integer.
94+
// * cached_input_tokens: the number of cached input tokens. Type: unsigned integer.
9495
// * output_tokens: the number of output tokens. Type: unsigned integer.
9596
// * total_tokens: the total number of tokens. Type: unsigned integer.
9697
//
9798
// For example, the following expressions are valid:
9899
//
99100
// * "model == 'llama' ? input_tokens + output_token * 0.5 : total_tokens"
100101
// * "backend == 'foo.default' ? input_tokens + output_tokens : total_tokens"
102+
// * "backend == 'bar.default' ? (input_tokens - cached_input_tokens) + cached_input_tokens * 0.1 + output_tokens : total_tokens"
101103
// * "input_tokens + output_tokens + total_tokens"
102104
// * "input_tokens * output_tokens"
103105
//
@@ -111,6 +113,8 @@ type LLMRequestCostType string
111113
const (
112114
// LLMRequestCostTypeInputToken is the cost type of the input token.
113115
LLMRequestCostTypeInputToken LLMRequestCostType = "InputToken"
116+
// LLMRequestCostTypeCachedInputToken is the cost type of the cached input token.
117+
LLMRequestCostTypeCachedInputToken LLMRequestCostType = "CachedInputToken"
114118
// LLMRequestCostTypeOutputToken is the cost type of the output token.
115119
LLMRequestCostTypeOutputToken LLMRequestCostType = "OutputToken"
116120
// LLMRequestCostTypeTotalToken is the cost type of the total token.

examples/token_ratelimit/token_ratelimit.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ spec:
4949
llmRequestCosts:
5050
- metadataKey: llm_input_token
5151
type: InputToken
52+
- metadataKey: llm_cached_input_token
53+
type: CachedInputToken
5254
- metadataKey: llm_output_token
5355
type: OutputToken
5456
- metadataKey: llm_total_token
@@ -164,6 +166,26 @@ spec:
164166
namespace: io.envoy.ai_gateway
165167
key: llm_total_token
166168

169+
# Repeat the same configuration for a different token type.
170+
# This configures the cached input token limit, and it has a different budget than others,
171+
# so it will be rate limited separately.
172+
- clientSelectors:
173+
- headers:
174+
- name: x-user-id
175+
type: Distinct
176+
limit:
177+
requests: 100
178+
unit: Hour
179+
cost:
180+
request:
181+
from: Number
182+
number: 0
183+
response:
184+
from: Metadata
185+
metadata:
186+
namespace: io.envoy.ai_gateway
187+
key: llm_cached_input_token
188+
167189
# Repeat the same configuration for a different token type.
168190
# This configures the token limit based on the CEL expression.
169191
- clientSelectors:

internal/controller/gateway.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ func (c *GatewayController) reconcileFilterConfigSecret(
260260
switch cost.Type {
261261
case aigv1a1.LLMRequestCostTypeInputToken:
262262
fc.Type = filterapi.LLMRequestCostTypeInputToken
263+
case aigv1a1.LLMRequestCostTypeCachedInputToken:
264+
fc.Type = filterapi.LLMRequestCostTypeCachedInputToken
263265
case aigv1a1.LLMRequestCostTypeOutputToken:
264266
fc.Type = filterapi.LLMRequestCostTypeOutputToken
265267
case aigv1a1.LLMRequestCostTypeTotalToken:

internal/controller/gateway_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ func TestGatewayController_reconcileFilterConfigSecret(t *testing.T) {
194194
{MetadataKey: "foo", Type: aigv1a1.LLMRequestCostTypeInputToken},
195195
{MetadataKey: "bar", Type: aigv1a1.LLMRequestCostTypeOutputToken},
196196
{MetadataKey: "baz", Type: aigv1a1.LLMRequestCostTypeTotalToken},
197+
{MetadataKey: "qux", Type: aigv1a1.LLMRequestCostTypeCachedInputToken},
197198
},
198199
},
199200
},
@@ -267,12 +268,13 @@ func TestGatewayController_reconcileFilterConfigSecret(t *testing.T) {
267268
require.True(t, ok)
268269
var fc filterapi.Config
269270
require.NoError(t, yaml.Unmarshal([]byte(configStr), &fc))
270-
require.Len(t, fc.LLMRequestCosts, 4)
271+
require.Len(t, fc.LLMRequestCosts, 5)
271272
require.Equal(t, filterapi.LLMRequestCostTypeInputToken, fc.LLMRequestCosts[0].Type)
272273
require.Equal(t, filterapi.LLMRequestCostTypeOutputToken, fc.LLMRequestCosts[1].Type)
273274
require.Equal(t, filterapi.LLMRequestCostTypeTotalToken, fc.LLMRequestCosts[2].Type)
274-
require.Equal(t, filterapi.LLMRequestCostTypeCEL, fc.LLMRequestCosts[3].Type)
275-
require.Equal(t, `backend == 'foo.default' ? input_tokens + output_tokens : total_tokens`, fc.LLMRequestCosts[3].CEL)
275+
require.Equal(t, filterapi.LLMRequestCostTypeCachedInputToken, fc.LLMRequestCosts[3].Type)
276+
require.Equal(t, filterapi.LLMRequestCostTypeCEL, fc.LLMRequestCosts[4].Type)
277+
require.Equal(t, `backend == 'foo.default' ? input_tokens + output_tokens : total_tokens`, fc.LLMRequestCosts[4].CEL)
276278
require.Len(t, fc.Models, 1)
277279
require.Equal(t, "mymodel", fc.Models[0].Name)
278280

internal/extproc/chatcompletion_processor.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,13 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessResponseBody(ctx context.
410410
c.metrics.RecordTokenLatency(ctx, tokenUsage.OutputTokens, body.EndOfStream, c.requestHeaders)
411411
// Emit usage once at end-of-stream using final totals.
412412
if body.EndOfStream {
413-
c.metrics.RecordTokenUsage(ctx, c.costs.InputTokens, c.costs.OutputTokens, c.requestHeaders)
413+
c.metrics.RecordTokenUsage(ctx, c.costs.InputTokens, c.costs.CachedInputTokens, c.costs.OutputTokens, c.requestHeaders)
414414
}
415415
// TODO: if c.forcedStreamOptionIncludeUsage is true, we should not include usage in the response body since
416416
// that's what the clients would expect. However, it is a little bit tricky as we simply just reading the streaming
417417
// chunk by chunk, we only want to drop a specific line before the last chunk.
418418
} else {
419-
c.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.OutputTokens, c.requestHeaders)
419+
c.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.CachedInputTokens, tokenUsage.OutputTokens, c.requestHeaders)
420420
}
421421

422422
if body.EndOfStream && len(c.config.requestCosts) > 0 {
@@ -536,6 +536,8 @@ func buildDynamicMetadata(config *processorConfig, costs *translator.LLMTokenUsa
536536
switch rc.Type {
537537
case filterapi.LLMRequestCostTypeInputToken:
538538
cost = costs.InputTokens
539+
case filterapi.LLMRequestCostTypeCachedInputToken:
540+
cost = costs.CachedInputTokens
539541
case filterapi.LLMRequestCostTypeOutputToken:
540542
cost = costs.OutputTokens
541543
case filterapi.LLMRequestCostTypeTotalToken:
@@ -546,6 +548,7 @@ func buildDynamicMetadata(config *processorConfig, costs *translator.LLMTokenUsa
546548
requestHeaders[internalapi.ModelNameHeaderKeyDefault],
547549
backendName,
548550
costs.InputTokens,
551+
costs.CachedInputTokens,
549552
costs.OutputTokens,
550553
costs.TotalTokens,
551554
)

internal/extproc/chatcompletion_processor_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
259259
mt := &mockTranslator{
260260
t: t, expResponseBody: inBody,
261261
retBodyMutation: expBodyMut, retHeaderMutation: expHeadMut,
262-
retUsedToken: translator.LLMTokenUsage{OutputTokens: 123, InputTokens: 1},
262+
retUsedToken: translator.LLMTokenUsage{OutputTokens: 123, InputTokens: 1, CachedInputTokens: 1},
263263
}
264264

265265
celProgInt, err := llmcostcel.NewProgram("54321")
@@ -275,6 +275,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
275275
requestCosts: []processorConfigRequestCost{
276276
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}},
277277
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}},
278+
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCachedInputToken, MetadataKey: "cached_input_token_usage"}},
278279
{
279280
celProg: celProgInt,
280281
LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"},
@@ -304,6 +305,8 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
304305
GetStructValue().Fields["output_token_usage"].GetNumberValue())
305306
require.Equal(t, float64(1), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
306307
GetStructValue().Fields["input_token_usage"].GetNumberValue())
308+
require.Equal(t, float64(1), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
309+
GetStructValue().Fields["cached_input_token_usage"].GetNumberValue())
307310
require.Equal(t, float64(54321), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
308311
GetStructValue().Fields["cel_int"].GetNumberValue())
309312
require.Equal(t, float64(9999), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
@@ -356,11 +359,12 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
356359
// Final chunk should mark success and record usage once.
357360
final := &extprocv3.HttpBody{Body: []byte("chunk-final"), EndOfStream: true}
358361
mt.expResponseBody = final
359-
mt.retUsedToken = translator.LLMTokenUsage{InputTokens: 5, OutputTokens: 138, TotalTokens: 143}
362+
mt.retUsedToken = translator.LLMTokenUsage{InputTokens: 5, CachedInputTokens: 3, OutputTokens: 138, TotalTokens: 143}
360363
_, err = p.ProcessResponseBody(t.Context(), final)
361364
require.NoError(t, err)
362365
mm.RequireRequestSuccess(t)
363366
require.Equal(t, 143, mm.tokenUsageCount) // 5 input + 138 output
367+
require.Equal(t, 3, mm.cachedInputCount) // cached input tokens
364368
require.Equal(t, 138, mm.streamingOutputTokens) // accumulated output tokens from stream
365369
})
366370
}

internal/extproc/messages_processor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex
302302
c.costs.TotalTokens += tokenUsage.TotalTokens
303303

304304
// Update metrics with token usage.
305-
c.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.OutputTokens, c.requestHeaders)
305+
c.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.CachedInputTokens, tokenUsage.OutputTokens, c.requestHeaders)
306306
if c.stream {
307307
c.metrics.RecordTokenLatency(ctx, tokenUsage.OutputTokens, body.EndOfStream, c.requestHeaders)
308308
}

internal/extproc/mocks_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ type mockChatCompletionMetrics struct {
172172
backend string
173173
requestSuccessCount int
174174
requestErrorCount int
175+
cachedInputCount int
175176
tokenUsageCount int
176177
// streamingOutputTokens tracks the cumulative output tokens recorded via RecordTokenLatency.
177178
streamingOutputTokens int
@@ -201,8 +202,9 @@ func (m *mockChatCompletionMetrics) SetResponseModel(responseModel internalapi.R
201202
func (m *mockChatCompletionMetrics) SetBackend(backend *filterapi.Backend) { m.backend = backend.Name }
202203

203204
// RecordTokenUsage implements [metrics.ChatCompletion].
204-
func (m *mockChatCompletionMetrics) RecordTokenUsage(_ context.Context, input, output uint32, _ map[string]string) {
205+
func (m *mockChatCompletionMetrics) RecordTokenUsage(_ context.Context, input, cachedInput, output uint32, _ map[string]string) {
205206
m.tokenUsageCount += int(input + output)
207+
m.cachedInputCount += int(cachedInput)
206208
}
207209

208210
// RecordTokenLatency implements [metrics.ChatCompletion].

internal/extproc/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestServer_LoadConfig(t *testing.T) {
8484
require.Equal(t, "1 + 1", s.config.requestCosts[1].CEL)
8585
prog := s.config.requestCosts[1].celProg
8686
require.NotNil(t, prog)
87-
val, err := llmcostcel.EvaluateProgram(prog, "", "", 1, 1, 1)
87+
val, err := llmcostcel.EvaluateProgram(prog, "", "", 1, 1, 1, 1)
8888
require.NoError(t, err)
8989
require.Equal(t, uint64(2), val)
9090
require.Equal(t, config.Models, s.config.declaredModels)

0 commit comments

Comments
 (0)