Skip to content

Commit 6fa83ad

Browse files
committed
Improve type checker for named types
1 parent d63c3b5 commit 6fa83ad

File tree

9 files changed

+215
-92
lines changed

9 files changed

+215
-92
lines changed

builtin/builtin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ var Builtins = []*Function{
164164
if len(args) != 1 {
165165
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
166166
}
167-
switch kind(args[0]) {
167+
switch kind(deref.Type(args[0])) {
168168
case reflect.Interface:
169169
return integerType, nil
170170
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:

builtin/builtin_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,16 @@ func Test_int_unwraps_underlying_value(t *testing.T) {
638638
require.NoError(t, err)
639639
assert.Equal(t, true, out)
640640
}
641+
642+
func TestBuiltin_int_with_deref(t *testing.T) {
643+
x := 42
644+
env := map[string]any{
645+
"x": &x,
646+
}
647+
program, err := expr.Compile(`int(x)`, expr.Env(env))
648+
require.NoError(t, err)
649+
650+
out, err := expr.Run(program, env)
651+
require.NoError(t, err)
652+
assert.Equal(t, 42, out)
653+
}

builtin/lib.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func Round(x any) any {
180180
}
181181

182182
func Int(x any) any {
183-
switch x := x.(type) {
183+
switch x := deref.Deref(x).(type) {
184184
case float32:
185185
return int(x)
186186
case float64:

checker/checker.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,10 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
319319
if isTime(l) && isTime(r) {
320320
return boolNature
321321
}
322-
if or(l, r, isNumber, isString, isTime) {
322+
if isDuration(l) && isDuration(r) {
323+
return boolNature
324+
}
325+
if or(l, r, isNumber, isString, isTime, isDuration) {
323326
return boolNature
324327
}
325328

@@ -333,6 +336,9 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
333336
if isTime(l) && isDuration(r) {
334337
return timeNature
335338
}
339+
if isDuration(l) && isDuration(r) {
340+
return durationNature
341+
}
336342
if or(l, r, isNumber, isTime, isDuration) {
337343
return unknown
338344
}
@@ -341,7 +347,16 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
341347
if isNumber(l) && isNumber(r) {
342348
return combined(l, r)
343349
}
344-
if or(l, r, isNumber) {
350+
if isNumber(l) && isDuration(r) {
351+
return durationNature
352+
}
353+
if isDuration(l) && isNumber(r) {
354+
return durationNature
355+
}
356+
if isDuration(l) && isDuration(r) {
357+
return durationNature
358+
}
359+
if or(l, r, isNumber, isDuration) {
345360
return unknown
346361
}
347362

@@ -382,6 +397,9 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
382397
if isDuration(l) && isTime(r) {
383398
return timeNature
384399
}
400+
if isDuration(l) && isDuration(r) {
401+
return durationNature
402+
}
385403
if or(l, r, isNumber, isString, isTime, isDuration) {
386404
return unknown
387405
}

checker/nature/nature.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ func (n Nature) FieldByName(name string) (Nature, bool) {
161161
return Nature{Type: field.Type, FieldIndex: field.Index}, ok
162162
}
163163

164+
func (n Nature) PkgPath() string {
165+
if n.Type == nil {
166+
return ""
167+
}
168+
return n.Type.PkgPath()
169+
}
170+
164171
func (n Nature) IsFastMap() bool {
165172
if n.Type == nil {
166173
return false

checker/types.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ func isInteger(nt Nature) bool {
8989
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
9090
fallthrough
9191
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
92-
return true
92+
return nt.PkgPath() == ""
9393
}
9494
return false
9595
}
9696

9797
func isFloat(nt Nature) bool {
9898
switch nt.Kind() {
9999
case reflect.Float32, reflect.Float64:
100-
return true
100+
return nt.PkgPath() == ""
101101
}
102102
return false
103103
}
@@ -178,15 +178,23 @@ func kind(t reflect.Type) reflect.Kind {
178178
}
179179

180180
func isComparable(l, r Nature) bool {
181-
switch {
182-
case l.Kind() == r.Kind():
181+
if isUnknown(l) || isUnknown(r) {
182+
return true
183+
}
184+
if isNil(l) || isNil(r) {
183185
return true
184-
case isNumber(l) && isNumber(r):
186+
}
187+
if isNumber(l) && isNumber(r) {
188+
return true
189+
}
190+
if isDuration(l) && isDuration(r) {
185191
return true
186-
case isNil(l) || isNil(r):
192+
}
193+
if isTime(l) && isTime(r) {
187194
return true
188-
case isUnknown(l) || isUnknown(r):
195+
}
196+
if isArray(l) && isArray(r) {
189197
return true
190198
}
191-
return false
199+
return l.AssignableTo(r)
192200
}

expr_test.go

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,10 @@ func TestExpr(t *testing.T) {
11191119
`duration("1h") + duration("1m")`,
11201120
time.Hour + time.Minute,
11211121
},
1122+
{
1123+
`duration("1h") - duration("1m")`,
1124+
time.Hour - time.Minute,
1125+
},
11221126
{
11231127
`7 * duration("1h")`,
11241128
7 * time.Hour,
@@ -2328,86 +2332,6 @@ func TestIssue432(t *testing.T) {
23282332
assert.Equal(t, float64(10), out)
23292333
}
23302334

2331-
func TestIssue461(t *testing.T) {
2332-
type EnvStr string
2333-
type EnvField struct {
2334-
S EnvStr
2335-
Str string
2336-
}
2337-
type Env struct {
2338-
S EnvStr
2339-
Str string
2340-
EnvField EnvField
2341-
}
2342-
var tests = []struct {
2343-
input string
2344-
env Env
2345-
want bool
2346-
}{
2347-
{
2348-
input: "Str == S",
2349-
env: Env{S: "string", Str: "string"},
2350-
want: false,
2351-
},
2352-
{
2353-
input: "Str == Str",
2354-
env: Env{Str: "string"},
2355-
want: true,
2356-
},
2357-
{
2358-
input: "S == S",
2359-
env: Env{Str: "string"},
2360-
want: true,
2361-
},
2362-
{
2363-
input: `Str == "string"`,
2364-
env: Env{Str: "string"},
2365-
want: true,
2366-
},
2367-
{
2368-
input: `S == "string"`,
2369-
env: Env{Str: "string"},
2370-
want: false,
2371-
},
2372-
{
2373-
input: "EnvField.Str == EnvField.S",
2374-
env: Env{EnvField: EnvField{S: "string", Str: "string"}},
2375-
want: false,
2376-
},
2377-
{
2378-
input: "EnvField.Str == EnvField.Str",
2379-
env: Env{EnvField: EnvField{Str: "string"}},
2380-
want: true,
2381-
},
2382-
{
2383-
input: "EnvField.S == EnvField.S",
2384-
env: Env{EnvField: EnvField{Str: "string"}},
2385-
want: true,
2386-
},
2387-
{
2388-
input: `EnvField.Str == "string"`,
2389-
env: Env{EnvField: EnvField{Str: "string"}},
2390-
want: true,
2391-
},
2392-
{
2393-
input: `EnvField.S == "string"`,
2394-
env: Env{EnvField: EnvField{Str: "string"}},
2395-
want: false,
2396-
},
2397-
}
2398-
2399-
for _, tt := range tests {
2400-
t.Run(tt.input, func(t *testing.T) {
2401-
program, err := expr.Compile(tt.input, expr.Env(tt.env), expr.AsBool())
2402-
2403-
out, err := expr.Run(program, tt.env)
2404-
require.NoError(t, err)
2405-
2406-
require.Equal(t, tt.want, out)
2407-
})
2408-
}
2409-
}
2410-
24112335
func TestIssue462(t *testing.T) {
24122336
env := map[string]any{
24132337
"foo": func() (string, error) {

test/issues/461/issue_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package issue_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
"github.com/expr-lang/expr/internal/testify/require"
8+
)
9+
10+
func TestIssue461(t *testing.T) {
11+
type EnvStr string
12+
type EnvField struct {
13+
S EnvStr
14+
Str string
15+
}
16+
type Env struct {
17+
S EnvStr
18+
Str string
19+
EnvField EnvField
20+
}
21+
var tests = []struct {
22+
input string
23+
env Env
24+
want bool
25+
err string
26+
}{
27+
{
28+
input: "Str == S",
29+
env: Env{S: "string", Str: "string"},
30+
err: "invalid operation: == (mismatched types string and issue_test.EnvStr)",
31+
},
32+
{
33+
input: "Str == Str",
34+
env: Env{Str: "string"},
35+
want: true,
36+
},
37+
{
38+
input: "S == S",
39+
env: Env{Str: "string"},
40+
want: true,
41+
},
42+
{
43+
input: `Str == "string"`,
44+
env: Env{Str: "string"},
45+
want: true,
46+
},
47+
{
48+
input: `S == "string"`,
49+
env: Env{Str: "string"},
50+
err: "invalid operation: == (mismatched types issue_test.EnvStr and string)",
51+
},
52+
{
53+
input: "EnvField.Str == EnvField.S",
54+
env: Env{EnvField: EnvField{S: "string", Str: "string"}},
55+
err: "invalid operation: == (mismatched types string and issue_test.EnvStr)",
56+
},
57+
{
58+
input: "EnvField.Str == EnvField.Str",
59+
env: Env{EnvField: EnvField{Str: "string"}},
60+
want: true,
61+
},
62+
{
63+
input: "EnvField.S == EnvField.S",
64+
env: Env{EnvField: EnvField{Str: "string"}},
65+
want: true,
66+
},
67+
{
68+
input: `EnvField.Str == "string"`,
69+
env: Env{EnvField: EnvField{Str: "string"}},
70+
want: true,
71+
},
72+
{
73+
input: `EnvField.S == "string"`,
74+
env: Env{EnvField: EnvField{Str: "string"}},
75+
err: "invalid operation: == (mismatched types issue_test.EnvStr and string)",
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.input, func(t *testing.T) {
81+
program, err := expr.Compile(tt.input, expr.Env(tt.env), expr.AsBool())
82+
83+
if tt.err != "" {
84+
require.Error(t, err)
85+
require.Contains(t, err.Error(), tt.err)
86+
} else {
87+
out, err := expr.Run(program, tt.env)
88+
require.NoError(t, err)
89+
require.Equal(t, tt.want, out)
90+
}
91+
})
92+
}
93+
}

test/issues/730/issue_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package issue_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
"github.com/expr-lang/expr/internal/testify/require"
8+
)
9+
10+
type ModeEnum int
11+
12+
const (
13+
ModeEnumA ModeEnum = 1
14+
)
15+
16+
type Env struct {
17+
Mode *ModeEnum
18+
}
19+
20+
func TestIssue730(t *testing.T) {
21+
code := `int(Mode) == 1`
22+
23+
tmp := ModeEnumA
24+
25+
env := map[string]any{
26+
"Mode": &tmp,
27+
}
28+
29+
program, err := expr.Compile(code, expr.Env(env))
30+
require.NoError(t, err)
31+
32+
output, err := expr.Run(program, env)
33+
require.NoError(t, err)
34+
require.True(t, output.(bool))
35+
}
36+
37+
func TestIssue730_warn_about_different_types(t *testing.T) {
38+
code := `Mode == 1`
39+
40+
_, err := expr.Compile(code, expr.Env(Env{}))
41+
require.Error(t, err)
42+
require.Contains(t, err.Error(), "invalid operation: == (mismatched types issue_test.ModeEnum and int)")
43+
}
44+
45+
func TestIssue730_eval(t *testing.T) {
46+
code := `Mode == 1`
47+
48+
tmp := ModeEnumA
49+
50+
env := map[string]any{
51+
"Mode": &tmp,
52+
}
53+
54+
// Golang also does not allow this:
55+
// _ = ModeEnumA == int(1) // will not compile
56+
57+
out, err := expr.Eval(code, env)
58+
require.NoError(t, err)
59+
require.False(t, out.(bool))
60+
}

0 commit comments

Comments
 (0)