Skip to content

Commit c80d8d6

Browse files
authored
feat: Add support for Map parameter type (#51)
* feat: Add support for Map parameter type * seperate validation of parameter schema from validation of input * fix tests * fix tests * fix tests * fix validate functino * fix validate functino * fix validate function * fix unit tests * fix unit tests * set version to 34 in tbgenkit tests * fix tests * fix tests * add AdditionalProperties in inputSchema * skip genkit test * update toolbox version to 0.12.0 * minor changes * minor changes * minor changes
1 parent 5664121 commit c80d8d6

File tree

10 files changed

+647
-44
lines changed

10 files changed

+647
-44
lines changed

core/client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,19 @@ func (tc *ToolboxClient) newToolboxTool(
9999

100100
// Iterate over the tool's parameters from the schema to categorize them.
101101
for _, p := range schema.Parameters {
102+
103+
if ap, ok := p.AdditionalProperties.(map[string]any); ok {
104+
apParam, err := mapToSchema(ap)
105+
if err != nil {
106+
return nil, nil, nil, err
107+
}
108+
p.AdditionalProperties = apParam
109+
}
110+
// Validate parameter schema
111+
if err := p.ValidateDefinition(); err != nil {
112+
// Return a detailed error indicating which tool failed validation.
113+
return nil, nil, nil, fmt.Errorf("invalid schema for tool '%s': %w", name, err)
114+
}
102115
paramSchema[p.Name] = struct{}{}
103116

104117
if len(p.AuthSources) > 0 {

core/e2e_setup_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ func getEnvVar(key string) string {
4343
return value
4444
}
4545

46-
func accessSecretVersion(ctx context.Context, projectID, secretID string) string {
46+
func accessSecretVersion(ctx context.Context, projectID, secretID string, version string) string {
4747
client, err := secretmanager.NewClient(ctx)
4848
if err != nil {
4949
log.Fatalf("Failed to create secretmanager client: %v", err)
5050
}
5151
defer client.Close()
5252

5353
req := &secretmanagerpb.AccessSecretVersionRequest{
54-
Name: fmt.Sprintf("projects/%s/secrets/%s/versions/latest", projectID, secretID),
54+
Name: fmt.Sprintf("projects/%s/secrets/%s/versions/%s", projectID, secretID, version),
5555
}
5656

5757
result, err := client.AccessSecretVersion(ctx, req)

core/e2e_test.go

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ func TestMain(m *testing.M) {
4343

4444
// Get secrets and auth tokens
4545
log.Println("Fetching secrets and auth tokens...")
46-
toolsManifestContent := accessSecretVersion(ctx, projectID, "sdk_testing_tools")
47-
clientID1 := accessSecretVersion(ctx, projectID, "sdk_testing_client1")
48-
clientID2 := accessSecretVersion(ctx, projectID, "sdk_testing_client2")
46+
toolsManifestContent := accessSecretVersion(ctx, projectID, "sdk_testing_tools", "34")
47+
clientID1 := accessSecretVersion(ctx, projectID, "sdk_testing_client1", "latest")
48+
clientID2 := accessSecretVersion(ctx, projectID, "sdk_testing_client2", "latest")
4949
authToken1 = getAuthToken(ctx, clientID1)
5050
authToken2 = getAuthToken(ctx, clientID2)
5151

@@ -78,7 +78,6 @@ func TestMain(m *testing.M) {
7878
os.Exit(exitCode)
7979
}
8080

81-
// TestE2E_Basic maps to the TestBasicE2E class
8281
func TestE2E_Basic(t *testing.T) {
8382
// Helper to create a new client for each sub-test, like a function-scoped fixture
8483
newClient := func(t *testing.T) *core.ToolboxClient {
@@ -132,7 +131,7 @@ func TestE2E_Basic(t *testing.T) {
132131
toolset, err := client.LoadToolset("", context.Background())
133132
require.NoError(t, err)
134133

135-
assert.Len(t, toolset, 6)
134+
assert.Len(t, toolset, 7)
136135
toolNames := make(map[string]struct{})
137136
for _, tool := range toolset {
138137
toolNames[tool.Name()] = struct{}{}
@@ -144,6 +143,7 @@ func TestE2E_Basic(t *testing.T) {
144143
"get-row-by-id": {},
145144
"get-n-rows": {},
146145
"search-rows": {},
146+
"process-data": {},
147147
}
148148
assert.Equal(t, expectedTools, toolNames)
149149
})
@@ -182,7 +182,6 @@ func TestE2E_Basic(t *testing.T) {
182182
})
183183
}
184184

185-
// TestE2E_BindParams maps to the TestBindParams class
186185
func TestE2E_BindParams(t *testing.T) {
187186
newClient := func(t *testing.T) *core.ToolboxClient {
188187
client, err := core.NewToolboxClient("http://localhost:5000")
@@ -236,7 +235,6 @@ func TestE2E_BindParams(t *testing.T) {
236235
})
237236
}
238237

239-
// TestE2E_Auth maps to the TestAuth class
240238
func TestE2E_Auth(t *testing.T) {
241239
newClient := func(t *testing.T) *core.ToolboxClient {
242240
client, err := core.NewToolboxClient("http://localhost:5000")
@@ -338,7 +336,6 @@ func TestE2E_Auth(t *testing.T) {
338336
})
339337
}
340338

341-
// TestE2E_OptionalParams maps to the TestOptionalParams class
342339
func TestE2E_OptionalParams(t *testing.T) {
343340
// Helper to create a new client
344341
newClient := func(t *testing.T) *core.ToolboxClient {
@@ -475,3 +472,112 @@ func TestE2E_OptionalParams(t *testing.T) {
475472
assert.Equal(t, "null", response, "Response should be null for non-matching data")
476473
})
477474
}
475+
476+
func TestE2E_MapParams(t *testing.T) {
477+
// Helper to create a new client
478+
newClient := func(t *testing.T) *core.ToolboxClient {
479+
client, err := core.NewToolboxClient("http://localhost:5000")
480+
require.NoError(t, err, "Failed to create ToolboxClient")
481+
return client
482+
}
483+
484+
// Helper to load the process-data tool
485+
processDataTool := func(t *testing.T, client *core.ToolboxClient) *core.ToolboxTool {
486+
tool, err := client.LoadTool("process-data", context.Background())
487+
require.NoError(t, err, "Failed to load tool 'process-data'")
488+
return tool
489+
}
490+
491+
t.Run("test_tool_schema_is_correct", func(t *testing.T) {
492+
client := newClient(t)
493+
tool := processDataTool(t, client)
494+
params := tool.Parameters()
495+
496+
// Convert slice to map for easy lookup
497+
paramMap := make(map[string]core.ParameterSchema)
498+
for _, p := range params {
499+
paramMap[p.Name] = p
500+
}
501+
502+
// Verify 'execution_context' parameter.
503+
execCtxParam, ok := paramMap["execution_context"]
504+
require.True(t, ok, "'execution_context' parameter should exist")
505+
assert.True(t, execCtxParam.Required, "'execution_context' should be required")
506+
assert.Equal(t, "object", execCtxParam.Type, "'execution_context' type should be object")
507+
508+
// Verify 'user_scores' parameter.
509+
userScoresParam, ok := paramMap["user_scores"]
510+
require.True(t, ok, "'user_scores' parameter should exist")
511+
assert.True(t, userScoresParam.Required, "'user_scores' should be required")
512+
assert.Equal(t, "object", userScoresParam.Type, "'user_scores' type should be object")
513+
514+
// Verify 'feature_flags' parameter.
515+
featureFlagsParam, ok := paramMap["feature_flags"]
516+
require.True(t, ok, "'feature_flags' parameter should exist")
517+
assert.False(t, featureFlagsParam.Required, "'feature_flags' should be optional")
518+
assert.Equal(t, "object", featureFlagsParam.Type, "'feature_flags' type should be object")
519+
})
520+
521+
t.Run("test_run_tool_with_all_map_params", func(t *testing.T) {
522+
client := newClient(t)
523+
tool := processDataTool(t, client)
524+
525+
// Invoke the tool with valid map parameters.
526+
response, err := tool.Invoke(context.Background(), map[string]any{
527+
"execution_context": map[string]any{
528+
"env": "prod",
529+
"id": 1234,
530+
"user": 1234.5,
531+
},
532+
"user_scores": map[string]any{
533+
"user1": 100,
534+
"user2": 200,
535+
},
536+
"feature_flags": map[string]any{
537+
"new_feature": true,
538+
},
539+
})
540+
require.NoError(t, err)
541+
respStr, ok := response.(string)
542+
require.True(t, ok, "Response should be a string")
543+
544+
assert.Contains(t, respStr, `"execution_context":{"env":"prod","id":1234,"user":1234.5}`)
545+
assert.Contains(t, respStr, `"user_scores":{"user1":100,"user2":200}`)
546+
assert.Contains(t, respStr, `"feature_flags":{"new_feature":true}`)
547+
})
548+
549+
t.Run("test_run_tool_omitting_optional_map", func(t *testing.T) {
550+
client := newClient(t)
551+
tool := processDataTool(t, client)
552+
553+
// Invoke the tool without the optional 'feature_flags' parameter.
554+
response, err := tool.Invoke(context.Background(), map[string]any{
555+
"execution_context": map[string]any{"env": "dev"},
556+
"user_scores": map[string]any{"user3": 300},
557+
})
558+
require.NoError(t, err)
559+
respStr, ok := response.(string)
560+
require.True(t, ok, "Response should be a string")
561+
562+
assert.Contains(t, respStr, `"execution_context":{"env":"dev"}`)
563+
assert.Contains(t, respStr, `"user_scores":{"user3":300}`)
564+
assert.Contains(t, respStr, `"feature_flags":null`)
565+
})
566+
567+
t.Run("test_run_tool_with_wrong_map_value_type", func(t *testing.T) {
568+
client := newClient(t)
569+
tool := processDataTool(t, client)
570+
571+
// Attempt to invoke the tool with an incorrect type in a map value.
572+
_, err := tool.Invoke(context.Background(), map[string]any{
573+
"execution_context": map[string]any{"env": "staging"},
574+
"user_scores": map[string]any{
575+
"user4": "not-an-integer",
576+
},
577+
})
578+
579+
// Assert that an error was returned.
580+
require.Error(t, err, "Expected an error for wrong map value type")
581+
assert.Contains(t, err.Error(), "expects an integer, but got string", "Error message should indicate a validation failure")
582+
})
583+
}

core/protocol.go

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ import (
2121

2222
// Schema for a tool parameter.
2323
type ParameterSchema struct {
24-
Name string `json:"name"`
25-
Type string `json:"type"`
26-
Required bool `json:"required,omitempty"`
27-
Description string `json:"description"`
28-
AuthSources []string `json:"authSources,omitempty"`
29-
Items *ParameterSchema `json:"items,omitempty"`
24+
Name string `json:"name"`
25+
Type string `json:"type"`
26+
Required bool `json:"required,omitempty"`
27+
Description string `json:"description"`
28+
AuthSources []string `json:"authSources,omitempty"`
29+
Items *ParameterSchema `json:"items,omitempty"`
30+
AdditionalProperties any `json:"additionalProperties,omitempty"`
3031
}
3132

3233
// validateType is a helper for manual type checking.
@@ -64,22 +65,90 @@ func (p *ParameterSchema) validateType(value any) error {
6465
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
6566
return fmt.Errorf("parameter '%s' expects an array/slice, but got %T", p.Name, value)
6667
}
67-
if p.Items == nil {
68-
return fmt.Errorf("parameter '%s' is an array but is missing item type definition", p.Name)
69-
}
7068
for i := range v.Len() {
7169
item := v.Index(i).Interface()
7270

7371
if err := p.Items.validateType(item); err != nil {
7472
return fmt.Errorf("error in array '%s' at index %d: %w", p.Name, i, err)
7573
}
7674
}
75+
case "object":
76+
// First, check that the value is a map with string keys.
77+
valMap, ok := value.(map[string]any)
78+
if !ok {
79+
return fmt.Errorf("parameter '%s' expects a map, but got %T", p.Name, value)
80+
}
81+
82+
switch ap := p.AdditionalProperties.(type) {
83+
// No validation of values
84+
case bool:
85+
86+
// Validate type for each value in map
87+
case *ParameterSchema:
88+
for key, val := range valMap {
89+
if err := ap.validateType(val); err != nil {
90+
return fmt.Errorf("error in object '%s' for key '%s': %w", p.Name, key, err)
91+
}
92+
}
93+
94+
default:
95+
// This is a schema / manifest error.
96+
return fmt.Errorf(
97+
"invalid schema for parameter '%s': AdditionalProperties must be a boolean or a map[string]any, but got %T",
98+
p.Name,
99+
ap,
100+
)
101+
}
77102
default:
78103
return fmt.Errorf("unknown type '%s' in schema for parameter '%s'", p.Type, p.Name)
79104
}
80105
return nil
81106
}
82107

108+
// ValidateDefinition checks if the schema itself is well-formed.
109+
func (p *ParameterSchema) ValidateDefinition() error {
110+
if p.Type == "" {
111+
return fmt.Errorf("schema validation failed for '%s': type is missing", p.Name)
112+
}
113+
114+
switch p.Type {
115+
case "array":
116+
if p.Items == nil {
117+
return fmt.Errorf("parameter '%s' is an array but is missing item type definition", p.Name)
118+
}
119+
// Recursively validate the nested schema's definition.
120+
if err := p.Items.ValidateDefinition(); err != nil {
121+
return err
122+
}
123+
124+
case "object":
125+
switch ap := p.AdditionalProperties.(type) {
126+
case bool:
127+
// Valid scenario
128+
case *ParameterSchema:
129+
if err := ap.ValidateDefinition(); err != nil {
130+
return err
131+
}
132+
default:
133+
// Any other type is an invalid schema definition.
134+
return fmt.Errorf(
135+
"invalid schema for parameter '%s': AdditionalProperties must be a boolean or a schema, but got %T",
136+
p.Name,
137+
ap,
138+
)
139+
}
140+
141+
case "string", "integer", "float", "boolean":
142+
// No type-specific rules for these.
143+
break
144+
145+
default:
146+
return fmt.Errorf("unknown schema type '%s' for parameter '%s'", p.Type, p.Name)
147+
}
148+
149+
return nil
150+
}
151+
83152
// Schema for a tool.
84153
type ToolSchema struct {
85154
Description string `json:"description"`

0 commit comments

Comments
 (0)