Skip to content

Commit 5bf0b76

Browse files
Copilotdrewlanenga
andauthored
Add IP filter builtin function (#52)
* Initial plan * Implement ipfilter builtin function with comprehensive tests Co-authored-by: drewlanenga <[email protected]> * Fix ipfilter return values for parsing errors Co-authored-by: drewlanenga <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: drewlanenga <[email protected]>
1 parent edc52cc commit 5bf0b76

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

expr/builtins/builtins.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ func LoadAllBuiltins() {
135135
// MySQL Builtins
136136
expr.FuncAdd("cast", &Cast{})
137137
expr.FuncAdd("char_length", &Length{})
138+
139+
// Network functions
140+
expr.FuncAdd("ipfilter", &IPFilter{})
138141
})
139142
}
140143

expr/builtins/builtins_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,25 @@ var builtinTests = []testBuiltins{
804804
{`json.jmespath(json_field, "[?b].tags | [0 ")`, nil},
805805
{`json.jmespath(json_bad, "[?b].tags | [0 ")`, value.ErrValue},
806806
{`json.jmespath(json_object, "name")`, value.NewStringValue("bob")},
807+
808+
// IP Filter tests
809+
{`ipfilter("192.168.1.100", "192.168.1.0/24")`, value.BoolValueTrue},
810+
{`ipfilter("192.168.1.1", "192.168.1.0/24")`, value.BoolValueTrue},
811+
{`ipfilter("192.168.1.254", "192.168.1.0/24")`, value.BoolValueTrue},
812+
{`ipfilter("192.168.2.1", "192.168.1.0/24")`, value.BoolValueFalse},
813+
{`ipfilter("10.0.0.1", "192.168.1.0/24")`, value.BoolValueFalse},
814+
{`ipfilter("127.0.0.1", "127.0.0.0/8")`, value.BoolValueTrue},
815+
{`ipfilter("128.0.0.1", "127.0.0.0/8")`, value.BoolValueFalse},
816+
{`ipfilter("10.10.10.10", "10.0.0.0/8")`, value.BoolValueTrue},
817+
{`ipfilter("11.10.10.10", "10.0.0.0/8")`, value.BoolValueFalse},
818+
// IPv6 tests
819+
{`ipfilter("2001:db8::1", "2001:db8::/32")`, value.BoolValueTrue},
820+
{`ipfilter("2001:db9::1", "2001:db8::/32")`, value.BoolValueFalse},
821+
{`ipfilter("::1", "::1/128")`, value.BoolValueTrue},
822+
// Invalid inputs should fail evaluation
823+
{`ipfilter("invalid.ip", "192.168.1.0/24")`, nil},
824+
{`ipfilter("192.168.1.1", "invalid/cidr")`, nil},
825+
{`ipfilter("192.168.1.1", "192.168.1.0/35")`, nil}, // Invalid CIDR range
807826
}
808827

809828
var testValidation = []string{
@@ -923,6 +942,11 @@ var testValidation = []string{
923942
`json.jmespath(json_field)`, // Must have 2 args
924943
`json.jmespath(json_field, 1)`, // Must have 2 args, 2nd must be string
925944
`json.jmespath(json_bad, "")`,
945+
946+
// IP Filter validation
947+
`ipfilter()`, // Must have 2 args
948+
`ipfilter("192.168.1.1")`, // Must have 2 args
949+
`ipfilter("192.168.1.1", "192.168.1.0/24", "extra")`, // Must have only 2 args
926950
}
927951
var testValidationx = []string{
928952
`tolower()`, `lower(a,b)`, // must be one arg

expr/builtins/network.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package builtins
2+
3+
import (
4+
"fmt"
5+
"net"
6+
7+
"github.com/lytics/qlbridge/expr"
8+
"github.com/lytics/qlbridge/value"
9+
)
10+
11+
// IPFilter determines whether an IP address is contained within a given CIDR subnet
12+
//
13+
// ipfilter("192.168.1.100", "192.168.1.0/24") => true
14+
// ipfilter("10.0.0.1", "192.168.1.0/24") => false
15+
type IPFilter struct{}
16+
17+
// Type is Bool
18+
func (m *IPFilter) Type() value.ValueType { return value.BoolType }
19+
20+
func (m *IPFilter) Validate(n *expr.FuncNode) (expr.EvaluatorFunc, error) {
21+
if len(n.Args) != 2 {
22+
return nil, fmt.Errorf("Expected 2 args for ipfilter(ip_address, subnet_cidr) but got %s", n)
23+
}
24+
return ipFilterEval, nil
25+
}
26+
27+
func ipFilterEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) {
28+
// Convert arguments to strings
29+
ipStr, ipOk := value.ValueToString(args[0])
30+
cidrStr, cidrOk := value.ValueToString(args[1])
31+
32+
if !ipOk || !cidrOk {
33+
return value.BoolValueFalse, false
34+
}
35+
36+
// Parse the IP address
37+
ip := net.ParseIP(ipStr)
38+
if ip == nil {
39+
// Invalid IP address
40+
return value.BoolValueFalse, false
41+
}
42+
43+
// Parse the CIDR notation
44+
_, ipNet, err := net.ParseCIDR(cidrStr)
45+
if err != nil {
46+
// Invalid CIDR notation
47+
return value.BoolValueFalse, false
48+
}
49+
50+
// Check if IP is contained in the subnet
51+
if ipNet.Contains(ip) {
52+
return value.BoolValueTrue, true
53+
}
54+
55+
return value.BoolValueFalse, true
56+
}

0 commit comments

Comments
 (0)