Skip to content

Commit 0223e4d

Browse files
committed
Add better nil checks
1 parent 7157a97 commit 0223e4d

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

expr_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ func TestExpr(t *testing.T) {
436436
return ret
437437
},
438438
Inc: func(a int) int { return a + 1 },
439+
Nil: nil,
439440
}
440441

441442
tests := []struct {
@@ -731,13 +732,17 @@ func TestExpr(t *testing.T) {
731732
"bar",
732733
},
733734
{
734-
`Nil`,
735+
`NilStruct`,
735736
(*time.Time)(nil),
736737
},
737738
{
738-
`Nil == nil && nil == Nil && nil == nil && Nil == Nil`,
739+
`Nil == nil && nil == Nil && nil == nil && Nil == Nil && NilInt == nil && NilSlice == nil && NilStruct == nil`,
739740
true,
740741
},
742+
{
743+
`0 == nil || "str" == nil || true == nil`,
744+
false,
745+
},
741746
}
742747

743748
for _, tt := range tests {
@@ -786,7 +791,10 @@ type mockEnv struct {
786791
Segments []*segment
787792
BirthDay time.Time
788793
Now time.Time
789-
Nil *time.Time
794+
Nil interface{}
795+
NilStruct *time.Time
796+
NilInt *int
797+
NilSlice []ticket
790798
}
791799

792800
func (e *mockEnv) GetInt() int {

vm/generate/main.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,13 @@ func main() {
123123
}
124124
if helper.string {
125125
echo(`case string:`)
126-
echo(`return x %v b.(string)`, op)
126+
echo(`switch y := b.(type) {`)
127+
echo(`case string: return x %v y`, op)
128+
echo(`}`)
127129
}
128130
echo(`}`)
129131
if name == "equal" {
130-
echo(`// Two nil values should be considered as equal.`)
131-
echo(`if (a == nil || reflect.ValueOf(a).IsNil()) && (b == nil || reflect.ValueOf(b).IsNil()) { return true }`)
132+
echo(`if isNil(a) && isNil(b) { return true }`)
132133
echo(`return reflect.DeepEqual(a, b)`)
133134
} else {
134135
echo(`panic(fmt.Sprintf("invalid operation: %%T %%v %%T", a, "%v", b))`, op)

vm/helpers.go

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,12 @@ func equal(a, b interface{}) interface{} {
332332
return x == y
333333
}
334334
case string:
335-
return x == b.(string)
335+
switch y := b.(type) {
336+
case string:
337+
return x == y
338+
}
336339
}
337-
// Two nil values should be considered as equal.
338-
if (a == nil || reflect.ValueOf(a).IsNil()) && (b == nil || reflect.ValueOf(b).IsNil()) {
340+
if isNil(a) && isNil(b) {
339341
return true
340342
}
341343
return reflect.DeepEqual(a, b)
@@ -668,7 +670,10 @@ func less(a, b interface{}) interface{} {
668670
return x < y
669671
}
670672
case string:
671-
return x < b.(string)
673+
switch y := b.(type) {
674+
case string:
675+
return x < y
676+
}
672677
}
673678
panic(fmt.Sprintf("invalid operation: %T %v %T", a, "<", b))
674679
}
@@ -1000,7 +1005,10 @@ func more(a, b interface{}) interface{} {
10001005
return x > y
10011006
}
10021007
case string:
1003-
return x > b.(string)
1008+
switch y := b.(type) {
1009+
case string:
1010+
return x > y
1011+
}
10041012
}
10051013
panic(fmt.Sprintf("invalid operation: %T %v %T", a, ">", b))
10061014
}
@@ -1332,7 +1340,10 @@ func lessOrEqual(a, b interface{}) interface{} {
13321340
return x <= y
13331341
}
13341342
case string:
1335-
return x <= b.(string)
1343+
switch y := b.(type) {
1344+
case string:
1345+
return x <= y
1346+
}
13361347
}
13371348
panic(fmt.Sprintf("invalid operation: %T %v %T", a, "<=", b))
13381349
}
@@ -1664,7 +1675,10 @@ func moreOrEqual(a, b interface{}) interface{} {
16641675
return x >= y
16651676
}
16661677
case string:
1667-
return x >= b.(string)
1678+
switch y := b.(type) {
1679+
case string:
1680+
return x >= y
1681+
}
16681682
}
16691683
panic(fmt.Sprintf("invalid operation: %T %v %T", a, ">=", b))
16701684
}
@@ -1996,7 +2010,10 @@ func add(a, b interface{}) interface{} {
19962010
return x + y
19972011
}
19982012
case string:
1999-
return x + b.(string)
2013+
switch y := b.(type) {
2014+
case string:
2015+
return x + y
2016+
}
20002017
}
20012018
panic(fmt.Sprintf("invalid operation: %T %v %T", a, "+", b))
20022019
}

vm/runtime.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,16 @@ func toFloat64(a interface{}) float64 {
313313
panic(fmt.Sprintf("invalid operation: float64(%T)", x))
314314
}
315315
}
316+
317+
func isNil(v interface{}) bool {
318+
if v == nil {
319+
return true
320+
}
321+
r := reflect.ValueOf(v)
322+
switch r.Kind() {
323+
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice:
324+
return r.IsNil()
325+
default:
326+
return false
327+
}
328+
}

0 commit comments

Comments
 (0)