From 359dd99f9f1812aaf7db66391e100ca0de9dabf6 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Thu, 28 Aug 2025 18:44:21 -0300 Subject: [PATCH 01/24] export the checker and ensure no methods or fields are exported --- checker/checker.go | 112 ++++++++++++++++++++++----------------------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index f4923413..98362d06 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -77,7 +77,7 @@ func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { config = conf.New(nil) } - v := &checker{config: config} + v := &Checker{config: config} nt := v.visit(tree.Node) @@ -113,7 +113,7 @@ func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { return t, nil } -type checker struct { +type Checker struct { config *conf.Config predicateScopes []predicateScope varScopes []varScope @@ -141,53 +141,53 @@ type info struct { elem reflect.Type } -func (v *checker) visit(node ast.Node) Nature { +func (v *Checker) visit(node ast.Node) Nature { var nt Nature switch n := node.(type) { case *ast.NilNode: - nt = v.NilNode(n) + nt = v.nilNode(n) case *ast.IdentifierNode: - nt = v.IdentifierNode(n) + nt = v.identifierNode(n) case *ast.IntegerNode: - nt = v.IntegerNode(n) + nt = v.integerNode(n) case *ast.FloatNode: - nt = v.FloatNode(n) + nt = v.floatNode(n) case *ast.BoolNode: - nt = v.BoolNode(n) + nt = v.boolNode(n) case *ast.StringNode: - nt = v.StringNode(n) + nt = v.stringNode(n) case *ast.ConstantNode: - nt = v.ConstantNode(n) + nt = v.constantNode(n) case *ast.UnaryNode: - nt = v.UnaryNode(n) + nt = v.unaryNode(n) case *ast.BinaryNode: - nt = v.BinaryNode(n) + nt = v.binaryNode(n) case *ast.ChainNode: - nt = v.ChainNode(n) + nt = v.chainNode(n) case *ast.MemberNode: - nt = v.MemberNode(n) + nt = v.memberNode(n) case *ast.SliceNode: - nt = v.SliceNode(n) + nt = v.sliceNode(n) case *ast.CallNode: - nt = v.CallNode(n) + nt = v.callNode(n) case *ast.BuiltinNode: - nt = v.BuiltinNode(n) + nt = v.builtinNode(n) case *ast.PredicateNode: - nt = v.PredicateNode(n) + nt = v.predicateNode(n) case *ast.PointerNode: - nt = v.PointerNode(n) + nt = v.pointerNode(n) case *ast.VariableDeclaratorNode: - nt = v.VariableDeclaratorNode(n) + nt = v.variableDeclaratorNode(n) case *ast.SequenceNode: - nt = v.SequenceNode(n) + nt = v.sequenceNode(n) case *ast.ConditionalNode: - nt = v.ConditionalNode(n) + nt = v.conditionalNode(n) case *ast.ArrayNode: - nt = v.ArrayNode(n) + nt = v.arrayNode(n) case *ast.MapNode: - nt = v.MapNode(n) + nt = v.mapNode(n) case *ast.PairNode: - nt = v.PairNode(n) + nt = v.pairNode(n) default: panic(fmt.Sprintf("undefined node type (%T)", node)) } @@ -195,7 +195,7 @@ func (v *checker) visit(node ast.Node) Nature { return nt } -func (v *checker) error(node ast.Node, format string, args ...any) Nature { +func (v *Checker) error(node ast.Node, format string, args ...any) Nature { if v.err == nil { // show first error v.err = &file.Error{ Location: node.Location(), @@ -205,11 +205,11 @@ func (v *checker) error(node ast.Node, format string, args ...any) Nature { return unknown } -func (v *checker) NilNode(*ast.NilNode) Nature { +func (v *Checker) nilNode(*ast.NilNode) Nature { return nilNature } -func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature { +func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { if variable, ok := v.lookupVariable(node.Value); ok { return variable.nature } @@ -221,7 +221,7 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature { } // ident method returns type of environment variable, builtin or function. -func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Nature { +func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Nature { if nt, ok := v.config.Env.Get(name); ok { return nt } @@ -239,27 +239,27 @@ func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Natur return unknown } -func (v *checker) IntegerNode(*ast.IntegerNode) Nature { +func (v *Checker) integerNode(*ast.IntegerNode) Nature { return integerNature } -func (v *checker) FloatNode(*ast.FloatNode) Nature { +func (v *Checker) floatNode(*ast.FloatNode) Nature { return floatNature } -func (v *checker) BoolNode(*ast.BoolNode) Nature { +func (v *Checker) boolNode(*ast.BoolNode) Nature { return boolNature } -func (v *checker) StringNode(*ast.StringNode) Nature { +func (v *Checker) stringNode(*ast.StringNode) Nature { return stringNature } -func (v *checker) ConstantNode(node *ast.ConstantNode) Nature { +func (v *Checker) constantNode(node *ast.ConstantNode) Nature { return Nature{Type: reflect.TypeOf(node.Value)} } -func (v *checker) UnaryNode(node *ast.UnaryNode) Nature { +func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { nt := v.visit(node.Node) nt = nt.Deref() @@ -288,7 +288,7 @@ func (v *checker) UnaryNode(node *ast.UnaryNode) Nature { return v.error(node, `invalid operation: %v (mismatched type %s)`, node.Operator, nt) } -func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { +func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { l := v.visit(node.Left) r := v.visit(node.Right) @@ -480,11 +480,11 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r) } -func (v *checker) ChainNode(node *ast.ChainNode) Nature { +func (v *Checker) chainNode(node *ast.ChainNode) Nature { return v.visit(node.Node) } -func (v *checker) MemberNode(node *ast.MemberNode) Nature { +func (v *Checker) memberNode(node *ast.MemberNode) Nature { // $env variable if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" { if name, ok := node.Property.(*ast.StringNode); ok { @@ -566,7 +566,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { return v.error(node, "type %v[%v] is undefined", base, prop) } -func (v *checker) SliceNode(node *ast.SliceNode) Nature { +func (v *Checker) sliceNode(node *ast.SliceNode) Nature { nt := v.visit(node.Node) if isUnknown(nt) { @@ -597,7 +597,7 @@ func (v *checker) SliceNode(node *ast.SliceNode) Nature { return nt } -func (v *checker) CallNode(node *ast.CallNode) Nature { +func (v *Checker) callNode(node *ast.CallNode) Nature { nt := v.functionReturnType(node) // Check if type was set on node (for example, by patcher) @@ -618,7 +618,7 @@ func (v *checker) CallNode(node *ast.CallNode) Nature { return nt } -func (v *checker) functionReturnType(node *ast.CallNode) Nature { +func (v *Checker) functionReturnType(node *ast.CallNode) Nature { nt := v.visit(node.Callee) if nt.Func != nil { @@ -657,7 +657,7 @@ func (v *checker) functionReturnType(node *ast.CallNode) Nature { return v.error(node, "%s is not callable", nt) } -func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { +func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { switch node.Name { case "all", "none", "any", "one": collection := v.visit(node.Arguments[0]).Deref() @@ -893,7 +893,7 @@ type scopeVar struct { varNature Nature } -func (v *checker) begin(collectionNature Nature, vars ...scopeVar) { +func (v *Checker) begin(collectionNature Nature, vars ...scopeVar) { scope := predicateScope{collection: collectionNature, vars: make(map[string]Nature)} for _, v := range vars { scope.vars[v.varName] = v.varNature @@ -901,11 +901,11 @@ func (v *checker) begin(collectionNature Nature, vars ...scopeVar) { v.predicateScopes = append(v.predicateScopes, scope) } -func (v *checker) end() { +func (v *Checker) end() { v.predicateScopes = v.predicateScopes[:len(v.predicateScopes)-1] } -func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { +func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if len(node.Arguments) != 2 { return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments)) } @@ -941,7 +941,7 @@ func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "type %v does not support indexing", base) } -func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { +func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { if f.Validate != nil { args := make([]reflect.Type, len(arguments)) for i, arg := range arguments { @@ -995,7 +995,7 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] return v.error(node, "no matching overload for %v", f.Name) } -func (v *checker) checkArguments( +func (v *Checker) checkArguments( name string, fn Nature, arguments []ast.Node, @@ -1152,7 +1152,7 @@ func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Na } } -func (v *checker) PredicateNode(node *ast.PredicateNode) Nature { +func (v *Checker) predicateNode(node *ast.PredicateNode) Nature { nt := v.visit(node.Node) var out []reflect.Type if isUnknown(nt) { @@ -1166,7 +1166,7 @@ func (v *checker) PredicateNode(node *ast.PredicateNode) Nature { } } -func (v *checker) PointerNode(node *ast.PointerNode) Nature { +func (v *Checker) pointerNode(node *ast.PointerNode) Nature { if len(v.predicateScopes) == 0 { return v.error(node, "cannot use pointer accessor outside predicate") } @@ -1189,7 +1189,7 @@ func (v *checker) PointerNode(node *ast.PointerNode) Nature { return v.error(node, "unknown pointer #%v", node.Name) } -func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { +func (v *Checker) variableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { if _, ok := v.config.Env.Get(node.Name); ok { return v.error(node, "cannot redeclare %v", node.Name) } @@ -1209,7 +1209,7 @@ func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Natur return exprNature } -func (v *checker) SequenceNode(node *ast.SequenceNode) Nature { +func (v *Checker) sequenceNode(node *ast.SequenceNode) Nature { if len(node.Nodes) == 0 { return v.error(node, "empty sequence expression") } @@ -1220,7 +1220,7 @@ func (v *checker) SequenceNode(node *ast.SequenceNode) Nature { return last } -func (v *checker) lookupVariable(name string) (varScope, bool) { +func (v *Checker) lookupVariable(name string) (varScope, bool) { for i := len(v.varScopes) - 1; i >= 0; i-- { if v.varScopes[i].name == name { return v.varScopes[i], true @@ -1229,7 +1229,7 @@ func (v *checker) lookupVariable(name string) (varScope, bool) { return varScope{}, false } -func (v *checker) ConditionalNode(node *ast.ConditionalNode) Nature { +func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { c := v.visit(node.Cond) if !isBool(c) && !isUnknown(c) { return v.error(node.Cond, "non-bool expression (type %v) used as condition", c) @@ -1253,7 +1253,7 @@ func (v *checker) ConditionalNode(node *ast.ConditionalNode) Nature { return unknown } -func (v *checker) ArrayNode(node *ast.ArrayNode) Nature { +func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { var prev Nature allElementsAreSameType := true for i, node := range node.Nodes { @@ -1271,14 +1271,14 @@ func (v *checker) ArrayNode(node *ast.ArrayNode) Nature { return arrayNature } -func (v *checker) MapNode(node *ast.MapNode) Nature { +func (v *Checker) mapNode(node *ast.MapNode) Nature { for _, pair := range node.Pairs { v.visit(pair) } return mapNature } -func (v *checker) PairNode(node *ast.PairNode) Nature { +func (v *Checker) pairNode(node *ast.PairNode) Nature { v.visit(node.Key) v.visit(node.Value) return nilNature From 360157aaabbc28c5a8aed11b3a00a9c110296e55 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Thu, 28 Aug 2025 19:45:20 -0300 Subject: [PATCH 02/24] make Checker reusable --- checker/checker.go | 84 ++++++++++++++++++++++++----------------- checker/checker_test.go | 23 +++++++---- parser/lexer/lexer.go | 7 ---- 3 files changed, 64 insertions(+), 50 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 98362d06..225a9e10 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -73,13 +73,39 @@ func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { // Check checks types of the expression tree. It returns type of the expression // and error if any. If config is nil, then default configuration will be used. func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + return new(Checker).Check(tree, config) +} + +type Checker struct { + config *conf.Config + predicateScopes []predicateScope + varScopes []varScope + err *file.Error + needsReset bool +} + +type predicateScope struct { + collection Nature + vars map[string]Nature +} + +type varScope struct { + name string + nature Nature +} + +func (c *Checker) Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + if c.needsReset { + c.reset() + } + c.needsReset = true + if config == nil { config = conf.New(nil) } + c.config = config - v := &Checker{config: config} - - nt := v.visit(tree.Node) + nt := c.visit(tree.Node) // To keep compatibility with previous versions, we should return any, if nature is unknown. t := nt.Type @@ -87,25 +113,25 @@ func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { t = anyType } - if v.err != nil { - return t, v.err.Bind(tree.Source) + if c.err != nil { + return t, c.err.Bind(tree.Source) } - if v.config.Expect != reflect.Invalid { - if v.config.ExpectAny { + if c.config.Expect != reflect.Invalid { + if c.config.ExpectAny { if isUnknown(nt) { return t, nil } } - switch v.config.Expect { + switch c.config.Expect { case reflect.Int, reflect.Int64, reflect.Float64: if !isNumber(nt) { - return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, nt) + return nil, fmt.Errorf("expected %v, but got %v", c.config.Expect, nt) } default: - if nt.Kind() != v.config.Expect { - return nil, fmt.Errorf("expected %v, but got %s", v.config.Expect, nt) + if nt.Kind() != c.config.Expect { + return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt) } } } @@ -113,32 +139,20 @@ func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { return t, nil } -type Checker struct { - config *conf.Config - predicateScopes []predicateScope - varScopes []varScope - err *file.Error +func (c *Checker) reset() { + clearSlice(c.predicateScopes) + clearSlice(c.varScopes) + c.predicateScopes = c.predicateScopes[:0] + c.varScopes = c.varScopes[:0] + c.err = nil + c.config = nil } -type predicateScope struct { - collection Nature - vars map[string]Nature -} - -type varScope struct { - name string - nature Nature -} - -type info struct { - method bool - fn *builtin.Function - - // elem is element type of array or map. - // Arrays created with type []any, but - // we would like to detect expressions - // like `42 in ["a"]` as invalid. - elem reflect.Type +func clearSlice[S ~[]E, E any](s S) { + var zero E + for i := range s { + s[i] = zero + } } func (v *Checker) visit(node ast.Node) Nature { diff --git a/checker/checker_test.go b/checker/checker_test.go index 2ec7c7cf..bf2b812d 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -136,6 +136,7 @@ func TestCheck(t *testing.T) { {"(Embed).EmbedPointerEmbedInt > 0"}, } + c := new(checker.Checker) for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { var err error @@ -145,7 +146,7 @@ func TestCheck(t *testing.T) { config := conf.New(mock.Env{}) expr.AsBool()(config) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) assert.NoError(t, err) }) } @@ -682,12 +683,13 @@ invalid operation: + (mismatched types int and bool) (1:6) }, } + c := new(checker.Checker) for _, tt := range errorTests { t.Run(tt.code, func(t *testing.T) { tree, err := parser.Parse(tt.code) require.NoError(t, err) - _, err = checker.Check(tree, conf.New(mock.Env{})) + _, err = c.Check(tree, conf.New(mock.Env{})) if err == nil { err = fmt.Errorf("") } @@ -839,6 +841,7 @@ func TestCheck_TypeWeights(t *testing.T) { "Float32": float32(11), "Float64": float64(12), } + c := new(checker.Checker) for a := range types { for b := range types { tree, err := parser.Parse(fmt.Sprintf("%s + %s", a, b)) @@ -846,7 +849,7 @@ func TestCheck_TypeWeights(t *testing.T) { config := conf.New(types) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) require.NoError(t, err) } } @@ -926,6 +929,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { config := conf.CreateNew() add(config) + c := new(checker.Checker) tests := []string{ "add(1)", @@ -938,7 +942,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { tree, err := parser.Parse(test) require.NoError(t, err) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) require.NoError(t, err) require.Equal(t, reflect.Int, tree.Node.Type().Kind()) }) @@ -948,7 +952,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { tree, err := parser.Parse("add(1, '2')") require.NoError(t, err) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) require.Error(t, err) require.Equal(t, "cannot use string as argument (type int) to call add (1:8)\n | add(1, '2')\n | .......^", err.Error()) }) @@ -1046,12 +1050,13 @@ func TestCheck_env_keyword(t *testing.T) { {`$env[name]`, reflect.Interface}, } + c := new(checker.Checker) for _, test := range tests { t.Run(test.input, func(t *testing.T) { tree, err := parser.Parse(test.input) require.NoError(t, err) - rtype, err := checker.Check(tree, conf.New(env)) + rtype, err := c.Check(tree, conf.New(env)) require.NoError(t, err) require.True(t, rtype.Kind() == test.want, fmt.Sprintf("expected %s, got %s", test.want, rtype.Kind())) }) @@ -1067,12 +1072,13 @@ func TestCheck_builtin_without_call(t *testing.T) { {`string.A`, "type func(interface {}) string has no field A (1:8)\n | string.A\n | .......^"}, } + c := new(checker.Checker) for _, test := range tests { t.Run(test.input, func(t *testing.T) { tree, err := parser.Parse(test.input) require.NoError(t, err) - _, err = checker.Check(tree, conf.New(nil)) + _, err = c.Check(tree, conf.New(nil)) require.Error(t, err) require.Equal(t, test.err, err.Error()) }) @@ -1134,13 +1140,14 @@ func TestCheck_types(t *testing.T) { {`arr | filter(.value contains "a") | filter(.value == 0)`, `invalid operation: == (mismatched types string and int)`}, } + c := new(checker.Checker) for _, test := range tests { t.Run(test.code, func(t *testing.T) { tree, err := parser.Parse(test.code) require.NoError(t, err) config := conf.New(env) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) if test.err == noerr { require.NoError(t, err) } else { diff --git a/parser/lexer/lexer.go b/parser/lexer/lexer.go index a8cf6a5b..fe41e824 100644 --- a/parser/lexer/lexer.go +++ b/parser/lexer/lexer.go @@ -92,13 +92,6 @@ func (l *Lexer) peek() rune { return eof } -func (l *Lexer) peekByte() (byte, bool) { - if l.end.byte >= 0 && l.end.byte < len(l.source.String()) { - return l.source.String()[l.end.byte], true - } - return 0, false -} - func (l *Lexer) backup() { if l.eof { l.eof = false From d234ae3b4cab371ab42c1d724c0bc2a41c41bdf2 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Thu, 28 Aug 2025 20:00:37 -0300 Subject: [PATCH 03/24] add PatchAndCheck method to Checker, and reuse it in visitors --- checker/checker.go | 120 ++++++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 56 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 225a9e10..55da8cf3 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -13,40 +13,6 @@ import ( "github.com/expr-lang/expr/parser" ) -// Run visitors in a given config over the given tree -// runRepeatable controls whether to filter for only vistors that require multiple passes or not -func runVisitors(tree *parser.Tree, config *conf.Config, runRepeatable bool) { - for { - more := false - for _, v := range config.Visitors { - // We need to perform types check, because some visitors may rely on - // types information available in the tree. - _, _ = Check(tree, config) - - r, repeatable := v.(interface { - Reset() - ShouldRepeat() bool - }) - - if repeatable { - if runRepeatable { - r.Reset() - ast.Walk(&tree.Node, v) - more = more || r.ShouldRepeat() - } - } else { - if !runRepeatable { - ast.Walk(&tree.Node, v) - } - } - } - - if !more { - break - } - } -} - // ParseCheck parses input expression and checks its types. Also, it applies // all provided patchers. In case of error, it returns error with a tree. func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { @@ -55,14 +21,7 @@ func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { return tree, err } - if len(config.Visitors) > 0 { - // Run all patchers that dont support being run repeatedly first - runVisitors(tree, config, false) - - // Run patchers that require multiple passes next (currently only Operator patching) - runVisitors(tree, config, true) - } - _, err = Check(tree, config) + _, err = new(Checker).PatchAndCheck(tree, config) if err != nil { return tree, err } @@ -94,17 +53,59 @@ type varScope struct { nature Nature } -func (c *Checker) Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { - if c.needsReset { - c.reset() +// PatchAndCheck applies all patchers and checks the tree. +func (c *Checker) PatchAndCheck(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + c.reset(config) + if len(config.Visitors) > 0 { + // Run all patchers that dont support being run repeatedly first + c.runVisitors(tree, false) + + // Run patchers that require multiple passes next (currently only Operator patching) + c.runVisitors(tree, true) } - c.needsReset = true + return c.Check(tree, config) +} - if config == nil { - config = conf.New(nil) +func (c *Checker) Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + c.reset(config) + return c.check(tree) +} + +// Run visitors in a given config over the given tree +// runRepeatable controls whether to filter for only vistors that require multiple passes or not +func (c *Checker) runVisitors(tree *parser.Tree, runRepeatable bool) { + for { + more := false + for _, v := range c.config.Visitors { + // We need to perform types check, because some visitors may rely on + // types information available in the tree. + _, _ = c.Check(tree, c.config) + + r, repeatable := v.(interface { + Reset() + ShouldRepeat() bool + }) + + if repeatable { + if runRepeatable { + r.Reset() + ast.Walk(&tree.Node, v) + more = more || r.ShouldRepeat() + } + } else { + if !runRepeatable { + ast.Walk(&tree.Node, v) + } + } + } + + if !more { + break + } } - c.config = config +} +func (c *Checker) check(tree *parser.Tree) (reflect.Type, error) { nt := c.visit(tree.Node) // To keep compatibility with previous versions, we should return any, if nature is unknown. @@ -139,13 +140,20 @@ func (c *Checker) Check(tree *parser.Tree, config *conf.Config) (reflect.Type, e return t, nil } -func (c *Checker) reset() { - clearSlice(c.predicateScopes) - clearSlice(c.varScopes) - c.predicateScopes = c.predicateScopes[:0] - c.varScopes = c.varScopes[:0] - c.err = nil - c.config = nil +func (c *Checker) reset(config *conf.Config) { + if c.needsReset { + clearSlice(c.predicateScopes) + clearSlice(c.varScopes) + c.predicateScopes = c.predicateScopes[:0] + c.varScopes = c.varScopes[:0] + c.err = nil + } + c.needsReset = true + + if config == nil { + config = conf.New(nil) + } + c.config = config } func clearSlice[S ~[]E, E any](s S) { From edce1b0c2457dc9c5207d9398d5afc376333853d Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Thu, 28 Aug 2025 22:36:38 -0300 Subject: [PATCH 04/24] avoid unnecessary map and consolidate structs --- checker/checker.go | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 55da8cf3..0e79601b 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -29,8 +29,7 @@ func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { return tree, nil } -// Check checks types of the expression tree. It returns type of the expression -// and error if any. If config is nil, then default configuration will be used. +// Check calls Check on a disposable Checker. func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { return new(Checker).Check(tree, config) } @@ -45,7 +44,7 @@ type Checker struct { type predicateScope struct { collection Nature - vars map[string]Nature + vars []varScope } type varScope struct { @@ -66,6 +65,8 @@ func (c *Checker) PatchAndCheck(tree *parser.Tree, config *conf.Config) (reflect return c.Check(tree, config) } +// Check checks types of the expression tree. It returns type of the expression +// and error if any. If config is nil, then default configuration will be used. func (c *Checker) Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { c.reset(config) return c.check(tree) @@ -732,7 +733,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } - v.begin(collection, scopeVar{"index", integerNature}) + v.begin(collection, varScope{"index", integerNature}) predicate := v.visit(node.Arguments[1]) v.end() @@ -884,7 +885,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } - v.begin(collection, scopeVar{"index", integerNature}, scopeVar{"acc", unknown}) + v.begin(collection, varScope{"index", integerNature}, varScope{"acc", unknown}) predicate := v.visit(node.Arguments[1]) v.end() @@ -910,17 +911,11 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node, "unknown builtin %v", node.Name) } -type scopeVar struct { - varName string - varNature Nature -} - -func (v *Checker) begin(collectionNature Nature, vars ...scopeVar) { - scope := predicateScope{collection: collectionNature, vars: make(map[string]Nature)} - for _, v := range vars { - scope.vars[v.varName] = v.varNature - } - v.predicateScopes = append(v.predicateScopes, scope) +func (v *Checker) begin(collectionNature Nature, vars ...varScope) { + v.predicateScopes = append(v.predicateScopes, predicateScope{ + collection: collectionNature, + vars: vars, + }) } func (v *Checker) end() { @@ -1204,8 +1199,10 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { return v.error(node, "cannot use %v as array", scope) } if scope.vars != nil { - if t, ok := scope.vars[node.Name]; ok { - return t + for i := range scope.vars { + if node.Name == scope.vars[i].name { + return scope.vars[i].nature + } } } return v.error(node, "unknown pointer #%v", node.Name) From 80afb7648b02bdd9a799392ef5a7ce69e9d4fb56 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 03:24:53 -0300 Subject: [PATCH 05/24] make all Nature methods use pointer-receiver --- checker/checker.go | 135 ++++++++++++++++++++++----------------- checker/nature/nature.go | 42 ++++++------ compiler/compiler.go | 3 +- docgen/docgen.go | 3 +- 4 files changed, 103 insertions(+), 80 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 0e79601b..38eafd90 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -129,11 +129,11 @@ func (c *Checker) check(tree *parser.Tree) (reflect.Type, error) { switch c.config.Expect { case reflect.Int, reflect.Int64, reflect.Float64: if !isNumber(nt) { - return nil, fmt.Errorf("expected %v, but got %v", c.config.Expect, nt) + return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } default: if nt.Kind() != c.config.Expect { - return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt) + return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } } } @@ -257,7 +257,7 @@ func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Natur } } if v.config.Strict && strict { - return v.error(node, "unknown name %v", name) + return v.error(node, "unknown name %s", name) } return unknown } @@ -305,10 +305,10 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { } default: - return v.error(node, "unknown operator (%v)", node.Operator) + return v.error(node, "unknown operator (%s)", node.Operator) } - return v.error(node, `invalid operation: %v (mismatched type %s)`, node.Operator, nt) + return v.error(node, `invalid operation: %s (mismatched type %s)`, node.Operator, nt.String()) } func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { @@ -432,14 +432,16 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return boolNature } if isMap(r) { - if !isUnknown(l) && !l.AssignableTo(r.Key()) { - return v.error(node, "cannot use %v as type %v in map key", l, r.Key()) + rKey := r.Key() + if !isUnknown(l) && !l.AssignableTo(rKey) { + return v.error(node, "cannot use %s as type %s in map key", l.String(), rKey.String()) } return boolNature } if isArray(r) { - if !isComparable(l, r.Elem()) { - return v.error(node, "cannot use %v as type %v in array", l, r.Elem()) + rElem := r.Elem() + if !isComparable(l, rElem) { + return v.error(node, "cannot use %s as type %s in array", l.String(), rElem.String()) } return boolNature } @@ -496,11 +498,11 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return unknown default: - return v.error(node, "unknown operator (%v)", node.Operator) + return v.error(node, "unknown operator (%s)", node.Operator) } - return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r) + return v.error(node, `invalid operation: %s (mismatched types %s and %s)`, node.Operator, l.String(), r.String()) } func (v *Checker) chainNode(node *ast.ChainNode) Nature { @@ -533,7 +535,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { if name, ok := node.Property.(*ast.StringNode); ok { if isNil(base) { - return v.error(node, "type nil has no field %v", name.Value) + return v.error(node, "type nil has no field %s", name.Value) } // First, check methods defined on base type itself, @@ -548,20 +550,20 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { switch base.Kind() { case reflect.Map: if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { - return v.error(node.Property, "cannot use %v to get an element from %v", prop, base) + return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String()) } if prop, ok := node.Property.(*ast.StringNode); ok { if field, ok := base.Fields[prop.Value]; ok { return field } else if base.Strict { - return v.error(node.Property, "unknown field %v", prop.Value) + return v.error(node.Property, "unknown field %s", prop.Value) } } return base.Elem() case reflect.Array, reflect.Slice: if !isInteger(prop) && !isUnknown(prop) { - return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop) + return v.error(node.Property, "array elements can only be selected using an integer (got %s)", prop.String()) } return base.Elem() @@ -572,9 +574,9 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { return Nature{Type: field.Type} } if node.Method { - return v.error(node, "type %v has no method %v", base, propertyName) + return v.error(node, "type %v has no method %v", base.String(), propertyName) } - return v.error(node, "type %v has no field %v", base, propertyName) + return v.error(node, "type %v has no field %v", base.String(), propertyName) } } @@ -582,11 +584,11 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { if name, ok := node.Property.(*ast.StringNode); ok { if node.Method { - return v.error(node, "type %v has no method %v", base, name.Value) + return v.error(node, "type %v has no method %v", base.String(), name.Value) } - return v.error(node, "type %v has no field %v", base, name.Value) + return v.error(node, "type %v has no field %v", base.String(), name.Value) } - return v.error(node, "type %v[%v] is undefined", base, prop) + return v.error(node, "type %v[%v] is undefined", base.String(), prop.String()) } func (v *Checker) sliceNode(node *ast.SliceNode) Nature { @@ -600,20 +602,20 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { case reflect.String, reflect.Array, reflect.Slice: // ok default: - return v.error(node, "cannot slice %s", nt) + return v.error(node, "cannot slice %s", nt.String()) } if node.From != nil { from := v.visit(node.From) if !isInteger(from) && !isUnknown(from) { - return v.error(node.From, "non-integer slice index %v", from) + return v.error(node.From, "non-integer slice index %v", from.String()) } } if node.To != nil { to := v.visit(node.To) if !isInteger(to) && !isUnknown(to) { - return v.error(node.To, "non-integer slice index %v", to) + return v.error(node.To, "non-integer slice index %v", to.String()) } } @@ -677,15 +679,16 @@ func (v *Checker) functionReturnType(node *ast.CallNode) Nature { } return outType } - return v.error(node, "%s is not callable", nt) + return v.error(node, "%s is not callable", nt.String()) } func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { switch node.Name { case "all", "none", "any", "one": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -697,16 +700,18 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + out := predicate.Out(0) + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) } return boolNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "filter": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -718,7 +723,8 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + out := predicate.Out(0) + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) } if isUnknown(collection) { return arrayNature @@ -728,9 +734,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should has one input and one output param") case "map": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection, varScope{"index", integerNature}) @@ -746,9 +753,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should has one input and one output param") case "count": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } if len(node.Arguments) == 1 { @@ -763,7 +771,8 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate.NumOut() == 1 && predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + out := predicate.Out(0) + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) } return integerNature @@ -771,9 +780,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sum": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } if len(node.Arguments) == 2 { @@ -794,9 +804,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } case "find", "findLast": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -808,7 +819,8 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + out := predicate.Out(0) + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) } if isUnknown(collection) { return unknown @@ -818,9 +830,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should has one input and one output param") case "findIndex", "findLastIndex": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -832,16 +845,18 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + out := predicate.Out(0) + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) } return integerNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "groupBy": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -858,9 +873,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sortBy": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -880,9 +896,10 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should has one input and one output param") case "reduce": - collection := v.visit(node.Arguments[0]).Deref() + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection, varScope{"index", integerNature}, varScope{"acc", unknown}) @@ -946,16 +963,16 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { switch base.Kind() { case reflect.Slice, reflect.Array: if !isInteger(prop) && !isUnknown(prop) { - return v.error(node.Arguments[1], "non-integer slice index %s", prop) + return v.error(node.Arguments[1], "non-integer slice index %s", prop.String()) } return base.Elem() case reflect.Map: if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { - return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop, base) + return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop.String(), base.String()) } return base.Elem() } - return v.error(node.Arguments[0], "type %v does not support indexing", base) + return v.error(node.Arguments[0], "type %v does not support indexing", base.String()) } func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { @@ -1086,7 +1103,8 @@ func (v *Checker) checkArguments( if fn.IsVariadic() && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. - in = fn.In(fn.NumIn() - 1).Elem() + in = fn.In(fn.NumIn() - 1) + in = in.Elem() } else { in = fn.In(i + fnInOffset) } @@ -1107,7 +1125,7 @@ func (v *Checker) checkArguments( } return unknown, &file.Error{ Location: arg.Location(), - Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in, name), + Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in.String(), name), } } @@ -1119,12 +1137,15 @@ func (v *Checker) checkArguments( // We also need to check if dereference arg type is assignable to the function input type. // For example, func(int) and argument *int. In this case we will add OpDeref to the argument, // so we can call the function with *int argument. - assignable = assignable || argNature.Deref().AssignableTo(in) + if !assignable { + nt := argNature.Deref() + assignable = nt.AssignableTo(in) + } if !assignable && !isUnknown(argNature) { return unknown, &file.Error{ Location: arg.Location(), - Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature, in, name), + Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature.String(), in.String(), name), } } } @@ -1251,7 +1272,7 @@ func (v *Checker) lookupVariable(name string) (varScope, bool) { func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { c := v.visit(node.Cond) if !isBool(c) && !isUnknown(c) { - return v.error(node.Cond, "non-bool expression (type %v) used as condition", c) + return v.error(node.Cond, "non-bool expression (type %v) used as condition", c.String()) } t1 := v.visit(node.Exp1) diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 993c9fcf..f5ee74dc 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -25,11 +25,11 @@ type Nature struct { FieldIndex []int // Index of field in type. } -func (n Nature) IsAny() bool { +func (n *Nature) IsAny() bool { return n.Kind() == reflect.Interface && n.NumMethods() == 0 } -func (n Nature) IsUnknown() bool { +func (n *Nature) IsUnknown() bool { switch { case n.Type == nil && !n.Nil: return true @@ -39,35 +39,35 @@ func (n Nature) IsUnknown() bool { return false } -func (n Nature) String() string { +func (n *Nature) String() string { if n.Type != nil { return n.Type.String() } return "unknown" } -func (n Nature) Deref() Nature { +func (n *Nature) Deref() Nature { if n.Type != nil { n.Type = deref.Type(n.Type) } - return n + return *n } -func (n Nature) Kind() reflect.Kind { +func (n *Nature) Kind() reflect.Kind { if n.Type != nil { return n.Type.Kind() } return reflect.Invalid } -func (n Nature) Key() Nature { +func (n *Nature) Key() Nature { if n.Kind() == reflect.Map { return Nature{Type: n.Type.Key()} } return unknown } -func (n Nature) Elem() Nature { +func (n *Nature) Elem() Nature { switch n.Kind() { case reflect.Ptr: return Nature{Type: n.Type.Elem()} @@ -85,7 +85,7 @@ func (n Nature) Elem() Nature { return unknown } -func (n Nature) AssignableTo(nt Nature) bool { +func (n *Nature) AssignableTo(nt Nature) bool { if n.Nil { // Untyped nil is assignable to any interface, but implements only the empty interface. if nt.IsAny() { @@ -98,14 +98,14 @@ func (n Nature) AssignableTo(nt Nature) bool { return n.Type.AssignableTo(nt.Type) } -func (n Nature) NumMethods() int { +func (n *Nature) NumMethods() int { if n.Type == nil { return 0 } return n.Type.NumMethod() } -func (n Nature) MethodByName(name string) (Nature, bool) { +func (n *Nature) MethodByName(name string) (Nature, bool) { if n.Type == nil { return unknown, false } @@ -132,42 +132,42 @@ func (n Nature) MethodByName(name string) (Nature, bool) { } } -func (n Nature) NumIn() int { +func (n *Nature) NumIn() int { if n.Type == nil { return 0 } return n.Type.NumIn() } -func (n Nature) In(i int) Nature { +func (n *Nature) In(i int) Nature { if n.Type == nil { return unknown } return Nature{Type: n.Type.In(i)} } -func (n Nature) NumOut() int { +func (n *Nature) NumOut() int { if n.Type == nil { return 0 } return n.Type.NumOut() } -func (n Nature) Out(i int) Nature { +func (n *Nature) Out(i int) Nature { if n.Type == nil { return unknown } return Nature{Type: n.Type.Out(i)} } -func (n Nature) IsVariadic() bool { +func (n *Nature) IsVariadic() bool { if n.Type == nil { return false } return n.Type.IsVariadic() } -func (n Nature) FieldByName(name string) (Nature, bool) { +func (n *Nature) FieldByName(name string) (Nature, bool) { if n.Type == nil { return unknown, false } @@ -175,14 +175,14 @@ func (n Nature) FieldByName(name string) (Nature, bool) { return Nature{Type: field.Type, FieldIndex: field.Index}, ok } -func (n Nature) PkgPath() string { +func (n *Nature) PkgPath() string { if n.Type == nil { return "" } return n.Type.PkgPath() } -func (n Nature) IsFastMap() bool { +func (n *Nature) IsFastMap() bool { if n.Type == nil { return false } @@ -194,7 +194,7 @@ func (n Nature) IsFastMap() bool { return false } -func (n Nature) Get(name string) (Nature, bool) { +func (n *Nature) Get(name string) (Nature, bool) { if n.Type == nil { return unknown, false } @@ -221,7 +221,7 @@ func (n Nature) Get(name string) (Nature, bool) { return unknown, false } -func (n Nature) All() map[string]Nature { +func (n *Nature) All() map[string]Nature { table := make(map[string]Nature) if n.Type == nil { diff --git a/compiler/compiler.go b/compiler/compiler.go index 595355d2..4f7d82ca 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1080,7 +1080,8 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { for i, arg := range node.Arguments { c.compile(arg) argType := arg.Type() - if argType.Kind() == reflect.Ptr || arg.Nature().IsUnknown() { + argNature := arg.Nature() + if argType.Kind() == reflect.Ptr || argNature.IsUnknown() { if f.Deref == nil { // By default, builtins expect arguments to be dereferenced. c.emit(OpDeref) diff --git a/docgen/docgen.go b/docgen/docgen.go index 1844f23b..e93dbfa2 100644 --- a/docgen/docgen.go +++ b/docgen/docgen.go @@ -85,7 +85,8 @@ func CreateDoc(i any) *Context { PkgPath: deref.Type(reflect.TypeOf(i)).PkgPath(), } - for name, t := range conf.Env(i).All() { + env := conf.Env(i) + for name, t := range env.All() { if _, ok := c.Variables[Identifier(name)]; ok { continue } From 1c36c13d77b1cac1279154cc802d2e3546b5b7ed Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 04:06:26 -0300 Subject: [PATCH 06/24] added cache for type and nature information --- checker/checker.go | 10 +-- checker/info.go | 10 +-- checker/nature/nature.go | 114 ++++++++++++++++++++++++++++------- compiler/compiler.go | 24 +++++--- conf/config.go | 3 +- expr.go | 1 + patcher/operator_override.go | 5 +- 7 files changed, 125 insertions(+), 42 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 38eafd90..5ad41700 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -245,7 +245,7 @@ func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { // ident method returns type of environment variable, builtin or function. func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Nature { - if nt, ok := v.config.Env.Get(name); ok { + if nt, ok := v.config.Env.Get(&v.config.NtCache, name); ok { return nt } if builtins { @@ -540,7 +540,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { // First, check methods defined on base type itself, // independent of which type it is. Without dereferencing. - if m, ok := base.MethodByName(name.Value); ok { + if m, ok := base.MethodByName(&v.config.NtCache, name.Value); ok { return m } } @@ -570,7 +570,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value - if field, ok := base.FieldByName(propertyName); ok { + if field, ok := base.FieldByName(&v.config.NtCache, propertyName); ok { return Nature{Type: field.Type} } if node.Method { @@ -949,7 +949,7 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" { if s, ok := node.Arguments[1].(*ast.StringNode); ok { - if nt, ok := v.config.Env.Get(s.Value); ok { + if nt, ok := v.config.Env.Get(&v.config.NtCache, s.Value); ok { return nt } } @@ -1230,7 +1230,7 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { } func (v *Checker) variableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { - if _, ok := v.config.Env.Get(node.Name); ok { + if _, ok := v.config.Env.Get(&v.config.NtCache, node.Name); ok { return v.error(node, "cannot redeclare %v", node.Name) } if _, ok := v.config.Functions[node.Name]; ok { diff --git a/checker/info.go b/checker/info.go index f1cc92eb..41e21d6b 100644 --- a/checker/info.go +++ b/checker/info.go @@ -8,11 +8,11 @@ import ( "github.com/expr-lang/expr/vm" ) -func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { +func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: if env.Kind() == reflect.Struct { - if field, ok := env.Get(n.Value); ok && len(field.FieldIndex) > 0 { + if field, ok := env.Get(c, n.Value); ok && len(field.FieldIndex) > 0 { return true, field.FieldIndex, n.Value } } @@ -22,7 +22,7 @@ func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { if base.Kind() == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { name := prop.Value - if field, ok := base.FieldByName(name); ok { + if field, ok := base.FieldByName(c, name); ok { return true, field.FieldIndex, name } } @@ -31,11 +31,11 @@ func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { return false, nil, "" } -func MethodIndex(env Nature, node ast.Node) (bool, int, string) { +func MethodIndex(c *Cache, env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: if env.Kind() == reflect.Struct { - if m, ok := env.Get(n.Value); ok { + if m, ok := env.Get(c, n.Value); ok { return m.Method, m.MethodIndex, n.Value } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index f5ee74dc..91c2594c 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -7,9 +7,7 @@ import ( "github.com/expr-lang/expr/internal/deref" ) -var ( - unknown = Nature{} -) +var unknown = Nature{} type Nature struct { Type reflect.Type // Type of the value. If nil, then value is unknown. @@ -25,6 +23,17 @@ type Nature struct { FieldIndex []int // Index of field in type. } +type Cache struct { + methodByName map[rTypeWithKey]*Nature + fieldByName map[rTypeWithKey]*Nature + get map[rTypeWithKey]*Nature +} + +type rTypeWithKey struct { + t reflect.Type + key string +} + func (n *Nature) IsAny() bool { return n.Kind() == reflect.Interface && n.NumMethods() == 0 } @@ -105,13 +114,36 @@ func (n *Nature) NumMethods() int { return n.Type.NumMethod() } -func (n *Nature) MethodByName(name string) (Nature, bool) { +func (n *Nature) MethodByName(c *Cache, name string) (Nature, bool) { + if ntPtr := n.methodByNamePtr(c, name); ntPtr != nil { + return *ntPtr, true + } + return unknown, false +} + +func (n *Nature) methodByNamePtr(c *Cache, name string) *Nature { + var ntPtr *Nature + var cacheHit bool + key := rTypeWithKey{n.Type, name} + if c.methodByName == nil { + c.methodByName = map[rTypeWithKey]*Nature{} + } else { + ntPtr, cacheHit = c.methodByName[key] + } + if !cacheHit { + ntPtr = n.methodByNameSlow(name) + c.methodByName[key] = ntPtr + } + return ntPtr +} + +func (n *Nature) methodByNameSlow(name string) *Nature { if n.Type == nil { - return unknown, false + return nil } method, ok := n.Type.MethodByName(name) if !ok { - return unknown, false + return nil } if n.Type.Kind() == reflect.Interface { @@ -122,13 +154,13 @@ func (n *Nature) MethodByName(name string) (Nature, bool) { // Also, we can not use m.Index here, because it will be // different indexes for different types which implement // the same interface. - return Nature{Type: method.Type}, true + return &Nature{Type: method.Type} } else { - return Nature{ + return &Nature{ Type: method.Type, Method: true, MethodIndex: method.Index, - }, true + } } } @@ -167,12 +199,33 @@ func (n *Nature) IsVariadic() bool { return n.Type.IsVariadic() } -func (n *Nature) FieldByName(name string) (Nature, bool) { +func (n *Nature) FieldByName(c *Cache, name string) (Nature, bool) { + var ntPtr *Nature + var cacheHit bool + key := rTypeWithKey{n.Type, name} + if c.fieldByName == nil { + c.fieldByName = map[rTypeWithKey]*Nature{} + } else { + ntPtr, cacheHit = c.fieldByName[key] + } + if !cacheHit { + ntPtr = n.fieldByNameSlow(name) + c.fieldByName[key] = ntPtr + } + if ntPtr != nil { + return *ntPtr, true + } + return unknown, false +} + +func (n *Nature) fieldByNameSlow(name string) *Nature { if n.Type == nil { - return unknown, false + return nil } - field, ok := fetchField(n.Type, name) - return Nature{Type: field.Type, FieldIndex: field.Index}, ok + if field, ok := fetchField(n.Type, name); ok { + return &Nature{Type: field.Type, FieldIndex: field.Index} + } + return nil } func (n *Nature) PkgPath() string { @@ -194,13 +247,32 @@ func (n *Nature) IsFastMap() bool { return false } -func (n *Nature) Get(name string) (Nature, bool) { +func (n *Nature) Get(c *Cache, name string) (Nature, bool) { + var ntPtr *Nature + var cacheHit bool + key := rTypeWithKey{n.Type, name} + if c.get == nil { + c.get = map[rTypeWithKey]*Nature{} + } else { + ntPtr, cacheHit = c.get[key] + } + if !cacheHit { + ntPtr = n.getSlow(c, name) + c.get[key] = ntPtr + } + if ntPtr != nil { + return *ntPtr, true + } + return unknown, false +} + +func (n *Nature) getSlow(c *Cache, name string) *Nature { if n.Type == nil { - return unknown, false + return nil } - if m, ok := n.MethodByName(name); ok { - return m, true + if m := n.methodByNamePtr(c, name); m != nil { + return m } t := deref.Type(n.Type) @@ -208,17 +280,17 @@ func (n *Nature) Get(name string) (Nature, bool) { switch t.Kind() { case reflect.Struct: if f, ok := fetchField(t, name); ok { - return Nature{ + return &Nature{ Type: f.Type, FieldIndex: f.Index, - }, true + } } case reflect.Map: if f, ok := n.Fields[name]; ok { - return f, true + return &f } } - return unknown, false + return nil } func (n *Nature) All() map[string]Nature { diff --git a/compiler/compiler.go b/compiler/compiler.go index 4f7d82ca..b0fb45df 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -9,6 +9,7 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/checker" + "github.com/expr-lang/expr/checker/nature" . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" @@ -36,6 +37,12 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro debugInfo: make(map[string]string), } + if config != nil { + c.ntCache = &c.config.NtCache + } else { + c.ntCache = new(nature.Cache) + } + c.compile(tree.Node) if c.config != nil { @@ -74,6 +81,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro type compiler struct { config *conf.Config + ntCache *nature.Cache locations []file.Location bytecode []Opcode variables int @@ -302,12 +310,12 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { if env.IsFastMap() { c.emit(OpLoadFast, c.addConstant(node.Value)) - } else if ok, index, name := checker.FieldIndex(env, node); ok { + } else if ok, index, name := checker.FieldIndex(c.ntCache, env, node); ok { c.emit(OpLoadField, c.addConstant(&runtime.Field{ Index: index, Path: []string{name}, })) - } else if ok, index, name := checker.MethodIndex(env, node); ok { + } else if ok, index, name := checker.MethodIndex(c.ntCache, env, node); ok { c.emit(OpLoadMethod, c.addConstant(&runtime.Method{ Name: name, Index: index, @@ -653,7 +661,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { env = c.config.Env } - if ok, index, name := checker.MethodIndex(env, node); ok { + if ok, index, name := checker.MethodIndex(c.ntCache, env, node); ok { c.compile(node.Node) c.emit(OpMethod, c.addConstant(&runtime.Method{ Name: name, @@ -664,14 +672,14 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { op := OpFetch base := node.Node - ok, index, nodeName := checker.FieldIndex(env, node) + ok, index, nodeName := checker.FieldIndex(c.ntCache, env, node) path := []string{nodeName} if ok { op = OpFetchField for !node.Optional { if ident, isIdent := base.(*ast.IdentifierNode); isIdent { - if ok, identIndex, name := checker.FieldIndex(env, ident); ok { + if ok, identIndex, name := checker.FieldIndex(c.ntCache, env, ident); ok { index = append(identIndex, index...) path = append([]string{name}, path...) c.emitLocation(ident.Location(), OpLoadField, c.addConstant( @@ -682,7 +690,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { } if member, isMember := base.(*ast.MemberNode); isMember { - if ok, memberIndex, name := checker.FieldIndex(env, member); ok { + if ok, memberIndex, name := checker.FieldIndex(c.ntCache, env, member); ok { index = append(memberIndex, index...) path = append([]string{name}, path...) node = member @@ -743,7 +751,7 @@ func (c *compiler) CallNode(node *ast.CallNode) { } } case *ast.IdentifierNode: - if t, ok := c.config.Env.MethodByName(callee.Value); ok && t.Method { + if t, ok := c.config.Env.MethodByName(c.ntCache, callee.Value); ok && t.Method { fnInOffset = 1 fnNumIn-- } @@ -777,7 +785,7 @@ func (c *compiler) CallNode(node *ast.CallNode) { c.compile(node.Callee) if c.config != nil { - isMethod, _, _ := checker.MethodIndex(c.config.Env, node.Callee) + isMethod, _, _ := checker.MethodIndex(c.ntCache, c.config.Env, node.Callee) if index, ok := checker.TypedFuncIndex(node.Callee.Type(), isMethod); ok { c.emit(OpCallTyped, index) return diff --git a/conf/config.go b/conf/config.go index 2c14d988..6de19c7d 100644 --- a/conf/config.go +++ b/conf/config.go @@ -34,6 +34,7 @@ type Config struct { Functions FunctionsTable Builtins FunctionsTable Disabled map[string]bool // disabled builtins + NtCache nature.Cache } // CreateNew creates new config with default values. @@ -92,7 +93,7 @@ func (c *Config) IsOverridden(name string) bool { if _, ok := c.Functions[name]; ok { return true } - if _, ok := c.Env.Get(name); ok { + if _, ok := c.Env.Get(&c.NtCache, name); ok { return true } return false diff --git a/expr.go b/expr.go index 48298fe7..e8f4eb64 100644 --- a/expr.go +++ b/expr.go @@ -48,6 +48,7 @@ func Operator(operator string, fn ...string) Option { Overloads: fn, Env: &c.Env, Functions: c.Functions, + NtCache: &c.NtCache, } c.Visitors = append(c.Visitors, p) } diff --git a/patcher/operator_override.go b/patcher/operator_override.go index 308cbdba..cf4287c2 100644 --- a/patcher/operator_override.go +++ b/patcher/operator_override.go @@ -16,6 +16,7 @@ type OperatorOverloading struct { Env *nature.Nature // Env type. Functions conf.FunctionsTable // Env functions. applied bool // Flag to indicate if any changes were made to the tree. + NtCache *nature.Cache } func (p *OperatorOverloading) Visit(node *ast.Node) { @@ -62,7 +63,7 @@ func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (r func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) { for _, fn := range p.Overloads { - fnType, ok := p.Env.Get(fn) + fnType, ok := p.Env.Get(p.NtCache, fn) if !ok { continue } @@ -109,7 +110,7 @@ func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex func (p *OperatorOverloading) Check() { for _, fn := range p.Overloads { - fnType, foundType := p.Env.Get(fn) + fnType, foundType := p.Env.Get(p.NtCache, fn) fnFunc, foundFunc := p.Functions[fn] if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) { panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator)) From e9b245b16591890e679a2775026b32c43309a3b9 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 06:23:55 -0300 Subject: [PATCH 07/24] optimize type and nature checks --- checker/checker.go | 292 ++++++++++++++++++++------------------- checker/info.go | 2 +- checker/nature/nature.go | 181 +++++++++++++++++++++--- checker/types.go | 190 ------------------------- 4 files changed, 311 insertions(+), 354 deletions(-) delete mode 100644 checker/types.go diff --git a/checker/checker.go b/checker/checker.go index 5ad41700..8f65727a 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "regexp" + "time" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" @@ -13,6 +14,21 @@ import ( "github.com/expr-lang/expr/parser" ) +var ( + anyType = reflect.TypeOf(new(any)).Elem() + + unknown = Nature{} + nilNature = Nature{Nil: true} + boolNature = Nature{Type: reflect.TypeOf(true)} + integerNature = Nature{Type: reflect.TypeOf(0)} + floatNature = Nature{Type: reflect.TypeOf(float64(0))} + stringNature = Nature{Type: reflect.TypeOf("")} + arrayNature = Nature{Type: reflect.TypeOf([]any{})} + mapNature = Nature{Type: reflect.TypeOf(map[string]any{})} + timeNature = Nature{Type: reflect.TypeOf(time.Time{})} + durationNature = Nature{Type: reflect.TypeOf(time.Duration(0))} +) + // ParseCheck parses input expression and checks its types. Also, it applies // all provided patchers. In case of error, it returns error with a tree. func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { @@ -121,14 +137,14 @@ func (c *Checker) check(tree *parser.Tree) (reflect.Type, error) { if c.config.Expect != reflect.Invalid { if c.config.ExpectAny { - if isUnknown(nt) { + if nt.IsUnknown() { return t, nil } } switch c.config.Expect { case reflect.Int, reflect.Int64, reflect.Float64: - if !isNumber(nt) { + if !nt.IsNumber() { return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } default: @@ -289,18 +305,18 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { switch node.Operator { case "!", "not": - if isBool(nt) { + if nt.IsBool() { return boolNature } - if isUnknown(nt) { + if nt.IsUnknown() { return boolNature } case "+", "-": - if isNumber(nt) { + if nt.IsNumber() { return nt } - if isUnknown(nt) { + if nt.IsUnknown() { return unknown } @@ -320,135 +336,135 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { switch node.Operator { case "==", "!=": - if isComparable(l, r) { + if l.ComparableTo(r) { return boolNature } case "or", "||", "and", "&&": - if isBool(l) && isBool(r) { + if l.IsBool() && r.IsBool() { return boolNature } - if or(l, r, isBool) { + if l.MaybeCompatible(r, BoolCheck) { return boolNature } case "<", ">", ">=", "<=": - if isNumber(l) && isNumber(r) { + if l.IsNumber() && r.IsNumber() { return boolNature } - if isString(l) && isString(r) { + if l.IsString() && r.IsString() { return boolNature } - if isTime(l) && isTime(r) { + if l.IsTime() && r.IsTime() { return boolNature } - if isDuration(l) && isDuration(r) { + if l.IsDuration() && r.IsDuration() { return boolNature } - if or(l, r, isNumber, isString, isTime, isDuration) { + if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { return boolNature } case "-": - if isNumber(l) && isNumber(r) { - return combined(l, r) + if l.IsNumber() && r.IsNumber() { + return l.PromoteNumericNature(r) } - if isTime(l) && isTime(r) { + if l.IsTime() && r.IsTime() { return durationNature } - if isTime(l) && isDuration(r) { + if l.IsTime() && r.IsDuration() { return timeNature } - if isDuration(l) && isDuration(r) { + if l.IsDuration() && r.IsDuration() { return durationNature } - if or(l, r, isNumber, isTime, isDuration) { + if l.MaybeCompatible(r, NumberCheck, TimeCheck, DurationCheck) { return unknown } case "*": - if isNumber(l) && isNumber(r) { - return combined(l, r) + if l.IsNumber() && r.IsNumber() { + return l.PromoteNumericNature(r) } - if isNumber(l) && isDuration(r) { + if l.IsNumber() && r.IsDuration() { return durationNature } - if isDuration(l) && isNumber(r) { + if l.IsDuration() && r.IsNumber() { return durationNature } - if isDuration(l) && isDuration(r) { + if l.IsDuration() && r.IsDuration() { return durationNature } - if or(l, r, isNumber, isDuration) { + if l.MaybeCompatible(r, NumberCheck, DurationCheck) { return unknown } case "/": - if isNumber(l) && isNumber(r) { + if l.IsNumber() && r.IsNumber() { return floatNature } - if or(l, r, isNumber) { + if l.MaybeCompatible(r, NumberCheck) { return floatNature } case "**", "^": - if isNumber(l) && isNumber(r) { + if l.IsNumber() && r.IsNumber() { return floatNature } - if or(l, r, isNumber) { + if l.MaybeCompatible(r, NumberCheck) { return floatNature } case "%": - if isInteger(l) && isInteger(r) { + if l.IsInteger() && r.IsInteger() { return integerNature } - if or(l, r, isInteger) { + if l.MaybeCompatible(r, IntegerCheck) { return integerNature } case "+": - if isNumber(l) && isNumber(r) { - return combined(l, r) + if l.IsNumber() && r.IsNumber() { + return l.PromoteNumericNature(r) } - if isString(l) && isString(r) { + if l.IsString() && r.IsString() { return stringNature } - if isTime(l) && isDuration(r) { + if l.IsTime() && r.IsDuration() { return timeNature } - if isDuration(l) && isTime(r) { + if l.IsDuration() && r.IsTime() { return timeNature } - if isDuration(l) && isDuration(r) { + if l.IsDuration() && r.IsDuration() { return durationNature } - if or(l, r, isNumber, isString, isTime, isDuration) { + if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { return unknown } case "in": - if (isString(l) || isUnknown(l)) && isStruct(r) { + if (l.IsString() || l.IsUnknown()) && r.IsStruct() { return boolNature } - if isMap(r) { + if r.IsMap() { rKey := r.Key() - if !isUnknown(l) && !l.AssignableTo(rKey) { + if !l.IsUnknown() && !l.AssignableTo(rKey) { return v.error(node, "cannot use %s as type %s in map key", l.String(), rKey.String()) } return boolNature } - if isArray(r) { + if r.IsArray() { rElem := r.Elem() - if !isComparable(l, rElem) { + if !l.ComparableTo(rElem) { return v.error(node, "cannot use %s as type %s in array", l.String(), rElem.String()) } return boolNature } - if isUnknown(l) && anyOf(r, isString, isArray, isMap) { + if l.IsUnknown() && r.IsAnyOf(StringCheck, ArrayCheck, MapCheck) { return boolNature } - if isUnknown(r) { + if r.IsUnknown() { return boolNature } @@ -459,37 +475,37 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return v.error(node, err.Error()) } } - if isString(l) && isString(r) { + if l.IsString() && r.IsString() { return boolNature } - if or(l, r, isString) { + if l.MaybeCompatible(r, StringCheck) { return boolNature } case "contains", "startsWith", "endsWith": - if isString(l) && isString(r) { + if l.IsString() && r.IsString() { return boolNature } - if or(l, r, isString) { + if l.MaybeCompatible(r, StringCheck) { return boolNature } case "..": - if isInteger(l) && isInteger(r) { - return arrayOf(integerNature) + if l.IsInteger() && r.IsInteger() { + return integerNature.MakeArrayOf() } - if or(l, r, isInteger) { - return arrayOf(integerNature) + if l.MaybeCompatible(r, IntegerCheck) { + return integerNature.MakeArrayOf() } case "??": - if isNil(l) && !isNil(r) { + if l.Nil && !r.Nil { return r } - if !isNil(l) && isNil(r) { + if !l.Nil && r.Nil { return l } - if isNil(l) && isNil(r) { + if l.Nil && r.Nil { return nilNature } if r.AssignableTo(l) { @@ -529,12 +545,12 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { base := v.visit(node.Node) prop := v.visit(node.Property) - if isUnknown(base) { + if base.IsUnknown() { return unknown } if name, ok := node.Property.(*ast.StringNode); ok { - if isNil(base) { + if base.Nil { return v.error(node, "type nil has no field %s", name.Value) } @@ -549,7 +565,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { switch base.Kind() { case reflect.Map: - if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { + if !prop.AssignableTo(base.Key()) && !prop.IsUnknown() { return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String()) } if prop, ok := node.Property.(*ast.StringNode); ok { @@ -562,7 +578,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { return base.Elem() case reflect.Array, reflect.Slice: - if !isInteger(prop) && !isUnknown(prop) { + if !prop.IsInteger() && !prop.IsUnknown() { return v.error(node.Property, "array elements can only be selected using an integer (got %s)", prop.String()) } return base.Elem() @@ -594,7 +610,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { func (v *Checker) sliceNode(node *ast.SliceNode) Nature { nt := v.visit(node.Node) - if isUnknown(nt) { + if nt.IsUnknown() { return unknown } @@ -607,14 +623,14 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { if node.From != nil { from := v.visit(node.From) - if !isInteger(from) && !isUnknown(from) { + if !from.IsInteger() && !from.IsUnknown() { return v.error(node.From, "non-integer slice index %v", from.String()) } } if node.To != nil { to := v.visit(node.To) - if !isInteger(to) && !isUnknown(to) { + if !to.IsInteger() && !to.IsUnknown() { return v.error(node.To, "non-integer slice index %v", to.String()) } } @@ -660,11 +676,11 @@ func (v *Checker) functionReturnType(node *ast.CallNode) Nature { } } - if isUnknown(nt) { + if nt.IsUnknown() { return unknown } - if isNil(nt) { + if nt.Nil { return v.error(node, "%v is nil; cannot call nil as function", fnName) } @@ -687,7 +703,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "all", "none", "any", "one": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -695,13 +711,12 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - out := predicate.Out(0) - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) + if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) } return boolNature } @@ -710,7 +725,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "filter": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -718,25 +733,25 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - out := predicate.Out(0) - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) + if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) } - if isUnknown(collection) { + if collection.IsUnknown() { return arrayNature } - return arrayOf(collection.Elem()) + collection = collection.Elem() + return collection.MakeArrayOf() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "map": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -744,18 +759,18 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - return arrayOf(*predicate.PredicateOut) + return predicate.PredicateOut.MakeArrayOf() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "count": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -767,12 +782,11 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - out := predicate.Out(0) - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) } return integerNature @@ -782,7 +796,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "sum": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -791,13 +805,13 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return predicate.Out(0) + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + return *predicate.Out(0) } } else { - if isUnknown(collection) { + if collection.IsUnknown() { return unknown } return collection.Elem() @@ -806,7 +820,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "find", "findLast": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -814,15 +828,14 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - out := predicate.Out(0) - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) + if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) } - if isUnknown(collection) { + if collection.IsUnknown() { return unknown } return collection.Elem() @@ -832,7 +845,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "findIndex", "findLastIndex": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -840,13 +853,12 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - out := predicate.Out(0) - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", out.String()) + if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) } return integerNature } @@ -855,7 +867,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "groupBy": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -863,19 +875,20 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - groups := arrayOf(collection.Elem()) - return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &groups} + collection = collection.Elem() + collection = collection.MakeArrayOf() + return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &collection} } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sortBy": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -887,9 +900,9 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { _ = v.visit(node.Arguments[2]) } - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { return collection } @@ -898,7 +911,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { case "reduce": collection := v.visit(node.Arguments[0]) collection = collection.Deref() - if !isArray(collection) && !isUnknown(collection) { + if !collection.IsArray() && !collection.IsUnknown() { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } @@ -910,7 +923,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { _ = v.visit(node.Arguments[2]) } - if isFunc(predicate) && predicate.NumOut() == 1 { + if predicate.IsFunc() && predicate.NumOut() == 1 { return *predicate.PredicateOut } return v.error(node.Arguments[1], "predicate should has two input and one output param") @@ -956,18 +969,18 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { return unknown } - if isUnknown(base) { + if base.IsUnknown() { return unknown } switch base.Kind() { case reflect.Slice, reflect.Array: - if !isInteger(prop) && !isUnknown(prop) { + if !prop.IsInteger() && !prop.IsUnknown() { return v.error(node.Arguments[1], "non-integer slice index %s", prop.String()) } return base.Elem() case reflect.Map: - if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { + if !prop.AssignableTo(base.Key()) && !prop.IsUnknown() { return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop.String(), base.String()) } return base.Elem() @@ -980,7 +993,7 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] args := make([]reflect.Type, len(arguments)) for i, arg := range arguments { argNature := v.visit(arg) - if isUnknown(argNature) { + if argNature.IsUnknown() { args[i] = anyType } else { args[i] = argNature.Type @@ -1035,7 +1048,7 @@ func (v *Checker) checkArguments( arguments []ast.Node, node ast.Node, ) (Nature, *file.Error) { - if isUnknown(fn) { + if fn.IsUnknown() { return unknown, nil } @@ -1093,7 +1106,7 @@ func (v *Checker) checkArguments( for _, arg := range arguments { _ = v.visit(arg) } - return fn.Out(0), err + return *fn.Out(0), err } for i, arg := range arguments { @@ -1103,23 +1116,22 @@ func (v *Checker) checkArguments( if fn.IsVariadic() && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. - in = fn.In(fn.NumIn() - 1) - in = in.Elem() + in = fn.In(fn.NumIn() - 1).Elem() } else { - in = fn.In(i + fnInOffset) + in = *fn.In(i + fnInOffset) } - if isFloat(in) && isInteger(argNature) { + if in.IsFloat() && argNature.IsInteger() { traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in) continue } - if isInteger(in) && isInteger(argNature) && argNature.Kind() != in.Kind() { + if in.IsInteger() && argNature.IsInteger() && argNature.Kind() != in.Kind() { traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) continue } - if isNil(argNature) { + if argNature.Nil { if in.Kind() == reflect.Ptr || in.Kind() == reflect.Interface { continue } @@ -1142,7 +1154,7 @@ func (v *Checker) checkArguments( assignable = nt.AssignableTo(in) } - if !assignable && !isUnknown(argNature) { + if !assignable && !argNature.IsUnknown() { return unknown, &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature.String(), in.String(), name), @@ -1150,7 +1162,7 @@ func (v *Checker) checkArguments( } } - return fn.Out(0), nil + return *fn.Out(0), nil } func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newNature Nature) { @@ -1193,9 +1205,9 @@ func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Na func (v *Checker) predicateNode(node *ast.PredicateNode) Nature { nt := v.visit(node.Node) var out []reflect.Type - if isUnknown(nt) { + if nt.IsUnknown() { out = append(out, anyType) - } else if !isNil(nt) { + } else if !nt.Nil { out = append(out, nt.Type) } return Nature{ @@ -1210,7 +1222,7 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { } scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { - if isUnknown(scope.collection) { + if scope.collection.IsUnknown() { return unknown } switch scope.collection.Kind() { @@ -1271,20 +1283,20 @@ func (v *Checker) lookupVariable(name string) (varScope, bool) { func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { c := v.visit(node.Cond) - if !isBool(c) && !isUnknown(c) { + if !c.IsBool() && !c.IsUnknown() { return v.error(node.Cond, "non-bool expression (type %v) used as condition", c.String()) } t1 := v.visit(node.Exp1) t2 := v.visit(node.Exp2) - if isNil(t1) && !isNil(t2) { + if t1.Nil && !t2.Nil { return t2 } - if !isNil(t1) && isNil(t2) { + if !t1.Nil && t2.Nil { return t1 } - if isNil(t1) && isNil(t2) { + if t1.Nil && t2.Nil { return nilNature } if t1.AssignableTo(t2) { @@ -1306,7 +1318,7 @@ func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { prev = curr } if allElementsAreSameType { - return arrayOf(prev) + return prev.MakeArrayOf() } return arrayNature } diff --git a/checker/info.go b/checker/info.go index 41e21d6b..5684933b 100644 --- a/checker/info.go +++ b/checker/info.go @@ -121,7 +121,7 @@ func IsFastFunc(fn reflect.Type, method bool) bool { fn.NumOut() == 1 && fn.Out(0).Kind() == reflect.Interface { rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods - if kind(rest) == reflect.Slice && rest.Elem().Kind() == reflect.Interface { + if rest != nil && rest.Kind() == reflect.Slice && rest.Elem().Kind() == reflect.Interface { return true } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 91c2594c..ab2a1707 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -1,13 +1,37 @@ package nature import ( + "fmt" "reflect" + "time" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/internal/deref" ) -var unknown = Nature{} +var ( + arrayType = reflect.TypeOf([]any{}) + timeType = reflect.TypeOf(time.Time{}) + durationType = reflect.TypeOf(time.Duration(0)) + + unknown = Nature{} + floatNature = Nature{Type: reflect.TypeOf(float64(0))} + integerNature = Nature{Type: reflect.TypeOf(0)} +) + +type NatureCheck int + +const ( + _ NatureCheck = iota + BoolCheck + StringCheck + IntegerCheck + NumberCheck + MapCheck + ArrayCheck + TimeCheck + DurationCheck +) type Nature struct { Type reflect.Type // Type of the value. If nil, then value is unknown. @@ -39,13 +63,7 @@ func (n *Nature) IsAny() bool { } func (n *Nature) IsUnknown() bool { - switch { - case n.Type == nil && !n.Nil: - return true - case n.IsAny(): - return true - } - return false + return n.Type == nil && !n.Nil || n.IsAny() } func (n *Nature) String() string { @@ -124,15 +142,14 @@ func (n *Nature) MethodByName(c *Cache, name string) (Nature, bool) { func (n *Nature) methodByNamePtr(c *Cache, name string) *Nature { var ntPtr *Nature var cacheHit bool - key := rTypeWithKey{n.Type, name} if c.methodByName == nil { c.methodByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = c.methodByName[key] + ntPtr, cacheHit = c.methodByName[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.methodByNameSlow(name) - c.methodByName[key] = ntPtr + c.methodByName[rTypeWithKey{n.Type, name}] = ntPtr } return ntPtr } @@ -171,11 +188,11 @@ func (n *Nature) NumIn() int { return n.Type.NumIn() } -func (n *Nature) In(i int) Nature { +func (n *Nature) In(i int) *Nature { if n.Type == nil { - return unknown + return &unknown } - return Nature{Type: n.Type.In(i)} + return &Nature{Type: n.Type.In(i)} } func (n *Nature) NumOut() int { @@ -185,11 +202,11 @@ func (n *Nature) NumOut() int { return n.Type.NumOut() } -func (n *Nature) Out(i int) Nature { +func (n *Nature) Out(i int) *Nature { if n.Type == nil { - return unknown + return &unknown } - return Nature{Type: n.Type.Out(i)} + return &Nature{Type: n.Type.Out(i)} } func (n *Nature) IsVariadic() bool { @@ -202,15 +219,14 @@ func (n *Nature) IsVariadic() bool { func (n *Nature) FieldByName(c *Cache, name string) (Nature, bool) { var ntPtr *Nature var cacheHit bool - key := rTypeWithKey{n.Type, name} if c.fieldByName == nil { c.fieldByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = c.fieldByName[key] + ntPtr, cacheHit = c.fieldByName[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.fieldByNameSlow(name) - c.fieldByName[key] = ntPtr + c.fieldByName[rTypeWithKey{n.Type, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true @@ -250,15 +266,14 @@ func (n *Nature) IsFastMap() bool { func (n *Nature) Get(c *Cache, name string) (Nature, bool) { var ntPtr *Nature var cacheHit bool - key := rTypeWithKey{n.Type, name} if c.get == nil { c.get = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = c.get[key] + ntPtr, cacheHit = c.get[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.getSlow(c, name) - c.get[key] = ntPtr + c.get[rTypeWithKey{n.Type, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true @@ -331,3 +346,123 @@ func (n *Nature) All() map[string]Nature { return table } + +func (n *Nature) IsNumber() bool { + return n.IsInteger() || n.IsFloat() +} + +func (n *Nature) IsInteger() bool { + if n.PkgPath() == "" { + switch n.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + } + } + return false +} + +func (n *Nature) IsFloat() bool { + if n.PkgPath() == "" { + switch n.Kind() { + case reflect.Float32, reflect.Float64: + return true + } + } + return false +} + +func (n *Nature) PromoteNumericNature(rhs Nature) Nature { + if n.IsUnknown() || rhs.IsUnknown() { + return unknown + } + if n.IsFloat() || rhs.IsFloat() { + return floatNature + } + return integerNature +} + +func (n *Nature) IsTime() bool { + return n.Type == timeType +} + +func (n *Nature) IsDuration() bool { + return n.Type == durationType +} + +func (n *Nature) IsBool() bool { + return n.Kind() == reflect.Bool +} + +func (n *Nature) IsString() bool { + return n.Kind() == reflect.String +} + +func (n *Nature) IsArray() bool { + k := n.Kind() + return k == reflect.Slice || k == reflect.Array +} + +func (n *Nature) IsMap() bool { + return n.Kind() == reflect.Map +} + +func (n *Nature) IsStruct() bool { + return n.Kind() == reflect.Struct +} + +func (n *Nature) IsFunc() bool { + return n.Kind() == reflect.Func +} + +func (n *Nature) IsAnyOf(cs ...NatureCheck) bool { + var result bool + for i := 0; i < len(cs) && !result; i++ { + switch cs[i] { + case BoolCheck: + result = n.IsBool() + case StringCheck: + result = n.IsString() + case IntegerCheck: + result = n.IsInteger() + case NumberCheck: + result = n.IsNumber() + case MapCheck: + result = n.IsMap() + case ArrayCheck: + result = n.IsArray() + case TimeCheck: + result = n.IsTime() + case DurationCheck: + result = n.IsDuration() + default: + panic(fmt.Sprintf("unknown check value %d", cs[i])) + } + } + return result +} + +func (n *Nature) ComparableTo(rhs Nature) bool { + return n.IsUnknown() || rhs.IsUnknown() || + n.Nil || rhs.Nil || + n.IsNumber() && rhs.IsNumber() || + n.IsDuration() && rhs.IsDuration() || + n.IsTime() && rhs.IsTime() || + n.IsArray() && rhs.IsArray() || + n.AssignableTo(rhs) +} + +func (n *Nature) MaybeCompatible(rhs Nature, cs ...NatureCheck) bool { + nIsUnknown := n.IsUnknown() + rshIsUnknown := rhs.IsUnknown() + return nIsUnknown && rshIsUnknown || + nIsUnknown && rhs.IsAnyOf(cs...) || + rshIsUnknown && n.IsAnyOf(cs...) +} + +func (n *Nature) MakeArrayOf() Nature { + return Nature{ + Type: arrayType, + ArrayOf: n, + } +} diff --git a/checker/types.go b/checker/types.go deleted file mode 100644 index 09896de5..00000000 --- a/checker/types.go +++ /dev/null @@ -1,190 +0,0 @@ -package checker - -import ( - "reflect" - "time" - - . "github.com/expr-lang/expr/checker/nature" -) - -var ( - unknown = Nature{} - nilNature = Nature{Nil: true} - boolNature = Nature{Type: reflect.TypeOf(true)} - integerNature = Nature{Type: reflect.TypeOf(0)} - floatNature = Nature{Type: reflect.TypeOf(float64(0))} - stringNature = Nature{Type: reflect.TypeOf("")} - arrayNature = Nature{Type: reflect.TypeOf([]any{})} - mapNature = Nature{Type: reflect.TypeOf(map[string]any{})} - timeNature = Nature{Type: reflect.TypeOf(time.Time{})} - durationNature = Nature{Type: reflect.TypeOf(time.Duration(0))} -) - -var ( - anyType = reflect.TypeOf(new(any)).Elem() - timeType = reflect.TypeOf(time.Time{}) - durationType = reflect.TypeOf(time.Duration(0)) - arrayType = reflect.TypeOf([]any{}) -) - -func arrayOf(nt Nature) Nature { - return Nature{ - Type: arrayType, - ArrayOf: &nt, - } -} - -func isNil(nt Nature) bool { - return nt.Nil -} - -func combined(l, r Nature) Nature { - if isUnknown(l) || isUnknown(r) { - return unknown - } - if isFloat(l) || isFloat(r) { - return floatNature - } - return integerNature -} - -func anyOf(nt Nature, fns ...func(Nature) bool) bool { - for _, fn := range fns { - if fn(nt) { - return true - } - } - return false -} - -func or(l, r Nature, fns ...func(Nature) bool) bool { - if isUnknown(l) && isUnknown(r) { - return true - } - if isUnknown(l) && anyOf(r, fns...) { - return true - } - if isUnknown(r) && anyOf(l, fns...) { - return true - } - return false -} - -func isUnknown(nt Nature) bool { - return nt.IsUnknown() -} - -func isInteger(nt Nature) bool { - switch nt.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fallthrough - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return nt.PkgPath() == "" - } - return false -} - -func isFloat(nt Nature) bool { - switch nt.Kind() { - case reflect.Float32, reflect.Float64: - return nt.PkgPath() == "" - } - return false -} - -func isNumber(nt Nature) bool { - return isInteger(nt) || isFloat(nt) -} - -func isTime(nt Nature) bool { - switch nt.Type { - case timeType: - return true - } - return false -} - -func isDuration(nt Nature) bool { - switch nt.Type { - case durationType: - return true - } - return false -} - -func isBool(nt Nature) bool { - switch nt.Kind() { - case reflect.Bool: - return true - } - return false -} - -func isString(nt Nature) bool { - switch nt.Kind() { - case reflect.String: - return true - } - return false -} - -func isArray(nt Nature) bool { - switch nt.Kind() { - case reflect.Slice, reflect.Array: - return true - } - return false -} - -func isMap(nt Nature) bool { - switch nt.Kind() { - case reflect.Map: - return true - } - return false -} - -func isStruct(nt Nature) bool { - switch nt.Kind() { - case reflect.Struct: - return true - } - return false -} - -func isFunc(nt Nature) bool { - switch nt.Kind() { - case reflect.Func: - return true - } - return false -} - -func kind(t reflect.Type) reflect.Kind { - if t == nil { - return reflect.Invalid - } - return t.Kind() -} - -func isComparable(l, r Nature) bool { - if isUnknown(l) || isUnknown(r) { - return true - } - if isNil(l) || isNil(r) { - return true - } - if isNumber(l) && isNumber(r) { - return true - } - if isDuration(l) && isDuration(r) { - return true - } - if isTime(l) && isTime(r) { - return true - } - if isArray(l) && isArray(r) { - return true - } - return l.AssignableTo(r) -} From 879544e42e04f3aeeff70254639d4dd611925529 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 07:56:43 -0300 Subject: [PATCH 08/24] more aggressively cache some partial results and add cheap checks --- checker/checker.go | 14 ++++++++------ checker/nature/nature.go | 11 ++++++++--- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 8f65727a..3291280b 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -1052,13 +1052,14 @@ func (v *Checker) checkArguments( return unknown, nil } - if fn.NumOut() == 0 { + numOut := fn.NumOut() + if numOut == 0 { return unknown, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } - if numOut := fn.NumOut(); numOut > 2 { + if numOut > 2 { return unknown, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), @@ -1078,7 +1079,8 @@ func (v *Checker) checkArguments( } var err *file.Error - if fn.IsVariadic() { + isVariadic := fn.IsVariadic() + if isVariadic { if len(arguments) < fnNumIn-1 { err = &file.Error{ Location: node.Location(), @@ -1113,10 +1115,10 @@ func (v *Checker) checkArguments( argNature := v.visit(arg) var in Nature - if fn.IsVariadic() && i >= fnNumIn-1 { + if isVariadic && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. - in = fn.In(fn.NumIn() - 1).Elem() + in = fn.In(fnNumIn - 1).Elem() } else { in = *fn.In(i + fnInOffset) } @@ -1149,7 +1151,7 @@ func (v *Checker) checkArguments( // We also need to check if dereference arg type is assignable to the function input type. // For example, func(int) and argument *int. In this case we will add OpDeref to the argument, // so we can call the function with *int argument. - if !assignable { + if !assignable && argNature.IsPointer() { nt := argNature.Deref() assignable = nt.AssignableTo(in) } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index ab2a1707..29013573 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -74,10 +74,11 @@ func (n *Nature) String() string { } func (n *Nature) Deref() Nature { - if n.Type != nil { - n.Type = deref.Type(n.Type) + ret := *n + if ret.Type != nil { + ret.Type = deref.Type(ret.Type) } - return *n + return ret } func (n *Nature) Kind() reflect.Kind { @@ -415,6 +416,10 @@ func (n *Nature) IsFunc() bool { return n.Kind() == reflect.Func } +func (n *Nature) IsPointer() bool { + return n.Kind() == reflect.Ptr +} + func (n *Nature) IsAnyOf(cs ...NatureCheck) bool { var result bool for i := 0; i < len(cs) && !result; i++ { From 8f6e2667acbf352e3d6f68b48d6b2a13f7ba03d5 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 08:21:12 -0300 Subject: [PATCH 09/24] cache reused values --- checker/checker.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 3291280b..52d1281f 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -639,8 +639,6 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { } func (v *Checker) callNode(node *ast.CallNode) Nature { - nt := v.functionReturnType(node) - // Check if type was set on node (for example, by patcher) // and use node type instead of function return type. // @@ -652,15 +650,18 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { // fix `errCall()` to return proper type, so on second // checker pass we should replace anyType on method node // with new correct function return type. - if node.Type() != nil && node.Type() != anyType { + if typ := node.Type(); typ != nil && typ != anyType { return node.Nature() } - return nt + return v.functionReturnType(node) } func (v *Checker) functionReturnType(node *ast.CallNode) Nature { nt := v.visit(node.Callee) + if nt.IsUnknown() { + return unknown + } if nt.Func != nil { return v.checkFunction(nt.Func, node, node.Arguments) @@ -676,16 +677,11 @@ func (v *Checker) functionReturnType(node *ast.CallNode) Nature { } } - if nt.IsUnknown() { - return unknown - } - if nt.Nil { return v.error(node, "%v is nil; cannot call nil as function", fnName) } - switch nt.Kind() { - case reflect.Func: + if nt.Kind() == reflect.Func { outType, err := v.checkArguments(fnName, nt, node.Arguments, node) if err != nil { if v.err == nil { From 1eb8ab5d62b3b028a8fdd13048670ca7f49509dd Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 10:09:43 -0300 Subject: [PATCH 10/24] reduce copying in Nature by breaking down optional fields --- checker/checker.go | 7 +-- checker/info.go | 8 ++-- checker/nature/nature.go | 98 ++++++++++++++++++++++++---------------- checker/nature/utils.go | 9 +++- conf/env.go | 6 ++- types/types.go | 15 ++++-- 6 files changed, 89 insertions(+), 54 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 52d1281f..d9d5eace 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -568,7 +568,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { if !prop.AssignableTo(base.Key()) && !prop.IsUnknown() { return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String()) } - if prop, ok := node.Property.(*ast.StringNode); ok { + if prop, ok := node.Property.(*ast.StringNode); ok && base.MapData != nil { if field, ok := base.Fields[prop.Value]; ok { return field } else if base.Strict { @@ -1124,13 +1124,14 @@ func (v *Checker) checkArguments( continue } - if in.IsInteger() && argNature.IsInteger() && argNature.Kind() != in.Kind() { + inKind := in.Kind() + if in.IsInteger() && argNature.IsInteger() && argNature.Kind() != inKind { traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) continue } if argNature.Nil { - if in.Kind() == reflect.Ptr || in.Kind() == reflect.Interface { + if inKind == reflect.Ptr || inKind == reflect.Interface { continue } return unknown, &file.Error{ diff --git a/checker/info.go b/checker/info.go index 5684933b..4e47df8a 100644 --- a/checker/info.go +++ b/checker/info.go @@ -12,7 +12,8 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: if env.Kind() == reflect.Struct { - if field, ok := env.Get(c, n.Value); ok && len(field.FieldIndex) > 0 { + field, ok := env.Get(c, n.Value) + if ok && field.StructData != nil && len(field.FieldIndex) > 0 { return true, field.FieldIndex, n.Value } } @@ -22,7 +23,8 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { if base.Kind() == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { name := prop.Value - if field, ok := base.FieldByName(c, name); ok { + field, ok := base.FieldByName(c, name) + if ok && field.StructData != nil { return true, field.FieldIndex, name } } @@ -35,7 +37,7 @@ func MethodIndex(c *Cache, env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: if env.Kind() == reflect.Struct { - if m, ok := env.Get(c, n.Value); ok { + if m, ok := env.Get(c, n.Value); ok && m.StructData != nil { return m.Method, m.MethodIndex, n.Value } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 29013573..5cb68d39 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -34,17 +34,25 @@ const ( ) type Nature struct { - Type reflect.Type // Type of the value. If nil, then value is unknown. - Func *builtin.Function // Used to pass function type from callee to CallNode. - ArrayOf *Nature // Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). - PredicateOut *Nature // Out nature of predicate. + Type reflect.Type // Type of the value. If nil, then value is unknown. + ArrayOf *Nature // Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). + *MapData + *StructData + Func *builtin.Function // Used to pass function type from callee to CallNode. + PredicateOut *Nature // Out nature of predicate. + Nil bool // If value is nil. + Strict bool // If map is types.StrictMap. + Method bool // If value retrieved from method. Usually used to determine amount of in arguments. +} + +type StructData struct { + FieldIndex []int // Index of field in type. + MethodIndex int // Index of method in type. +} + +type MapData struct { Fields map[string]Nature // Fields of map type. DefaultMapValue *Nature // Default value of map type. - Strict bool // If map is types.StrictMap. - Nil bool // If value is nil. - Method bool // If value retrieved from method. Usually used to determine amount of in arguments. - MethodIndex int // Index of method in type. - FieldIndex []int // Index of field in type. } type Cache struct { @@ -59,7 +67,7 @@ type rTypeWithKey struct { } func (n *Nature) IsAny() bool { - return n.Kind() == reflect.Interface && n.NumMethods() == 0 + return n.Type != nil && n.Type.Kind() == reflect.Interface && n.NumMethods() == 0 } func (n *Nature) IsUnknown() bool { @@ -100,7 +108,7 @@ func (n *Nature) Elem() Nature { case reflect.Ptr: return Nature{Type: n.Type.Elem()} case reflect.Map: - if n.DefaultMapValue != nil { + if n.MapData != nil && n.DefaultMapValue != nil { return *n.DefaultMapValue } return Nature{Type: n.Type.Elem()} @@ -141,6 +149,7 @@ func (n *Nature) MethodByName(c *Cache, name string) (Nature, bool) { } func (n *Nature) methodByNamePtr(c *Cache, name string) *Nature { + return n.methodByNameSlow(name) var ntPtr *Nature var cacheHit bool if c.methodByName == nil { @@ -173,12 +182,13 @@ func (n *Nature) methodByNameSlow(name string) *Nature { // different indexes for different types which implement // the same interface. return &Nature{Type: method.Type} - } else { - return &Nature{ - Type: method.Type, - Method: true, + } + return &Nature{ + Type: method.Type, + Method: true, + StructData: &StructData{ MethodIndex: method.Index, - } + }, } } @@ -240,7 +250,12 @@ func (n *Nature) fieldByNameSlow(name string) *Nature { return nil } if field, ok := fetchField(n.Type, name); ok { - return &Nature{Type: field.Type, FieldIndex: field.Index} + return &Nature{ + Type: field.Type, + StructData: &StructData{ + FieldIndex: field.Index, + }, + } } return nil } @@ -292,18 +307,21 @@ func (n *Nature) getSlow(c *Cache, name string) *Nature { } t := deref.Type(n.Type) - switch t.Kind() { case reflect.Struct: if f, ok := fetchField(t, name); ok { return &Nature{ - Type: f.Type, - FieldIndex: f.Index, + Type: f.Type, + StructData: &StructData{ + FieldIndex: f.Index, + }, } } case reflect.Map: - if f, ok := n.Fields[name]; ok { - return &f + if n.MapData != nil { + if f, ok := n.Fields[name]; ok { + return &f + } } } return nil @@ -319,9 +337,11 @@ func (n *Nature) All() map[string]Nature { for i := 0; i < n.Type.NumMethod(); i++ { method := n.Type.Method(i) table[method.Name] = Nature{ - Type: method.Type, - Method: true, - MethodIndex: method.Index, + Type: method.Type, + Method: true, + StructData: &StructData{ + MethodIndex: method.Index, + }, } } @@ -337,11 +357,13 @@ func (n *Nature) All() map[string]Nature { } case reflect.Map: - for key, nt := range n.Fields { - if _, ok := table[key]; ok { - continue + if n.MapData != nil { + for key, nt := range n.Fields { + if _, ok := table[key]; ok { + continue + } + table[key] = nt } - table[key] = nt } } @@ -353,22 +375,18 @@ func (n *Nature) IsNumber() bool { } func (n *Nature) IsInteger() bool { - if n.PkgPath() == "" { - switch n.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true - } + switch n.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return n.PkgPath() == "" } return false } func (n *Nature) IsFloat() bool { - if n.PkgPath() == "" { - switch n.Kind() { - case reflect.Float32, reflect.Float64: - return true - } + switch n.Kind() { + case reflect.Float32, reflect.Float64: + return n.PkgPath() == "" } return false } diff --git a/checker/nature/utils.go b/checker/nature/utils.go index c1551546..a8296407 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -68,6 +68,9 @@ func StructFields(t reflect.Type) map[string]Nature { if _, ok := table[name]; ok { continue } + if typ.StructData == nil { + typ.StructData = new(StructData) + } typ.FieldIndex = append(f.Index, typ.FieldIndex...) table[name] = typ } @@ -78,8 +81,10 @@ func StructFields(t reflect.Type) map[string]Nature { continue } table[name] = Nature{ - Type: f.Type, - FieldIndex: f.Index, + Type: f.Type, + StructData: &StructData{ + FieldIndex: f.Index, + }, } } diff --git a/conf/env.go b/conf/env.go index 8b13df1e..063fca81 100644 --- a/conf/env.go +++ b/conf/env.go @@ -34,8 +34,10 @@ func Env(env any) Nature { case reflect.Map: n := Nature{ - Type: v.Type(), - Fields: make(map[string]Nature, v.Len()), + Type: v.Type(), + MapData: &MapData{ + Fields: make(map[string]Nature, v.Len()), + }, Strict: true, } diff --git a/types/types.go b/types/types.go index bb1cbe5f..72f18da2 100644 --- a/types/types.go +++ b/types/types.go @@ -101,12 +101,17 @@ const Extra = "[[__extra_keys__]]" func (m Map) Nature() Nature { nt := Nature{ - Type: reflect.TypeOf(map[string]any{}), - Fields: make(map[string]Nature, len(m)), + Type: reflect.TypeOf(map[string]any{}), + MapData: &MapData{ + Fields: make(map[string]Nature, len(m)), + }, Strict: true, } for k, v := range m { if k == Extra { + if nt.MapData == nil { + nt.MapData = new(MapData) + } nt.Strict = false natureOfDefaultValue := v.Nature() nt.DefaultMapValue = &natureOfDefaultValue @@ -156,8 +161,10 @@ type array struct { func (a array) Nature() Nature { of := a.of.Nature() return Nature{ - Type: reflect.TypeOf([]any{}), - Fields: make(map[string]Nature, 1), + Type: reflect.TypeOf([]any{}), + MapData: &MapData{ + Fields: make(map[string]Nature, 1), + }, ArrayOf: &of, } } From 06a425beca3632058a0f62da25ed370cd21ee7fa Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 21:55:17 -0300 Subject: [PATCH 11/24] enforce use of cache --- ast/node.go | 2 +- checker/checker.go | 307 ++++++++++++++++++----------------- checker/info.go | 17 +- checker/nature/nature.go | 292 ++++++++++++++++++++------------- checker/nature/utils.go | 18 +- compiler/compiler.go | 3 +- conf/config.go | 5 +- conf/env.go | 43 ++--- docgen/docgen.go | 5 +- patcher/operator_override.go | 4 +- types/types.go | 33 ++-- 11 files changed, 405 insertions(+), 324 deletions(-) diff --git a/ast/node.go b/ast/node.go index 02923ac5..8e44b7f0 100644 --- a/ast/node.go +++ b/ast/node.go @@ -66,7 +66,7 @@ func (n *base) Type() reflect.Type { // SetType sets the type of the node. func (n *base) SetType(t reflect.Type) { - n.nature.Type = t + n.nature = nature.FromType(nil, t) } // NilNode represents nil. diff --git a/checker/checker.go b/checker/checker.go index d9d5eace..ae2629ba 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -15,18 +15,17 @@ import ( ) var ( - anyType = reflect.TypeOf(new(any)).Elem() - - unknown = Nature{} - nilNature = Nature{Nil: true} - boolNature = Nature{Type: reflect.TypeOf(true)} - integerNature = Nature{Type: reflect.TypeOf(0)} - floatNature = Nature{Type: reflect.TypeOf(float64(0))} - stringNature = Nature{Type: reflect.TypeOf("")} - arrayNature = Nature{Type: reflect.TypeOf([]any{})} - mapNature = Nature{Type: reflect.TypeOf(map[string]any{})} - timeNature = Nature{Type: reflect.TypeOf(time.Time{})} - durationNature = Nature{Type: reflect.TypeOf(time.Duration(0))} + anyType = reflect.TypeOf(new(any)).Elem() + boolType = reflect.TypeOf(true) + intType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) + stringType = reflect.TypeOf("") + arrayType = reflect.TypeOf([]any{}) + mapType = reflect.TypeOf(map[string]any{}) + timeType = reflect.TypeOf(time.Time{}) + durationType = reflect.TypeOf(time.Duration(0)) + + anyTypeSlice = []reflect.Type{anyType} ) // ParseCheck parses input expression and checks its types. Also, it applies @@ -148,7 +147,7 @@ func (c *Checker) check(tree *parser.Tree) (reflect.Type, error) { return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } default: - if nt.Kind() != c.config.Expect { + if nt.Kind != c.config.Expect { return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } } @@ -241,11 +240,11 @@ func (v *Checker) error(node ast.Node, format string, args ...any) Nature { Message: fmt.Sprintf(format, args...), } } - return unknown + return FromType(&v.config.NtCache, nil) } func (v *Checker) nilNode(*ast.NilNode) Nature { - return nilNature + return NatureOf(&v.config.NtCache, nil) } func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { @@ -253,7 +252,7 @@ func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { return variable.nature } if node.Value == "$env" { - return unknown + return FromType(&v.config.NtCache, nil) } return v.ident(node, node.Value, v.config.Strict, true) @@ -261,41 +260,45 @@ func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { // ident method returns type of environment variable, builtin or function. func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Nature { - if nt, ok := v.config.Env.Get(&v.config.NtCache, name); ok { + if nt, ok := v.config.Env.Get(name); ok { return nt } if builtins { if fn, ok := v.config.Functions[name]; ok { - return Nature{Type: fn.Type(), Func: fn} + nt := FromType(&v.config.NtCache, fn.Type()) + nt.Func = fn + return nt } if fn, ok := v.config.Builtins[name]; ok { - return Nature{Type: fn.Type(), Func: fn} + nt := FromType(&v.config.NtCache, fn.Type()) + nt.Func = fn + return nt } } if v.config.Strict && strict { return v.error(node, "unknown name %s", name) } - return unknown + return FromType(&v.config.NtCache, nil) } func (v *Checker) integerNode(*ast.IntegerNode) Nature { - return integerNature + return FromType(&v.config.NtCache, intType) } func (v *Checker) floatNode(*ast.FloatNode) Nature { - return floatNature + return FromType(&v.config.NtCache, floatType) } func (v *Checker) boolNode(*ast.BoolNode) Nature { - return boolNature + return FromType(&v.config.NtCache, boolType) } func (v *Checker) stringNode(*ast.StringNode) Nature { - return stringNature + return FromType(&v.config.NtCache, stringType) } func (v *Checker) constantNode(node *ast.ConstantNode) Nature { - return Nature{Type: reflect.TypeOf(node.Value)} + return FromType(&v.config.NtCache, reflect.TypeOf(node.Value)) } func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { @@ -306,10 +309,10 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { case "!", "not": if nt.IsBool() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if nt.IsUnknown() { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "+", "-": @@ -317,7 +320,7 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { return nt } if nt.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } default: @@ -337,32 +340,32 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { switch node.Operator { case "==", "!=": if l.ComparableTo(r) { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "or", "||", "and", "&&": if l.IsBool() && r.IsBool() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.MaybeCompatible(r, BoolCheck) { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "<", ">", ">=", "<=": if l.IsNumber() && r.IsNumber() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.IsString() && r.IsString() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.IsTime() && r.IsTime() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.IsDuration() && r.IsDuration() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "-": @@ -370,16 +373,16 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l.PromoteNumericNature(r) } if l.IsTime() && r.IsTime() { - return durationNature + return FromType(&v.config.NtCache, durationType) } if l.IsTime() && r.IsDuration() { - return timeNature + return FromType(&v.config.NtCache, timeType) } if l.IsDuration() && r.IsDuration() { - return durationNature + return FromType(&v.config.NtCache, durationType) } if l.MaybeCompatible(r, NumberCheck, TimeCheck, DurationCheck) { - return unknown + return FromType(&v.config.NtCache, nil) } case "*": @@ -387,40 +390,40 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l.PromoteNumericNature(r) } if l.IsNumber() && r.IsDuration() { - return durationNature + return FromType(&v.config.NtCache, durationType) } if l.IsDuration() && r.IsNumber() { - return durationNature + return FromType(&v.config.NtCache, durationType) } if l.IsDuration() && r.IsDuration() { - return durationNature + return FromType(&v.config.NtCache, durationType) } if l.MaybeCompatible(r, NumberCheck, DurationCheck) { - return unknown + return FromType(&v.config.NtCache, nil) } case "/": if l.IsNumber() && r.IsNumber() { - return floatNature + return FromType(&v.config.NtCache, floatType) } if l.MaybeCompatible(r, NumberCheck) { - return floatNature + return FromType(&v.config.NtCache, floatType) } case "**", "^": if l.IsNumber() && r.IsNumber() { - return floatNature + return FromType(&v.config.NtCache, floatType) } if l.MaybeCompatible(r, NumberCheck) { - return floatNature + return FromType(&v.config.NtCache, floatType) } case "%": if l.IsInteger() && r.IsInteger() { - return integerNature + return FromType(&v.config.NtCache, intType) } if l.MaybeCompatible(r, IntegerCheck) { - return integerNature + return FromType(&v.config.NtCache, intType) } case "+": @@ -428,44 +431,44 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l.PromoteNumericNature(r) } if l.IsString() && r.IsString() { - return stringNature + return FromType(&v.config.NtCache, stringType) } if l.IsTime() && r.IsDuration() { - return timeNature + return FromType(&v.config.NtCache, timeType) } if l.IsDuration() && r.IsTime() { - return timeNature + return FromType(&v.config.NtCache, timeType) } if l.IsDuration() && r.IsDuration() { - return durationNature + return FromType(&v.config.NtCache, durationType) } if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { - return unknown + return FromType(&v.config.NtCache, nil) } case "in": if (l.IsString() || l.IsUnknown()) && r.IsStruct() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if r.IsMap() { rKey := r.Key() if !l.IsUnknown() && !l.AssignableTo(rKey) { return v.error(node, "cannot use %s as type %s in map key", l.String(), rKey.String()) } - return boolNature + return FromType(&v.config.NtCache, boolType) } if r.IsArray() { rElem := r.Elem() if !l.ComparableTo(rElem) { return v.error(node, "cannot use %s as type %s in array", l.String(), rElem.String()) } - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.IsUnknown() && r.IsAnyOf(StringCheck, ArrayCheck, MapCheck) { - return boolNature + return FromType(&v.config.NtCache, boolType) } if r.IsUnknown() { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "matches": @@ -476,26 +479,23 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { } } if l.IsString() && r.IsString() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.MaybeCompatible(r, StringCheck) { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "contains", "startsWith", "endsWith": if l.IsString() && r.IsString() { - return boolNature + return FromType(&v.config.NtCache, boolType) } if l.MaybeCompatible(r, StringCheck) { - return boolNature + return FromType(&v.config.NtCache, boolType) } case "..": - if l.IsInteger() && r.IsInteger() { - return integerNature.MakeArrayOf() - } - if l.MaybeCompatible(r, IntegerCheck) { - return integerNature.MakeArrayOf() + if l.IsInteger() && r.IsInteger() || l.MaybeCompatible(r, IntegerCheck) { + return ArrayFromType(&v.config.NtCache, intType) } case "??": @@ -506,12 +506,12 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l } if l.Nil && r.Nil { - return nilNature + return NatureOf(&v.config.NtCache, nil) } if r.AssignableTo(l) { return l } - return unknown + return FromType(&v.config.NtCache, nil) default: return v.error(node, "unknown operator (%s)", node.Operator) @@ -539,14 +539,14 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { } return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } - return unknown + return FromType(&v.config.NtCache, nil) } base := v.visit(node.Node) prop := v.visit(node.Property) if base.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } if name, ok := node.Property.(*ast.StringNode); ok { @@ -556,19 +556,19 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { // First, check methods defined on base type itself, // independent of which type it is. Without dereferencing. - if m, ok := base.MethodByName(&v.config.NtCache, name.Value); ok { + if m, ok := base.MethodByName(name.Value); ok { return m } } base = base.Deref() - switch base.Kind() { + switch base.Kind { case reflect.Map: if !prop.AssignableTo(base.Key()) && !prop.IsUnknown() { return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String()) } - if prop, ok := node.Property.(*ast.StringNode); ok && base.MapData != nil { + if prop, ok := node.Property.(*ast.StringNode); ok && base.Optional != nil { if field, ok := base.Fields[prop.Value]; ok { return field } else if base.Strict { @@ -586,8 +586,8 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value - if field, ok := base.FieldByName(&v.config.NtCache, propertyName); ok { - return Nature{Type: field.Type} + if field, ok := base.FieldByName(propertyName); ok { + return FromType(&v.config.NtCache, field.Type) } if node.Method { return v.error(node, "type %v has no method %v", base.String(), propertyName) @@ -611,10 +611,10 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { nt := v.visit(node.Node) if nt.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } - switch nt.Kind() { + switch nt.Kind { case reflect.String, reflect.Array, reflect.Slice: // ok default: @@ -651,7 +651,9 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { // checker pass we should replace anyType on method node // with new correct function return type. if typ := node.Type(); typ != nil && typ != anyType { - return node.Nature() + nt := node.Nature() + nt.Cache = &v.config.NtCache // AST doesn't cache nature info + return nt } return v.functionReturnType(node) @@ -660,7 +662,7 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { func (v *Checker) functionReturnType(node *ast.CallNode) Nature { nt := v.visit(node.Callee) if nt.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } if nt.Func != nil { @@ -681,13 +683,13 @@ func (v *Checker) functionReturnType(node *ast.CallNode) Nature { return v.error(node, "%v is nil; cannot call nil as function", fnName) } - if nt.Kind() == reflect.Func { + if nt.Kind == reflect.Func { outType, err := v.checkArguments(fnName, nt, node.Arguments, node) if err != nil { if v.err == nil { v.err = err } - return unknown + return FromType(&v.config.NtCache, nil) } return outType } @@ -709,12 +711,13 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { - return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return boolNature + return FromType(&v.config.NtCache, boolType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -731,13 +734,14 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { - return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } if collection.IsUnknown() { - return arrayNature + return FromType(&v.config.NtCache, arrayType) } collection = collection.Elem() return collection.MakeArrayOf() @@ -751,15 +755,15 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, varScope{"index", integerNature}) + v.begin(collection, varScope{"index", FromType(&v.config.NtCache, intType)}) predicate := v.visit(node.Arguments[1]) v.end() if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - return predicate.PredicateOut.MakeArrayOf() + return predicate.Ref.MakeArrayOf() } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -771,7 +775,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } if len(node.Arguments) == 1 { - return integerNature + return FromType(&v.config.NtCache, intType) } v.begin(collection) @@ -780,12 +784,13 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { - return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return integerNature + return FromType(&v.config.NtCache, intType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -803,12 +808,12 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { - return *predicate.Out(0) + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { + return predicate.Out(0) } } else { if collection.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } return collection.Elem() } @@ -826,13 +831,14 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { - return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } if collection.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } return collection.Elem() } @@ -851,12 +857,13 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !predicate.Out(0).IsBool() && !predicate.Out(0).IsUnknown() { - return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicate.Out(0)) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return integerNature + return FromType(&v.config.NtCache, intType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -873,11 +880,13 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { collection = collection.Elem() collection = collection.MakeArrayOf() - return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &collection} + nt := NatureOf(&v.config.NtCache, map[any][]any{}) + nt.Ref = &collection + return nt } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -898,7 +907,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && predicate.In(0).IsUnknown() { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { return collection } @@ -911,7 +920,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, varScope{"index", integerNature}, varScope{"acc", unknown}) + v.begin(collection, varScope{"index", FromType(&v.config.NtCache, intType)}, varScope{"acc", FromType(&v.config.NtCache, nil)}) predicate := v.visit(node.Arguments[1]) v.end() @@ -920,7 +929,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } if predicate.IsFunc() && predicate.NumOut() == 1 { - return *predicate.PredicateOut + return *predicate.Ref } return v.error(node.Arguments[1], "predicate should has two input and one output param") @@ -958,18 +967,18 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" { if s, ok := node.Arguments[1].(*ast.StringNode); ok { - if nt, ok := v.config.Env.Get(&v.config.NtCache, s.Value); ok { + if nt, ok := v.config.Env.Get(s.Value); ok { return nt } } - return unknown + return FromType(&v.config.NtCache, nil) } if base.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } - switch base.Kind() { + switch base.Kind { case reflect.Slice, reflect.Array: if !prop.IsInteger() && !prop.IsUnknown() { return v.error(node.Arguments[1], "non-integer slice index %s", prop.String()) @@ -999,21 +1008,21 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if err != nil { return v.error(node, "%v", err) } - return Nature{Type: t} + return FromType(&v.config.NtCache, t) } else if len(f.Types) == 0 { - nt, err := v.checkArguments(f.Name, Nature{Type: f.Type()}, arguments, node) + nt, err := v.checkArguments(f.Name, FromType(&v.config.NtCache, f.Type()), arguments, node) if err != nil { if v.err == nil { v.err = err } - return unknown + return FromType(&v.config.NtCache, nil) } // No type was specified, so we assume the function returns any. return nt } var lastErr *file.Error for _, t := range f.Types { - outNature, err := v.checkArguments(f.Name, Nature{Type: t}, arguments, node) + outNature, err := v.checkArguments(f.Name, FromType(&v.config.NtCache, t), arguments, node) if err != nil { lastErr = err continue @@ -1032,7 +1041,7 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if v.err == nil { v.err = lastErr } - return unknown + return FromType(&v.config.NtCache, nil) } return v.error(node, "no matching overload for %v", f.Name) @@ -1045,18 +1054,18 @@ func (v *Checker) checkArguments( node ast.Node, ) (Nature, *file.Error) { if fn.IsUnknown() { - return unknown, nil + return FromType(&v.config.NtCache, nil), nil } numOut := fn.NumOut() if numOut == 0 { - return unknown, &file.Error{ + return FromType(&v.config.NtCache, nil), &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } if numOut > 2 { - return unknown, &file.Error{ + return FromType(&v.config.NtCache, nil), &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), } @@ -1104,7 +1113,7 @@ func (v *Checker) checkArguments( for _, arg := range arguments { _ = v.visit(arg) } - return *fn.Out(0), err + return fn.Out(0), err } for i, arg := range arguments { @@ -1114,9 +1123,9 @@ func (v *Checker) checkArguments( if isVariadic && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. - in = fn.In(fnNumIn - 1).Elem() + in = fn.InElem(fnNumIn - 1) } else { - in = *fn.In(i + fnInOffset) + in = fn.In(i + fnInOffset) } if in.IsFloat() && argNature.IsInteger() { @@ -1124,17 +1133,16 @@ func (v *Checker) checkArguments( continue } - inKind := in.Kind() - if in.IsInteger() && argNature.IsInteger() && argNature.Kind() != inKind { + if in.IsInteger() && argNature.IsInteger() && argNature.Kind != in.Kind { traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) continue } if argNature.Nil { - if inKind == reflect.Ptr || inKind == reflect.Interface { + if in.Kind == reflect.Ptr || in.Kind == reflect.Interface { continue } - return unknown, &file.Error{ + return FromType(&v.config.NtCache, nil), &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in.String(), name), } @@ -1154,14 +1162,14 @@ func (v *Checker) checkArguments( } if !assignable && !argNature.IsUnknown() { - return unknown, &file.Error{ + return FromType(&v.config.NtCache, nil), &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature.String(), in.String(), name), } } } - return *fn.Out(0), nil + return fn.Out(0), nil } func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newNature Nature) { @@ -1209,10 +1217,9 @@ func (v *Checker) predicateNode(node *ast.PredicateNode) Nature { } else if !nt.Nil { out = append(out, nt.Type) } - return Nature{ - Type: reflect.FuncOf([]reflect.Type{anyType}, out, false), - PredicateOut: &nt, - } + n := FromType(&v.config.NtCache, reflect.FuncOf(anyTypeSlice, out, false)) + n.Ref = &nt + return n } func (v *Checker) pointerNode(node *ast.PointerNode) Nature { @@ -1222,9 +1229,9 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { if scope.collection.IsUnknown() { - return unknown + return FromType(&v.config.NtCache, nil) } - switch scope.collection.Kind() { + switch scope.collection.Kind { case reflect.Array, reflect.Slice: return scope.collection.Elem() } @@ -1241,7 +1248,7 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { } func (v *Checker) variableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { - if _, ok := v.config.Env.Get(&v.config.NtCache, node.Name); ok { + if _, ok := v.config.Env.Get(node.Name); ok { return v.error(node, "cannot redeclare %v", node.Name) } if _, ok := v.config.Functions[node.Name]; ok { @@ -1296,12 +1303,12 @@ func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { return t1 } if t1.Nil && t2.Nil { - return nilNature + return NatureOf(&v.config.NtCache, nil) } if t1.AssignableTo(t2) { return t1 } - return unknown + return FromType(&v.config.NtCache, nil) } func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { @@ -1310,7 +1317,7 @@ func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { for i, node := range node.Nodes { curr := v.visit(node) if i > 0 { - if curr.Kind() != prev.Kind() { + if curr.Kind != prev.Kind { allElementsAreSameType = false } } @@ -1319,18 +1326,18 @@ func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { if allElementsAreSameType { return prev.MakeArrayOf() } - return arrayNature + return FromType(&v.config.NtCache, arrayType) } func (v *Checker) mapNode(node *ast.MapNode) Nature { for _, pair := range node.Pairs { v.visit(pair) } - return mapNature + return FromType(&v.config.NtCache, mapType) } func (v *Checker) pairNode(node *ast.PairNode) Nature { v.visit(node.Key) v.visit(node.Value) - return nilNature + return NatureOf(&v.config.NtCache, nil) } diff --git a/checker/info.go b/checker/info.go index 4e47df8a..8bd38b3a 100644 --- a/checker/info.go +++ b/checker/info.go @@ -11,20 +11,21 @@ import ( func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if env.Kind() == reflect.Struct { - field, ok := env.Get(c, n.Value) - if ok && field.StructData != nil && len(field.FieldIndex) > 0 { + if env.Kind == reflect.Struct { + field, ok := env.Get(n.Value) + if ok && field.Optional != nil && len(field.FieldIndex) > 0 { return true, field.FieldIndex, n.Value } } case *ast.MemberNode: base := n.Node.Nature() + base.Cache = c // AST doesn't cache nature info base = base.Deref() - if base.Kind() == reflect.Struct { + if base.Kind == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { name := prop.Value - field, ok := base.FieldByName(c, name) - if ok && field.StructData != nil { + field, ok := base.FieldByName(name) + if ok && field.Optional != nil { return true, field.FieldIndex, name } } @@ -36,8 +37,8 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { func MethodIndex(c *Cache, env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if env.Kind() == reflect.Struct { - if m, ok := env.Get(c, n.Value); ok && m.StructData != nil { + if env.Kind == reflect.Struct { + if m, ok := env.Get(n.Value); ok && m.Optional != nil { return m.Method, m.MethodIndex, n.Value } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 5cb68d39..1ed3d4c5 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -10,13 +10,11 @@ import ( ) var ( + intType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) arrayType = reflect.TypeOf([]any{}) timeType = reflect.TypeOf(time.Time{}) durationType = reflect.TypeOf(time.Duration(0)) - - unknown = Nature{} - floatNature = Nature{Type: reflect.TypeOf(float64(0))} - integerNature = Nature{Type: reflect.TypeOf(0)} ) type NatureCheck int @@ -34,27 +32,44 @@ const ( ) type Nature struct { - Type reflect.Type // Type of the value. If nil, then value is unknown. - ArrayOf *Nature // Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). - *MapData - *StructData - Func *builtin.Function // Used to pass function type from callee to CallNode. - PredicateOut *Nature // Out nature of predicate. - Nil bool // If value is nil. - Strict bool // If map is types.StrictMap. - Method bool // If value retrieved from method. Usually used to determine amount of in arguments. -} - -type StructData struct { + // The order of the fields matter, check alignment before making changes. + + Type reflect.Type // Type of the value. If nil, then value is unknown. + Kind reflect.Kind // Kind of the value. + + *Cache + *Optional + Func *builtin.Function // Used to pass function type from callee to CallNode. + + // Ref is a reference used for multiple, disjoint purposes. When the Nature + // is for a: + // - Predicate: then Ref is the nature of the Out of the predicate. + // - Array-like types: then Ref is the Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). + Ref *Nature + + Nil bool // If value is nil. + Strict bool // If map is types.StrictMap. + Method bool // If value retrieved from method. Usually used to determine amount of in arguments. +} + +type Optional struct { + // struct-only data FieldIndex []int // Index of field in type. MethodIndex int // Index of method in type. -} -type MapData struct { + // map-only data Fields map[string]Nature // Fields of map type. DefaultMapValue *Nature // Default value of map type. + + // func-only data + inElem, outZero *Nature } +// Cache is a shared cache of type information. It is only used in the stages +// where type information becomes relevant, so packages like ast, parser, types, +// and lexer do not need to use the cache because they don't need any service +// from the Nature type, they only describe. However, when receiving a Nature +// from one of those packages, the cache must be set immediately. type Cache struct { methodByName map[rTypeWithKey]*Nature fieldByName map[rTypeWithKey]*Nature @@ -66,8 +81,36 @@ type rTypeWithKey struct { key string } +func NatureOf(c *Cache, i any) Nature { + // reflect.TypeOf(nil) returns nil, but in FromType we want to differentiate + // what nil means for us + if i == nil { + return Nature{Cache: c, Nil: true} + } + return FromType(c, reflect.TypeOf(i)) +} + +func FromType(c *Cache, t reflect.Type) Nature { + if t != nil { + k := t.Kind() + var opt *Optional + if k == reflect.Func { + opt = new(Optional) + } + return Nature{Type: t, Kind: k, Optional: opt, Cache: c} + } + return Nature{Cache: c} +} + +func ArrayFromType(c *Cache, t reflect.Type) Nature { + elem := FromType(c, t) + nt := FromType(c, arrayType) + nt.Ref = &elem + return nt +} + func (n *Nature) IsAny() bool { - return n.Type != nil && n.Type.Kind() == reflect.Interface && n.NumMethods() == 0 + return n.Type != nil && n.Kind == reflect.Interface && n.NumMethods() == 0 } func (n *Nature) IsUnknown() bool { @@ -85,40 +128,34 @@ func (n *Nature) Deref() Nature { ret := *n if ret.Type != nil { ret.Type = deref.Type(ret.Type) + ret.Kind = ret.Type.Kind() } return ret } -func (n *Nature) Kind() reflect.Kind { - if n.Type != nil { - return n.Type.Kind() - } - return reflect.Invalid -} - func (n *Nature) Key() Nature { - if n.Kind() == reflect.Map { - return Nature{Type: n.Type.Key()} + if n.Kind == reflect.Map { + return FromType(n.Cache, n.Type.Key()) } - return unknown + return FromType(n.Cache, nil) } func (n *Nature) Elem() Nature { - switch n.Kind() { + switch n.Kind { case reflect.Ptr: - return Nature{Type: n.Type.Elem()} + return FromType(n.Cache, n.Type.Elem()) case reflect.Map: - if n.MapData != nil && n.DefaultMapValue != nil { + if n.Optional != nil && n.DefaultMapValue != nil { return *n.DefaultMapValue } - return Nature{Type: n.Type.Elem()} - case reflect.Array, reflect.Slice: - if n.ArrayOf != nil { - return *n.ArrayOf + return FromType(n.Cache, n.Type.Elem()) + case reflect.Slice, reflect.Array: + if n.Ref != nil { + return *n.Ref } - return Nature{Type: n.Type.Elem()} + return FromType(n.Cache, n.Type.Elem()) } - return unknown + return FromType(n.Cache, nil) } func (n *Nature) AssignableTo(nt Nature) bool { @@ -141,25 +178,25 @@ func (n *Nature) NumMethods() int { return n.Type.NumMethod() } -func (n *Nature) MethodByName(c *Cache, name string) (Nature, bool) { - if ntPtr := n.methodByNamePtr(c, name); ntPtr != nil { +func (n *Nature) MethodByName(name string) (Nature, bool) { + if ntPtr := n.methodByNamePtr(name); ntPtr != nil { return *ntPtr, true } - return unknown, false + return FromType(n.Cache, nil), false } -func (n *Nature) methodByNamePtr(c *Cache, name string) *Nature { +func (n *Nature) methodByNamePtr(name string) *Nature { return n.methodByNameSlow(name) var ntPtr *Nature var cacheHit bool - if c.methodByName == nil { - c.methodByName = map[rTypeWithKey]*Nature{} + if n.Cache.methodByName == nil { + n.Cache.methodByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = c.methodByName[rTypeWithKey{n.Type, name}] + ntPtr, cacheHit = n.Cache.methodByName[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.methodByNameSlow(name) - c.methodByName[rTypeWithKey{n.Type, name}] = ntPtr + n.Cache.methodByName[rTypeWithKey{n.Type, name}] = ntPtr } return ntPtr } @@ -173,7 +210,8 @@ func (n *Nature) methodByNameSlow(name string) *Nature { return nil } - if n.Type.Kind() == reflect.Interface { + nt := FromType(n.Cache, method.Type) + if n.Kind == reflect.Interface { // In case of interface type method will not have a receiver, // and to prevent checker decreasing numbers of in arguments // return method type as not method (second argument is false). @@ -181,15 +219,14 @@ func (n *Nature) methodByNameSlow(name string) *Nature { // Also, we can not use m.Index here, because it will be // different indexes for different types which implement // the same interface. - return &Nature{Type: method.Type} + return &nt } - return &Nature{ - Type: method.Type, - Method: true, - StructData: &StructData{ - MethodIndex: method.Index, - }, + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.Method = true + nt.MethodIndex = method.Index + return &nt } func (n *Nature) NumIn() int { @@ -199,11 +236,33 @@ func (n *Nature) NumIn() int { return n.Type.NumIn() } -func (n *Nature) In(i int) *Nature { +func (n *Nature) InElem(i int) Nature { + if n.inElem == nil { + if n.Type == nil { + n2 := FromType(n.Cache, nil) + n.inElem = &n2 + } else { + n2 := FromType(n.Cache, n.Type.In(i)) + n2 = n2.Elem() + n.inElem = &n2 + } + } + return *n.inElem +} + +func (n *Nature) In(i int) Nature { if n.Type == nil { - return &unknown + return FromType(n.Cache, nil) + } + return FromType(n.Cache, n.Type.In(i)) +} + +func (n *Nature) IsFirstArgUnknown() bool { + if n.Type != nil { + n2 := FromType(n.Cache, n.Type.In(0)) + return n2.IsUnknown() } - return &Nature{Type: n.Type.In(i)} + return false } func (n *Nature) NumOut() int { @@ -213,11 +272,23 @@ func (n *Nature) NumOut() int { return n.Type.NumOut() } -func (n *Nature) Out(i int) *Nature { +func (n *Nature) Out(i int) Nature { + if i != 0 { + return n.out(i) + } + if n.outZero != nil { + return *n.outZero + } + nt := n.out(0) + n.outZero = &nt + return nt +} + +func (n *Nature) out(i int) Nature { if n.Type == nil { - return &unknown + return FromType(n.Cache, nil) } - return &Nature{Type: n.Type.Out(i)} + return FromType(n.Cache, n.Type.Out(i)) } func (n *Nature) IsVariadic() bool { @@ -227,22 +298,22 @@ func (n *Nature) IsVariadic() bool { return n.Type.IsVariadic() } -func (n *Nature) FieldByName(c *Cache, name string) (Nature, bool) { +func (n *Nature) FieldByName(name string) (Nature, bool) { var ntPtr *Nature var cacheHit bool - if c.fieldByName == nil { - c.fieldByName = map[rTypeWithKey]*Nature{} + if n.Cache.fieldByName == nil { + n.Cache.fieldByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = c.fieldByName[rTypeWithKey{n.Type, name}] + ntPtr, cacheHit = n.Cache.fieldByName[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.fieldByNameSlow(name) - c.fieldByName[rTypeWithKey{n.Type, name}] = ntPtr + n.Cache.fieldByName[rTypeWithKey{n.Type, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true } - return unknown, false + return FromType(n.Cache, nil), false } func (n *Nature) fieldByNameSlow(name string) *Nature { @@ -250,12 +321,12 @@ func (n *Nature) fieldByNameSlow(name string) *Nature { return nil } if field, ok := fetchField(n.Type, name); ok { - return &Nature{ - Type: field.Type, - StructData: &StructData{ - FieldIndex: field.Index, - }, + nt := FromType(n.Cache, field.Type) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.FieldIndex = field.Index + return &nt } return nil } @@ -279,30 +350,30 @@ func (n *Nature) IsFastMap() bool { return false } -func (n *Nature) Get(c *Cache, name string) (Nature, bool) { +func (n *Nature) Get(name string) (Nature, bool) { var ntPtr *Nature var cacheHit bool - if c.get == nil { - c.get = map[rTypeWithKey]*Nature{} + if n.Cache.get == nil { + n.Cache.get = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = c.get[rTypeWithKey{n.Type, name}] + ntPtr, cacheHit = n.Cache.get[rTypeWithKey{n.Type, name}] } if !cacheHit { - ntPtr = n.getSlow(c, name) - c.get[rTypeWithKey{n.Type, name}] = ntPtr + ntPtr = n.getSlow(name) + n.Cache.get[rTypeWithKey{n.Type, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true } - return unknown, false + return FromType(n.Cache, nil), false } -func (n *Nature) getSlow(c *Cache, name string) *Nature { +func (n *Nature) getSlow(name string) *Nature { if n.Type == nil { return nil } - if m := n.methodByNamePtr(c, name); m != nil { + if m := n.methodByNamePtr(name); m != nil { return m } @@ -310,15 +381,15 @@ func (n *Nature) getSlow(c *Cache, name string) *Nature { switch t.Kind() { case reflect.Struct: if f, ok := fetchField(t, name); ok { - return &Nature{ - Type: f.Type, - StructData: &StructData{ - FieldIndex: f.Index, - }, + nt := FromType(n.Cache, f.Type) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.FieldIndex = f.Index + return &nt } case reflect.Map: - if n.MapData != nil { + if n.Optional != nil { if f, ok := n.Fields[name]; ok { return &f } @@ -334,22 +405,22 @@ func (n *Nature) All() map[string]Nature { return table } - for i := 0; i < n.Type.NumMethod(); i++ { + for i := 0; i < n.NumMethods(); i++ { method := n.Type.Method(i) - table[method.Name] = Nature{ - Type: method.Type, - Method: true, - StructData: &StructData{ - MethodIndex: method.Index, - }, + nt := FromType(n.Cache, method.Type) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.Method = true + nt.MethodIndex = method.Index + table[method.Name] = nt } t := deref.Type(n.Type) switch t.Kind() { case reflect.Struct: - for name, nt := range StructFields(t) { + for name, nt := range StructFields(n.Cache, t) { if _, ok := table[name]; ok { continue } @@ -357,7 +428,7 @@ func (n *Nature) All() map[string]Nature { } case reflect.Map: - if n.MapData != nil { + if n.Optional != nil { for key, nt := range n.Fields { if _, ok := table[key]; ok { continue @@ -375,7 +446,7 @@ func (n *Nature) IsNumber() bool { } func (n *Nature) IsInteger() bool { - switch n.Kind() { + switch n.Kind { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return n.PkgPath() == "" @@ -384,7 +455,7 @@ func (n *Nature) IsInteger() bool { } func (n *Nature) IsFloat() bool { - switch n.Kind() { + switch n.Kind { case reflect.Float32, reflect.Float64: return n.PkgPath() == "" } @@ -393,12 +464,12 @@ func (n *Nature) IsFloat() bool { func (n *Nature) PromoteNumericNature(rhs Nature) Nature { if n.IsUnknown() || rhs.IsUnknown() { - return unknown + return FromType(n.Cache, nil) } if n.IsFloat() || rhs.IsFloat() { - return floatNature + return FromType(n.Cache, floatType) } - return integerNature + return FromType(n.Cache, intType) } func (n *Nature) IsTime() bool { @@ -410,32 +481,32 @@ func (n *Nature) IsDuration() bool { } func (n *Nature) IsBool() bool { - return n.Kind() == reflect.Bool + return n.Kind == reflect.Bool } func (n *Nature) IsString() bool { - return n.Kind() == reflect.String + return n.Kind == reflect.String } func (n *Nature) IsArray() bool { - k := n.Kind() + k := n.Kind return k == reflect.Slice || k == reflect.Array } func (n *Nature) IsMap() bool { - return n.Kind() == reflect.Map + return n.Kind == reflect.Map } func (n *Nature) IsStruct() bool { - return n.Kind() == reflect.Struct + return n.Kind == reflect.Struct } func (n *Nature) IsFunc() bool { - return n.Kind() == reflect.Func + return n.Kind == reflect.Func } func (n *Nature) IsPointer() bool { - return n.Kind() == reflect.Ptr + return n.Kind == reflect.Ptr } func (n *Nature) IsAnyOf(cs ...NatureCheck) bool { @@ -484,8 +555,7 @@ func (n *Nature) MaybeCompatible(rhs Nature, cs ...NatureCheck) bool { } func (n *Nature) MakeArrayOf() Nature { - return Nature{ - Type: arrayType, - ArrayOf: n, - } + nt := FromType(n.Cache, arrayType) + nt.Ref = n + return nt } diff --git a/checker/nature/utils.go b/checker/nature/utils.go index a8296407..f72acc5f 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -50,7 +50,7 @@ func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { return reflect.StructField{}, false } -func StructFields(t reflect.Type) map[string]Nature { +func StructFields(c *Cache, t reflect.Type) map[string]Nature { table := make(map[string]Nature) t = deref.Type(t) @@ -64,12 +64,12 @@ func StructFields(t reflect.Type) map[string]Nature { f := t.Field(i) if f.Anonymous { - for name, typ := range StructFields(f.Type) { + for name, typ := range StructFields(c, f.Type) { if _, ok := table[name]; ok { continue } - if typ.StructData == nil { - typ.StructData = new(StructData) + if typ.Optional == nil { + typ.Optional = new(Optional) } typ.FieldIndex = append(f.Index, typ.FieldIndex...) table[name] = typ @@ -80,12 +80,12 @@ func StructFields(t reflect.Type) map[string]Nature { if !ok { continue } - table[name] = Nature{ - Type: f.Type, - StructData: &StructData{ - FieldIndex: f.Index, - }, + nt := FromType(c, f.Type) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.FieldIndex = f.Index + table[name] = nt } } diff --git a/compiler/compiler.go b/compiler/compiler.go index b0fb45df..6829fc00 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -751,7 +751,7 @@ func (c *compiler) CallNode(node *ast.CallNode) { } } case *ast.IdentifierNode: - if t, ok := c.config.Env.MethodByName(c.ntCache, callee.Value); ok && t.Method { + if t, ok := c.config.Env.MethodByName(callee.Value); ok && t.Method { fnInOffset = 1 fnNumIn-- } @@ -1089,6 +1089,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.compile(arg) argType := arg.Type() argNature := arg.Nature() + argNature.Cache = c.ntCache // AST doesn't cache nature info if argType.Kind() == reflect.Ptr || argNature.IsUnknown() { if f.Deref == nil { // By default, builtins expect arguments to be dereferenced. diff --git a/conf/config.go b/conf/config.go index 6de19c7d..841fcdfc 100644 --- a/conf/config.go +++ b/conf/config.go @@ -50,6 +50,7 @@ func CreateNew() *Config { for _, f := range builtin.Builtins { c.Builtins[f.Name] = f } + c.Env.Cache = &c.NtCache return c } @@ -62,7 +63,7 @@ func New(env any) *Config { func (c *Config) WithEnv(env any) { c.EnvObject = env - c.Env = Env(env) + c.Env = Env(&c.NtCache, env) c.Strict = c.Env.Strict } @@ -93,7 +94,7 @@ func (c *Config) IsOverridden(name string) bool { if _, ok := c.Functions[name]; ok { return true } - if _, ok := c.Env.Get(&c.NtCache, name); ok { + if _, ok := c.Env.Get(name); ok { return true } return false diff --git a/conf/env.go b/conf/env.go index 063fca81..9fad8ef7 100644 --- a/conf/env.go +++ b/conf/env.go @@ -9,37 +9,36 @@ import ( "github.com/expr-lang/expr/types" ) -func Env(env any) Nature { +func Env(c *Cache, env any) Nature { if env == nil { - return Nature{ - Type: reflect.TypeOf(map[string]any{}), - Strict: true, - } + n := NatureOf(c, map[string]any{}) + n.Strict = true + return n } switch env := env.(type) { case types.Map: - return env.Nature() + nt := env.Nature() + nt.Cache = c + return nt } v := reflect.ValueOf(env) - d := deref.Value(v) + t := v.Type() - switch d.Kind() { + switch deref.Value(v).Kind() { case reflect.Struct: - return Nature{ - Type: v.Type(), - Strict: true, - } + n := FromType(c, t) + n.Strict = true + return n case reflect.Map: - n := Nature{ - Type: v.Type(), - MapData: &MapData{ - Fields: make(map[string]Nature, v.Len()), - }, - Strict: true, + n := FromType(c, v.Type()) + if n.Optional == nil { + n.Optional = new(Optional) } + n.Strict = true + n.Fields = make(map[string]Nature, v.Len()) for _, key := range v.MapKeys() { elem := v.MapIndex(key) @@ -51,14 +50,16 @@ func Env(env any) Nature { switch face := face.(type) { case types.Map: - n.Fields[key.String()] = face.Nature() + nt := face.Nature() + nt.Cache = c + n.Fields[key.String()] = nt default: if face == nil { - n.Fields[key.String()] = Nature{Nil: true} + n.Fields[key.String()] = NatureOf(c, nil) continue } - n.Fields[key.String()] = Nature{Type: reflect.TypeOf(face)} + n.Fields[key.String()] = NatureOf(c, face) } } diff --git a/docgen/docgen.go b/docgen/docgen.go index e93dbfa2..6d1ec59d 100644 --- a/docgen/docgen.go +++ b/docgen/docgen.go @@ -85,7 +85,7 @@ func CreateDoc(i any) *Context { PkgPath: deref.Type(reflect.TypeOf(i)).PkgPath(), } - env := conf.Env(i) + env := conf.Env(new(nature.Cache), i) for name, t := range env.All() { if _, ok := c.Variables[Identifier(name)]; ok { continue @@ -222,7 +222,8 @@ appendix: c.Types[name] = a } - for name, field := range nature.StructFields(t) { + ntCache := new(nature.Cache) + for name, field := range nature.StructFields(ntCache, t) { if isPrivate(name) || isProtobuf(name) { continue } diff --git a/patcher/operator_override.go b/patcher/operator_override.go index cf4287c2..e0a00593 100644 --- a/patcher/operator_override.go +++ b/patcher/operator_override.go @@ -63,7 +63,7 @@ func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (r func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) { for _, fn := range p.Overloads { - fnType, ok := p.Env.Get(p.NtCache, fn) + fnType, ok := p.Env.Get(fn) if !ok { continue } @@ -110,7 +110,7 @@ func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex func (p *OperatorOverloading) Check() { for _, fn := range p.Overloads { - fnType, foundType := p.Env.Get(p.NtCache, fn) + fnType, foundType := p.Env.Get(fn) fnFunc, foundFunc := p.Functions[fn] if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) { panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator)) diff --git a/types/types.go b/types/types.go index 72f18da2..a1517994 100644 --- a/types/types.go +++ b/types/types.go @@ -44,7 +44,7 @@ func TypeOf(v any) Type { type anyType struct{} func (anyType) Nature() Nature { - return Nature{Type: nil} + return FromType(nil, nil) } func (anyType) Equal(t Type) bool { @@ -58,7 +58,7 @@ func (anyType) String() string { type nilType struct{} func (nilType) Nature() Nature { - return Nature{Nil: true} + return NatureOf(nil, nil) } func (nilType) Equal(t Type) bool { @@ -77,7 +77,7 @@ type rtype struct { } func (r rtype) Nature() Nature { - return Nature{Type: r.t} + return FromType(nil, r.t) } func (r rtype) Equal(t Type) bool { @@ -100,17 +100,16 @@ type Map map[string]Type const Extra = "[[__extra_keys__]]" func (m Map) Nature() Nature { - nt := Nature{ - Type: reflect.TypeOf(map[string]any{}), - MapData: &MapData{ - Fields: make(map[string]Nature, len(m)), - }, - Strict: true, + nt := NatureOf(nil, map[string]any{}) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.Fields = make(map[string]Nature, len(m)) + nt.Strict = true for k, v := range m { if k == Extra { - if nt.MapData == nil { - nt.MapData = new(MapData) + if nt.Optional == nil { + nt.Optional = new(Optional) } nt.Strict = false natureOfDefaultValue := v.Nature() @@ -160,13 +159,13 @@ type array struct { func (a array) Nature() Nature { of := a.of.Nature() - return Nature{ - Type: reflect.TypeOf([]any{}), - MapData: &MapData{ - Fields: make(map[string]Nature, 1), - }, - ArrayOf: &of, + nt := NatureOf(nil, []any{}) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.Fields = make(map[string]Nature, 1) + nt.Ref = &of + return nt } func (a array) Equal(t Type) bool { From d9ba735c068e8a8f9054cf8934398ccca4ee5c18 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Fri, 29 Aug 2025 22:20:36 -0300 Subject: [PATCH 12/24] improve ergonomics of FromType and NatureOf --- ast/node.go | 2 +- checker/checker.go | 178 +++++++++++++++++++-------------------- checker/nature/nature.go | 74 +++++++++------- checker/nature/utils.go | 2 +- conf/env.go | 10 +-- types/types.go | 10 +-- 6 files changed, 144 insertions(+), 132 deletions(-) diff --git a/ast/node.go b/ast/node.go index 8e44b7f0..40c507aa 100644 --- a/ast/node.go +++ b/ast/node.go @@ -66,7 +66,7 @@ func (n *base) Type() reflect.Type { // SetType sets the type of the node. func (n *base) SetType(t reflect.Type) { - n.nature = nature.FromType(nil, t) + n.nature = nature.FromType(t) } // NilNode represents nil. diff --git a/checker/checker.go b/checker/checker.go index ae2629ba..3708e455 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -183,7 +183,7 @@ func (v *Checker) visit(node ast.Node) Nature { var nt Nature switch n := node.(type) { case *ast.NilNode: - nt = v.nilNode(n) + nt = v.config.NtCache.NatureOf(nil) case *ast.IdentifierNode: nt = v.identifierNode(n) case *ast.IntegerNode: @@ -240,11 +240,7 @@ func (v *Checker) error(node ast.Node, format string, args ...any) Nature { Message: fmt.Sprintf(format, args...), } } - return FromType(&v.config.NtCache, nil) -} - -func (v *Checker) nilNode(*ast.NilNode) Nature { - return NatureOf(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { @@ -252,7 +248,7 @@ func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { return variable.nature } if node.Value == "$env" { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } return v.ident(node, node.Value, v.config.Strict, true) @@ -265,12 +261,12 @@ func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Natur } if builtins { if fn, ok := v.config.Functions[name]; ok { - nt := FromType(&v.config.NtCache, fn.Type()) + nt := v.config.NtCache.FromType(fn.Type()) nt.Func = fn return nt } if fn, ok := v.config.Builtins[name]; ok { - nt := FromType(&v.config.NtCache, fn.Type()) + nt := v.config.NtCache.FromType(fn.Type()) nt.Func = fn return nt } @@ -278,27 +274,27 @@ func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Natur if v.config.Strict && strict { return v.error(node, "unknown name %s", name) } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } func (v *Checker) integerNode(*ast.IntegerNode) Nature { - return FromType(&v.config.NtCache, intType) + return v.config.NtCache.FromType(intType) } func (v *Checker) floatNode(*ast.FloatNode) Nature { - return FromType(&v.config.NtCache, floatType) + return v.config.NtCache.FromType(floatType) } func (v *Checker) boolNode(*ast.BoolNode) Nature { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } func (v *Checker) stringNode(*ast.StringNode) Nature { - return FromType(&v.config.NtCache, stringType) + return v.config.NtCache.FromType(stringType) } func (v *Checker) constantNode(node *ast.ConstantNode) Nature { - return FromType(&v.config.NtCache, reflect.TypeOf(node.Value)) + return v.config.NtCache.FromType(reflect.TypeOf(node.Value)) } func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { @@ -309,10 +305,10 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { case "!", "not": if nt.IsBool() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if nt.IsUnknown() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "+", "-": @@ -320,7 +316,7 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { return nt } if nt.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } default: @@ -340,32 +336,32 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { switch node.Operator { case "==", "!=": if l.ComparableTo(r) { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "or", "||", "and", "&&": if l.IsBool() && r.IsBool() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.MaybeCompatible(r, BoolCheck) { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "<", ">", ">=", "<=": if l.IsNumber() && r.IsNumber() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.IsString() && r.IsString() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.IsTime() && r.IsTime() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.IsDuration() && r.IsDuration() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "-": @@ -373,16 +369,16 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l.PromoteNumericNature(r) } if l.IsTime() && r.IsTime() { - return FromType(&v.config.NtCache, durationType) + return v.config.NtCache.FromType(durationType) } if l.IsTime() && r.IsDuration() { - return FromType(&v.config.NtCache, timeType) + return v.config.NtCache.FromType(timeType) } if l.IsDuration() && r.IsDuration() { - return FromType(&v.config.NtCache, durationType) + return v.config.NtCache.FromType(durationType) } if l.MaybeCompatible(r, NumberCheck, TimeCheck, DurationCheck) { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } case "*": @@ -390,40 +386,40 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l.PromoteNumericNature(r) } if l.IsNumber() && r.IsDuration() { - return FromType(&v.config.NtCache, durationType) + return v.config.NtCache.FromType(durationType) } if l.IsDuration() && r.IsNumber() { - return FromType(&v.config.NtCache, durationType) + return v.config.NtCache.FromType(durationType) } if l.IsDuration() && r.IsDuration() { - return FromType(&v.config.NtCache, durationType) + return v.config.NtCache.FromType(durationType) } if l.MaybeCompatible(r, NumberCheck, DurationCheck) { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } case "/": if l.IsNumber() && r.IsNumber() { - return FromType(&v.config.NtCache, floatType) + return v.config.NtCache.FromType(floatType) } if l.MaybeCompatible(r, NumberCheck) { - return FromType(&v.config.NtCache, floatType) + return v.config.NtCache.FromType(floatType) } case "**", "^": if l.IsNumber() && r.IsNumber() { - return FromType(&v.config.NtCache, floatType) + return v.config.NtCache.FromType(floatType) } if l.MaybeCompatible(r, NumberCheck) { - return FromType(&v.config.NtCache, floatType) + return v.config.NtCache.FromType(floatType) } case "%": if l.IsInteger() && r.IsInteger() { - return FromType(&v.config.NtCache, intType) + return v.config.NtCache.FromType(intType) } if l.MaybeCompatible(r, IntegerCheck) { - return FromType(&v.config.NtCache, intType) + return v.config.NtCache.FromType(intType) } case "+": @@ -431,44 +427,44 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l.PromoteNumericNature(r) } if l.IsString() && r.IsString() { - return FromType(&v.config.NtCache, stringType) + return v.config.NtCache.FromType(stringType) } if l.IsTime() && r.IsDuration() { - return FromType(&v.config.NtCache, timeType) + return v.config.NtCache.FromType(timeType) } if l.IsDuration() && r.IsTime() { - return FromType(&v.config.NtCache, timeType) + return v.config.NtCache.FromType(timeType) } if l.IsDuration() && r.IsDuration() { - return FromType(&v.config.NtCache, durationType) + return v.config.NtCache.FromType(durationType) } if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } case "in": if (l.IsString() || l.IsUnknown()) && r.IsStruct() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if r.IsMap() { rKey := r.Key() if !l.IsUnknown() && !l.AssignableTo(rKey) { return v.error(node, "cannot use %s as type %s in map key", l.String(), rKey.String()) } - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if r.IsArray() { rElem := r.Elem() if !l.ComparableTo(rElem) { return v.error(node, "cannot use %s as type %s in array", l.String(), rElem.String()) } - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.IsUnknown() && r.IsAnyOf(StringCheck, ArrayCheck, MapCheck) { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if r.IsUnknown() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "matches": @@ -479,18 +475,18 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { } } if l.IsString() && r.IsString() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.MaybeCompatible(r, StringCheck) { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "contains", "startsWith", "endsWith": if l.IsString() && r.IsString() { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } if l.MaybeCompatible(r, StringCheck) { - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } case "..": @@ -506,12 +502,12 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return l } if l.Nil && r.Nil { - return NatureOf(&v.config.NtCache, nil) + return v.config.NtCache.NatureOf(nil) } if r.AssignableTo(l) { return l } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) default: return v.error(node, "unknown operator (%s)", node.Operator) @@ -539,14 +535,14 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { } return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } base := v.visit(node.Node) prop := v.visit(node.Property) if base.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } if name, ok := node.Property.(*ast.StringNode); ok { @@ -587,7 +583,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value if field, ok := base.FieldByName(propertyName); ok { - return FromType(&v.config.NtCache, field.Type) + return v.config.NtCache.FromType(field.Type) } if node.Method { return v.error(node, "type %v has no method %v", base.String(), propertyName) @@ -611,7 +607,7 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { nt := v.visit(node.Node) if nt.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } switch nt.Kind { @@ -662,7 +658,7 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { func (v *Checker) functionReturnType(node *ast.CallNode) Nature { nt := v.visit(node.Callee) if nt.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } if nt.Func != nil { @@ -689,7 +685,7 @@ func (v *Checker) functionReturnType(node *ast.CallNode) Nature { if v.err == nil { v.err = err } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } return outType } @@ -717,7 +713,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if !predicateOut.IsBool() && !predicateOut.IsUnknown() { return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return FromType(&v.config.NtCache, boolType) + return v.config.NtCache.FromType(boolType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -741,7 +737,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } if collection.IsUnknown() { - return FromType(&v.config.NtCache, arrayType) + return v.config.NtCache.FromType(arrayType) } collection = collection.Elem() return collection.MakeArrayOf() @@ -755,7 +751,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, varScope{"index", FromType(&v.config.NtCache, intType)}) + v.begin(collection, varScope{"index", v.config.NtCache.FromType(intType)}) predicate := v.visit(node.Arguments[1]) v.end() @@ -775,7 +771,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } if len(node.Arguments) == 1 { - return FromType(&v.config.NtCache, intType) + return v.config.NtCache.FromType(intType) } v.begin(collection) @@ -790,7 +786,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return FromType(&v.config.NtCache, intType) + return v.config.NtCache.FromType(intType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -813,7 +809,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } } else { if collection.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } return collection.Elem() } @@ -838,7 +834,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } if collection.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } return collection.Elem() } @@ -863,7 +859,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { if !predicateOut.IsBool() && !predicateOut.IsUnknown() { return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return FromType(&v.config.NtCache, intType) + return v.config.NtCache.FromType(intType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -884,7 +880,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { collection = collection.Elem() collection = collection.MakeArrayOf() - nt := NatureOf(&v.config.NtCache, map[any][]any{}) + nt := v.config.NtCache.NatureOf(map[any][]any{}) nt.Ref = &collection return nt } @@ -920,7 +916,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, varScope{"index", FromType(&v.config.NtCache, intType)}, varScope{"acc", FromType(&v.config.NtCache, nil)}) + v.begin(collection, varScope{"index", v.config.NtCache.FromType(intType)}, varScope{"acc", v.config.NtCache.FromType(nil)}) predicate := v.visit(node.Arguments[1]) v.end() @@ -971,11 +967,11 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { return nt } } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } if base.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } switch base.Kind { @@ -1008,21 +1004,21 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if err != nil { return v.error(node, "%v", err) } - return FromType(&v.config.NtCache, t) + return v.config.NtCache.FromType(t) } else if len(f.Types) == 0 { - nt, err := v.checkArguments(f.Name, FromType(&v.config.NtCache, f.Type()), arguments, node) + nt, err := v.checkArguments(f.Name, v.config.NtCache.FromType(f.Type()), arguments, node) if err != nil { if v.err == nil { v.err = err } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } // No type was specified, so we assume the function returns any. return nt } var lastErr *file.Error for _, t := range f.Types { - outNature, err := v.checkArguments(f.Name, FromType(&v.config.NtCache, t), arguments, node) + outNature, err := v.checkArguments(f.Name, v.config.NtCache.FromType(t), arguments, node) if err != nil { lastErr = err continue @@ -1041,7 +1037,7 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if v.err == nil { v.err = lastErr } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } return v.error(node, "no matching overload for %v", f.Name) @@ -1054,18 +1050,18 @@ func (v *Checker) checkArguments( node ast.Node, ) (Nature, *file.Error) { if fn.IsUnknown() { - return FromType(&v.config.NtCache, nil), nil + return v.config.NtCache.FromType(nil), nil } numOut := fn.NumOut() if numOut == 0 { - return FromType(&v.config.NtCache, nil), &file.Error{ + return v.config.NtCache.FromType(nil), &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } if numOut > 2 { - return FromType(&v.config.NtCache, nil), &file.Error{ + return v.config.NtCache.FromType(nil), &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), } @@ -1142,7 +1138,7 @@ func (v *Checker) checkArguments( if in.Kind == reflect.Ptr || in.Kind == reflect.Interface { continue } - return FromType(&v.config.NtCache, nil), &file.Error{ + return v.config.NtCache.FromType(nil), &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in.String(), name), } @@ -1162,7 +1158,7 @@ func (v *Checker) checkArguments( } if !assignable && !argNature.IsUnknown() { - return FromType(&v.config.NtCache, nil), &file.Error{ + return v.config.NtCache.FromType(nil), &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature.String(), in.String(), name), } @@ -1217,7 +1213,7 @@ func (v *Checker) predicateNode(node *ast.PredicateNode) Nature { } else if !nt.Nil { out = append(out, nt.Type) } - n := FromType(&v.config.NtCache, reflect.FuncOf(anyTypeSlice, out, false)) + n := v.config.NtCache.FromType(reflect.FuncOf(anyTypeSlice, out, false)) n.Ref = &nt return n } @@ -1229,7 +1225,7 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { if scope.collection.IsUnknown() { - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } switch scope.collection.Kind { case reflect.Array, reflect.Slice: @@ -1303,12 +1299,12 @@ func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { return t1 } if t1.Nil && t2.Nil { - return NatureOf(&v.config.NtCache, nil) + return v.config.NtCache.NatureOf(nil) } if t1.AssignableTo(t2) { return t1 } - return FromType(&v.config.NtCache, nil) + return v.config.NtCache.FromType(nil) } func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { @@ -1326,18 +1322,18 @@ func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { if allElementsAreSameType { return prev.MakeArrayOf() } - return FromType(&v.config.NtCache, arrayType) + return v.config.NtCache.FromType(arrayType) } func (v *Checker) mapNode(node *ast.MapNode) Nature { for _, pair := range node.Pairs { v.visit(pair) } - return FromType(&v.config.NtCache, mapType) + return v.config.NtCache.FromType(mapType) } func (v *Checker) pairNode(node *ast.PairNode) Nature { v.visit(node.Key) v.visit(node.Value) - return NatureOf(&v.config.NtCache, nil) + return v.config.NtCache.NatureOf(nil) } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 1ed3d4c5..d85691c3 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -81,16 +81,20 @@ type rTypeWithKey struct { key string } -func NatureOf(c *Cache, i any) Nature { +// NatureOf returns a Nature describing "i". If "i" is nil then it returns a +// Nature describing the value "nil". +func (c *Cache) NatureOf(i any) Nature { // reflect.TypeOf(nil) returns nil, but in FromType we want to differentiate // what nil means for us if i == nil { return Nature{Cache: c, Nil: true} } - return FromType(c, reflect.TypeOf(i)) + return c.FromType(reflect.TypeOf(i)) } -func FromType(c *Cache, t reflect.Type) Nature { +// FromType returns a Nature describing a value of type "t". If "t" is nil then +// it returns a Nature describing an unknown value. +func (c *Cache) FromType(t reflect.Type) Nature { if t != nil { k := t.Kind() var opt *Optional @@ -102,9 +106,21 @@ func FromType(c *Cache, t reflect.Type) Nature { return Nature{Cache: c} } +// NatureOf calls NatureOf on a nil *Cache. See the comment on Cache. +func NatureOf(i any) Nature { + var c *Cache + return c.NatureOf(i) +} + +// FromType calls FromType on a nil *Cache. See the comment on Cache. +func FromType(t reflect.Type) Nature { + var c *Cache + return c.FromType(t) +} + func ArrayFromType(c *Cache, t reflect.Type) Nature { - elem := FromType(c, t) - nt := FromType(c, arrayType) + elem := c.FromType(t) + nt := c.FromType(arrayType) nt.Ref = &elem return nt } @@ -135,27 +151,27 @@ func (n *Nature) Deref() Nature { func (n *Nature) Key() Nature { if n.Kind == reflect.Map { - return FromType(n.Cache, n.Type.Key()) + return n.Cache.FromType(n.Type.Key()) } - return FromType(n.Cache, nil) + return n.Cache.FromType(nil) } func (n *Nature) Elem() Nature { switch n.Kind { case reflect.Ptr: - return FromType(n.Cache, n.Type.Elem()) + return n.Cache.FromType(n.Type.Elem()) case reflect.Map: if n.Optional != nil && n.DefaultMapValue != nil { return *n.DefaultMapValue } - return FromType(n.Cache, n.Type.Elem()) + return n.Cache.FromType(n.Type.Elem()) case reflect.Slice, reflect.Array: if n.Ref != nil { return *n.Ref } - return FromType(n.Cache, n.Type.Elem()) + return n.Cache.FromType(n.Type.Elem()) } - return FromType(n.Cache, nil) + return n.Cache.FromType(nil) } func (n *Nature) AssignableTo(nt Nature) bool { @@ -182,7 +198,7 @@ func (n *Nature) MethodByName(name string) (Nature, bool) { if ntPtr := n.methodByNamePtr(name); ntPtr != nil { return *ntPtr, true } - return FromType(n.Cache, nil), false + return n.Cache.FromType(nil), false } func (n *Nature) methodByNamePtr(name string) *Nature { @@ -210,7 +226,7 @@ func (n *Nature) methodByNameSlow(name string) *Nature { return nil } - nt := FromType(n.Cache, method.Type) + nt := n.Cache.FromType(method.Type) if n.Kind == reflect.Interface { // In case of interface type method will not have a receiver, // and to prevent checker decreasing numbers of in arguments @@ -239,10 +255,10 @@ func (n *Nature) NumIn() int { func (n *Nature) InElem(i int) Nature { if n.inElem == nil { if n.Type == nil { - n2 := FromType(n.Cache, nil) + n2 := n.Cache.FromType(nil) n.inElem = &n2 } else { - n2 := FromType(n.Cache, n.Type.In(i)) + n2 := n.Cache.FromType(n.Type.In(i)) n2 = n2.Elem() n.inElem = &n2 } @@ -252,14 +268,14 @@ func (n *Nature) InElem(i int) Nature { func (n *Nature) In(i int) Nature { if n.Type == nil { - return FromType(n.Cache, nil) + return n.Cache.FromType(nil) } - return FromType(n.Cache, n.Type.In(i)) + return n.Cache.FromType(n.Type.In(i)) } func (n *Nature) IsFirstArgUnknown() bool { if n.Type != nil { - n2 := FromType(n.Cache, n.Type.In(0)) + n2 := n.Cache.FromType(n.Type.In(0)) return n2.IsUnknown() } return false @@ -286,9 +302,9 @@ func (n *Nature) Out(i int) Nature { func (n *Nature) out(i int) Nature { if n.Type == nil { - return FromType(n.Cache, nil) + return n.Cache.FromType(nil) } - return FromType(n.Cache, n.Type.Out(i)) + return n.Cache.FromType(n.Type.Out(i)) } func (n *Nature) IsVariadic() bool { @@ -313,7 +329,7 @@ func (n *Nature) FieldByName(name string) (Nature, bool) { if ntPtr != nil { return *ntPtr, true } - return FromType(n.Cache, nil), false + return n.Cache.FromType(nil), false } func (n *Nature) fieldByNameSlow(name string) *Nature { @@ -321,7 +337,7 @@ func (n *Nature) fieldByNameSlow(name string) *Nature { return nil } if field, ok := fetchField(n.Type, name); ok { - nt := FromType(n.Cache, field.Type) + nt := n.Cache.FromType(field.Type) if nt.Optional == nil { nt.Optional = new(Optional) } @@ -365,7 +381,7 @@ func (n *Nature) Get(name string) (Nature, bool) { if ntPtr != nil { return *ntPtr, true } - return FromType(n.Cache, nil), false + return n.Cache.FromType(nil), false } func (n *Nature) getSlow(name string) *Nature { @@ -381,7 +397,7 @@ func (n *Nature) getSlow(name string) *Nature { switch t.Kind() { case reflect.Struct: if f, ok := fetchField(t, name); ok { - nt := FromType(n.Cache, f.Type) + nt := n.Cache.FromType(f.Type) if nt.Optional == nil { nt.Optional = new(Optional) } @@ -407,7 +423,7 @@ func (n *Nature) All() map[string]Nature { for i := 0; i < n.NumMethods(); i++ { method := n.Type.Method(i) - nt := FromType(n.Cache, method.Type) + nt := n.Cache.FromType(method.Type) if nt.Optional == nil { nt.Optional = new(Optional) } @@ -464,12 +480,12 @@ func (n *Nature) IsFloat() bool { func (n *Nature) PromoteNumericNature(rhs Nature) Nature { if n.IsUnknown() || rhs.IsUnknown() { - return FromType(n.Cache, nil) + return n.Cache.FromType(nil) } if n.IsFloat() || rhs.IsFloat() { - return FromType(n.Cache, floatType) + return n.Cache.FromType(floatType) } - return FromType(n.Cache, intType) + return n.Cache.FromType(intType) } func (n *Nature) IsTime() bool { @@ -555,7 +571,7 @@ func (n *Nature) MaybeCompatible(rhs Nature, cs ...NatureCheck) bool { } func (n *Nature) MakeArrayOf() Nature { - nt := FromType(n.Cache, arrayType) + nt := n.Cache.FromType(arrayType) nt.Ref = n return nt } diff --git a/checker/nature/utils.go b/checker/nature/utils.go index f72acc5f..782992a0 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -80,7 +80,7 @@ func StructFields(c *Cache, t reflect.Type) map[string]Nature { if !ok { continue } - nt := FromType(c, f.Type) + nt := c.FromType(f.Type) if nt.Optional == nil { nt.Optional = new(Optional) } diff --git a/conf/env.go b/conf/env.go index 9fad8ef7..63dc5255 100644 --- a/conf/env.go +++ b/conf/env.go @@ -11,7 +11,7 @@ import ( func Env(c *Cache, env any) Nature { if env == nil { - n := NatureOf(c, map[string]any{}) + n := c.NatureOf(map[string]any{}) n.Strict = true return n } @@ -28,12 +28,12 @@ func Env(c *Cache, env any) Nature { switch deref.Value(v).Kind() { case reflect.Struct: - n := FromType(c, t) + n := c.FromType(t) n.Strict = true return n case reflect.Map: - n := FromType(c, v.Type()) + n := c.FromType(v.Type()) if n.Optional == nil { n.Optional = new(Optional) } @@ -56,10 +56,10 @@ func Env(c *Cache, env any) Nature { default: if face == nil { - n.Fields[key.String()] = NatureOf(c, nil) + n.Fields[key.String()] = c.NatureOf(nil) continue } - n.Fields[key.String()] = NatureOf(c, face) + n.Fields[key.String()] = c.NatureOf(face) } } diff --git a/types/types.go b/types/types.go index a1517994..05e0ae70 100644 --- a/types/types.go +++ b/types/types.go @@ -44,7 +44,7 @@ func TypeOf(v any) Type { type anyType struct{} func (anyType) Nature() Nature { - return FromType(nil, nil) + return FromType(nil) } func (anyType) Equal(t Type) bool { @@ -58,7 +58,7 @@ func (anyType) String() string { type nilType struct{} func (nilType) Nature() Nature { - return NatureOf(nil, nil) + return NatureOf(nil) } func (nilType) Equal(t Type) bool { @@ -77,7 +77,7 @@ type rtype struct { } func (r rtype) Nature() Nature { - return FromType(nil, r.t) + return FromType(r.t) } func (r rtype) Equal(t Type) bool { @@ -100,7 +100,7 @@ type Map map[string]Type const Extra = "[[__extra_keys__]]" func (m Map) Nature() Nature { - nt := NatureOf(nil, map[string]any{}) + nt := NatureOf(map[string]any{}) if nt.Optional == nil { nt.Optional = new(Optional) } @@ -159,7 +159,7 @@ type array struct { func (a array) Nature() Nature { of := a.of.Nature() - nt := NatureOf(nil, []any{}) + nt := NatureOf([]any{}) if nt.Optional == nil { nt.Optional = new(Optional) } From aa33cf99b80d2cd462cfdc8b6f12f10569e77e09 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sat, 30 Aug 2025 05:59:35 -0300 Subject: [PATCH 13/24] improve struct method fetching --- checker/checker.go | 4 -- checker/nature/nature.go | 65 +++++++++++++++++++---------- checker/nature/utils.go | 90 ++++++++++++++++++++++++++++------------ 3 files changed, 105 insertions(+), 54 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 3708e455..76cc4d0e 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -652,10 +652,6 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { return nt } - return v.functionReturnType(node) -} - -func (v *Checker) functionReturnType(node *ast.CallNode) Nature { nt := v.visit(node.Callee) if nt.IsUnknown() { return v.config.NtCache.FromType(nil) diff --git a/checker/nature/nature.go b/checker/nature/nature.go index d85691c3..f9a057d4 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -74,6 +74,7 @@ type Cache struct { methodByName map[rTypeWithKey]*Nature fieldByName map[rTypeWithKey]*Nature get map[rTypeWithKey]*Nature + structs map[reflect.Type]map[string]Nature } type rTypeWithKey struct { @@ -81,6 +82,23 @@ type rTypeWithKey struct { key string } +type structInfo struct { + fields map[string]Nature +} + +func newStructInfo(t reflect.Type) *structInfo { + numField := t.NumField() + return &structInfo{ + fields: make(map[string]Nature, numField), + } +} + +func (si *structInfo) set(name string, nt Nature) { + if _, ok := si.fields[name]; !ok { + si.fields[name] = nt + } +} + // NatureOf returns a Nature describing "i". If "i" is nil then it returns a // Nature describing the value "nil". func (c *Cache) NatureOf(i any) Nature { @@ -315,16 +333,21 @@ func (n *Nature) IsVariadic() bool { } func (n *Nature) FieldByName(name string) (Nature, bool) { + t, k := derefTypeKind(n.Type, n.Kind) + if k != reflect.Struct { + return n.Cache.FromType(nil), false + } + var ntPtr *Nature var cacheHit bool if n.Cache.fieldByName == nil { n.Cache.fieldByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = n.Cache.fieldByName[rTypeWithKey{n.Type, name}] + ntPtr, cacheHit = n.Cache.fieldByName[rTypeWithKey{t, name}] } if !cacheHit { ntPtr = n.fieldByNameSlow(name) - n.Cache.fieldByName[rTypeWithKey{n.Type, name}] = ntPtr + n.Cache.fieldByName[rTypeWithKey{t, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true @@ -336,13 +359,12 @@ func (n *Nature) fieldByNameSlow(name string) *Nature { if n.Type == nil { return nil } - if field, ok := fetchField(n.Type, name); ok { - nt := n.Cache.FromType(field.Type) - if nt.Optional == nil { - nt.Optional = new(Optional) + t, k := derefTypeKind(n.Type, n.Kind) + if k == reflect.Struct { + nt, ok := n.Cache.fetchField(nil, t, name) + if ok { + return &nt } - nt.FieldIndex = field.Index - return &nt } return nil } @@ -367,6 +389,13 @@ func (n *Nature) IsFastMap() bool { } func (n *Nature) Get(name string) (Nature, bool) { + if n.Kind == reflect.Map && n.Optional != nil { + if f, ok := n.Fields[name]; ok { + return f, true + } + return n.Cache.FromType(nil), false + } + var ntPtr *Nature var cacheHit bool if n.Cache.get == nil { @@ -393,24 +422,14 @@ func (n *Nature) getSlow(name string) *Nature { return m } - t := deref.Type(n.Type) - switch t.Kind() { - case reflect.Struct: - if f, ok := fetchField(t, name); ok { - nt := n.Cache.FromType(f.Type) - if nt.Optional == nil { - nt.Optional = new(Optional) - } - nt.FieldIndex = f.Index + t, k := derefTypeKind(n.Type, n.Kind) + if k == reflect.Struct { + nt, ok := n.Cache.fetchField(nil, t, name) + if ok { return &nt } - case reflect.Map: - if n.Optional != nil { - if f, ok := n.Fields[name]; ok { - return &f - } - } } + return nil } diff --git a/checker/nature/utils.go b/checker/nature/utils.go index 782992a0..830fab73 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -6,48 +6,84 @@ import ( "github.com/expr-lang/expr/internal/deref" ) -func fieldName(field reflect.StructField) (string, bool) { - switch taggedName := field.Tag.Get("expr"); taggedName { +func derefTypeKind(t reflect.Type, k reflect.Kind) (reflect.Type, reflect.Kind) { + for k == reflect.Pointer { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { + switch taggedName := tag.Get("expr"); taggedName { case "-": return "", false case "": - return field.Name, true + return fieldName, true default: return taggedName, true } } -func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { - // If t is not a struct, early return. - if t.Kind() != reflect.Struct { - return reflect.StructField{}, false +func (c *Cache) fetchField(info map[string]Nature, t reflect.Type, name string) (Nature, bool) { + numField := t.NumField() + switch { + case info != nil: + case c.structs == nil: + c.structs = map[reflect.Type]map[string]Nature{} + fallthrough + case c.structs[t] == nil: + info = make(map[string]Nature, numField) + c.structs[t] = info + default: + info = c.structs[t] } - // First check all structs fields. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - // Search all fields, even embedded structs. - if n, ok := fieldName(field); ok && n == name { - return field, true + // Lookup own fields first. Cache all that is possible + for i := 0; i < numField; i++ { + sf := t.Field(i) + // BUG: we should skip if !sf.IsExported() + fName, ok := fieldName(sf.Name, sf.Tag) + if !ok || fName == "" { + // name can still be empty for a type created at runtime with + // reflect + continue + } + nt := c.FromType(sf.Type) + if nt.Optional == nil { + nt.Optional = new(Optional) + } + nt.FieldIndex = sf.Index + if _, ok := info[fName]; !ok { + // avoid overwriting fields that could potentially be own fields of + // a parent struct + info[fName] = nt + } + if fName == name { + return nt, true } } - // Second check fields of embedded structs. - for i := 0; i < t.NumField(); i++ { - anon := t.Field(i) - if anon.Anonymous { - anonType := anon.Type - if anonType.Kind() == reflect.Pointer { - anonType = anonType.Elem() - } - if field, ok := fetchField(anonType, name); ok { - field.Index = append(anon.Index, field.Index...) - return field, true - } + // Lookup embedded fields + for i := 0; i < numField; i++ { + sf := t.Field(i) + // we do enter embedded non-exported types because they could contain + // exported fields + if !sf.Anonymous { + continue + } + t, k := derefTypeKind(sf.Type, sf.Type.Kind()) + if k != reflect.Struct { + continue + } + nt, ok := c.fetchField(info, t, name) + if ok { + nt.FieldIndex = append(sf.Index, nt.FieldIndex...) + return nt, true } } - return reflect.StructField{}, false + return c.FromType(nil), false } func StructFields(c *Cache, t reflect.Type) map[string]Nature { @@ -76,7 +112,7 @@ func StructFields(c *Cache, t reflect.Type) map[string]Nature { } } - name, ok := fieldName(f) + name, ok := fieldName(f.Name, f.Tag) if !ok { continue } From ab994ae928ccef532c9580b4696d6bb6cd4aad1e Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sat, 30 Aug 2025 06:13:39 -0300 Subject: [PATCH 14/24] make cache private to Nature --- checker/checker.go | 2 +- checker/info.go | 2 +- checker/nature/nature.go | 93 +++++++++++++++++++++------------------- compiler/compiler.go | 2 +- conf/config.go | 2 +- conf/env.go | 4 +- 6 files changed, 55 insertions(+), 50 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 76cc4d0e..fd87a922 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -648,7 +648,7 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { // with new correct function return type. if typ := node.Type(); typ != nil && typ != anyType { nt := node.Nature() - nt.Cache = &v.config.NtCache // AST doesn't cache nature info + nt.SetCache(&v.config.NtCache) // AST doesn't cache nature info return nt } diff --git a/checker/info.go b/checker/info.go index 8bd38b3a..77a21536 100644 --- a/checker/info.go +++ b/checker/info.go @@ -19,7 +19,7 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { } case *ast.MemberNode: base := n.Node.Nature() - base.Cache = c // AST doesn't cache nature info + base.SetCache(c) // AST doesn't cache nature info base = base.Deref() if base.Kind == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { diff --git a/checker/nature/nature.go b/checker/nature/nature.go index f9a057d4..68b03817 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -37,7 +37,7 @@ type Nature struct { Type reflect.Type // Type of the value. If nil, then value is unknown. Kind reflect.Kind // Kind of the value. - *Cache + cache *Cache *Optional Func *builtin.Function // Used to pass function type from callee to CallNode. @@ -105,7 +105,7 @@ func (c *Cache) NatureOf(i any) Nature { // reflect.TypeOf(nil) returns nil, but in FromType we want to differentiate // what nil means for us if i == nil { - return Nature{Cache: c, Nil: true} + return Nature{cache: c, Nil: true} } return c.FromType(reflect.TypeOf(i)) } @@ -116,12 +116,13 @@ func (c *Cache) FromType(t reflect.Type) Nature { if t != nil { k := t.Kind() var opt *Optional - if k == reflect.Func { + switch k { + case reflect.Func, reflect.Struct: opt = new(Optional) } - return Nature{Type: t, Kind: k, Optional: opt, Cache: c} + return Nature{Type: t, Kind: k, Optional: opt, cache: c} } - return Nature{Cache: c} + return Nature{cache: c} } // NatureOf calls NatureOf on a nil *Cache. See the comment on Cache. @@ -143,6 +144,10 @@ func ArrayFromType(c *Cache, t reflect.Type) Nature { return nt } +func (n *Nature) SetCache(c *Cache) { + n.cache = c +} + func (n *Nature) IsAny() bool { return n.Type != nil && n.Kind == reflect.Interface && n.NumMethods() == 0 } @@ -169,27 +174,27 @@ func (n *Nature) Deref() Nature { func (n *Nature) Key() Nature { if n.Kind == reflect.Map { - return n.Cache.FromType(n.Type.Key()) + return n.cache.FromType(n.Type.Key()) } - return n.Cache.FromType(nil) + return n.cache.FromType(nil) } func (n *Nature) Elem() Nature { switch n.Kind { case reflect.Ptr: - return n.Cache.FromType(n.Type.Elem()) + return n.cache.FromType(n.Type.Elem()) case reflect.Map: if n.Optional != nil && n.DefaultMapValue != nil { return *n.DefaultMapValue } - return n.Cache.FromType(n.Type.Elem()) + return n.cache.FromType(n.Type.Elem()) case reflect.Slice, reflect.Array: if n.Ref != nil { return *n.Ref } - return n.Cache.FromType(n.Type.Elem()) + return n.cache.FromType(n.Type.Elem()) } - return n.Cache.FromType(nil) + return n.cache.FromType(nil) } func (n *Nature) AssignableTo(nt Nature) bool { @@ -216,21 +221,21 @@ func (n *Nature) MethodByName(name string) (Nature, bool) { if ntPtr := n.methodByNamePtr(name); ntPtr != nil { return *ntPtr, true } - return n.Cache.FromType(nil), false + return n.cache.FromType(nil), false } func (n *Nature) methodByNamePtr(name string) *Nature { return n.methodByNameSlow(name) var ntPtr *Nature var cacheHit bool - if n.Cache.methodByName == nil { - n.Cache.methodByName = map[rTypeWithKey]*Nature{} + if n.cache.methodByName == nil { + n.cache.methodByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = n.Cache.methodByName[rTypeWithKey{n.Type, name}] + ntPtr, cacheHit = n.cache.methodByName[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.methodByNameSlow(name) - n.Cache.methodByName[rTypeWithKey{n.Type, name}] = ntPtr + n.cache.methodByName[rTypeWithKey{n.Type, name}] = ntPtr } return ntPtr } @@ -244,7 +249,7 @@ func (n *Nature) methodByNameSlow(name string) *Nature { return nil } - nt := n.Cache.FromType(method.Type) + nt := n.cache.FromType(method.Type) if n.Kind == reflect.Interface { // In case of interface type method will not have a receiver, // and to prevent checker decreasing numbers of in arguments @@ -273,10 +278,10 @@ func (n *Nature) NumIn() int { func (n *Nature) InElem(i int) Nature { if n.inElem == nil { if n.Type == nil { - n2 := n.Cache.FromType(nil) + n2 := n.cache.FromType(nil) n.inElem = &n2 } else { - n2 := n.Cache.FromType(n.Type.In(i)) + n2 := n.cache.FromType(n.Type.In(i)) n2 = n2.Elem() n.inElem = &n2 } @@ -286,14 +291,14 @@ func (n *Nature) InElem(i int) Nature { func (n *Nature) In(i int) Nature { if n.Type == nil { - return n.Cache.FromType(nil) + return n.cache.FromType(nil) } - return n.Cache.FromType(n.Type.In(i)) + return n.cache.FromType(n.Type.In(i)) } func (n *Nature) IsFirstArgUnknown() bool { if n.Type != nil { - n2 := n.Cache.FromType(n.Type.In(0)) + n2 := n.cache.FromType(n.Type.In(0)) return n2.IsUnknown() } return false @@ -320,9 +325,9 @@ func (n *Nature) Out(i int) Nature { func (n *Nature) out(i int) Nature { if n.Type == nil { - return n.Cache.FromType(nil) + return n.cache.FromType(nil) } - return n.Cache.FromType(n.Type.Out(i)) + return n.cache.FromType(n.Type.Out(i)) } func (n *Nature) IsVariadic() bool { @@ -335,24 +340,24 @@ func (n *Nature) IsVariadic() bool { func (n *Nature) FieldByName(name string) (Nature, bool) { t, k := derefTypeKind(n.Type, n.Kind) if k != reflect.Struct { - return n.Cache.FromType(nil), false + return n.cache.FromType(nil), false } var ntPtr *Nature var cacheHit bool - if n.Cache.fieldByName == nil { - n.Cache.fieldByName = map[rTypeWithKey]*Nature{} + if n.cache.fieldByName == nil { + n.cache.fieldByName = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = n.Cache.fieldByName[rTypeWithKey{t, name}] + ntPtr, cacheHit = n.cache.fieldByName[rTypeWithKey{t, name}] } if !cacheHit { ntPtr = n.fieldByNameSlow(name) - n.Cache.fieldByName[rTypeWithKey{t, name}] = ntPtr + n.cache.fieldByName[rTypeWithKey{t, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true } - return n.Cache.FromType(nil), false + return n.cache.FromType(nil), false } func (n *Nature) fieldByNameSlow(name string) *Nature { @@ -361,7 +366,7 @@ func (n *Nature) fieldByNameSlow(name string) *Nature { } t, k := derefTypeKind(n.Type, n.Kind) if k == reflect.Struct { - nt, ok := n.Cache.fetchField(nil, t, name) + nt, ok := n.cache.fetchField(nil, t, name) if ok { return &nt } @@ -393,24 +398,24 @@ func (n *Nature) Get(name string) (Nature, bool) { if f, ok := n.Fields[name]; ok { return f, true } - return n.Cache.FromType(nil), false + return n.cache.FromType(nil), false } var ntPtr *Nature var cacheHit bool - if n.Cache.get == nil { - n.Cache.get = map[rTypeWithKey]*Nature{} + if n.cache.get == nil { + n.cache.get = map[rTypeWithKey]*Nature{} } else { - ntPtr, cacheHit = n.Cache.get[rTypeWithKey{n.Type, name}] + ntPtr, cacheHit = n.cache.get[rTypeWithKey{n.Type, name}] } if !cacheHit { ntPtr = n.getSlow(name) - n.Cache.get[rTypeWithKey{n.Type, name}] = ntPtr + n.cache.get[rTypeWithKey{n.Type, name}] = ntPtr } if ntPtr != nil { return *ntPtr, true } - return n.Cache.FromType(nil), false + return n.cache.FromType(nil), false } func (n *Nature) getSlow(name string) *Nature { @@ -424,7 +429,7 @@ func (n *Nature) getSlow(name string) *Nature { t, k := derefTypeKind(n.Type, n.Kind) if k == reflect.Struct { - nt, ok := n.Cache.fetchField(nil, t, name) + nt, ok := n.cache.fetchField(nil, t, name) if ok { return &nt } @@ -442,7 +447,7 @@ func (n *Nature) All() map[string]Nature { for i := 0; i < n.NumMethods(); i++ { method := n.Type.Method(i) - nt := n.Cache.FromType(method.Type) + nt := n.cache.FromType(method.Type) if nt.Optional == nil { nt.Optional = new(Optional) } @@ -455,7 +460,7 @@ func (n *Nature) All() map[string]Nature { switch t.Kind() { case reflect.Struct: - for name, nt := range StructFields(n.Cache, t) { + for name, nt := range StructFields(n.cache, t) { if _, ok := table[name]; ok { continue } @@ -499,12 +504,12 @@ func (n *Nature) IsFloat() bool { func (n *Nature) PromoteNumericNature(rhs Nature) Nature { if n.IsUnknown() || rhs.IsUnknown() { - return n.Cache.FromType(nil) + return n.cache.FromType(nil) } if n.IsFloat() || rhs.IsFloat() { - return n.Cache.FromType(floatType) + return n.cache.FromType(floatType) } - return n.Cache.FromType(intType) + return n.cache.FromType(intType) } func (n *Nature) IsTime() bool { @@ -590,7 +595,7 @@ func (n *Nature) MaybeCompatible(rhs Nature, cs ...NatureCheck) bool { } func (n *Nature) MakeArrayOf() Nature { - nt := n.Cache.FromType(arrayType) + nt := n.cache.FromType(arrayType) nt.Ref = n return nt } diff --git a/compiler/compiler.go b/compiler/compiler.go index 6829fc00..a4bd5871 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1089,7 +1089,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.compile(arg) argType := arg.Type() argNature := arg.Nature() - argNature.Cache = c.ntCache // AST doesn't cache nature info + argNature.SetCache(c.ntCache) // AST doesn't cache nature info if argType.Kind() == reflect.Ptr || argNature.IsUnknown() { if f.Deref == nil { // By default, builtins expect arguments to be dereferenced. diff --git a/conf/config.go b/conf/config.go index 841fcdfc..d6274389 100644 --- a/conf/config.go +++ b/conf/config.go @@ -50,7 +50,7 @@ func CreateNew() *Config { for _, f := range builtin.Builtins { c.Builtins[f.Name] = f } - c.Env.Cache = &c.NtCache + c.Env.SetCache(&c.NtCache) return c } diff --git a/conf/env.go b/conf/env.go index 63dc5255..43fb19eb 100644 --- a/conf/env.go +++ b/conf/env.go @@ -19,7 +19,7 @@ func Env(c *Cache, env any) Nature { switch env := env.(type) { case types.Map: nt := env.Nature() - nt.Cache = c + nt.SetCache(c) return nt } @@ -51,7 +51,7 @@ func Env(c *Cache, env any) Nature { switch face := face.(type) { case types.Map: nt := face.Nature() - nt.Cache = c + nt.SetCache(c) n.Fields[key.String()] = nt default: From 720d2b6c85c5ca5338cdc1c861983294ee40a602 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sat, 30 Aug 2025 21:38:44 -0300 Subject: [PATCH 15/24] first iteration improving struct reflect caching --- checker/nature/nature.go | 123 ++++++++++++------------- checker/nature/utils.go | 193 ++++++++++++++++++++++++++++++++++++++- compiler/compiler.go | 3 +- 3 files changed, 250 insertions(+), 69 deletions(-) diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 68b03817..04ead0be 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -54,7 +54,8 @@ type Nature struct { type Optional struct { // struct-only data - FieldIndex []int // Index of field in type. + *structData + FieldIndex []int // Index of field in type. TODO: deprecate MethodIndex int // Index of method in type. // map-only data @@ -73,8 +74,9 @@ type Optional struct { type Cache struct { methodByName map[rTypeWithKey]*Nature fieldByName map[rTypeWithKey]*Nature - get map[rTypeWithKey]*Nature - structs map[reflect.Type]map[string]Nature + structs map[reflect.Type]Nature + + xxxStructs map[reflect.Type]map[string]Nature // TODO: deprecate } type rTypeWithKey struct { @@ -82,23 +84,6 @@ type rTypeWithKey struct { key string } -type structInfo struct { - fields map[string]Nature -} - -func newStructInfo(t reflect.Type) *structInfo { - numField := t.NumField() - return &structInfo{ - fields: make(map[string]Nature, numField), - } -} - -func (si *structInfo) set(name string, nt Nature) { - if _, ok := si.fields[name]; !ok { - si.fields[name] = nt - } -} - // NatureOf returns a Nature describing "i". If "i" is nil then it returns a // Nature describing the value "nil". func (c *Cache) NatureOf(i any) Nature { @@ -113,16 +98,47 @@ func (c *Cache) NatureOf(i any) Nature { // FromType returns a Nature describing a value of type "t". If "t" is nil then // it returns a Nature describing an unknown value. func (c *Cache) FromType(t reflect.Type) Nature { - if t != nil { - k := t.Kind() - var opt *Optional - switch k { - case reflect.Func, reflect.Struct: - opt = new(Optional) + if t == nil { + return Nature{cache: c} + } + var opt *Optional + k := t.Kind() + switch k { + case reflect.Struct: + return c.getStruct(t) + case reflect.Func: + opt = new(Optional) + } + return Nature{Type: t, Kind: k, Optional: opt, cache: c} +} + +func (c *Cache) getStruct(t reflect.Type) Nature { + if c != nil { + if c.structs == nil { + c.structs = map[reflect.Type]Nature{} + } else if nt, ok := c.structs[t]; ok { + return nt } - return Nature{Type: t, Kind: k, Optional: opt, cache: c} } - return Nature{cache: c} + numField := t.NumField() + opt := &Optional{ + structData: &structData{ + Cache: c, + rType: t, + fields: make(map[string]structField, numField), + numField: numField, + anonIdx: -1, // do not lookup embedded fields yet + }, + } + nt := Nature{ + Type: t, + Kind: reflect.Struct, + Optional: opt, + } + if c != nil { + nt.SetCache(c) + } + return nt } // NatureOf calls NatureOf on a nil *Cache. See the comment on Cache. @@ -146,6 +162,13 @@ func ArrayFromType(c *Cache, t reflect.Type) Nature { func (n *Nature) SetCache(c *Cache) { n.cache = c + if n.Kind == reflect.Struct { + n.structData.Cache = c + if c.structs == nil { + c.structs = map[reflect.Type]Nature{} + } + c.structs[n.Type] = *n + } } func (n *Nature) IsAny() bool { @@ -221,7 +244,7 @@ func (n *Nature) MethodByName(name string) (Nature, bool) { if ntPtr := n.methodByNamePtr(name); ntPtr != nil { return *ntPtr, true } - return n.cache.FromType(nil), false + return Nature{}, false } func (n *Nature) methodByNamePtr(name string) *Nature { @@ -361,12 +384,9 @@ func (n *Nature) FieldByName(name string) (Nature, bool) { } func (n *Nature) fieldByNameSlow(name string) *Nature { - if n.Type == nil { - return nil - } t, k := derefTypeKind(n.Type, n.Kind) if k == reflect.Struct { - nt, ok := n.cache.fetchField(nil, t, name) + nt, ok := n.cache.getStruct(t).structField(nil, name) if ok { return &nt } @@ -398,44 +418,21 @@ func (n *Nature) Get(name string) (Nature, bool) { if f, ok := n.Fields[name]; ok { return f, true } - return n.cache.FromType(nil), false - } - - var ntPtr *Nature - var cacheHit bool - if n.cache.get == nil { - n.cache.get = map[rTypeWithKey]*Nature{} - } else { - ntPtr, cacheHit = n.cache.get[rTypeWithKey{n.Type, name}] - } - if !cacheHit { - ntPtr = n.getSlow(name) - n.cache.get[rTypeWithKey{n.Type, name}] = ntPtr - } - if ntPtr != nil { - return *ntPtr, true + return Nature{}, false } - return n.cache.FromType(nil), false -} -func (n *Nature) getSlow(name string) *Nature { - if n.Type == nil { - return nil + if nt, ok := n.MethodByName(name); ok { + return nt, true } - if m := n.methodByNamePtr(name); m != nil { - return m - } - - t, k := derefTypeKind(n.Type, n.Kind) - if k == reflect.Struct { - nt, ok := n.cache.fetchField(nil, t, name) + if n.Kind == reflect.Struct { + nt, ok := n.structField(nil, name) if ok { - return &nt + return nt, true } } - return nil + return Nature{}, false } func (n *Nature) All() map[string]Nature { diff --git a/checker/nature/utils.go b/checker/nature/utils.go index 830fab73..a55df11e 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -1,6 +1,7 @@ package nature import ( + "fmt" "reflect" "github.com/expr-lang/expr/internal/deref" @@ -25,18 +26,200 @@ func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { } } +type structData struct { + *Cache + rType reflect.Type + fields map[string]structField + numField, ownIdx, anonIdx int + + curParent, curChild *structData + curChildIndex []int +} + +type structField struct { + Nature + Index []int +} + +func (s *structData) finished() bool { + return s.ownIdx >= s.numField && // no own fields left to visit + s.anonIdx >= s.numField && // no embedded fields to visit + s.curChild == nil // no child in process of visiting +} + +func (s *structData) structDebug(prefix string) { // TODO: DEBUG + if len(s.fields) == 0 { + fmt.Printf("%s[%s]\n", prefix, s.rType) + return + } + fmt.Printf("%s[%s] fields:\n", prefix, s.rType) + prefix += " " + for k, v := range s.fields { + fmt.Printf("%s%s at index %v:", prefix, k, v.Index) + if v.Nature.Kind == reflect.Struct { + v.Nature.structDebug(prefix) + } else { + fmt.Printf("%s[%s]\n", prefix, v.Nature.Type) + } + } +} + +func (s *structData) structField(parentEmbed *structData, name string) (Nature, bool) { + if f, ok := s.fields[name]; ok { + return f.Nature, true + } + if s.finished() { + return Nature{}, false + } + + // Lookup own fields first. + for ; s.ownIdx < s.numField; s.ownIdx++ { + sf := s.rType.Field(s.ownIdx) + // BUG: we should skip if !sf.IsExported() here + + if sf.Anonymous && s.anonIdx < 0 { + // start iterating anon fields on the first instead of zero + s.anonIdx = s.ownIdx + } + fName, ok := fieldName(sf.Name, sf.Tag) + if !ok || fName == "" { + // name can still be empty for a type created at runtime with + // reflect + continue + } + nt := s.FromType(sf.Type) + + // TODO BEGIN: deprecate + opt := new(Optional) + if nt.Optional != nil { + *opt = *nt.Optional + } + nt.Optional = opt + nt.FieldIndex = sf.Index + // TODO END: deprecate + + s.fields[fName] = structField{ + Nature: nt, + Index: sf.Index, + } + if parentEmbed != nil { + parentEmbed.trySet(fName, nt, sf.Index) + } + if fName == name { + return nt, true + } + } + + if s.curChild != nil { + nt, ok := s.findInEmbedded(parentEmbed, s.curChild, s.curChildIndex, name) + if ok { + return nt, true + } + } + + // Lookup embedded fields through anon own fields + for ; s.anonIdx >= 0 && s.anonIdx < s.numField; s.anonIdx++ { + sf := s.rType.Field(s.anonIdx) + // we do enter embedded non-exported types because they could contain + // exported fields + if !sf.Anonymous { + continue + } + t, k := derefTypeKind(sf.Type, sf.Type.Kind()) + if k != reflect.Struct { + continue + } + + childEmbed := s.Cache.getStruct(t).structData + nt, ok := s.findInEmbedded(parentEmbed, childEmbed, sf.Index, name) + if ok { + return nt, true + } + } + + return Nature{}, false +} + +func (s *structData) findInEmbedded( + parentEmbed, childEmbed *structData, + childIndex []int, + name string, +) (Nature, bool) { + // Set current parent/child data. This allows trySet to handle child fields + // and add them to our struct and to the parent as well if needed + s.curParent = parentEmbed + s.curChild = childEmbed + s.curChildIndex = childIndex + defer func() { + // Ensure to cleanup references + s.curParent = nil + if childEmbed.finished() { + // If the child can still have more fields to explore then keep it + // referened to look it up again if we need to + s.curChild = nil + s.curChildIndex = nil + } + }() + + // See if the child has already cached its fields. This is still important + // to check even if it's the s.unfinishedEmbedded because it may have + // explored new fields since the last time we visited it + for name, sf := range childEmbed.fields { + s.trySet(name, sf.Nature, sf.Index) + } + + // Recheck if we have what we needed from the above sync + if sf, ok := s.fields[name]; ok { + return sf.Nature, true + } + + // Try finding in the child again in case it hasn't finished + if !childEmbed.finished() { + if _, ok := childEmbed.structField(s, name); ok { + return s.fields[name].Nature, true + } + } + + return Nature{}, false +} + +func (s *structData) trySet(name string, nt Nature, idx []int) { + if _, ok := s.fields[name]; ok { + return + } + idx = append(s.curChildIndex, idx...) + + // TODO BEGIN: deprecate + opt := new(Optional) + if nt.Optional != nil { + *opt = *nt.Optional + } + nt.Optional = opt + nt.FieldIndex = idx + // TODO END: deprecate + + s.fields[name] = structField{ + Nature: nt, + Index: idx, + } + if s.curParent != nil { + s.curParent.trySet(name, nt, idx) + } +} + +// TODO: deprecate func (c *Cache) fetchField(info map[string]Nature, t reflect.Type, name string) (Nature, bool) { numField := t.NumField() switch { case info != nil: - case c.structs == nil: - c.structs = map[reflect.Type]map[string]Nature{} + case c.xxxStructs == nil: + c.xxxStructs = map[reflect.Type]map[string]Nature{} fallthrough - case c.structs[t] == nil: + case c.xxxStructs[t] == nil: info = make(map[string]Nature, numField) - c.structs[t] = info + c.xxxStructs[t] = info default: - info = c.structs[t] + info = c.xxxStructs[t] } // Lookup own fields first. Cache all that is possible diff --git a/compiler/compiler.go b/compiler/compiler.go index a4bd5871..9d3992ff 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -5,6 +5,7 @@ import ( "math" "reflect" "regexp" + "runtime/debug" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" @@ -25,7 +26,7 @@ const ( func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err error) { defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v", r) + err = fmt.Errorf("%v\n%s", r, debug.Stack()) } }() From 35e3c1e18ee2b22ac9a69d283cb1872d94a309f6 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sat, 30 Aug 2025 22:35:51 -0300 Subject: [PATCH 16/24] cleaned up struct fields data --- checker/info.go | 13 ++-- checker/nature/nature.go | 103 +++++++++++---------------- checker/nature/utils.go | 150 +++++++++------------------------------ 3 files changed, 78 insertions(+), 188 deletions(-) diff --git a/checker/info.go b/checker/info.go index 77a21536..acbf0e04 100644 --- a/checker/info.go +++ b/checker/info.go @@ -11,11 +11,8 @@ import ( func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if env.Kind == reflect.Struct { - field, ok := env.Get(n.Value) - if ok && field.Optional != nil && len(field.FieldIndex) > 0 { - return true, field.FieldIndex, n.Value - } + if idx, ok := env.FieldIndex_(n.Value); ok { + return true, idx, n.Value } case *ast.MemberNode: base := n.Node.Nature() @@ -23,10 +20,8 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { base = base.Deref() if base.Kind == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { - name := prop.Value - field, ok := base.FieldByName(name) - if ok && field.Optional != nil { - return true, field.FieldIndex, name + if idx, ok := base.FieldIndex_(prop.Value); ok { + return true, idx, prop.Value } } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 04ead0be..f905c408 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -55,8 +55,7 @@ type Nature struct { type Optional struct { // struct-only data *structData - FieldIndex []int // Index of field in type. TODO: deprecate - MethodIndex int // Index of method in type. + MethodIndex int // Index of method in type. // map-only data Fields map[string]Nature // Fields of map type. @@ -73,10 +72,7 @@ type Optional struct { // from one of those packages, the cache must be set immediately. type Cache struct { methodByName map[rTypeWithKey]*Nature - fieldByName map[rTypeWithKey]*Nature structs map[reflect.Type]Nature - - xxxStructs map[reflect.Type]map[string]Nature // TODO: deprecate } type rTypeWithKey struct { @@ -121,19 +117,18 @@ func (c *Cache) getStruct(t reflect.Type) Nature { } } numField := t.NumField() - opt := &Optional{ - structData: &structData{ - Cache: c, - rType: t, - fields: make(map[string]structField, numField), - numField: numField, - anonIdx: -1, // do not lookup embedded fields yet - }, - } nt := Nature{ - Type: t, - Kind: reflect.Struct, - Optional: opt, + Type: t, + Kind: reflect.Struct, + Optional: &Optional{ + structData: &structData{ + Cache: c, + rType: t, + fields: make(map[string]structField, numField), + numField: numField, + anonIdx: -1, // do not lookup embedded fields yet + }, + }, } if c != nil { nt.SetCache(c) @@ -187,12 +182,11 @@ func (n *Nature) String() string { } func (n *Nature) Deref() Nature { - ret := *n - if ret.Type != nil { - ret.Type = deref.Type(ret.Type) - ret.Kind = ret.Type.Kind() + t, _, changed := derefTypeKind(n.Type, n.Kind) + if !changed { + return *n } - return ret + return n.cache.FromType(t) } func (n *Nature) Key() Nature { @@ -361,37 +355,19 @@ func (n *Nature) IsVariadic() bool { } func (n *Nature) FieldByName(name string) (Nature, bool) { - t, k := derefTypeKind(n.Type, n.Kind) - if k != reflect.Struct { - return n.cache.FromType(nil), false + if n.Kind != reflect.Struct { + return Nature{}, false } - - var ntPtr *Nature - var cacheHit bool - if n.cache.fieldByName == nil { - n.cache.fieldByName = map[rTypeWithKey]*Nature{} + var sd *structData + if n.Optional != nil && n.structData != nil { + sd = n.structData } else { - ntPtr, cacheHit = n.cache.fieldByName[rTypeWithKey{t, name}] - } - if !cacheHit { - ntPtr = n.fieldByNameSlow(name) - n.cache.fieldByName[rTypeWithKey{t, name}] = ntPtr + sd = n.cache.getStruct(n.Type).structData } - if ntPtr != nil { - return *ntPtr, true - } - return n.cache.FromType(nil), false -} - -func (n *Nature) fieldByNameSlow(name string) *Nature { - t, k := derefTypeKind(n.Type, n.Kind) - if k == reflect.Struct { - nt, ok := n.cache.getStruct(t).structField(nil, name) - if ok { - return &nt - } + if sf, ok := sd.structField(nil, name); ok { + return sf.Nature, true } - return nil + return Nature{}, false } func (n *Nature) PkgPath() string { @@ -402,15 +378,10 @@ func (n *Nature) PkgPath() string { } func (n *Nature) IsFastMap() bool { - if n.Type == nil { - return false - } - if n.Type.Kind() == reflect.Map && + return n.Type != nil && + n.Type.Kind() == reflect.Map && n.Type.Key().Kind() == reflect.String && - n.Type.Elem().Kind() == reflect.Interface { - return true - } - return false + n.Type.Elem().Kind() == reflect.Interface } func (n *Nature) Get(name string) (Nature, bool) { @@ -420,21 +391,27 @@ func (n *Nature) Get(name string) (Nature, bool) { } return Nature{}, false } - if nt, ok := n.MethodByName(name); ok { return nt, true } - if n.Kind == reflect.Struct { - nt, ok := n.structField(nil, name) - if ok { - return nt, true + if sf, ok := n.structField(nil, name); ok { + return sf.Nature, true } } - return Nature{}, false } +func (n *Nature) FieldIndex_(name string) ([]int, bool) { + if n.Kind != reflect.Struct { + return nil, false + } + if sf, ok := n.structField(nil, name); ok { + return sf.Index, true + } + return nil, false +} + func (n *Nature) All() map[string]Nature { table := make(map[string]Nature) diff --git a/checker/nature/utils.go b/checker/nature/utils.go index a55df11e..65465a88 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -7,12 +7,13 @@ import ( "github.com/expr-lang/expr/internal/deref" ) -func derefTypeKind(t reflect.Type, k reflect.Kind) (reflect.Type, reflect.Kind) { +func derefTypeKind(t reflect.Type, k reflect.Kind) (_ reflect.Type, _ reflect.Kind, changed bool) { for k == reflect.Pointer { + changed = true t = t.Elem() k = t.Kind() } - return t, k + return t, k, changed } func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { @@ -64,87 +65,78 @@ func (s *structData) structDebug(prefix string) { // TODO: DEBUG } } -func (s *structData) structField(parentEmbed *structData, name string) (Nature, bool) { +func (s *structData) structField(parentEmbed *structData, name string) (structField, bool) { if f, ok := s.fields[name]; ok { - return f.Nature, true + return f, true } if s.finished() { - return Nature{}, false + return structField{}, false } // Lookup own fields first. for ; s.ownIdx < s.numField; s.ownIdx++ { - sf := s.rType.Field(s.ownIdx) - // BUG: we should skip if !sf.IsExported() here + field := s.rType.Field(s.ownIdx) + // BUG: we should skip if !field.IsExported() here - if sf.Anonymous && s.anonIdx < 0 { + if field.Anonymous && s.anonIdx < 0 { // start iterating anon fields on the first instead of zero s.anonIdx = s.ownIdx } - fName, ok := fieldName(sf.Name, sf.Tag) + fName, ok := fieldName(field.Name, field.Tag) if !ok || fName == "" { // name can still be empty for a type created at runtime with // reflect continue } - nt := s.FromType(sf.Type) - - // TODO BEGIN: deprecate - opt := new(Optional) - if nt.Optional != nil { - *opt = *nt.Optional - } - nt.Optional = opt - nt.FieldIndex = sf.Index - // TODO END: deprecate - - s.fields[fName] = structField{ + nt := s.FromType(field.Type) + sf := structField{ Nature: nt, - Index: sf.Index, + Index: field.Index, } + s.fields[fName] = sf if parentEmbed != nil { - parentEmbed.trySet(fName, nt, sf.Index) + parentEmbed.trySet(fName, sf) } if fName == name { - return nt, true + return sf, true } } if s.curChild != nil { - nt, ok := s.findInEmbedded(parentEmbed, s.curChild, s.curChildIndex, name) + sf, ok := s.findInEmbedded(parentEmbed, s.curChild, s.curChildIndex, name) if ok { - return nt, true + return sf, true } } // Lookup embedded fields through anon own fields for ; s.anonIdx >= 0 && s.anonIdx < s.numField; s.anonIdx++ { - sf := s.rType.Field(s.anonIdx) + field := s.rType.Field(s.anonIdx) // we do enter embedded non-exported types because they could contain // exported fields - if !sf.Anonymous { + if !field.Anonymous { continue } - t, k := derefTypeKind(sf.Type, sf.Type.Kind()) + t, k, _ := derefTypeKind(field.Type, field.Type.Kind()) if k != reflect.Struct { continue } childEmbed := s.Cache.getStruct(t).structData - nt, ok := s.findInEmbedded(parentEmbed, childEmbed, sf.Index, name) + sf, ok := s.findInEmbedded(parentEmbed, childEmbed, field.Index, name) if ok { - return nt, true + return sf, true } } - return Nature{}, false + return structField{}, false } func (s *structData) findInEmbedded( parentEmbed, childEmbed *structData, childIndex []int, name string, -) (Nature, bool) { +) (structField, bool) { // Set current parent/child data. This allows trySet to handle child fields // and add them to our struct and to the parent as well if needed s.curParent = parentEmbed @@ -165,110 +157,38 @@ func (s *structData) findInEmbedded( // to check even if it's the s.unfinishedEmbedded because it may have // explored new fields since the last time we visited it for name, sf := range childEmbed.fields { - s.trySet(name, sf.Nature, sf.Index) + s.trySet(name, sf) } // Recheck if we have what we needed from the above sync if sf, ok := s.fields[name]; ok { - return sf.Nature, true + return sf, true } // Try finding in the child again in case it hasn't finished if !childEmbed.finished() { if _, ok := childEmbed.structField(s, name); ok { - return s.fields[name].Nature, true + return s.fields[name], true } } - return Nature{}, false + return structField{}, false } -func (s *structData) trySet(name string, nt Nature, idx []int) { +func (s *structData) trySet(name string, sf structField) { if _, ok := s.fields[name]; ok { return } - idx = append(s.curChildIndex, idx...) - - // TODO BEGIN: deprecate - opt := new(Optional) - if nt.Optional != nil { - *opt = *nt.Optional - } - nt.Optional = opt - nt.FieldIndex = idx - // TODO END: deprecate - + sf.Index = append(s.curChildIndex, sf.Index...) s.fields[name] = structField{ - Nature: nt, - Index: idx, + Nature: sf.Nature, + Index: sf.Index, } if s.curParent != nil { - s.curParent.trySet(name, nt, idx) + s.curParent.trySet(name, sf) } } -// TODO: deprecate -func (c *Cache) fetchField(info map[string]Nature, t reflect.Type, name string) (Nature, bool) { - numField := t.NumField() - switch { - case info != nil: - case c.xxxStructs == nil: - c.xxxStructs = map[reflect.Type]map[string]Nature{} - fallthrough - case c.xxxStructs[t] == nil: - info = make(map[string]Nature, numField) - c.xxxStructs[t] = info - default: - info = c.xxxStructs[t] - } - - // Lookup own fields first. Cache all that is possible - for i := 0; i < numField; i++ { - sf := t.Field(i) - // BUG: we should skip if !sf.IsExported() - fName, ok := fieldName(sf.Name, sf.Tag) - if !ok || fName == "" { - // name can still be empty for a type created at runtime with - // reflect - continue - } - nt := c.FromType(sf.Type) - if nt.Optional == nil { - nt.Optional = new(Optional) - } - nt.FieldIndex = sf.Index - if _, ok := info[fName]; !ok { - // avoid overwriting fields that could potentially be own fields of - // a parent struct - info[fName] = nt - } - if fName == name { - return nt, true - } - } - - // Lookup embedded fields - for i := 0; i < numField; i++ { - sf := t.Field(i) - // we do enter embedded non-exported types because they could contain - // exported fields - if !sf.Anonymous { - continue - } - t, k := derefTypeKind(sf.Type, sf.Type.Kind()) - if k != reflect.Struct { - continue - } - nt, ok := c.fetchField(info, t, name) - if ok { - nt.FieldIndex = append(sf.Index, nt.FieldIndex...) - return nt, true - } - } - - return c.FromType(nil), false -} - func StructFields(c *Cache, t reflect.Type) map[string]Nature { table := make(map[string]Nature) @@ -290,7 +210,6 @@ func StructFields(c *Cache, t reflect.Type) map[string]Nature { if typ.Optional == nil { typ.Optional = new(Optional) } - typ.FieldIndex = append(f.Index, typ.FieldIndex...) table[name] = typ } } @@ -303,7 +222,6 @@ func StructFields(c *Cache, t reflect.Type) map[string]Nature { if nt.Optional == nil { nt.Optional = new(Optional) } - nt.FieldIndex = f.Index table[name] = nt } From e68fbc783fc0fd7737958f6d3c5dffc5da4c19d3 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 03:25:26 -0300 Subject: [PATCH 17/24] improve caching of functions data --- checker/checker.go | 68 ++++++------ checker/info.go | 6 +- checker/nature/nature.go | 227 ++++++++++++++++++++++----------------- checker/nature/utils.go | 95 ++++++++++------ types/types.go | 3 - 5 files changed, 229 insertions(+), 170 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index fd87a922..f053f36e 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -240,7 +240,7 @@ func (v *Checker) error(node ast.Node, format string, args ...any) Nature { Message: fmt.Sprintf(format, args...), } } - return v.config.NtCache.FromType(nil) + return Nature{} } func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { @@ -248,7 +248,7 @@ func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { return variable.nature } if node.Value == "$env" { - return v.config.NtCache.FromType(nil) + return Nature{} } return v.ident(node, node.Value, v.config.Strict, true) @@ -262,19 +262,25 @@ func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Natur if builtins { if fn, ok := v.config.Functions[name]; ok { nt := v.config.NtCache.FromType(fn.Type()) - nt.Func = fn + if nt.FuncData == nil { + nt.FuncData = new(FuncData) + } + nt.FuncData.Func = fn return nt } if fn, ok := v.config.Builtins[name]; ok { nt := v.config.NtCache.FromType(fn.Type()) - nt.Func = fn + if nt.FuncData == nil { + nt.FuncData = new(FuncData) + } + nt.FuncData.Func = fn return nt } } if v.config.Strict && strict { return v.error(node, "unknown name %s", name) } - return v.config.NtCache.FromType(nil) + return Nature{} } func (v *Checker) integerNode(*ast.IntegerNode) Nature { @@ -316,7 +322,7 @@ func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { return nt } if nt.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } default: @@ -378,7 +384,7 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return v.config.NtCache.FromType(durationType) } if l.MaybeCompatible(r, NumberCheck, TimeCheck, DurationCheck) { - return v.config.NtCache.FromType(nil) + return Nature{} } case "*": @@ -395,7 +401,7 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return v.config.NtCache.FromType(durationType) } if l.MaybeCompatible(r, NumberCheck, DurationCheck) { - return v.config.NtCache.FromType(nil) + return Nature{} } case "/": @@ -439,7 +445,7 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { return v.config.NtCache.FromType(durationType) } if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { - return v.config.NtCache.FromType(nil) + return Nature{} } case "in": @@ -507,7 +513,7 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { if r.AssignableTo(l) { return l } - return v.config.NtCache.FromType(nil) + return Nature{} default: return v.error(node, "unknown operator (%s)", node.Operator) @@ -535,14 +541,14 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { } return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } - return v.config.NtCache.FromType(nil) + return Nature{} } base := v.visit(node.Node) prop := v.visit(node.Property) if base.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } if name, ok := node.Property.(*ast.StringNode); ok { @@ -607,7 +613,7 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { nt := v.visit(node.Node) if nt.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } switch nt.Kind { @@ -654,11 +660,11 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { nt := v.visit(node.Callee) if nt.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } - if nt.Func != nil { - return v.checkFunction(nt.Func, node, node.Arguments) + if nt.FuncData != nil && nt.FuncData.Func != nil { + return v.checkFunction(nt.FuncData.Func, node, node.Arguments) } fnName := "function" @@ -681,7 +687,7 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { if v.err == nil { v.err = err } - return v.config.NtCache.FromType(nil) + return Nature{} } return outType } @@ -805,7 +811,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } } else { if collection.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } return collection.Elem() } @@ -830,7 +836,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } if collection.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } return collection.Elem() } @@ -912,7 +918,7 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, varScope{"index", v.config.NtCache.FromType(intType)}, varScope{"acc", v.config.NtCache.FromType(nil)}) + v.begin(collection, varScope{"index", v.config.NtCache.FromType(intType)}, varScope{"acc", Nature{}}) predicate := v.visit(node.Arguments[1]) v.end() @@ -963,11 +969,11 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { return nt } } - return v.config.NtCache.FromType(nil) + return Nature{} } if base.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } switch base.Kind { @@ -1007,7 +1013,7 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if v.err == nil { v.err = err } - return v.config.NtCache.FromType(nil) + return Nature{} } // No type was specified, so we assume the function returns any. return nt @@ -1033,7 +1039,7 @@ func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if v.err == nil { v.err = lastErr } - return v.config.NtCache.FromType(nil) + return Nature{} } return v.error(node, "no matching overload for %v", f.Name) @@ -1046,18 +1052,18 @@ func (v *Checker) checkArguments( node ast.Node, ) (Nature, *file.Error) { if fn.IsUnknown() { - return v.config.NtCache.FromType(nil), nil + return Nature{}, nil } numOut := fn.NumOut() if numOut == 0 { - return v.config.NtCache.FromType(nil), &file.Error{ + return Nature{}, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } if numOut > 2 { - return v.config.NtCache.FromType(nil), &file.Error{ + return Nature{}, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), } @@ -1134,7 +1140,7 @@ func (v *Checker) checkArguments( if in.Kind == reflect.Ptr || in.Kind == reflect.Interface { continue } - return v.config.NtCache.FromType(nil), &file.Error{ + return Nature{}, &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in.String(), name), } @@ -1154,7 +1160,7 @@ func (v *Checker) checkArguments( } if !assignable && !argNature.IsUnknown() { - return v.config.NtCache.FromType(nil), &file.Error{ + return Nature{}, &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature.String(), in.String(), name), } @@ -1221,7 +1227,7 @@ func (v *Checker) pointerNode(node *ast.PointerNode) Nature { scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { if scope.collection.IsUnknown() { - return v.config.NtCache.FromType(nil) + return Nature{} } switch scope.collection.Kind { case reflect.Array, reflect.Slice: @@ -1300,7 +1306,7 @@ func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { if t1.AssignableTo(t2) { return t1 } - return v.config.NtCache.FromType(nil) + return Nature{} } func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { diff --git a/checker/info.go b/checker/info.go index acbf0e04..74b9bd38 100644 --- a/checker/info.go +++ b/checker/info.go @@ -11,7 +11,7 @@ import ( func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if idx, ok := env.FieldIndex_(n.Value); ok { + if idx, ok := env.FieldIndex(n.Value); ok { return true, idx, n.Value } case *ast.MemberNode: @@ -20,7 +20,7 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { base = base.Deref() if base.Kind == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { - if idx, ok := base.FieldIndex_(prop.Value); ok { + if idx, ok := base.FieldIndex(prop.Value); ok { return true, idx, prop.Value } } @@ -33,7 +33,7 @@ func MethodIndex(c *Cache, env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: if env.Kind == reflect.Struct { - if m, ok := env.Get(n.Value); ok && m.Optional != nil { + if m, ok := env.Get(n.Value); ok && m.FuncData != nil { return m.Method, m.MethodIndex, n.Value } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index f905c408..d385bc2d 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -39,7 +39,7 @@ type Nature struct { cache *Cache *Optional - Func *builtin.Function // Used to pass function type from callee to CallNode. + *FuncData // Ref is a reference used for multiple, disjoint purposes. When the Nature // is for a: @@ -53,16 +53,29 @@ type Nature struct { } type Optional struct { - // struct-only data + pkgPath string + methodset *methodset // optional to avoid the map in *Cache + *structData - MethodIndex int // Index of method in type. // map-only data Fields map[string]Nature // Fields of map type. DefaultMapValue *Nature // Default value of map type. - // func-only data + pkgPathSet bool +} + +type FuncData struct { + Func *builtin.Function // Used to pass function type from callee to CallNode. + MethodIndex int // Index of method in type. + inElem, outZero *Nature + numIn, numOut int + + isVariadic bool + isVariadicSet bool + numInSet bool + numOutSet bool } // Cache is a shared cache of type information. It is only used in the stages @@ -71,8 +84,8 @@ type Optional struct { // from the Nature type, they only describe. However, when receiving a Nature // from one of those packages, the cache must be set immediately. type Cache struct { - methodByName map[rTypeWithKey]*Nature - structs map[reflect.Type]Nature + methods map[reflect.Type]*methodset + structs map[reflect.Type]Nature } type rTypeWithKey struct { @@ -95,17 +108,17 @@ func (c *Cache) NatureOf(i any) Nature { // it returns a Nature describing an unknown value. func (c *Cache) FromType(t reflect.Type) Nature { if t == nil { - return Nature{cache: c} + return Nature{} } - var opt *Optional + var fd *FuncData k := t.Kind() switch k { case reflect.Struct: return c.getStruct(t) case reflect.Func: - opt = new(Optional) + fd = new(FuncData) } - return Nature{Type: t, Kind: k, Optional: opt, cache: c} + return Nature{Type: t, Kind: k, FuncData: fd, cache: c} } func (c *Cache) getStruct(t reflect.Type) Nature { @@ -116,16 +129,14 @@ func (c *Cache) getStruct(t reflect.Type) Nature { return nt } } - numField := t.NumField() nt := Nature{ Type: t, Kind: reflect.Struct, Optional: &Optional{ structData: &structData{ - Cache: c, + cache: c, rType: t, - fields: make(map[string]structField, numField), - numField: numField, + numField: t.NumField(), anonIdx: -1, // do not lookup embedded fields yet }, }, @@ -136,6 +147,32 @@ func (c *Cache) getStruct(t reflect.Type) Nature { return nt } +func (c *Cache) getMethodset(t reflect.Type, k reflect.Kind) *methodset { + if t == nil || c == nil { + return nil + } + if c.methods == nil { + c.methods = map[reflect.Type]*methodset{ + t: nil, + } + } else if s, ok := c.methods[t]; ok { + return s + } + numMethod := t.NumMethod() + if numMethod < 1 { + c.methods[t] = nil // negative cache + return nil + } + s := &methodset{ + cache: c, + rType: t, + kind: k, + numMethod: numMethod, + } + c.methods[t] = s + return s +} + // NatureOf calls NatureOf on a nil *Cache. See the comment on Cache. func NatureOf(i any) Nature { var c *Cache @@ -158,11 +195,25 @@ func ArrayFromType(c *Cache, t reflect.Type) Nature { func (n *Nature) SetCache(c *Cache) { n.cache = c if n.Kind == reflect.Struct { - n.structData.Cache = c + n.structData.cache = c if c.structs == nil { - c.structs = map[reflect.Type]Nature{} + c.structs = map[reflect.Type]Nature{ + n.Type: *n, + } + } else if nt, ok := c.structs[n.Type]; ok { + // invalidate local, use shared from cache + n.Optional.structData = nt.Optional.structData + } else { + c.structs[n.Type] = *n + } + } + if n.Optional != nil { + if s, ok := c.methods[n.Type]; ok { + // invalidate local if set, use shared from cache + n.Optional.methodset = s + } else if n.Optional.methodset != nil { + c.methods[n.Type] = n.Optional.methodset } - c.structs[n.Type] = *n } } @@ -193,7 +244,7 @@ func (n *Nature) Key() Nature { if n.Kind == reflect.Map { return n.cache.FromType(n.Type.Key()) } - return n.cache.FromType(nil) + return Nature{} } func (n *Nature) Elem() Nature { @@ -211,105 +262,69 @@ func (n *Nature) Elem() Nature { } return n.cache.FromType(n.Type.Elem()) } - return n.cache.FromType(nil) + return Nature{} } func (n *Nature) AssignableTo(nt Nature) bool { if n.Nil { - // Untyped nil is assignable to any interface, but implements only the empty interface. - if nt.IsAny() { + switch nt.Kind { + case reflect.Pointer, reflect.Interface: return true } } - if n.Type == nil || nt.Type == nil { + if n.Type == nil || nt.Type == nil || + n.Kind != nt.Kind && nt.Kind != reflect.Interface { return false } return n.Type.AssignableTo(nt.Type) } -func (n *Nature) NumMethods() int { - if n.Type == nil { - return 0 +func (n *Nature) getMethodset() *methodset { + if n.Optional != nil && n.Optional.methodset != nil { + return n.Optional.methodset } - return n.Type.NumMethod() -} - -func (n *Nature) MethodByName(name string) (Nature, bool) { - if ntPtr := n.methodByNamePtr(name); ntPtr != nil { - return *ntPtr, true + s := n.cache.getMethodset(n.Type, n.Kind) + if n.Optional != nil { + n.Optional.methodset = s // cache locally if possible } - return Nature{}, false + return s } -func (n *Nature) methodByNamePtr(name string) *Nature { - return n.methodByNameSlow(name) - var ntPtr *Nature - var cacheHit bool - if n.cache.methodByName == nil { - n.cache.methodByName = map[rTypeWithKey]*Nature{} - } else { - ntPtr, cacheHit = n.cache.methodByName[rTypeWithKey{n.Type, name}] - } - if !cacheHit { - ntPtr = n.methodByNameSlow(name) - n.cache.methodByName[rTypeWithKey{n.Type, name}] = ntPtr +func (n *Nature) NumMethods() int { + if s := n.getMethodset(); s != nil { + return s.numMethod } - return ntPtr + return 0 } -func (n *Nature) methodByNameSlow(name string) *Nature { - if n.Type == nil { - return nil - } - method, ok := n.Type.MethodByName(name) - if !ok { - return nil - } - - nt := n.cache.FromType(method.Type) - if n.Kind == reflect.Interface { - // In case of interface type method will not have a receiver, - // and to prevent checker decreasing numbers of in arguments - // return method type as not method (second argument is false). - - // Also, we can not use m.Index here, because it will be - // different indexes for different types which implement - // the same interface. - return &nt - } - if nt.Optional == nil { - nt.Optional = new(Optional) +func (n *Nature) MethodByName(name string) (Nature, bool) { + if s := n.getMethodset(); s != nil { + if m, ok := s.method(name); ok { + return m.nature, true + } } - nt.Method = true - nt.MethodIndex = method.Index - return &nt + return Nature{}, false } func (n *Nature) NumIn() int { - if n.Type == nil { - return 0 + if n.numInSet { + return n.numIn } - return n.Type.NumIn() + n.numInSet = true + n.numIn = n.Type.NumIn() + return n.numIn } func (n *Nature) InElem(i int) Nature { if n.inElem == nil { - if n.Type == nil { - n2 := n.cache.FromType(nil) - n.inElem = &n2 - } else { - n2 := n.cache.FromType(n.Type.In(i)) - n2 = n2.Elem() - n.inElem = &n2 - } + n2 := n.cache.FromType(n.Type.In(i)) + n2 = n2.Elem() + n.inElem = &n2 } return *n.inElem } func (n *Nature) In(i int) Nature { - if n.Type == nil { - return n.cache.FromType(nil) - } return n.cache.FromType(n.Type.In(i)) } @@ -322,10 +337,12 @@ func (n *Nature) IsFirstArgUnknown() bool { } func (n *Nature) NumOut() int { - if n.Type == nil { - return 0 + if n.numOutSet { + return n.numOut } - return n.Type.NumOut() + n.numOutSet = true + n.numOut = n.Type.NumOut() + return n.numOut } func (n *Nature) Out(i int) Nature { @@ -342,16 +359,18 @@ func (n *Nature) Out(i int) Nature { func (n *Nature) out(i int) Nature { if n.Type == nil { - return n.cache.FromType(nil) + return Nature{} } return n.cache.FromType(n.Type.Out(i)) } func (n *Nature) IsVariadic() bool { - if n.Type == nil { - return false + if n.isVariadicSet { + return n.isVariadic } - return n.Type.IsVariadic() + n.isVariadicSet = true + n.isVariadic = n.Type.IsVariadic() + return n.isVariadic } func (n *Nature) FieldByName(name string) (Nature, bool) { @@ -374,7 +393,15 @@ func (n *Nature) PkgPath() string { if n.Type == nil { return "" } - return n.Type.PkgPath() + if n.Optional != nil && n.Optional.pkgPathSet { + return n.Optional.pkgPath + } + p := n.Type.PkgPath() + if n.Optional != nil { + n.Optional.pkgPathSet = true + n.Optional.pkgPath = p + } + return p } func (n *Nature) IsFastMap() bool { @@ -386,11 +413,13 @@ func (n *Nature) IsFastMap() bool { func (n *Nature) Get(name string) (Nature, bool) { if n.Kind == reflect.Map && n.Optional != nil { - if f, ok := n.Fields[name]; ok { - return f, true - } - return Nature{}, false + f, ok := n.Fields[name] + return f, ok } + return n.getSlow(name) +} + +func (n *Nature) getSlow(name string) (Nature, bool) { if nt, ok := n.MethodByName(name); ok { return nt, true } @@ -402,7 +431,7 @@ func (n *Nature) Get(name string) (Nature, bool) { return Nature{}, false } -func (n *Nature) FieldIndex_(name string) ([]int, bool) { +func (n *Nature) FieldIndex(name string) ([]int, bool) { if n.Kind != reflect.Struct { return nil, false } @@ -423,7 +452,7 @@ func (n *Nature) All() map[string]Nature { method := n.Type.Method(i) nt := n.cache.FromType(method.Type) if nt.Optional == nil { - nt.Optional = new(Optional) + nt.FuncData = new(FuncData) } nt.Method = true nt.MethodIndex = method.Index @@ -478,7 +507,7 @@ func (n *Nature) IsFloat() bool { func (n *Nature) PromoteNumericNature(rhs Nature) Nature { if n.IsUnknown() || rhs.IsUnknown() { - return n.cache.FromType(nil) + return Nature{} } if n.IsFloat() || rhs.IsFloat() { return n.cache.FromType(floatType) diff --git a/checker/nature/utils.go b/checker/nature/utils.go index 65465a88..27555ba3 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -3,8 +3,6 @@ package nature import ( "fmt" "reflect" - - "github.com/expr-lang/expr/internal/deref" ) func derefTypeKind(t reflect.Type, k reflect.Kind) (_ reflect.Type, _ reflect.Kind, changed bool) { @@ -28,7 +26,7 @@ func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { } type structData struct { - *Cache + cache *Cache rType reflect.Type fields map[string]structField numField, ownIdx, anonIdx int @@ -66,7 +64,11 @@ func (s *structData) structDebug(prefix string) { // TODO: DEBUG } func (s *structData) structField(parentEmbed *structData, name string) (structField, bool) { - if f, ok := s.fields[name]; ok { + if s.fields == nil { + if s.numField > 0 { + s.fields = make(map[string]structField, s.numField) + } + } else if f, ok := s.fields[name]; ok { return f, true } if s.finished() { @@ -88,7 +90,7 @@ func (s *structData) structField(parentEmbed *structData, name string) (structFi // reflect continue } - nt := s.FromType(field.Type) + nt := s.cache.FromType(field.Type) sf := structField{ Nature: nt, Index: field.Index, @@ -122,7 +124,7 @@ func (s *structData) structField(parentEmbed *structData, name string) (structFi continue } - childEmbed := s.Cache.getStruct(t).structData + childEmbed := s.cache.getStruct(t).structData sf, ok := s.findInEmbedded(parentEmbed, childEmbed, field.Index, name) if ok { return sf, true @@ -191,41 +193,66 @@ func (s *structData) trySet(name string, sf structField) { func StructFields(c *Cache, t reflect.Type) map[string]Nature { table := make(map[string]Nature) - - t = deref.Type(t) if t == nil { return table } + t, k, _ := derefTypeKind(t, t.Kind()) + if k == reflect.Struct { + // lookup for a field with an empty name, which will cause to never find a + // match, meaning everything will have been cached. + sd := c.getStruct(t).structData + sd.structField(nil, "") + for name, sf := range sd.fields { + table[name] = sf.Nature + } + } + return table +} - switch t.Kind() { - case reflect.Struct: - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) +type methodset struct { + cache *Cache + rType reflect.Type + kind reflect.Kind + methods map[string]method + numMethod, idx int +} - if f.Anonymous { - for name, typ := range StructFields(c, f.Type) { - if _, ok := table[name]; ok { - continue - } - if typ.Optional == nil { - typ.Optional = new(Optional) - } - table[name] = typ - } - } +type method struct { + reflect.Method + nature Nature +} - name, ok := fieldName(f.Name, f.Tag) - if !ok { - continue - } - nt := c.FromType(f.Type) - if nt.Optional == nil { - nt.Optional = new(Optional) - } - table[name] = nt +func (s *methodset) method(name string) (method, bool) { + if s.methods == nil { + s.methods = make(map[string]method, s.numMethod) + } else if m, ok := s.methods[name]; ok { + return m, true + } + for ; s.idx < s.numMethod; s.idx++ { + rm := s.rType.Method(s.idx) + if !rm.IsExported() { + continue + } + nt := s.cache.FromType(rm.Type) + if s.rType.Kind() != reflect.Interface { + nt.Method = true + nt.MethodIndex = rm.Index + // In case of interface type method will not have a receiver, + // and to prevent checker decreasing numbers of in arguments + // return method type as not method (second argument is false). + // Also, we can not use m.Index here, because it will be + // different indexes for different types which implement + // the same interface. + } + m := method{ + Method: rm, + nature: nt, + } + s.methods[rm.Name] = m + if rm.Name == name { + return m, true } } - - return table + return method{}, false } diff --git a/types/types.go b/types/types.go index 05e0ae70..232f9b9b 100644 --- a/types/types.go +++ b/types/types.go @@ -108,9 +108,6 @@ func (m Map) Nature() Nature { nt.Strict = true for k, v := range m { if k == Extra { - if nt.Optional == nil { - nt.Optional = new(Optional) - } nt.Strict = false natureOfDefaultValue := v.Nature() nt.DefaultMapValue = &natureOfDefaultValue From 3d7e30e85912c18274c399fb41e1e9f6fb794e46 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 03:43:25 -0300 Subject: [PATCH 18/24] simplify checker code --- checker/checker.go | 51 ++++++++++++---------------------------------- 1 file changed, 13 insertions(+), 38 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index f053f36e..cf6359b6 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -187,15 +187,15 @@ func (v *Checker) visit(node ast.Node) Nature { case *ast.IdentifierNode: nt = v.identifierNode(n) case *ast.IntegerNode: - nt = v.integerNode(n) + nt = v.config.NtCache.FromType(intType) case *ast.FloatNode: - nt = v.floatNode(n) + nt = v.config.NtCache.FromType(floatType) case *ast.BoolNode: - nt = v.boolNode(n) + nt = v.config.NtCache.FromType(boolType) case *ast.StringNode: - nt = v.stringNode(n) + nt = v.config.NtCache.FromType(stringType) case *ast.ConstantNode: - nt = v.constantNode(n) + nt = v.config.NtCache.FromType(reflect.TypeOf(n.Value)) case *ast.UnaryNode: nt = v.unaryNode(n) case *ast.BinaryNode: @@ -244,8 +244,10 @@ func (v *Checker) error(node ast.Node, format string, args ...any) Nature { } func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { - if variable, ok := v.lookupVariable(node.Value); ok { - return variable.nature + for i := len(v.varScopes) - 1; i >= 0; i-- { + if v.varScopes[i].name == node.Value { + return v.varScopes[i].nature + } } if node.Value == "$env" { return Nature{} @@ -283,26 +285,6 @@ func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Natur return Nature{} } -func (v *Checker) integerNode(*ast.IntegerNode) Nature { - return v.config.NtCache.FromType(intType) -} - -func (v *Checker) floatNode(*ast.FloatNode) Nature { - return v.config.NtCache.FromType(floatType) -} - -func (v *Checker) boolNode(*ast.BoolNode) Nature { - return v.config.NtCache.FromType(boolType) -} - -func (v *Checker) stringNode(*ast.StringNode) Nature { - return v.config.NtCache.FromType(stringType) -} - -func (v *Checker) constantNode(node *ast.ConstantNode) Nature { - return v.config.NtCache.FromType(reflect.TypeOf(node.Value)) -} - func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { nt := v.visit(node.Node) nt = nt.Deref() @@ -1255,8 +1237,10 @@ func (v *Checker) variableDeclaratorNode(node *ast.VariableDeclaratorNode) Natur if _, ok := v.config.Builtins[node.Name]; ok { return v.error(node, "cannot redeclare builtin %v", node.Name) } - if _, ok := v.lookupVariable(node.Name); ok { - return v.error(node, "cannot redeclare variable %v", node.Name) + for i := len(v.varScopes) - 1; i >= 0; i-- { + if v.varScopes[i].name == node.Name { + return v.error(node, "cannot redeclare variable %v", node.Name) + } } varNature := v.visit(node.Value) v.varScopes = append(v.varScopes, varScope{node.Name, varNature}) @@ -1276,15 +1260,6 @@ func (v *Checker) sequenceNode(node *ast.SequenceNode) Nature { return last } -func (v *Checker) lookupVariable(name string) (varScope, bool) { - for i := len(v.varScopes) - 1; i >= 0; i-- { - if v.varScopes[i].name == name { - return v.varScopes[i], true - } - } - return varScope{}, false -} - func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { c := v.visit(node.Cond) if !c.IsBool() && !c.IsUnknown() { From 163df86abdc9e761ca91d7c94a5539874ca42d4a Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 03:50:49 -0300 Subject: [PATCH 19/24] simplify and remove dead code --- checker/nature/nature.go | 5 ----- checker/nature/utils.go | 18 ------------------ compiler/compiler.go | 5 ++--- parser/parser.go | 5 ++--- 4 files changed, 4 insertions(+), 29 deletions(-) diff --git a/checker/nature/nature.go b/checker/nature/nature.go index d385bc2d..e20c8ea0 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -88,11 +88,6 @@ type Cache struct { structs map[reflect.Type]Nature } -type rTypeWithKey struct { - t reflect.Type - key string -} - // NatureOf returns a Nature describing "i". If "i" is nil then it returns a // Nature describing the value "nil". func (c *Cache) NatureOf(i any) Nature { diff --git a/checker/nature/utils.go b/checker/nature/utils.go index 27555ba3..56076b29 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -1,7 +1,6 @@ package nature import ( - "fmt" "reflect" ) @@ -46,23 +45,6 @@ func (s *structData) finished() bool { s.curChild == nil // no child in process of visiting } -func (s *structData) structDebug(prefix string) { // TODO: DEBUG - if len(s.fields) == 0 { - fmt.Printf("%s[%s]\n", prefix, s.rType) - return - } - fmt.Printf("%s[%s] fields:\n", prefix, s.rType) - prefix += " " - for k, v := range s.fields { - fmt.Printf("%s%s at index %v:", prefix, k, v.Index) - if v.Nature.Kind == reflect.Struct { - v.Nature.structDebug(prefix) - } else { - fmt.Printf("%s[%s]\n", prefix, v.Nature.Type) - } - } -} - func (s *structData) structField(parentEmbed *structData, name string) (structField, bool) { if s.fields == nil { if s.numField > 0 { diff --git a/compiler/compiler.go b/compiler/compiler.go index 9d3992ff..30f6e86f 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -10,7 +10,6 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/checker" - "github.com/expr-lang/expr/checker/nature" . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" @@ -41,7 +40,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro if config != nil { c.ntCache = &c.config.NtCache } else { - c.ntCache = new(nature.Cache) + c.ntCache = new(Cache) } c.compile(tree.Node) @@ -82,7 +81,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro type compiler struct { config *conf.Config - ntCache *nature.Cache + ntCache *Cache locations []file.Location bytecode []Opcode variables int diff --git a/parser/parser.go b/parser/parser.go index e1dd111f..6034403b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -12,7 +12,6 @@ import ( "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" - "github.com/expr-lang/expr/parser/lexer" . "github.com/expr-lang/expr/parser/lexer" "github.com/expr-lang/expr/parser/operator" "github.com/expr-lang/expr/parser/utils" @@ -49,7 +48,7 @@ var predicates = map[string]struct { // Parser is a reusable parser. The zero value is ready for use. type Parser struct { - lexer *lexer.Lexer + lexer *Lexer current, stashed Token hasStash bool err *file.Error @@ -60,7 +59,7 @@ type Parser struct { func (p *Parser) Parse(input string, config *conf.Config) (*Tree, error) { if p.lexer == nil { - p.lexer = lexer.New() + p.lexer = New() } p.config = config source := file.NewSource(input) From 40d2c1453aab44f8458dfb4d03e6cdf02e7947c5 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 03:52:25 -0300 Subject: [PATCH 20/24] add checker benchmarks --- checker/checker_bench_test.go | 82 +++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 checker/checker_bench_test.go diff --git a/checker/checker_bench_test.go b/checker/checker_bench_test.go new file mode 100644 index 00000000..e2e2fca8 --- /dev/null +++ b/checker/checker_bench_test.go @@ -0,0 +1,82 @@ +package checker_test + +import ( + "testing" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/checker" + "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/parser" +) + +func BenchmarkChecker(b *testing.B) { + cases := []struct { + name, input string + }{ + {"function calls", ` +func( + func( + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + ), + func( + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + ), + func( + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + ) +) + `}, + {"unary and binary operations", ` +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 + `}, + {"deep struct access", ` +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a. +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a. +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a. +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a + `}, + } + + f := func(params ...any) (any, error) { return nil, nil } + env := map[string]any{ + "a": new(recursive), + "b": true, + "func": f, + } + config := conf.New(env) + expr.Function("func", f, f) + expr.ConstExpr("func") + + for _, c := range cases { + b.Run("name="+c.name, func(b *testing.B) { + tree, err := parser.ParseWithConfig(c.input, config) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err = checker.Check(tree, config) + if err != nil { + b.Fatal(err) + } + } + }) + } + +} + +type recursive struct { + Inner *recursive `expr:"a"` +} From 247a6a66dd87ac9ffb9f231a12b304f9fd3f06a1 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 14:31:37 -0300 Subject: [PATCH 21/24] address PR comments and minor improvements --- checker/nature/nature.go | 33 +++++++++----- checker/nature/utils.go | 95 +++++++++++++++++++++------------------- conf/config.go | 2 +- conf/env.go | 9 +++- docgen/docgen.go | 2 +- internal/deref/deref.go | 9 ++++ 6 files changed, 93 insertions(+), 57 deletions(-) diff --git a/checker/nature/nature.go b/checker/nature/nature.go index e20c8ea0..84185c36 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -190,8 +190,8 @@ func ArrayFromType(c *Cache, t reflect.Type) Nature { func (n *Nature) SetCache(c *Cache) { n.cache = c if n.Kind == reflect.Struct { - n.structData.cache = c if c.structs == nil { + n.structData.setCache(c) c.structs = map[reflect.Type]Nature{ n.Type: *n, } @@ -199,14 +199,27 @@ func (n *Nature) SetCache(c *Cache) { // invalidate local, use shared from cache n.Optional.structData = nt.Optional.structData } else { + n.structData.setCache(c) c.structs[n.Type] = *n } } - if n.Optional != nil { - if s, ok := c.methods[n.Type]; ok { - // invalidate local if set, use shared from cache + hasMethodset := n.Optional != nil && n.Optional.methodset != nil + if c.methods != nil || hasMethodset { + if c.methods == nil { + // Cache is new and the type already gathered some methods + n.Optional.methodset.setCache(c) + c.methods = map[reflect.Type]*methodset{ + n.Type: n.Optional.methodset, + } + } else if s, ok := c.methods[n.Type]; ok { + if n.Optional == nil { + n.Optional = new(Optional) + } + // Cache is not new. Invalidate local if set n.Optional.methodset = s - } else if n.Optional.methodset != nil { + } else if hasMethodset { + // Cache miss and the type already gathered some methods + n.Optional.methodset.setCache(c) c.methods[n.Type] = n.Optional.methodset } } @@ -228,7 +241,7 @@ func (n *Nature) String() string { } func (n *Nature) Deref() Nature { - t, _, changed := derefTypeKind(n.Type, n.Kind) + t, _, changed := deref.TypeKind(n.Type, n.Kind) if !changed { return *n } @@ -294,7 +307,7 @@ func (n *Nature) NumMethods() int { func (n *Nature) MethodByName(name string) (Nature, bool) { if s := n.getMethodset(); s != nil { - if m, ok := s.method(name); ok { + if m := s.method(name); m != nil { return m.nature, true } } @@ -378,7 +391,7 @@ func (n *Nature) FieldByName(name string) (Nature, bool) { } else { sd = n.cache.getStruct(n.Type).structData } - if sf, ok := sd.structField(nil, name); ok { + if sf := sd.structField(nil, name); sf != nil { return sf.Nature, true } return Nature{}, false @@ -419,7 +432,7 @@ func (n *Nature) getSlow(name string) (Nature, bool) { return nt, true } if n.Kind == reflect.Struct { - if sf, ok := n.structField(nil, name); ok { + if sf := n.structField(nil, name); sf != nil { return sf.Nature, true } } @@ -430,7 +443,7 @@ func (n *Nature) FieldIndex(name string) ([]int, bool) { if n.Kind != reflect.Struct { return nil, false } - if sf, ok := n.structField(nil, name); ok { + if sf := n.structField(nil, name); sf != nil { return sf.Index, true } return nil, false diff --git a/checker/nature/utils.go b/checker/nature/utils.go index 56076b29..620297d5 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -2,16 +2,9 @@ package nature import ( "reflect" -) -func derefTypeKind(t reflect.Type, k reflect.Kind) (_ reflect.Type, _ reflect.Kind, changed bool) { - for k == reflect.Pointer { - changed = true - t = t.Elem() - k = t.Kind() - } - return t, k, changed -} + "github.com/expr-lang/expr/internal/deref" +) func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { switch taggedName := tag.Get("expr"); taggedName { @@ -27,7 +20,7 @@ func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { type structData struct { cache *Cache rType reflect.Type - fields map[string]structField + fields map[string]*structField numField, ownIdx, anonIdx int curParent, curChild *structData @@ -39,22 +32,29 @@ type structField struct { Index []int } +func (s *structData) setCache(c *Cache) { + s.cache = c + for _, sf := range s.fields { + sf.SetCache(c) + } +} + func (s *structData) finished() bool { return s.ownIdx >= s.numField && // no own fields left to visit s.anonIdx >= s.numField && // no embedded fields to visit s.curChild == nil // no child in process of visiting } -func (s *structData) structField(parentEmbed *structData, name string) (structField, bool) { +func (s *structData) structField(parentEmbed *structData, name string) *structField { if s.fields == nil { if s.numField > 0 { - s.fields = make(map[string]structField, s.numField) + s.fields = make(map[string]*structField, s.numField) } - } else if f, ok := s.fields[name]; ok { - return f, true + } else if f := s.fields[name]; f != nil { + return f } if s.finished() { - return structField{}, false + return nil } // Lookup own fields first. @@ -73,7 +73,7 @@ func (s *structData) structField(parentEmbed *structData, name string) (structFi continue } nt := s.cache.FromType(field.Type) - sf := structField{ + sf := &structField{ Nature: nt, Index: field.Index, } @@ -82,14 +82,14 @@ func (s *structData) structField(parentEmbed *structData, name string) (structFi parentEmbed.trySet(fName, sf) } if fName == name { - return sf, true + return sf } } if s.curChild != nil { - sf, ok := s.findInEmbedded(parentEmbed, s.curChild, s.curChildIndex, name) - if ok { - return sf, true + sf := s.findInEmbedded(parentEmbed, s.curChild, s.curChildIndex, name) + if sf != nil { + return sf } } @@ -101,26 +101,26 @@ func (s *structData) structField(parentEmbed *structData, name string) (structFi if !field.Anonymous { continue } - t, k, _ := derefTypeKind(field.Type, field.Type.Kind()) + t, k, _ := deref.TypeKind(field.Type, field.Type.Kind()) if k != reflect.Struct { continue } childEmbed := s.cache.getStruct(t).structData - sf, ok := s.findInEmbedded(parentEmbed, childEmbed, field.Index, name) - if ok { - return sf, true + sf := s.findInEmbedded(parentEmbed, childEmbed, field.Index, name) + if sf != nil { + return sf } } - return structField{}, false + return nil } func (s *structData) findInEmbedded( parentEmbed, childEmbed *structData, childIndex []int, name string, -) (structField, bool) { +) *structField { // Set current parent/child data. This allows trySet to handle child fields // and add them to our struct and to the parent as well if needed s.curParent = parentEmbed @@ -145,29 +145,29 @@ func (s *structData) findInEmbedded( } // Recheck if we have what we needed from the above sync - if sf, ok := s.fields[name]; ok { - return sf, true + if sf := s.fields[name]; sf != nil { + return sf } // Try finding in the child again in case it hasn't finished if !childEmbed.finished() { - if _, ok := childEmbed.structField(s, name); ok { - return s.fields[name], true + if childEmbed.structField(s, name) != nil { + return s.fields[name] } } - return structField{}, false + return nil } -func (s *structData) trySet(name string, sf structField) { +func (s *structData) trySet(name string, sf *structField) { if _, ok := s.fields[name]; ok { return } - sf.Index = append(s.curChildIndex, sf.Index...) - s.fields[name] = structField{ + sf = &structField{ Nature: sf.Nature, - Index: sf.Index, + Index: append(s.curChildIndex, sf.Index...), } + s.fields[name] = sf if s.curParent != nil { s.curParent.trySet(name, sf) } @@ -178,7 +178,7 @@ func StructFields(c *Cache, t reflect.Type) map[string]Nature { if t == nil { return table } - t, k, _ := derefTypeKind(t, t.Kind()) + t, k, _ := deref.TypeKind(t, t.Kind()) if k == reflect.Struct { // lookup for a field with an empty name, which will cause to never find a // match, meaning everything will have been cached. @@ -195,7 +195,7 @@ type methodset struct { cache *Cache rType reflect.Type kind reflect.Kind - methods map[string]method + methods map[string]*method numMethod, idx int } @@ -204,11 +204,18 @@ type method struct { nature Nature } -func (s *methodset) method(name string) (method, bool) { +func (s *methodset) setCache(c *Cache) { + s.cache = c + for _, m := range s.methods { + m.nature.SetCache(c) + } +} + +func (s *methodset) method(name string) *method { if s.methods == nil { - s.methods = make(map[string]method, s.numMethod) - } else if m, ok := s.methods[name]; ok { - return m, true + s.methods = make(map[string]*method, s.numMethod) + } else if m := s.methods[name]; m != nil { + return m } for ; s.idx < s.numMethod; s.idx++ { rm := s.rType.Method(s.idx) @@ -227,14 +234,14 @@ func (s *methodset) method(name string) (method, bool) { // different indexes for different types which implement // the same interface. } - m := method{ + m := &method{ Method: rm, nature: nt, } s.methods[rm.Name] = m if rm.Name == name { - return m, true + return m } } - return method{}, false + return nil } diff --git a/conf/config.go b/conf/config.go index d6274389..80b42c96 100644 --- a/conf/config.go +++ b/conf/config.go @@ -63,7 +63,7 @@ func New(env any) *Config { func (c *Config) WithEnv(env any) { c.EnvObject = env - c.Env = Env(&c.NtCache, env) + c.Env = EnvWithCache(&c.NtCache, env) c.Strict = c.Env.Strict } diff --git a/conf/env.go b/conf/env.go index 43fb19eb..74b65df1 100644 --- a/conf/env.go +++ b/conf/env.go @@ -9,7 +9,14 @@ import ( "github.com/expr-lang/expr/types" ) -func Env(c *Cache, env any) Nature { +// Env returns the Nature of the given environment. +// +// Deprecated: use EnvWithCache instead. +func Env(env any) Nature { + return EnvWithCache(new(Cache), env) +} + +func EnvWithCache(c *Cache, env any) Nature { if env == nil { n := c.NatureOf(map[string]any{}) n.Strict = true diff --git a/docgen/docgen.go b/docgen/docgen.go index 6d1ec59d..5070d14f 100644 --- a/docgen/docgen.go +++ b/docgen/docgen.go @@ -85,7 +85,7 @@ func CreateDoc(i any) *Context { PkgPath: deref.Type(reflect.TypeOf(i)).PkgPath(), } - env := conf.Env(new(nature.Cache), i) + env := conf.EnvWithCache(new(nature.Cache), i) for name, t := range env.All() { if _, ok := c.Variables[Identifier(name)]; ok { continue diff --git a/internal/deref/deref.go b/internal/deref/deref.go index da3e28ce..4ad7877f 100644 --- a/internal/deref/deref.go +++ b/internal/deref/deref.go @@ -45,3 +45,12 @@ func Value(v reflect.Value) reflect.Value { } return v } + +func TypeKind(t reflect.Type, k reflect.Kind) (_ reflect.Type, _ reflect.Kind, changed bool) { + for k == reflect.Pointer { + changed = true + t = t.Elem() + k = t.Kind() + } + return t, k, changed +} From 32aebcd4096ac20d1fcaa94b9cc3cdb41dc5dc54 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 20:12:35 -0300 Subject: [PATCH 22/24] rename from SetCache to Bind --- checker/checker.go | 2 +- checker/info.go | 2 +- checker/nature/nature.go | 15 +++++++++------ checker/nature/utils.go | 8 ++++---- compiler/compiler.go | 2 +- conf/config.go | 2 +- conf/env.go | 4 ++-- 7 files changed, 19 insertions(+), 16 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index cf6359b6..ea44ecfa 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -636,7 +636,7 @@ func (v *Checker) callNode(node *ast.CallNode) Nature { // with new correct function return type. if typ := node.Type(); typ != nil && typ != anyType { nt := node.Nature() - nt.SetCache(&v.config.NtCache) // AST doesn't cache nature info + nt.Bind(&v.config.NtCache) // AST doesn't cache nature info return nt } diff --git a/checker/info.go b/checker/info.go index 74b9bd38..6b2fba3e 100644 --- a/checker/info.go +++ b/checker/info.go @@ -16,7 +16,7 @@ func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { } case *ast.MemberNode: base := n.Node.Nature() - base.SetCache(c) // AST doesn't cache nature info + base.Bind(c) // AST doesn't cache nature info base = base.Deref() if base.Kind == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 84185c36..9f453187 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -137,7 +137,7 @@ func (c *Cache) getStruct(t reflect.Type) Nature { }, } if c != nil { - nt.SetCache(c) + nt.Bind(c) } return nt } @@ -187,11 +187,14 @@ func ArrayFromType(c *Cache, t reflect.Type) Nature { return nt } -func (n *Nature) SetCache(c *Cache) { +func (n *Nature) Bind(c *Cache) { + if n.cache == c { + return + } n.cache = c if n.Kind == reflect.Struct { if c.structs == nil { - n.structData.setCache(c) + n.structData.bind(c) c.structs = map[reflect.Type]Nature{ n.Type: *n, } @@ -199,7 +202,7 @@ func (n *Nature) SetCache(c *Cache) { // invalidate local, use shared from cache n.Optional.structData = nt.Optional.structData } else { - n.structData.setCache(c) + n.structData.bind(c) c.structs[n.Type] = *n } } @@ -207,7 +210,7 @@ func (n *Nature) SetCache(c *Cache) { if c.methods != nil || hasMethodset { if c.methods == nil { // Cache is new and the type already gathered some methods - n.Optional.methodset.setCache(c) + n.Optional.methodset.bind(c) c.methods = map[reflect.Type]*methodset{ n.Type: n.Optional.methodset, } @@ -219,7 +222,7 @@ func (n *Nature) SetCache(c *Cache) { n.Optional.methodset = s } else if hasMethodset { // Cache miss and the type already gathered some methods - n.Optional.methodset.setCache(c) + n.Optional.methodset.bind(c) c.methods[n.Type] = n.Optional.methodset } } diff --git a/checker/nature/utils.go b/checker/nature/utils.go index 620297d5..3c7a33c2 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -32,10 +32,10 @@ type structField struct { Index []int } -func (s *structData) setCache(c *Cache) { +func (s *structData) bind(c *Cache) { s.cache = c for _, sf := range s.fields { - sf.SetCache(c) + sf.Bind(c) } } @@ -204,10 +204,10 @@ type method struct { nature Nature } -func (s *methodset) setCache(c *Cache) { +func (s *methodset) bind(c *Cache) { s.cache = c for _, m := range s.methods { - m.nature.SetCache(c) + m.nature.Bind(c) } } diff --git a/compiler/compiler.go b/compiler/compiler.go index 30f6e86f..7910aaf8 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1089,7 +1089,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.compile(arg) argType := arg.Type() argNature := arg.Nature() - argNature.SetCache(c.ntCache) // AST doesn't cache nature info + argNature.Bind(c.ntCache) // AST doesn't cache nature info if argType.Kind() == reflect.Ptr || argNature.IsUnknown() { if f.Deref == nil { // By default, builtins expect arguments to be dereferenced. diff --git a/conf/config.go b/conf/config.go index 80b42c96..aba35fea 100644 --- a/conf/config.go +++ b/conf/config.go @@ -50,7 +50,7 @@ func CreateNew() *Config { for _, f := range builtin.Builtins { c.Builtins[f.Name] = f } - c.Env.SetCache(&c.NtCache) + c.Env.Bind(&c.NtCache) return c } diff --git a/conf/env.go b/conf/env.go index 74b65df1..ef280c20 100644 --- a/conf/env.go +++ b/conf/env.go @@ -26,7 +26,7 @@ func EnvWithCache(c *Cache, env any) Nature { switch env := env.(type) { case types.Map: nt := env.Nature() - nt.SetCache(c) + nt.Bind(c) return nt } @@ -58,7 +58,7 @@ func EnvWithCache(c *Cache, env any) Nature { switch face := face.(type) { case types.Map: nt := face.Nature() - nt.SetCache(c) + nt.Bind(c) n.Fields[key.String()] = nt default: From f6e2d87193452a5c7bf3c2fd8c5d390d09d36731 Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 22:02:06 -0300 Subject: [PATCH 23/24] make the benchmarks more fair --- checker/checker_bench_test.go | 44 ++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/checker/checker_bench_test.go b/checker/checker_bench_test.go index e2e2fca8..840514fb 100644 --- a/checker/checker_bench_test.go +++ b/checker/checker_bench_test.go @@ -1,10 +1,13 @@ package checker_test import ( + "runtime" "testing" "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/checker" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/parser" ) @@ -59,22 +62,51 @@ a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a expr.ConstExpr("func") for _, c := range cases { - b.Run("name="+c.name, func(b *testing.B) { + batchSize := 100_000 + if batchSize > b.N { + batchSize = b.N + } + trees := make([]*parser.Tree, 0, batchSize) + for i := 0; i < batchSize; i++ { tree, err := parser.ParseWithConfig(c.input, config) if err != nil { b.Fatal(err) } - b.ReportAllocs() - b.ResetTimer() + trees = append(trees, tree) + } + runtime.GC() // try to cleanup the mess from the initialization + + b.Run("name="+c.name, func(b *testing.B) { + var err error for i := 0; i < b.N; i++ { - _, err = checker.Check(tree, config) - if err != nil { - b.Fatal(err) + j := i + if j < 0 || j >= len(trees) { + b.StopTimer() + invalidateTrees(trees...) + j = 0 + b.StartTimer() } + + _, err = checker.Check(trees[j], config) + } + b.StopTimer() + if err != nil { + b.Fatal(err) } }) } +} + +type visitorFunc func(*ast.Node) +func (f visitorFunc) Visit(node *ast.Node) { f(node) } + +func invalidateTrees(trees ...*parser.Tree) { + for _, tree := range trees { + ast.Walk(&tree.Node, visitorFunc(func(node *ast.Node) { + (*node).SetNature(nature.Nature{}) + })) + } } type recursive struct { From eafdb3c94ed26fa517f4849ed594368acd74566b Mon Sep 17 00:00:00 2001 From: Diego Augusto Molina Date: Sun, 31 Aug 2025 23:46:27 -0300 Subject: [PATCH 24/24] fix benchmark option not being applied --- checker/checker_bench_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/checker/checker_bench_test.go b/checker/checker_bench_test.go index 840514fb..0583d735 100644 --- a/checker/checker_bench_test.go +++ b/checker/checker_bench_test.go @@ -58,8 +58,8 @@ a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a "func": f, } config := conf.New(env) - expr.Function("func", f, f) - expr.ConstExpr("func") + expr.Function("func", f, f)(config) + expr.ConstExpr("func")(config) for _, c := range cases { batchSize := 100_000