Skip to content

Commit 0b7b47f

Browse files
committed
Add expr.Function feature
1 parent 1888db0 commit 0b7b47f

File tree

10 files changed

+287
-84
lines changed

10 files changed

+287
-84
lines changed

ast/node.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ type IdentifierNode struct {
5151
Value string
5252
Deref bool
5353
FieldIndex []int
54-
Method bool
55-
MethodIndex int
54+
Method bool // true if method, false if field
55+
MethodIndex int // index of method, set only if Method is true
5656
}
5757

5858
type IntegerNode struct {
@@ -124,6 +124,7 @@ type CallNode struct {
124124
Arguments []Node
125125
Typed int
126126
Fast bool
127+
Func func(params ...interface{}) (interface{}, error)
127128
}
128129

129130
type BuiltinNode struct {

checker/checker.go

Lines changed: 119 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type visitor struct {
6060

6161
type info struct {
6262
method bool
63+
fn *conf.Function
6364
}
6465

6566
func (v *visitor) visit(node ast.Node) (reflect.Type, info) {
@@ -134,6 +135,12 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
134135
node.Deref = true
135136
return anyType, info{}
136137
}
138+
if fn, ok := v.config.Functions[node.Value]; ok {
139+
// Return anyType instead of func type as we don't know the arguments yet.
140+
// The func type can be one of the fn.Types. The type will be resolved
141+
// when the arguments are known in CallNode.
142+
return anyType, info{fn: fn}
143+
}
137144
if t, ok := v.config.Types[node.Value]; ok {
138145
if t.Ambiguous {
139146
return v.error(node, "ambiguous identifier %v", node.Value)
@@ -466,6 +473,32 @@ func (v *visitor) SliceNode(node *ast.SliceNode) (reflect.Type, info) {
466473
func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
467474
fn, fnInfo := v.visit(node.Callee)
468475

476+
if fnInfo.fn != nil {
477+
f := fnInfo.fn
478+
node.Func = f.Func
479+
if len(f.Types) == 0 {
480+
// No type was specified, so we assume the function returns any.
481+
return anyType, info{}
482+
}
483+
var firstErr *file.Error
484+
for _, t := range f.Types {
485+
outType, err := v.checkFunc(f.Name, t, false, node)
486+
if err != nil {
487+
if firstErr == nil {
488+
firstErr = err
489+
}
490+
continue
491+
}
492+
return outType, info{}
493+
}
494+
if firstErr != nil {
495+
if v.err == nil {
496+
v.err = firstErr
497+
}
498+
return anyType, info{}
499+
}
500+
}
501+
469502
fnName := "function"
470503
if identifier, ok := node.Callee.(*ast.IdentifierNode); ok {
471504
fnName = identifier.Value
@@ -475,7 +508,6 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
475508
fnName = name.Value
476509
}
477510
}
478-
479511
switch fn.Kind() {
480512
case reflect.Interface:
481513
return anyType, info{}
@@ -484,37 +516,50 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
484516
if fnInfo.method {
485517
inputParamsCount = 2 // for methods
486518
}
487-
519+
// TODO: Deprecate OpCallFast and move fn(...any) any to TypedFunc list.
520+
// To do this we need add support for variadic arguments in OpCallTyped.
488521
if !isAny(fn) &&
489522
fn.IsVariadic() &&
490523
fn.NumIn() == inputParamsCount &&
491-
((fn.NumOut() == 1 && // Function with one return value
492-
fn.Out(0).Kind() == reflect.Interface) ||
493-
(fn.NumOut() == 2 && // Function with one return value and an error
494-
fn.Out(0).Kind() == reflect.Interface &&
495-
fn.Out(1) == errorType)) {
524+
fn.NumOut() == 1 &&
525+
fn.Out(0).Kind() == reflect.Interface {
496526
rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods
497527
if rest.Kind() == reflect.Slice && rest.Elem().Kind() == reflect.Interface {
498528
node.Fast = true
499529
}
500530
}
501531

502-
return v.checkFunc(fn, fnInfo.method, node, fnName, node.Arguments)
532+
outType, err := v.checkFunc(fnName, fn, fnInfo.method, node)
533+
if err != nil {
534+
if v.err == nil {
535+
v.err = err
536+
}
537+
return anyType, info{}
538+
}
539+
540+
v.findTypedFunc(node, fn, fnInfo.method)
541+
542+
return outType, info{}
503543
}
504544
return v.error(node, "%v is not callable", fn)
505545
}
506546

507-
// checkFunc checks func arguments and returns "return type" of func or method.
508-
func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, name string, arguments []ast.Node) (reflect.Type, info) {
547+
func (v *visitor) checkFunc(name string, fn reflect.Type, method bool, node *ast.CallNode) (reflect.Type, *file.Error) {
509548
if isAny(fn) {
510-
return anyType, info{}
549+
return anyType, nil
511550
}
512551

513552
if fn.NumOut() == 0 {
514-
return v.error(node, "func %v doesn't return value", name)
553+
return anyType, &file.Error{
554+
Location: node.Location(),
555+
Message: fmt.Sprintf("func %v doesn't return value", name),
556+
}
515557
}
516558
if numOut := fn.NumOut(); numOut > 2 {
517-
return v.error(node, "func %v returns more then two values", name)
559+
return anyType, &file.Error{
560+
Location: node.Location(),
561+
Message: fmt.Sprintf("func %v returns more then two values", name),
562+
}
518563
}
519564

520565
// If func is method on an env, first argument should be a receiver,
@@ -530,19 +575,28 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, na
530575
}
531576

532577
if fn.IsVariadic() {
533-
if len(arguments) < fnNumIn-1 {
534-
return v.error(node, "not enough arguments to call %v", name)
578+
if len(node.Arguments) < fnNumIn-1 {
579+
return anyType, &file.Error{
580+
Location: node.Location(),
581+
Message: fmt.Sprintf("not enough arguments to call %v", name),
582+
}
535583
}
536584
} else {
537-
if len(arguments) > fnNumIn {
538-
return v.error(node, "too many arguments to call %v", name)
585+
if len(node.Arguments) > fnNumIn {
586+
return anyType, &file.Error{
587+
Location: node.Location(),
588+
Message: fmt.Sprintf("too many arguments to call %v", name),
589+
}
539590
}
540-
if len(arguments) < fnNumIn {
541-
return v.error(node, "not enough arguments to call %v", name)
591+
if len(node.Arguments) < fnNumIn {
592+
return anyType, &file.Error{
593+
Location: node.Location(),
594+
Message: fmt.Sprintf("not enough arguments to call %v", name),
595+
}
542596
}
543597
}
544598

545-
for i, arg := range arguments {
599+
for i, arg := range node.Arguments {
546600
t, _ := v.visit(arg)
547601

548602
var in reflect.Type
@@ -564,44 +618,14 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, na
564618
}
565619

566620
if !t.AssignableTo(in) && t.Kind() != reflect.Interface {
567-
return v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name)
568-
}
569-
}
570-
571-
// OnCallTyped doesn't work for functions with variadic arguments,
572-
// and doesn't work named function, like `type MyFunc func() int`.
573-
// In PkgPath() is an empty string, it's unnamed function.
574-
if !fn.IsVariadic() && fn.PkgPath() == "" {
575-
funcTypes:
576-
for i := range vm.FuncTypes {
577-
if i == 0 {
578-
continue
579-
}
580-
typed := reflect.ValueOf(vm.FuncTypes[i]).Elem().Type()
581-
if typed.Kind() != reflect.Func {
582-
continue
583-
}
584-
if typed.NumOut() != fn.NumOut() {
585-
continue
621+
return anyType, &file.Error{
622+
Location: arg.Location(),
623+
Message: fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name),
586624
}
587-
for j := 0; j < typed.NumOut(); j++ {
588-
if typed.Out(j) != fn.Out(j) {
589-
continue funcTypes
590-
}
591-
}
592-
if typed.NumIn() != fnNumIn {
593-
continue
594-
}
595-
for j := 0; j < typed.NumIn(); j++ {
596-
if typed.In(j) != fn.In(j+fnInOffset) {
597-
continue funcTypes
598-
}
599-
}
600-
node.Typed = i
601625
}
602626
}
603627

604-
return fn.Out(0), info{}
628+
return fn.Out(0), nil
605629
}
606630

607631
func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
@@ -769,3 +793,44 @@ func (v *visitor) PairNode(node *ast.PairNode) (reflect.Type, info) {
769793
v.visit(node.Value)
770794
return nilType, info{}
771795
}
796+
797+
func (v *visitor) findTypedFunc(node *ast.CallNode, fn reflect.Type, method bool) {
798+
// OnCallTyped doesn't work for functions with variadic arguments,
799+
// and doesn't work named function, like `type MyFunc func() int`.
800+
// In PkgPath() is an empty string, it's unnamed function.
801+
if !fn.IsVariadic() && fn.PkgPath() == "" {
802+
fnNumIn := fn.NumIn()
803+
fnInOffset := 0
804+
if method {
805+
fnNumIn--
806+
fnInOffset = 1
807+
}
808+
funcTypes:
809+
for i := range vm.FuncTypes {
810+
if i == 0 {
811+
continue
812+
}
813+
typed := reflect.ValueOf(vm.FuncTypes[i]).Elem().Type()
814+
if typed.Kind() != reflect.Func {
815+
continue
816+
}
817+
if typed.NumOut() != fn.NumOut() {
818+
continue
819+
}
820+
for j := 0; j < typed.NumOut(); j++ {
821+
if typed.Out(j) != fn.Out(j) {
822+
continue funcTypes
823+
}
824+
}
825+
if typed.NumIn() != fnNumIn {
826+
continue
827+
}
828+
for j := 0; j < typed.NumIn(); j++ {
829+
if typed.In(j) != fn.In(j+fnInOffset) {
830+
continue funcTypes
831+
}
832+
}
833+
node.Typed = i
834+
}
835+
}
836+
}

compiler/compiler.go

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package compiler
22

33
import (
44
"fmt"
5-
"math"
65
"reflect"
76

87
"github.com/antonmedv/expr/ast"
@@ -25,8 +24,9 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro
2524
}()
2625

2726
c := &compiler{
28-
index: make(map[interface{}]int),
29-
locations: make([]file.Location, 0),
27+
locations: make([]file.Location, 0),
28+
constantsIndex: make(map[interface{}]int),
29+
functionsIndex: make(map[string]int),
3030
}
3131

3232
if config != nil {
@@ -52,20 +52,23 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro
5252
Constants: c.constants,
5353
Bytecode: c.bytecode,
5454
Arguments: c.arguments,
55+
Functions: c.functions,
5556
}
5657
return
5758
}
5859

5960
type compiler struct {
60-
locations []file.Location
61-
constants []interface{}
62-
bytecode []Opcode
63-
index map[interface{}]int
64-
mapEnv bool
65-
cast reflect.Kind
66-
nodes []ast.Node
67-
chains [][]int
68-
arguments []int
61+
locations []file.Location
62+
bytecode []Opcode
63+
constants []interface{}
64+
constantsIndex map[interface{}]int
65+
functions []Function
66+
functionsIndex map[string]int
67+
mapEnv bool
68+
cast reflect.Kind
69+
nodes []ast.Node
70+
chains [][]int
71+
arguments []int
6972
}
7073

7174
func (c *compiler) emitLocation(loc file.Location, op Opcode, arg int) int {
@@ -110,25 +113,30 @@ func (c *compiler) addConstant(constant interface{}) int {
110113
indexable = true
111114
hash = fmt.Sprintf("%v", method)
112115
}
113-
114116
if indexable {
115-
if p, ok := c.index[hash]; ok {
117+
if p, ok := c.constantsIndex[hash]; ok {
116118
return p
117119
}
118120
}
119-
120121
c.constants = append(c.constants, constant)
121-
if len(c.constants) > math.MaxUint16 {
122-
panic("exceeded constants max space limit")
123-
}
124-
125122
p := len(c.constants) - 1
126123
if indexable {
127-
c.index[hash] = p
124+
c.constantsIndex[hash] = p
128125
}
129126
return p
130127
}
131128

129+
func (c *compiler) addFunction(f Function) int {
130+
addr := fmt.Sprintf("%v", f)
131+
if p, ok := c.functionsIndex[addr]; ok {
132+
return p
133+
}
134+
p := len(c.functions)
135+
c.functions = append(c.functions, f)
136+
c.functionsIndex[addr] = p
137+
return p
138+
}
139+
132140
func (c *compiler) patchJump(placeholder int) {
133141
offset := len(c.bytecode) - placeholder
134142
c.arguments[placeholder-1] = offset
@@ -516,6 +524,22 @@ func (c *compiler) CallNode(node *ast.CallNode) {
516524
for _, arg := range node.Arguments {
517525
c.compile(arg)
518526
}
527+
if node.Func != nil {
528+
switch len(node.Arguments) {
529+
case 0:
530+
c.emit(OpCall0, c.addFunction(node.Func))
531+
case 1:
532+
c.emit(OpCall1, c.addFunction(node.Func))
533+
case 2:
534+
c.emit(OpCall2, c.addFunction(node.Func))
535+
case 3:
536+
c.emit(OpCall3, c.addFunction(node.Func))
537+
default:
538+
c.emit(OpLoadFunc, c.addFunction(node.Func))
539+
c.emit(OpCallN, len(node.Arguments))
540+
}
541+
return
542+
}
519543
c.compile(node.Callee)
520544
if node.Typed > 0 {
521545
c.emit(OpCallTyped, node.Typed)

0 commit comments

Comments
 (0)