Skip to content

Commit 46ac26b

Browse files
committed
cockroach + bq
1 parent a41472c commit 46ac26b

File tree

7 files changed

+39
-49
lines changed

7 files changed

+39
-49
lines changed

internal/server/mocks.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
2222
"github.com/googleapis/genai-toolbox/internal/prompts"
2323
"github.com/googleapis/genai-toolbox/internal/tools"
24+
"github.com/googleapis/genai-toolbox/internal/util"
2425
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2526
)
2627

@@ -34,7 +35,7 @@ type MockTool struct {
3435
requiresClientAuthrorization bool
3536
}
3637

37-
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
38+
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, util.ToolboxError) {
3839
mock := []any{t.Name}
3940
return mock, nil
4041
}

internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
143143
}
144144

145145
if !source.IsDatasetAllowed(projectId, datasetId) {
146-
return nil, util.NewClientServerError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), http.StatusInternalServerError, nil)
146+
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
147147
}
148148

149149
bqClient, _, err := source.RetrieveClientAndService(accessToken)

internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,26 +104,27 @@ type Tool struct {
104104
mcpManifest tools.McpManifest
105105
}
106106

107-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
107+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
108108
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
109109
if err != nil {
110-
return nil, err
110+
return nil, util.NewClientServerError("source used is not compatible with the tool", 500, err)
111111
}
112112

113113
paramsMap := params.AsMap()
114114
sql, ok := paramsMap["sql"].(string)
115115
if !ok {
116-
return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"])
116+
return nil, util.NewAgentError(fmt.Sprintf("parameter 'sql' is required, unable to cast %v", paramsMap["sql"]), nil)
117117
}
118+
118119
logger, err := util.LoggerFromContext(ctx)
119120
if err != nil {
120-
return nil, fmt.Errorf("error getting logger: %s", err)
121+
return nil, util.NewClientServerError("error getting logger", 500, err)
121122
}
122-
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
123+
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", t.Type, sql))
123124

124125
results, err := source.Query(ctx, sql)
125126
if err != nil {
126-
return nil, fmt.Errorf("unable to execute query: %w", err)
127+
return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err))
127128
}
128129
defer results.Close()
129130

@@ -133,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
133134
for results.Next() {
134135
v, err := results.Values()
135136
if err != nil {
136-
return nil, fmt.Errorf("unable to parse row: %w", err)
137+
return nil, util.NewClientServerError("unable to parse row", 500, err)
137138
}
138139
row := orderedmap.Row{}
139140
for i, f := range fields {
@@ -143,16 +144,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
143144
}
144145

145146
if err := results.Err(); err != nil {
146-
return nil, fmt.Errorf("unable to execute query: %w", err)
147+
return nil, util.ProcessGeneralError(fmt.Errorf("error during row iteration: %w", err))
147148
}
148149

149150
return out, nil
150151
}
151152

152-
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
153-
return parameters.ParseParams(t.Parameters, data, claims)
154-
}
155-
156153
func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
157154
return params, nil
158155
}

internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/googleapis/genai-toolbox/internal/sources"
2424
"github.com/googleapis/genai-toolbox/internal/sources/cockroachdb"
2525
"github.com/googleapis/genai-toolbox/internal/tools"
26+
"github.com/googleapis/genai-toolbox/internal/util"
2627
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2728
"github.com/jackc/pgx/v5"
2829
)
@@ -116,15 +117,15 @@ type Tool struct {
116117
mcpManifest tools.McpManifest
117118
}
118119

119-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
120+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
120121
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
121122
if err != nil {
122-
return nil, err
123+
return nil, util.NewClientServerError("source used is not compatible with the tool", 500, err)
123124
}
124125

125126
results, err := source.Query(ctx, listSchemasStatement)
126127
if err != nil {
127-
return nil, fmt.Errorf("unable to execute query: %w", err)
128+
return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err))
128129
}
129130
defer results.Close()
130131

@@ -134,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
134135
for results.Next() {
135136
values, err := results.Values()
136137
if err != nil {
137-
return nil, fmt.Errorf("unable to parse row: %w", err)
138+
return nil, util.NewClientServerError("unable to parse row", 500, err)
138139
}
139140
rowMap := make(map[string]any)
140141
for i, field := range fields {
@@ -144,16 +145,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
144145
}
145146

146147
if err := results.Err(); err != nil {
147-
return nil, fmt.Errorf("error reading query results: %w", err)
148+
return nil, util.ProcessGeneralError(fmt.Errorf("error reading query results: %w", err))
148149
}
149150

150151
return out, nil
151152
}
152153

153-
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
154-
return parameters.ParseParams(t.AllParams, data, claims)
155-
}
156-
157154
func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
158155
return params, nil
159156
}

internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/googleapis/genai-toolbox/internal/sources"
2424
"github.com/googleapis/genai-toolbox/internal/sources/cockroachdb"
2525
"github.com/googleapis/genai-toolbox/internal/tools"
26+
"github.com/googleapis/genai-toolbox/internal/util"
2627
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2728
"github.com/jackc/pgx/v5"
2829
)
@@ -179,26 +180,26 @@ type Tool struct {
179180
mcpManifest tools.McpManifest
180181
}
181182

182-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
183+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
183184
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
184185
if err != nil {
185-
return nil, err
186+
return nil, util.NewClientServerError("source used is not compatible with the tool", 500, err)
186187
}
187188

188189
paramsMap := params.AsMap()
189190

190191
tableNames, ok := paramsMap["table_names"].(string)
191192
if !ok {
192-
return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string")
193+
return nil, util.NewAgentError("invalid 'table_names' parameter; expected a string", nil)
193194
}
194195
outputFormat, _ := paramsMap["output_format"].(string)
195196
if outputFormat != "simple" && outputFormat != "detailed" {
196-
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
197+
return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil)
197198
}
198199

199200
results, err := source.Query(ctx, listTablesStatement, tableNames, outputFormat)
200201
if err != nil {
201-
return nil, fmt.Errorf("unable to execute query: %w", err)
202+
return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err))
202203
}
203204
defer results.Close()
204205

@@ -208,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
208209
for results.Next() {
209210
values, err := results.Values()
210211
if err != nil {
211-
return nil, fmt.Errorf("unable to parse row: %w", err)
212+
return nil, util.NewClientServerError("unable to parse row", 500, err)
212213
}
213214
rowMap := make(map[string]any)
214215
for i, field := range fields {
@@ -218,16 +219,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
218219
}
219220

220221
if err := results.Err(); err != nil {
221-
return nil, fmt.Errorf("error reading query results: %w", err)
222+
return nil, util.ProcessGeneralError(fmt.Errorf("error reading query results: %w", err))
222223
}
223224

224225
return out, nil
225226
}
226227

227-
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
228-
return parameters.ParseParams(t.AllParams, data, claims)
229-
}
230-
231228
func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
232229
return params, nil
233230
}

internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/googleapis/genai-toolbox/internal/sources"
2424
"github.com/googleapis/genai-toolbox/internal/sources/cockroachdb"
2525
"github.com/googleapis/genai-toolbox/internal/tools"
26+
"github.com/googleapis/genai-toolbox/internal/util"
2627
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
2728
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2829
"github.com/jackc/pgx/v5"
@@ -110,26 +111,26 @@ type Tool struct {
110111
mcpManifest tools.McpManifest
111112
}
112113

113-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
114+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
114115
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
115116
if err != nil {
116-
return nil, err
117+
return nil, util.NewClientServerError("source used is not compatible with the tool", 500, err)
117118
}
118119

119120
paramsMap := params.AsMap()
120121
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
121122
if err != nil {
122-
return nil, fmt.Errorf("unable to extract template params %w", err)
123+
return nil, util.NewAgentError(fmt.Sprintf("unable to resolve template params: %v", err), err)
123124
}
124125

125126
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
126127
if err != nil {
127-
return nil, fmt.Errorf("unable to extract standard params %w", err)
128+
return nil, util.NewAgentError(fmt.Sprintf("unable to extract standard params: %v", err), err)
128129
}
129130
sliceParams := newParams.AsSlice()
130131
results, err := source.Query(ctx, newStatement, sliceParams...)
131132
if err != nil {
132-
return nil, fmt.Errorf("unable to execute query: %w", err)
133+
return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err))
133134
}
134135
defer results.Close()
135136

@@ -139,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
139140
for results.Next() {
140141
v, err := results.Values()
141142
if err != nil {
142-
return nil, fmt.Errorf("unable to parse row: %w", err)
143+
return nil, util.NewClientServerError("unable to parse row", 500, err)
143144
}
144145
row := orderedmap.Row{}
145146
for i, f := range fields {
@@ -149,16 +150,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
149150
}
150151

151152
if err := results.Err(); err != nil {
152-
return nil, fmt.Errorf("unable to execute query: %w", err)
153+
return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err))
153154
}
154155

155156
return out, nil
156157
}
157158

158-
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
159-
return parameters.ParseParams(t.AllParams, data, claims)
160-
}
161-
162159
func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
163160
return params, nil
164161
}

tests/bigquery/bigquery_integration_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,8 +1166,8 @@ func runBigQueryWriteModeBlockedTest(t *testing.T, tableNameParam, datasetName s
11661166
wantResult string
11671167
}{
11681168
{"SELECT statement should succeed", fmt.Sprintf("SELECT id, name FROM %s WHERE id = 1", tableNameParam), http.StatusOK, "", `[{"id":1,"name":"Alice"}]`},
1169-
{"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""},
1170-
{"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""},
1169+
{"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusOK, "write mode is 'blocked', only SELECT statements are allowed", ""},
1170+
{"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusOK, "write mode is 'blocked', only SELECT statements are allowed", ""},
11711171
}
11721172

11731173
for _, tc := range testCases {
@@ -1710,7 +1710,8 @@ func runBigQueryDataTypeTests(t *testing.T) {
17101710
api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke",
17111711
requestHeader: map[string]string{},
17121712
requestBody: bytes.NewBuffer([]byte(`{"int_val": 123}`)),
1713-
isErr: true,
1713+
want: `{\"error\":\"parameter \\\"string_val\\\" is required\"}`,
1714+
isErr: false,
17141715
},
17151716
{
17161717
name: "invoke my-array-datatype-tool",
@@ -2931,7 +2932,7 @@ func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, ta
29312932
api: "http://127.0.0.1:5000/api/tool/my-search-catalog-tool/invoke",
29322933
requestHeader: map[string]string{},
29332934
requestBody: bytes.NewBuffer([]byte(`{}`)),
2934-
isErr: true,
2935+
isErr: false,
29352936
},
29362937
{
29372938
name: "invoke my-search-catalog-tool",

0 commit comments

Comments
 (0)