diff --git a/structs/structs.go b/structs/structs.go index a5ae9bde..47496aef 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -1,6 +1,9 @@ package structs -import "reflect" +import ( + "errors" + "reflect" +) // CallbackFunc on the struct field // example: @@ -35,3 +38,65 @@ func Walk(s interface{}, callback CallbackFunc) { } } } + +// FilterStruct filters the struct based on include and exclude fields and returns a new struct. +// - input: the original struct. +// - includeFields: list of fields to include (if empty, includes all). +// - excludeFields: list of fields to exclude (processed after include). +func FilterStruct(input interface{}, includeFields, excludeFields []string) (interface{}, error) { + val := reflect.ValueOf(input) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil, errors.New("input must be a struct") + } + + includeMap := make(map[string]bool) + excludeMap := make(map[string]bool) + + for _, field := range includeFields { + includeMap[field] = true + } + for _, field := range excludeFields { + excludeMap[field] = true + } + + typeOfStruct := val.Type() + filteredStruct := reflect.New(typeOfStruct).Elem() + + for i := 0; i < val.NumField(); i++ { + field := typeOfStruct.Field(i) + fieldName := field.Name + fieldValue := val.Field(i) + + if (len(includeMap) == 0 || includeMap[fieldName]) && !excludeMap[fieldName] { + filteredStruct.Field(i).Set(fieldValue) + } + } + + return filteredStruct.Interface(), nil +} + +// GetStructFields returns all the top-level field names from the given struct. +// - input: the original struct. +// Returns a slice of field names or an error if the input is not a struct. +func GetStructFields(input interface{}) ([]string, error) { + val := reflect.ValueOf(input) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil, errors.New("input must be a struct") + } + + fields := make([]string, 0, val.NumField()) + typeOfStruct := val.Type() + for i := 0; i < val.NumField(); i++ { + fields = append(fields, typeOfStruct.Field(i).Name) + } + + return fields, nil +} diff --git a/structs/structs_test.go b/structs/structs_test.go new file mode 100644 index 00000000..014840b0 --- /dev/null +++ b/structs/structs_test.go @@ -0,0 +1,123 @@ +package structs + +import ( + "reflect" + "testing" +) + +type TestStruct struct { + Name string + Age int + Address string +} + +type NestedStruct struct { + Basic TestStruct + PtrField *TestStruct +} + +func TestFilterStruct(t *testing.T) { + s := TestStruct{ + Name: "John", + Age: 30, + Address: "New York", + } + + tests := []struct { + name string + input interface{} + includeFields []string + excludeFields []string + want TestStruct + wantErr bool + }{ + { + name: "include specific fields", + input: s, + includeFields: []string{"Name", "Age"}, + excludeFields: []string{}, + want: TestStruct{ + Name: "John", + Age: 30, + }, + wantErr: false, + }, + { + name: "exclude specific fields", + input: s, + includeFields: []string{}, + excludeFields: []string{"Address"}, + want: TestStruct{ + Name: "John", + Age: 30, + }, + wantErr: false, + }, + { + name: "non-struct input", + input: "not a struct", + includeFields: []string{}, + excludeFields: []string{}, + want: TestStruct{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FilterStruct(tt.input, tt.includeFields, tt.excludeFields) + if (err != nil) != tt.wantErr { + t.Errorf("FilterStruct() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FilterStruct() = %v, want %v", got, tt.want) + } + } + }) + } +} + +func TestGetStructFields(t *testing.T) { + s := TestStruct{ + Name: "John", + Age: 30, + Address: "New York", + } + + tests := []struct { + name string + input interface{} + want []string + wantErr bool + }{ + { + name: "valid struct", + input: s, + want: []string{"Name", "Age", "Address"}, + wantErr: false, + }, + { + name: "non-struct input", + input: "not a struct", + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetStructFields(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("GetStructFields() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetStructFields() = %v, want %v", got, tt.want) + } + } + }) + } +}