Skip to content

Commit d889594

Browse files
committed
Improve filtering with --include and --exclude.
Closes #31.
1 parent 300da10 commit d889594

File tree

4 files changed

+96
-15
lines changed

4 files changed

+96
-15
lines changed

cmd/relationship.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ import (
1717
"encoding/base64"
1818
"encoding/gob"
1919
"fmt"
20+
"github.com/VirusTotal/vt-cli/utils"
2021
"os"
2122
"path"
2223
"sync"
2324

24-
"github.com/VirusTotal/vt-cli/utils"
2525
vt "github.com/VirusTotal/vt-go"
2626
homedir "github.com/mitchellh/go-homedir"
2727
"github.com/spf13/cobra"
@@ -106,7 +106,7 @@ func NewRelationshipsCmd(collection, objectType, use string) *cobra.Command {
106106
Args: cobra.ExactArgs(1),
107107
RunE: func(cmd *cobra.Command, args []string) error {
108108
var wg sync.WaitGroup
109-
var m sync.Map
109+
var sm sync.Map
110110
for _, r := range objectRelationshipsMap[objectType] {
111111
wg.Add(1)
112112
go func(relationshipName string) {
@@ -118,17 +118,30 @@ func NewRelationshipsCmd(collection, objectType, use string) *cobra.Command {
118118
if err != nil {
119119
fmt.Println(err)
120120
} else if len(objs) > 0 {
121-
m.Store(relationshipName, objs)
121+
sm.Store(relationshipName, objs)
122122
}
123123
wg.Done()
124124
}(r.Name)
125125
}
126126
wg.Wait()
127+
128+
m := make(map[string]interface{})
129+
sm.Range(func(key, value interface{}) bool {
130+
m[key.(string)] = value
131+
return true
132+
})
133+
134+
if viper.IsSet("include") || viper.IsSet("exclude") {
135+
m = utils.FilterMap(m,
136+
viper.GetStringSlice("include"),
137+
viper.GetStringSlice("exclude"))
138+
}
139+
127140
p, err := NewPrinter(cmd)
128141
if err != nil {
129142
return err
130143
}
131-
return p.PrintSyncMap(&m)
144+
return p.Print(m)
132145
},
133146
}
134147

utils/filter.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
package utils
1515

1616
import (
17-
"reflect"
18-
1917
glob "github.com/gobwas/glob"
18+
"reflect"
19+
"strings"
2020
)
2121

22-
// FilterMap receives a map with string keys and arbirary values (possibly
22+
// FilterMap receives a map with string keys and arbitrary values (possibly
2323
// other maps) and return a new map which is a subset of the original one
24-
// contaning only the keys matching any of the patterns in "include" and
24+
// containing only the keys matching any of the patterns in "include" and
2525
// excluding keys matching any of the patterns in "exclude". The logic for
2626
// determining if a key matches the pattern goes as follow:
2727
//
@@ -45,10 +45,27 @@ func FilterMap(m map[string]interface{}, include, exclude []string) map[string]i
4545
cp := glob.MustCompile(p, '.')
4646
includeGlob[i] = cp
4747
}
48+
// For each include pattern that do not ends with **, add the same pattern
49+
// but ended in .**. This because when someone says that she wants to include
50+
// "foo.bar", where "foo.bar" is dictionary, what she actually expects is
51+
// getting the dictionary with all its keys, but the keys inside the
52+
// dictionary don't match the "foo.bar" pattern, so we add "foo.bar.**".
53+
for _, p := range include {
54+
if !strings.HasSuffix(p, "**") {
55+
includeGlob = append(includeGlob, glob.MustCompile(p+".**", '.'))
56+
}
57+
}
4858
for i, p := range exclude {
4959
cp := glob.MustCompile(p, '.')
5060
excludeGlob[i] = cp
5161
}
62+
// The same happens if you exclude "foo.bar", what you actually mean is
63+
// excluding "foo.bar" and "foo.bar.**".
64+
for _, p := range exclude {
65+
if !strings.HasSuffix(p, "**") {
66+
excludeGlob = append(excludeGlob, glob.MustCompile(p+".**", '.'))
67+
}
68+
}
5269
filtered := filterMap(reflect.ValueOf(m), includeGlob, excludeGlob, "")
5370
return filtered.Interface().(map[string]interface{})
5471
}

utils/filter_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,31 @@ var testCases = []testCase{
7373
},
7474
},
7575

76+
testCase{
77+
include: []string{"foo"},
78+
exclude: []string{"**.quux"},
79+
input: testMap,
80+
output: map[string]interface{}{
81+
82+
},
83+
},
84+
85+
testCase{
86+
include: []string{"foo**"},
87+
exclude: []string{"**.key1"},
88+
input: testMap,
89+
output: map[string]interface{}{
90+
"foo": map[string]interface{}{
91+
"qux": map[string]interface{}{
92+
"quux": map[string]interface{}{
93+
"key2": "val2",
94+
"key3": []string{"val3"},
95+
},
96+
},
97+
},
98+
},
99+
},
100+
76101
testCase{
77102
include: []string{"foo**"},
78103
exclude: []string{"**.key1"},
@@ -179,6 +204,21 @@ var testCases = []testCase{
179204
},
180205
},
181206

207+
testCase{
208+
include: []string{"qux"},
209+
input: testMap,
210+
output: map[string]interface{}{
211+
"qux": []interface{}{
212+
map[string]interface{}{
213+
"key4": "val4",
214+
},
215+
map[string]interface{}{
216+
"key5": "val5",
217+
},
218+
},
219+
},
220+
},
221+
182222
testCase{
183223
include: []string{"ary"},
184224
input: testMap,

utils/printer.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ func (p *Printer) PrintSyncMap(sm *sync.Map) error {
6363
m[key.(string)] = value
6464
return true
6565
})
66+
if viper.IsSet("include") || viper.IsSet("exclude") {
67+
m = FilterMap(m,
68+
viper.GetStringSlice("include"),
69+
viper.GetStringSlice("exclude"))
70+
}
6671
return p.Print(m)
6772
}
6873

@@ -95,21 +100,27 @@ func ObjectToMap(obj *vt.Object) map[string]interface{} {
95100
m[name] = l
96101
}
97102
}
98-
if viper.IsSet("include") || viper.IsSet("exclude") {
99-
m = FilterMap(m,
100-
viper.GetStringSlice("include"),
101-
viper.GetStringSlice("exclude"))
102-
}
103103
return m
104104
}
105105

106106
// PrintObjects prints all the specified objects to stdout.
107107
func (p *Printer) PrintObjects(objs []*vt.Object) error {
108108
list := make([]map[string]interface{}, 0)
109109
for _, obj := range objs {
110-
list = append(list, ObjectToMap(obj))
110+
m := ObjectToMap(obj)
111+
if viper.IsSet("include") || viper.IsSet("exclude") {
112+
m = FilterMap(m,
113+
viper.GetStringSlice("include"),
114+
viper.GetStringSlice("exclude"))
115+
}
116+
if len(m) > 0 {
117+
list = append(list, m)
118+
}
119+
}
120+
if len(list) > 0 {
121+
return p.Print(list)
111122
}
112-
return p.Print(list)
123+
return nil
113124
}
114125

115126
// PrintObject prints the specified object to stdout.

0 commit comments

Comments
 (0)