diff --git a/ast/node.go b/ast/node.go index 02923ac5..40c507aa 100644 --- a/ast/node.go +++ b/ast/node.go @@ -66,7 +66,7 @@ func (n *base) Type() reflect.Type { // SetType sets the type of the node. func (n *base) SetType(t reflect.Type) { - n.nature.Type = t + n.nature = nature.FromType(t) } // NilNode represents nil. diff --git a/checker/checker.go b/checker/checker.go index f4923413..ea44ecfa 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "regexp" + "time" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" @@ -13,15 +14,88 @@ import ( "github.com/expr-lang/expr/parser" ) +var ( + anyType = reflect.TypeOf(new(any)).Elem() + boolType = reflect.TypeOf(true) + intType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) + stringType = reflect.TypeOf("") + arrayType = reflect.TypeOf([]any{}) + mapType = reflect.TypeOf(map[string]any{}) + timeType = reflect.TypeOf(time.Time{}) + durationType = reflect.TypeOf(time.Duration(0)) + + anyTypeSlice = []reflect.Type{anyType} +) + +// ParseCheck parses input expression and checks its types. Also, it applies +// all provided patchers. In case of error, it returns error with a tree. +func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { + tree, err := parser.ParseWithConfig(input, config) + if err != nil { + return tree, err + } + + _, err = new(Checker).PatchAndCheck(tree, config) + if err != nil { + return tree, err + } + + return tree, nil +} + +// Check calls Check on a disposable Checker. +func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + return new(Checker).Check(tree, config) +} + +type Checker struct { + config *conf.Config + predicateScopes []predicateScope + varScopes []varScope + err *file.Error + needsReset bool +} + +type predicateScope struct { + collection Nature + vars []varScope +} + +type varScope struct { + name string + nature Nature +} + +// PatchAndCheck applies all patchers and checks the tree. +func (c *Checker) PatchAndCheck(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + c.reset(config) + if len(config.Visitors) > 0 { + // Run all patchers that dont support being run repeatedly first + c.runVisitors(tree, false) + + // Run patchers that require multiple passes next (currently only Operator patching) + c.runVisitors(tree, true) + } + return c.Check(tree, config) +} + +// Check checks types of the expression tree. It returns type of the expression +// and error if any. If config is nil, then default configuration will be used. +func (c *Checker) Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { + c.reset(config) + return c.check(tree) +} + // Run visitors in a given config over the given tree // runRepeatable controls whether to filter for only vistors that require multiple passes or not -func runVisitors(tree *parser.Tree, config *conf.Config, runRepeatable bool) { +func (c *Checker) runVisitors(tree *parser.Tree, runRepeatable bool) { for { more := false - for _, v := range config.Visitors { + for _, v := range c.config.Visitors { // We need to perform types check, because some visitors may rely on // types information available in the tree. - _, _ = Check(tree, config) + _, _ = c.Check(tree, c.config) r, repeatable := v.(interface { Reset() @@ -47,39 +121,8 @@ func runVisitors(tree *parser.Tree, config *conf.Config, runRepeatable bool) { } } -// ParseCheck parses input expression and checks its types. Also, it applies -// all provided patchers. In case of error, it returns error with a tree. -func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { - tree, err := parser.ParseWithConfig(input, config) - if err != nil { - return tree, err - } - - if len(config.Visitors) > 0 { - // Run all patchers that dont support being run repeatedly first - runVisitors(tree, config, false) - - // Run patchers that require multiple passes next (currently only Operator patching) - runVisitors(tree, config, true) - } - _, err = Check(tree, config) - if err != nil { - return tree, err - } - - return tree, nil -} - -// Check checks types of the expression tree. It returns type of the expression -// and error if any. If config is nil, then default configuration will be used. -func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { - if config == nil { - config = conf.New(nil) - } - - v := &checker{config: config} - - nt := v.visit(tree.Node) +func (c *Checker) check(tree *parser.Tree) (reflect.Type, error) { + nt := c.visit(tree.Node) // To keep compatibility with previous versions, we should return any, if nature is unknown. t := nt.Type @@ -87,25 +130,25 @@ func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { t = anyType } - if v.err != nil { - return t, v.err.Bind(tree.Source) + if c.err != nil { + return t, c.err.Bind(tree.Source) } - if v.config.Expect != reflect.Invalid { - if v.config.ExpectAny { - if isUnknown(nt) { + if c.config.Expect != reflect.Invalid { + if c.config.ExpectAny { + if nt.IsUnknown() { return t, nil } } - switch v.config.Expect { + switch c.config.Expect { case reflect.Int, reflect.Int64, reflect.Float64: - if !isNumber(nt) { - return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, nt) + if !nt.IsNumber() { + return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } default: - if nt.Kind() != v.config.Expect { - return nil, fmt.Errorf("expected %v, but got %s", v.config.Expect, nt) + if nt.Kind != c.config.Expect { + return nil, fmt.Errorf("expected %v, but got %s", c.config.Expect, nt.String()) } } } @@ -113,81 +156,76 @@ func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { return t, nil } -type checker struct { - config *conf.Config - predicateScopes []predicateScope - varScopes []varScope - err *file.Error -} - -type predicateScope struct { - collection Nature - vars map[string]Nature -} +func (c *Checker) reset(config *conf.Config) { + if c.needsReset { + clearSlice(c.predicateScopes) + clearSlice(c.varScopes) + c.predicateScopes = c.predicateScopes[:0] + c.varScopes = c.varScopes[:0] + c.err = nil + } + c.needsReset = true -type varScope struct { - name string - nature Nature + if config == nil { + config = conf.New(nil) + } + c.config = config } -type info struct { - method bool - fn *builtin.Function - - // elem is element type of array or map. - // Arrays created with type []any, but - // we would like to detect expressions - // like `42 in ["a"]` as invalid. - elem reflect.Type +func clearSlice[S ~[]E, E any](s S) { + var zero E + for i := range s { + s[i] = zero + } } -func (v *checker) visit(node ast.Node) Nature { +func (v *Checker) visit(node ast.Node) Nature { var nt Nature switch n := node.(type) { case *ast.NilNode: - nt = v.NilNode(n) + nt = v.config.NtCache.NatureOf(nil) case *ast.IdentifierNode: - nt = v.IdentifierNode(n) + nt = v.identifierNode(n) case *ast.IntegerNode: - nt = v.IntegerNode(n) + nt = v.config.NtCache.FromType(intType) case *ast.FloatNode: - nt = v.FloatNode(n) + nt = v.config.NtCache.FromType(floatType) case *ast.BoolNode: - nt = v.BoolNode(n) + nt = v.config.NtCache.FromType(boolType) case *ast.StringNode: - nt = v.StringNode(n) + nt = v.config.NtCache.FromType(stringType) case *ast.ConstantNode: - nt = v.ConstantNode(n) + nt = v.config.NtCache.FromType(reflect.TypeOf(n.Value)) case *ast.UnaryNode: - nt = v.UnaryNode(n) + nt = v.unaryNode(n) case *ast.BinaryNode: - nt = v.BinaryNode(n) + nt = v.binaryNode(n) case *ast.ChainNode: - nt = v.ChainNode(n) + nt = v.chainNode(n) case *ast.MemberNode: - nt = v.MemberNode(n) + nt = v.memberNode(n) case *ast.SliceNode: - nt = v.SliceNode(n) + nt = v.sliceNode(n) case *ast.CallNode: - nt = v.CallNode(n) + nt = v.callNode(n) case *ast.BuiltinNode: - nt = v.BuiltinNode(n) + nt = v.builtinNode(n) case *ast.PredicateNode: - nt = v.PredicateNode(n) + nt = v.predicateNode(n) case *ast.PointerNode: - nt = v.PointerNode(n) + nt = v.pointerNode(n) case *ast.VariableDeclaratorNode: - nt = v.VariableDeclaratorNode(n) + nt = v.variableDeclaratorNode(n) case *ast.SequenceNode: - nt = v.SequenceNode(n) + nt = v.sequenceNode(n) case *ast.ConditionalNode: - nt = v.ConditionalNode(n) + nt = v.conditionalNode(n) case *ast.ArrayNode: - nt = v.ArrayNode(n) + nt = v.arrayNode(n) case *ast.MapNode: - nt = v.MapNode(n) + nt = v.mapNode(n) case *ast.PairNode: - nt = v.PairNode(n) + nt = v.pairNode(n) default: panic(fmt.Sprintf("undefined node type (%T)", node)) } @@ -195,100 +233,88 @@ func (v *checker) visit(node ast.Node) Nature { return nt } -func (v *checker) error(node ast.Node, format string, args ...any) Nature { +func (v *Checker) error(node ast.Node, format string, args ...any) Nature { if v.err == nil { // show first error v.err = &file.Error{ Location: node.Location(), Message: fmt.Sprintf(format, args...), } } - return unknown + return Nature{} } -func (v *checker) NilNode(*ast.NilNode) Nature { - return nilNature -} - -func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature { - if variable, ok := v.lookupVariable(node.Value); ok { - return variable.nature +func (v *Checker) identifierNode(node *ast.IdentifierNode) Nature { + for i := len(v.varScopes) - 1; i >= 0; i-- { + if v.varScopes[i].name == node.Value { + return v.varScopes[i].nature + } } if node.Value == "$env" { - return unknown + return Nature{} } return v.ident(node, node.Value, v.config.Strict, true) } // ident method returns type of environment variable, builtin or function. -func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Nature { +func (v *Checker) ident(node ast.Node, name string, strict, builtins bool) Nature { if nt, ok := v.config.Env.Get(name); ok { return nt } if builtins { if fn, ok := v.config.Functions[name]; ok { - return Nature{Type: fn.Type(), Func: fn} + nt := v.config.NtCache.FromType(fn.Type()) + if nt.FuncData == nil { + nt.FuncData = new(FuncData) + } + nt.FuncData.Func = fn + return nt } if fn, ok := v.config.Builtins[name]; ok { - return Nature{Type: fn.Type(), Func: fn} + nt := v.config.NtCache.FromType(fn.Type()) + if nt.FuncData == nil { + nt.FuncData = new(FuncData) + } + nt.FuncData.Func = fn + return nt } } if v.config.Strict && strict { - return v.error(node, "unknown name %v", name) + return v.error(node, "unknown name %s", name) } - return unknown + return Nature{} } -func (v *checker) IntegerNode(*ast.IntegerNode) Nature { - return integerNature -} - -func (v *checker) FloatNode(*ast.FloatNode) Nature { - return floatNature -} - -func (v *checker) BoolNode(*ast.BoolNode) Nature { - return boolNature -} - -func (v *checker) StringNode(*ast.StringNode) Nature { - return stringNature -} - -func (v *checker) ConstantNode(node *ast.ConstantNode) Nature { - return Nature{Type: reflect.TypeOf(node.Value)} -} - -func (v *checker) UnaryNode(node *ast.UnaryNode) Nature { +func (v *Checker) unaryNode(node *ast.UnaryNode) Nature { nt := v.visit(node.Node) nt = nt.Deref() switch node.Operator { case "!", "not": - if isBool(nt) { - return boolNature + if nt.IsBool() { + return v.config.NtCache.FromType(boolType) } - if isUnknown(nt) { - return boolNature + if nt.IsUnknown() { + return v.config.NtCache.FromType(boolType) } case "+", "-": - if isNumber(nt) { + if nt.IsNumber() { return nt } - if isUnknown(nt) { - return unknown + if nt.IsUnknown() { + return Nature{} } default: - return v.error(node, "unknown operator (%v)", node.Operator) + return v.error(node, "unknown operator (%s)", node.Operator) } - return v.error(node, `invalid operation: %v (mismatched type %s)`, node.Operator, nt) + return v.error(node, `invalid operation: %s (mismatched type %s)`, node.Operator, nt.String()) } -func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { +func (v *Checker) binaryNode(node *ast.BinaryNode) Nature { l := v.visit(node.Left) r := v.visit(node.Right) @@ -297,134 +323,136 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { switch node.Operator { case "==", "!=": - if isComparable(l, r) { - return boolNature + if l.ComparableTo(r) { + return v.config.NtCache.FromType(boolType) } case "or", "||", "and", "&&": - if isBool(l) && isBool(r) { - return boolNature + if l.IsBool() && r.IsBool() { + return v.config.NtCache.FromType(boolType) } - if or(l, r, isBool) { - return boolNature + if l.MaybeCompatible(r, BoolCheck) { + return v.config.NtCache.FromType(boolType) } case "<", ">", ">=", "<=": - if isNumber(l) && isNumber(r) { - return boolNature + if l.IsNumber() && r.IsNumber() { + return v.config.NtCache.FromType(boolType) } - if isString(l) && isString(r) { - return boolNature + if l.IsString() && r.IsString() { + return v.config.NtCache.FromType(boolType) } - if isTime(l) && isTime(r) { - return boolNature + if l.IsTime() && r.IsTime() { + return v.config.NtCache.FromType(boolType) } - if isDuration(l) && isDuration(r) { - return boolNature + if l.IsDuration() && r.IsDuration() { + return v.config.NtCache.FromType(boolType) } - if or(l, r, isNumber, isString, isTime, isDuration) { - return boolNature + if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { + return v.config.NtCache.FromType(boolType) } case "-": - if isNumber(l) && isNumber(r) { - return combined(l, r) + if l.IsNumber() && r.IsNumber() { + return l.PromoteNumericNature(r) } - if isTime(l) && isTime(r) { - return durationNature + if l.IsTime() && r.IsTime() { + return v.config.NtCache.FromType(durationType) } - if isTime(l) && isDuration(r) { - return timeNature + if l.IsTime() && r.IsDuration() { + return v.config.NtCache.FromType(timeType) } - if isDuration(l) && isDuration(r) { - return durationNature + if l.IsDuration() && r.IsDuration() { + return v.config.NtCache.FromType(durationType) } - if or(l, r, isNumber, isTime, isDuration) { - return unknown + if l.MaybeCompatible(r, NumberCheck, TimeCheck, DurationCheck) { + return Nature{} } case "*": - if isNumber(l) && isNumber(r) { - return combined(l, r) + if l.IsNumber() && r.IsNumber() { + return l.PromoteNumericNature(r) } - if isNumber(l) && isDuration(r) { - return durationNature + if l.IsNumber() && r.IsDuration() { + return v.config.NtCache.FromType(durationType) } - if isDuration(l) && isNumber(r) { - return durationNature + if l.IsDuration() && r.IsNumber() { + return v.config.NtCache.FromType(durationType) } - if isDuration(l) && isDuration(r) { - return durationNature + if l.IsDuration() && r.IsDuration() { + return v.config.NtCache.FromType(durationType) } - if or(l, r, isNumber, isDuration) { - return unknown + if l.MaybeCompatible(r, NumberCheck, DurationCheck) { + return Nature{} } case "/": - if isNumber(l) && isNumber(r) { - return floatNature + if l.IsNumber() && r.IsNumber() { + return v.config.NtCache.FromType(floatType) } - if or(l, r, isNumber) { - return floatNature + if l.MaybeCompatible(r, NumberCheck) { + return v.config.NtCache.FromType(floatType) } case "**", "^": - if isNumber(l) && isNumber(r) { - return floatNature + if l.IsNumber() && r.IsNumber() { + return v.config.NtCache.FromType(floatType) } - if or(l, r, isNumber) { - return floatNature + if l.MaybeCompatible(r, NumberCheck) { + return v.config.NtCache.FromType(floatType) } case "%": - if isInteger(l) && isInteger(r) { - return integerNature + if l.IsInteger() && r.IsInteger() { + return v.config.NtCache.FromType(intType) } - if or(l, r, isInteger) { - return integerNature + if l.MaybeCompatible(r, IntegerCheck) { + return v.config.NtCache.FromType(intType) } case "+": - if isNumber(l) && isNumber(r) { - return combined(l, r) + if l.IsNumber() && r.IsNumber() { + return l.PromoteNumericNature(r) } - if isString(l) && isString(r) { - return stringNature + if l.IsString() && r.IsString() { + return v.config.NtCache.FromType(stringType) } - if isTime(l) && isDuration(r) { - return timeNature + if l.IsTime() && r.IsDuration() { + return v.config.NtCache.FromType(timeType) } - if isDuration(l) && isTime(r) { - return timeNature + if l.IsDuration() && r.IsTime() { + return v.config.NtCache.FromType(timeType) } - if isDuration(l) && isDuration(r) { - return durationNature + if l.IsDuration() && r.IsDuration() { + return v.config.NtCache.FromType(durationType) } - if or(l, r, isNumber, isString, isTime, isDuration) { - return unknown + if l.MaybeCompatible(r, NumberCheck, StringCheck, TimeCheck, DurationCheck) { + return Nature{} } case "in": - if (isString(l) || isUnknown(l)) && isStruct(r) { - return boolNature + if (l.IsString() || l.IsUnknown()) && r.IsStruct() { + return v.config.NtCache.FromType(boolType) } - if isMap(r) { - if !isUnknown(l) && !l.AssignableTo(r.Key()) { - return v.error(node, "cannot use %v as type %v in map key", l, r.Key()) + if r.IsMap() { + rKey := r.Key() + if !l.IsUnknown() && !l.AssignableTo(rKey) { + return v.error(node, "cannot use %s as type %s in map key", l.String(), rKey.String()) } - return boolNature + return v.config.NtCache.FromType(boolType) } - if isArray(r) { - if !isComparable(l, r.Elem()) { - return v.error(node, "cannot use %v as type %v in array", l, r.Elem()) + if r.IsArray() { + rElem := r.Elem() + if !l.ComparableTo(rElem) { + return v.error(node, "cannot use %s as type %s in array", l.String(), rElem.String()) } - return boolNature + return v.config.NtCache.FromType(boolType) } - if isUnknown(l) && anyOf(r, isString, isArray, isMap) { - return boolNature + if l.IsUnknown() && r.IsAnyOf(StringCheck, ArrayCheck, MapCheck) { + return v.config.NtCache.FromType(boolType) } - if isUnknown(r) { - return boolNature + if r.IsUnknown() { + return v.config.NtCache.FromType(boolType) } case "matches": @@ -434,57 +462,54 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { return v.error(node, err.Error()) } } - if isString(l) && isString(r) { - return boolNature + if l.IsString() && r.IsString() { + return v.config.NtCache.FromType(boolType) } - if or(l, r, isString) { - return boolNature + if l.MaybeCompatible(r, StringCheck) { + return v.config.NtCache.FromType(boolType) } case "contains", "startsWith", "endsWith": - if isString(l) && isString(r) { - return boolNature + if l.IsString() && r.IsString() { + return v.config.NtCache.FromType(boolType) } - if or(l, r, isString) { - return boolNature + if l.MaybeCompatible(r, StringCheck) { + return v.config.NtCache.FromType(boolType) } case "..": - if isInteger(l) && isInteger(r) { - return arrayOf(integerNature) - } - if or(l, r, isInteger) { - return arrayOf(integerNature) + if l.IsInteger() && r.IsInteger() || l.MaybeCompatible(r, IntegerCheck) { + return ArrayFromType(&v.config.NtCache, intType) } case "??": - if isNil(l) && !isNil(r) { + if l.Nil && !r.Nil { return r } - if !isNil(l) && isNil(r) { + if !l.Nil && r.Nil { return l } - if isNil(l) && isNil(r) { - return nilNature + if l.Nil && r.Nil { + return v.config.NtCache.NatureOf(nil) } if r.AssignableTo(l) { return l } - return unknown + return Nature{} default: - return v.error(node, "unknown operator (%v)", node.Operator) + return v.error(node, "unknown operator (%s)", node.Operator) } - return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r) + return v.error(node, `invalid operation: %s (mismatched types %s and %s)`, node.Operator, l.String(), r.String()) } -func (v *checker) ChainNode(node *ast.ChainNode) Nature { +func (v *Checker) chainNode(node *ast.ChainNode) Nature { return v.visit(node.Node) } -func (v *checker) MemberNode(node *ast.MemberNode) Nature { +func (v *Checker) memberNode(node *ast.MemberNode) Nature { // $env variable if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" { if name, ok := node.Property.(*ast.StringNode); ok { @@ -498,19 +523,19 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { } return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } - return unknown + return Nature{} } base := v.visit(node.Node) prop := v.visit(node.Property) - if isUnknown(base) { - return unknown + if base.IsUnknown() { + return Nature{} } if name, ok := node.Property.(*ast.StringNode); ok { - if isNil(base) { - return v.error(node, "type nil has no field %v", name.Value) + if base.Nil { + return v.error(node, "type nil has no field %s", name.Value) } // First, check methods defined on base type itself, @@ -522,23 +547,23 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { base = base.Deref() - switch base.Kind() { + switch base.Kind { case reflect.Map: - if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { - return v.error(node.Property, "cannot use %v to get an element from %v", prop, base) + if !prop.AssignableTo(base.Key()) && !prop.IsUnknown() { + return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String()) } - if prop, ok := node.Property.(*ast.StringNode); ok { + if prop, ok := node.Property.(*ast.StringNode); ok && base.Optional != nil { if field, ok := base.Fields[prop.Value]; ok { return field } else if base.Strict { - return v.error(node.Property, "unknown field %v", prop.Value) + return v.error(node.Property, "unknown field %s", prop.Value) } } return base.Elem() case reflect.Array, reflect.Slice: - if !isInteger(prop) && !isUnknown(prop) { - return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop) + if !prop.IsInteger() && !prop.IsUnknown() { + return v.error(node.Property, "array elements can only be selected using an integer (got %s)", prop.String()) } return base.Elem() @@ -546,12 +571,12 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value if field, ok := base.FieldByName(propertyName); ok { - return Nature{Type: field.Type} + return v.config.NtCache.FromType(field.Type) } if node.Method { - return v.error(node, "type %v has no method %v", base, propertyName) + return v.error(node, "type %v has no method %v", base.String(), propertyName) } - return v.error(node, "type %v has no field %v", base, propertyName) + return v.error(node, "type %v has no field %v", base.String(), propertyName) } } @@ -559,47 +584,45 @@ func (v *checker) MemberNode(node *ast.MemberNode) Nature { if name, ok := node.Property.(*ast.StringNode); ok { if node.Method { - return v.error(node, "type %v has no method %v", base, name.Value) + return v.error(node, "type %v has no method %v", base.String(), name.Value) } - return v.error(node, "type %v has no field %v", base, name.Value) + return v.error(node, "type %v has no field %v", base.String(), name.Value) } - return v.error(node, "type %v[%v] is undefined", base, prop) + return v.error(node, "type %v[%v] is undefined", base.String(), prop.String()) } -func (v *checker) SliceNode(node *ast.SliceNode) Nature { +func (v *Checker) sliceNode(node *ast.SliceNode) Nature { nt := v.visit(node.Node) - if isUnknown(nt) { - return unknown + if nt.IsUnknown() { + return Nature{} } - switch nt.Kind() { + switch nt.Kind { case reflect.String, reflect.Array, reflect.Slice: // ok default: - return v.error(node, "cannot slice %s", nt) + return v.error(node, "cannot slice %s", nt.String()) } if node.From != nil { from := v.visit(node.From) - if !isInteger(from) && !isUnknown(from) { - return v.error(node.From, "non-integer slice index %v", from) + if !from.IsInteger() && !from.IsUnknown() { + return v.error(node.From, "non-integer slice index %v", from.String()) } } if node.To != nil { to := v.visit(node.To) - if !isInteger(to) && !isUnknown(to) { - return v.error(node.To, "non-integer slice index %v", to) + if !to.IsInteger() && !to.IsUnknown() { + return v.error(node.To, "non-integer slice index %v", to.String()) } } return nt } -func (v *checker) CallNode(node *ast.CallNode) Nature { - nt := v.functionReturnType(node) - +func (v *Checker) callNode(node *ast.CallNode) Nature { // Check if type was set on node (for example, by patcher) // and use node type instead of function return type. // @@ -611,18 +634,19 @@ func (v *checker) CallNode(node *ast.CallNode) Nature { // fix `errCall()` to return proper type, so on second // checker pass we should replace anyType on method node // with new correct function return type. - if node.Type() != nil && node.Type() != anyType { - return node.Nature() + if typ := node.Type(); typ != nil && typ != anyType { + nt := node.Nature() + nt.Bind(&v.config.NtCache) // AST doesn't cache nature info + return nt } - return nt -} - -func (v *checker) functionReturnType(node *ast.CallNode) Nature { nt := v.visit(node.Callee) + if nt.IsUnknown() { + return Nature{} + } - if nt.Func != nil { - return v.checkFunction(nt.Func, node, node.Arguments) + if nt.FuncData != nil && nt.FuncData.Func != nil { + return v.checkFunction(nt.FuncData.Func, node, node.Arguments) } fnName := "function" @@ -635,122 +659,126 @@ func (v *checker) functionReturnType(node *ast.CallNode) Nature { } } - if isUnknown(nt) { - return unknown - } - - if isNil(nt) { + if nt.Nil { return v.error(node, "%v is nil; cannot call nil as function", fnName) } - switch nt.Kind() { - case reflect.Func: + if nt.Kind == reflect.Func { outType, err := v.checkArguments(fnName, nt, node.Arguments, node) if err != nil { if v.err == nil { v.err = err } - return unknown + return Nature{} } return outType } - return v.error(node, "%s is not callable", nt) + return v.error(node, "%s is not callable", nt.String()) } -func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { +func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { switch node.Name { case "all", "none", "any", "one": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return boolNature + return v.config.NtCache.FromType(boolType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "filter": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - if isUnknown(collection) { - return arrayNature + if collection.IsUnknown() { + return v.config.NtCache.FromType(arrayType) } - return arrayOf(collection.Elem()) + collection = collection.Elem() + return collection.MakeArrayOf() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "map": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, scopeVar{"index", integerNature}) + v.begin(collection, varScope{"index", v.config.NtCache.FromType(intType)}) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - return arrayOf(*predicate.PredicateOut) + return predicate.Ref.MakeArrayOf() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "count": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } if len(node.Arguments) == 1 { - return integerNature + return v.config.NtCache.FromType(intType) } v.begin(collection) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return integerNature + return v.config.NtCache.FromType(intType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sum": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } if len(node.Arguments) == 2 { @@ -758,86 +786,95 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { return predicate.Out(0) } } else { - if isUnknown(collection) { - return unknown + if collection.IsUnknown() { + return Nature{} } return collection.Elem() } case "find", "findLast": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - if isUnknown(collection) { - return unknown + if collection.IsUnknown() { + return Nature{} } return collection.Elem() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "findIndex", "findLastIndex": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) + predicateOut := predicate.Out(0) + if !predicateOut.IsBool() && !predicateOut.IsUnknown() { + return v.error(node.Arguments[1], "predicate should return boolean (got %s)", predicateOut.String()) } - return integerNature + return v.config.NtCache.FromType(intType) } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "groupBy": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { - groups := arrayOf(collection.Elem()) - return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &groups} + collection = collection.Elem() + collection = collection.MakeArrayOf() + nt := v.config.NtCache.NatureOf(map[any][]any{}) + nt.Ref = &collection + return nt } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sortBy": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } v.begin(collection) @@ -848,21 +885,22 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { _ = v.visit(node.Arguments[2]) } - if isFunc(predicate) && + if predicate.IsFunc() && predicate.NumOut() == 1 && - predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown() { return collection } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "reduce": - collection := v.visit(node.Arguments[0]).Deref() - if !isArray(collection) && !isUnknown(collection) { - return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + collection := v.visit(node.Arguments[0]) + collection = collection.Deref() + if !collection.IsArray() && !collection.IsUnknown() { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) } - v.begin(collection, scopeVar{"index", integerNature}, scopeVar{"acc", unknown}) + v.begin(collection, varScope{"index", v.config.NtCache.FromType(intType)}, varScope{"acc", Nature{}}) predicate := v.visit(node.Arguments[1]) v.end() @@ -870,8 +908,8 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { _ = v.visit(node.Arguments[2]) } - if isFunc(predicate) && predicate.NumOut() == 1 { - return *predicate.PredicateOut + if predicate.IsFunc() && predicate.NumOut() == 1 { + return *predicate.Ref } return v.error(node.Arguments[1], "predicate should has two input and one output param") @@ -888,24 +926,18 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { return v.error(node, "unknown builtin %v", node.Name) } -type scopeVar struct { - varName string - varNature Nature -} - -func (v *checker) begin(collectionNature Nature, vars ...scopeVar) { - scope := predicateScope{collection: collectionNature, vars: make(map[string]Nature)} - for _, v := range vars { - scope.vars[v.varName] = v.varNature - } - v.predicateScopes = append(v.predicateScopes, scope) +func (v *Checker) begin(collectionNature Nature, vars ...varScope) { + v.predicateScopes = append(v.predicateScopes, predicateScope{ + collection: collectionNature, + vars: vars, + }) } -func (v *checker) end() { +func (v *Checker) end() { v.predicateScopes = v.predicateScopes[:len(v.predicateScopes)-1] } -func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { +func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if len(node.Arguments) != 2 { return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments)) } @@ -919,34 +951,34 @@ func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { return nt } } - return unknown + return Nature{} } - if isUnknown(base) { - return unknown + if base.IsUnknown() { + return Nature{} } - switch base.Kind() { + switch base.Kind { case reflect.Slice, reflect.Array: - if !isInteger(prop) && !isUnknown(prop) { - return v.error(node.Arguments[1], "non-integer slice index %s", prop) + if !prop.IsInteger() && !prop.IsUnknown() { + return v.error(node.Arguments[1], "non-integer slice index %s", prop.String()) } return base.Elem() case reflect.Map: - if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { - return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop, base) + if !prop.AssignableTo(base.Key()) && !prop.IsUnknown() { + return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop.String(), base.String()) } return base.Elem() } - return v.error(node.Arguments[0], "type %v does not support indexing", base) + return v.error(node.Arguments[0], "type %v does not support indexing", base.String()) } -func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { +func (v *Checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { if f.Validate != nil { args := make([]reflect.Type, len(arguments)) for i, arg := range arguments { argNature := v.visit(arg) - if isUnknown(argNature) { + if argNature.IsUnknown() { args[i] = anyType } else { args[i] = argNature.Type @@ -956,21 +988,21 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if err != nil { return v.error(node, "%v", err) } - return Nature{Type: t} + return v.config.NtCache.FromType(t) } else if len(f.Types) == 0 { - nt, err := v.checkArguments(f.Name, Nature{Type: f.Type()}, arguments, node) + nt, err := v.checkArguments(f.Name, v.config.NtCache.FromType(f.Type()), arguments, node) if err != nil { if v.err == nil { v.err = err } - return unknown + return Nature{} } // No type was specified, so we assume the function returns any. return nt } var lastErr *file.Error for _, t := range f.Types { - outNature, err := v.checkArguments(f.Name, Nature{Type: t}, arguments, node) + outNature, err := v.checkArguments(f.Name, v.config.NtCache.FromType(t), arguments, node) if err != nil { lastErr = err continue @@ -989,30 +1021,31 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] if v.err == nil { v.err = lastErr } - return unknown + return Nature{} } return v.error(node, "no matching overload for %v", f.Name) } -func (v *checker) checkArguments( +func (v *Checker) checkArguments( name string, fn Nature, arguments []ast.Node, node ast.Node, ) (Nature, *file.Error) { - if isUnknown(fn) { - return unknown, nil + if fn.IsUnknown() { + return Nature{}, nil } - if fn.NumOut() == 0 { - return unknown, &file.Error{ + numOut := fn.NumOut() + if numOut == 0 { + return Nature{}, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } - if numOut := fn.NumOut(); numOut > 2 { - return unknown, &file.Error{ + if numOut > 2 { + return Nature{}, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), } @@ -1031,7 +1064,8 @@ func (v *checker) checkArguments( } var err *file.Error - if fn.IsVariadic() { + isVariadic := fn.IsVariadic() + if isVariadic { if len(arguments) < fnNumIn-1 { err = &file.Error{ Location: node.Location(), @@ -1066,31 +1100,31 @@ func (v *checker) checkArguments( argNature := v.visit(arg) var in Nature - if fn.IsVariadic() && i >= fnNumIn-1 { + if isVariadic && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. - in = fn.In(fn.NumIn() - 1).Elem() + in = fn.InElem(fnNumIn - 1) } else { in = fn.In(i + fnInOffset) } - if isFloat(in) && isInteger(argNature) { + if in.IsFloat() && argNature.IsInteger() { traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in) continue } - if isInteger(in) && isInteger(argNature) && argNature.Kind() != in.Kind() { + if in.IsInteger() && argNature.IsInteger() && argNature.Kind != in.Kind { traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) continue } - if isNil(argNature) { - if in.Kind() == reflect.Ptr || in.Kind() == reflect.Interface { + if argNature.Nil { + if in.Kind == reflect.Ptr || in.Kind == reflect.Interface { continue } - return unknown, &file.Error{ + return Nature{}, &file.Error{ Location: arg.Location(), - Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in, name), + Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in.String(), name), } } @@ -1102,12 +1136,15 @@ func (v *checker) checkArguments( // We also need to check if dereference arg type is assignable to the function input type. // For example, func(int) and argument *int. In this case we will add OpDeref to the argument, // so we can call the function with *int argument. - assignable = assignable || argNature.Deref().AssignableTo(in) + if !assignable && argNature.IsPointer() { + nt := argNature.Deref() + assignable = nt.AssignableTo(in) + } - if !assignable && !isUnknown(argNature) { - return unknown, &file.Error{ + if !assignable && !argNature.IsUnknown() { + return Nature{}, &file.Error{ Location: arg.Location(), - Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature, in, name), + Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature.String(), in.String(), name), } } } @@ -1152,44 +1189,45 @@ func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Na } } -func (v *checker) PredicateNode(node *ast.PredicateNode) Nature { +func (v *Checker) predicateNode(node *ast.PredicateNode) Nature { nt := v.visit(node.Node) var out []reflect.Type - if isUnknown(nt) { + if nt.IsUnknown() { out = append(out, anyType) - } else if !isNil(nt) { + } else if !nt.Nil { out = append(out, nt.Type) } - return Nature{ - Type: reflect.FuncOf([]reflect.Type{anyType}, out, false), - PredicateOut: &nt, - } + n := v.config.NtCache.FromType(reflect.FuncOf(anyTypeSlice, out, false)) + n.Ref = &nt + return n } -func (v *checker) PointerNode(node *ast.PointerNode) Nature { +func (v *Checker) pointerNode(node *ast.PointerNode) Nature { if len(v.predicateScopes) == 0 { return v.error(node, "cannot use pointer accessor outside predicate") } scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { - if isUnknown(scope.collection) { - return unknown + if scope.collection.IsUnknown() { + return Nature{} } - switch scope.collection.Kind() { + switch scope.collection.Kind { case reflect.Array, reflect.Slice: return scope.collection.Elem() } return v.error(node, "cannot use %v as array", scope) } if scope.vars != nil { - if t, ok := scope.vars[node.Name]; ok { - return t + for i := range scope.vars { + if node.Name == scope.vars[i].name { + return scope.vars[i].nature + } } } return v.error(node, "unknown pointer #%v", node.Name) } -func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { +func (v *Checker) variableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { if _, ok := v.config.Env.Get(node.Name); ok { return v.error(node, "cannot redeclare %v", node.Name) } @@ -1199,8 +1237,10 @@ func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Natur if _, ok := v.config.Builtins[node.Name]; ok { return v.error(node, "cannot redeclare builtin %v", node.Name) } - if _, ok := v.lookupVariable(node.Name); ok { - return v.error(node, "cannot redeclare variable %v", node.Name) + for i := len(v.varScopes) - 1; i >= 0; i-- { + if v.varScopes[i].name == node.Name { + return v.error(node, "cannot redeclare variable %v", node.Name) + } } varNature := v.visit(node.Value) v.varScopes = append(v.varScopes, varScope{node.Name, varNature}) @@ -1209,7 +1249,7 @@ func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Natur return exprNature } -func (v *checker) SequenceNode(node *ast.SequenceNode) Nature { +func (v *Checker) sequenceNode(node *ast.SequenceNode) Nature { if len(node.Nodes) == 0 { return v.error(node, "empty sequence expression") } @@ -1220,66 +1260,57 @@ func (v *checker) SequenceNode(node *ast.SequenceNode) Nature { return last } -func (v *checker) lookupVariable(name string) (varScope, bool) { - for i := len(v.varScopes) - 1; i >= 0; i-- { - if v.varScopes[i].name == name { - return v.varScopes[i], true - } - } - return varScope{}, false -} - -func (v *checker) ConditionalNode(node *ast.ConditionalNode) Nature { +func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { c := v.visit(node.Cond) - if !isBool(c) && !isUnknown(c) { - return v.error(node.Cond, "non-bool expression (type %v) used as condition", c) + if !c.IsBool() && !c.IsUnknown() { + return v.error(node.Cond, "non-bool expression (type %v) used as condition", c.String()) } t1 := v.visit(node.Exp1) t2 := v.visit(node.Exp2) - if isNil(t1) && !isNil(t2) { + if t1.Nil && !t2.Nil { return t2 } - if !isNil(t1) && isNil(t2) { + if !t1.Nil && t2.Nil { return t1 } - if isNil(t1) && isNil(t2) { - return nilNature + if t1.Nil && t2.Nil { + return v.config.NtCache.NatureOf(nil) } if t1.AssignableTo(t2) { return t1 } - return unknown + return Nature{} } -func (v *checker) ArrayNode(node *ast.ArrayNode) Nature { +func (v *Checker) arrayNode(node *ast.ArrayNode) Nature { var prev Nature allElementsAreSameType := true for i, node := range node.Nodes { curr := v.visit(node) if i > 0 { - if curr.Kind() != prev.Kind() { + if curr.Kind != prev.Kind { allElementsAreSameType = false } } prev = curr } if allElementsAreSameType { - return arrayOf(prev) + return prev.MakeArrayOf() } - return arrayNature + return v.config.NtCache.FromType(arrayType) } -func (v *checker) MapNode(node *ast.MapNode) Nature { +func (v *Checker) mapNode(node *ast.MapNode) Nature { for _, pair := range node.Pairs { v.visit(pair) } - return mapNature + return v.config.NtCache.FromType(mapType) } -func (v *checker) PairNode(node *ast.PairNode) Nature { +func (v *Checker) pairNode(node *ast.PairNode) Nature { v.visit(node.Key) v.visit(node.Value) - return nilNature + return v.config.NtCache.NatureOf(nil) } diff --git a/checker/checker_bench_test.go b/checker/checker_bench_test.go new file mode 100644 index 00000000..0583d735 --- /dev/null +++ b/checker/checker_bench_test.go @@ -0,0 +1,114 @@ +package checker_test + +import ( + "runtime" + "testing" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/checker" + "github.com/expr-lang/expr/checker/nature" + "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/parser" +) + +func BenchmarkChecker(b *testing.B) { + cases := []struct { + name, input string + }{ + {"function calls", ` +func( + func( + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + ), + func( + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + ), + func( + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + func(func(a, 'a', 1, nil), func(a, 'a', 1, nil), func(a, 'a', 1, nil)), + ) +) + `}, + {"unary and binary operations", ` +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 && +!b && !b || !b == !b && !b != !b || 1 < 1.0 && 0.1 > 1 || 0 <= 1.0 && 0.1 >= 1 + `}, + {"deep struct access", ` +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a. +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a. +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a. +a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a.a + `}, + } + + f := func(params ...any) (any, error) { return nil, nil } + env := map[string]any{ + "a": new(recursive), + "b": true, + "func": f, + } + config := conf.New(env) + expr.Function("func", f, f)(config) + expr.ConstExpr("func")(config) + + for _, c := range cases { + batchSize := 100_000 + if batchSize > b.N { + batchSize = b.N + } + trees := make([]*parser.Tree, 0, batchSize) + for i := 0; i < batchSize; i++ { + tree, err := parser.ParseWithConfig(c.input, config) + if err != nil { + b.Fatal(err) + } + trees = append(trees, tree) + } + runtime.GC() // try to cleanup the mess from the initialization + + b.Run("name="+c.name, func(b *testing.B) { + var err error + for i := 0; i < b.N; i++ { + j := i + if j < 0 || j >= len(trees) { + b.StopTimer() + invalidateTrees(trees...) + j = 0 + b.StartTimer() + } + + _, err = checker.Check(trees[j], config) + } + b.StopTimer() + if err != nil { + b.Fatal(err) + } + }) + } +} + +type visitorFunc func(*ast.Node) + +func (f visitorFunc) Visit(node *ast.Node) { f(node) } + +func invalidateTrees(trees ...*parser.Tree) { + for _, tree := range trees { + ast.Walk(&tree.Node, visitorFunc(func(node *ast.Node) { + (*node).SetNature(nature.Nature{}) + })) + } +} + +type recursive struct { + Inner *recursive `expr:"a"` +} diff --git a/checker/checker_test.go b/checker/checker_test.go index 2ec7c7cf..bf2b812d 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -136,6 +136,7 @@ func TestCheck(t *testing.T) { {"(Embed).EmbedPointerEmbedInt > 0"}, } + c := new(checker.Checker) for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { var err error @@ -145,7 +146,7 @@ func TestCheck(t *testing.T) { config := conf.New(mock.Env{}) expr.AsBool()(config) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) assert.NoError(t, err) }) } @@ -682,12 +683,13 @@ invalid operation: + (mismatched types int and bool) (1:6) }, } + c := new(checker.Checker) for _, tt := range errorTests { t.Run(tt.code, func(t *testing.T) { tree, err := parser.Parse(tt.code) require.NoError(t, err) - _, err = checker.Check(tree, conf.New(mock.Env{})) + _, err = c.Check(tree, conf.New(mock.Env{})) if err == nil { err = fmt.Errorf("") } @@ -839,6 +841,7 @@ func TestCheck_TypeWeights(t *testing.T) { "Float32": float32(11), "Float64": float64(12), } + c := new(checker.Checker) for a := range types { for b := range types { tree, err := parser.Parse(fmt.Sprintf("%s + %s", a, b)) @@ -846,7 +849,7 @@ func TestCheck_TypeWeights(t *testing.T) { config := conf.New(types) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) require.NoError(t, err) } } @@ -926,6 +929,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { config := conf.CreateNew() add(config) + c := new(checker.Checker) tests := []string{ "add(1)", @@ -938,7 +942,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { tree, err := parser.Parse(test) require.NoError(t, err) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) require.NoError(t, err) require.Equal(t, reflect.Int, tree.Node.Type().Kind()) }) @@ -948,7 +952,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { tree, err := parser.Parse("add(1, '2')") require.NoError(t, err) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) require.Error(t, err) require.Equal(t, "cannot use string as argument (type int) to call add (1:8)\n | add(1, '2')\n | .......^", err.Error()) }) @@ -1046,12 +1050,13 @@ func TestCheck_env_keyword(t *testing.T) { {`$env[name]`, reflect.Interface}, } + c := new(checker.Checker) for _, test := range tests { t.Run(test.input, func(t *testing.T) { tree, err := parser.Parse(test.input) require.NoError(t, err) - rtype, err := checker.Check(tree, conf.New(env)) + rtype, err := c.Check(tree, conf.New(env)) require.NoError(t, err) require.True(t, rtype.Kind() == test.want, fmt.Sprintf("expected %s, got %s", test.want, rtype.Kind())) }) @@ -1067,12 +1072,13 @@ func TestCheck_builtin_without_call(t *testing.T) { {`string.A`, "type func(interface {}) string has no field A (1:8)\n | string.A\n | .......^"}, } + c := new(checker.Checker) for _, test := range tests { t.Run(test.input, func(t *testing.T) { tree, err := parser.Parse(test.input) require.NoError(t, err) - _, err = checker.Check(tree, conf.New(nil)) + _, err = c.Check(tree, conf.New(nil)) require.Error(t, err) require.Equal(t, test.err, err.Error()) }) @@ -1134,13 +1140,14 @@ func TestCheck_types(t *testing.T) { {`arr | filter(.value contains "a") | filter(.value == 0)`, `invalid operation: == (mismatched types string and int)`}, } + c := new(checker.Checker) for _, test := range tests { t.Run(test.code, func(t *testing.T) { tree, err := parser.Parse(test.code) require.NoError(t, err) config := conf.New(env) - _, err = checker.Check(tree, config) + _, err = c.Check(tree, config) if test.err == noerr { require.NoError(t, err) } else { diff --git a/checker/info.go b/checker/info.go index f1cc92eb..6b2fba3e 100644 --- a/checker/info.go +++ b/checker/info.go @@ -8,22 +8,20 @@ import ( "github.com/expr-lang/expr/vm" ) -func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { +func FieldIndex(c *Cache, env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if env.Kind() == reflect.Struct { - if field, ok := env.Get(n.Value); ok && len(field.FieldIndex) > 0 { - return true, field.FieldIndex, n.Value - } + if idx, ok := env.FieldIndex(n.Value); ok { + return true, idx, n.Value } case *ast.MemberNode: base := n.Node.Nature() + base.Bind(c) // AST doesn't cache nature info base = base.Deref() - if base.Kind() == reflect.Struct { + if base.Kind == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { - name := prop.Value - if field, ok := base.FieldByName(name); ok { - return true, field.FieldIndex, name + if idx, ok := base.FieldIndex(prop.Value); ok { + return true, idx, prop.Value } } } @@ -31,11 +29,11 @@ func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { return false, nil, "" } -func MethodIndex(env Nature, node ast.Node) (bool, int, string) { +func MethodIndex(c *Cache, env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if env.Kind() == reflect.Struct { - if m, ok := env.Get(n.Value); ok { + if env.Kind == reflect.Struct { + if m, ok := env.Get(n.Value); ok && m.FuncData != nil { return m.Method, m.MethodIndex, n.Value } } @@ -121,7 +119,7 @@ func IsFastFunc(fn reflect.Type, method bool) bool { fn.NumOut() == 1 && fn.Out(0).Kind() == reflect.Interface { rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods - if kind(rest) == reflect.Slice && rest.Elem().Kind() == reflect.Interface { + if rest != nil && rest.Kind() == reflect.Slice && rest.Elem().Kind() == reflect.Interface { return true } } diff --git a/checker/nature/nature.go b/checker/nature/nature.go index 993c9fcf..9f453187 100644 --- a/checker/nature/nature.go +++ b/checker/nature/nature.go @@ -1,247 +1,480 @@ package nature import ( + "fmt" "reflect" + "time" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/internal/deref" ) var ( - unknown = Nature{} + intType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) + arrayType = reflect.TypeOf([]any{}) + timeType = reflect.TypeOf(time.Time{}) + durationType = reflect.TypeOf(time.Duration(0)) +) + +type NatureCheck int + +const ( + _ NatureCheck = iota + BoolCheck + StringCheck + IntegerCheck + NumberCheck + MapCheck + ArrayCheck + TimeCheck + DurationCheck ) type Nature struct { - Type reflect.Type // Type of the value. If nil, then value is unknown. - Func *builtin.Function // Used to pass function type from callee to CallNode. - ArrayOf *Nature // Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). - PredicateOut *Nature // Out nature of predicate. + // The order of the fields matter, check alignment before making changes. + + Type reflect.Type // Type of the value. If nil, then value is unknown. + Kind reflect.Kind // Kind of the value. + + cache *Cache + *Optional + *FuncData + + // Ref is a reference used for multiple, disjoint purposes. When the Nature + // is for a: + // - Predicate: then Ref is the nature of the Out of the predicate. + // - Array-like types: then Ref is the Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). + Ref *Nature + + Nil bool // If value is nil. + Strict bool // If map is types.StrictMap. + Method bool // If value retrieved from method. Usually used to determine amount of in arguments. +} + +type Optional struct { + pkgPath string + methodset *methodset // optional to avoid the map in *Cache + + *structData + + // map-only data Fields map[string]Nature // Fields of map type. DefaultMapValue *Nature // Default value of map type. - Strict bool // If map is types.StrictMap. - Nil bool // If value is nil. - Method bool // If value retrieved from method. Usually used to determine amount of in arguments. - MethodIndex int // Index of method in type. - FieldIndex []int // Index of field in type. + + pkgPathSet bool } -func (n Nature) IsAny() bool { - return n.Kind() == reflect.Interface && n.NumMethods() == 0 +type FuncData struct { + Func *builtin.Function // Used to pass function type from callee to CallNode. + MethodIndex int // Index of method in type. + + inElem, outZero *Nature + numIn, numOut int + + isVariadic bool + isVariadicSet bool + numInSet bool + numOutSet bool } -func (n Nature) IsUnknown() bool { - switch { - case n.Type == nil && !n.Nil: - return true - case n.IsAny(): - return true +// Cache is a shared cache of type information. It is only used in the stages +// where type information becomes relevant, so packages like ast, parser, types, +// and lexer do not need to use the cache because they don't need any service +// from the Nature type, they only describe. However, when receiving a Nature +// from one of those packages, the cache must be set immediately. +type Cache struct { + methods map[reflect.Type]*methodset + structs map[reflect.Type]Nature +} + +// NatureOf returns a Nature describing "i". If "i" is nil then it returns a +// Nature describing the value "nil". +func (c *Cache) NatureOf(i any) Nature { + // reflect.TypeOf(nil) returns nil, but in FromType we want to differentiate + // what nil means for us + if i == nil { + return Nature{cache: c, Nil: true} } - return false + return c.FromType(reflect.TypeOf(i)) } -func (n Nature) String() string { - if n.Type != nil { - return n.Type.String() +// FromType returns a Nature describing a value of type "t". If "t" is nil then +// it returns a Nature describing an unknown value. +func (c *Cache) FromType(t reflect.Type) Nature { + if t == nil { + return Nature{} } - return "unknown" + var fd *FuncData + k := t.Kind() + switch k { + case reflect.Struct: + return c.getStruct(t) + case reflect.Func: + fd = new(FuncData) + } + return Nature{Type: t, Kind: k, FuncData: fd, cache: c} } -func (n Nature) Deref() Nature { - if n.Type != nil { - n.Type = deref.Type(n.Type) +func (c *Cache) getStruct(t reflect.Type) Nature { + if c != nil { + if c.structs == nil { + c.structs = map[reflect.Type]Nature{} + } else if nt, ok := c.structs[t]; ok { + return nt + } + } + nt := Nature{ + Type: t, + Kind: reflect.Struct, + Optional: &Optional{ + structData: &structData{ + cache: c, + rType: t, + numField: t.NumField(), + anonIdx: -1, // do not lookup embedded fields yet + }, + }, + } + if c != nil { + nt.Bind(c) + } + return nt +} + +func (c *Cache) getMethodset(t reflect.Type, k reflect.Kind) *methodset { + if t == nil || c == nil { + return nil + } + if c.methods == nil { + c.methods = map[reflect.Type]*methodset{ + t: nil, + } + } else if s, ok := c.methods[t]; ok { + return s + } + numMethod := t.NumMethod() + if numMethod < 1 { + c.methods[t] = nil // negative cache + return nil + } + s := &methodset{ + cache: c, + rType: t, + kind: k, + numMethod: numMethod, + } + c.methods[t] = s + return s +} + +// NatureOf calls NatureOf on a nil *Cache. See the comment on Cache. +func NatureOf(i any) Nature { + var c *Cache + return c.NatureOf(i) +} + +// FromType calls FromType on a nil *Cache. See the comment on Cache. +func FromType(t reflect.Type) Nature { + var c *Cache + return c.FromType(t) +} + +func ArrayFromType(c *Cache, t reflect.Type) Nature { + elem := c.FromType(t) + nt := c.FromType(arrayType) + nt.Ref = &elem + return nt +} + +func (n *Nature) Bind(c *Cache) { + if n.cache == c { + return + } + n.cache = c + if n.Kind == reflect.Struct { + if c.structs == nil { + n.structData.bind(c) + c.structs = map[reflect.Type]Nature{ + n.Type: *n, + } + } else if nt, ok := c.structs[n.Type]; ok { + // invalidate local, use shared from cache + n.Optional.structData = nt.Optional.structData + } else { + n.structData.bind(c) + c.structs[n.Type] = *n + } + } + hasMethodset := n.Optional != nil && n.Optional.methodset != nil + if c.methods != nil || hasMethodset { + if c.methods == nil { + // Cache is new and the type already gathered some methods + n.Optional.methodset.bind(c) + c.methods = map[reflect.Type]*methodset{ + n.Type: n.Optional.methodset, + } + } else if s, ok := c.methods[n.Type]; ok { + if n.Optional == nil { + n.Optional = new(Optional) + } + // Cache is not new. Invalidate local if set + n.Optional.methodset = s + } else if hasMethodset { + // Cache miss and the type already gathered some methods + n.Optional.methodset.bind(c) + c.methods[n.Type] = n.Optional.methodset + } } - return n } -func (n Nature) Kind() reflect.Kind { +func (n *Nature) IsAny() bool { + return n.Type != nil && n.Kind == reflect.Interface && n.NumMethods() == 0 +} + +func (n *Nature) IsUnknown() bool { + return n.Type == nil && !n.Nil || n.IsAny() +} + +func (n *Nature) String() string { if n.Type != nil { - return n.Type.Kind() + return n.Type.String() + } + return "unknown" +} + +func (n *Nature) Deref() Nature { + t, _, changed := deref.TypeKind(n.Type, n.Kind) + if !changed { + return *n } - return reflect.Invalid + return n.cache.FromType(t) } -func (n Nature) Key() Nature { - if n.Kind() == reflect.Map { - return Nature{Type: n.Type.Key()} +func (n *Nature) Key() Nature { + if n.Kind == reflect.Map { + return n.cache.FromType(n.Type.Key()) } - return unknown + return Nature{} } -func (n Nature) Elem() Nature { - switch n.Kind() { +func (n *Nature) Elem() Nature { + switch n.Kind { case reflect.Ptr: - return Nature{Type: n.Type.Elem()} + return n.cache.FromType(n.Type.Elem()) case reflect.Map: - if n.DefaultMapValue != nil { + if n.Optional != nil && n.DefaultMapValue != nil { return *n.DefaultMapValue } - return Nature{Type: n.Type.Elem()} - case reflect.Array, reflect.Slice: - if n.ArrayOf != nil { - return *n.ArrayOf + return n.cache.FromType(n.Type.Elem()) + case reflect.Slice, reflect.Array: + if n.Ref != nil { + return *n.Ref } - return Nature{Type: n.Type.Elem()} + return n.cache.FromType(n.Type.Elem()) } - return unknown + return Nature{} } -func (n Nature) AssignableTo(nt Nature) bool { +func (n *Nature) AssignableTo(nt Nature) bool { if n.Nil { - // Untyped nil is assignable to any interface, but implements only the empty interface. - if nt.IsAny() { + switch nt.Kind { + case reflect.Pointer, reflect.Interface: return true } } - if n.Type == nil || nt.Type == nil { + if n.Type == nil || nt.Type == nil || + n.Kind != nt.Kind && nt.Kind != reflect.Interface { return false } return n.Type.AssignableTo(nt.Type) } -func (n Nature) NumMethods() int { - if n.Type == nil { - return 0 +func (n *Nature) getMethodset() *methodset { + if n.Optional != nil && n.Optional.methodset != nil { + return n.Optional.methodset + } + s := n.cache.getMethodset(n.Type, n.Kind) + if n.Optional != nil { + n.Optional.methodset = s // cache locally if possible } - return n.Type.NumMethod() + return s } -func (n Nature) MethodByName(name string) (Nature, bool) { - if n.Type == nil { - return unknown, false +func (n *Nature) NumMethods() int { + if s := n.getMethodset(); s != nil { + return s.numMethod } - method, ok := n.Type.MethodByName(name) - if !ok { - return unknown, false + return 0 +} + +func (n *Nature) MethodByName(name string) (Nature, bool) { + if s := n.getMethodset(); s != nil { + if m := s.method(name); m != nil { + return m.nature, true + } } + return Nature{}, false +} - if n.Type.Kind() == reflect.Interface { - // In case of interface type method will not have a receiver, - // and to prevent checker decreasing numbers of in arguments - // return method type as not method (second argument is false). +func (n *Nature) NumIn() int { + if n.numInSet { + return n.numIn + } + n.numInSet = true + n.numIn = n.Type.NumIn() + return n.numIn +} - // Also, we can not use m.Index here, because it will be - // different indexes for different types which implement - // the same interface. - return Nature{Type: method.Type}, true - } else { - return Nature{ - Type: method.Type, - Method: true, - MethodIndex: method.Index, - }, true +func (n *Nature) InElem(i int) Nature { + if n.inElem == nil { + n2 := n.cache.FromType(n.Type.In(i)) + n2 = n2.Elem() + n.inElem = &n2 } + return *n.inElem } -func (n Nature) NumIn() int { - if n.Type == nil { - return 0 +func (n *Nature) In(i int) Nature { + return n.cache.FromType(n.Type.In(i)) +} + +func (n *Nature) IsFirstArgUnknown() bool { + if n.Type != nil { + n2 := n.cache.FromType(n.Type.In(0)) + return n2.IsUnknown() } - return n.Type.NumIn() + return false } -func (n Nature) In(i int) Nature { - if n.Type == nil { - return unknown +func (n *Nature) NumOut() int { + if n.numOutSet { + return n.numOut } - return Nature{Type: n.Type.In(i)} + n.numOutSet = true + n.numOut = n.Type.NumOut() + return n.numOut } -func (n Nature) NumOut() int { - if n.Type == nil { - return 0 +func (n *Nature) Out(i int) Nature { + if i != 0 { + return n.out(i) + } + if n.outZero != nil { + return *n.outZero } - return n.Type.NumOut() + nt := n.out(0) + n.outZero = &nt + return nt } -func (n Nature) Out(i int) Nature { +func (n *Nature) out(i int) Nature { if n.Type == nil { - return unknown + return Nature{} } - return Nature{Type: n.Type.Out(i)} + return n.cache.FromType(n.Type.Out(i)) } -func (n Nature) IsVariadic() bool { - if n.Type == nil { - return false +func (n *Nature) IsVariadic() bool { + if n.isVariadicSet { + return n.isVariadic } - return n.Type.IsVariadic() + n.isVariadicSet = true + n.isVariadic = n.Type.IsVariadic() + return n.isVariadic } -func (n Nature) FieldByName(name string) (Nature, bool) { - if n.Type == nil { - return unknown, false +func (n *Nature) FieldByName(name string) (Nature, bool) { + if n.Kind != reflect.Struct { + return Nature{}, false + } + var sd *structData + if n.Optional != nil && n.structData != nil { + sd = n.structData + } else { + sd = n.cache.getStruct(n.Type).structData + } + if sf := sd.structField(nil, name); sf != nil { + return sf.Nature, true } - field, ok := fetchField(n.Type, name) - return Nature{Type: field.Type, FieldIndex: field.Index}, ok + return Nature{}, false } -func (n Nature) PkgPath() string { +func (n *Nature) PkgPath() string { if n.Type == nil { return "" } - return n.Type.PkgPath() + if n.Optional != nil && n.Optional.pkgPathSet { + return n.Optional.pkgPath + } + p := n.Type.PkgPath() + if n.Optional != nil { + n.Optional.pkgPathSet = true + n.Optional.pkgPath = p + } + return p } -func (n Nature) IsFastMap() bool { - if n.Type == nil { - return false - } - if n.Type.Kind() == reflect.Map && +func (n *Nature) IsFastMap() bool { + return n.Type != nil && + n.Type.Kind() == reflect.Map && n.Type.Key().Kind() == reflect.String && - n.Type.Elem().Kind() == reflect.Interface { - return true - } - return false + n.Type.Elem().Kind() == reflect.Interface } -func (n Nature) Get(name string) (Nature, bool) { - if n.Type == nil { - return unknown, false +func (n *Nature) Get(name string) (Nature, bool) { + if n.Kind == reflect.Map && n.Optional != nil { + f, ok := n.Fields[name] + return f, ok } + return n.getSlow(name) +} - if m, ok := n.MethodByName(name); ok { - return m, true +func (n *Nature) getSlow(name string) (Nature, bool) { + if nt, ok := n.MethodByName(name); ok { + return nt, true } - - t := deref.Type(n.Type) - - switch t.Kind() { - case reflect.Struct: - if f, ok := fetchField(t, name); ok { - return Nature{ - Type: f.Type, - FieldIndex: f.Index, - }, true - } - case reflect.Map: - if f, ok := n.Fields[name]; ok { - return f, true + if n.Kind == reflect.Struct { + if sf := n.structField(nil, name); sf != nil { + return sf.Nature, true } } - return unknown, false + return Nature{}, false +} + +func (n *Nature) FieldIndex(name string) ([]int, bool) { + if n.Kind != reflect.Struct { + return nil, false + } + if sf := n.structField(nil, name); sf != nil { + return sf.Index, true + } + return nil, false } -func (n Nature) All() map[string]Nature { +func (n *Nature) All() map[string]Nature { table := make(map[string]Nature) if n.Type == nil { return table } - for i := 0; i < n.Type.NumMethod(); i++ { + for i := 0; i < n.NumMethods(); i++ { method := n.Type.Method(i) - table[method.Name] = Nature{ - Type: method.Type, - Method: true, - MethodIndex: method.Index, + nt := n.cache.FromType(method.Type) + if nt.Optional == nil { + nt.FuncData = new(FuncData) } + nt.Method = true + nt.MethodIndex = method.Index + table[method.Name] = nt } t := deref.Type(n.Type) switch t.Kind() { case reflect.Struct: - for name, nt := range StructFields(t) { + for name, nt := range StructFields(n.cache, t) { if _, ok := table[name]; ok { continue } @@ -249,13 +482,134 @@ func (n Nature) All() map[string]Nature { } case reflect.Map: - for key, nt := range n.Fields { - if _, ok := table[key]; ok { - continue + if n.Optional != nil { + for key, nt := range n.Fields { + if _, ok := table[key]; ok { + continue + } + table[key] = nt } - table[key] = nt } } return table } + +func (n *Nature) IsNumber() bool { + return n.IsInteger() || n.IsFloat() +} + +func (n *Nature) IsInteger() bool { + switch n.Kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return n.PkgPath() == "" + } + return false +} + +func (n *Nature) IsFloat() bool { + switch n.Kind { + case reflect.Float32, reflect.Float64: + return n.PkgPath() == "" + } + return false +} + +func (n *Nature) PromoteNumericNature(rhs Nature) Nature { + if n.IsUnknown() || rhs.IsUnknown() { + return Nature{} + } + if n.IsFloat() || rhs.IsFloat() { + return n.cache.FromType(floatType) + } + return n.cache.FromType(intType) +} + +func (n *Nature) IsTime() bool { + return n.Type == timeType +} + +func (n *Nature) IsDuration() bool { + return n.Type == durationType +} + +func (n *Nature) IsBool() bool { + return n.Kind == reflect.Bool +} + +func (n *Nature) IsString() bool { + return n.Kind == reflect.String +} + +func (n *Nature) IsArray() bool { + k := n.Kind + return k == reflect.Slice || k == reflect.Array +} + +func (n *Nature) IsMap() bool { + return n.Kind == reflect.Map +} + +func (n *Nature) IsStruct() bool { + return n.Kind == reflect.Struct +} + +func (n *Nature) IsFunc() bool { + return n.Kind == reflect.Func +} + +func (n *Nature) IsPointer() bool { + return n.Kind == reflect.Ptr +} + +func (n *Nature) IsAnyOf(cs ...NatureCheck) bool { + var result bool + for i := 0; i < len(cs) && !result; i++ { + switch cs[i] { + case BoolCheck: + result = n.IsBool() + case StringCheck: + result = n.IsString() + case IntegerCheck: + result = n.IsInteger() + case NumberCheck: + result = n.IsNumber() + case MapCheck: + result = n.IsMap() + case ArrayCheck: + result = n.IsArray() + case TimeCheck: + result = n.IsTime() + case DurationCheck: + result = n.IsDuration() + default: + panic(fmt.Sprintf("unknown check value %d", cs[i])) + } + } + return result +} + +func (n *Nature) ComparableTo(rhs Nature) bool { + return n.IsUnknown() || rhs.IsUnknown() || + n.Nil || rhs.Nil || + n.IsNumber() && rhs.IsNumber() || + n.IsDuration() && rhs.IsDuration() || + n.IsTime() && rhs.IsTime() || + n.IsArray() && rhs.IsArray() || + n.AssignableTo(rhs) +} + +func (n *Nature) MaybeCompatible(rhs Nature, cs ...NatureCheck) bool { + nIsUnknown := n.IsUnknown() + rshIsUnknown := rhs.IsUnknown() + return nIsUnknown && rshIsUnknown || + nIsUnknown && rhs.IsAnyOf(cs...) || + rshIsUnknown && n.IsAnyOf(cs...) +} + +func (n *Nature) MakeArrayOf() Nature { + nt := n.cache.FromType(arrayType) + nt.Ref = n + return nt +} diff --git a/checker/nature/utils.go b/checker/nature/utils.go index c1551546..3c7a33c2 100644 --- a/checker/nature/utils.go +++ b/checker/nature/utils.go @@ -6,84 +6,242 @@ import ( "github.com/expr-lang/expr/internal/deref" ) -func fieldName(field reflect.StructField) (string, bool) { - switch taggedName := field.Tag.Get("expr"); taggedName { +func fieldName(fieldName string, tag reflect.StructTag) (string, bool) { + switch taggedName := tag.Get("expr"); taggedName { case "-": return "", false case "": - return field.Name, true + return fieldName, true default: return taggedName, true } } -func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { - // If t is not a struct, early return. - if t.Kind() != reflect.Struct { - return reflect.StructField{}, false +type structData struct { + cache *Cache + rType reflect.Type + fields map[string]*structField + numField, ownIdx, anonIdx int + + curParent, curChild *structData + curChildIndex []int +} + +type structField struct { + Nature + Index []int +} + +func (s *structData) bind(c *Cache) { + s.cache = c + for _, sf := range s.fields { + sf.Bind(c) + } +} + +func (s *structData) finished() bool { + return s.ownIdx >= s.numField && // no own fields left to visit + s.anonIdx >= s.numField && // no embedded fields to visit + s.curChild == nil // no child in process of visiting +} + +func (s *structData) structField(parentEmbed *structData, name string) *structField { + if s.fields == nil { + if s.numField > 0 { + s.fields = make(map[string]*structField, s.numField) + } + } else if f := s.fields[name]; f != nil { + return f + } + if s.finished() { + return nil + } + + // Lookup own fields first. + for ; s.ownIdx < s.numField; s.ownIdx++ { + field := s.rType.Field(s.ownIdx) + // BUG: we should skip if !field.IsExported() here + + if field.Anonymous && s.anonIdx < 0 { + // start iterating anon fields on the first instead of zero + s.anonIdx = s.ownIdx + } + fName, ok := fieldName(field.Name, field.Tag) + if !ok || fName == "" { + // name can still be empty for a type created at runtime with + // reflect + continue + } + nt := s.cache.FromType(field.Type) + sf := &structField{ + Nature: nt, + Index: field.Index, + } + s.fields[fName] = sf + if parentEmbed != nil { + parentEmbed.trySet(fName, sf) + } + if fName == name { + return sf + } } - // First check all structs fields. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - // Search all fields, even embedded structs. - if n, ok := fieldName(field); ok && n == name { - return field, true + if s.curChild != nil { + sf := s.findInEmbedded(parentEmbed, s.curChild, s.curChildIndex, name) + if sf != nil { + return sf } } - // Second check fields of embedded structs. - for i := 0; i < t.NumField(); i++ { - anon := t.Field(i) - if anon.Anonymous { - anonType := anon.Type - if anonType.Kind() == reflect.Pointer { - anonType = anonType.Elem() - } - if field, ok := fetchField(anonType, name); ok { - field.Index = append(anon.Index, field.Index...) - return field, true - } + // Lookup embedded fields through anon own fields + for ; s.anonIdx >= 0 && s.anonIdx < s.numField; s.anonIdx++ { + field := s.rType.Field(s.anonIdx) + // we do enter embedded non-exported types because they could contain + // exported fields + if !field.Anonymous { + continue + } + t, k, _ := deref.TypeKind(field.Type, field.Type.Kind()) + if k != reflect.Struct { + continue + } + + childEmbed := s.cache.getStruct(t).structData + sf := s.findInEmbedded(parentEmbed, childEmbed, field.Index, name) + if sf != nil { + return sf } } - return reflect.StructField{}, false + return nil } -func StructFields(t reflect.Type) map[string]Nature { - table := make(map[string]Nature) +func (s *structData) findInEmbedded( + parentEmbed, childEmbed *structData, + childIndex []int, + name string, +) *structField { + // Set current parent/child data. This allows trySet to handle child fields + // and add them to our struct and to the parent as well if needed + s.curParent = parentEmbed + s.curChild = childEmbed + s.curChildIndex = childIndex + defer func() { + // Ensure to cleanup references + s.curParent = nil + if childEmbed.finished() { + // If the child can still have more fields to explore then keep it + // referened to look it up again if we need to + s.curChild = nil + s.curChildIndex = nil + } + }() - t = deref.Type(t) + // See if the child has already cached its fields. This is still important + // to check even if it's the s.unfinishedEmbedded because it may have + // explored new fields since the last time we visited it + for name, sf := range childEmbed.fields { + s.trySet(name, sf) + } + + // Recheck if we have what we needed from the above sync + if sf := s.fields[name]; sf != nil { + return sf + } + + // Try finding in the child again in case it hasn't finished + if !childEmbed.finished() { + if childEmbed.structField(s, name) != nil { + return s.fields[name] + } + } + + return nil +} + +func (s *structData) trySet(name string, sf *structField) { + if _, ok := s.fields[name]; ok { + return + } + sf = &structField{ + Nature: sf.Nature, + Index: append(s.curChildIndex, sf.Index...), + } + s.fields[name] = sf + if s.curParent != nil { + s.curParent.trySet(name, sf) + } +} + +func StructFields(c *Cache, t reflect.Type) map[string]Nature { + table := make(map[string]Nature) if t == nil { return table } + t, k, _ := deref.TypeKind(t, t.Kind()) + if k == reflect.Struct { + // lookup for a field with an empty name, which will cause to never find a + // match, meaning everything will have been cached. + sd := c.getStruct(t).structData + sd.structField(nil, "") + for name, sf := range sd.fields { + table[name] = sf.Nature + } + } + return table +} - switch t.Kind() { - case reflect.Struct: - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) +type methodset struct { + cache *Cache + rType reflect.Type + kind reflect.Kind + methods map[string]*method + numMethod, idx int +} - if f.Anonymous { - for name, typ := range StructFields(f.Type) { - if _, ok := table[name]; ok { - continue - } - typ.FieldIndex = append(f.Index, typ.FieldIndex...) - table[name] = typ - } - } +type method struct { + reflect.Method + nature Nature +} - name, ok := fieldName(f) - if !ok { - continue - } - table[name] = Nature{ - Type: f.Type, - FieldIndex: f.Index, - } +func (s *methodset) bind(c *Cache) { + s.cache = c + for _, m := range s.methods { + m.nature.Bind(c) + } +} - } +func (s *methodset) method(name string) *method { + if s.methods == nil { + s.methods = make(map[string]*method, s.numMethod) + } else if m := s.methods[name]; m != nil { + return m } + for ; s.idx < s.numMethod; s.idx++ { + rm := s.rType.Method(s.idx) + if !rm.IsExported() { + continue + } + nt := s.cache.FromType(rm.Type) + if s.rType.Kind() != reflect.Interface { + nt.Method = true + nt.MethodIndex = rm.Index + // In case of interface type method will not have a receiver, + // and to prevent checker decreasing numbers of in arguments + // return method type as not method (second argument is false). - return table + // Also, we can not use m.Index here, because it will be + // different indexes for different types which implement + // the same interface. + } + m := &method{ + Method: rm, + nature: nt, + } + s.methods[rm.Name] = m + if rm.Name == name { + return m + } + } + return nil } diff --git a/checker/types.go b/checker/types.go deleted file mode 100644 index 09896de5..00000000 --- a/checker/types.go +++ /dev/null @@ -1,190 +0,0 @@ -package checker - -import ( - "reflect" - "time" - - . "github.com/expr-lang/expr/checker/nature" -) - -var ( - unknown = Nature{} - nilNature = Nature{Nil: true} - boolNature = Nature{Type: reflect.TypeOf(true)} - integerNature = Nature{Type: reflect.TypeOf(0)} - floatNature = Nature{Type: reflect.TypeOf(float64(0))} - stringNature = Nature{Type: reflect.TypeOf("")} - arrayNature = Nature{Type: reflect.TypeOf([]any{})} - mapNature = Nature{Type: reflect.TypeOf(map[string]any{})} - timeNature = Nature{Type: reflect.TypeOf(time.Time{})} - durationNature = Nature{Type: reflect.TypeOf(time.Duration(0))} -) - -var ( - anyType = reflect.TypeOf(new(any)).Elem() - timeType = reflect.TypeOf(time.Time{}) - durationType = reflect.TypeOf(time.Duration(0)) - arrayType = reflect.TypeOf([]any{}) -) - -func arrayOf(nt Nature) Nature { - return Nature{ - Type: arrayType, - ArrayOf: &nt, - } -} - -func isNil(nt Nature) bool { - return nt.Nil -} - -func combined(l, r Nature) Nature { - if isUnknown(l) || isUnknown(r) { - return unknown - } - if isFloat(l) || isFloat(r) { - return floatNature - } - return integerNature -} - -func anyOf(nt Nature, fns ...func(Nature) bool) bool { - for _, fn := range fns { - if fn(nt) { - return true - } - } - return false -} - -func or(l, r Nature, fns ...func(Nature) bool) bool { - if isUnknown(l) && isUnknown(r) { - return true - } - if isUnknown(l) && anyOf(r, fns...) { - return true - } - if isUnknown(r) && anyOf(l, fns...) { - return true - } - return false -} - -func isUnknown(nt Nature) bool { - return nt.IsUnknown() -} - -func isInteger(nt Nature) bool { - switch nt.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fallthrough - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return nt.PkgPath() == "" - } - return false -} - -func isFloat(nt Nature) bool { - switch nt.Kind() { - case reflect.Float32, reflect.Float64: - return nt.PkgPath() == "" - } - return false -} - -func isNumber(nt Nature) bool { - return isInteger(nt) || isFloat(nt) -} - -func isTime(nt Nature) bool { - switch nt.Type { - case timeType: - return true - } - return false -} - -func isDuration(nt Nature) bool { - switch nt.Type { - case durationType: - return true - } - return false -} - -func isBool(nt Nature) bool { - switch nt.Kind() { - case reflect.Bool: - return true - } - return false -} - -func isString(nt Nature) bool { - switch nt.Kind() { - case reflect.String: - return true - } - return false -} - -func isArray(nt Nature) bool { - switch nt.Kind() { - case reflect.Slice, reflect.Array: - return true - } - return false -} - -func isMap(nt Nature) bool { - switch nt.Kind() { - case reflect.Map: - return true - } - return false -} - -func isStruct(nt Nature) bool { - switch nt.Kind() { - case reflect.Struct: - return true - } - return false -} - -func isFunc(nt Nature) bool { - switch nt.Kind() { - case reflect.Func: - return true - } - return false -} - -func kind(t reflect.Type) reflect.Kind { - if t == nil { - return reflect.Invalid - } - return t.Kind() -} - -func isComparable(l, r Nature) bool { - if isUnknown(l) || isUnknown(r) { - return true - } - if isNil(l) || isNil(r) { - return true - } - if isNumber(l) && isNumber(r) { - return true - } - if isDuration(l) && isDuration(r) { - return true - } - if isTime(l) && isTime(r) { - return true - } - if isArray(l) && isArray(r) { - return true - } - return l.AssignableTo(r) -} diff --git a/compiler/compiler.go b/compiler/compiler.go index 595355d2..7910aaf8 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -5,6 +5,7 @@ import ( "math" "reflect" "regexp" + "runtime/debug" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" @@ -24,7 +25,7 @@ const ( func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err error) { defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v", r) + err = fmt.Errorf("%v\n%s", r, debug.Stack()) } }() @@ -36,6 +37,12 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro debugInfo: make(map[string]string), } + if config != nil { + c.ntCache = &c.config.NtCache + } else { + c.ntCache = new(Cache) + } + c.compile(tree.Node) if c.config != nil { @@ -74,6 +81,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro type compiler struct { config *conf.Config + ntCache *Cache locations []file.Location bytecode []Opcode variables int @@ -302,12 +310,12 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { if env.IsFastMap() { c.emit(OpLoadFast, c.addConstant(node.Value)) - } else if ok, index, name := checker.FieldIndex(env, node); ok { + } else if ok, index, name := checker.FieldIndex(c.ntCache, env, node); ok { c.emit(OpLoadField, c.addConstant(&runtime.Field{ Index: index, Path: []string{name}, })) - } else if ok, index, name := checker.MethodIndex(env, node); ok { + } else if ok, index, name := checker.MethodIndex(c.ntCache, env, node); ok { c.emit(OpLoadMethod, c.addConstant(&runtime.Method{ Name: name, Index: index, @@ -653,7 +661,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { env = c.config.Env } - if ok, index, name := checker.MethodIndex(env, node); ok { + if ok, index, name := checker.MethodIndex(c.ntCache, env, node); ok { c.compile(node.Node) c.emit(OpMethod, c.addConstant(&runtime.Method{ Name: name, @@ -664,14 +672,14 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { op := OpFetch base := node.Node - ok, index, nodeName := checker.FieldIndex(env, node) + ok, index, nodeName := checker.FieldIndex(c.ntCache, env, node) path := []string{nodeName} if ok { op = OpFetchField for !node.Optional { if ident, isIdent := base.(*ast.IdentifierNode); isIdent { - if ok, identIndex, name := checker.FieldIndex(env, ident); ok { + if ok, identIndex, name := checker.FieldIndex(c.ntCache, env, ident); ok { index = append(identIndex, index...) path = append([]string{name}, path...) c.emitLocation(ident.Location(), OpLoadField, c.addConstant( @@ -682,7 +690,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { } if member, isMember := base.(*ast.MemberNode); isMember { - if ok, memberIndex, name := checker.FieldIndex(env, member); ok { + if ok, memberIndex, name := checker.FieldIndex(c.ntCache, env, member); ok { index = append(memberIndex, index...) path = append([]string{name}, path...) node = member @@ -777,7 +785,7 @@ func (c *compiler) CallNode(node *ast.CallNode) { c.compile(node.Callee) if c.config != nil { - isMethod, _, _ := checker.MethodIndex(c.config.Env, node.Callee) + isMethod, _, _ := checker.MethodIndex(c.ntCache, c.config.Env, node.Callee) if index, ok := checker.TypedFuncIndex(node.Callee.Type(), isMethod); ok { c.emit(OpCallTyped, index) return @@ -1080,7 +1088,9 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { for i, arg := range node.Arguments { c.compile(arg) argType := arg.Type() - if argType.Kind() == reflect.Ptr || arg.Nature().IsUnknown() { + argNature := arg.Nature() + argNature.Bind(c.ntCache) // AST doesn't cache nature info + if argType.Kind() == reflect.Ptr || argNature.IsUnknown() { if f.Deref == nil { // By default, builtins expect arguments to be dereferenced. c.emit(OpDeref) diff --git a/conf/config.go b/conf/config.go index 2c14d988..aba35fea 100644 --- a/conf/config.go +++ b/conf/config.go @@ -34,6 +34,7 @@ type Config struct { Functions FunctionsTable Builtins FunctionsTable Disabled map[string]bool // disabled builtins + NtCache nature.Cache } // CreateNew creates new config with default values. @@ -49,6 +50,7 @@ func CreateNew() *Config { for _, f := range builtin.Builtins { c.Builtins[f.Name] = f } + c.Env.Bind(&c.NtCache) return c } @@ -61,7 +63,7 @@ func New(env any) *Config { func (c *Config) WithEnv(env any) { c.EnvObject = env - c.Env = Env(env) + c.Env = EnvWithCache(&c.NtCache, env) c.Strict = c.Env.Strict } diff --git a/conf/env.go b/conf/env.go index 8b13df1e..ef280c20 100644 --- a/conf/env.go +++ b/conf/env.go @@ -9,35 +9,43 @@ import ( "github.com/expr-lang/expr/types" ) +// Env returns the Nature of the given environment. +// +// Deprecated: use EnvWithCache instead. func Env(env any) Nature { + return EnvWithCache(new(Cache), env) +} + +func EnvWithCache(c *Cache, env any) Nature { if env == nil { - return Nature{ - Type: reflect.TypeOf(map[string]any{}), - Strict: true, - } + n := c.NatureOf(map[string]any{}) + n.Strict = true + return n } switch env := env.(type) { case types.Map: - return env.Nature() + nt := env.Nature() + nt.Bind(c) + return nt } v := reflect.ValueOf(env) - d := deref.Value(v) + t := v.Type() - switch d.Kind() { + switch deref.Value(v).Kind() { case reflect.Struct: - return Nature{ - Type: v.Type(), - Strict: true, - } + n := c.FromType(t) + n.Strict = true + return n case reflect.Map: - n := Nature{ - Type: v.Type(), - Fields: make(map[string]Nature, v.Len()), - Strict: true, + n := c.FromType(v.Type()) + if n.Optional == nil { + n.Optional = new(Optional) } + n.Strict = true + n.Fields = make(map[string]Nature, v.Len()) for _, key := range v.MapKeys() { elem := v.MapIndex(key) @@ -49,14 +57,16 @@ func Env(env any) Nature { switch face := face.(type) { case types.Map: - n.Fields[key.String()] = face.Nature() + nt := face.Nature() + nt.Bind(c) + n.Fields[key.String()] = nt default: if face == nil { - n.Fields[key.String()] = Nature{Nil: true} + n.Fields[key.String()] = c.NatureOf(nil) continue } - n.Fields[key.String()] = Nature{Type: reflect.TypeOf(face)} + n.Fields[key.String()] = c.NatureOf(face) } } diff --git a/docgen/docgen.go b/docgen/docgen.go index 1844f23b..5070d14f 100644 --- a/docgen/docgen.go +++ b/docgen/docgen.go @@ -85,7 +85,8 @@ func CreateDoc(i any) *Context { PkgPath: deref.Type(reflect.TypeOf(i)).PkgPath(), } - for name, t := range conf.Env(i).All() { + env := conf.EnvWithCache(new(nature.Cache), i) + for name, t := range env.All() { if _, ok := c.Variables[Identifier(name)]; ok { continue } @@ -221,7 +222,8 @@ appendix: c.Types[name] = a } - for name, field := range nature.StructFields(t) { + ntCache := new(nature.Cache) + for name, field := range nature.StructFields(ntCache, t) { if isPrivate(name) || isProtobuf(name) { continue } diff --git a/expr.go b/expr.go index 48298fe7..e8f4eb64 100644 --- a/expr.go +++ b/expr.go @@ -48,6 +48,7 @@ func Operator(operator string, fn ...string) Option { Overloads: fn, Env: &c.Env, Functions: c.Functions, + NtCache: &c.NtCache, } c.Visitors = append(c.Visitors, p) } diff --git a/internal/deref/deref.go b/internal/deref/deref.go index da3e28ce..4ad7877f 100644 --- a/internal/deref/deref.go +++ b/internal/deref/deref.go @@ -45,3 +45,12 @@ func Value(v reflect.Value) reflect.Value { } return v } + +func TypeKind(t reflect.Type, k reflect.Kind) (_ reflect.Type, _ reflect.Kind, changed bool) { + for k == reflect.Pointer { + changed = true + t = t.Elem() + k = t.Kind() + } + return t, k, changed +} diff --git a/parser/lexer/lexer.go b/parser/lexer/lexer.go index a8cf6a5b..fe41e824 100644 --- a/parser/lexer/lexer.go +++ b/parser/lexer/lexer.go @@ -92,13 +92,6 @@ func (l *Lexer) peek() rune { return eof } -func (l *Lexer) peekByte() (byte, bool) { - if l.end.byte >= 0 && l.end.byte < len(l.source.String()) { - return l.source.String()[l.end.byte], true - } - return 0, false -} - func (l *Lexer) backup() { if l.eof { l.eof = false diff --git a/parser/parser.go b/parser/parser.go index e1dd111f..6034403b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -12,7 +12,6 @@ import ( "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" - "github.com/expr-lang/expr/parser/lexer" . "github.com/expr-lang/expr/parser/lexer" "github.com/expr-lang/expr/parser/operator" "github.com/expr-lang/expr/parser/utils" @@ -49,7 +48,7 @@ var predicates = map[string]struct { // Parser is a reusable parser. The zero value is ready for use. type Parser struct { - lexer *lexer.Lexer + lexer *Lexer current, stashed Token hasStash bool err *file.Error @@ -60,7 +59,7 @@ type Parser struct { func (p *Parser) Parse(input string, config *conf.Config) (*Tree, error) { if p.lexer == nil { - p.lexer = lexer.New() + p.lexer = New() } p.config = config source := file.NewSource(input) diff --git a/patcher/operator_override.go b/patcher/operator_override.go index 308cbdba..e0a00593 100644 --- a/patcher/operator_override.go +++ b/patcher/operator_override.go @@ -16,6 +16,7 @@ type OperatorOverloading struct { Env *nature.Nature // Env type. Functions conf.FunctionsTable // Env functions. applied bool // Flag to indicate if any changes were made to the tree. + NtCache *nature.Cache } func (p *OperatorOverloading) Visit(node *ast.Node) { diff --git a/types/types.go b/types/types.go index bb1cbe5f..232f9b9b 100644 --- a/types/types.go +++ b/types/types.go @@ -44,7 +44,7 @@ func TypeOf(v any) Type { type anyType struct{} func (anyType) Nature() Nature { - return Nature{Type: nil} + return FromType(nil) } func (anyType) Equal(t Type) bool { @@ -58,7 +58,7 @@ func (anyType) String() string { type nilType struct{} func (nilType) Nature() Nature { - return Nature{Nil: true} + return NatureOf(nil) } func (nilType) Equal(t Type) bool { @@ -77,7 +77,7 @@ type rtype struct { } func (r rtype) Nature() Nature { - return Nature{Type: r.t} + return FromType(r.t) } func (r rtype) Equal(t Type) bool { @@ -100,11 +100,12 @@ type Map map[string]Type const Extra = "[[__extra_keys__]]" func (m Map) Nature() Nature { - nt := Nature{ - Type: reflect.TypeOf(map[string]any{}), - Fields: make(map[string]Nature, len(m)), - Strict: true, + nt := NatureOf(map[string]any{}) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.Fields = make(map[string]Nature, len(m)) + nt.Strict = true for k, v := range m { if k == Extra { nt.Strict = false @@ -155,11 +156,13 @@ type array struct { func (a array) Nature() Nature { of := a.of.Nature() - return Nature{ - Type: reflect.TypeOf([]any{}), - Fields: make(map[string]Nature, 1), - ArrayOf: &of, + nt := NatureOf([]any{}) + if nt.Optional == nil { + nt.Optional = new(Optional) } + nt.Fields = make(map[string]Nature, 1) + nt.Ref = &of + return nt } func (a array) Equal(t Type) bool {