diff --git a/schema.go b/schema.go index 94b4530d..2d970790 100644 --- a/schema.go +++ b/schema.go @@ -14,6 +14,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/danielgtaylor/huma/v2/validation" @@ -676,6 +677,29 @@ type SchemaTransformer interface { TransformSchema(r Registry, s *Schema) *Schema } +type SchemaFactory func(Registry) *Schema + +var ( + schemaFactoryMu sync.RWMutex + schemaFactories = map[reflect.Type]SchemaFactory{} +) + +// RegisterTypeSchema associates a schema factory with the given type. +// The provided factory runs whenever SchemaFromType handles that type. +// Later calls replace any previously registered factory. +func RegisterTypeSchema(t reflect.Type, factory SchemaFactory) { + if t == nil { + panic("huma: RegisterTypeSchema called with nil type") + } + if factory == nil { + panic("huma: RegisterTypeSchema called with nil factory") + } + + schemaFactoryMu.Lock() + schemaFactories[t] = factory + schemaFactoryMu.Unlock() +} + // SchemaFromType returns a schema for a given type, using the registry to // possibly create references for nested structs. The schema that is returned // can then be passed to `huma.Validate` to efficiently validate incoming @@ -699,12 +723,46 @@ func SchemaFromType(r Registry, t reflect.Type) *Schema { return s } +func lookupTypeSchemaFactory(t reflect.Type) SchemaFactory { + if t == nil { + return nil + } + + schemaFactoryMu.RLock() + factory := schemaFactories[t] + schemaFactoryMu.RUnlock() + return factory +} + +func schemaFromRegisteredFactory(r Registry, t reflect.Type) *Schema { + factory := lookupTypeSchemaFactory(t) + if factory == nil { + return nil + } + + s := factory(r) + if s == nil { + return nil + } + + s.PrecomputeMessages() + return s +} + func schemaFromType(r Registry, t reflect.Type) *Schema { + if custom := schemaFromRegisteredFactory(r, t); custom != nil { + return custom + } + isPointer := t.Kind() == reflect.Pointer s := Schema{} t = deref(t) + if custom := schemaFromRegisteredFactory(r, t); custom != nil { + return custom + } + v := reflect.New(t).Interface() if sp, ok := v.(SchemaProvider); ok { // Special case: type provides its own schema. Do not try to generate. diff --git a/schema_test.go b/schema_test.go index 18ab17a1..b097f37e 100644 --- a/schema_test.go +++ b/schema_test.go @@ -1544,3 +1544,44 @@ func TestSchemaTransformer(t *testing.T) { updateSchema2 := huma.SchemaFromType(r, reflect.TypeOf(ExampleUpdateStruct{})) validateSchema(updateSchema2) } + +type customSchemaInt int64 + +func TestRegisterTypeSchema(t *testing.T) { + huma.RegisterTypeSchema(reflect.TypeOf(customSchemaInt(0)), func(r huma.Registry) *huma.Schema { + return &huma.Schema{Type: huma.TypeString} + }) + + registry := huma.NewMapRegistry("#/components/schemas/", huma.DefaultSchemaNamer) + + type input struct { + ID customSchemaInt `json:"id"` + Optional *customSchemaInt `json:"optional,omitempty"` + Nested struct { + Value customSchemaInt `json:"value"` + } `json:"nested"` + } + + schema := huma.SchemaFromType(registry, reflect.TypeOf(input{})) + require.NotNil(t, schema) + require.Contains(t, schema.Properties, "id") + assert.Equal(t, huma.TypeString, schema.Properties["id"].Type) + + require.Contains(t, schema.Properties, "optional") + assert.Equal(t, huma.TypeString, schema.Properties["optional"].Type) + + nested := schema.Properties["nested"] + require.NotNil(t, nested) + nestedSchema := nested + if nested.Ref != "" { + nestedSchema = registry.SchemaFromRef(nested.Ref) + require.NotNil(t, nestedSchema) + } + assert.Equal(t, huma.TypeObject, nestedSchema.Type) + require.Contains(t, nestedSchema.Properties, "value") + assert.Equal(t, huma.TypeString, nestedSchema.Properties["value"].Type) + + custom := registry.Schema(reflect.TypeOf(customSchemaInt(0)), true, "Standalone") + require.NotNil(t, custom) + assert.Equal(t, huma.TypeString, custom.Type) +}