diff --git a/compiler/semantic/analyzer.go b/compiler/semantic/analyzer.go index f779d127e..a3a62ff78 100644 --- a/compiler/semantic/analyzer.go +++ b/compiler/semantic/analyzer.go @@ -3,7 +3,6 @@ package semantic import ( "context" "errors" - "strconv" "strings" "github.com/brimdata/super" @@ -40,16 +39,11 @@ func Analyze(ctx context.Context, p *parser.AST, env *exec.Environment, extInput seq.Prepend(&sem.NullScan{}) } } - resolver := newResolver(t) - semSeq, dagFuncs := resolver.resolve(seq) + newChecker(t).check(t.reporter, seq) if err := t.Error(); err != nil { return nil, err } - newChecker(t, dagFuncs).check(t.reporter, semSeq) - if err := t.Error(); err != nil { - return nil, err - } - main := newDagen(t.reporter).assemble(semSeq, dagFuncs) + main := newDagen(t.reporter).assemble(seq, t.resolver.funcs) return main, t.Error() } @@ -59,26 +53,25 @@ func Analyze(ctx context.Context, p *parser.AST, env *exec.Environment, extInput // to dataflow. type translator struct { reporter - ctx context.Context - opStack []*ast.OpDecl - cteStack []*ast.SQLCTE - env *exec.Environment - scope *Scope - sctx *super.Context - funcs map[string]*sem.FuncDef - funcDecls map[string]*funcDecl + ctx context.Context + resolver *resolver + opStack []*ast.OpDecl + cteStack []*ast.SQLCTE + env *exec.Environment + scope *Scope + sctx *super.Context } func newTranslator(ctx context.Context, r reporter, env *exec.Environment) *translator { - return &translator{ - reporter: r, - ctx: ctx, - env: env, - scope: NewScope(nil), - sctx: super.NewContext(), - funcs: make(map[string]*sem.FuncDef), - funcDecls: make(map[string]*funcDecl), + t := &translator{ + reporter: r, + ctx: ctx, + env: env, + scope: NewScope(nil), + sctx: super.NewContext(), } + t.resolver = newResolver(t) + return t } func HasSource(seq sem.Seq) bool { @@ -111,18 +104,6 @@ func (t *translator) exitScope() { t.scope = t.scope.parent } -func (t *translator) newFunc(body ast.Expr, name string, params []string, e sem.Expr) string { - tag := strconv.Itoa(len(t.funcs)) - t.funcs[tag] = &sem.FuncDef{ - Node: body, - Tag: tag, - Name: name, - Params: params, - Body: e, - } - return tag -} - type opDecl struct { ast *ast.OpDecl scope *Scope // parent scope of op declaration. diff --git a/compiler/semantic/checker.go b/compiler/semantic/checker.go index 1d4a481ed..d2d7dc6c4 100644 --- a/compiler/semantic/checker.go +++ b/compiler/semantic/checker.go @@ -14,20 +14,14 @@ import ( type checker struct { t *translator - funcs map[string]*sem.FuncDef checked map[super.Type]super.Type unknown *super.TypeError estack []errlist } -func newChecker(t *translator, funcs []*sem.FuncDef) *checker { - funcMap := make(map[string]*sem.FuncDef) - for _, f := range funcs { - funcMap[f.Tag] = f - } +func newChecker(t *translator) *checker { return &checker{ t: t, - funcs: funcMap, unknown: t.sctx.LookupTypeError(t.sctx.MustLookupTypeRecord(nil)), checked: make(map[super.Type]super.Type), } @@ -466,15 +460,15 @@ func (c *checker) callBuiltin(call *sem.CallExpr, args []super.Type) super.Type } func (c *checker) callFunc(call *sem.CallExpr, args []super.Type) super.Type { - f := c.funcs[call.Tag] - if len(args) != len(f.Params) { + f := c.t.resolver.funcs[call.Tag] + if len(args) != len(f.params) { // The translator has already checked that len(args) is len(params) // but when there's an error, mismatches can still show up here so // we ignore these here. return c.unknown } fields := make([]super.Field, 0, len(args)) - for k, param := range f.Params { + for k, param := range f.params { fields = append(fields, super.Field{Name: param, Type: args[k]}) } argsType := c.t.sctx.MustLookupTypeRecord(fields) @@ -487,7 +481,7 @@ func (c *checker) callFunc(call *sem.CallExpr, args []super.Type) super.Type { // of all recursive functions. When we add (optional) type signatures to functions, // this problem will (partially) go away. c.checked[argsType] = c.unknown - typ := c.expr(argsType, f.Body) + typ := c.expr(argsType, f.body) c.checked[argsType] = typ return typ } diff --git a/compiler/semantic/dagen.go b/compiler/semantic/dagen.go index 3cbca5034..739e75cf1 100644 --- a/compiler/semantic/dagen.go +++ b/compiler/semantic/dagen.go @@ -23,7 +23,7 @@ func newDagen(r reporter) *dagen { } } -func (d *dagen) assemble(seq sem.Seq, funcs []*sem.FuncDef) *dag.Main { +func (d *dagen) assemble(seq sem.Seq, funcs map[string]*funcDef) *dag.Main { dagSeq := d.seq(seq) dagSeq = d.checkOutputs(true, dagSeq) dagFuncs := make([]*dag.FuncDef, 0, len(d.funcs)) @@ -37,7 +37,7 @@ func (d *dagen) assemble(seq sem.Seq, funcs []*sem.FuncDef) *dag.Main { return &dag.Main{Funcs: dagFuncs, Body: dagSeq} } -func (d *dagen) assembleExpr(e sem.Expr, funcs []*sem.FuncDef) *dag.MainExpr { +func (d *dagen) assembleExpr(e sem.Expr, funcs map[string]*funcDef) *dag.MainExpr { dagExpr := d.expr(e) dagFuncs := make([]*dag.FuncDef, 0, len(d.funcs)) for _, f := range funcs { @@ -523,13 +523,13 @@ func (d *dagen) call(c *sem.CallExpr) *dag.CallExpr { } } -func (d *dagen) fn(f *sem.FuncDef) *dag.FuncDef { +func (d *dagen) fn(f *funcDef) *dag.FuncDef { return &dag.FuncDef{ Kind: "FuncDef", - Tag: f.Tag, - Name: f.Name, - Params: f.Params, - Expr: d.expr(f.Body), + Tag: f.tag, + Name: f.name, + Params: f.params, + Expr: d.expr(f.body), } } diff --git a/compiler/semantic/evaluator.go b/compiler/semantic/evaluator.go index 03b233a67..273dbaa93 100644 --- a/compiler/semantic/evaluator.go +++ b/compiler/semantic/evaluator.go @@ -14,7 +14,7 @@ import ( type evaluator struct { translator *translator - in map[string]*sem.FuncDef + in map[string]*funcDef errs errlist constThis bool bad bool @@ -25,7 +25,7 @@ type errloc struct { err error } -func newEvaluator(t *translator, funcs map[string]*sem.FuncDef) *evaluator { +func newEvaluator(t *translator, funcs map[string]*funcDef) *evaluator { return &evaluator{ translator: t, in: funcs, @@ -47,29 +47,18 @@ func (e *evaluator) maybeEval(sctx *super.Context, expr sem.Expr) (super.Value, } return val, true } - // re-enter the semantic analyzer with just this expr by resolving - // all needed funcs then traversing the resulting sem tree and seeing - // if it will eval as a compile-time constant. If so, compile it the rest - // of the way and invoke rungen to get the result and return it. - // If an error is encountered, returns the error. If the expression - // isn't a compile-time const, then errors will accumulate. Note that - // no existing state in the translator is touched nor is the passed-in - // sem tree modified at all; instead, the process here creates copies - // of any needed sem tree and funcs. - r := newResolver(e.translator) - resolvedExpr, funcs := r.resolveExpr(expr) - e.expr(resolvedExpr) + e.expr(expr) if len(e.errs) > 0 || e.bad { return super.Value{}, false } - for _, f := range funcs { + for _, f := range e.translator.resolver.funcs { e.constThis = true - e.expr(f.Body) + e.expr(f.body) if len(e.errs) > 0 || e.bad { return super.Value{}, false } } - main := newDagen(e.translator.reporter).assembleExpr(resolvedExpr, funcs) + main := newDagen(e.translator.reporter).assembleExpr(expr, e.translator.resolver.funcs) val, err := rungen.EvalAtCompileTime(sctx, main) if err != nil { e.errs.error(expr, err) diff --git a/compiler/semantic/expr.go b/compiler/semantic/expr.go index ced73f206..2b73b33e9 100644 --- a/compiler/semantic/expr.go +++ b/compiler/semantic/expr.go @@ -95,15 +95,15 @@ func (t *translator) semExpr(e ast.Expr) sem.Expr { return t.semFString(e) case *ast.FuncNameExpr: // We get here for &refs that are in a call expression. e.g., - // an arg to another function. These are only built-ins as - // user functions should be referenced directly as an ID. - tag := e.Name - if boundTag, _ := t.scope.lookupFunc(e.Name); boundTag != "" { - tag = boundTag + // an arg to another function. It could be a built-in (as in &upper), + // or a user function (as in fn foo():... &foo)... + id := e.Name + if boundID, _ := t.scope.lookupFuncDeclOrParam(id); boundID != "" { + id = boundID } return &sem.FuncRef{ Node: e, - Tag: tag, + ID: id, } case *ast.GlobExpr: return &sem.RegexpSearchExpr{ @@ -144,10 +144,10 @@ func (t *translator) semExpr(e ast.Expr) sem.Expr { } return out case *ast.LambdaExpr: - tag := t.newFunc(e, "lambda", idsAsStrings(e.Params), t.semExpr(e.Expr)) + funcDecl := t.resolver.newFuncDecl("lambda", e, t.scope) return &sem.FuncRef{ Node: e, - Tag: tag, + ID: funcDecl.id, } case *ast.MapExpr: var entries []sem.Entry @@ -388,7 +388,7 @@ func (t *translator) semID(id *ast.IDExpr, lval bool) sem.Expr { // an error to avoid a rake when such a function is mistakenly passed // without "&" and otherwise turns into a field reference. if entry := t.scope.lookupEntry(id.Name); entry != nil { - if _, ok := entry.ref.(*sem.FuncDef); ok && !lval { + if _, ok := entry.ref.(*funcDef); ok && !lval { t.error(id, fmt.Errorf("function %q referenced but not called (consider &%s to create a function value)", id.Name, id.Name)) return badExpr() } @@ -653,8 +653,8 @@ func (t *translator) maybeSubquery(n ast.Node, name string) *sem.SubqueryExpr { } func (t *translator) semCallLambda(lambda *ast.LambdaExpr, args []sem.Expr) sem.Expr { - tag := t.newFunc(lambda, "lambda", idsAsStrings(lambda.Params), t.semExpr(lambda.Expr)) - return sem.NewCall(lambda, tag, args) + funcDecl := t.resolver.newFuncDecl("lambda", lambda, t.scope) + return t.resolver.mustResolveCall(lambda, funcDecl.id, args) } func (t *translator) semCallByName(call *ast.CallExpr, name string, args []sem.Expr) sem.Expr { @@ -664,23 +664,38 @@ func (t *translator) semCallByName(call *ast.CallExpr, name string, args []sem.E // Check if the name resolves to a symbol in scope. if entry := t.scope.lookupEntry(name); entry != nil { switch ref := entry.ref.(type) { - case param: - // Called name is a parameter inside of a function. We create a dummy - // CallParam that will be converted to a direct call to the passed-in - // function (we don't know it yet and there may be multiple variations - // that all land at this call site) in the next pass of semantic analysis. - return &sem.CallParam{ - Node: call, - Param: name, - Args: args, + case funcParamValue: + t.error(call, fmt.Errorf("function called via parameter %q is bound to a non-function", name)) + return badExpr() + case *funcParamLambda: + // Called name is a parameter inside of a function. We only end up here + // when actual values have been bound to the parameter (i.e., we're compiling + // a lambda-variant function each time it is called to create each variant), + // so we call the resolver here to create a new instance of the function being + // called. In the case of recursion, all the lambdas that are further passed + // as args are known (in terms of their decl IDs), so the resolver can + // look this up in the variants of the decl and stop the recursion even if the body + // of the called entity is not completed yet. We won't know the type but we + // can't know the type without function type signatures so when we integrate + // type checking here, we will use unknown for this corner case. + if isBuiltin(ref.id) { + // Check argument count here for builtin functions. + if _, err := function.New(super.NewContext(), ref.id, len(args)); err != nil { + t.error(call, fmt.Errorf("function %q called via parameter %q: %w", ref.id, ref.param, err)) + return badExpr() + } + return sem.NewCall(call, ref.id, args) } + return t.resolver.mustResolveCall(call, ref.id, args) case *opDecl: t.error(call, fmt.Errorf("cannot call user operator %q in an expression (consider subquery syntax)", name)) return badExpr() case *sem.FuncRef: - return sem.NewCall(call, ref.Tag, args) - case *sem.FuncDef: - return sem.NewCall(call, ref.Tag, args) + // FuncRefs are put in the symbol table when passing stuff to user ops, e.g., + // a lambda as a parameter, a &func, or a builtin like &upper. + return t.resolver.mustResolveCall(ref, ref.ID, args) + case *funcDecl: + return t.resolver.mustResolveCall(call, ref.id, args) case *constDecl, *queryDecl: t.error(call, fmt.Errorf("%q is not a function", name)) return badExpr() @@ -691,22 +706,7 @@ func (t *translator) semCallByName(call *ast.CallExpr, name string, args []sem.E } panic(entry.ref) } - // Call could be to a user func. Check if we have a matching func in scope. - // When the name is a formal argument, the bindings will have been put - // in scope and will point to the right entity (a builtin function name or a FuncDef). - tag, _ := t.scope.lookupFunc(name) nargs := len(args) - // udf should be checked first since a udf can override builtin functions. - if f := t.funcs[tag]; f != nil { - if len(f.Params) != nargs { - t.error(call, fmt.Errorf("call expects %d argument(s)", len(f.Params))) - return badExpr() - } - return sem.NewCall(call, f.Tag, args) - } - if tag != "" { - name = tag - } nameLower := strings.ToLower(name) switch { case nameLower == "map": @@ -779,15 +779,19 @@ func (t *translator) semMapCall(call *ast.CallExpr, args []sem.Expr) sem.Expr { t.error(call, errors.New("map requires two arguments")) return badExpr() } - f, ok := args[1].(*sem.FuncRef) + ref, ok := args[1].(*sem.FuncRef) if !ok { t.error(call, errors.New("second argument to map must be a function")) return badExpr() } - e := &sem.MapCallExpr{ - Node: call, - Expr: args[0], - Lambda: sem.NewCall(call.Args[1], f.Tag, []sem.Expr{sem.NewThis(call.Args[1], nil)}), + mapArgs := []sem.Expr{sem.NewThis(call.Args[1], nil)} + e := t.resolver.resolveCall(call.Args[1], ref.ID, mapArgs) + if callExpr, ok := e.(*sem.CallExpr); ok { + return &sem.MapCallExpr{ + Node: call, + Expr: args[0], + Lambda: callExpr, + } } return e } diff --git a/compiler/semantic/fmt.go b/compiler/semantic/fmt.go index 787bb3179..a760d7ae6 100644 --- a/compiler/semantic/fmt.go +++ b/compiler/semantic/fmt.go @@ -2,22 +2,13 @@ package semantic import ( "github.com/brimdata/super/compiler/semantic/sem" - "github.com/brimdata/super/sup" ) -func Format(main *sem.Main) string { - clrSeq(main.Body) - for _, f := range main.Funcs { - f.Node = nil - clrExpr(f.Body) +func Clear(seq sem.Seq, funcs map[string]*funcDef) { + clrSeq(seq) + for _, f := range funcs { + clrExpr(f.body) } - m := sup.NewMarshaler() - m.Decorate(sup.StyleSimple) - s, err := m.Marshal(main) - if err != nil { - return err.Error() - } - return s } func clrSeq(seq sem.Seq) { diff --git a/compiler/semantic/op.go b/compiler/semantic/op.go index e7e71169f..f3b3d7143 100644 --- a/compiler/semantic/op.go +++ b/compiler/semantic/op.go @@ -1116,68 +1116,9 @@ func (t *translator) semTypeDecl(d *ast.TypeDecl) { } } -func idsAsStrings(ids []*ast.ID) []string { - out := make([]string, 0, len(ids)) - for _, p := range ids { - out = append(out, p.Name) - } - return out -} - -type funcDecl struct { - translator *translator - decl *ast.FuncDecl - funcDef *sem.FuncDef - scope *Scope - pending bool -} - -func newFuncDecl(t *translator, d *ast.FuncDecl, funcDef *sem.FuncDef, scope *Scope) *funcDecl { - return &funcDecl{ - translator: t, - decl: d, - funcDef: funcDef, - scope: scope, - } -} - -func (f *funcDecl) resolve() *sem.FuncDef { - t := f.translator - if f.funcDef.Body == nil { - if !f.pending { - f.pending = true - save := t.scope - t.scope = NewScope(f.scope) - defer func() { - f.pending = false - t.scope = save - }() - t.enterScope() - for _, p := range f.decl.Lambda.Params { - t.scope.BindSymbol(p.Name, param{}) - } - f.funcDef.Body = t.semExpr(f.decl.Lambda.Expr) - t.exitScope() - } else { - t.error(f.decl.Name, fmt.Errorf("function %q involved in cyclic dependency", f.funcDef.Name)) - f.funcDef.Body = badExpr() - } - } - return f.funcDef -} - -func (t *translator) resolveFunc(tag string) *sem.FuncDef { - if decl, ok := t.funcDecls[tag]; ok { - return decl.resolve() - } - return t.funcs[tag] -} - func (t *translator) semFuncDecl(d *ast.FuncDecl) { - tag := t.newFunc(d.Lambda, d.Name.Name, idsAsStrings(d.Lambda.Params), nil) - funcDef := t.funcs[tag] - t.funcDecls[tag] = newFuncDecl(t, d, funcDef, t.scope) - if err := t.scope.BindSymbol(d.Name.Name, funcDef); err != nil { + funcDecl := t.resolver.newFuncDecl(d.Name.Name, d.Lambda, t.scope) + if err := t.scope.BindSymbol(d.Name.Name, funcDecl); err != nil { t.error(d.Name, err) } } @@ -1279,8 +1220,8 @@ func (t *translator) isBool(e sem.Expr) bool { case *sem.CondExpr: return t.isBool(e.Then) && t.isBool(e.Else) case *sem.CallExpr: - if f := t.resolveFunc(e.Tag); f != nil { - return t.isBool(f.Body) + if funcDef, ok := t.resolver.funcs[e.Tag]; ok { + return t.isBool(funcDef.body) } if e.Tag == "cast" { if len(e.Args) != 2 { @@ -1502,11 +1443,11 @@ func (t *translator) mustEval(e sem.Expr) (super.Value, bool) { // and we'll compile this all the way to a DAG and rungen it. This is pretty // general because we need to handle things like subqueries that call // operator sequences that result in a constant value. - return newEvaluator(t, t.funcs).mustEval(t.sctx, e) + return newEvaluator(t, t.resolver.funcs).mustEval(t.sctx, e) } // maybeEVal leaves no errors behind and simply returns a value and bool // indicating if the eval was successful func (t *translator) maybeEval(e sem.Expr) (super.Value, bool) { - return newEvaluator(t, t.funcs).maybeEval(t.sctx, e) + return newEvaluator(t, t.resolver.funcs).maybeEval(t.sctx, e) } diff --git a/compiler/semantic/resolver.go b/compiler/semantic/resolver.go index c4790abf9..12c6a50ea 100644 --- a/compiler/semantic/resolver.go +++ b/compiler/semantic/resolver.go @@ -2,7 +2,6 @@ package semantic import ( "fmt" - "slices" "strconv" "github.com/brimdata/super" @@ -11,540 +10,236 @@ import ( "github.com/brimdata/super/runtime/sam/expr/function" ) +type funcParamLambda struct { + param string // name of parameter that this instance of an argument is bound to + id string // decl ID (either built-in name or ID of decl slot) +} +type funcParamValue struct{} + type resolver struct { - t *translator - fixed map[string]*sem.FuncDef - variants []*sem.FuncDef - params []map[string]string - ntag int + t *translator + decls map[string]*funcDecl // decl id to funcDecl + funcs map[string]*funcDef // call tag to funcDef + fixed map[string]string // decl id of fixed func to call tag + ntag int } func newResolver(t *translator) *resolver { return &resolver{ t: t, - fixed: make(map[string]*sem.FuncDef), - } -} - -func (r *resolver) resolve(seq sem.Seq) (sem.Seq, []*sem.FuncDef) { - out := r.seq(seq) - funcs := slices.Clone(r.variants) - for _, f := range r.fixed { - funcs = append(funcs, f) - } - return out, funcs -} - -func (r *resolver) resolveExpr(e sem.Expr) (sem.Expr, []*sem.FuncDef) { - out := r.expr(e) - funcs := slices.Clone(r.variants) - for _, f := range r.fixed { - funcs = append(funcs, f) - } - return out, funcs -} - -// resolveSeq traverses a DAG and substitues all instances of CallParam by rewriting -// each called function that incudes one or more CallParams substituting the actual -// function called (which is also possibly resolved) as indicated by a FuncRef. -// Each function is also rewritten to remove the function parameters from its params. -// After resolve is done, all CallParam or FuncRef pseudo-expression nodes are -// removed from the sem. All functions are converted from the input function table -// to the output function table whether or not they needed to be resolved. -// Any FuncRefs that do not get bound to a CallParam (i.e., appear in random expressions) -// are found and reported as error as are any CallParam that are called with non-FuncRef -// arguments. -func (r *resolver) seq(seq sem.Seq) sem.Seq { - var out sem.Seq - for _, op := range seq { - out = append(out, r.op(op)) - } - return out -} - -func (r *resolver) op(op sem.Op) sem.Op { - switch op := op.(type) { - // - // Scanners first - // - case *sem.CommitMetaScan: - case *sem.DBMetaScan: - case *sem.DefaultScan: - case *sem.DeleteScan: - case *sem.FileScan: - case *sem.HTTPScan: - case *sem.NullScan: - case *sem.PoolMetaScan: - case *sem.PoolScan: - case *sem.RobotScan: - return &sem.RobotScan{ - Node: op.Node, - Expr: r.expr(op.Expr), - Format: op.Format, - } - // - // Ops in alphabetic order - // - case *sem.AggregateOp: - return &sem.AggregateOp{ - Node: op.Node, - Limit: op.Limit, - Keys: r.assignments(op.Keys), - Aggs: r.assignments(op.Aggs), - } - case *sem.BadOp: - case *sem.CutOp: - return &sem.CutOp{ - Node: op.Node, - Args: r.assignments(op.Args), - } - case *sem.DebugOp: - return &sem.DebugOp{ - Node: op.Node, - Expr: r.expr(op.Expr), - } - case *sem.DistinctOp: - return &sem.DistinctOp{ - Node: op.Node, - Expr: r.expr(op.Expr), - } - case *sem.DropOp: - return &sem.DropOp{ - Node: op.Node, - Args: r.exprs(op.Args), - } - case *sem.ExplodeOp: - return &sem.ExplodeOp{ - Node: op.Node, - Args: r.exprs(op.Args), - Type: op.Type, - As: op.As, - } - case *sem.FilterOp: - return &sem.FilterOp{ - Node: op.Node, - Expr: r.expr(op.Expr), - } - case *sem.ForkOp: - var paths []sem.Seq - for _, seq := range op.Paths { - paths = append(paths, r.seq(seq)) - } - return &sem.ForkOp{ - Node: op.Node, - Paths: paths, - } - case *sem.FuseOp: - case *sem.HeadOp: - case *sem.JoinOp: - return &sem.JoinOp{ - Node: op.Node, - Style: op.Style, - LeftAlias: op.LeftAlias, - RightAlias: op.RightAlias, - Cond: r.expr(op.Cond), - } - case *sem.LoadOp: - case *sem.MergeOp: - return &sem.MergeOp{ - Node: op.Node, - Exprs: r.sortExprs(op.Exprs), - } - case *sem.OutputOp: - case *sem.PassOp: - case *sem.PutOp: - return &sem.PutOp{ - Node: op.Node, - Args: r.assignments(op.Args), - } - case *sem.RenameOp: - return &sem.RenameOp{ - Node: op.Node, - Args: r.assignments(op.Args), - } - case *sem.SkipOp: - case *sem.SortOp: - return &sem.SortOp{ - Node: op.Node, - Exprs: r.sortExprs(op.Exprs), - Reverse: op.Reverse, - } - case *sem.SwitchOp: - var cases []sem.Case - for _, c := range op.Cases { - cases = append(cases, sem.Case{ - Expr: r.expr(c.Expr), - Path: r.seq(c.Path), - }) - } - return &sem.SwitchOp{ - Node: op.Node, - Expr: r.expr(op.Expr), - Cases: cases, - } - case *sem.TailOp: - case *sem.TopOp: - return &sem.TopOp{ - Node: op.Node, - Limit: op.Limit, - Exprs: r.sortExprs(op.Exprs), - } - case *sem.UniqOp: - case *sem.UnnestOp: - return &sem.UnnestOp{ - Node: op.Node, - Expr: r.expr(op.Expr), - Body: r.seq(op.Body), - } - case *sem.ValuesOp: - return &sem.ValuesOp{ - Node: op.Node, - Exprs: r.exprs(op.Exprs), - } - - default: - panic(op) - } - return op -} - -func (r *resolver) assignments(assignments []sem.Assignment) []sem.Assignment { - var out []sem.Assignment - for _, a := range assignments { - out = append(out, sem.Assignment{ - Node: a.Node, - LHS: r.expr(a.LHS), - RHS: r.expr(a.RHS), - }) - } - return out -} - -func (r *resolver) sortExprs(exprs []sem.SortExpr) []sem.SortExpr { - var out []sem.SortExpr - for _, e := range exprs { - out = append(out, sem.SortExpr{ - Node: e.Node, - Expr: r.expr(e.Expr), - Order: e.Order, - Nulls: e.Nulls}) - } - return out -} - -func (r *resolver) exprs(exprs []sem.Expr) []sem.Expr { - var out []sem.Expr - for _, e := range exprs { - out = append(out, r.expr(e)) - } - return out -} - -func (r *resolver) expr(e sem.Expr) sem.Expr { - switch e := e.(type) { - case nil: - return nil - case *sem.AggFunc: - return &sem.AggFunc{ - Node: e.Node, - Name: e.Name, - Distinct: e.Distinct, - Expr: r.expr(e.Expr), - Where: r.expr(e.Where), - } - case *sem.ArrayExpr: - return &sem.ArrayExpr{ - Node: e.Node, - Elems: r.arrayElems(e.Elems), - } - case *sem.BadExpr: - case *sem.BinaryExpr: - return sem.NewBinaryExpr(e.Node, e.Op, r.expr(e.LHS), r.expr(e.RHS)) - case *sem.CallExpr: - return r.resolveCall(e.Node, e.Tag, e.Args) - case *sem.CallParam: - // This is a call to a parameter. It must only appear enclosed in a FuncDef - // with params containing the name in the e.Param. The function being referenced - // or passed in can be lazily created. - return r.resolveCallParam(e) - case *sem.CondExpr: - return &sem.CondExpr{ - Node: e.Node, - Cond: r.expr(e.Cond), - Then: r.expr(e.Then), - Else: r.expr(e.Else), - } - case *sem.DotExpr: - return &sem.DotExpr{ - Node: e.Node, - LHS: r.expr(e.LHS), - RHS: e.RHS, - } - case *sem.FuncRef: - // This needs to be in an argument list and can't be anywhere else... bad DAG - panic(e) - case *sem.IndexExpr: - return &sem.IndexExpr{ - Node: e.Node, - Expr: r.expr(e.Expr), - Index: r.expr(e.Index), - } - case *sem.IsNullExpr: - return &sem.IsNullExpr{ - Node: e.Node, - Expr: r.expr(e.Expr), - } - case *sem.LiteralExpr: - case *sem.MapCallExpr: - call, ok := r.resolveCall(e.Node, e.Lambda.Tag, e.Lambda.Args).(*sem.CallExpr) - if !ok { - return badExpr() - } - return &sem.MapCallExpr{ - Node: e.Node, - Expr: r.expr(e.Expr), - Lambda: call, - } - case *sem.MapExpr: - var entries []sem.Entry - for _, entry := range e.Entries { - entries = append(entries, sem.Entry{ - Key: r.expr(entry.Key), - Value: r.expr(entry.Value), - }) - } - return &sem.MapExpr{ - Node: e.Node, - Entries: entries, - } - case *sem.RecordExpr: - return &sem.RecordExpr{ - Node: e.Node, - Elems: r.recordElems(e.Elems), - } - case *sem.RegexpMatchExpr: - return &sem.RegexpMatchExpr{ - Node: e.Node, - Pattern: e.Pattern, - Expr: r.expr(e.Expr), - } - case *sem.RegexpSearchExpr: - return &sem.RegexpSearchExpr{ - Node: e.Node, - Pattern: e.Pattern, - Expr: r.expr(e.Expr), - } - case *sem.SearchTermExpr: - return &sem.SearchTermExpr{ - Node: e.Node, - Text: e.Text, - Value: e.Value, - Expr: r.expr(e.Expr), - } - case *sem.SetExpr: - return &sem.SetExpr{ - Node: e.Node, - Elems: r.arrayElems(e.Elems), - } - case *sem.SliceExpr: - return &sem.SliceExpr{ - Node: e.Node, - Expr: r.expr(e.Expr), - From: r.expr(e.From), - To: r.expr(e.To), - } - case *sem.SubqueryExpr: - // We clear params before processing a subquery so you can't - // touch passed-in functions inside the first "this" of a correlated - // subquery. We can support this later if people are interested. - // It requires a bit of surgery. - r.pushParams(make(map[string]string)) - defer r.popParams() - return &sem.SubqueryExpr{ - Node: e.Node, - Correlated: e.Correlated, - Array: e.Array, - Body: r.seq(e.Body), - } - case *sem.ThisExpr: - case *sem.UnaryExpr: - return sem.NewUnaryExpr(e.Node, e.Op, r.expr(e.Operand)) - default: - panic(e) - } - return e -} - -func (r *resolver) arrayElems(elems []sem.ArrayElem) []sem.ArrayElem { - var out []sem.ArrayElem - for _, elem := range elems { - switch elem := elem.(type) { - case *sem.SpreadElem: - out = append(out, &sem.SpreadElem{ - Node: elem.Node, - Expr: r.expr(elem.Expr), - }) - case *sem.ExprElem: - out = append(out, &sem.ExprElem{ - Node: elem.Node, - Expr: r.expr(elem.Expr), - }) - default: - panic(elem) - } - } - return out -} - -func (r *resolver) recordElems(elems []sem.RecordElem) []sem.RecordElem { - var out []sem.RecordElem - for _, elem := range elems { - switch elem := elem.(type) { - case *sem.SpreadElem: - out = append(out, &sem.SpreadElem{ - Node: elem.Node, - Expr: r.expr(elem.Expr), - }) - case *sem.FieldElem: - out = append(out, &sem.FieldElem{ - Node: elem.Node, - Name: elem.Name, - Value: r.expr(elem.Value), - }) - default: - panic(elem) - } - } - return out -} - -func (r *resolver) resolveCallParam(call *sem.CallParam) sem.Expr { - oldTag := r.lookupParam(call.Param) - if oldTag == "" { - // This can happen when we go to resolve a parameter that wasn't bound to - // an actual function because some other value was bound to it so it didn't - // get put in the parameter table. - r.t.error(call.Node, fmt.Errorf("function called via parameter %q is bound to a non-function", call.Param)) - return badExpr() - } - if isBuiltin(oldTag) { + decls: make(map[string]*funcDecl), + funcs: make(map[string]*funcDef), + fixed: make(map[string]string), + } +} + +// There is a funcDef for every unique lambda-unraveled instance and +// exactly one for each such combination of decl IDs. The instances +// are unqiue up to the lamba params, which are in turn identified by +// their decl ID (or built-in name). Since lambda params originate only +// at a built-in or function declaration and they can't be modified +// (only passed by reference), the variants are determined by decl ID +// rather than call tag. This allows us to translate and unravel +// all the functions and lambda arguments in a single pass integrated +// into the translator logic, which in turn, allows us to carry out +// type checking since functions are resolved to actual sem.Expr instances +// from the beginning. +type funcDef struct { + tag string + name string // original name in decl or "lambda" (for errors) + params []string // params of args without any lambda args + lambdas []lambda + body sem.Expr +} + +type lambda struct { + param string // parameter name this lambda arg appeared as + pos int // position in the formal parameters list of the function + id string // declaration ID (or built-in) of the function value +} + +func (r *resolver) resolveCall(n ast.Node, id string, args []sem.Expr) sem.Expr { + if isBuiltin(id) { // Check argument count here for builtin functions. - if _, err := function.New(super.NewContext(), oldTag, len(call.Args)); err != nil { - r.t.error(call.Node, fmt.Errorf("function %q called via parameter %q: %w", oldTag, call.Param, err)) + if _, err := function.New(super.NewContext(), id, len(args)); err != nil { + r.t.error(n, err) return badExpr() } + return sem.NewCall(n, id, args) } - return r.resolveCall(call.Node, oldTag, call.Args) + return r.mustResolveCall(n, id, args) } -func (r *resolver) resolveCall(n ast.Node, oldTag string, args []sem.Expr) sem.Expr { - if isBuiltin(oldTag) { - return &sem.CallExpr{ - Node: n, - Tag: oldTag, - Args: r.exprs(args), - } +func (r *resolver) mustResolveCall(n ast.Node, id string, args []sem.Expr) sem.Expr { + d, ok := r.decls[id] + if !ok { + panic(id) } - // Translate the tag to the new func table and convert any - // function refs passed as args to lookup-table removing + // Translate the decl ID to a func by converting any + // function refs passed as args to a key into the variants table removing // correponding args. var params []string var exprs []sem.Expr - funcDef := r.t.resolveFunc(oldTag) - if len(funcDef.Params) != len(args) { - r.t.error(n, fmt.Errorf("%q: expected %d params but called with %d", funcDef.Name, len(funcDef.Params), len(args))) + declParams := d.lambda.Params + if len(declParams) != len(args) { + r.t.error(n, fmt.Errorf("%q: expected %d params but called with %d", d.name, len(declParams), len(args))) return badExpr() } - bindings := make(map[string]string) + var lambdas []lambda for k, arg := range args { if f, ok := arg.(*sem.FuncRef); ok { - bindings[funcDef.Params[k]] = f.Tag + lambdas = append(lambdas, lambda{param: declParams[k].Name, pos: k, id: f.ID}) continue } - e := r.expr(arg) - if e, ok := e.(*sem.ThisExpr); ok { + if e, ok := arg.(*sem.ThisExpr); ok { if len(e.Path) == 1 { // Propagate a function passed as a function value inside of - // a function to another function. - if tag := r.lookupParam(e.Path[0]); tag != "" { - bindings[funcDef.Params[k]] = tag + // a function to another function as the new param name. + if id, ok := r.t.scope.lookupFuncParamLambda(e.Path[0]); ok { + lambdas = append(lambdas, lambda{param: declParams[k].Name, pos: k, id: id}) continue } } } - params = append(params, funcDef.Params[k]) - exprs = append(exprs, r.expr(arg)) + params = append(params, declParams[k].Name) + exprs = append(exprs, arg) } - if len(funcDef.Params) == len(params) { + if len(declParams) == len(params) { // No need to specialize this call since no function args are being passed. - newTag := r.lookupFixed(oldTag) return &sem.CallExpr{ Node: n, - Tag: newTag, + Tag: r.lookupFixed(id), Args: args, } } // Enter the new function scope and set up the bindings for the // values we retrieved above while evaluating args in the outer scope. - r.pushParams(bindings) - defer r.popParams() - newTag := r.lookupVariant(oldTag, params) return &sem.CallExpr{ Node: n, - Tag: newTag, + Tag: r.getVariant(id, params, lambdas).tag, Args: exprs, } } -func (r *resolver) lookupFixed(oldTag string) string { - if funcDef, ok := r.fixed[oldTag]; ok { - return funcDef.Tag +func (r *resolver) resolveVariant(d *funcDecl, variant *funcDef) { + save := r.t.scope + r.t.scope = NewScope(d.scope) + defer func() { + r.t.scope = save + }() + r.t.enterScope() + for _, lambda := range variant.lambdas { + r.t.scope.BindSymbol(lambda.param, &funcParamLambda{param: lambda.param, id: lambda.id}) } - funcDef := r.t.resolveFunc(oldTag) - newTag := r.nextTag() - newFuncDef := &sem.FuncDef{ - Node: funcDef.Node, - Tag: newTag, - Name: funcDef.Name, - Params: funcDef.Params, + for _, param := range variant.params { + r.t.scope.BindSymbol(param, funcParamValue{}) } - r.fixed[oldTag] = newFuncDef - newFuncDef.Body = r.expr(funcDef.Body) - return newTag + variant.body = r.t.semExpr(d.lambda.Expr) + r.t.exitScope() } -func (r *resolver) lookupVariant(oldTag string, params []string) string { - newTag := r.nextTag() - funcDef := r.t.resolveFunc(oldTag) - r.variants = append(r.variants, &sem.FuncDef{ - Node: funcDef.Node, - Tag: newTag, - Name: funcDef.Name, - Params: params, - Body: r.expr(funcDef.Body), // since params have been bound this will convert the CallParams - }) - return newTag +func (r *resolver) resolveFixed(d *funcDecl, tag string) *funcDef { + save := r.t.scope + r.t.scope = NewScope(d.scope) + defer func() { + r.t.scope = save + }() + r.t.enterScope() + params := idsToStrings(d.lambda.Params) + for _, param := range params { + r.t.scope.BindSymbol(param, funcParamValue{}) + } + body := r.t.semExpr(d.lambda.Expr) + r.t.exitScope() + return &funcDef{ + tag: tag, + name: d.name, + params: params, + body: body, + } } -func (r *resolver) nextTag() string { - tag := strconv.Itoa(r.ntag) - r.ntag++ +func (r *resolver) lookupFixed(id string) string { + if tag, ok := r.fixed[id]; ok { + return tag + } + tag := r.nextTag() + // Install this binding up front to block recursion. + r.fixed[id] = tag + decl := r.decls[id] + r.funcs[tag] = r.resolveFixed(decl, tag) return tag } -func (r *resolver) pushParams(scope map[string]string) { - r.params = append(r.params, scope) +func (r *resolver) lookupVariant(id string, lambdas []lambda) *funcDef { + d := r.decls[id] + for _, variant := range d.variants { + if ok := matchVariant(variant, lambdas); ok { + return variant + } + } + return nil +} + +func (r *resolver) getVariant(id string, params []string, lambdas []lambda) *funcDef { + if variant := r.lookupVariant(id, lambdas); variant != nil { + return variant + } + d := r.decls[id] + tag := r.nextTag() + variant := &funcDef{ + tag: tag, + name: d.name, + params: params, + lambdas: lambdas, + // body needs to be resolved + } + r.funcs[tag] = variant + // We put the variant in the decl before resolving the body to stop recursion. + d.variants = append(d.variants, variant) + // Resolve the body here. + r.resolveVariant(d, variant) + return variant +} + +func matchVariant(def *funcDef, lambdas []lambda) bool { + // find match or detect change and error or false/nil + if len(def.lambdas) != len(lambdas) { + return false + } + for k, lambda := range lambdas { + if lambda.pos != def.lambdas[k].pos || lambda.id != def.lambdas[k].id { + return false + } + } + return true } -func (r *resolver) popParams() { - r.params = r.params[0 : len(r.params)-1] +type funcDecl struct { + id string + name string + lambda *ast.LambdaExpr + scope *Scope + variants []*funcDef } -func (r *resolver) lookupParam(param string) string { - if len(r.params) == 0 { - return "" +func (r *resolver) newFuncDecl(name string, lambda *ast.LambdaExpr, s *Scope) *funcDecl { + // decl IDs give us an id for sem.FuncRef, which can then be turned back + // into a *funcDecl with r.decls + id := strconv.Itoa(len(r.decls)) + r.decls[id] = &funcDecl{ + id: id, + name: name, + lambda: lambda, + scope: s, } - return r.params[len(r.params)-1][param] + return r.decls[id] +} + +func (r *resolver) nextTag() string { + tag := strconv.Itoa(r.ntag) + r.ntag++ + return tag } func isBuiltin(tag string) bool { diff --git a/compiler/semantic/scope.go b/compiler/semantic/scope.go index 55418076d..2167877fb 100644 --- a/compiler/semantic/scope.go +++ b/compiler/semantic/scope.go @@ -25,8 +25,6 @@ type entry struct { order int } -type param struct{} - func (s *Scope) BindSymbol(name string, e any) error { if _, ok := s.symbols[name]; ok { return fmt.Errorf("symbol %q redefined", name) @@ -69,7 +67,7 @@ func (s *Scope) lookupExpr(t *translator, name string) sem.Expr { // function parameters hide exteral definitions as you don't // want the this.param ref to be overriden by a const etc. switch entry := entry.ref.(type) { - case *sem.FuncDef, *ast.FuncNameExpr, param, *opDecl: + case *funcDecl, *ast.FuncNameExpr, *funcParamLambda, funcParamValue, *opDecl: return nil case *constDecl: return entry.resolve(t) @@ -79,20 +77,36 @@ func (s *Scope) lookupExpr(t *translator, name string) sem.Expr { return nil } -func (s *Scope) lookupFunc(name string) (string, error) { +// Returns the decl ID of a function declared with the given name in +// this scope or a function value passed as a lambda parameter and bound +// to formal parameter with the given name. +func (s *Scope) lookupFuncDeclOrParam(name string) (string, error) { entry := s.lookupEntry(name) if entry == nil { return "", nil } switch ref := entry.ref.(type) { - case *sem.FuncDef: - return ref.Tag, nil - case *sem.FuncRef: - return ref.Tag, nil + case *funcDecl: + return ref.id, nil + case *funcParamLambda: + return ref.id, nil } return "", fmt.Errorf("%q is not bound to a function", name) } +// See if there's a function value passed as a lambda of a formal parameter +// and if so, return the underlying decl ID of that lambda argument. +func (s *Scope) lookupFuncParamLambda(name string) (string, bool) { + entry := s.lookupEntry(name) + if entry == nil { + return "", false + } + if ref, ok := entry.ref.(*funcParamLambda); ok { + return ref.id, true + } + return "", false +} + // resolve paths based on SQL semantics in order of precedence // and replace with dag path with schemafied semantics. // In the case of unqualified col ref, check that it is not ambiguous diff --git a/compiler/semantic/sem/expr.go b/compiler/semantic/sem/expr.go index 95325f60e..f43d58007 100644 --- a/compiler/semantic/sem/expr.go +++ b/compiler/semantic/sem/expr.go @@ -176,28 +176,17 @@ func (*UnaryExpr) exprNode() {} // FuncRef is a pseudo-expression that represents a function reference as a value. // It is not used by the runtime (but could be if we wanted to support this). Instead, -// the semantic pass uses this in a first stage to represent lambda-parameterized functions -// then in a second stage it unrolls them all into regular calls by creating a unique -// new function for each combination of passed in lambdas. +// the semantic pass uses this to represent lambda-parameterized functions, e.g., +// functions that are passed to other functions as arguments. Whenever such values +// appear as function arguments, they are installed in the symbol table as bound to +// the function declaration's ID then each variation of lambda-invoked function is +// compiled to a unique function by the resolver. type FuncRef struct { ast.Node - Tag string + ID string } -// CallParam is a pseudo-expression that is like a call but represents the call -// of a FuncRef passed as an argument with the parameter name given by Param. -// It is not used by the runtime (but could be if we wanted to support this). Instead, -// the semantic pass uses this in a first stage to represent abstract calls to functions -// passed as parameters, then in a second stage it flattens them all into regular calls -// by creating a unique new function for each combination of passed-in lambdas. -type CallParam struct { - ast.Node - Param string - Args []Expr -} - -func (*FuncRef) exprNode() {} -func (*CallParam) exprNode() {} +func (*FuncRef) exprNode() {} func NewThis(n ast.Node, path []string) *ThisExpr { return &ThisExpr{Node: n, Path: path} diff --git a/compiler/semantic/sem/op.go b/compiler/semantic/sem/op.go index 6895839d7..94ce00ee3 100644 --- a/compiler/semantic/sem/op.go +++ b/compiler/semantic/sem/op.go @@ -14,11 +14,6 @@ import ( "github.com/segmentio/ksuid" ) -type Main struct { - Funcs []*FuncDef - Body Seq -} - // Op is the interface implemented by all AST operator nodes. type Op interface { opNode() @@ -91,14 +86,6 @@ func (*PoolMetaScan) opNode() {} func (*PoolScan) opNode() {} func (*RobotScan) opNode() {} -type FuncDef struct { - ast.Node - Tag string - Name string - Params []string - Body Expr -} - type Seq []Op func (s *Seq) Prepend(front Op) {