diff --git a/structs/structs.go b/structs/structs.go index 0c48466..95fd645 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -40,19 +40,13 @@ func Walk(s interface{}, callback CallbackFunc) { } } -// FilterStruct filters the struct based on include and exclude fields and returns a new struct. -// - input: the original struct. -// - includeFields: list of fields to include (if empty, includes all). -// - excludeFields: list of fields to exclude (processed after include). -func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, error) { - var zeroValue T +func walkFilteredFields[T any](input T, includeFields, excludeFields []string, walker func(field reflect.StructField, value reflect.Value)) error { val := reflect.ValueOf(input) if val.Kind() == reflect.Ptr { val = val.Elem() } - if val.Kind() != reflect.Struct { - return zeroValue, errors.New("input must be a struct") + return errors.New("input must be a struct") } includeMap := make(map[string]bool) @@ -66,7 +60,6 @@ func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, err } typeOfStruct := val.Type() - filteredStruct := reflect.New(typeOfStruct).Elem() for i := 0; i < val.NumField(); i++ { field := typeOfStruct.Field(i) @@ -77,13 +70,62 @@ func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, err fieldValue := val.Field(i) if (len(includeMap) == 0 || includeMap[fieldName]) && !excludeMap[fieldName] { - filteredStruct.Field(i).Set(fieldValue) + walker(field, fieldValue) } } + return nil +} + +// FilterStruct filters the struct based on include and exclude fields and returns a new struct. +// - input: the original struct. +// - includeFields: list of fields to include (if empty, includes all). +// - excludeFields: list of fields to exclude (processed after include). +func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, error) { + var zeroValue T + val := reflect.ValueOf(input) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + filteredStruct := reflect.New(val.Type()).Elem() + + walker := func(field reflect.StructField, value reflect.Value) { + filteredStruct.FieldByName(field.Name).Set(value) + } + + if err := walkFilteredFields(input, includeFields, excludeFields, walker); err != nil { + return zeroValue, err + } return filteredStruct.Interface().(T), nil } +func FilterStructToMap[T any](input T, includeFields, excludeFields []string) (map[string]any, error) { + resultMap := make(map[string]any) + + walker := func(field reflect.StructField, value reflect.Value) { + jsonTag := field.Tag.Get("json") + jsonKey := strings.Split(jsonTag, ",")[0] + + if jsonKey == "" || jsonKey == "-" { + return + } + + fieldValue := value.Interface() + if strings.Contains(jsonTag, "omitempty") && value.IsZero() { + return + } + + resultMap[jsonKey] = fieldValue + } + + if err := walkFilteredFields(input, includeFields, excludeFields, walker); err != nil { + return nil, err + } + + return resultMap, nil +} + // GetStructFields returns all the top-level field names from the given struct. // - input: the original struct. // Returns a slice of field names or an error if the input is not a struct. diff --git a/structs/structs_test.go b/structs/structs_test.go index 5e4dea0..9e1c9b5 100644 --- a/structs/structs_test.go +++ b/structs/structs_test.go @@ -16,6 +16,15 @@ type NestedStruct struct { PtrField *TestStruct } +type MapTestStruct struct { + Name string `json:"name"` + Age int `json:"age,omitempty"` + Password string `json:"-"` + IsActive bool `json:"is_active"` + Address string `json:"address"` + Country string `json:"country,omitempty"` +} + func TestFilterStruct(t *testing.T) { s := TestStruct{ Name: "John", @@ -25,7 +34,7 @@ func TestFilterStruct(t *testing.T) { tests := []struct { name string - input interface{} + input any includeFields []string excludeFields []string want TestStruct @@ -79,6 +88,94 @@ func TestFilterStruct(t *testing.T) { } } +func TestFilterStructToMap(t *testing.T) { + s := MapTestStruct{ + Name: "John", + Age: 30, // To test omitempty on a non zero value + Password: "secret-password", // To test an ignored tag (json:"-") + IsActive: true, + Address: "New York", + Country: "", // To test omitempty on a zero value + } + + tests := []struct { + name string + input any + includeFields []string + excludeFields []string + want map[string]any + wantErr bool + }{ + { + name: "no filtering", + input: s, + includeFields: nil, + excludeFields: nil, + want: map[string]any{ + "name": "John", + "age": 30, + "is_active": true, + "address": "New York", + }, + wantErr: false, + }, + { + name: "include specific fields", + input: s, + includeFields: []string{"Name", "Address"}, + excludeFields: []string{}, + want: map[string]any{ + "name": "John", + "address": "New York", + }, + wantErr: false, + }, + { + name: "exclude specific fields", + input: s, + includeFields: []string{}, + excludeFields: []string{"Address", "IsActive"}, + want: map[string]any{ + "name": "John", + "age": 30, + }, + wantErr: false, + }, + { + name: "include and exclude", + input: s, + includeFields: []string{"Name", "Age", "Address"}, + excludeFields: []string{"Age"}, + want: map[string]any{ + "name": "John", + "address": "New York", + }, + wantErr: false, + }, + { + name: "non-struct input", + input: "not a struct", + includeFields: []string{}, + excludeFields: []string{}, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FilterStructToMap(tt.input, tt.includeFields, tt.excludeFields) + if (err != nil) != tt.wantErr { + t.Errorf("FilterStructToMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FilterStructToMap() got = %v, want %v", got, tt.want) + } + }) + } +} + func TestGetStructFields(t *testing.T) { s := TestStruct{ Name: "John", @@ -88,7 +185,7 @@ func TestGetStructFields(t *testing.T) { tests := []struct { name string - input interface{} + input any want []string wantErr bool }{