Skip to content

Commit 848ff39

Browse files
authored
Add OpCallTyped (#269)
1 parent 07e6b41 commit 848ff39

File tree

8 files changed

+376
-6
lines changed

8 files changed

+376
-6
lines changed

ast/node.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ type CallNode struct {
128128
base
129129
Callee Node
130130
Arguments []Node
131+
Typed int
131132
Fast bool
132133
}
133134

checker/checker.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/antonmedv/expr/conf"
99
"github.com/antonmedv/expr/file"
1010
"github.com/antonmedv/expr/parser"
11+
"github.com/antonmedv/expr/vm"
1112
)
1213

1314
func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
@@ -498,7 +499,7 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
498499
}
499500

500501
// checkFunc checks func arguments and returns "return type" of func or method.
501-
func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name string, arguments []ast.Node) (reflect.Type, info) {
502+
func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, name string, arguments []ast.Node) (reflect.Type, info) {
502503
if isAny(fn) {
503504
return anyType, info{}
504505
}
@@ -564,6 +565,36 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
564565
}
565566
}
566567

568+
if !fn.IsVariadic() {
569+
funcTypes:
570+
for i := range vm.FuncTypes {
571+
if i == 0 {
572+
continue
573+
}
574+
typed := reflect.ValueOf(vm.FuncTypes[i]).Elem().Type()
575+
if typed.Kind() != reflect.Func {
576+
continue
577+
}
578+
if typed.NumOut() != fn.NumOut() {
579+
continue
580+
}
581+
for j := 0; j < typed.NumOut(); j++ {
582+
if typed.Out(j) != fn.Out(j) {
583+
continue funcTypes
584+
}
585+
}
586+
if typed.NumIn() != len(arguments) {
587+
continue
588+
}
589+
for j, arg := range arguments {
590+
if typed.In(j) != arg.Type() {
591+
continue funcTypes
592+
}
593+
}
594+
node.Typed = i
595+
}
596+
}
597+
567598
return fn.Out(0), info{}
568599
}
569600

compiler/compiler.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,15 @@ func (c *compiler) CallNode(node *ast.CallNode) {
523523
for _, arg := range node.Arguments {
524524
c.compile(arg)
525525
}
526-
op := OpCall
527-
if node.Fast {
528-
op = OpCallFast
529-
}
530526
c.compile(node.Callee)
531-
c.emit(op, len(node.Arguments))
527+
if node.Typed > 0 {
528+
c.emit(OpCallTyped, node.Typed)
529+
return
530+
} else if node.Fast {
531+
c.emit(OpCallFast, len(node.Arguments))
532+
} else {
533+
c.emit(OpCall, len(node.Arguments))
534+
}
532535
}
533536

534537
func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {

vm/func_types/main.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"go/format"
7+
"reflect"
8+
"strings"
9+
"text/template"
10+
. "time"
11+
)
12+
13+
// Keep sorted.
14+
var types = []interface{}{
15+
nil,
16+
new(func() Duration),
17+
new(func() Month),
18+
new(func() Time),
19+
new(func() Weekday),
20+
new(func() []byte),
21+
new(func() []interface{}),
22+
new(func() bool),
23+
new(func() byte),
24+
new(func() float64),
25+
new(func() int),
26+
new(func() int64),
27+
new(func() interface{}),
28+
new(func() map[string]interface{}),
29+
new(func() rune),
30+
new(func() string),
31+
new(func() uint),
32+
new(func() uint64),
33+
new(func(Duration) Duration),
34+
new(func(Duration) Time),
35+
new(func(Time) Duration),
36+
new(func(Time) bool),
37+
new(func([]interface{}, string) string),
38+
new(func([]string, string) string),
39+
new(func(bool) bool),
40+
new(func(bool) float64),
41+
new(func(bool) int),
42+
new(func(bool) string),
43+
new(func(float64) bool),
44+
new(func(float64) float64),
45+
new(func(float64) int),
46+
new(func(float64) string),
47+
new(func(int) bool),
48+
new(func(int) float64),
49+
new(func(int) int),
50+
new(func(int) string),
51+
new(func(int, int) int),
52+
new(func(int, int) string),
53+
new(func(int64) Time),
54+
new(func(string) []string),
55+
new(func(string) bool),
56+
new(func(string) float64),
57+
new(func(string) int),
58+
new(func(string) string),
59+
new(func(string, byte) int),
60+
new(func(string, int) int),
61+
new(func(string, rune) int),
62+
new(func(string, string) bool),
63+
new(func(string, string) string),
64+
}
65+
66+
func main() {
67+
data := struct {
68+
Index string
69+
Code string
70+
}{}
71+
72+
for i, t := range types {
73+
if i == 0 {
74+
continue
75+
}
76+
fn := reflect.ValueOf(t).Elem().Type()
77+
data.Index += fmt.Sprintf("%v: new(%v),\n", i, fn)
78+
data.Code += fmt.Sprintf("case %d:\n", i)
79+
args := make([]string, fn.NumIn())
80+
for j := fn.NumIn() - 1; j >= 0; j-- {
81+
data.Code += fmt.Sprintf("arg%v := vm.pop().(%v)\n", j+1, fn.In(j))
82+
args[j] = fmt.Sprintf("arg%v", j+1)
83+
}
84+
data.Code += fmt.Sprintf("return fn.(%v)(%v)\n", fn, strings.Join(args, ", "))
85+
}
86+
87+
var b bytes.Buffer
88+
err := template.Must(
89+
template.New("func_types").
90+
Parse(source),
91+
).Execute(&b, data)
92+
if err != nil {
93+
panic(err)
94+
}
95+
96+
formatted, err := format.Source(b.Bytes())
97+
if err != nil {
98+
panic(err)
99+
}
100+
fmt.Print(string(formatted))
101+
}
102+
103+
const source = `// Code generated by vm/func_types/main.go. DO NOT EDIT.
104+
105+
package vm
106+
107+
import (
108+
"fmt"
109+
"time"
110+
)
111+
112+
var FuncTypes = []interface{}{
113+
{{ .Index }}
114+
}
115+
116+
func (vm *VM) call(fn interface{}, kind int) interface{} {
117+
switch kind {
118+
{{ .Code }}
119+
}
120+
panic(fmt.Sprintf("unknown function kind (%v)", kind))
121+
}
122+
`

0 commit comments

Comments
 (0)