Skip to content

Commit 1df4c24

Browse files
authored
refact: pkg/database decisions filter, queries (#3635)
* refact pkg/database: rename query methods * refact pkg/database: extract decisionfilter.go * refact pkg/database: rename filters -> filter * handle error with since==nil * lint
1 parent d64ee2a commit 1df4c24

File tree

4 files changed

+237
-219
lines changed

4 files changed

+237
-219
lines changed

pkg/database/alertfilter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func alertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
255255
return predicates, nil
256256
}
257257

258-
func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) {
258+
func applyAlertFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) {
259259
preds, err := alertPredicatesFromFilter(filter)
260260
if err != nil {
261261
return nil, err

pkg/database/alerts.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,15 +742,15 @@ func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []
742742
return alertIDs, nil
743743
}
744744

745-
func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) {
745+
func (c *Client) AlertsCountPerScenario(ctx context.Context, filter map[string][]string) (map[string]int, error) {
746746
var res []struct {
747747
Scenario string
748748
Count int
749749
}
750750

751751
query := c.Ent.Alert.Query()
752752

753-
query, err := BuildAlertRequestFromFilter(query, filters)
753+
query, err := applyAlertFilter(query, filter)
754754
if err != nil {
755755
return nil, fmt.Errorf("failed to build alert request: %w", err)
756756
}
@@ -801,7 +801,7 @@ func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]s
801801
for {
802802
alerts := c.Ent.Alert.Query()
803803

804-
alerts, err := BuildAlertRequestFromFilter(alerts, filter)
804+
alerts, err := applyAlertFilter(alerts, filter)
805805
if err != nil {
806806
return nil, err
807807
}

pkg/database/decisionfilter.go

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
package database
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"strings"
7+
8+
"github.com/pkg/errors"
9+
10+
"github.com/crowdsecurity/crowdsec/pkg/csnet"
11+
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
12+
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
13+
"github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
14+
"github.com/crowdsecurity/crowdsec/pkg/types"
15+
)
16+
17+
func applyDecisionFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) {
18+
var (
19+
rng csnet.Range
20+
err error
21+
)
22+
23+
contains := true
24+
/*if contains is true, return bans that *contains* the given value (value is the inner)
25+
else, return bans that are *contained* by the given value (value is the outer)*/
26+
27+
/*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */
28+
if v, ok := filter["simulated"]; ok {
29+
if v[0] == "false" {
30+
query = query.Where(decision.SimulatedEQ(false))
31+
}
32+
33+
delete(filter, "simulated")
34+
} else {
35+
query = query.Where(decision.SimulatedEQ(false))
36+
}
37+
38+
for param, value := range filter {
39+
switch param {
40+
case "contains":
41+
contains, err = strconv.ParseBool(value[0])
42+
if err != nil {
43+
return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
44+
}
45+
case "scopes", "scope": // Swagger mentions both of them, let's just support both to make sure we don't break anything
46+
scopes := strings.Split(value[0], ",")
47+
for i, scope := range scopes {
48+
switch strings.ToLower(scope) {
49+
case "ip":
50+
scopes[i] = types.Ip
51+
case "range":
52+
scopes[i] = types.Range
53+
case "country":
54+
scopes[i] = types.Country
55+
case "as":
56+
scopes[i] = types.AS
57+
}
58+
}
59+
60+
query = query.Where(decision.ScopeIn(scopes...))
61+
case "value":
62+
query = query.Where(decision.ValueEQ(value[0]))
63+
case "type":
64+
query = query.Where(decision.TypeEQ(value[0]))
65+
case "origins":
66+
query = query.Where(
67+
decision.OriginIn(strings.Split(value[0], ",")...),
68+
)
69+
case "scenarios_containing":
70+
predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold)
71+
query = query.Where(decision.Or(predicates...))
72+
case "scenarios_not_containing":
73+
predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold)
74+
query = query.Where(decision.Not(
75+
decision.Or(
76+
predicates...,
77+
),
78+
))
79+
case "ip", "range":
80+
rng, err = csnet.NewRange(value[0])
81+
if err != nil {
82+
return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err)
83+
}
84+
case "limit":
85+
limit, err := strconv.Atoi(value[0])
86+
if err != nil {
87+
return nil, errors.Wrapf(InvalidFilter, "invalid limit value : %s", err)
88+
}
89+
90+
query = query.Limit(limit)
91+
case "offset":
92+
offset, err := strconv.Atoi(value[0])
93+
if err != nil {
94+
return nil, errors.Wrapf(InvalidFilter, "invalid offset value : %s", err)
95+
}
96+
97+
query = query.Offset(offset)
98+
case "id_gt":
99+
id, err := strconv.Atoi(value[0])
100+
if err != nil {
101+
return nil, errors.Wrapf(InvalidFilter, "invalid id_gt value : %s", err)
102+
}
103+
104+
query = query.Where(decision.IDGT(id))
105+
}
106+
}
107+
108+
query, err = decisionIPFilter(query, contains, rng)
109+
if err != nil {
110+
return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err)
111+
}
112+
113+
return query, nil
114+
}
115+
116+
func decisionIPv4Filter(decisions *ent.DecisionQuery, contains bool, rng csnet.Range) (*ent.DecisionQuery, error) {
117+
if contains {
118+
/*Decision contains {start_ip,end_ip}*/
119+
return decisions.Where(decision.And(
120+
decision.StartIPLTE(rng.Start.Addr),
121+
decision.EndIPGTE(rng.End.Addr),
122+
decision.IPSizeEQ(int64(rng.Size())))), nil
123+
}
124+
125+
/*Decision is contained within {start_ip,end_ip}*/
126+
return decisions.Where(decision.And(
127+
decision.StartIPGTE(rng.Start.Addr),
128+
decision.EndIPLTE(rng.End.Addr),
129+
decision.IPSizeEQ(int64(rng.Size())))), nil
130+
}
131+
132+
func decisionIPv6Filter(decisions *ent.DecisionQuery, contains bool, rng csnet.Range) (*ent.DecisionQuery, error) {
133+
/*decision contains {start_ip,end_ip}*/
134+
if contains {
135+
return decisions.Where(decision.And(
136+
// matching addr size
137+
decision.IPSizeEQ(int64(rng.Size())),
138+
decision.Or(
139+
// decision.start_ip < query.start_ip
140+
decision.StartIPLT(rng.Start.Addr),
141+
decision.And(
142+
// decision.start_ip == query.start_ip
143+
decision.StartIPEQ(rng.Start.Addr),
144+
// decision.start_suffix <= query.start_suffix
145+
decision.StartSuffixLTE(rng.Start.Sfx),
146+
)),
147+
decision.Or(
148+
// decision.end_ip > query.end_ip
149+
decision.EndIPGT(rng.End.Addr),
150+
decision.And(
151+
// decision.end_ip == query.end_ip
152+
decision.EndIPEQ(rng.End.Addr),
153+
// decision.end_suffix >= query.end_suffix
154+
decision.EndSuffixGTE(rng.End.Sfx),
155+
),
156+
),
157+
)), nil
158+
}
159+
160+
/*decision is contained within {start_ip,end_ip}*/
161+
return decisions.Where(decision.And(
162+
// matching addr size
163+
decision.IPSizeEQ(int64(rng.Size())),
164+
decision.Or(
165+
// decision.start_ip > query.start_ip
166+
decision.StartIPGT(rng.Start.Addr),
167+
decision.And(
168+
// decision.start_ip == query.start_ip
169+
decision.StartIPEQ(rng.Start.Addr),
170+
// decision.start_suffix >= query.start_suffix
171+
decision.StartSuffixGTE(rng.Start.Sfx),
172+
)),
173+
decision.Or(
174+
// decision.end_ip < query.end_ip
175+
decision.EndIPLT(rng.End.Addr),
176+
decision.And(
177+
// decision.end_ip == query.end_ip
178+
decision.EndIPEQ(rng.End.Addr),
179+
// decision.end_suffix <= query.end_suffix
180+
decision.EndSuffixLTE(rng.End.Sfx),
181+
),
182+
),
183+
)), nil
184+
}
185+
186+
func decisionIPFilter(decisions *ent.DecisionQuery, contains bool, rng csnet.Range) (*ent.DecisionQuery, error) {
187+
switch rng.Size() {
188+
case 4:
189+
return decisionIPv4Filter(decisions, contains, rng)
190+
case 16:
191+
return decisionIPv6Filter(decisions, contains, rng)
192+
case 0:
193+
return decisions, nil
194+
default:
195+
return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", rng.Size())
196+
}
197+
}
198+
199+
func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision {
200+
words := strings.Split(s, ",")
201+
predicates := make([]predicate.Decision, len(words))
202+
203+
for i, word := range words {
204+
predicates[i] = predicateFunc(word)
205+
}
206+
207+
return predicates
208+
}

0 commit comments

Comments
 (0)