Skip to content

Commit 3d4c219

Browse files
committed
Simplify "not in" operators
1 parent 7c1d2ee commit 3d4c219

File tree

12 files changed

+350
-263
lines changed

12 files changed

+350
-263
lines changed

ast/node.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,12 @@ type UnaryNode struct {
8888

8989
type BinaryNode struct {
9090
base
91+
Regexp *regexp.Regexp
9192
Operator string
9293
Left Node
9394
Right Node
9495
}
9596

96-
type MatchesNode struct {
97-
base
98-
Regexp *regexp.Regexp
99-
Left Node
100-
Right Node
101-
}
102-
10397
type ChainNode struct {
10498
base
10599
Node Node

ast/visitor.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ func Walk(node *Node, v Visitor) {
2020
case *BinaryNode:
2121
Walk(&n.Left, v)
2222
Walk(&n.Right, v)
23-
case *MatchesNode:
24-
Walk(&n.Left, v)
25-
Walk(&n.Right, v)
2623
case *ChainNode:
2724
Walk(&n.Node, v)
2825
case *MemberNode:

checker/checker.go

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package checker
33
import (
44
"fmt"
55
"reflect"
6+
"regexp"
67

78
"github.com/antonmedv/expr/ast"
89
"github.com/antonmedv/expr/conf"
@@ -78,8 +79,6 @@ func (v *visitor) visit(node ast.Node) (reflect.Type, info) {
7879
t, i = v.UnaryNode(n)
7980
case *ast.BinaryNode:
8081
t, i = v.BinaryNode(n)
81-
case *ast.MatchesNode:
82-
t, i = v.MatchesNode(n)
8382
case *ast.ChainNode:
8483
t, i = v.ChainNode(n)
8584
case *ast.MemberNode:
@@ -232,23 +231,6 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
232231
return boolType, info{}
233232
}
234233

235-
case "in", "not in":
236-
if (isString(l) || isAny(l)) && isStruct(r) {
237-
return boolType, info{}
238-
}
239-
if isMap(r) {
240-
return boolType, info{}
241-
}
242-
if isArray(r) {
243-
return boolType, info{}
244-
}
245-
if isAny(l) && anyOf(r, isString, isArray, isMap) {
246-
return boolType, info{}
247-
}
248-
if isAny(l) && isAny(r) {
249-
return boolType, info{}
250-
}
251-
252234
case "<", ">", ">=", "<=":
253235
if isNumber(l) && isNumber(r) {
254236
return boolType, info{}
@@ -315,6 +297,38 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
315297
return anyType, info{}
316298
}
317299

300+
case "in":
301+
if (isString(l) || isAny(l)) && isStruct(r) {
302+
return boolType, info{}
303+
}
304+
if isMap(r) {
305+
return boolType, info{}
306+
}
307+
if isArray(r) {
308+
return boolType, info{}
309+
}
310+
if isAny(l) && anyOf(r, isString, isArray, isMap) {
311+
return boolType, info{}
312+
}
313+
if isAny(l) && isAny(r) {
314+
return boolType, info{}
315+
}
316+
317+
case "matches":
318+
if s, ok := node.Right.(*ast.StringNode); ok {
319+
r, err := regexp.Compile(s.Value)
320+
if err != nil {
321+
return v.error(node, err.Error())
322+
}
323+
node.Regexp = r
324+
}
325+
if isString(l) && isString(r) {
326+
return boolType, info{}
327+
}
328+
if or(l, r, isString) {
329+
return boolType, info{}
330+
}
331+
318332
case "contains", "startsWith", "endsWith":
319333
if isString(l) && isString(r) {
320334
return boolType, info{}
@@ -340,20 +354,6 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
340354
return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
341355
}
342356

343-
func (v *visitor) MatchesNode(node *ast.MatchesNode) (reflect.Type, info) {
344-
l, _ := v.visit(node.Left)
345-
r, _ := v.visit(node.Right)
346-
347-
if isString(l) && isString(r) {
348-
return boolType, info{}
349-
}
350-
if or(l, r, isString) {
351-
return boolType, info{}
352-
}
353-
354-
return v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r)
355-
}
356-
357357
func (v *visitor) ChainNode(node *ast.ChainNode) (reflect.Type, info) {
358358
return v.visit(node.Node)
359359
}

checker/checker_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ var successTests = []string{
6060
"String in MapOfFoo",
6161
"String matches 'ok'",
6262
"String matches Any",
63+
"String not matches Any",
6364
"String not in ArrayOfFoo",
6465
"StringPtr == nil",
6566
"[1, 2, 3] == []",

compiler/compiler.go

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ func (c *compiler) compile(node ast.Node) {
161161
c.UnaryNode(n)
162162
case *ast.BinaryNode:
163163
c.BinaryNode(n)
164-
case *ast.MatchesNode:
165-
c.MatchesNode(n)
166164
case *ast.ChainNode:
167165
c.ChainNode(n)
168166
case *ast.MemberNode:
@@ -329,17 +327,6 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
329327
c.compile(node.Right)
330328
c.patchJump(end)
331329

332-
case "in":
333-
c.compile(node.Left)
334-
c.compile(node.Right)
335-
c.emit(OpIn)
336-
337-
case "not in":
338-
c.compile(node.Left)
339-
c.compile(node.Right)
340-
c.emit(OpIn)
341-
c.emit(OpNot)
342-
343330
case "<":
344331
c.compile(node.Left)
345332
c.compile(node.Right)
@@ -390,6 +377,21 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
390377
c.compile(node.Right)
391378
c.emit(OpExponent)
392379

380+
case "in":
381+
c.compile(node.Left)
382+
c.compile(node.Right)
383+
c.emit(OpIn)
384+
385+
case "matches":
386+
if node.Regexp != nil {
387+
c.compile(node.Left)
388+
c.emit(OpMatchesConst, c.addConstant(node.Regexp))
389+
} else {
390+
c.compile(node.Left)
391+
c.compile(node.Right)
392+
c.emit(OpMatches)
393+
}
394+
393395
case "contains":
394396
c.compile(node.Left)
395397
c.compile(node.Right)
@@ -416,17 +418,6 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
416418
}
417419
}
418420

419-
func (c *compiler) MatchesNode(node *ast.MatchesNode) {
420-
if node.Regexp != nil {
421-
c.compile(node.Left)
422-
c.emit(OpMatchesConst, c.addConstant(node.Regexp))
423-
return
424-
}
425-
c.compile(node.Left)
426-
c.compile(node.Right)
427-
c.emit(OpMatches)
428-
}
429-
430421
func (c *compiler) ChainNode(node *ast.ChainNode) {
431422
c.chains = append(c.chains, []int{})
432423
c.compile(node.Node)

optimizer/in_array.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ type inArray struct{}
1111
func (*inArray) Visit(node *Node) {
1212
switch n := (*node).(type) {
1313
case *BinaryNode:
14-
if n.Operator == "in" || n.Operator == "not in" {
14+
if n.Operator == "in" {
1515
if array, ok := n.Right.(*ArrayNode); ok {
1616
if len(array.Nodes) > 0 {
1717
t := n.Left.Type()

optimizer/in_range.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ type inRange struct{}
99
func (*inRange) Visit(node *Node) {
1010
switch n := (*node).(type) {
1111
case *BinaryNode:
12-
if n.Operator == "in" || n.Operator == "not in" {
12+
if n.Operator == "in" {
1313
if rng, ok := n.Right.(*BinaryNode); ok && rng.Operator == ".." {
1414
if from, ok := rng.Left.(*IntegerNode); ok {
1515
if to, ok := rng.Right.(*IntegerNode); ok {
@@ -26,12 +26,6 @@ func (*inRange) Visit(node *Node) {
2626
Right: to,
2727
},
2828
})
29-
if n.Operator == "not in" {
30-
Patch(node, &UnaryNode{
31-
Operator: "not",
32-
Node: *node,
33-
})
34-
}
3529
}
3630
}
3731
}

parser/lexer/lexer.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ func (l *lexer) emitEOF() {
9595
l.startLoc = l.loc
9696
}
9797

98+
func (l *lexer) skip() {
99+
l.start = l.end
100+
l.startLoc = l.loc
101+
}
102+
98103
func (l *lexer) word() string {
99104
return l.input[l.start:l.end]
100105
}
@@ -118,22 +123,26 @@ func (l *lexer) acceptRun(valid string) {
118123
l.backup()
119124
}
120125

121-
func (l *lexer) acceptWord(word string) bool {
122-
pos, loc, prev := l.end, l.loc, l.prev
123-
124-
// Skip spaces (U+0020) if any
126+
func (l *lexer) skipSpaces() {
125127
r := l.peek()
126128
for ; r == ' '; r = l.peek() {
127129
l.next()
128130
}
131+
l.skip()
132+
}
133+
134+
func (l *lexer) acceptWord(word string) bool {
135+
pos, loc, prev := l.end, l.loc, l.prev
136+
137+
l.skipSpaces()
129138

130139
for _, ch := range word {
131140
if l.next() != ch {
132141
l.end, l.loc, l.prev = pos, loc, prev
133142
return false
134143
}
135144
}
136-
if r = l.peek(); r != ' ' && r != eof {
145+
if r := l.peek(); r != ' ' && r != eof {
137146
l.end, l.loc, l.prev = pos, loc, prev
138147
return false
139148
}

parser/lexer/lexer_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ var lexTests = []lexTest{
112112
{
113113
`not in not abc not i not(false) not in not in`,
114114
[]Token{
115-
{Kind: Operator, Value: "not in"},
115+
{Kind: Operator, Value: "not"},
116+
{Kind: Operator, Value: "in"},
116117
{Kind: Operator, Value: "not"},
117118
{Kind: Identifier, Value: "abc"},
118119
{Kind: Operator, Value: "not"},
@@ -121,8 +122,10 @@ var lexTests = []lexTest{
121122
{Kind: Bracket, Value: "("},
122123
{Kind: Identifier, Value: "false"},
123124
{Kind: Bracket, Value: ")"},
124-
{Kind: Operator, Value: "not in"},
125-
{Kind: Operator, Value: "not in"},
125+
{Kind: Operator, Value: "not"},
126+
{Kind: Operator, Value: "in"},
127+
{Kind: Operator, Value: "not"},
128+
{Kind: Operator, Value: "in"},
126129
{Kind: EOF},
127130
},
128131
},
@@ -137,7 +140,8 @@ var lexTests = []lexTest{
137140
{
138141
"not in",
139142
[]Token{
140-
{Kind: Operator, Value: "not in"},
143+
{Kind: Operator, Value: "not"},
144+
{Kind: Operator, Value: "in"},
141145
{Kind: EOF},
142146
},
143147
},

parser/lexer/state.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,29 @@ loop:
127127
}
128128

129129
func not(l *lexer) stateFn {
130-
switch l.acceptWord("in") {
131-
case true:
132-
l.emitValue(Operator, "not in")
133-
case false:
134-
l.emitValue(Operator, "not")
130+
l.emit(Operator)
131+
132+
l.skipSpaces()
133+
134+
pos, loc, prev := l.end, l.loc, l.prev
135+
136+
// Get the next word.
137+
for {
138+
r := l.next()
139+
if IsAlphaNumeric(r) {
140+
// absorb
141+
} else {
142+
l.backup()
143+
break
144+
}
135145
}
136146

147+
switch l.word() {
148+
case "in", "matches", "contains", "startsWith", "endsWith":
149+
l.emit(Operator)
150+
default:
151+
l.end, l.loc, l.prev = pos, loc, prev
152+
}
137153
return root
138154
}
139155

0 commit comments

Comments
 (0)