Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ require (
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad
github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545
github.com/dolthub/go-mysql-server v0.20.1-0.20250531000817-b7b74d41e84e
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c
github.com/dolthub/vitess v0.0.0-20250530231040-bfd522856394
github.com/fatih/color v1.13.0
github.com/goccy/go-json v0.10.2
github.com/gogo/protobuf v1.3.2
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszNu37VPQckdhX1BPPVzREsGgNxQeefnlm3g=
github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA=
github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545 h1:O+/sjRQJadYzyVr89Zh9yCnhZJ0NlHwiDYsXHnj3LsU=
github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk=
github.com/dolthub/go-mysql-server v0.20.1-0.20250531000817-b7b74d41e84e h1:mZHcAqI2JsoAJbJr8lWsxuIoNpx7NMmbIakioitEHu4=
github.com/dolthub/go-mysql-server v0.20.1-0.20250531000817-b7b74d41e84e/go.mod h1:nzF9N8zhb7MhYypvwvHfKrN/MaDfuX4K5zXsiK0XvDg=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
Expand All @@ -276,8 +276,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70=
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA=
github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c h1:23KvsBtJk2GmHpXwQ/RkwIkdNpWL8tWdHRCiidhnaUA=
github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
github.com/dolthub/vitess v0.0.0-20250530231040-bfd522856394 h1:sMwntvk7O9dttaJLqnOvy8zgk0ah9qnyWkAahfOgnIo=
github.com/dolthub/vitess v0.0.0-20250530231040-bfd522856394/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
Expand Down
31 changes: 24 additions & 7 deletions server/ast/func_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgexprs "github.com/dolthub/doltgresql/server/expression"
)

// nodeFuncExpr handles *tree.FuncExpr nodes.
Expand All @@ -41,13 +42,12 @@ func nodeFuncExpr(ctx *Context, node *tree.FuncExpr) (vitess.Expr, error) {
case *tree.FunctionDefinition:
name = vitess.NewColIdent(funcRef.Name)
case *tree.UnresolvedName:
if funcRef.NumParts > 2 {
return nil, errors.Errorf("referencing items outside the schema or database is not yet supported")
colName, err := unresolvedNameToColName(funcRef)
if err != nil {
return nil, err
}
if funcRef.NumParts == 2 {
qualifier = vitess.NewTableIdent(funcRef.Parts[1])
}
name = vitess.NewColIdent(funcRef.Parts[0])

name = colName.Name
default:
return nil, errors.Errorf("unknown function reference")
}
Expand All @@ -69,8 +69,8 @@ func nodeFuncExpr(ctx *Context, node *tree.FuncExpr) (vitess.Expr, error) {
return nil, err
}

// special case for string_agg, which maps to the mysql aggregate function group_concat
switch strings.ToLower(name.String()) {
// special case for string_agg, which maps to the mysql aggregate function group_concat
case "string_agg":
if len(node.Exprs) != 2 {
return nil, errors.Errorf("string_agg requires two arguments")
Expand All @@ -96,6 +96,23 @@ func nodeFuncExpr(ctx *Context, node *tree.FuncExpr) (vitess.Expr, error) {
},
OrderBy: orderBy,
}, nil
case "array_agg":
var orderBy vitess.OrderBy
if len(node.OrderBy) > 0 {
orderBy, err = nodeOrderBy(ctx, node.OrderBy)
if err != nil {
return nil, err
}
}

return &vitess.OrderedInjectedExpr{
InjectedExpr: vitess.InjectedExpr{
Expression: &pgexprs.ArrayAgg{},
SelectExprChildren: exprs,
Auth: vitess.AuthInformation{},
},
OrderBy: orderBy,
}, nil
}

if len(node.OrderBy) > 0 {
Expand Down
211 changes: 211 additions & 0 deletions server/expression/array_agg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"sort"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/server/types"
)

type ArrayAgg struct {
selectExprs []sql.Expression
orderBy sql.SortFields
id sql.ColumnId
}

var _ sql.Aggregation = (*ArrayAgg)(nil)
var _ vitess.Injectable = (*ArrayAgg)(nil)

// WithResolvedChildren returns a new ArrayAgg with the provided children as its select expressions.
// The last child is expected to be the order by expressions.
func (a *ArrayAgg) WithResolvedChildren(children []any) (any, error) {
a.selectExprs = make([]sql.Expression, len(children)-1)
for i := 0; i < len(children)-1; i++ {
a.selectExprs[i] = children[i].(sql.Expression)
}

a.orderBy = children[len(children)-1].(sql.SortFields)
return a, nil
}

// Resolved implements sql.Expression
func (a *ArrayAgg) Resolved() bool {
return expression.ExpressionsResolved(a.selectExprs...) && expression.ExpressionsResolved(a.orderBy.ToExpressions()...)
}

// String implements sql.Expression
func (a *ArrayAgg) String() string {
sb := strings.Builder{}
sb.WriteString("array_agg(")

if a.selectExprs != nil {
var exprs = make([]string, len(a.selectExprs))
for i, expr := range a.selectExprs {
exprs[i] = expr.String()
}

sb.WriteString(strings.Join(exprs, ", "))
}

if len(a.orderBy) > 0 {
sb.WriteString(" order by ")
for i, ob := range a.orderBy {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(ob.String())
}
}

sb.WriteString(")")
return sb.String()
}

// Type implements sql.Expression
func (a *ArrayAgg) Type() sql.Type {
dt := a.selectExprs[0].Type().(*types.DoltgresType)
return dt.ToArrayType()
}

// IsNullable implements sql.Expression
func (a *ArrayAgg) IsNullable() bool {
return true
}

// Eval implements sql.Expression
func (a *ArrayAgg) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
panic("eval should never be called on an aggregation function")
}

// Children implements sql.Expression
func (a *ArrayAgg) Children() []sql.Expression {
return append(a.selectExprs, a.orderBy.ToExpressions()...)
}

// WithChildren implements sql.Expression
func (a ArrayAgg) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != len(a.selectExprs)+len(a.orderBy) {
return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), len(a.selectExprs)+len(a.orderBy))
}

a.selectExprs = children[:len(a.selectExprs)]
a.orderBy = a.orderBy.FromExpressions(children[len(a.selectExprs):]...)
return &a, nil
}

// Id implements sql.IdExpression
func (a *ArrayAgg) Id() sql.ColumnId {
return a.id
}

// WithId implements sql.IdExpression
func (a ArrayAgg) WithId(id sql.ColumnId) sql.IdExpression {
a.id = id
return &a
}

// NewWindowFunction implements sql.WindowAdaptableExpression
func (a *ArrayAgg) NewWindowFunction() (sql.WindowFunction, error) {
panic("window functions not yet supported for array_agg")
}

// WithWindow implements sql.WindowAdaptableExpression
func (a *ArrayAgg) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression {
panic("window functions not yet supported for array_agg")
}

// Window implements sql.WindowAdaptableExpression
func (a *ArrayAgg) Window() *sql.WindowDefinition {
return nil
}

// NewBuffer implements sql.Aggregation
func (a *ArrayAgg) NewBuffer() (sql.AggregationBuffer, error) {
return &arrayAggBuffer{
elements: make([]sql.Row, 0),
a: a,
}, nil
}

// arrayAggBuffer is the buffer used to accumulate values for the array_agg aggregation function.
type arrayAggBuffer struct {
elements []sql.Row
a *ArrayAgg
}

// Dispose implements sql.AggregationBuffer
func (a *arrayAggBuffer) Dispose() {}

// Eval implements sql.AggregationBuffer
func (a *arrayAggBuffer) Eval(ctx *sql.Context) (interface{}, error) {
if len(a.elements) == 0 {
return nil, nil
}

if a.a.orderBy != nil {
sorter := &expression.Sorter{
SortFields: a.a.orderBy,
Rows: a.elements,
Ctx: ctx,
}

sort.Stable(sorter)
if sorter.LastError != nil {
return nil, sorter.LastError
}
}

// convert to []interface for return. The last element in each row is the one we want to return, the rest are sort fields.
result := make([]interface{}, len(a.elements))
for i, row := range a.elements {
result[i] = row[(len(row) - 1)]
}

return result, nil
}

// Update implements sql.AggregationBuffer
func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error {
evalRow, err := evalExprs(ctx, a.a.selectExprs, row)
if err != nil {
return err
}

// TODO: unwrap values as necessary
// Append the current value to the end of the row. We want to preserve the row's original structure
// for sort ordering in the final step.
a.elements = append(a.elements, append(row, nil, evalRow[0]))
return nil
}

// evalExprs evaluates the provided expressions against the given row and returns the results as a new row.
func evalExprs(ctx *sql.Context, exprs []sql.Expression, row sql.Row) (sql.Row, error) {
result := make(sql.Row, len(exprs))
for i, expr := range exprs {
var err error
result[i], err = expr.Eval(ctx, row)
if err != nil {
return nil, err
}
}

return result, nil
}
Loading