Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions go/ai/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,146 @@ func (t *mockTool) Definition() *ToolDefinition {
func (t *mockTool) RunRaw(ctx context.Context, input any) (any, error) {
return nil, nil
}

func TestWithInputTypeDefaultValues(t *testing.T) {
t.Run("struct field values are captured as DefaultInput", func(t *testing.T) {
type TestInput struct {
Name string `json:"name"`
Age int `json:"age"`
Active bool `json:"active"`
Balance float64 `json:"balance"`
}

input := TestInput{
Name: "John",
Age: 30,
Active: true,
Balance: 100.50,
}

opt := WithInputType(input).(*inputOptions)

expectedDefaults := map[string]any{
"name": "John",
"age": float64(30),
"active": true,
"balance": 100.50,
}

if diff := cmp.Diff(expectedDefaults, opt.DefaultInput); diff != "" {
t.Errorf("DefaultInput mismatch (-want +got):\n%s", diff)
}
})

t.Run("zero values are included in DefaultInput", func(t *testing.T) {
type TestInput struct {
Name string `json:"name"`
Count int `json:"count"`
Active bool `json:"active"`
}

input := TestInput{} // all zero values

opt := WithInputType(input).(*inputOptions)

expectedDefaults := map[string]any{
"name": "",
"count": float64(0),
"active": false,
}

if diff := cmp.Diff(expectedDefaults, opt.DefaultInput); diff != "" {
t.Errorf("DefaultInput should include zero values, diff (-want +got):\n%s", diff)
}
})

t.Run("map input is used directly as DefaultInput", func(t *testing.T) {
input := map[string]any{
"name": "default",
"age": 25,
}

opt := WithInputType(input).(*inputOptions)

if diff := cmp.Diff(input, opt.DefaultInput); diff != "" {
t.Errorf("DefaultInput should match map input, diff (-want +got):\n%s", diff)
}
})
Comment on lines +709 to +720
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test map input is used directly as DefaultInput confirms a behavior in WithInputType that creates an inconsistency in how default inputs are handled. When a map[string]any is provided, it bypasses the JSON marshaling/unmarshaling cycle that struct inputs go through.

This leads to different types for numeric values in DefaultInput:

  • For structs, int values become float64 (e.g., lines 677, 700).
  • For maps, int values remain int (as tested here).

This inconsistency can lead to subtle bugs if downstream code expects all numbers in DefaultInput to be float64. To ensure consistent behavior, it would be better to process all input types through the same JSON marshaling/unmarshaling logic in WithInputType.

I'd recommend updating WithInputType to remove the special case for map[string]any and then adjusting this test to expect float64 values, similar to the other tests in this file. Since I cannot suggest changes to option.go, here is how this test could be updated after that change:

t.Run("map input is converted consistently", func(t *testing.T) {
    input := map[string]any{
        "name": "default",
        "age":  25,
    }

    opt := WithInputType(input).(*inputOptions)

    expected := map[string]any{
        "name": "default",
        "age":  float64(25),
    }

    if diff := cmp.Diff(expected, opt.DefaultInput); diff != "" {
        t.Errorf("DefaultInput should be consistent with struct processing, diff (-want +got):\n%s", diff)
    }
})


t.Run("jsonschema default tag is reflected in schema", func(t *testing.T) {
type TestInputWithDefaults struct {
Name string `json:"name" jsonschema:"default=guest"`
Age int `json:"age" jsonschema:"default=25"`
Active bool `json:"active" jsonschema:"default=true"`
}

opt := WithInputType(TestInputWithDefaults{}).(*inputOptions)

props, ok := opt.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties in schema")
}

nameSchema, ok := props["name"].(map[string]any)
if !ok {
t.Fatal("expected name property in schema")
}
if nameSchema["default"] != "guest" {
t.Errorf("expected name default to be 'guest', got %v", nameSchema["default"])
}

ageSchema, ok := props["age"].(map[string]any)
if !ok {
t.Fatal("expected age property in schema")
}
if ageSchema["default"] != float64(25) {
t.Errorf("expected age default to be 25, got %v", ageSchema["default"])
}

activeSchema, ok := props["active"].(map[string]any)
if !ok {
t.Fatal("expected active property in schema")
}
if activeSchema["default"] != true {
t.Errorf("expected active default to be true, got %v", activeSchema["default"])
}
})
Comment on lines +722 to +759
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test contains repetitive checks for each field in the schema. To improve maintainability and readability, consider refactoring it into a table-driven test, similar to the pattern used in TestSchemaAsMapWithDefaults in go/internal/base/json_test.go.

	t.Run("jsonschema default tag is reflected in schema", func(t *testing.T) {
		type TestInputWithDefaults struct {
			Name   string `json:"name" jsonschema:"default=guest"`
			Age    int    `json:"age" jsonschema:"default=25"`
			Active bool   `json:"active" jsonschema:"default=true"`
		}

		opt := WithInputType(TestInputWithDefaults{}).(*inputOptions)

		props, ok := opt.InputSchema["properties"].(map[string]any)
		if !ok {
			t.Fatal("expected properties in schema")
		}

		tests := []struct {
			field    string
			expected any
		}{
			{"name", "guest"},
			{"age", float64(25)},
			{"active", true},
		}

		for _, tc := range tests {
			t.Run(tc.field, func(t *testing.T) {
				prop, ok := props[tc.field].(map[string]any)
				if !ok {
					t.Fatalf("expected %s property in schema", tc.field)
				}
				if prop["default"] != tc.expected {
					t.Errorf("expected default for %s to be %v, got %v", tc.field, tc.expected, prop["default"])
				}
			})
		}
	})


t.Run("struct values take precedence over jsonschema defaults", func(t *testing.T) {
type TestInputWithDefaults struct {
Name string `json:"name" jsonschema:"default=guest"`
Age int `json:"age" jsonschema:"default=25"`
}

input := TestInputWithDefaults{
Name: "admin",
Age: 40,
}

opt := WithInputType(input).(*inputOptions)

// DefaultInput should have the struct values, not the jsonschema defaults
expectedDefaults := map[string]any{
"name": "admin",
"age": float64(40),
}

if diff := cmp.Diff(expectedDefaults, opt.DefaultInput); diff != "" {
t.Errorf("struct values should be used as DefaultInput, diff (-want +got):\n%s", diff)
}

// But the schema should still have the jsonschema tag defaults
props, ok := opt.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties in schema")
}

nameSchema, ok := props["name"].(map[string]any)
if !ok {
t.Fatal("expected name property in schema")
}
if nameSchema["default"] != "guest" {
t.Errorf("schema should retain jsonschema default 'guest', got %v", nameSchema["default"])
}
})
}
208 changes: 208 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,214 @@ type HelloPromptInput struct {
Name string
}

func TestPromptDefaultInput(t *testing.T) {
reg := registry.New()

type GreetingInput struct {
Name string `json:"name"`
Greeting string `json:"greeting"`
}

t.Run("uses struct values as defaults when nil input provided", func(t *testing.T) {
p := DefinePrompt(
reg, "greeting-nil",
WithPrompt("{{greeting}} {{name}}!"),
WithInputType(GreetingInput{Name: "World", Greeting: "Hello"}),
)

req, err := p.Render(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

want := "Hello World!"
got := req.Messages[0].Content[0].Text
if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("partial input uses defaults for missing optional fields", func(t *testing.T) {
// Note: Fields must use omitempty to be optional and allow partial input
type PartialGreetingInput struct {
Name string `json:"name,omitempty"`
Greeting string `json:"greeting,omitempty"`
}

p := DefinePrompt(
reg, "greeting-partial",
WithPrompt("{{greeting}} {{name}}!"),
WithInputType(PartialGreetingInput{Name: "World", Greeting: "Hello"}),
)

// Only provide name, greeting should come from default
req, err := p.Render(context.Background(), map[string]any{"name": "Alice"})
if err != nil {
t.Fatal(err)
}

want := "Hello Alice!"
got := req.Messages[0].Content[0].Text
if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("provided input overrides defaults", func(t *testing.T) {
p := DefinePrompt(
reg, "greeting-override",
WithPrompt("{{greeting}} {{name}}!"),
WithInputType(GreetingInput{Name: "World", Greeting: "Hello"}),
)

// Provide both values, should override defaults
req, err := p.Render(context.Background(), map[string]any{"name": "Bob", "greeting": "Hi"})
if err != nil {
t.Fatal(err)
}

want := "Hi Bob!"
got := req.Messages[0].Content[0].Text
if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("map default input works", func(t *testing.T) {
p := DefinePrompt(
reg, "greeting-map",
WithPrompt("{{greeting}} {{name}}!"),
WithInputType(map[string]any{"name": "Universe", "greeting": "Howdy"}),
)

req, err := p.Render(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

want := "Howdy Universe!"
got := req.Messages[0].Content[0].Text
if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("zero values are treated as valid defaults", func(t *testing.T) {
type CountInput struct {
Count int `json:"count"`
Show bool `json:"show"`
}

p := DefinePrompt(
reg, "count-zero",
WithPrompt("Count: {{count}}, Show: {{show}}"),
WithInputType(CountInput{Count: 0, Show: false}),
)

req, err := p.Render(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

want := "Count: 0, Show: false"
got := req.Messages[0].Content[0].Text
if got != want {
t.Errorf("got %q want %q", got, want)
}
})
}

func TestPromptDefaultInputWithJsonschemaDefaults(t *testing.T) {
reg := registry.New()

// Test that jsonschema defaults in tags are correctly reflected in schema
type InputWithSchemaDefaults struct {
Name string `json:"name" jsonschema:"default=guest"`
Priority int `json:"priority" jsonschema:"default=5"`
}

t.Run("struct values override jsonschema defaults during rendering", func(t *testing.T) {
p := DefinePrompt(
reg, "schema-defaults-override",
WithPrompt("{{name}} (priority: {{priority}})"),
WithInputType(InputWithSchemaDefaults{Name: "admin", Priority: 10}),
)

req, err := p.Render(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

// Should use struct values, not jsonschema defaults
want := "admin (priority: 10)"
got := req.Messages[0].Content[0].Text
if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("schema has jsonschema defaults exposed", func(t *testing.T) {
p := DefinePrompt(
reg, "schema-defaults-exposed",
WithPrompt("{{name}}"),
WithInputType(InputWithSchemaDefaults{Name: "admin", Priority: 10}),
)

desc := p.(api.Action).Desc()
schema := desc.InputSchema

props, ok := schema["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties in input schema")
}

nameSchema, ok := props["name"].(map[string]any)
if !ok {
t.Fatal("expected name property in schema")
}

if nameSchema["default"] != "guest" {
t.Errorf("expected schema default to be 'guest', got %v", nameSchema["default"])
}

prioritySchema, ok := props["priority"].(map[string]any)
if !ok {
t.Fatal("expected priority property in schema")
}

if prioritySchema["default"] != float64(5) {
t.Errorf("expected schema default to be 5, got %v", prioritySchema["default"])
}
})
Comment on lines +299 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test has repetitive checks for schema properties. For better readability and maintainability, you could refactor this into a table-driven test. This would make it easier to add more properties to check in the future.

	t.Run("schema has jsonschema defaults exposed", func(t *testing.T) {
		p := DefinePrompt(
			reg, "schema-defaults-exposed",
			WithPrompt("{{name}}"),
			WithInputType(InputWithSchemaDefaults{Name: "admin", Priority: 10}),
		)

		desc := p.(api.Action).Desc()
		schema := desc.InputSchema

		props, ok := schema["properties"].(map[string]any)
		if !ok {
			t.Fatal("expected properties in input schema")
		}

		testCases := []struct {
			field    string
			expected any
		}{
			{"name", "guest"},
			{"priority", float64(5)},
		}

		for _, tc := range testCases {
			t.Run(tc.field, func(t *testing.T) {
				prop, ok := props[tc.field].(map[string]any)
				if !ok {
					t.Fatalf("expected %s property in schema", tc.field)
				}
				if prop["default"] != tc.expected {
					t.Errorf("expected schema default to be %v, got %v", tc.expected, prop["default"])
				}
			})
		}
	})


t.Run("defaultInput in metadata uses struct values not jsonschema defaults", func(t *testing.T) {
p := DefinePrompt(
reg, "schema-defaults-metadata",
WithPrompt("{{name}}"),
WithInputType(InputWithSchemaDefaults{Name: "admin", Priority: 10}),
)

desc := p.(api.Action).Desc()
promptMeta, ok := desc.Metadata["prompt"].(map[string]any)
if !ok {
t.Fatal("expected prompt metadata")
}

defaultInput, ok := promptMeta["defaultInput"].(map[string]any)
if !ok {
t.Fatal("expected defaultInput in prompt metadata")
}

if defaultInput["name"] != "admin" {
t.Errorf("expected defaultInput name to be 'admin', got %v", defaultInput["name"])
}

if defaultInput["priority"] != float64(10) {
t.Errorf("expected defaultInput priority to be 10, got %v", defaultInput["priority"])
}
})
Comment on lines +333 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous test, this one can be refactored into a table-driven test to reduce code duplication and improve maintainability.

	t.Run("defaultInput in metadata uses struct values not jsonschema defaults", func(t *testing.T) {
		p := DefinePrompt(
			reg, "schema-defaults-metadata",
			WithPrompt("{{name}}"),
			WithInputType(InputWithSchemaDefaults{Name: "admin", Priority: 10}),
		)

		desc := p.(api.Action).Desc()
		promptMeta, ok := desc.Metadata["prompt"].(map[string]any)
		if !ok {
			t.Fatal("expected prompt metadata")
		}

		defaultInput, ok := promptMeta["defaultInput"].(map[string]any)
		if !ok {
			t.Fatal("expected defaultInput in prompt metadata")
		}

		testCases := []struct {
			field    string
			expected any
		}{
			{"name", "admin"},
			{"priority", float64(10)},
		}

		for _, tc := range testCases {
			t.Run(tc.field, func(t *testing.T) {
				if defaultInput[tc.field] != tc.expected {
					t.Errorf("expected defaultInput %s to be %v, got %v", tc.field, tc.expected, defaultInput[tc.field])
				}
			})
		}
	})

}

func definePromptModel(reg api.Registry) Model {
return DefineModel(reg, "test/chat",
&ModelOptions{Supports: &ModelSupports{
Expand Down
Loading
Loading