@@ -60,6 +60,7 @@ type visitor struct {
60
60
61
61
type info struct {
62
62
method bool
63
+ fn * conf.Function
63
64
}
64
65
65
66
func (v * visitor ) visit (node ast.Node ) (reflect.Type , info ) {
@@ -134,6 +135,12 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
134
135
node .Deref = true
135
136
return anyType , info {}
136
137
}
138
+ if fn , ok := v .config .Functions [node .Value ]; ok {
139
+ // Return anyType instead of func type as we don't know the arguments yet.
140
+ // The func type can be one of the fn.Types. The type will be resolved
141
+ // when the arguments are known in CallNode.
142
+ return anyType , info {fn : fn }
143
+ }
137
144
if t , ok := v .config .Types [node .Value ]; ok {
138
145
if t .Ambiguous {
139
146
return v .error (node , "ambiguous identifier %v" , node .Value )
@@ -466,6 +473,32 @@ func (v *visitor) SliceNode(node *ast.SliceNode) (reflect.Type, info) {
466
473
func (v * visitor ) CallNode (node * ast.CallNode ) (reflect.Type , info ) {
467
474
fn , fnInfo := v .visit (node .Callee )
468
475
476
+ if fnInfo .fn != nil {
477
+ f := fnInfo .fn
478
+ node .Func = f .Func
479
+ if len (f .Types ) == 0 {
480
+ // No type was specified, so we assume the function returns any.
481
+ return anyType , info {}
482
+ }
483
+ var firstErr * file.Error
484
+ for _ , t := range f .Types {
485
+ outType , err := v .checkFunc (f .Name , t , false , node )
486
+ if err != nil {
487
+ if firstErr == nil {
488
+ firstErr = err
489
+ }
490
+ continue
491
+ }
492
+ return outType , info {}
493
+ }
494
+ if firstErr != nil {
495
+ if v .err == nil {
496
+ v .err = firstErr
497
+ }
498
+ return anyType , info {}
499
+ }
500
+ }
501
+
469
502
fnName := "function"
470
503
if identifier , ok := node .Callee .(* ast.IdentifierNode ); ok {
471
504
fnName = identifier .Value
@@ -475,7 +508,6 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
475
508
fnName = name .Value
476
509
}
477
510
}
478
-
479
511
switch fn .Kind () {
480
512
case reflect .Interface :
481
513
return anyType , info {}
@@ -484,37 +516,50 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
484
516
if fnInfo .method {
485
517
inputParamsCount = 2 // for methods
486
518
}
487
-
519
+ // TODO: Deprecate OpCallFast and move fn(...any) any to TypedFunc list.
520
+ // To do this we need add support for variadic arguments in OpCallTyped.
488
521
if ! isAny (fn ) &&
489
522
fn .IsVariadic () &&
490
523
fn .NumIn () == inputParamsCount &&
491
- ((fn .NumOut () == 1 && // Function with one return value
492
- fn .Out (0 ).Kind () == reflect .Interface ) ||
493
- (fn .NumOut () == 2 && // Function with one return value and an error
494
- fn .Out (0 ).Kind () == reflect .Interface &&
495
- fn .Out (1 ) == errorType )) {
524
+ fn .NumOut () == 1 &&
525
+ fn .Out (0 ).Kind () == reflect .Interface {
496
526
rest := fn .In (fn .NumIn () - 1 ) // function has only one param for functions and two for methods
497
527
if rest .Kind () == reflect .Slice && rest .Elem ().Kind () == reflect .Interface {
498
528
node .Fast = true
499
529
}
500
530
}
501
531
502
- return v .checkFunc (fn , fnInfo .method , node , fnName , node .Arguments )
532
+ outType , err := v .checkFunc (fnName , fn , fnInfo .method , node )
533
+ if err != nil {
534
+ if v .err == nil {
535
+ v .err = err
536
+ }
537
+ return anyType , info {}
538
+ }
539
+
540
+ v .findTypedFunc (node , fn , fnInfo .method )
541
+
542
+ return outType , info {}
503
543
}
504
544
return v .error (node , "%v is not callable" , fn )
505
545
}
506
546
507
- // checkFunc checks func arguments and returns "return type" of func or method.
508
- func (v * visitor ) checkFunc (fn reflect.Type , method bool , node * ast.CallNode , name string , arguments []ast.Node ) (reflect.Type , info ) {
547
+ func (v * visitor ) checkFunc (name string , fn reflect.Type , method bool , node * ast.CallNode ) (reflect.Type , * file.Error ) {
509
548
if isAny (fn ) {
510
- return anyType , info {}
549
+ return anyType , nil
511
550
}
512
551
513
552
if fn .NumOut () == 0 {
514
- return v .error (node , "func %v doesn't return value" , name )
553
+ return anyType , & file.Error {
554
+ Location : node .Location (),
555
+ Message : fmt .Sprintf ("func %v doesn't return value" , name ),
556
+ }
515
557
}
516
558
if numOut := fn .NumOut (); numOut > 2 {
517
- return v .error (node , "func %v returns more then two values" , name )
559
+ return anyType , & file.Error {
560
+ Location : node .Location (),
561
+ Message : fmt .Sprintf ("func %v returns more then two values" , name ),
562
+ }
518
563
}
519
564
520
565
// If func is method on an env, first argument should be a receiver,
@@ -530,19 +575,28 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, na
530
575
}
531
576
532
577
if fn .IsVariadic () {
533
- if len (arguments ) < fnNumIn - 1 {
534
- return v .error (node , "not enough arguments to call %v" , name )
578
+ if len (node .Arguments ) < fnNumIn - 1 {
579
+ return anyType , & file.Error {
580
+ Location : node .Location (),
581
+ Message : fmt .Sprintf ("not enough arguments to call %v" , name ),
582
+ }
535
583
}
536
584
} else {
537
- if len (arguments ) > fnNumIn {
538
- return v .error (node , "too many arguments to call %v" , name )
585
+ if len (node .Arguments ) > fnNumIn {
586
+ return anyType , & file.Error {
587
+ Location : node .Location (),
588
+ Message : fmt .Sprintf ("too many arguments to call %v" , name ),
589
+ }
539
590
}
540
- if len (arguments ) < fnNumIn {
541
- return v .error (node , "not enough arguments to call %v" , name )
591
+ if len (node .Arguments ) < fnNumIn {
592
+ return anyType , & file.Error {
593
+ Location : node .Location (),
594
+ Message : fmt .Sprintf ("not enough arguments to call %v" , name ),
595
+ }
542
596
}
543
597
}
544
598
545
- for i , arg := range arguments {
599
+ for i , arg := range node . Arguments {
546
600
t , _ := v .visit (arg )
547
601
548
602
var in reflect.Type
@@ -564,44 +618,14 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, na
564
618
}
565
619
566
620
if ! t .AssignableTo (in ) && t .Kind () != reflect .Interface {
567
- return v .error (arg , "cannot use %v as argument (type %v) to call %v " , t , in , name )
568
- }
569
- }
570
-
571
- // OnCallTyped doesn't work for functions with variadic arguments,
572
- // and doesn't work named function, like `type MyFunc func() int`.
573
- // In PkgPath() is an empty string, it's unnamed function.
574
- if ! fn .IsVariadic () && fn .PkgPath () == "" {
575
- funcTypes:
576
- for i := range vm .FuncTypes {
577
- if i == 0 {
578
- continue
579
- }
580
- typed := reflect .ValueOf (vm .FuncTypes [i ]).Elem ().Type ()
581
- if typed .Kind () != reflect .Func {
582
- continue
583
- }
584
- if typed .NumOut () != fn .NumOut () {
585
- continue
621
+ return anyType , & file.Error {
622
+ Location : arg .Location (),
623
+ Message : fmt .Sprintf ("cannot use %v as argument (type %v) to call %v " , t , in , name ),
586
624
}
587
- for j := 0 ; j < typed .NumOut (); j ++ {
588
- if typed .Out (j ) != fn .Out (j ) {
589
- continue funcTypes
590
- }
591
- }
592
- if typed .NumIn () != fnNumIn {
593
- continue
594
- }
595
- for j := 0 ; j < typed .NumIn (); j ++ {
596
- if typed .In (j ) != fn .In (j + fnInOffset ) {
597
- continue funcTypes
598
- }
599
- }
600
- node .Typed = i
601
625
}
602
626
}
603
627
604
- return fn .Out (0 ), info {}
628
+ return fn .Out (0 ), nil
605
629
}
606
630
607
631
func (v * visitor ) BuiltinNode (node * ast.BuiltinNode ) (reflect.Type , info ) {
@@ -769,3 +793,44 @@ func (v *visitor) PairNode(node *ast.PairNode) (reflect.Type, info) {
769
793
v .visit (node .Value )
770
794
return nilType , info {}
771
795
}
796
+
797
+ func (v * visitor ) findTypedFunc (node * ast.CallNode , fn reflect.Type , method bool ) {
798
+ // OnCallTyped doesn't work for functions with variadic arguments,
799
+ // and doesn't work named function, like `type MyFunc func() int`.
800
+ // In PkgPath() is an empty string, it's unnamed function.
801
+ if ! fn .IsVariadic () && fn .PkgPath () == "" {
802
+ fnNumIn := fn .NumIn ()
803
+ fnInOffset := 0
804
+ if method {
805
+ fnNumIn --
806
+ fnInOffset = 1
807
+ }
808
+ funcTypes:
809
+ for i := range vm .FuncTypes {
810
+ if i == 0 {
811
+ continue
812
+ }
813
+ typed := reflect .ValueOf (vm .FuncTypes [i ]).Elem ().Type ()
814
+ if typed .Kind () != reflect .Func {
815
+ continue
816
+ }
817
+ if typed .NumOut () != fn .NumOut () {
818
+ continue
819
+ }
820
+ for j := 0 ; j < typed .NumOut (); j ++ {
821
+ if typed .Out (j ) != fn .Out (j ) {
822
+ continue funcTypes
823
+ }
824
+ }
825
+ if typed .NumIn () != fnNumIn {
826
+ continue
827
+ }
828
+ for j := 0 ; j < typed .NumIn (); j ++ {
829
+ if typed .In (j ) != fn .In (j + fnInOffset ) {
830
+ continue funcTypes
831
+ }
832
+ }
833
+ node .Typed = i
834
+ }
835
+ }
836
+ }
0 commit comments