Skip to content

Commit f43d767

Browse files
authored
chore: Add support for optional parameters (#25)
* chore: Add support for optional parameters * change error message * minor fix
1 parent c02d3af commit f43d767

File tree

4 files changed

+236
-3
lines changed

4 files changed

+236
-3
lines changed

core/e2e_test.go

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func TestE2E_Basic(t *testing.T) {
169169
// The Go SDK performs validation inside Invoke, so we check the error there.
170170
_, err := tool.Invoke(context.Background(), map[string]any{})
171171
require.Error(t, err)
172-
assert.Contains(t, err.Error(), "parameter \"num_rows\" is required")
172+
assert.Contains(t, err.Error(), "missing required parameter 'num_rows'")
173173
})
174174

175175
t.Run("test_run_tool_wrong_param_type", func(t *testing.T) {
@@ -337,3 +337,141 @@ func TestE2E_Auth(t *testing.T) {
337337
assert.Contains(t, err.Error(), "no field named row_data in claims")
338338
})
339339
}
340+
341+
// TestE2E_OptionalParams maps to the TestOptionalParams class
342+
func TestE2E_OptionalParams(t *testing.T) {
343+
// Helper to create a new client
344+
newClient := func(t *testing.T) *core.ToolboxClient {
345+
client, err := core.NewToolboxClient("http://localhost:5000")
346+
require.NoError(t, err, "Failed to create ToolboxClient")
347+
return client
348+
}
349+
350+
// Helper to load the search-rows tool
351+
searchRowsTool := func(t *testing.T, client *core.ToolboxClient) *core.ToolboxTool {
352+
tool, err := client.LoadTool("search-rows", context.Background())
353+
require.NoError(t, err, "Failed to load tool 'search-rows'")
354+
return tool
355+
}
356+
357+
t.Run("test_tool_schema_is_correct", func(t *testing.T) {
358+
client := newClient(t)
359+
tool := searchRowsTool(t, client)
360+
params := tool.Parameters()
361+
362+
// Convert slice to map for easy lookup
363+
paramMap := make(map[string]core.ParameterSchema)
364+
for _, p := range params {
365+
paramMap[p.Name] = p
366+
}
367+
368+
// Check required parameter 'email'
369+
emailParam, ok := paramMap["email"]
370+
require.True(t, ok, "email parameter should exist")
371+
assert.True(t, emailParam.Required, "'email' should be required")
372+
assert.Equal(t, "string", emailParam.Type)
373+
374+
// Check optional parameter 'data'
375+
dataParam, ok := paramMap["data"]
376+
require.True(t, ok, "data parameter should exist")
377+
assert.False(t, dataParam.Required, "'data' should be optional")
378+
assert.Equal(t, "string", dataParam.Type)
379+
380+
// Check optional parameter 'id'
381+
idParam, ok := paramMap["id"]
382+
require.True(t, ok, "id parameter should exist")
383+
assert.False(t, idParam.Required, "'id' should be optional")
384+
assert.Equal(t, "integer", idParam.Type)
385+
})
386+
387+
t.Run("test_run_tool_omitting_optionals", func(t *testing.T) {
388+
client := newClient(t)
389+
tool := searchRowsTool(t, client)
390+
391+
// Test case 1: Optional params are completely omitted
392+
response1, err1 := tool.Invoke(context.Background(), map[string]any{
393+
"email": "[email protected]",
394+
})
395+
require.NoError(t, err1)
396+
respStr1, ok1 := response1.(string)
397+
require.True(t, ok1)
398+
assert.Contains(t, respStr1, `"email":"[email protected]"`)
399+
assert.Contains(t, respStr1, "row2")
400+
assert.NotContains(t, respStr1, "row3")
401+
402+
// Test case 2: Optional params are explicitly nil
403+
// This should produce the same result as omitting them
404+
response2, err2 := tool.Invoke(context.Background(), map[string]any{
405+
"email": "[email protected]",
406+
"data": nil,
407+
"id": nil,
408+
})
409+
require.NoError(t, err2)
410+
respStr2, ok2 := response2.(string)
411+
require.True(t, ok2)
412+
assert.Equal(t, respStr1, respStr2)
413+
})
414+
415+
t.Run("test_run_tool_with_all_params_provided", func(t *testing.T) {
416+
client := newClient(t)
417+
tool := searchRowsTool(t, client)
418+
response, err := tool.Invoke(context.Background(), map[string]any{
419+
"email": "[email protected]",
420+
"data": "row3",
421+
"id": 3,
422+
})
423+
require.NoError(t, err)
424+
respStr, ok := response.(string)
425+
require.True(t, ok)
426+
assert.Contains(t, respStr, `"email":"[email protected]"`)
427+
assert.Contains(t, respStr, `"id":3`)
428+
assert.Contains(t, respStr, "row3")
429+
assert.NotContains(t, respStr, "row2")
430+
})
431+
432+
t.Run("test_run_tool_missing_required_param", func(t *testing.T) {
433+
client := newClient(t)
434+
tool := searchRowsTool(t, client)
435+
_, err := tool.Invoke(context.Background(), map[string]any{
436+
"data": "row5",
437+
"id": 5,
438+
})
439+
require.Error(t, err)
440+
assert.Contains(t, err.Error(), "missing required parameter 'email'")
441+
})
442+
443+
t.Run("test_run_tool_required_param_is_nil", func(t *testing.T) {
444+
client := newClient(t)
445+
tool := searchRowsTool(t, client)
446+
_, err := tool.Invoke(context.Background(), map[string]any{
447+
"email": nil,
448+
"id": 5,
449+
})
450+
require.Error(t, err)
451+
assert.Contains(t, err.Error(), "parameter 'email' is required but received a nil value")
452+
})
453+
454+
// Corresponds to tests that check server-side logic by providing data that doesn't match
455+
t.Run("test_run_tool_with_non_matching_data", func(t *testing.T) {
456+
client := newClient(t)
457+
tool := searchRowsTool(t, client)
458+
459+
// Test with a different email
460+
response, err := tool.Invoke(context.Background(), map[string]any{
461+
"email": "[email protected]",
462+
"id": 3,
463+
"data": "row3",
464+
})
465+
require.NoError(t, err)
466+
assert.Equal(t, "null", response, "Response should be null for non-matching email")
467+
468+
// Test with different data
469+
response, err = tool.Invoke(context.Background(), map[string]any{
470+
"email": "[email protected]",
471+
"id": 3,
472+
"data": "row4", // This data doesn't match the id
473+
})
474+
require.NoError(t, err)
475+
assert.Equal(t, "null", response, "Response should be null for non-matching data")
476+
})
477+
}

core/protocol.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
type ParameterSchema struct {
2424
Name string `json:"name"`
2525
Type string `json:"type"`
26+
Required bool `json:"required,omitempty"`
2627
Description string `json:"description"`
2728
AuthSources []string `json:"authSources,omitempty"`
2829
Items *ParameterSchema `json:"items,omitempty"`
@@ -31,7 +32,10 @@ type ParameterSchema struct {
3132
// validateType is a helper for manual type checking.
3233
func (p *ParameterSchema) validateType(value any) error {
3334
if value == nil {
34-
return fmt.Errorf("parameter '%s' received a nil value", p.Name)
35+
if p.Required {
36+
return fmt.Errorf("parameter '%s' is required but received a nil value", p.Name)
37+
}
38+
return nil
3539
}
3640

3741
switch p.Type {

core/protocol_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,82 @@ func TestParameterSchemaUndefinedType(t *testing.T) {
254254
}
255255

256256
}
257+
258+
func TestOptionalStringParameter(t *testing.T) {
259+
schema := ParameterSchema{
260+
Name: "nickname",
261+
Type: "string",
262+
Description: "An optional nickname",
263+
Required: false, // Explicitly optional
264+
}
265+
266+
t.Run("allows nil value for optional parameter", func(t *testing.T) {
267+
err := schema.validateType(nil)
268+
if err != nil {
269+
t.Errorf("validateType() with nil should not return an error for an optional parameter, but got: %v", err)
270+
}
271+
})
272+
273+
t.Run("allows valid string value", func(t *testing.T) {
274+
err := schema.validateType("my-name")
275+
if err != nil {
276+
t.Errorf("validateType() should not return an error for a valid string, but got: %v", err)
277+
}
278+
})
279+
}
280+
281+
func TestRequiredParameter(t *testing.T) {
282+
schema := ParameterSchema{
283+
Name: "id",
284+
Type: "integer",
285+
Description: "A required ID",
286+
Required: true, // Explicitly required
287+
}
288+
289+
t.Run("rejects nil value for required parameter", func(t *testing.T) {
290+
err := schema.validateType(nil)
291+
if err == nil {
292+
t.Errorf("validateType() with nil should return an error for a required parameter, but it didn't")
293+
}
294+
})
295+
296+
t.Run("allows valid integer value", func(t *testing.T) {
297+
err := schema.validateType(12345)
298+
if err != nil {
299+
t.Errorf("validateType() should not return an error for a valid integer, but got: %v", err)
300+
}
301+
})
302+
}
303+
304+
func TestOptionalArrayParameter(t *testing.T) {
305+
schema := ParameterSchema{
306+
Name: "optional_scores",
307+
Type: "array",
308+
Description: "An optional list of scores",
309+
Required: false,
310+
Items: &ParameterSchema{
311+
Type: "integer",
312+
},
313+
}
314+
315+
t.Run("allows nil value for optional array", func(t *testing.T) {
316+
err := schema.validateType(nil)
317+
if err != nil {
318+
t.Errorf("validateType() with nil should not return an error for an optional array, but got: %v", err)
319+
}
320+
})
321+
322+
t.Run("allows valid integer slice", func(t *testing.T) {
323+
err := schema.validateType([]int{95, 100})
324+
if err != nil {
325+
t.Errorf("validateType() should not return an error for a valid slice, but got: %v", err)
326+
}
327+
})
328+
329+
t.Run("rejects slice with wrong item type", func(t *testing.T) {
330+
err := schema.validateType([]string{"not", "an", "int"})
331+
if err == nil {
332+
t.Errorf("validateType() should have returned an error for a slice with incorrect item types, but it didn't")
333+
}
334+
})
335+
}

core/tool.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,22 @@ func (tt *ToolboxTool) validateAndBuildPayload(input map[string]any) (map[string
343343
}
344344
}
345345

346+
for _, param := range tt.parameters {
347+
if param.Required {
348+
// A required parameter must be present in either the user input or as a bound parameter.
349+
_, isProvided := input[param.Name]
350+
_, isBound := tt.boundParams[param.Name]
351+
352+
if !isProvided && !isBound {
353+
return nil, fmt.Errorf("missing required parameter '%s'", param.Name)
354+
}
355+
}
356+
}
357+
346358
// Initialize the final payload with the validated user input.
347359
finalPayload := make(map[string]any, len(input)+len(tt.boundParams))
348360
for k, v := range input {
349-
if _, ok := paramSchema[k]; ok {
361+
if _, ok := paramSchema[k]; ok && v != nil {
350362
finalPayload[k] = v
351363
}
352364
}

0 commit comments

Comments
 (0)