Skip to content

Commit be852fe

Browse files
authored
Merge pull request #1497 from dolthub/zachmu/aggs
array_agg support and general framework for postgres aggregate functions
2 parents a4a0222 + e450b5c commit be852fe

File tree

14 files changed

+392
-33
lines changed

14 files changed

+392
-33
lines changed

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ require (
1010
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d
1111
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
1212
github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad
13-
github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677
13+
github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545
1414
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
15-
github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c
15+
github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c
1616
github.com/fatih/color v1.13.0
1717
github.com/goccy/go-json v0.10.2
1818
github.com/gogo/protobuf v1.3.2

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
266266
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
267267
github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszNu37VPQckdhX1BPPVzREsGgNxQeefnlm3g=
268268
github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA=
269-
github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677 h1:Wn0v7xBxkdzYqDN/4ksI34jstZOskMccT2SYFrvUw4c=
270-
github.com/dolthub/go-mysql-server v0.20.1-0.20250521012141-b56c7c6eb677/go.mod h1:5ZdrW0fHZbz+8CngT9gksqSX4H3y+7v1pns7tJCEpu0=
269+
github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545 h1:O+/sjRQJadYzyVr89Zh9yCnhZJ0NlHwiDYsXHnj3LsU=
270+
github.com/dolthub/go-mysql-server v0.20.1-0.20250527234113-f38274720545/go.mod h1:Zn9XK5KLYwPbyMpwfeUP+TgnhlgyID2vXf1WcF0M6Fk=
271271
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
272272
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
273273
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
@@ -276,8 +276,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
276276
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
277277
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70=
278278
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA=
279-
github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c h1:imdag6PPCHAO2rZNsFoQoR4I/vIVTmO/czoOl5rUnbk=
280-
github.com/dolthub/vitess v0.0.0-20250512224608-8fb9c6ea092c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
279+
github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c h1:23KvsBtJk2GmHpXwQ/RkwIkdNpWL8tWdHRCiidhnaUA=
280+
github.com/dolthub/vitess v0.0.0-20250523011542-0f6cf9472d1c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
281281
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
282282
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
283283
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=

server/analyzer/init.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package analyzer
1717
import (
1818
"github.com/dolthub/go-mysql-server/sql/analyzer"
1919
"github.com/dolthub/go-mysql-server/sql/plan"
20+
"github.com/dolthub/go-mysql-server/sql/planbuilder"
2021
)
2122

2223
// IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs
@@ -101,6 +102,23 @@ func initEngine() {
101102
// place to put it. Our foreign key validation logic is different from MySQL's, and since it's not an analyzer rule
102103
// we can't swap out a rule like the rest of the logic in this package, we have to do a function swap.
103104
plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition
105+
106+
planbuilder.IsAggregateFunc = IsAggregateFunc
107+
}
108+
109+
// IsAggregateFunc checks if the given function name is an aggregate function. This is the entire set supported by
110+
// MySQL plus some postgres specific ones.
111+
func IsAggregateFunc(name string) bool {
112+
if planbuilder.IsMySQLAggregateFuncName(name) {
113+
return true
114+
}
115+
116+
switch name {
117+
case "array_agg":
118+
return true
119+
}
120+
121+
return false
104122
}
105123

106124
// insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package aggregate
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
20+
"github.com/dolthub/doltgresql/server/functions/framework"
21+
pgtypes "github.com/dolthub/doltgresql/server/types"
22+
)
23+
24+
// initArrayAgg registers the functions to the catalog.
25+
func initArrayAgg() {
26+
framework.RegisterAggregateFunction(array_agg)
27+
}
28+
29+
// array_agg represents the PostgreSQL array_agg function.
30+
var array_agg = framework.Func1Aggregate{
31+
Function1: framework.Function1{
32+
Name: "array_agg",
33+
Return: pgtypes.AnyArray,
34+
Parameters: [1]*pgtypes.DoltgresType{
35+
pgtypes.AnyElement,
36+
},
37+
Callable: func(ctx *sql.Context, paramsAndReturn [2]*pgtypes.DoltgresType, val1 any) (any, error) {
38+
return nil, nil
39+
},
40+
},
41+
NewAggBuffer: newArrayAggBuffer,
42+
}
43+
44+
type arrayAggBuffer struct {
45+
elements []any
46+
}
47+
48+
func newArrayAggBuffer() (sql.AggregationBuffer, error) {
49+
return &arrayAggBuffer{
50+
elements: make([]any, 0),
51+
}, nil
52+
}
53+
54+
func (a *arrayAggBuffer) Dispose() {}
55+
56+
func (a *arrayAggBuffer) Eval(context *sql.Context) (interface{}, error) {
57+
if len(a.elements) == 0 {
58+
return nil, nil
59+
}
60+
return a.elements, nil
61+
}
62+
63+
func (a *arrayAggBuffer) Update(ctx *sql.Context, row sql.Row) error {
64+
a.elements = append(a.elements, row[0])
65+
return nil
66+
}

server/functions/aggregate/init.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package aggregate
16+
17+
func Init() {
18+
initArrayAgg()
19+
}

server/functions/framework/catalog.go

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package framework
1616

1717
import (
18+
"fmt"
1819
"strings"
1920

2021
"github.com/cockroachdb/errors"
@@ -28,6 +29,9 @@ import (
2829
// Catalog contains all of the PostgreSQL functions.
2930
var Catalog = map[string][]FunctionInterface{}
3031

32+
// AggregateCatalog contains all of the PostgreSQL aggregate functions.
33+
var AggregateCatalog = map[string][]AggregateFunctionInterface{}
34+
3135
// initializedFunctions simply states whether Initialize has been called yet.
3236
var initializedFunctions = false
3337

@@ -61,6 +65,21 @@ func RegisterFunction(f FunctionInterface) {
6165
}
6266
}
6367

68+
// RegisterAggregateFunction registers the given function, so that it will be usable from a running server. This should be called
69+
// from within an init().
70+
func RegisterAggregateFunction(f AggregateFunctionInterface) {
71+
if initializedFunctions {
72+
panic("attempted to register a function after the init() phase")
73+
}
74+
switch f := f.(type) {
75+
case Func1Aggregate:
76+
name := strings.ToLower(f.Name)
77+
AggregateCatalog[name] = append(AggregateCatalog[name], f)
78+
default:
79+
panic(fmt.Sprintf("unhandled function type %T", f))
80+
}
81+
}
82+
6483
// Initialize handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not
6584
// apply to PostgreSQL (and functions of the same name often have different behavior).
6685
func Initialize() {
@@ -76,6 +95,7 @@ func Initialize() {
7695
replaceGmsBuiltIns()
7796
validateFunctions()
7897
compileFunctions()
98+
compileAggs()
7999
}
80100

81101
// replaceGmsBuiltIns replaces all GMS built-ins that have conflicting names with PostgreSQL functions.
@@ -151,7 +171,29 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface)
151171
Fn: createFunc,
152172
})
153173
compiledCatalog[funcName] = createFunc
154-
namedCatalog[funcName] = overloads
174+
}
175+
176+
// compileNonOperatorFunction creates a CompiledFunction for each overload of the given function.
177+
func compileAggFunction(funcName string, overloads []AggregateFunctionInterface) {
178+
var newBuffer func() (sql.AggregationBuffer, error)
179+
overloadTree := NewOverloads()
180+
for _, functionOverload := range overloads {
181+
newBuffer = functionOverload.NewBuffer
182+
if err := overloadTree.Add(functionOverload); err != nil {
183+
panic(err)
184+
}
185+
}
186+
187+
// Store the compiled function into the engine's built-in functions
188+
// TODO: don't do this, use an actual contract for communicating these functions to the engine catalog
189+
createFunc := func(params ...sql.Expression) (sql.Expression, error) {
190+
return NewCompiledAggregateFunction(funcName, params, overloadTree, newBuffer), nil
191+
}
192+
function.BuiltIns = append(function.BuiltIns, sql.FunctionN{
193+
Name: funcName,
194+
Fn: createFunc,
195+
})
196+
compiledCatalog[funcName] = createFunc
155197
}
156198

157199
// compileFunctions creates a CompiledFunction for each overload of each function in the catalog.
@@ -165,32 +207,38 @@ func compileFunctions() {
165207
// special rules, so it's far more efficient to reuse it for operators. Operators are also a special case since they
166208
// all have different names, while standard overload deducers work on a function-name basis.
167209
for signature, functionOverload := range unaryFunctions {
168-
overloads, ok := unaryAggregateOverloads[signature.Operator]
210+
overloads, ok := unaryOperatorOverloads[signature.Operator]
169211
if !ok {
170212
overloads = NewOverloads()
171-
unaryAggregateOverloads[signature.Operator] = overloads
213+
unaryOperatorOverloads[signature.Operator] = overloads
172214
}
173215
if err := overloads.Add(functionOverload); err != nil {
174216
panic(err)
175217
}
176218
}
177219

178220
for signature, functionOverload := range binaryFunctions {
179-
overloads, ok := binaryAggregateOverloads[signature.Operator]
221+
overloads, ok := binaryOperatorOverloads[signature.Operator]
180222
if !ok {
181223
overloads = NewOverloads()
182-
binaryAggregateOverloads[signature.Operator] = overloads
224+
binaryOperatorOverloads[signature.Operator] = overloads
183225
}
184226
if err := overloads.Add(functionOverload); err != nil {
185227
panic(err)
186228
}
187229
}
188230

189231
// Add all permutations for the unary and binary operators
190-
for operator, overload := range unaryAggregateOverloads {
191-
unaryAggregatePermutations[operator] = overload.overloadsForParams(1)
232+
for operator, overload := range unaryOperatorOverloads {
233+
unaryOperatorPermutations[operator] = overload.overloadsForParams(1)
234+
}
235+
for operator, overload := range binaryOperatorOverloads {
236+
binaryOperatorPermutations[operator] = overload.overloadsForParams(2)
192237
}
193-
for operator, overload := range binaryAggregateOverloads {
194-
binaryAggregatePermutations[operator] = overload.overloadsForParams(2)
238+
}
239+
240+
func compileAggs() {
241+
for funcName, overloads := range AggregateCatalog {
242+
compileAggFunction(funcName, overloads)
195243
}
196244
}

0 commit comments

Comments
 (0)