Skip to content

Commit e4f60e5

Browse files
authored
fix(embeddingModel): add embedding model to MCP handler (googleapis#2310)
- Add embedding model to mcp handlers - Add integration tests
1 parent d7af21b commit e4f60e5

File tree

10 files changed

+319
-6
lines changed

10 files changed

+319
-6
lines changed

.ci/integration.cloudbuild.yaml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ steps:
8787
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
8888
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
8989
secretEnv:
90-
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
90+
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
9191
volumes:
9292
- name: "go"
9393
path: "/gopath"
@@ -134,7 +134,7 @@ steps:
134134
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
135135
- "ALLOYDB_POSTGRES_REGION=$_REGION"
136136
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
137-
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
137+
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
138138
volumes:
139139
- name: "go"
140140
path: "/gopath"
@@ -305,7 +305,7 @@ steps:
305305
- "POSTGRES_HOST=$_POSTGRES_HOST"
306306
- "POSTGRES_PORT=$_POSTGRES_PORT"
307307
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
308-
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
308+
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
309309
volumes:
310310
- name: "go"
311311
path: "/gopath"
@@ -964,6 +964,13 @@ steps:
964964
965965
availableSecrets:
966966
secretManager:
967+
# Common secrets
968+
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
969+
env: CLIENT_ID
970+
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
971+
env: API_KEY
972+
973+
# Resource-specific secrets
967974
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
968975
env: CLOUD_SQL_POSTGRES_USER
969976
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
@@ -980,8 +987,6 @@ availableSecrets:
980987
env: POSTGRES_USER
981988
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
982989
env: POSTGRES_PASS
983-
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
984-
env: CLIENT_ID
985990
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
986991
env: NEO4J_USER
987992
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest

internal/server/mcp/v20241105/method.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
183183
}
184184
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
185185

186+
embeddingModels := resourceMgr.GetEmbeddingModelMap()
187+
params, err = tool.EmbedParams(ctx, params, embeddingModels)
188+
if err != nil {
189+
err = fmt.Errorf("error embedding parameters: %w", err)
190+
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
191+
}
192+
186193
// run tool invocation and generate response.
187194
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
188195
if err != nil {

internal/server/mcp/v20250326/method.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
183183
}
184184
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
185185

186+
embeddingModels := resourceMgr.GetEmbeddingModelMap()
187+
params, err = tool.EmbedParams(ctx, params, embeddingModels)
188+
if err != nil {
189+
err = fmt.Errorf("error embedding parameters: %w", err)
190+
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
191+
}
192+
186193
// run tool invocation and generate response.
187194
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
188195
if err != nil {

internal/server/mcp/v20250618/method.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
176176
}
177177
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
178178

179+
embeddingModels := resourceMgr.GetEmbeddingModelMap()
180+
params, err = tool.EmbedParams(ctx, params, embeddingModels)
181+
if err != nil {
182+
err = fmt.Errorf("error embedding parameters: %w", err)
183+
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
184+
}
185+
179186
// run tool invocation and generate response.
180187
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
181188
if err != nil {

internal/server/mcp/v20251125/method.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
176176
}
177177
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
178178

179+
embeddingModels := resourceMgr.GetEmbeddingModelMap()
180+
params, err = tool.EmbedParams(ctx, params, embeddingModels)
181+
if err != nil {
182+
err = fmt.Errorf("error embedding parameters: %w", err)
183+
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
184+
}
185+
179186
// run tool invocation and generate response.
180187
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
181188
if err != nil {

tests/alloydbpg/alloydb_pg_integration_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
147147
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
148148
defer teardownTable2(t)
149149

150+
// Set up table for semanti search
151+
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
152+
defer tearDownVectorTable(t)
153+
150154
// Write config into a file and pass it to command
151155
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
152156
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
153157
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
154158
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
155159

160+
// Add semantic search tool config
161+
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
162+
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
163+
156164
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
157165

158166
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)

tests/cloudsqlpg/cloud_sql_pg_integration_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
132132
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
133133
defer teardownTable2(t)
134134

135+
// Set up table for semantic search
136+
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
137+
defer tearDownVectorTable(t)
138+
135139
// Write config into a file and pass it to command
136140
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
137141
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
138142
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
139143
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
140144

145+
// Add semantic search tool config
146+
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
147+
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
148+
141149
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
142150
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
143151
if err != nil {
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
186194
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
187195
tests.RunPostgresListRolesTest(t, ctx, pool)
188196
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
197+
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
189198
}
190199

191200
// Test connection with different IP type

0 commit comments

Comments
 (0)