From 94989f7eee392c311643e80d39bbbd85964bd7e8 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Thu, 15 May 2025 17:00:16 -0700 Subject: [PATCH 01/18] Scaffolding for new aggregation funcs --- server/functions/framework/catalog.go | 33 + .../framework/compiled_aggregate_function.go | 658 ++++++++++++++++++ server/functions/framework/functions.go | 77 ++ testing/go/functions_test.go | 16 + 4 files changed, 784 insertions(+) create mode 100644 server/functions/framework/compiled_aggregate_function.go diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index ff92b6645a..b4609d6790 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -28,6 +28,9 @@ import ( // Catalog contains all of the PostgreSQL functions. var Catalog = map[string][]FunctionInterface{} +// AggregateCatalog contains all of the PostgreSQL aggregate functions. +var AggregateCatalog = map[string][]AggregateFunction{} + // initializedFunctions simply states whether Initialize has been called yet. var initializedFunctions = false @@ -61,6 +64,36 @@ func RegisterFunction(f FunctionInterface) { } } +// RegisterFunction registers the given function, so that it will be usable from a running server. This should be called +// from within an init(). +func RegisterAggregateFunction(f AggregateFunctionInterface) { + if initializedFunctions { + panic("attempted to register a function after the init() phase") + } + switch f := f.(type) { + case Function0: + name := strings.ToLower(f.Name) + Catalog[name] = append(Catalog[name], f) + case Function1: + name := strings.ToLower(f.Name) + Catalog[name] = append(Catalog[name], f) + case Function2: + name := strings.ToLower(f.Name) + Catalog[name] = append(Catalog[name], f) + case Function3: + name := strings.ToLower(f.Name) + Catalog[name] = append(Catalog[name], f) + case Function4: + name := strings.ToLower(f.Name) + Catalog[name] = append(Catalog[name], f) + case InterpretedFunction: + name := strings.ToLower(f.ID.FunctionName()) + Catalog[name] = append(Catalog[name], f) + default: + panic("unhandled function type") + } +} + // Initialize handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not // apply to PostgreSQL (and functions of the same name often have different behavior). func Initialize() { diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go new file mode 100644 index 0000000000..6a15d54709 --- /dev/null +++ b/server/functions/framework/compiled_aggregate_function.go @@ -0,0 +1,658 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package framework + +import ( + "fmt" + "strings" + + cerrors "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + "github.com/dolthub/doltgresql/server/plpgsql" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// AggregateFunction is an expression that represents CompiledAggregateFunction +type AggregateFunction interface { + sql.FunctionExpression + sql.NonDeterministicExpression + sql.Aggregation + specificFuncImpl() +} + +// CompiledFunction is an expression that represents a fully-analyzed PostgreSQL function. +type CompiledAggregateFunction struct { + CompiledFunction +} + +func (c *CompiledAggregateFunction) Id() sql.ColumnId { + // TODO implement me + panic("implement me") +} + +func (c *CompiledAggregateFunction) WithId(id sql.ColumnId) sql.IdExpression { + // TODO implement me + panic("implement me") +} + +func (c *CompiledAggregateFunction) NewWindowFunction() (sql.WindowFunction, error) { + // TODO implement me + panic("implement me") +} + +func (c *CompiledAggregateFunction) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + // TODO implement me + panic("implement me") +} + +func (c *CompiledAggregateFunction) Window() *sql.WindowDefinition { + // TODO implement me + panic("implement me") +} + +func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { + // TODO implement me + panic("implement me") +} + +var _ sql.FunctionExpression = (*CompiledAggregateFunction)(nil) +var _ sql.NonDeterministicExpression = (*CompiledAggregateFunction)(nil) +var _ sql.Aggregation = (*CompiledAggregateFunction)(nil) + +// NewCompiledFunction returns a newly compiled function. +func NewCompiledAggregateFunction(name string, args []sql.Expression, functions *Overloads) *CompiledFunction { + return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), isOperator, nil) +} + +// newCompiledFunctionInternal is called internally, which skips steps that may have already been processed. +func newCompiledAggregateFunctionInternal( + name string, + args []sql.Expression, + overloads *Overloads, + fnOverloads []Overload, + isOperator bool, + runner sql.StatementRunner, +) *CompiledFunction { + c := &CompiledFunction{ + Name: name, + Arguments: args, + IsOperator: isOperator, + overloads: overloads, + fnOverloads: fnOverloads, + runner: runner, + } + // First we'll analyze all the parameters. + originalTypes, err := c.analyzeParameters() + if err != nil { + // Errors should be returned from the call to Eval, so we'll stash it for now + c.stashedErr = err + return c + } + // Next we'll resolve the overload based on the parameters given. + overload, err := c.resolve(overloads, fnOverloads, originalTypes) + if err != nil { + c.stashedErr = err + return c + } + // If we do not receive an overload, then the parameters given did not result in a valid match + if !overload.Valid() { + c.stashedErr = ErrFunctionDoesNotExist.New(c.OverloadString(originalTypes)) + return c + } + + fn := overload.Function() + + // Then we'll handle the polymorphic types + // https://www.postgresql.org/docs/15/extend-type-system.html#EXTEND-TYPES-POLYMORPHIC + functionParameterTypes := fn.GetParameters() + c.callResolved = make([]*pgtypes.DoltgresType, len(functionParameterTypes)+1) + hasPolymorphicParam := false + for i, param := range functionParameterTypes { + if param.IsPolymorphicType() { + // resolve will ensure that the parameter types are valid, so we can just assign them here + hasPolymorphicParam = true + c.callResolved[i] = originalTypes[i] + } else { + if d, ok := args[i].Type().(*pgtypes.DoltgresType); ok { + // `param` is a default type which does not have type modifier set + param = param.WithAttTypMod(d.GetAttTypMod()) + } + c.callResolved[i] = param + } + } + returnType := fn.GetReturn() + c.callResolved[len(c.callResolved)-1] = returnType + if returnType.IsPolymorphicType() { + if hasPolymorphicParam { + c.callResolved[len(c.callResolved)-1] = c.resolvePolymorphicReturnType(functionParameterTypes, originalTypes, returnType) + } else if c.Name == "array_in" || c.Name == "array_recv" || c.Name == "enum_in" || c.Name == "enum_recv" || c.Name == "anyenum_in" || c.Name == "anyenum_recv" { + // The return type should resolve to the type of OID value passed in as second argument. + // TODO: Possible that the oid type has a special property with polymorphic return types, + // in that perhaps their value will set the return type in the absence of another polymorphic type in the parameter list + } else { + c.stashedErr = cerrors.Errorf("A result of type %s requires at least one input of type anyelement, anyarray, anynonarray, anyenum, anyrange, or anymultirange.", returnType.String()) + return c + } + } + + // Lastly, we assign everything to the function struct + c.overload = overload + c.originalTypes = originalTypes + return c +} + +// FunctionName implements the interface sql.Expression. +func (c *CompiledAggregateFunction) FunctionName() string { + return c.Name +} + +// Description implements the interface sql.Expression. +func (c *CompiledAggregateFunction) Description() string { + return fmt.Sprintf("The PostgreSQL function `%s`", c.Name) +} + +// Resolved implements the interface sql.Expression. +func (c *CompiledAggregateFunction) Resolved() bool { + for _, param := range c.Arguments { + if !param.Resolved() { + return false + } + } + // We don't error until evaluation time, so we need to tell the engine we're resolved if there was a stashed error + return c.stashedErr != nil || c.overload.Valid() +} + +// StashedError returns the stashed error if one exists. Otherwise, returns nil. +func (c *CompiledAggregateFunction) StashedError() error { + if c == nil { + return nil + } + return c.stashedErr +} + +// String implements the interface sql.Expression. +func (c *CompiledAggregateFunction) String() string { + sb := strings.Builder{} + sb.WriteString(c.Name + "(") + for i, param := range c.Arguments { + // Aliases will output the string "x as x", which is an artifact of how we build the AST, so we'll bypass it + if alias, ok := param.(*expression.Alias); ok { + param = alias.Child + } + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(param.String()) + } + sb.WriteString(")") + return sb.String() +} + +// OverloadString returns the name of the function represented by the given overload. +func (c *CompiledAggregateFunction) OverloadString(types []*pgtypes.DoltgresType) string { + sb := strings.Builder{} + sb.WriteString(c.Name + "(") + for i, t := range types { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(t.String()) + } + sb.WriteString(")") + return sb.String() +} + +// Type implements the interface sql.Expression. +func (c *CompiledAggregateFunction) Type() sql.Type { + if len(c.callResolved) > 0 { + return c.callResolved[len(c.callResolved)-1] + } + // Compilation must have errored, so we'll return the unknown type + return pgtypes.Unknown +} + +// IsNullable implements the interface sql.Expression. +func (c *CompiledAggregateFunction) IsNullable() bool { + // All functions seem to return NULL when given a NULL value + return true +} + +// IsNonDeterministic implements the interface sql.NonDeterministicExpression. +func (c *CompiledAggregateFunction) IsNonDeterministic() bool { + if c.overload.Valid() { + return c.overload.Function().NonDeterministic() + } + // Compilation must have errored, so we'll just return true + return true +} + +// IsVariadic returns whether this function has any variadic parameters. +func (c *CompiledAggregateFunction) IsVariadic() bool { + if c.overload.Valid() { + return c.overload.params.variadic != -1 + } + // Compilation must have errored, so we'll just return true + return true +} + +// Eval implements the interface sql.Expression. +func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + // If we have a stashed error, then we should return that now. Errors are stashed when they're supposed to be + // returned during the call to Eval. This helps to ensure consistency with how errors are returned in Postgres. + if c.stashedErr != nil { + return nil, c.stashedErr + } + + // Evaluate all arguments, returning immediately if we encounter a null argument and the function is marked STRICT + var err error + isStrict := c.overload.Function().IsStrict() + args := make([]any, len(c.Arguments)) + for i, arg := range c.Arguments { + args[i], err = arg.Eval(ctx, row) + if err != nil { + return nil, err + } + // TODO: once we remove GMS types from all of our expressions, we can remove this step which ensures the correct type + if _, ok := arg.Type().(*pgtypes.DoltgresType); !ok { + dt, err := pgtypes.FromGmsTypeToDoltgresType(arg.Type()) + if err != nil { + return nil, err + } + args[i], _, _ = dt.Convert(ctx, args[i]) + } + if args[i] == nil && isStrict { + return nil, nil + } + } + + if len(c.overload.casts) > 0 { + targetParamTypes := c.overload.Function().GetParameters() + for i, arg := range args { + // For variadic params, we need to identify the corresponding target type + var targetType *pgtypes.DoltgresType + isVariadicArg := c.overload.params.variadic >= 0 && i >= len(c.overload.params.paramTypes)-1 + if isVariadicArg { + targetType = targetParamTypes[c.overload.params.variadic] + if !targetType.IsArrayType() { + // should be impossible, we check this at function compile time + return nil, cerrors.Errorf("variadic arguments must be array types, was %T", targetType) + } + targetType = targetType.ArrayBaseType() + } else { + targetType = targetParamTypes[i] + } + + if c.overload.casts[i] != nil { + args[i], err = c.overload.casts[i](ctx, arg, targetType) + if err != nil { + return nil, err + } + } else { + return nil, cerrors.Errorf("function %s is missing the appropriate implicit cast", c.OverloadString(c.originalTypes)) + } + } + } + + args = c.overload.params.coalesceVariadicValues(args) + + // Call the function + switch f := c.overload.Function().(type) { + case Function0: + return f.Callable(ctx) + case Function1: + return f.Callable(ctx, ([2]*pgtypes.DoltgresType)(c.callResolved), args[0]) + case Function2: + return f.Callable(ctx, ([3]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1]) + case Function3: + return f.Callable(ctx, ([4]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2]) + case Function4: + return f.Callable(ctx, ([5]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2], args[3]) + case InterpretedFunction: + return plpgsql.Call(ctx, f, c.runner, c.callResolved, args) + default: + return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval") + } +} + +// Children implements the interface sql.Expression. +func (c *CompiledAggregateFunction) Children() []sql.Expression { + return c.Arguments +} + +// WithChildren implements the interface sql.Expression. +func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(c.Arguments) { + return nil, sql.ErrInvalidChildrenNumber.New(len(children), len(c.Arguments)) + } + + // We have to re-resolve here, since the change in children may require it (e.g. we have more type info than we did) + return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.IsOperator, c.runner), nil +} + +// resolve returns an overloadMatch that either matches the given parameters exactly, or is a viable match after casting. +// Returns an invalid overloadMatch if a viable match is not found. +func (c *CompiledAggregateFunction) resolve(overloads *Overloads, fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) (overloadMatch, error) { + // First check for an exact match + exactMatch, found := overloads.ExactMatchForTypes(argTypes...) + if found { + return overloadMatch{ + params: Overload{ + function: exactMatch, + paramTypes: argTypes, + argTypes: argTypes, + variadic: -1, + }, + }, nil + } + // There are no exact matches, so now we'll look through all overloads to determine the best match. This is + // much more work, but there's a performance penalty for runtime overload resolution in Postgres as well. + return c.resolveFunction(argTypes, fnOverloads) +} + +// resolveFunction resolves a function according to the rules defined by Postgres. +// https://www.postgresql.org/docs/15/typeconv-func.html +func (c *CompiledAggregateFunction) resolveFunction(argTypes []*pgtypes.DoltgresType, overloads []Overload) (overloadMatch, error) { + // First we'll discard all overloads that do not have implicitly-convertible param types + compatibleOverloads := c.typeCompatibleOverloads(overloads, argTypes) + + // No compatible overloads available, return early + if len(compatibleOverloads) == 0 { + return overloadMatch{}, nil + } + + // If we've found exactly one match then we'll return that one + // TODO: we need to also prefer non-variadic functions here over variadic ones (no such conflict can exist for now) + // https://www.postgresql.org/docs/15/typeconv-func.html + if len(compatibleOverloads) == 1 { + return compatibleOverloads[0], nil + } + + // Next rank the candidates by the number of params whose types match exactly + closestMatches := c.closestTypeMatches(argTypes, compatibleOverloads) + + // Now check again for exactly one match + if len(closestMatches) == 1 { + return closestMatches[0], nil + } + + // If there was more than a single match, try to find the one with the most preferred type conversions + preferredOverloads := c.preferredTypeMatches(argTypes, closestMatches) + + // Check once more for exactly one match + if len(preferredOverloads) == 1 { + return preferredOverloads[0], nil + } + + // Next we'll check the type categories for `unknown` types + unknownOverloads, ok := c.unknownTypeCategoryMatches(argTypes, preferredOverloads) + if !ok { + return overloadMatch{}, nil + } + + // Check again for exactly one match + if len(unknownOverloads) == 1 { + return unknownOverloads[0], nil + } + + // No matching function overload found + return overloadMatch{}, nil +} + +// typeCompatibleOverloads returns all overloads that have a matching number of params whose types can be +// implicitly converted to the ones provided. This is the set of all possible overloads that could be used with the +// param types provided. +func (c *CompiledAggregateFunction) typeCompatibleOverloads(fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) []overloadMatch { + var compatible []overloadMatch + for _, overload := range fnOverloads { + isConvertible := true + overloadCasts := make([]pgtypes.TypeCastFunction, len(argTypes)) + // Polymorphic parameters must be gathered so that we can later verify that they all have matching base types + var polymorphicParameters []*pgtypes.DoltgresType + var polymorphicTargets []*pgtypes.DoltgresType + for i := range argTypes { + paramType := overload.argTypes[i] + if paramType.IsValidForPolymorphicType(argTypes[i]) { + overloadCasts[i] = identityCast + polymorphicParameters = append(polymorphicParameters, paramType) + polymorphicTargets = append(polymorphicTargets, argTypes[i]) + } else { + if overloadCasts[i] = GetImplicitCast(argTypes[i], paramType); overloadCasts[i] == nil { + if argTypes[i].ID == pgtypes.Unknown.ID { + overloadCasts[i] = UnknownLiteralCast + } else { + isConvertible = false + break + } + } + } + } + + if isConvertible && c.polymorphicTypesCompatible(polymorphicParameters, polymorphicTargets) { + compatible = append(compatible, overloadMatch{params: overload, casts: overloadCasts}) + } + } + return compatible +} + +// closestTypeMatches returns the set of overload candidates that have the most exact type matches for the arg types +// provided. +func (*CompiledAggregateFunction) closestTypeMatches(argTypes []*pgtypes.DoltgresType, candidates []overloadMatch) []overloadMatch { + matchCount := 0 + var matches []overloadMatch + for _, cand := range candidates { + currentMatchCount := 0 + for argIdx := range argTypes { + argType := cand.params.argTypes[argIdx] + if argTypes[argIdx].ID == argType.ID || argTypes[argIdx].ID == pgtypes.Unknown.ID { + currentMatchCount++ + } + } + if currentMatchCount > matchCount { + matchCount = currentMatchCount + matches = append([]overloadMatch{}, cand) + } else if currentMatchCount == matchCount { + matches = append(matches, cand) + } + } + return matches +} + +// preferredTypeMatches returns the overload candidates that have the most preferred types for args that require casts. +func (*CompiledAggregateFunction) preferredTypeMatches(argTypes []*pgtypes.DoltgresType, candidates []overloadMatch) []overloadMatch { + preferredCount := 0 + var preferredOverloads []overloadMatch + for _, cand := range candidates { + currentPreferredCount := 0 + for argIdx := range argTypes { + argType := cand.params.argTypes[argIdx] + if argTypes[argIdx].ID != argType.ID && argType.IsPreferred { + currentPreferredCount++ + } + } + + if currentPreferredCount > preferredCount { + preferredCount = currentPreferredCount + preferredOverloads = append([]overloadMatch{}, cand) + } else if currentPreferredCount == preferredCount { + preferredOverloads = append(preferredOverloads, cand) + } + } + return preferredOverloads +} + +// unknownTypeCategoryMatches checks the type categories of `unknown` types. These types have an inherent bias toward +// the string category since an `unknown` literal resembles a string. Returns false if the resolution should fail. +func (c *CompiledAggregateFunction) unknownTypeCategoryMatches(argTypes []*pgtypes.DoltgresType, candidates []overloadMatch) ([]overloadMatch, bool) { + matches := make([]overloadMatch, len(candidates)) + copy(matches, candidates) + // For our first loop, we'll filter matches based on whether they accept the string category + for argIdx := range argTypes { + // We're only concerned with `unknown` types + if argTypes[argIdx].ID != pgtypes.Unknown.ID { + continue + } + var newMatches []overloadMatch + for _, match := range matches { + if match.params.argTypes[argIdx].TypCategory == pgtypes.TypeCategory_StringTypes { + newMatches = append(newMatches, match) + } + } + // If we've found matches in this step, then we'll update our match set + if len(newMatches) > 0 { + matches = newMatches + } + } + // Return early if we've filtered down to a single match + if len(matches) == 1 { + return matches, true + } + // TODO: implement the remainder of step 4.e. from the documentation (following code assumes it has been implemented) + // ... + + // If we've discarded every function, then we'll actually return all original candidates + if len(matches) == 0 { + return candidates, true + } + // In this case, we've trimmed at least one candidate, so we'll return our new matches + return matches, true +} + +// polymorphicTypesCompatible returns whether any polymorphic types given are compatible with the expression types given +func (*CompiledAggregateFunction) polymorphicTypesCompatible(paramTypes []*pgtypes.DoltgresType, exprTypes []*pgtypes.DoltgresType) bool { + if len(paramTypes) != len(exprTypes) { + return false + } + // If there are less than two parameters then we don't even need to check + if len(paramTypes) < 2 { + return true + } + + // If one of the types is anyarray, then anyelement behaves as anynonarray, so we can convert them to anynonarray + for _, paramType := range paramTypes { + if paramType.ID == pgtypes.AnyArray.ID { + // At least one parameter is anyarray, so copy all parameters to a new slice and replace anyelement with anynonarray + newParamTypes := make([]*pgtypes.DoltgresType, len(paramTypes)) + copy(newParamTypes, paramTypes) + for i := range newParamTypes { + if paramTypes[i].ID == pgtypes.AnyElement.ID { + newParamTypes[i] = pgtypes.AnyNonArray + } + } + paramTypes = newParamTypes + break + } + } + + // The base type is the type that must match between all polymorphic types. + var baseType *pgtypes.DoltgresType + for i, paramType := range paramTypes { + if paramType.IsPolymorphicType() && exprTypes[i].ID != pgtypes.Unknown.ID { + // Although we do this check before we ever reach this function, we do it again as we may convert anyelement + // to anynonarray, which changes type validity + if !paramType.IsValidForPolymorphicType(exprTypes[i]) { + return false + } + // Get the base expression type that we'll compare against + baseExprType := exprTypes[i] + if baseExprType.IsArrayType() { + baseExprType = baseExprType.ArrayBaseType() + } + // TODO: handle range types + // Check that the base expression type matches the previously-found base type + if baseType.IsEmptyType() { + baseType = baseExprType + } else if baseType.ID != baseExprType.ID { + return false + } + } + } + return true +} + +// resolvePolymorphicReturnType returns the type that should be used for the return type. If the return type is not a +// polymorphic type, then the return type is directly returned. However, if the return type is a polymorphic type, then +// the type is determined using the expression types and parameter types. This makes the assumption that everything has +// already been validated. +func (c *CompiledAggregateFunction) resolvePolymorphicReturnType(functionInterfaceTypes []*pgtypes.DoltgresType, originalTypes []*pgtypes.DoltgresType, returnType *pgtypes.DoltgresType) *pgtypes.DoltgresType { + if !returnType.IsPolymorphicType() { + return returnType + } + // We can use the first polymorphic non-unknown type that we find, since we can morph it into any type that we need. + // We've verified that all polymorphic types are compatible in a previous step, so this is safe to do. + var firstPolymorphicType *pgtypes.DoltgresType + for i, functionInterfaceType := range functionInterfaceTypes { + if functionInterfaceType.IsPolymorphicType() && originalTypes[i].ID != pgtypes.Unknown.ID { + firstPolymorphicType = originalTypes[i] + break + } + } + + // if all types are `unknown`, use `text` type + if firstPolymorphicType.IsEmptyType() { + firstPolymorphicType = pgtypes.Text + } + + switch returnType.ID { + case pgtypes.AnyElement.ID, pgtypes.AnyNonArray.ID: + // For return types, anyelement behaves the same as anynonarray. + // This isn't explicitly in the documentation, however it does note that: + // "...anynonarray and anyenum do not represent separate type variables; they are the same type as anyelement..." + // The implication of this being that anyelement will always return the base type even for array types, + // just like anynonarray would. + if firstPolymorphicType.IsArrayType() { + return firstPolymorphicType.ArrayBaseType() + } else { + return firstPolymorphicType + } + case pgtypes.AnyArray.ID: + // Array types will return themselves, so this is safe + if firstPolymorphicType.IsArrayType() { + return firstPolymorphicType + } else if firstPolymorphicType.ID == pgtypes.Internal.ID { + return pgtypes.IDToBuiltInDoltgresType[firstPolymorphicType.BaseTypeForInternal] + } else { + return firstPolymorphicType.ToArrayType() + } + default: + panic(cerrors.Errorf("`%s` is not yet handled during function compilation", returnType.String())) + } +} + +// analyzeParameters analyzes the parameters within an Eval call. +func (c *CompiledAggregateFunction) analyzeParameters() (originalTypes []*pgtypes.DoltgresType, err error) { + originalTypes = make([]*pgtypes.DoltgresType, len(c.Arguments)) + for i, param := range c.Arguments { + returnType := param.Type() + if extendedType, ok := returnType.(*pgtypes.DoltgresType); ok && !extendedType.IsEmptyType() { + if extendedType.TypType == pgtypes.TypeType_Domain { + extendedType = extendedType.DomainUnderlyingBaseType() + } + originalTypes[i] = extendedType + } else { + // TODO: we need to remove GMS types from all of our expressions so that we can remove this + dt, err := pgtypes.FromGmsTypeToDoltgresType(param.Type()) + if err != nil { + return nil, err + } + originalTypes[i] = dt + } + } + return originalTypes, nil +} + +// specificFuncImpl implements the interface sql.Expression. +func (*CompiledAggregateFunction) specificFuncImpl() {} diff --git a/server/functions/framework/functions.go b/server/functions/framework/functions.go index ebbafa3b4a..fdd632afd7 100644 --- a/server/functions/framework/functions.go +++ b/server/functions/framework/functions.go @@ -45,6 +45,11 @@ type FunctionInterface interface { enforceInterfaceInheritance(error) } +type AggregateFunctionInterface interface { + FunctionInterface + sql.Aggregation +} + // Function0 is a function that does not take any parameters. type Function0 struct { Name string @@ -280,3 +285,75 @@ func (f Function4) InternalID() id.Id { // enforceInterfaceInheritance implements the FunctionInterface interface. func (f Function4) enforceInterfaceInheritance(error) {} + +// Func1Aggregate is a function that takes one parameter and is an aggregate function. +type Func1Aggregate struct { + Function0 +} + +func (f Func1Aggregate) Resolved() bool { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) String() string { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) Type() sql.Type { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) IsNullable() bool { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) Children() []sql.Expression { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) WithChildren(children ...sql.Expression) (sql.Expression, error) { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) Id() sql.ColumnId { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) WithId(columnId sql.ColumnId) sql.IdExpression { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) NewWindowFunction() (sql.WindowFunction, error) { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) Window() *sql.WindowDefinition { + // TODO implement me + panic("implement me") +} + +func (f Func1Aggregate) NewBuffer() (sql.AggregationBuffer, error) { + // TODO implement me + panic("implement me") +} + +var _ AggregateFunction = Func1Aggregate{} \ No newline at end of file diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 70aa685919..a7e5fd8fec 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -20,6 +20,22 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) +func TestSimple(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "left", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT left('abc', 1);`, + Expected: []sql.Row{ + {"a"}, + }, + }, + }, + }, + }) +} + // https://www.postgresql.org/docs/15/functions-math.html func TestFunctionsMath(t *testing.T) { RunScripts(t, []ScriptTest{ From 421bd81506b6dd476c29d73cffa8a07ad64e3319 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Tue, 20 May 2025 16:55:11 -0700 Subject: [PATCH 02/18] sketching out compiled agg funcs --- .../framework/compiled_aggregate_function.go | 446 ++---------------- 1 file changed, 47 insertions(+), 399 deletions(-) diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index 6a15d54709..bd6ff4f995 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -29,53 +29,71 @@ import ( // AggregateFunction is an expression that represents CompiledAggregateFunction type AggregateFunction interface { sql.FunctionExpression - sql.NonDeterministicExpression sql.Aggregation specificFuncImpl() } // CompiledFunction is an expression that represents a fully-analyzed PostgreSQL function. type CompiledAggregateFunction struct { - CompiledFunction + *CompiledFunction + aggId sql.ColumnId } func (c *CompiledAggregateFunction) Id() sql.ColumnId { - // TODO implement me - panic("implement me") + return c.aggId } func (c *CompiledAggregateFunction) WithId(id sql.ColumnId) sql.IdExpression { - // TODO implement me - panic("implement me") + nc := *c + nc.aggId = id + return &nc } func (c *CompiledAggregateFunction) NewWindowFunction() (sql.WindowFunction, error) { - // TODO implement me - panic("implement me") + panic("windows are not implemented yet") } func (c *CompiledAggregateFunction) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { - // TODO implement me - panic("implement me") + panic("windows are not implemented yet") } func (c *CompiledAggregateFunction) Window() *sql.WindowDefinition { - // TODO implement me - panic("implement me") + panic("windows are not implemented yet") +} + +type arrayAggBuffer struct { + elements []any +} + +func newArrayAggBuffer() *arrayAggBuffer { + return &arrayAggBuffer{ + elements: make([]any, 0), + } +} + +func (a *arrayAggBuffer) Dispose() {} + +func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) { + if len(a.elements) == 0 { + return nil, nil + } + return a.elements, nil +} + +func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { + a.elements = append(a.elements, row[0]) + return nil } func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { - // TODO implement me - panic("implement me") + return newArrayAggBuffer(), nil } -var _ sql.FunctionExpression = (*CompiledAggregateFunction)(nil) -var _ sql.NonDeterministicExpression = (*CompiledAggregateFunction)(nil) -var _ sql.Aggregation = (*CompiledAggregateFunction)(nil) +var _ AggregateFunction = (*CompiledAggregateFunction)(nil) -// NewCompiledFunction returns a newly compiled function. -func NewCompiledAggregateFunction(name string, args []sql.Expression, functions *Overloads) *CompiledFunction { - return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), isOperator, nil) +// NewCompiledAggregateFunction returns a newly compiled function. +func NewCompiledAggregateFunction(name string, args []sql.Expression, functions *Overloads) *CompiledAggregateFunction { + return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args))) } // newCompiledFunctionInternal is called internally, which skips steps that may have already been processed. @@ -84,74 +102,13 @@ func newCompiledAggregateFunctionInternal( args []sql.Expression, overloads *Overloads, fnOverloads []Overload, - isOperator bool, - runner sql.StatementRunner, -) *CompiledFunction { - c := &CompiledFunction{ - Name: name, - Arguments: args, - IsOperator: isOperator, - overloads: overloads, - fnOverloads: fnOverloads, - runner: runner, - } - // First we'll analyze all the parameters. - originalTypes, err := c.analyzeParameters() - if err != nil { - // Errors should be returned from the call to Eval, so we'll stash it for now - c.stashedErr = err - return c - } - // Next we'll resolve the overload based on the parameters given. - overload, err := c.resolve(overloads, fnOverloads, originalTypes) - if err != nil { - c.stashedErr = err - return c - } - // If we do not receive an overload, then the parameters given did not result in a valid match - if !overload.Valid() { - c.stashedErr = ErrFunctionDoesNotExist.New(c.OverloadString(originalTypes)) - return c +) *CompiledAggregateFunction { + + cf := newCompiledFunctionInternal(name, args, overloads, fnOverloads, false, nil) + c := &CompiledAggregateFunction{ + CompiledFunction: cf, } - - fn := overload.Function() - - // Then we'll handle the polymorphic types - // https://www.postgresql.org/docs/15/extend-type-system.html#EXTEND-TYPES-POLYMORPHIC - functionParameterTypes := fn.GetParameters() - c.callResolved = make([]*pgtypes.DoltgresType, len(functionParameterTypes)+1) - hasPolymorphicParam := false - for i, param := range functionParameterTypes { - if param.IsPolymorphicType() { - // resolve will ensure that the parameter types are valid, so we can just assign them here - hasPolymorphicParam = true - c.callResolved[i] = originalTypes[i] - } else { - if d, ok := args[i].Type().(*pgtypes.DoltgresType); ok { - // `param` is a default type which does not have type modifier set - param = param.WithAttTypMod(d.GetAttTypMod()) - } - c.callResolved[i] = param - } - } - returnType := fn.GetReturn() - c.callResolved[len(c.callResolved)-1] = returnType - if returnType.IsPolymorphicType() { - if hasPolymorphicParam { - c.callResolved[len(c.callResolved)-1] = c.resolvePolymorphicReturnType(functionParameterTypes, originalTypes, returnType) - } else if c.Name == "array_in" || c.Name == "array_recv" || c.Name == "enum_in" || c.Name == "enum_recv" || c.Name == "anyenum_in" || c.Name == "anyenum_recv" { - // The return type should resolve to the type of OID value passed in as second argument. - // TODO: Possible that the oid type has a special property with polymorphic return types, - // in that perhaps their value will set the return type in the absence of another polymorphic type in the parameter list - } else { - c.stashedErr = cerrors.Errorf("A result of type %s requires at least one input of type anyelement, anyarray, anynonarray, anyenum, anyrange, or anymultirange.", returnType.String()) - return c - } - } - - // Lastly, we assign everything to the function struct - c.overload = overload - c.originalTypes = originalTypes + return c } @@ -251,6 +208,8 @@ func (c *CompiledAggregateFunction) IsVariadic() bool { // Eval implements the interface sql.Expression. func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + // TODO: probably should be an error? + // If we have a stashed error, then we should return that now. Errors are stashed when they're supposed to be // returned during the call to Eval. This helps to ensure consistency with how errors are returned in Postgres. if c.stashedErr != nil { @@ -340,318 +299,7 @@ func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sq } // We have to re-resolve here, since the change in children may require it (e.g. we have more type info than we did) - return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.IsOperator, c.runner), nil -} - -// resolve returns an overloadMatch that either matches the given parameters exactly, or is a viable match after casting. -// Returns an invalid overloadMatch if a viable match is not found. -func (c *CompiledAggregateFunction) resolve(overloads *Overloads, fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) (overloadMatch, error) { - // First check for an exact match - exactMatch, found := overloads.ExactMatchForTypes(argTypes...) - if found { - return overloadMatch{ - params: Overload{ - function: exactMatch, - paramTypes: argTypes, - argTypes: argTypes, - variadic: -1, - }, - }, nil - } - // There are no exact matches, so now we'll look through all overloads to determine the best match. This is - // much more work, but there's a performance penalty for runtime overload resolution in Postgres as well. - return c.resolveFunction(argTypes, fnOverloads) -} - -// resolveFunction resolves a function according to the rules defined by Postgres. -// https://www.postgresql.org/docs/15/typeconv-func.html -func (c *CompiledAggregateFunction) resolveFunction(argTypes []*pgtypes.DoltgresType, overloads []Overload) (overloadMatch, error) { - // First we'll discard all overloads that do not have implicitly-convertible param types - compatibleOverloads := c.typeCompatibleOverloads(overloads, argTypes) - - // No compatible overloads available, return early - if len(compatibleOverloads) == 0 { - return overloadMatch{}, nil - } - - // If we've found exactly one match then we'll return that one - // TODO: we need to also prefer non-variadic functions here over variadic ones (no such conflict can exist for now) - // https://www.postgresql.org/docs/15/typeconv-func.html - if len(compatibleOverloads) == 1 { - return compatibleOverloads[0], nil - } - - // Next rank the candidates by the number of params whose types match exactly - closestMatches := c.closestTypeMatches(argTypes, compatibleOverloads) - - // Now check again for exactly one match - if len(closestMatches) == 1 { - return closestMatches[0], nil - } - - // If there was more than a single match, try to find the one with the most preferred type conversions - preferredOverloads := c.preferredTypeMatches(argTypes, closestMatches) - - // Check once more for exactly one match - if len(preferredOverloads) == 1 { - return preferredOverloads[0], nil - } - - // Next we'll check the type categories for `unknown` types - unknownOverloads, ok := c.unknownTypeCategoryMatches(argTypes, preferredOverloads) - if !ok { - return overloadMatch{}, nil - } - - // Check again for exactly one match - if len(unknownOverloads) == 1 { - return unknownOverloads[0], nil - } - - // No matching function overload found - return overloadMatch{}, nil -} - -// typeCompatibleOverloads returns all overloads that have a matching number of params whose types can be -// implicitly converted to the ones provided. This is the set of all possible overloads that could be used with the -// param types provided. -func (c *CompiledAggregateFunction) typeCompatibleOverloads(fnOverloads []Overload, argTypes []*pgtypes.DoltgresType) []overloadMatch { - var compatible []overloadMatch - for _, overload := range fnOverloads { - isConvertible := true - overloadCasts := make([]pgtypes.TypeCastFunction, len(argTypes)) - // Polymorphic parameters must be gathered so that we can later verify that they all have matching base types - var polymorphicParameters []*pgtypes.DoltgresType - var polymorphicTargets []*pgtypes.DoltgresType - for i := range argTypes { - paramType := overload.argTypes[i] - if paramType.IsValidForPolymorphicType(argTypes[i]) { - overloadCasts[i] = identityCast - polymorphicParameters = append(polymorphicParameters, paramType) - polymorphicTargets = append(polymorphicTargets, argTypes[i]) - } else { - if overloadCasts[i] = GetImplicitCast(argTypes[i], paramType); overloadCasts[i] == nil { - if argTypes[i].ID == pgtypes.Unknown.ID { - overloadCasts[i] = UnknownLiteralCast - } else { - isConvertible = false - break - } - } - } - } - - if isConvertible && c.polymorphicTypesCompatible(polymorphicParameters, polymorphicTargets) { - compatible = append(compatible, overloadMatch{params: overload, casts: overloadCasts}) - } - } - return compatible -} - -// closestTypeMatches returns the set of overload candidates that have the most exact type matches for the arg types -// provided. -func (*CompiledAggregateFunction) closestTypeMatches(argTypes []*pgtypes.DoltgresType, candidates []overloadMatch) []overloadMatch { - matchCount := 0 - var matches []overloadMatch - for _, cand := range candidates { - currentMatchCount := 0 - for argIdx := range argTypes { - argType := cand.params.argTypes[argIdx] - if argTypes[argIdx].ID == argType.ID || argTypes[argIdx].ID == pgtypes.Unknown.ID { - currentMatchCount++ - } - } - if currentMatchCount > matchCount { - matchCount = currentMatchCount - matches = append([]overloadMatch{}, cand) - } else if currentMatchCount == matchCount { - matches = append(matches, cand) - } - } - return matches -} - -// preferredTypeMatches returns the overload candidates that have the most preferred types for args that require casts. -func (*CompiledAggregateFunction) preferredTypeMatches(argTypes []*pgtypes.DoltgresType, candidates []overloadMatch) []overloadMatch { - preferredCount := 0 - var preferredOverloads []overloadMatch - for _, cand := range candidates { - currentPreferredCount := 0 - for argIdx := range argTypes { - argType := cand.params.argTypes[argIdx] - if argTypes[argIdx].ID != argType.ID && argType.IsPreferred { - currentPreferredCount++ - } - } - - if currentPreferredCount > preferredCount { - preferredCount = currentPreferredCount - preferredOverloads = append([]overloadMatch{}, cand) - } else if currentPreferredCount == preferredCount { - preferredOverloads = append(preferredOverloads, cand) - } - } - return preferredOverloads -} - -// unknownTypeCategoryMatches checks the type categories of `unknown` types. These types have an inherent bias toward -// the string category since an `unknown` literal resembles a string. Returns false if the resolution should fail. -func (c *CompiledAggregateFunction) unknownTypeCategoryMatches(argTypes []*pgtypes.DoltgresType, candidates []overloadMatch) ([]overloadMatch, bool) { - matches := make([]overloadMatch, len(candidates)) - copy(matches, candidates) - // For our first loop, we'll filter matches based on whether they accept the string category - for argIdx := range argTypes { - // We're only concerned with `unknown` types - if argTypes[argIdx].ID != pgtypes.Unknown.ID { - continue - } - var newMatches []overloadMatch - for _, match := range matches { - if match.params.argTypes[argIdx].TypCategory == pgtypes.TypeCategory_StringTypes { - newMatches = append(newMatches, match) - } - } - // If we've found matches in this step, then we'll update our match set - if len(newMatches) > 0 { - matches = newMatches - } - } - // Return early if we've filtered down to a single match - if len(matches) == 1 { - return matches, true - } - // TODO: implement the remainder of step 4.e. from the documentation (following code assumes it has been implemented) - // ... - - // If we've discarded every function, then we'll actually return all original candidates - if len(matches) == 0 { - return candidates, true - } - // In this case, we've trimmed at least one candidate, so we'll return our new matches - return matches, true -} - -// polymorphicTypesCompatible returns whether any polymorphic types given are compatible with the expression types given -func (*CompiledAggregateFunction) polymorphicTypesCompatible(paramTypes []*pgtypes.DoltgresType, exprTypes []*pgtypes.DoltgresType) bool { - if len(paramTypes) != len(exprTypes) { - return false - } - // If there are less than two parameters then we don't even need to check - if len(paramTypes) < 2 { - return true - } - - // If one of the types is anyarray, then anyelement behaves as anynonarray, so we can convert them to anynonarray - for _, paramType := range paramTypes { - if paramType.ID == pgtypes.AnyArray.ID { - // At least one parameter is anyarray, so copy all parameters to a new slice and replace anyelement with anynonarray - newParamTypes := make([]*pgtypes.DoltgresType, len(paramTypes)) - copy(newParamTypes, paramTypes) - for i := range newParamTypes { - if paramTypes[i].ID == pgtypes.AnyElement.ID { - newParamTypes[i] = pgtypes.AnyNonArray - } - } - paramTypes = newParamTypes - break - } - } - - // The base type is the type that must match between all polymorphic types. - var baseType *pgtypes.DoltgresType - for i, paramType := range paramTypes { - if paramType.IsPolymorphicType() && exprTypes[i].ID != pgtypes.Unknown.ID { - // Although we do this check before we ever reach this function, we do it again as we may convert anyelement - // to anynonarray, which changes type validity - if !paramType.IsValidForPolymorphicType(exprTypes[i]) { - return false - } - // Get the base expression type that we'll compare against - baseExprType := exprTypes[i] - if baseExprType.IsArrayType() { - baseExprType = baseExprType.ArrayBaseType() - } - // TODO: handle range types - // Check that the base expression type matches the previously-found base type - if baseType.IsEmptyType() { - baseType = baseExprType - } else if baseType.ID != baseExprType.ID { - return false - } - } - } - return true -} - -// resolvePolymorphicReturnType returns the type that should be used for the return type. If the return type is not a -// polymorphic type, then the return type is directly returned. However, if the return type is a polymorphic type, then -// the type is determined using the expression types and parameter types. This makes the assumption that everything has -// already been validated. -func (c *CompiledAggregateFunction) resolvePolymorphicReturnType(functionInterfaceTypes []*pgtypes.DoltgresType, originalTypes []*pgtypes.DoltgresType, returnType *pgtypes.DoltgresType) *pgtypes.DoltgresType { - if !returnType.IsPolymorphicType() { - return returnType - } - // We can use the first polymorphic non-unknown type that we find, since we can morph it into any type that we need. - // We've verified that all polymorphic types are compatible in a previous step, so this is safe to do. - var firstPolymorphicType *pgtypes.DoltgresType - for i, functionInterfaceType := range functionInterfaceTypes { - if functionInterfaceType.IsPolymorphicType() && originalTypes[i].ID != pgtypes.Unknown.ID { - firstPolymorphicType = originalTypes[i] - break - } - } - - // if all types are `unknown`, use `text` type - if firstPolymorphicType.IsEmptyType() { - firstPolymorphicType = pgtypes.Text - } - - switch returnType.ID { - case pgtypes.AnyElement.ID, pgtypes.AnyNonArray.ID: - // For return types, anyelement behaves the same as anynonarray. - // This isn't explicitly in the documentation, however it does note that: - // "...anynonarray and anyenum do not represent separate type variables; they are the same type as anyelement..." - // The implication of this being that anyelement will always return the base type even for array types, - // just like anynonarray would. - if firstPolymorphicType.IsArrayType() { - return firstPolymorphicType.ArrayBaseType() - } else { - return firstPolymorphicType - } - case pgtypes.AnyArray.ID: - // Array types will return themselves, so this is safe - if firstPolymorphicType.IsArrayType() { - return firstPolymorphicType - } else if firstPolymorphicType.ID == pgtypes.Internal.ID { - return pgtypes.IDToBuiltInDoltgresType[firstPolymorphicType.BaseTypeForInternal] - } else { - return firstPolymorphicType.ToArrayType() - } - default: - panic(cerrors.Errorf("`%s` is not yet handled during function compilation", returnType.String())) - } -} - -// analyzeParameters analyzes the parameters within an Eval call. -func (c *CompiledAggregateFunction) analyzeParameters() (originalTypes []*pgtypes.DoltgresType, err error) { - originalTypes = make([]*pgtypes.DoltgresType, len(c.Arguments)) - for i, param := range c.Arguments { - returnType := param.Type() - if extendedType, ok := returnType.(*pgtypes.DoltgresType); ok && !extendedType.IsEmptyType() { - if extendedType.TypType == pgtypes.TypeType_Domain { - extendedType = extendedType.DomainUnderlyingBaseType() - } - originalTypes[i] = extendedType - } else { - // TODO: we need to remove GMS types from all of our expressions so that we can remove this - dt, err := pgtypes.FromGmsTypeToDoltgresType(param.Type()) - if err != nil { - return nil, err - } - originalTypes[i] = dt - } - } - return originalTypes, nil + return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads), nil } // specificFuncImpl implements the interface sql.Expression. From c1e70b559cd92c2117cee5cf61952e22e713c551 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 21 May 2025 15:13:07 -0700 Subject: [PATCH 03/18] removing more unneeded override funcs --- .../framework/compiled_aggregate_function.go | 208 +++++------------- 1 file changed, 52 insertions(+), 156 deletions(-) diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index bd6ff4f995..74272b40c1 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -15,15 +15,10 @@ package framework import ( - "fmt" - "strings" - cerrors "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" + "github.com/dolthub/go-mysql-server/sql" ) // AggregateFunction is an expression that represents CompiledAggregateFunction @@ -33,62 +28,12 @@ type AggregateFunction interface { specificFuncImpl() } -// CompiledFunction is an expression that represents a fully-analyzed PostgreSQL function. +// CompiledAggregateFunction is an expression that represents a fully-analyzed PostgreSQL aggregate function. type CompiledAggregateFunction struct { *CompiledFunction aggId sql.ColumnId } -func (c *CompiledAggregateFunction) Id() sql.ColumnId { - return c.aggId -} - -func (c *CompiledAggregateFunction) WithId(id sql.ColumnId) sql.IdExpression { - nc := *c - nc.aggId = id - return &nc -} - -func (c *CompiledAggregateFunction) NewWindowFunction() (sql.WindowFunction, error) { - panic("windows are not implemented yet") -} - -func (c *CompiledAggregateFunction) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { - panic("windows are not implemented yet") -} - -func (c *CompiledAggregateFunction) Window() *sql.WindowDefinition { - panic("windows are not implemented yet") -} - -type arrayAggBuffer struct { - elements []any -} - -func newArrayAggBuffer() *arrayAggBuffer { - return &arrayAggBuffer{ - elements: make([]any, 0), - } -} - -func (a *arrayAggBuffer) Dispose() {} - -func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) { - if len(a.elements) == 0 { - return nil, nil - } - return a.elements, nil -} - -func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { - a.elements = append(a.elements, row[0]) - return nil -} - -func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { - return newArrayAggBuffer(), nil -} - var _ AggregateFunction = (*CompiledAggregateFunction)(nil) // NewCompiledAggregateFunction returns a newly compiled function. @@ -112,100 +57,6 @@ func newCompiledAggregateFunctionInternal( return c } -// FunctionName implements the interface sql.Expression. -func (c *CompiledAggregateFunction) FunctionName() string { - return c.Name -} - -// Description implements the interface sql.Expression. -func (c *CompiledAggregateFunction) Description() string { - return fmt.Sprintf("The PostgreSQL function `%s`", c.Name) -} - -// Resolved implements the interface sql.Expression. -func (c *CompiledAggregateFunction) Resolved() bool { - for _, param := range c.Arguments { - if !param.Resolved() { - return false - } - } - // We don't error until evaluation time, so we need to tell the engine we're resolved if there was a stashed error - return c.stashedErr != nil || c.overload.Valid() -} - -// StashedError returns the stashed error if one exists. Otherwise, returns nil. -func (c *CompiledAggregateFunction) StashedError() error { - if c == nil { - return nil - } - return c.stashedErr -} - -// String implements the interface sql.Expression. -func (c *CompiledAggregateFunction) String() string { - sb := strings.Builder{} - sb.WriteString(c.Name + "(") - for i, param := range c.Arguments { - // Aliases will output the string "x as x", which is an artifact of how we build the AST, so we'll bypass it - if alias, ok := param.(*expression.Alias); ok { - param = alias.Child - } - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(param.String()) - } - sb.WriteString(")") - return sb.String() -} - -// OverloadString returns the name of the function represented by the given overload. -func (c *CompiledAggregateFunction) OverloadString(types []*pgtypes.DoltgresType) string { - sb := strings.Builder{} - sb.WriteString(c.Name + "(") - for i, t := range types { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(t.String()) - } - sb.WriteString(")") - return sb.String() -} - -// Type implements the interface sql.Expression. -func (c *CompiledAggregateFunction) Type() sql.Type { - if len(c.callResolved) > 0 { - return c.callResolved[len(c.callResolved)-1] - } - // Compilation must have errored, so we'll return the unknown type - return pgtypes.Unknown -} - -// IsNullable implements the interface sql.Expression. -func (c *CompiledAggregateFunction) IsNullable() bool { - // All functions seem to return NULL when given a NULL value - return true -} - -// IsNonDeterministic implements the interface sql.NonDeterministicExpression. -func (c *CompiledAggregateFunction) IsNonDeterministic() bool { - if c.overload.Valid() { - return c.overload.Function().NonDeterministic() - } - // Compilation must have errored, so we'll just return true - return true -} - -// IsVariadic returns whether this function has any variadic parameters. -func (c *CompiledAggregateFunction) IsVariadic() bool { - if c.overload.Valid() { - return c.overload.params.variadic != -1 - } - // Compilation must have errored, so we'll just return true - return true -} - // Eval implements the interface sql.Expression. func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // TODO: probably should be an error? @@ -287,11 +138,6 @@ func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interfa } } -// Children implements the interface sql.Expression. -func (c *CompiledAggregateFunction) Children() []sql.Expression { - return c.Arguments -} - // WithChildren implements the interface sql.Expression. func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != len(c.Arguments) { @@ -304,3 +150,53 @@ func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sq // specificFuncImpl implements the interface sql.Expression. func (*CompiledAggregateFunction) specificFuncImpl() {} + +type arrayAggBuffer struct { + elements []any +} + +func newArrayAggBuffer() *arrayAggBuffer { + return &arrayAggBuffer{ + elements: make([]any, 0), + } +} + +func (a *arrayAggBuffer) Dispose() {} + +func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) { + if len(a.elements) == 0 { + return nil, nil + } + return a.elements, nil +} + +func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { + a.elements = append(a.elements, row[0]) + return nil +} + +func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { + return newArrayAggBuffer(), nil +} + +func (c *CompiledAggregateFunction) Id() sql.ColumnId { + return c.aggId +} + +func (c *CompiledAggregateFunction) WithId(id sql.ColumnId) sql.IdExpression { + nc := *c + nc.aggId = id + return &nc +} + +func (c *CompiledAggregateFunction) NewWindowFunction() (sql.WindowFunction, error) { + panic("windows are not implemented yet") +} + +func (c *CompiledAggregateFunction) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + panic("windows are not implemented yet") +} + +func (c *CompiledAggregateFunction) Window() *sql.WindowDefinition { + panic("windows are not implemented yet") +} From a2b9e74626e24fffd656fa5350a163f28244e375 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 21 May 2025 16:24:33 -0700 Subject: [PATCH 04/18] checkpoint --- server/functions/framework/catalog.go | 20 ++--- .../framework/compiled_aggregate_function.go | 6 +- .../functions/framework/compiled_function.go | 16 ++++ server/functions/framework/functions.go | 79 +++++-------------- testing/go/functions_test.go | 12 ++- 5 files changed, 54 insertions(+), 79 deletions(-) diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index b4609d6790..8c71a468b7 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -15,6 +15,7 @@ package framework import ( + "fmt" "strings" "github.com/cockroachdb/errors" @@ -64,33 +65,24 @@ func RegisterFunction(f FunctionInterface) { } } -// RegisterFunction registers the given function, so that it will be usable from a running server. This should be called +// RegisterAggregateFunction registers the given function, so that it will be usable from a running server. This should be called // from within an init(). func RegisterAggregateFunction(f AggregateFunctionInterface) { if initializedFunctions { panic("attempted to register a function after the init() phase") } switch f := f.(type) { - case Function0: - name := strings.ToLower(f.Name) - Catalog[name] = append(Catalog[name], f) - case Function1: - name := strings.ToLower(f.Name) - Catalog[name] = append(Catalog[name], f) - case Function2: + case Func1Aggregate: name := strings.ToLower(f.Name) Catalog[name] = append(Catalog[name], f) - case Function3: + case Func2Aggregate: name := strings.ToLower(f.Name) Catalog[name] = append(Catalog[name], f) - case Function4: + case Func3Aggregate: name := strings.ToLower(f.Name) Catalog[name] = append(Catalog[name], f) - case InterpretedFunction: - name := strings.ToLower(f.ID.FunctionName()) - Catalog[name] = append(Catalog[name], f) default: - panic("unhandled function type") + panic(fmt.Sprintf("unhandled function type %T", f)) } } diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index 74272b40c1..e6723fe1a7 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -155,10 +155,10 @@ type arrayAggBuffer struct { elements []any } -func newArrayAggBuffer() *arrayAggBuffer { +func newArrayAggBuffer() (sql.AggregationBuffer, error) { return &arrayAggBuffer{ elements: make([]any, 0), - } + }, nil } func (a *arrayAggBuffer) Dispose() {} @@ -176,7 +176,7 @@ func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { } func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { - return newArrayAggBuffer(), nil + return newArrayAggBuffer() } func (c *CompiledAggregateFunction) Id() sql.ColumnId { diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index a3942918fb..54ae98de1b 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -721,3 +721,19 @@ func (c *CompiledFunction) analyzeParameters() (originalTypes []*pgtypes.Doltgre // specificFuncImpl implements the interface sql.Expression. func (*CompiledFunction) specificFuncImpl() {} + +func init() { + RegisterAggregateFunction(Func1Aggregate{ + Function1: Function1{ + Name: "array_agg", + Return: pgtypes.AnyArray, + Parameters: [1]*pgtypes.DoltgresType{ + pgtypes.AnyElement, + }, + Callable: func(ctx *sql.Context, paramsAndReturn [2]*pgtypes.DoltgresType, val1 any) (any, error) { + return nil, nil + }, + }, + NewAggBuffer: newArrayAggBuffer, + }) +} diff --git a/server/functions/framework/functions.go b/server/functions/framework/functions.go index fdd632afd7..5c4a04d944 100644 --- a/server/functions/framework/functions.go +++ b/server/functions/framework/functions.go @@ -45,9 +45,10 @@ type FunctionInterface interface { enforceInterfaceInheritance(error) } +// AggregateFunction is an interface for PostgreSQL aggregate functions type AggregateFunctionInterface interface { FunctionInterface - sql.Aggregation + NewBuffer() (sql.AggregationBuffer, error) } // Function0 is a function that does not take any parameters. @@ -288,72 +289,34 @@ func (f Function4) enforceInterfaceInheritance(error) {} // Func1Aggregate is a function that takes one parameter and is an aggregate function. type Func1Aggregate struct { - Function0 + Function1 + NewAggBuffer func() (sql.AggregationBuffer, error) } -func (f Func1Aggregate) Resolved() bool { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) String() string { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) Type() sql.Type { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) IsNullable() bool { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) Children() []sql.Expression { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) WithChildren(children ...sql.Expression) (sql.Expression, error) { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) Id() sql.ColumnId { - // TODO implement me - panic("implement me") -} - -func (f Func1Aggregate) WithId(columnId sql.ColumnId) sql.IdExpression { - // TODO implement me - panic("implement me") +func (f Func1Aggregate) NewBuffer() (sql.AggregationBuffer, error) { + return f.NewAggBuffer() } -func (f Func1Aggregate) NewWindowFunction() (sql.WindowFunction, error) { - // TODO implement me - panic("implement me") +// Func2Aggregate is a function that takes one parameter and is an aggregate function. +type Func2Aggregate struct { + Function2 + NewAggBuffer func() (sql.AggregationBuffer, error) } -func (f Func1Aggregate) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { - // TODO implement me - panic("implement me") +func (f Func2Aggregate) NewBuffer() (sql.AggregationBuffer, error) { + return f.NewAggBuffer() } -func (f Func1Aggregate) Window() *sql.WindowDefinition { - // TODO implement me - panic("implement me") +// Func3Aggregate is a function that takes one parameter and is an aggregate function. +type Func3Aggregate struct { + Function3 + NewAggBuffer func() (sql.AggregationBuffer, error) } -func (f Func1Aggregate) NewBuffer() (sql.AggregationBuffer, error) { - // TODO implement me - panic("implement me") +func (f Func3Aggregate) NewBuffer() (sql.AggregationBuffer, error) { + return f.NewAggBuffer() } -var _ AggregateFunction = Func1Aggregate{} \ No newline at end of file +var _ AggregateFunctionInterface = Func1Aggregate{} +var _ AggregateFunctionInterface = Func2Aggregate{} +var _ AggregateFunctionInterface = Func3Aggregate{} diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index a7e5fd8fec..a0318d58c2 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -20,15 +20,19 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) -func TestSimple(t *testing.T) { +func TestAggregateFunctions(t *testing.T) { RunScripts(t, []ScriptTest{ { - Name: "left", + Name: "array_agg", + SetUpScript: []string{ + `CREATE TABLE test (pk INT primary key, v1 INT, v2 INT);`, + `INSERT INTO test VALUES (1, 1, 2), (2, 3, 4), (3, 5, 6);`, + }, Assertions: []ScriptTestAssertion{ { - Query: `SELECT left('abc', 1);`, + Query: `SELECT array_agg(v1) FROM test;`, Expected: []sql.Row{ - {"a"}, + {[]int64{1, 3, 5}}, }, }, }, From 8dad3eece16de60ae6ead274f9a3d6e9e045fa61 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 21 May 2025 16:33:03 -0700 Subject: [PATCH 05/18] Postgres agg func determination --- server/analyzer/init.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 2babfa310e..48d34bbf29 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -17,6 +17,7 @@ package analyzer import ( "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/planbuilder" ) // IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs @@ -101,6 +102,23 @@ func initEngine() { // place to put it. Our foreign key validation logic is different from MySQL's, and since it's not an analyzer rule // we can't swap out a rule like the rest of the logic in this package, we have to do a function swap. plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition + + planbuilder.IsAggregateFunc = IsAggregateFunc +} + +// IsAggregateFunc checks if the given function name is an aggregate function. This is the entire set supported by +// MySQL plus some postgres specific ones. +func IsAggregateFunc(name string) bool { + if planbuilder.IsMySQLAggregateFuncName(name) { + return true + } + + switch name { + case "array_agg": + return true + } + + return false } // insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice. From f7e4f97da8484cd519f4d3d7d7c8d5eced08b7f7 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 12:13:27 -0700 Subject: [PATCH 06/18] Pulled out more aggregate func methods, did some renaming of poorly named vars --- server/functions/framework/catalog.go | 87 +++++++++++++++++++++---- server/functions/framework/operators.go | 24 +++---- 2 files changed, 87 insertions(+), 24 deletions(-) diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 8c71a468b7..843fd66b54 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -30,7 +30,7 @@ import ( var Catalog = map[string][]FunctionInterface{} // AggregateCatalog contains all of the PostgreSQL aggregate functions. -var AggregateCatalog = map[string][]AggregateFunction{} +var AggregateCatalog = map[string][]AggregateFunctionInterface{} // initializedFunctions simply states whether Initialize has been called yet. var initializedFunctions = false @@ -74,13 +74,13 @@ func RegisterAggregateFunction(f AggregateFunctionInterface) { switch f := f.(type) { case Func1Aggregate: name := strings.ToLower(f.Name) - Catalog[name] = append(Catalog[name], f) + AggregateCatalog[name] = append(AggregateCatalog[name], f) case Func2Aggregate: name := strings.ToLower(f.Name) - Catalog[name] = append(Catalog[name], f) + AggregateCatalog[name] = append(AggregateCatalog[name], f) case Func3Aggregate: name := strings.ToLower(f.Name) - Catalog[name] = append(Catalog[name], f) + AggregateCatalog[name] = append(AggregateCatalog[name], f) default: panic(fmt.Sprintf("unhandled function type %T", f)) } @@ -101,6 +101,7 @@ func Initialize() { replaceGmsBuiltIns() validateFunctions() compileFunctions() + compileAggs() } // replaceGmsBuiltIns replaces all GMS built-ins that have conflicting names with PostgreSQL functions. @@ -179,6 +180,28 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface) namedCatalog[funcName] = overloads } +// compileNonOperatorFunction creates a CompiledFunction for each overload of the given function. +func compileNonOperatorAggFunction(funcName string, overloads []AggregateFunctionInterface) { + overloadTree := NewOverloads() + for _, functionOverload := range overloads { + if err := overloadTree.Add(functionOverload); err != nil { + panic(err) + } + } + + // Store the compiled function into the engine's built-in functions + // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog + createFunc := func(params ...sql.Expression) (sql.Expression, error) { + return NewCompiledFunction(funcName, params, overloadTree, false), nil + } + function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ + Name: funcName, + Fn: createFunc, + }) + compiledCatalog[funcName] = createFunc + namedCatalog[funcName] = overloads +} + // compileFunctions creates a CompiledFunction for each overload of each function in the catalog. func compileFunctions() { for funcName, overloads := range Catalog { @@ -190,10 +213,50 @@ func compileFunctions() { // special rules, so it's far more efficient to reuse it for operators. Operators are also a special case since they // all have different names, while standard overload deducers work on a function-name basis. for signature, functionOverload := range unaryFunctions { - overloads, ok := unaryAggregateOverloads[signature.Operator] + overloads, ok := unaryOperatorOverloads[signature.Operator] + if !ok { + overloads = NewOverloads() + unaryOperatorOverloads[signature.Operator] = overloads + } + if err := overloads.Add(functionOverload); err != nil { + panic(err) + } + } + + for signature, functionOverload := range binaryFunctions { + overloads, ok := binaryOperatorOverloads[signature.Operator] + if !ok { + overloads = NewOverloads() + binaryOperatorOverloads[signature.Operator] = overloads + } + if err := overloads.Add(functionOverload); err != nil { + panic(err) + } + } + + // Add all permutations for the unary and binary operators + for operator, overload := range unaryOperatorOverloads { + unaryOperatorPermutations[operator] = overload.overloadsForParams(1) + } + for operator, overload := range binaryOperatorOverloads { + binaryOperatorPermutations[operator] = overload.overloadsForParams(2) + } +} + +func compileAggs() { + for funcName, overloads := range AggregateCatalog { + compileNonOperatorFunction(funcName, overloads) + } + + // Build the overload for all unary and binary functions based on their operator. This will be used for fallback if + // an exact match is not found. Compiled functions (which wrap the overload deducer) handle upcasting and other + // special rules, so it's far more efficient to reuse it for operators. Operators are also a special case since they + // all have different names, while standard overload deducers work on a function-name basis. + for signature, functionOverload := range unaryFunctions { + overloads, ok := unaryOperatorOverloads[signature.Operator] if !ok { overloads = NewOverloads() - unaryAggregateOverloads[signature.Operator] = overloads + unaryOperatorOverloads[signature.Operator] = overloads } if err := overloads.Add(functionOverload); err != nil { panic(err) @@ -201,10 +264,10 @@ func compileFunctions() { } for signature, functionOverload := range binaryFunctions { - overloads, ok := binaryAggregateOverloads[signature.Operator] + overloads, ok := binaryOperatorOverloads[signature.Operator] if !ok { overloads = NewOverloads() - binaryAggregateOverloads[signature.Operator] = overloads + binaryOperatorOverloads[signature.Operator] = overloads } if err := overloads.Add(functionOverload); err != nil { panic(err) @@ -212,10 +275,10 @@ func compileFunctions() { } // Add all permutations for the unary and binary operators - for operator, overload := range unaryAggregateOverloads { - unaryAggregatePermutations[operator] = overload.overloadsForParams(1) + for operator, overload := range unaryOperatorOverloads { + unaryOperatorPermutations[operator] = overload.overloadsForParams(1) } - for operator, overload := range binaryAggregateOverloads { - binaryAggregatePermutations[operator] = overload.overloadsForParams(2) + for operator, overload := range binaryOperatorOverloads { + binaryOperatorPermutations[operator] = overload.overloadsForParams(2) } } diff --git a/server/functions/framework/operators.go b/server/functions/framework/operators.go index 9d29cd66a2..77b15c2ed5 100644 --- a/server/functions/framework/operators.go +++ b/server/functions/framework/operators.go @@ -72,16 +72,16 @@ var ( unaryFunctions = map[unaryFunction]Function1{} // binaryFunctions is a map from a binaryFunction signature to the associated function. binaryFunctions = map[binaryFunction]Function2{} - // unaryAggregateOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions + // unaryOperatorOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions // for that operator. - unaryAggregateOverloads = map[Operator]*Overloads{} - // binaryAggregateOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions + unaryOperatorOverloads = map[Operator]*Overloads{} + // binaryOperatorOverloads is a map from an operator to an Overload deducer that is the aggregate of all functions // for that operator. - binaryAggregateOverloads = map[Operator]*Overloads{} - // unaryAggregatePermutations contains all of the permutations for each unary operator. - unaryAggregatePermutations = map[Operator][]Overload{} - // unaryAggregatePermutations contains all of the permutations for each binary operator. - binaryAggregatePermutations = map[Operator][]Overload{} + binaryOperatorOverloads = map[Operator]*Overloads{} + // unaryOperatorPermutations contains all of the permutations for each unary operator. + unaryOperatorPermutations = map[Operator][]Overload{} + // unaryOperatorPermutations contains all of the permutations for each binary operator. + binaryOperatorPermutations = map[Operator][]Overload{} ) // RegisterUnaryFunction registers the given function, so that it will be usable from a running server. This should @@ -127,8 +127,8 @@ func RegisterBinaryFunction(operator Operator, f Function2) { func GetUnaryFunction(operator Operator) IntermediateFunction { // Returns nil if not found, which is fine as IntermediateFunction will handle the nil deducer return IntermediateFunction{ - Functions: unaryAggregateOverloads[operator], - AllOverloads: unaryAggregatePermutations[operator], + Functions: unaryOperatorOverloads[operator], + AllOverloads: unaryOperatorPermutations[operator], IsOperator: true, } } @@ -137,8 +137,8 @@ func GetUnaryFunction(operator Operator) IntermediateFunction { func GetBinaryFunction(operator Operator) IntermediateFunction { // Returns nil if not found, which is fine as IntermediateFunction will handle the nil deducer return IntermediateFunction{ - Functions: binaryAggregateOverloads[operator], - AllOverloads: binaryAggregatePermutations[operator], + Functions: binaryOperatorOverloads[operator], + AllOverloads: binaryOperatorPermutations[operator], IsOperator: true, } } From c50eb4c5e86f2a1ff699d91518403581ca85dadf Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 12:17:49 -0700 Subject: [PATCH 07/18] Removed unused var and functionality --- server/functions/framework/catalog.go | 40 +------------------ .../functions/framework/compiled_catalog.go | 3 -- 2 files changed, 2 insertions(+), 41 deletions(-) diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 843fd66b54..6f75d33926 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -177,11 +177,10 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface) Fn: createFunc, }) compiledCatalog[funcName] = createFunc - namedCatalog[funcName] = overloads } // compileNonOperatorFunction creates a CompiledFunction for each overload of the given function. -func compileNonOperatorAggFunction(funcName string, overloads []AggregateFunctionInterface) { +func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) { overloadTree := NewOverloads() for _, functionOverload := range overloads { if err := overloadTree.Add(functionOverload); err != nil { @@ -199,7 +198,6 @@ func compileNonOperatorAggFunction(funcName string, overloads []AggregateFunctio Fn: createFunc, }) compiledCatalog[funcName] = createFunc - namedCatalog[funcName] = overloads } // compileFunctions creates a CompiledFunction for each overload of each function in the catalog. @@ -245,40 +243,6 @@ func compileFunctions() { func compileAggs() { for funcName, overloads := range AggregateCatalog { - compileNonOperatorFunction(funcName, overloads) - } - - // Build the overload for all unary and binary functions based on their operator. This will be used for fallback if - // an exact match is not found. Compiled functions (which wrap the overload deducer) handle upcasting and other - // special rules, so it's far more efficient to reuse it for operators. Operators are also a special case since they - // all have different names, while standard overload deducers work on a function-name basis. - for signature, functionOverload := range unaryFunctions { - overloads, ok := unaryOperatorOverloads[signature.Operator] - if !ok { - overloads = NewOverloads() - unaryOperatorOverloads[signature.Operator] = overloads - } - if err := overloads.Add(functionOverload); err != nil { - panic(err) - } - } - - for signature, functionOverload := range binaryFunctions { - overloads, ok := binaryOperatorOverloads[signature.Operator] - if !ok { - overloads = NewOverloads() - binaryOperatorOverloads[signature.Operator] = overloads - } - if err := overloads.Add(functionOverload); err != nil { - panic(err) - } - } - - // Add all permutations for the unary and binary operators - for operator, overload := range unaryOperatorOverloads { - unaryOperatorPermutations[operator] = overload.overloadsForParams(1) - } - for operator, overload := range binaryOperatorOverloads { - binaryOperatorPermutations[operator] = overload.overloadsForParams(2) + compileAggFunction(funcName, overloads) } } diff --git a/server/functions/framework/compiled_catalog.go b/server/functions/framework/compiled_catalog.go index ab0a5f4196..7706d7d096 100644 --- a/server/functions/framework/compiled_catalog.go +++ b/server/functions/framework/compiled_catalog.go @@ -23,9 +23,6 @@ import ( // compiledCatalog contains all of PostgreSQL functions in their compiled forms. var compiledCatalog = map[string]sql.CreateFuncNArgs{} -// namedCatalog contains the definitions of every PostgreSQL function associated with the given name. -var namedCatalog = map[string][]FunctionInterface{} - // GetFunction returns the compiled function with the given name and parameters. Returns false if the function could not // be found. func GetFunction(functionName string, params ...sql.Expression) (*CompiledFunction, bool, error) { From 6923b1ca074e583b00b30570d672fb091c25072a Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 12:41:36 -0700 Subject: [PATCH 08/18] Couple typos, almost working --- server/functions/framework/catalog.go | 2 +- server/functions/framework/compiled_function.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 6f75d33926..1ecd556400 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -191,7 +191,7 @@ func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) // Store the compiled function into the engine's built-in functions // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog createFunc := func(params ...sql.Expression) (sql.Expression, error) { - return NewCompiledFunction(funcName, params, overloadTree, false), nil + return NewCompiledAggregateFunction(funcName, params, overloadTree), nil } function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ Name: funcName, diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 54ae98de1b..13879f25ff 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -307,7 +307,7 @@ func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, err case InterpretedFunction: return plpgsql.Call(ctx, f, c.runner, c.callResolved, args) default: - return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval") + return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval %T", f) } } From 20c57b9225fa28a4c0594c32df891ed02ee5fb84 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 14:12:59 -0700 Subject: [PATCH 09/18] Bug fix --- .../framework/compiled_aggregate_function.go | 34 +++++++++++++++++-- .../functions/framework/compiled_function.go | 5 +-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index e6723fe1a7..2cb070cfa9 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -15,10 +15,13 @@ package framework import ( + "strings" + cerrors "github.com/cockroachdb/errors" "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" ) // AggregateFunction is an expression that represents CompiledAggregateFunction @@ -48,19 +51,19 @@ func newCompiledAggregateFunctionInternal( overloads *Overloads, fnOverloads []Overload, ) *CompiledAggregateFunction { - + cf := newCompiledFunctionInternal(name, args, overloads, fnOverloads, false, nil) c := &CompiledAggregateFunction{ CompiledFunction: cf, } - + return c } // Eval implements the interface sql.Expression. func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // TODO: probably should be an error? - + // If we have a stashed error, then we should return that now. Errors are stashed when they're supposed to be // returned during the call to Eval. This helps to ensure consistency with how errors are returned in Postgres. if c.stashedErr != nil { @@ -148,9 +151,34 @@ func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sq return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads), nil } +// SetStatementRunner implements the interface analyzer.Interpreter. +func (c *CompiledAggregateFunction) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Expression { + nc := *c + nc.runner = runner + return &nc +} + // specificFuncImpl implements the interface sql.Expression. func (*CompiledAggregateFunction) specificFuncImpl() {} +func (c *CompiledAggregateFunction) DebugString() string { + sb := strings.Builder{} + sb.WriteString("CompiledAggregateFunction:") + sb.WriteString(c.Name + "(") + for i, param := range c.Arguments { + // Aliases will output the string "x as x", which is an artifact of how we build the AST, so we'll bypass it + if alias, ok := param.(*expression.Alias); ok { + param = alias.Child + } + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(sql.DebugString(param)) + } + sb.WriteString(")") + return sb.String() +} + type arrayAggBuffer struct { elements []any } diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 13879f25ff..5f099ac66a 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -328,8 +328,9 @@ func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Express // SetStatementRunner implements the interface analyzer.Interpreter. func (c *CompiledFunction) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Expression { - c.runner = runner - return c + nc := *c + nc.runner = runner + return &nc } // GetQuickFunction returns the QuickFunction form of this function, if it exists. If one does not exist, then this From 8f91f1fddec7b782318f0377e2834405ee8b8303 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 14:19:03 -0700 Subject: [PATCH 10/18] Fully working POC --- testing/go/functions_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index a0318d58c2..36dc1a2b18 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -32,7 +32,7 @@ func TestAggregateFunctions(t *testing.T) { { Query: `SELECT array_agg(v1) FROM test;`, Expected: []sql.Row{ - {[]int64{1, 3, 5}}, + {"{1,3,5}"}, }, }, }, From 9cbe7b830c6be8a19057c681a9af5dfce4dc85ed Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 16:15:11 -0700 Subject: [PATCH 11/18] Moved array_agg into a new package --- server/functions/aggregate/array_agg.go | 66 +++++++++ server/functions/aggregate/init.go | 19 +++ server/functions/framework/catalog.go | 4 +- .../framework/compiled_aggregate_function.go | 126 ++---------------- .../functions/framework/compiled_function.go | 16 --- server/functions/framework/functions.go | 1 + server/initialization/initialization.go | 2 + 7 files changed, 101 insertions(+), 133 deletions(-) create mode 100755 server/functions/aggregate/array_agg.go create mode 100755 server/functions/aggregate/init.go diff --git a/server/functions/aggregate/array_agg.go b/server/functions/aggregate/array_agg.go new file mode 100755 index 0000000000..94cb2e15f5 --- /dev/null +++ b/server/functions/aggregate/array_agg.go @@ -0,0 +1,66 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +import ( + "github.com/dolthub/doltgresql/server/functions/framework" + pgtypes "github.com/dolthub/doltgresql/server/types" + "github.com/dolthub/go-mysql-server/sql" +) + +// initArrayAgg registers the functions to the catalog. +func initArrayAgg() { + framework.RegisterAggregateFunction(array_agg) +} + +// array_agg represents the PostgreSQL array_agg function. +var array_agg = framework.Func1Aggregate{ + Function1: framework.Function1{ + Name: "array_agg", + Return: pgtypes.AnyArray, + Parameters: [1]*pgtypes.DoltgresType{ + pgtypes.AnyElement, + }, + Callable: func(ctx *sql.Context, paramsAndReturn [2]*pgtypes.DoltgresType, val1 any) (any, error) { + return nil, nil + }, + }, + NewAggBuffer: newArrayAggBuffer, +} + +type arrayAggBuffer struct { + elements []any +} + +func newArrayAggBuffer() (sql.AggregationBuffer, error) { + return &arrayAggBuffer{ + elements: make([]any, 0), + }, nil +} + +func (a *arrayAggBuffer) Dispose() {} + +func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) { + if len(a.elements) == 0 { + return nil, nil + } + return a.elements, nil +} + +func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { + a.elements = append(a.elements, row[0]) + return nil +} + diff --git a/server/functions/aggregate/init.go b/server/functions/aggregate/init.go new file mode 100755 index 0000000000..469a0d52af --- /dev/null +++ b/server/functions/aggregate/init.go @@ -0,0 +1,19 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +func Init() { + initArrayAgg() +} diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 1ecd556400..3ff4e9858e 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -181,8 +181,10 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface) // compileNonOperatorFunction creates a CompiledFunction for each overload of the given function. func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) { + var newBuffer func()(sql.AggregationBuffer, error) overloadTree := NewOverloads() for _, functionOverload := range overloads { + newBuffer = functionOverload.NewBuffer if err := overloadTree.Add(functionOverload); err != nil { panic(err) } @@ -191,7 +193,7 @@ func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) // Store the compiled function into the engine's built-in functions // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog createFunc := func(params ...sql.Expression) (sql.Expression, error) { - return NewCompiledAggregateFunction(funcName, params, overloadTree), nil + return NewCompiledAggregateFunction(funcName, params, overloadTree, newBuffer), nil } function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ Name: funcName, diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index 2cb070cfa9..4ed2426d7a 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -18,8 +18,6 @@ import ( "strings" cerrors "github.com/cockroachdb/errors" - "github.com/dolthub/doltgresql/server/plpgsql" - pgtypes "github.com/dolthub/doltgresql/server/types" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" ) @@ -34,27 +32,24 @@ type AggregateFunction interface { // CompiledAggregateFunction is an expression that represents a fully-analyzed PostgreSQL aggregate function. type CompiledAggregateFunction struct { *CompiledFunction - aggId sql.ColumnId + aggId sql.ColumnId + newBuffer func() (sql.AggregationBuffer, error) } var _ AggregateFunction = (*CompiledAggregateFunction)(nil) // NewCompiledAggregateFunction returns a newly compiled function. -func NewCompiledAggregateFunction(name string, args []sql.Expression, functions *Overloads) *CompiledAggregateFunction { - return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args))) +// TODO: newBuffer probably needs to be parameterized in the overloads +func NewCompiledAggregateFunction(name string, args []sql.Expression, functions *Overloads, newBuffer func() (sql.AggregationBuffer, error)) *CompiledAggregateFunction { + return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), newBuffer) } // newCompiledFunctionInternal is called internally, which skips steps that may have already been processed. -func newCompiledAggregateFunctionInternal( - name string, - args []sql.Expression, - overloads *Overloads, - fnOverloads []Overload, -) *CompiledAggregateFunction { - +func newCompiledAggregateFunctionInternal(name string, args []sql.Expression, overloads *Overloads, fnOverloads []Overload, newBuffer func() (sql.AggregationBuffer, error)) *CompiledAggregateFunction { cf := newCompiledFunctionInternal(name, args, overloads, fnOverloads, false, nil) c := &CompiledAggregateFunction{ CompiledFunction: cf, + newBuffer: newBuffer, } return c @@ -62,83 +57,7 @@ func newCompiledAggregateFunctionInternal( // Eval implements the interface sql.Expression. func (c *CompiledAggregateFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - // TODO: probably should be an error? - - // If we have a stashed error, then we should return that now. Errors are stashed when they're supposed to be - // returned during the call to Eval. This helps to ensure consistency with how errors are returned in Postgres. - if c.stashedErr != nil { - return nil, c.stashedErr - } - - // Evaluate all arguments, returning immediately if we encounter a null argument and the function is marked STRICT - var err error - isStrict := c.overload.Function().IsStrict() - args := make([]any, len(c.Arguments)) - for i, arg := range c.Arguments { - args[i], err = arg.Eval(ctx, row) - if err != nil { - return nil, err - } - // TODO: once we remove GMS types from all of our expressions, we can remove this step which ensures the correct type - if _, ok := arg.Type().(*pgtypes.DoltgresType); !ok { - dt, err := pgtypes.FromGmsTypeToDoltgresType(arg.Type()) - if err != nil { - return nil, err - } - args[i], _, _ = dt.Convert(ctx, args[i]) - } - if args[i] == nil && isStrict { - return nil, nil - } - } - - if len(c.overload.casts) > 0 { - targetParamTypes := c.overload.Function().GetParameters() - for i, arg := range args { - // For variadic params, we need to identify the corresponding target type - var targetType *pgtypes.DoltgresType - isVariadicArg := c.overload.params.variadic >= 0 && i >= len(c.overload.params.paramTypes)-1 - if isVariadicArg { - targetType = targetParamTypes[c.overload.params.variadic] - if !targetType.IsArrayType() { - // should be impossible, we check this at function compile time - return nil, cerrors.Errorf("variadic arguments must be array types, was %T", targetType) - } - targetType = targetType.ArrayBaseType() - } else { - targetType = targetParamTypes[i] - } - - if c.overload.casts[i] != nil { - args[i], err = c.overload.casts[i](ctx, arg, targetType) - if err != nil { - return nil, err - } - } else { - return nil, cerrors.Errorf("function %s is missing the appropriate implicit cast", c.OverloadString(c.originalTypes)) - } - } - } - - args = c.overload.params.coalesceVariadicValues(args) - - // Call the function - switch f := c.overload.Function().(type) { - case Function0: - return f.Callable(ctx) - case Function1: - return f.Callable(ctx, ([2]*pgtypes.DoltgresType)(c.callResolved), args[0]) - case Function2: - return f.Callable(ctx, ([3]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1]) - case Function3: - return f.Callable(ctx, ([4]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2]) - case Function4: - return f.Callable(ctx, ([5]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2], args[3]) - case InterpretedFunction: - return plpgsql.Call(ctx, f, c.runner, c.callResolved, args) - default: - return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval") - } + return nil, cerrors.New("Eval should not be called on CompiledAggregateFunction") } // WithChildren implements the interface sql.Expression. @@ -148,7 +67,7 @@ func (c *CompiledAggregateFunction) WithChildren(children ...sql.Expression) (sq } // We have to re-resolve here, since the change in children may require it (e.g. we have more type info than we did) - return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads), nil + return newCompiledAggregateFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.newBuffer), nil } // SetStatementRunner implements the interface analyzer.Interpreter. @@ -178,33 +97,8 @@ func (c *CompiledAggregateFunction) DebugString() string { sb.WriteString(")") return sb.String() } - -type arrayAggBuffer struct { - elements []any -} - -func newArrayAggBuffer() (sql.AggregationBuffer, error) { - return &arrayAggBuffer{ - elements: make([]any, 0), - }, nil -} - -func (a *arrayAggBuffer) Dispose() {} - -func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) { - if len(a.elements) == 0 { - return nil, nil - } - return a.elements, nil -} - -func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { - a.elements = append(a.elements, row[0]) - return nil -} - func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { - return newArrayAggBuffer() + return c.newBuffer() } func (c *CompiledAggregateFunction) Id() sql.ColumnId { diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 5f099ac66a..e82ad050a8 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -722,19 +722,3 @@ func (c *CompiledFunction) analyzeParameters() (originalTypes []*pgtypes.Doltgre // specificFuncImpl implements the interface sql.Expression. func (*CompiledFunction) specificFuncImpl() {} - -func init() { - RegisterAggregateFunction(Func1Aggregate{ - Function1: Function1{ - Name: "array_agg", - Return: pgtypes.AnyArray, - Parameters: [1]*pgtypes.DoltgresType{ - pgtypes.AnyElement, - }, - Callable: func(ctx *sql.Context, paramsAndReturn [2]*pgtypes.DoltgresType, val1 any) (any, error) { - return nil, nil - }, - }, - NewAggBuffer: newArrayAggBuffer, - }) -} diff --git a/server/functions/framework/functions.go b/server/functions/framework/functions.go index 5c4a04d944..21977311ba 100644 --- a/server/functions/framework/functions.go +++ b/server/functions/framework/functions.go @@ -48,6 +48,7 @@ type FunctionInterface interface { // AggregateFunction is an interface for PostgreSQL aggregate functions type AggregateFunctionInterface interface { FunctionInterface + // TODO: this maybe needs to take the place of the Callable function NewBuffer() (sql.AggregationBuffer, error) } diff --git a/server/initialization/initialization.go b/server/initialization/initialization.go index 1627926dd3..6f62ac1dd7 100644 --- a/server/initialization/initialization.go +++ b/server/initialization/initialization.go @@ -29,6 +29,7 @@ import ( "github.com/dolthub/doltgresql/server/cast" "github.com/dolthub/doltgresql/server/config" "github.com/dolthub/doltgresql/server/functions" + "github.com/dolthub/doltgresql/server/functions/aggregate" "github.com/dolthub/doltgresql/server/functions/binary" "github.com/dolthub/doltgresql/server/functions/framework" "github.com/dolthub/doltgresql/server/functions/unary" @@ -53,6 +54,7 @@ func Initialize(dEnv *env.DoltEnv) { binary.Init() unary.Init() functions.Init() + aggregate.Init() cast.Init() framework.Initialize() sql.GlobalParser = pgsql.NewPostgresParser() From acb996503002ae895f3e72baca0166360670eb81 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 16:22:35 -0700 Subject: [PATCH 12/18] new test methods --- testing/go/functions_test.go | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 36dc1a2b18..03a97ecad1 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -25,14 +25,35 @@ func TestAggregateFunctions(t *testing.T) { { Name: "array_agg", SetUpScript: []string{ - `CREATE TABLE test (pk INT primary key, v1 INT, v2 INT);`, - `INSERT INTO test VALUES (1, 1, 2), (2, 3, 4), (3, 5, 6);`, + `CREATE TABLE t1 (pk INT primary key, t timestamp, v varchar, f float[]);`, + `INSERT INTO t1 VALUES + (1, '2023-01-01 00:00:00', 'a', '{1.0, 2.0}'), + (2, '2023-01-02 00:00:00', 'b', '{3.0, 4.0}'), + (3, '2023-01-03 00:00:00', 'c', '{5.0, 6.0}');`, }, Assertions: []ScriptTestAssertion{ { - Query: `SELECT array_agg(v1) FROM test;`, + Query: `SELECT array_agg(pk) FROM t1;`, Expected: []sql.Row{ - {"{1,3,5}"}, + {"{1,2,3}"}, + }, + }, + { + Query: `SELECT array_agg(t) FROM t1;`, + Expected: []sql.Row{ + {`{"2023-01-01 00:00:00","2023-01-02 00:00:00","2023-01-03 00:00:00"}`}, + }, + }, + { + Query: `SELECT array_agg(v) FROM t1;`, + Expected: []sql.Row{ + {"{a,b,c}"}, + }, + }, + { + Query: `SELECT array_agg(f) FROM t1;`, + Expected: []sql.Row{ + {"{{1.0,2.0},{3.0,4.0},{5.0,6.0}}"}, }, }, }, From 8e2150fc4fa95b713ec3a4edfdf6b73fc0400847 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 16:34:31 -0700 Subject: [PATCH 13/18] Added skipped test --- testing/go/functions_test.go | 1 + testing/go/types_test.go | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 03a97ecad1..43ff76e180 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -51,6 +51,7 @@ func TestAggregateFunctions(t *testing.T) { }, }, { + Skip: true, // Higher-level arrays don't work because they panic during output Query: `SELECT array_agg(f) FROM t1;`, Expected: []sql.Row{ {"{{1.0,2.0},{3.0,4.0},{5.0,6.0}}"}, diff --git a/testing/go/types_test.go b/testing/go/types_test.go index a646bd9a13..caaadfbe9d 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -623,6 +623,22 @@ var typesTests = []ScriptTest{ }, }, { + Name: "2D array", + Skip: true, // multiple dimensions not supported yet + SetUpScript: []string{ + "CREATE TABLE t_varchar (id INTEGER primary key, v1 CHARACTER VARYING[][]);", + "INSERT INTO t_varchar VALUES (1, '{{abcdefghij, NULL}, {1234, abc}}'), (2, ARRAY['ab''cdef', 'what', 'is,hi', 'wh\"at', '}', '{', '{}']);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t_varchar ORDER BY id;", + Expected: []sql.Row{ + {1, "{abcdefghij,NULL}"}, + {2, `{ab'cdef,what,"is,hi","wh\"at","}","{","{}"}`}, + }, + }, + }, + }, { Name: "Cidr type", Skip: true, SetUpScript: []string{ From f7d39ac692892dc1028d17a4b0f5a8d671122838 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 16:35:55 -0700 Subject: [PATCH 14/18] Formatting --- server/analyzer/init.go | 6 +++--- server/functions/aggregate/array_agg.go | 4 ++-- server/functions/framework/catalog.go | 2 +- server/functions/framework/functions.go | 2 +- testing/go/functions_test.go | 2 +- testing/go/types_test.go | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 48d34bbf29..20f5fc0fbf 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -102,7 +102,7 @@ func initEngine() { // place to put it. Our foreign key validation logic is different from MySQL's, and since it's not an analyzer rule // we can't swap out a rule like the rest of the logic in this package, we have to do a function swap. plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition - + planbuilder.IsAggregateFunc = IsAggregateFunc } @@ -112,12 +112,12 @@ func IsAggregateFunc(name string) bool { if planbuilder.IsMySQLAggregateFuncName(name) { return true } - + switch name { case "array_agg": return true } - + return false } diff --git a/server/functions/aggregate/array_agg.go b/server/functions/aggregate/array_agg.go index 94cb2e15f5..2a0457cb8b 100755 --- a/server/functions/aggregate/array_agg.go +++ b/server/functions/aggregate/array_agg.go @@ -15,9 +15,10 @@ package aggregate import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" - "github.com/dolthub/go-mysql-server/sql" ) // initArrayAgg registers the functions to the catalog. @@ -63,4 +64,3 @@ func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error { a.elements = append(a.elements, row[0]) return nil } - diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 3ff4e9858e..61eef30891 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -181,7 +181,7 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface) // compileNonOperatorFunction creates a CompiledFunction for each overload of the given function. func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) { - var newBuffer func()(sql.AggregationBuffer, error) + var newBuffer func() (sql.AggregationBuffer, error) overloadTree := NewOverloads() for _, functionOverload := range overloads { newBuffer = functionOverload.NewBuffer diff --git a/server/functions/framework/functions.go b/server/functions/framework/functions.go index 21977311ba..e8f41dfa85 100644 --- a/server/functions/framework/functions.go +++ b/server/functions/framework/functions.go @@ -291,7 +291,7 @@ func (f Function4) enforceInterfaceInheritance(error) {} // Func1Aggregate is a function that takes one parameter and is an aggregate function. type Func1Aggregate struct { Function1 - NewAggBuffer func() (sql.AggregationBuffer, error) + NewAggBuffer func() (sql.AggregationBuffer, error) } func (f Func1Aggregate) NewBuffer() (sql.AggregationBuffer, error) { diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 43ff76e180..0b529e46bb 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -51,7 +51,7 @@ func TestAggregateFunctions(t *testing.T) { }, }, { - Skip: true, // Higher-level arrays don't work because they panic during output + Skip: true, // Higher-level arrays don't work because they panic during output Query: `SELECT array_agg(f) FROM t1;`, Expected: []sql.Row{ {"{{1.0,2.0},{3.0,4.0},{5.0,6.0}}"}, diff --git a/testing/go/types_test.go b/testing/go/types_test.go index caaadfbe9d..14b7a10d23 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -638,7 +638,7 @@ var typesTests = []ScriptTest{ }, }, }, - }, { + }, { Name: "Cidr type", Skip: true, SetUpScript: []string{ From c75142b8424221c4870f96e0f64033271e6348f2 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 16:40:08 -0700 Subject: [PATCH 15/18] new gms --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 4eee7751b8..12b65437d4 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,9 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad - github.com/dolthub/go-mysql-server v0.20.1-0.20250514213318-e116aa682aaf + github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c + github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index 505c0e6082..6c795cb55f 100644 --- a/go.sum +++ b/go.sum @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszNu37VPQckdhX1BPPVzREsGgNxQeefnlm3g= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.20.1-0.20250514213318-e116aa682aaf h1:sYBhIXdvdBww8Xw6K1QMjmGbuFEfYefNBie65TOxdm0= -github.com/dolthub/go-mysql-server v0.20.1-0.20250514213318-e116aa682aaf/go.mod h1:5ZdrW0fHZbz+8CngT9gksqSX4H3y+7v1pns7tJCEpu0= +github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496 h1:A1CFUqYWyJTjRQrK8k6LbjAUwEqtEkHrLhdcBgRpDpY= +github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -276,8 +276,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c h1:imdag6PPCHAO2rZNsFoQoR4I/vIVTmO/czoOl5rUnbk= -github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c h1:23KvsBtJk2GmHpXwQ/RkwIkdNpWL8tWdHRCiidhnaUA= +github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= From 5bfa72ae254c09b7a2eed60d39b73e852f7473be Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 23 May 2025 16:41:54 -0700 Subject: [PATCH 16/18] latest gms --- go.mod | 4 ++-- go.sum | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index a21b946bca..63a4e6a398 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,9 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad - github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677 + github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c + github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index a97a1b5a39..eff716c837 100644 --- a/go.sum +++ b/go.sum @@ -268,6 +268,8 @@ github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszN github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677 h1:Wn0v7xBxkdzYqDN/4ksI34jstZOskMccT2SYFrvUw4c= github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677/go.mod h1:5ZdrW0fHZbz+8CngT9gksqSX4H3y+7v1pns7tJCEpu0= +github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496 h1:A1CFUqYWyJTjRQrK8k6LbjAUwEqtEkHrLhdcBgRpDpY= +github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -278,6 +280,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c h1:imdag6PPCHAO2rZNsFoQoR4I/vIVTmO/czoOl5rUnbk= github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c h1:23KvsBtJk2GmHpXwQ/RkwIkdNpWL8tWdHRCiidhnaUA= +github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= From 85abf36448b1366295573287edad923bb44290d2 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Tue, 27 May 2025 16:43:20 -0700 Subject: [PATCH 17/18] Some missing docs and simplifications --- server/functions/framework/catalog.go | 6 ----- .../framework/compiled_aggregate_function.go | 9 +++++++- server/functions/framework/functions.go | 22 ------------------- 3 files changed, 8 insertions(+), 29 deletions(-) diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 61eef30891..32480b52ea 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -75,12 +75,6 @@ func RegisterAggregateFunction(f AggregateFunctionInterface) { case Func1Aggregate: name := strings.ToLower(f.Name) AggregateCatalog[name] = append(AggregateCatalog[name], f) - case Func2Aggregate: - name := strings.ToLower(f.Name) - AggregateCatalog[name] = append(AggregateCatalog[name], f) - case Func3Aggregate: - name := strings.ToLower(f.Name) - AggregateCatalog[name] = append(AggregateCatalog[name], f) default: panic(fmt.Sprintf("unhandled function type %T", f)) } diff --git a/server/functions/framework/compiled_aggregate_function.go b/server/functions/framework/compiled_aggregate_function.go index 4ed2426d7a..cbf610b6fb 100644 --- a/server/functions/framework/compiled_aggregate_function.go +++ b/server/functions/framework/compiled_aggregate_function.go @@ -44,7 +44,7 @@ func NewCompiledAggregateFunction(name string, args []sql.Expression, functions return newCompiledAggregateFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), newBuffer) } -// newCompiledFunctionInternal is called internally, which skips steps that may have already been processed. +// newCompiledAggregateFunctionInternal is called internally, which skips steps that may have already been processed. func newCompiledAggregateFunctionInternal(name string, args []sql.Expression, overloads *Overloads, fnOverloads []Overload, newBuffer func() (sql.AggregationBuffer, error)) *CompiledAggregateFunction { cf := newCompiledFunctionInternal(name, args, overloads, fnOverloads, false, nil) c := &CompiledAggregateFunction{ @@ -97,28 +97,35 @@ func (c *CompiledAggregateFunction) DebugString() string { sb.WriteString(")") return sb.String() } + +// NewBuffer implements the interface sql.Aggregation. func (c *CompiledAggregateFunction) NewBuffer() (sql.AggregationBuffer, error) { return c.newBuffer() } +// Id implements the interface sql.Aggregation. func (c *CompiledAggregateFunction) Id() sql.ColumnId { return c.aggId } +// WithId implements the interface sql.Aggregation. func (c *CompiledAggregateFunction) WithId(id sql.ColumnId) sql.IdExpression { nc := *c nc.aggId = id return &nc } +// NewWindowFunction implements the interface sql.WindowAdaptableExpression. func (c *CompiledAggregateFunction) NewWindowFunction() (sql.WindowFunction, error) { panic("windows are not implemented yet") } +// WithWindow implements the interface sql.WindowAdaptableExpression. func (c *CompiledAggregateFunction) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { panic("windows are not implemented yet") } +// Window implements the interface sql.WindowAdaptableExpression. func (c *CompiledAggregateFunction) Window() *sql.WindowDefinition { panic("windows are not implemented yet") } diff --git a/server/functions/framework/functions.go b/server/functions/framework/functions.go index e8f41dfa85..a794b95ff0 100644 --- a/server/functions/framework/functions.go +++ b/server/functions/framework/functions.go @@ -298,26 +298,4 @@ func (f Func1Aggregate) NewBuffer() (sql.AggregationBuffer, error) { return f.NewAggBuffer() } -// Func2Aggregate is a function that takes one parameter and is an aggregate function. -type Func2Aggregate struct { - Function2 - NewAggBuffer func() (sql.AggregationBuffer, error) -} - -func (f Func2Aggregate) NewBuffer() (sql.AggregationBuffer, error) { - return f.NewAggBuffer() -} - -// Func3Aggregate is a function that takes one parameter and is an aggregate function. -type Func3Aggregate struct { - Function3 - NewAggBuffer func() (sql.AggregationBuffer, error) -} - -func (f Func3Aggregate) NewBuffer() (sql.AggregationBuffer, error) { - return f.NewAggBuffer() -} - var _ AggregateFunctionInterface = Func1Aggregate{} -var _ AggregateFunctionInterface = Func2Aggregate{} -var _ AggregateFunctionInterface = Func3Aggregate{} From e450b5c4a1aca01c3677d73270d25404410b479c Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Tue, 27 May 2025 16:44:53 -0700 Subject: [PATCH 18/18] new gms --- go.mod | 2 +- go.sum | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 63a4e6a398..c20be68328 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad - github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496 + github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index eff716c837..7c402b7624 100644 --- a/go.sum +++ b/go.sum @@ -266,10 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszNu37VPQckdhX1BPPVzREsGgNxQeefnlm3g= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677 h1:Wn0v7xBxkdzYqDN/4ksI34jstZOskMccT2SYFrvUw4c= -github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677/go.mod h1:5ZdrW0fHZbz+8CngT9gksqSX4H3y+7v1pns7tJCEpu0= -github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496 h1:A1CFUqYWyJTjRQrK8k6LbjAUwEqtEkHrLhdcBgRpDpY= -github.com/dolthub/go-mysql-server v0.20.1-0.20250523233748-82a4fd886496/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk= +github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545 h1:O+/sjRQJadYzyVr89Zh9yCnhZJ0NlHwiDYsXHnj3LsU= +github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -278,8 +276,6 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c h1:imdag6PPCHAO2rZNsFoQoR4I/vIVTmO/czoOl5rUnbk= -github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c h1:23KvsBtJk2GmHpXwQ/RkwIkdNpWL8tWdHRCiidhnaUA= github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=