diff --git a/go.mod b/go.mod index a21b946bca..c20be68328 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,9 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad - github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677 + github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c + github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index a97a1b5a39..7c402b7624 100644 --- a/go.sum +++ b/go.sum @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszNu37VPQckdhX1BPPVzREsGgNxQeefnlm3g= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677 h1:Wn0v7xBxkdzYqDN/4ksI34jstZOskMccT2SYFrvUw4c= -github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677/go.mod h1:5ZdrW0fHZbz+8CngT9gksqSX4H3y+7v1pns7tJCEpu0= +github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545 h1:O+/sjRQJadYzyVr89Zh9yCnhZJ0NlHwiDYsXHnj3LsU= +github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -276,8 +276,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c h1:imdag6PPCHAO2rZNsFoQoR4I/vIVTmO/czoOl5rUnbk= -github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c h1:23KvsBtJk2GmHpXwQ/RkwIkdNpWL8tWdHRCiidhnaUA= +github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 2babfa310e..20f5fc0fbf 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -17,6 +17,7 @@ package analyzer import ( "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/planbuilder" ) // IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs @@ -101,6 +102,23 @@ func initEngine() { // place to put it. Our foreign key validation logic is different from MySQL's, and since it's not an analyzer rule // we can't swap out a rule like the rest of the logic in this package, we have to do a function swap. plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition + + planbuilder.IsAggregateFunc = IsAggregateFunc +} + +// IsAggregateFunc checks if the given function name is an aggregate function. This is the entire set supported by +// MySQL plus some postgres specific ones. +func IsAggregateFunc(name string) bool { + if planbuilder.IsMySQLAggregateFuncName(name) { + return true + } + + switch name { + case "array_agg": + return true + } + + return false } // insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice. diff --git a/server/functions/aggregate/array_agg.go b/server/functions/aggregate/array_agg.go new file mode 100755 index 0000000000..2a0457cb8b --- /dev/null +++ b/server/functions/aggregate/array_agg.go @@ -0,0 +1,66 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +import ( + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// initArrayAgg registers the functions to the catalog. +func initArrayAgg() { + framework.RegisterAggregateFunction(array_agg) +} + +// array_agg represents the PostgreSQL array_agg function. +var array_agg = framework.Func1Aggregate{ + Function1: framework.Function1{ + Name: "array_agg", + Return: pgtypes.AnyArray, + Parameters: [1]*pgtypes.DoltgresType{ + pgtypes.AnyElement, + }, + Callable: func(ctx *sql.Context, paramsAndReturn [2]*pgtypes.DoltgresType, val1 any) (any, error) { + return nil, nil + }, + }, + NewAggBuffer: newArrayAggBuffer, +} + +type arrayAggBuffer struct { + elements []any +} + +func newArrayAggBuffer() (sql.AggregationBuffer, error) { + return &arrayAggBuffer{ + elements: make([]any, 0), + }, nil +} + +func (a *arrayAggBuffer) Dispose() {} + +func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) { + if len(a.elements) == 0 { + return nil, nil + } + return a.elements, nil +} + +func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { + a.elements = append(a.elements, row[0]) + return nil +} diff --git a/server/functions/aggregate/init.go b/server/functions/aggregate/init.go new file mode 100755 index 0000000000..469a0d52af --- /dev/null +++ b/server/functions/aggregate/init.go @@ -0,0 +1,19 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +func Init() { + initArrayAgg() +} diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index ff92b6645a..32480b52ea 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -15,6 +15,7 @@ package framework import ( + "fmt" "strings" "github.com/cockroachdb/errors" @@ -28,6 +29,9 @@ import ( // Catalog contains all of the PostgreSQL functions. var Catalog = map[string][]FunctionInterface{} +// AggregateCatalog contains all of the PostgreSQL aggregate functions. +var AggregateCatalog = map[string][]AggregateFunctionInterface{} + // initializedFunctions simply states whether Initialize has been called yet. var initializedFunctions = false @@ -61,6 +65,21 @@ func RegisterFunction(f FunctionInterface) { } } +// RegisterAggregateFunction registers the given function, so that it will be usable from a running server. This should be called +// from within an init(). +func RegisterAggregateFunction(f AggregateFunctionInterface) { + if initializedFunctions { + panic("attempted to register a function after the init() phase") + } + switch f := f.(type) { + case Func1Aggregate: + name := strings.ToLower(f.Name) + AggregateCatalog[name] = append(AggregateCatalog[name], f) + default: + panic(fmt.Sprintf("unhandled function type %T", f)) + } +} + // Initialize handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not // apply to PostgreSQL (and functions of the same name often have different behavior). func Initialize() { @@ -76,6 +95,7 @@ func Initialize() { replaceGmsBuiltIns() validateFunctions() compileFunctions() + compileAggs() } // replaceGmsBuiltIns replaces all GMS built-ins that have conflicting names with PostgreSQL functions. @@ -151,7 +171,29 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface) Fn: createFunc, }) compiledCatalog[funcName] = createFunc - namedCatalog[funcName] = overloads +} + +// compileNonOperatorFunction creates a CompiledFunction for each overload of the given function. +func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) { + var newBuffer func() (sql.AggregationBuffer, error) + overloadTree := NewOverloads() + for _, functionOverload := range overloads { + newBuffer = functionOverload.NewBuffer + if err := overloadTree.Add(functionOverload); err != nil { + panic(err) + } + } + + // Store the compiled function into the engine's built-in functions + // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog + createFunc := func(params ...sql.Expression) (sql.Expression, error) { + return NewCompiledAggregateFunction(funcName, params, overloadTree, newBuffer), nil + } + function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ + Name: funcName, + Fn: createFunc, + }) + compiledCatalog[funcName] = createFunc } // compileFunctions creates a CompiledFunction for each overload of each function in the catalog. @@ -165,10 +207,10 @@ func compileFunctions() { // special rules, so it's far more efficient to reuse it for operators. Operators are also a special case since they // all have different names, while standard overload deducers work on a function-name basis. for signature, functionOverload := range unaryFunctions { - overloads, ok := unaryAggregateOverloads[signature.Operator] + overloads, ok := unaryOperatorOverloads[signature.Operator] if !ok { overloads = NewOverloads() - unaryAggregateOverloads[signature.Operator] = overloads + unaryOperatorOverloads[signature.Operator] = overloads } if err := overloads.Add(functionOverload); err != nil { panic(err) @@ -176,10 +218,10 @@ func compileFunctions() { } for signature, functionOverload := range binaryFunctions { - overloads, ok := binaryAggregateOverloads[signature.Operator] + overloads, ok := binaryOperatorOverloads[signature.Operator] if !ok { overloads = NewOverloads() - binaryAggregateOverloads[signature.Operator] = overloads + binaryOperatorOverloads[signature.Operator] = overloads } if err := overloads.Add(functionOverload); err != nil { panic(err) @@ -187,10 +229,16 @@ func compileFunctions() { } // Add all permutations for the unary and binary operators - for operator, overload := range unaryAggregateOverloads { - unaryAggregatePermutations[operator] = overload.overloadsForParams(1) + for operator, overload := range unaryOperatorOverloads { + unaryOperatorPermutations[operator] = overload.overloadsForParams(1) + } + for operator, overload := range binaryOperatorOverloads { + binaryOperatorPermutations[operator] = overload.overloadsForParams(2) } - for operator, overload := range binaryAggregateOverloads { - binaryAggregatePermutations[operator] = overload.overloadsForParams(2) +} + +func compileAggs() { + for funcName, overloads := range AggregateCatalog { + compileAggFunction(funcName, overloads) } } diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go new file mode 100644 index 0000000000..cbf610b6fb --- /dev/null +++ b/server/functions/framework/compiled_aggregate_function.go @@ -0,0 +1,131 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package framework + +import ( + "strings" + + cerrors "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" +) + +// AggregateFunction is an expression that represents CompiledAggregateFunction +type AggregateFunction interface { + sql.FunctionExpression + sql.Aggregation + specificFuncImpl() +} + +// CompiledAggregateFunction is an expression that represents a fully-analyzed PostgreSQL aggregate function. +type CompiledAggregateFunction struct { + *CompiledFunction + aggId sql.ColumnId + newBuffer func() (sql.AggregationBuffer, error) +} + +var _ AggregateFunction = (*CompiledAggregateFunction)(nil) + +// NewCompiledAggregateFunction returns a newly compiled function. +// TODO: newBuffer probably needs to be parameterized in the overloads +func NewCompiledAggregateFunction(name string, args []sql.Expression, functions *Overloads, newBuffer func() (sql.AggregationBuffer, error)) *CompiledAggregateFunction { + return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), newBuffer) +} + +// newCompiledAggregateFunctionInternal is called internally, which skips steps that may have already been processed. +func newCompiledAggregateFunctionInternal(name string, args []sql.Expression, overloads *Overloads, fnOverloads []Overload, newBuffer func() (sql.AggregationBuffer, error)) *CompiledAggregateFunction { + cf := newCompiledFunctionInternal(name, args, overloads, fnOverloads, false, nil) + c := &CompiledAggregateFunction{ + CompiledFunction: cf, + newBuffer: newBuffer, + } + + return c +} + +// Eval implements the interface sql.Expression. +func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return nil, cerrors.New("Eval should not be called on CompiledAggregateFunction") +} + +// WithChildren implements the interface sql.Expression. +func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(c.Arguments) { + return nil, sql.ErrInvalidChildrenNumber.New(len(children), len(c.Arguments)) + } + + // We have to re-resolve here, since the change in children may require it (e.g. we have more type info than we did) + return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.newBuffer), nil +} + +// SetStatementRunner implements the interface analyzer.Interpreter. +func (c *CompiledAggregateFunction) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Expression { + nc := *c + nc.runner = runner + return &nc +} + +// specificFuncImpl implements the interface sql.Expression. +func (*CompiledAggregateFunction) specificFuncImpl() {} + +func (c *CompiledAggregateFunction) DebugString() string { + sb := strings.Builder{} + sb.WriteString("CompiledAggregateFunction:") + sb.WriteString(c.Name + "(") + for i, param := range c.Arguments { + // Aliases will output the string "x as x", which is an artifact of how we build the AST, so we'll bypass it + if alias, ok := param.(*expression.Alias); ok { + param = alias.Child + } + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(sql.DebugString(param)) + } + sb.WriteString(")") + return sb.String() +} + +// NewBuffer implements the interface sql.Aggregation. +func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { + return c.newBuffer() +} + +// Id implements the interface sql.Aggregation. +func (c *CompiledAggregateFunction) Id() sql.ColumnId { + return c.aggId +} + +// WithId implements the interface sql.Aggregation. +func (c *CompiledAggregateFunction) WithId(id sql.ColumnId) sql.IdExpression { + nc := *c + nc.aggId = id + return &nc +} + +// NewWindowFunction implements the interface sql.WindowAdaptableExpression. +func (c *CompiledAggregateFunction) NewWindowFunction() (sql.WindowFunction, error) { + panic("windows are not implemented yet") +} + +// WithWindow implements the interface sql.WindowAdaptableExpression. +func (c *CompiledAggregateFunction) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + panic("windows are not implemented yet") +} + +// Window implements the interface sql.WindowAdaptableExpression. +func (c *CompiledAggregateFunction) Window() *sql.WindowDefinition { + panic("windows are not implemented yet") +} diff --git a/server/functions/framework/compiled_catalog.go b/server/functions/framework/compiled_catalog.go index ab0a5f4196..7706d7d096 100644 --- a/server/functions/framework/compiled_catalog.go +++ b/server/functions/framework/compiled_catalog.go @@ -23,9 +23,6 @@ import ( // compiledCatalog contains all of PostgreSQL functions in their compiled forms. var compiledCatalog = map[string]sql.CreateFuncNArgs{} -// namedCatalog contains the definitions of every PostgreSQL function associated with the given name. -var namedCatalog = map[string][]FunctionInterface{} - // GetFunction returns the compiled function with the given name and parameters. Returns false if the function could not // be found. func GetFunction(functionName string, params ...sql.Expression) (*CompiledFunction, bool, error) { diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index a3942918fb..e82ad050a8 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -307,7 +307,7 @@ func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, err case InterpretedFunction: return plpgsql.Call(ctx, f, c.runner, c.callResolved, args) default: - return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval") + return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval %T", f) } } @@ -328,8 +328,9 @@ func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Express // SetStatementRunner implements the interface analyzer.Interpreter. func (c *CompiledFunction) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Expression { - c.runner = runner - return c + nc := *c + nc.runner = runner + return &nc } // GetQuickFunction returns the QuickFunction form of this function, if it exists. If one does not exist, then this diff --git a/server/functions/framework/functions.go b/server/functions/framework/functions.go index ebbafa3b4a..a794b95ff0 100644 --- a/server/functions/framework/functions.go +++ b/server/functions/framework/functions.go @@ -45,6 +45,13 @@ type FunctionInterface interface { enforceInterfaceInheritance(error) } +// AggregateFunction is an interface for PostgreSQL aggregate functions +type AggregateFunctionInterface interface { + FunctionInterface + // TODO: this maybe needs to take the place of the Callable function + NewBuffer() (sql.AggregationBuffer, error) +} + // Function0 is a function that does not take any parameters. type Function0 struct { Name string @@ -280,3 +287,15 @@ func (f Function4) InternalID() id.Id { // enforceInterfaceInheritance implements the FunctionInterface interface. func (f Function4) enforceInterfaceInheritance(error) {} + +// Func1Aggregate is a function that takes one parameter and is an aggregate function. +type Func1Aggregate struct { + Function1 + NewAggBuffer func() (sql.AggregationBuffer, error) +} + +func (f Func1Aggregate) NewBuffer() (sql.AggregationBuffer, error) { + return f.NewAggBuffer() +} + +var _ AggregateFunctionInterface = Func1Aggregate{} diff --git a/server/functions/framework/operators.go b/server/functions/framework/operators.go index 9d29cd66a2..77b15c2ed5 100644 --- a/server/functions/framework/operators.go +++ b/server/functions/framework/operators.go @@ -72,16 +72,16 @@ var ( unaryFunctions = map[unaryFunction]Function1{} // binaryFunctions is a map from a binaryFunction signature to the associated function. binaryFunctions = map[binaryFunction]Function2{} - // unaryAggregateOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions + // unaryOperatorOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions // for that operator. - unaryAggregateOverloads = map[Operator]*Overloads{} - // binaryAggregateOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions + unaryOperatorOverloads = map[Operator]*Overloads{} + // binaryOperatorOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions // for that operator. - binaryAggregateOverloads = map[Operator]*Overloads{} - // unaryAggregatePermutations contains all of the permutations for each unary operator. - unaryAggregatePermutations = map[Operator][]Overload{} - // unaryAggregatePermutations contains all of the permutations for each binary operator. - binaryAggregatePermutations = map[Operator][]Overload{} + binaryOperatorOverloads = map[Operator]*Overloads{} + // unaryOperatorPermutations contains all of the permutations for each unary operator. + unaryOperatorPermutations = map[Operator][]Overload{} + // unaryOperatorPermutations contains all of the permutations for each binary operator. + binaryOperatorPermutations = map[Operator][]Overload{} ) // RegisterUnaryFunction registers the given function, so that it will be usable from a running server. This should @@ -127,8 +127,8 @@ func RegisterBinaryFunction(operator Operator, f Function2) { func GetUnaryFunction(operator Operator) IntermediateFunction { // Returns nil if not found, which is fine as IntermediateFunction will handle the nil deducer return IntermediateFunction{ - Functions: unaryAggregateOverloads[operator], - AllOverloads: unaryAggregatePermutations[operator], + Functions: unaryOperatorOverloads[operator], + AllOverloads: unaryOperatorPermutations[operator], IsOperator: true, } } @@ -137,8 +137,8 @@ func GetUnaryFunction(operator Operator) IntermediateFunction { func GetBinaryFunction(operator Operator) IntermediateFunction { // Returns nil if not found, which is fine as IntermediateFunction will handle the nil deducer return IntermediateFunction{ - Functions: binaryAggregateOverloads[operator], - AllOverloads: binaryAggregatePermutations[operator], + Functions: binaryOperatorOverloads[operator], + AllOverloads: binaryOperatorPermutations[operator], IsOperator: true, } } diff --git a/server/initialization/initialization.go b/server/initialization/initialization.go index 1627926dd3..6f62ac1dd7 100644 --- a/server/initialization/initialization.go +++ b/server/initialization/initialization.go @@ -29,6 +29,7 @@ import ( "github.com/dolthub/doltgresql/server/cast" "github.com/dolthub/doltgresql/server/config" "github.com/dolthub/doltgresql/server/functions" + "github.com/dolthub/doltgresql/server/functions/aggregate" "github.com/dolthub/doltgresql/server/functions/binary" "github.com/dolthub/doltgresql/server/functions/framework" "github.com/dolthub/doltgresql/server/functions/unary" @@ -53,6 +54,7 @@ func Initialize(dEnv *env.DoltEnv) { binary.Init() unary.Init() functions.Init() + aggregate.Init() cast.Init() framework.Initialize() sql.GlobalParser = pgsql.NewPostgresParser() diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 70aa685919..0b529e46bb 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -20,6 +20,48 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) +func TestAggregateFunctions(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "array_agg", + SetUpScript: []string{ + `CREATE TABLE t1 (pk INT primary key, t timestamp, v varchar, f float[]);`, + `INSERT INTO t1 VALUES + (1, '2023-01-01 00:00:00', 'a', '{1.0, 2.0}'), + (2, '2023-01-02 00:00:00', 'b', '{3.0, 4.0}'), + (3, '2023-01-03 00:00:00', 'c', '{5.0, 6.0}');`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT array_agg(pk) FROM t1;`, + Expected: []sql.Row{ + {"{1,2,3}"}, + }, + }, + { + Query: `SELECT array_agg(t) FROM t1;`, + Expected: []sql.Row{ + {`{"2023-01-01 00:00:00","2023-01-02 00:00:00","2023-01-03 00:00:00"}`}, + }, + }, + { + Query: `SELECT array_agg(v) FROM t1;`, + Expected: []sql.Row{ + {"{a,b,c}"}, + }, + }, + { + Skip: true, // Higher-level arrays don't work because they panic during output + Query: `SELECT array_agg(f) FROM t1;`, + Expected: []sql.Row{ + {"{{1.0,2.0},{3.0,4.0},{5.0,6.0}}"}, + }, + }, + }, + }, + }) +} + // https://www.postgresql.org/docs/15/functions-math.html func TestFunctionsMath(t *testing.T) { RunScripts(t, []ScriptTest{ diff --git a/testing/go/types_test.go b/testing/go/types_test.go index a646bd9a13..14b7a10d23 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -623,6 +623,22 @@ var typesTests = []ScriptTest{ }, }, { + Name: "2D array", + Skip: true, // multiple dimensions not supported yet + SetUpScript: []string{ + "CREATE TABLE t_varchar (id INTEGER primary key, v1 CHARACTER VARYING[][]);", + "INSERT INTO t_varchar VALUES (1, '{{abcdefghij, NULL}, {1234, abc}}'), (2, ARRAY['ab''cdef', 'what', 'is,hi', 'wh\"at', '}', '{', '{}']);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t_varchar ORDER BY id;", + Expected: []sql.Row{ + {1, "{abcdefghij,NULL}"}, + {2, `{ab'cdef,what,"is,hi","wh\"at","}","{","{}"}`}, + }, + }, + }, + }, { Name: "Cidr type", Skip: true, SetUpScript: []string{