Skip to content

Commit 1078890

Browse files
authored
mcp: make session encryption seed rotatable (#1357)
1 parent c1ea619 commit 1078890

File tree

10 files changed

+149
-19
lines changed

10 files changed

+149
-19
lines changed

cmd/controller/main.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ type flags struct {
5555
// extProcMaxRecvMsgSize is the maximum message size in bytes that the gRPC server can receive.
5656
extProcMaxRecvMsgSize int
5757
// maxRecvMsgSize is the maximum message size in bytes that the gRPC extension server can receive.
58-
maxRecvMsgSize int
59-
watchNamespaces []string
60-
cacheSyncTimeout time.Duration
58+
maxRecvMsgSize int
59+
mcpSessionEncryptionSeed string
60+
watchNamespaces []string
61+
cacheSyncTimeout time.Duration
6162
}
6263

6364
// parsePullPolicy parses string into a k8s PullPolicy.
@@ -187,6 +188,13 @@ func parseAndValidateFlags(args []string) (flags, error) {
187188
2*time.Minute, // This is the controller-runtime default
188189
"Maximum time to wait for k8s caches to sync",
189190
)
191+
mcpSessionEncryptionSeed := fs.String(
192+
"mcpSessionEncryptionSeed",
193+
"seed",
194+
"Arbitrary string seed used to derive the MCP session encryption key. "+
195+
"Do not include commas as they are used as separators. You can optionally pass \"fallback\" seed after the first one to allow for key rotation. "+
196+
"For example: \"new-seed,old-seed-for-fallback\". The fallback seed is only used for decryption.",
197+
)
190198

191199
if err := fs.Parse(args); err != nil {
192200
err = fmt.Errorf("failed to parse flags: %w", err)
@@ -268,6 +276,7 @@ func parseAndValidateFlags(args []string) (flags, error) {
268276
maxRecvMsgSize: *maxRecvMsgSize,
269277
watchNamespaces: parseWatchNamespaces(*watchNamespaces),
270278
cacheSyncTimeout: *cacheSyncTimeout,
279+
mcpSessionEncryptionSeed: *mcpSessionEncryptionSeed,
271280
}, nil
272281
}
273282

@@ -355,6 +364,7 @@ func main() {
355364
ExtProcExtraEnvVars: parsedFlags.extProcExtraEnvVars,
356365
ExtProcImagePullSecrets: parsedFlags.extProcImagePullSecrets,
357366
ExtProcMaxRecvMsgSize: parsedFlags.extProcMaxRecvMsgSize,
367+
MCPSessionEncryptionSeed: parsedFlags.mcpSessionEncryptionSeed,
358368
}); err != nil {
359369
setupLog.Error(err, "failed to start controller")
360370
}

cmd/extproc/mainlib/main.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ func parseAndValidateFlags(args []string) (extProcFlags, error) {
105105
fs.StringVar(&flags.mcpSessionEncryptionSeed,
106106
"mcpSessionEncryptionSeed",
107107
"mcp",
108-
"seed used to derive the MCP session encryption key. This should be set to a secure value in production.",
108+
"Arbitrary string seed used to derive the MCP session encryption key. "+
109+
"Do not include commas as they are used as separators. You can optionally pass \"fallback\" seed after the first one to allow for key rotation. "+
110+
"For example: \"new-seed,old-seed-for-fallback\". The fallback seed is only used for decryption.",
109111
)
110112
fs.DurationVar(&flags.mcpWriteTimeout, "mcpWriteTimeout", 120*time.Second,
111113
"The maximum duration before timing out writes of the MCP response")
@@ -259,7 +261,8 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
259261

260262
var mcpServer *http.Server
261263
if mcpLis != nil {
262-
mcpSessionCrypto := mcpproxy.DefaultSessionCrypto(flags.mcpSessionEncryptionSeed)
264+
seed, fallbackSeed, _ := strings.Cut(flags.mcpSessionEncryptionSeed, ",")
265+
mcpSessionCrypto := mcpproxy.DefaultSessionCrypto(seed, fallbackSeed)
263266
var mcpProxyMux *http.ServeMux
264267
var mcpProxy *mcpproxy.MCPProxy
265268
mcpProxy, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics,

internal/controller/controller.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ type Options struct {
8585
ExtProcImagePullSecrets string
8686
// ExtProcMaxRecvMsgSize is the maximum message size in bytes that the gRPC server can receive for extProc.
8787
ExtProcMaxRecvMsgSize int
88+
// MCPSessionEncryptionSeed is the seed used to derive the encryption key for MCP session encryption.
89+
MCPSessionEncryptionSeed string
8890
}
8991

9092
// StartControllers starts the controllers for the AI Gateway.
@@ -215,6 +217,7 @@ func StartControllers(ctx context.Context, mgr manager.Manager, config *rest.Con
215217
options.ExtProcImagePullSecrets,
216218
options.ExtProcMaxRecvMsgSize,
217219
isKubernetes133OrLater(versionInfo, logger),
220+
options.MCPSessionEncryptionSeed,
218221
))
219222
mgr.GetWebhookServer().Register("/mutate", &webhook.Admission{Handler: h})
220223
}

internal/controller/gateway_mutator.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ type gatewayMutator struct {
4545
extProcImagePullSecrets []corev1.LocalObjectReference
4646
extProcMaxRecvMsgSize int
4747

48+
// mcpSessionEncryptionSeed is the seed used to derive the encryption key for MCP session data.
49+
mcpSessionEncryptionSeed string
50+
4851
// Whether to run the extProc container as a sidecar (true) as a normal container (false).
4952
// This is essentially a workaround for old k8s versions, and we can remove this in the future.
5053
extProcAsSideCar bool
@@ -54,6 +57,7 @@ func newGatewayMutator(c client.Client, kube kubernetes.Interface, logger logr.L
5457
extProcImage string, extProcImagePullPolicy corev1.PullPolicy, extProcLogLevel,
5558
udsPath, metricsRequestHeaderAttributes, spanRequestHeaderAttributes, rootPrefix, extProcExtraEnvVars, extProcImagePullSecrets string, extProcMaxRecvMsgSize int,
5659
extProcAsSideCar bool,
60+
mcpSessionEncryptionSeed string,
5761
) *gatewayMutator {
5862
var parsedEnvVars []corev1.EnvVar
5963
if extProcExtraEnvVars != "" {
@@ -90,6 +94,7 @@ func newGatewayMutator(c client.Client, kube kubernetes.Interface, logger logr.L
9094
extProcImagePullSecrets: parsedImagePullSecrets,
9195
extProcMaxRecvMsgSize: extProcMaxRecvMsgSize,
9296
extProcAsSideCar: extProcAsSideCar,
97+
mcpSessionEncryptionSeed: mcpSessionEncryptionSeed,
9398
}
9499
}
95100

@@ -123,7 +128,8 @@ func (g *gatewayMutator) buildExtProcArgs(filterConfigFullPath string, extProcAd
123128
"-maxRecvMsgSize", fmt.Sprintf("%d", g.extProcMaxRecvMsgSize),
124129
}
125130
if needMCP {
126-
args = append(args, "-mcpAddr", ":"+strconv.Itoa(internalapi.MCPProxyPort))
131+
args = append(args, "-mcpAddr", ":"+strconv.Itoa(internalapi.MCPProxyPort),
132+
"-mcpSessionEncryptionSeed", g.mcpSessionEncryptionSeed)
127133
}
128134

129135
// Add metrics header label mapping if configured.

internal/controller/gateway_mutator_test.go

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package controller
77

88
import (
99
"fmt"
10+
"strconv"
1011
"testing"
1112

1213
"github.com/stretchr/testify/require"
@@ -21,6 +22,7 @@ import (
2122
gwapiv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2"
2223

2324
aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1"
25+
"github.com/envoyproxy/ai-gateway/internal/internalapi"
2426
)
2527

2628
func TestGatewayMutator_Default(t *testing.T) {
@@ -51,6 +53,7 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
5153
extProcImagePullSecrets string
5254
extprocTest func(t *testing.T, container corev1.Container)
5355
podTest func(t *testing.T, pod corev1.Pod)
56+
needMCP bool
5457
}{
5558
{
5659
name: "basic extproc container",
@@ -61,6 +64,25 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
6164
require.Empty(t, pod.Spec.ImagePullSecrets)
6265
},
6366
},
67+
{
68+
name: "basic extproc container with MCPRoute",
69+
needMCP: true,
70+
extprocTest: func(t *testing.T, container corev1.Container) {
71+
var foundMCPAddr, foundMCPSeed bool
72+
for i, arg := range container.Args {
73+
switch arg {
74+
case "-mcpAddr":
75+
foundMCPAddr = true
76+
require.Equal(t, ":"+strconv.Itoa(internalapi.MCPProxyPort), container.Args[i+1])
77+
case "-mcpSessionEncryptionSeed":
78+
foundMCPSeed = true
79+
require.Equal(t, "seed", container.Args[i+1])
80+
}
81+
}
82+
require.True(t, foundMCPAddr)
83+
require.True(t, foundMCPSeed)
84+
},
85+
},
6486
{
6587
name: "with extra env vars",
6688
extProcExtraEnvVars: "OTEL_SERVICE_NAME=ai-gateway-extproc;OTEL_TRACES_EXPORTER=otlp",
@@ -190,6 +212,22 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
190212
})
191213
require.NoError(t, err)
192214

215+
if tt.needMCP {
216+
err = fakeClient.Create(t.Context(), &aigv1a1.MCPRoute{
217+
ObjectMeta: metav1.ObjectMeta{Name: "test-mcp", Namespace: gwNamespace},
218+
Spec: aigv1a1.MCPRouteSpec{
219+
ParentRefs: []gwapiv1a2.ParentReference{
220+
{
221+
Name: gwName,
222+
Kind: ptr.To(gwapiv1a2.Kind("Gateway")),
223+
Group: ptr.To(gwapiv1a2.Group("gateway.networking.k8s.io")),
224+
},
225+
},
226+
},
227+
})
228+
require.NoError(t, err)
229+
}
230+
193231
pod := &corev1.Pod{
194232
ObjectMeta: metav1.ObjectMeta{Name: "test-pod", Namespace: "test-namespace"},
195233
Spec: corev1.PodSpec{
@@ -227,7 +265,9 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
227265

228266
require.Equal(t, "ai-gateway-extproc", extProcContainer.Name)
229267
tt.extprocTest(t, extProcContainer)
230-
tt.podTest(t, *pod)
268+
if tt.podTest != nil {
269+
tt.podTest(t, *pod)
270+
}
231271
})
232272
}
233273
})
@@ -239,7 +279,7 @@ func newTestGatewayMutator(fakeClient client.Client, fakeKube *fake2.Clientset,
239279
return newGatewayMutator(
240280
fakeClient, fakeKube, ctrl.Log, "docker.io/envoyproxy/ai-gateway-extproc:latest", corev1.PullIfNotPresent,
241281
"info", "/tmp/extproc.sock", metricsRequestHeaderAttributes, spanRequestHeaderAttributes, "/v1", extProcExtraEnvVars, extProcImagePullSecrets, 512*1024*1024,
242-
sidecar,
282+
sidecar, "seed",
243283
)
244284
}
245285

internal/mcpproxy/crypto.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,25 @@ type SessionCrypto interface {
2626
}
2727

2828
// DefaultSessionCrypto returns a SessionCrypto implementation using PBKDF2 for key derivation and AES-GCM for encryption.
29-
func DefaultSessionCrypto(seed string) SessionCrypto {
30-
return pbkdf2AesGcm{
29+
func DefaultSessionCrypto(seed, fallbackSeed string) SessionCrypto {
30+
primary := &pbkdf2AesGcm{
3131
seed: seed, // Seed used to derive the encryption key.
3232
saltSize: 16, // Salt size for PBKDF2.
3333
keyLength: 32, // Key length for AES-256.
3434
iterations: 100_000, // Number of PBKDF2 iterations (trade security vs performance).
3535
}
36+
if fallbackSeed == "" {
37+
return primary
38+
}
39+
return &fallbackEnabledSessionCrypto{
40+
primary: primary,
41+
fallback: &pbkdf2AesGcm{
42+
seed: fallbackSeed,
43+
saltSize: 16,
44+
keyLength: 32,
45+
iterations: 100_000,
46+
},
47+
}
3648
}
3749

3850
// pbkdf2AesGcm implements SessionCrypto using PBKDF2 for key derivation and AES-GCM for encryption.
@@ -119,3 +131,26 @@ func (p pbkdf2AesGcm) Decrypt(encrypted string) (string, error) {
119131
}
120132
return string(plaintext), nil
121133
}
134+
135+
// fallbackEnabledSessionCrypto tries to decrypt using the primary SessionCrypto first for decryption.
136+
// If that fails and a fallback SessionCrypto is provided, it tries to decrypt using the fallback.
137+
type fallbackEnabledSessionCrypto struct {
138+
primary, fallback SessionCrypto
139+
}
140+
141+
// Encrypt always uses the primary SessionCrypto.
142+
func (f fallbackEnabledSessionCrypto) Encrypt(plaintext string) (string, error) {
143+
return f.primary.Encrypt(plaintext)
144+
}
145+
146+
// Decrypt tries the primary SessionCrypto first, and if that fails and a fallback is provided, it tries the fallback.
147+
func (f fallbackEnabledSessionCrypto) Decrypt(encrypted string) (string, error) {
148+
plaintext, err := f.primary.Decrypt(encrypted)
149+
if err == nil {
150+
return plaintext, nil
151+
}
152+
if f.fallback != nil {
153+
return f.fallback.Decrypt(encrypted)
154+
}
155+
return "", err
156+
}

internal/mcpproxy/crypto_test.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
)
1313

1414
func TestSessionEncryption(t *testing.T) {
15-
sc := DefaultSessionCrypto("test")
15+
sc := DefaultSessionCrypto("test", "")
1616

1717
enc, err := sc.Encrypt("plaintext")
1818
require.NoError(t, err)
@@ -23,7 +23,7 @@ func TestSessionEncryption(t *testing.T) {
2323
}
2424

2525
func TestEncryptionIsSalted(t *testing.T) {
26-
sc := DefaultSessionCrypto("test")
26+
sc := DefaultSessionCrypto("test", "")
2727

2828
enc1, err := sc.Encrypt("plaintext")
2929
require.NoError(t, err)
@@ -34,8 +34,8 @@ func TestEncryptionIsSalted(t *testing.T) {
3434
}
3535

3636
func TestDecryptWrongSeed(t *testing.T) {
37-
sc1 := DefaultSessionCrypto("test1")
38-
sc2 := DefaultSessionCrypto("test2")
37+
sc1 := DefaultSessionCrypto("test1", "")
38+
sc2 := DefaultSessionCrypto("test2", "")
3939

4040
enc, err := sc1.Encrypt("plaintext")
4141
require.NoError(t, err)
@@ -45,9 +45,34 @@ func TestDecryptWrongSeed(t *testing.T) {
4545
require.Empty(t, dec)
4646
}
4747

48+
func TestDecryptFallbackSeed(t *testing.T) {
49+
sc1 := DefaultSessionCrypto("test1", "")
50+
sc2 := DefaultSessionCrypto("test2", "test1")
51+
52+
// Decrypting should work with the fallback seed.
53+
enc, err := sc1.Encrypt("plaintext")
54+
require.NoError(t, err)
55+
dec, err := sc2.Decrypt(enc)
56+
require.NoError(t, err)
57+
require.Equal(t, "plaintext", dec)
58+
59+
// Encrypting should happen with the latest seed.
60+
enc2, err := sc2.Encrypt("plaintext2")
61+
require.NoError(t, err)
62+
require.NotEqual(t, enc, enc2)
63+
64+
dec2, err := sc1.Decrypt(enc2)
65+
require.Error(t, err)
66+
require.Empty(t, dec2)
67+
68+
dec2, err = sc2.Decrypt(enc2)
69+
require.NoError(t, err)
70+
require.Equal(t, "plaintext2", dec2)
71+
}
72+
4873
func TestDecryptDifferentInstancesSameSeed(t *testing.T) {
49-
sc1 := DefaultSessionCrypto("test")
50-
sc2 := DefaultSessionCrypto("test")
74+
sc1 := DefaultSessionCrypto("test", "")
75+
sc2 := DefaultSessionCrypto("test", "")
5176

5277
enc, err := sc1.Encrypt("plaintext")
5378
require.NoError(t, err)

internal/mcpproxy/handlers_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func newTestMCPProxy() *MCPProxy {
4242

4343
func newTestMCPProxyWithTracer(t tracingapi.MCPTracer) *MCPProxy {
4444
return &MCPProxy{
45-
sessionCrypto: DefaultSessionCrypto("test"),
45+
sessionCrypto: DefaultSessionCrypto("test", ""),
4646
mcpProxyConfig: &mcpProxyConfig{
4747
backendListenerAddr: "http://test-backend",
4848
routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{

internal/mcpproxy/mcpproxy_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ var noopTracer = tracing.NoopMCPTracer{}
6161

6262
func TestNewMCPProxy(t *testing.T) {
6363
l := slog.Default()
64-
proxy, mux, err := NewMCPProxy(l, stubMetrics{}, noopTracer, DefaultSessionCrypto("test"))
64+
proxy, mux, err := NewMCPProxy(l, stubMetrics{}, noopTracer, DefaultSessionCrypto("test", ""))
6565

6666
require.NoError(t, err)
6767
require.NotNil(t, proxy)
@@ -72,7 +72,7 @@ func TestNewMCPProxy(t *testing.T) {
7272

7373
func TestMCPProxy_HTTPMethods(t *testing.T) {
7474
l := slog.Default()
75-
_, mux, err := NewMCPProxy(l, stubMetrics{}, noopTracer, DefaultSessionCrypto("test"))
75+
_, mux, err := NewMCPProxy(l, stubMetrics{}, noopTracer, DefaultSessionCrypto("test", ""))
7676
require.NoError(t, err)
7777

7878
// Test unsupported method.

manifests/charts/ai-gateway-helm/values.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,16 @@ controller:
165165
nodeSelector: {}
166166
tolerations: []
167167
affinity: {}
168+
168169
# maxRecvMsgSize is the maximum message size in bytes that the gRPC extension server can receive
169170
# from xDS (envoy-gateway).
170171
# This value should be increased in setups where count/complexity of xDS
171172
# resources (configuration) is big. Defaults to 4MB.
172173
# maxRecvMsgSize: "4194304"
174+
175+
# mcpSessionEncryptionSeed is an arbitrary string seed used to derive the MCP session encryption key.
176+
# Do not include commas as they are used as separators. You can optionally pass "fallback" seed after the first one to allow for key rotation.
177+
# For example: "new-seed,old-seed-for-fallback". The fallback seed is only used for decryption.
178+
#
179+
# This value should be set to a secure random string in production environments instead of the default below.
180+
mcpSessionEncryptionSeed: "default-insecure-seed"

0 commit comments

Comments
 (0)