Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions structs/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,13 @@ func Walk(s interface{}, callback CallbackFunc) {
}
}

// FilterStruct filters the struct based on include and exclude fields and returns a new struct.
// - input: the original struct.
// - includeFields: list of fields to include (if empty, includes all).
// - excludeFields: list of fields to exclude (processed after include).
func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, error) {
var zeroValue T
func walkFilteredFields[T any](input T, includeFields, excludeFields []string, walker func(field reflect.StructField, value reflect.Value)) error {
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}

if val.Kind() != reflect.Struct {
return zeroValue, errors.New("input must be a struct")
return errors.New("input must be a struct")
}

includeMap := make(map[string]bool)
Expand All @@ -66,7 +60,6 @@ func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, err
}

typeOfStruct := val.Type()
filteredStruct := reflect.New(typeOfStruct).Elem()

for i := 0; i < val.NumField(); i++ {
field := typeOfStruct.Field(i)
Expand All @@ -77,13 +70,62 @@ func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, err
fieldValue := val.Field(i)

if (len(includeMap) == 0 || includeMap[fieldName]) && !excludeMap[fieldName] {
filteredStruct.Field(i).Set(fieldValue)
walker(field, fieldValue)
}
}
return nil
}

// FilterStruct filters the struct based on include and exclude fields and returns a new struct.
// - input: the original struct.
// - includeFields: list of fields to include (if empty, includes all).
// - excludeFields: list of fields to exclude (processed after include).
func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, error) {
var zeroValue T
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}

filteredStruct := reflect.New(val.Type()).Elem()

walker := func(field reflect.StructField, value reflect.Value) {
filteredStruct.FieldByName(field.Name).Set(value)
}

if err := walkFilteredFields(input, includeFields, excludeFields, walker); err != nil {
return zeroValue, err
}

return filteredStruct.Interface().(T), nil
}

func FilterStructToMap[T any](input T, includeFields, excludeFields []string) (map[string]any, error) {
resultMap := make(map[string]any)

walker := func(field reflect.StructField, value reflect.Value) {
jsonTag := field.Tag.Get("json")
jsonKey := strings.Split(jsonTag, ",")[0]

if jsonKey == "" || jsonKey == "-" {
return
}

fieldValue := value.Interface()
if strings.Contains(jsonTag, "omitempty") && value.IsZero() {
return
}

resultMap[jsonKey] = fieldValue
}

if err := walkFilteredFields(input, includeFields, excludeFields, walker); err != nil {
return nil, err
}

return resultMap, nil
}

// GetStructFields returns all the top-level field names from the given struct.
// - input: the original struct.
// Returns a slice of field names or an error if the input is not a struct.
Expand Down
101 changes: 99 additions & 2 deletions structs/structs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ type NestedStruct struct {
PtrField *TestStruct
}

type MapTestStruct struct {
Name string `json:"name"`
Age int `json:"age,omitempty"`
Password string `json:"-"`
IsActive bool `json:"is_active"`
Address string `json:"address"`
Country string `json:"country,omitempty"`
}

func TestFilterStruct(t *testing.T) {
s := TestStruct{
Name: "John",
Expand All @@ -25,7 +34,7 @@ func TestFilterStruct(t *testing.T) {

tests := []struct {
name string
input interface{}
input any
includeFields []string
excludeFields []string
want TestStruct
Expand Down Expand Up @@ -79,6 +88,94 @@ func TestFilterStruct(t *testing.T) {
}
}

func TestFilterStructToMap(t *testing.T) {
s := MapTestStruct{
Name: "John",
Age: 30, // To test omitempty on a non zero value
Password: "secret-password", // To test an ignored tag (json:"-")
IsActive: true,
Address: "New York",
Country: "", // To test omitempty on a zero value
}

tests := []struct {
name string
input any
includeFields []string
excludeFields []string
want map[string]any
wantErr bool
}{
{
name: "no filtering",
input: s,
includeFields: nil,
excludeFields: nil,
want: map[string]any{
"name": "John",
"age": 30,
"is_active": true,
"address": "New York",
},
wantErr: false,
},
{
name: "include specific fields",
input: s,
includeFields: []string{"Name", "Address"},
excludeFields: []string{},
want: map[string]any{
"name": "John",
"address": "New York",
},
wantErr: false,
},
{
name: "exclude specific fields",
input: s,
includeFields: []string{},
excludeFields: []string{"Address", "IsActive"},
want: map[string]any{
"name": "John",
"age": 30,
},
wantErr: false,
},
{
name: "include and exclude",
input: s,
includeFields: []string{"Name", "Age", "Address"},
excludeFields: []string{"Age"},
want: map[string]any{
"name": "John",
"address": "New York",
},
wantErr: false,
},
{
name: "non-struct input",
input: "not a struct",
includeFields: []string{},
excludeFields: []string{},
want: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := FilterStructToMap(tt.input, tt.includeFields, tt.excludeFields)
if (err != nil) != tt.wantErr {
t.Errorf("FilterStructToMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("FilterStructToMap() got = %v, want %v", got, tt.want)
}
})
}
}

func TestGetStructFields(t *testing.T) {
s := TestStruct{
Name: "John",
Expand All @@ -88,7 +185,7 @@ func TestGetStructFields(t *testing.T) {

tests := []struct {
name string
input interface{}
input any
want []string
wantErr bool
}{
Expand Down
Loading