1515package framework
1616
1717import (
18+ "fmt"
1819 "strings"
1920
2021 "github.com/cockroachdb/errors"
@@ -28,6 +29,9 @@ import (
2829// Catalog contains all of the PostgreSQL functions.
2930var Catalog = map [string ][]FunctionInterface {}
3031
32+ // AggregateCatalog contains all of the PostgreSQL aggregate functions.
33+ var AggregateCatalog = map [string ][]AggregateFunctionInterface {}
34+
3135// initializedFunctions simply states whether Initialize has been called yet.
3236var initializedFunctions = false
3337
@@ -61,6 +65,21 @@ func RegisterFunction(f FunctionInterface) {
6165 }
6266}
6367
68+ // RegisterAggregateFunction registers the given function, so that it will be usable from a running server. This should be called
69+ // from within an init().
70+ func RegisterAggregateFunction (f AggregateFunctionInterface ) {
71+ if initializedFunctions {
72+ panic ("attempted to register a function after the init() phase" )
73+ }
74+ switch f := f .(type ) {
75+ case Func1Aggregate :
76+ name := strings .ToLower (f .Name )
77+ AggregateCatalog [name ] = append (AggregateCatalog [name ], f )
78+ default :
79+ panic (fmt .Sprintf ("unhandled function type %T" , f ))
80+ }
81+ }
82+
6483// Initialize handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not
6584// apply to PostgreSQL (and functions of the same name often have different behavior).
6685func Initialize () {
@@ -76,6 +95,7 @@ func Initialize() {
7695 replaceGmsBuiltIns ()
7796 validateFunctions ()
7897 compileFunctions ()
98+ compileAggs ()
7999}
80100
81101// replaceGmsBuiltIns replaces all GMS built-ins that have conflicting names with PostgreSQL functions.
@@ -151,7 +171,29 @@ func compileNonOperatorFunction(funcName string, overloads []FunctionInterface)
151171 Fn : createFunc ,
152172 })
153173 compiledCatalog [funcName ] = createFunc
154- namedCatalog [funcName ] = overloads
174+ }
175+
176+ // compileNonOperatorFunction creates a CompiledFunction for each overload of the given function.
177+ func compileAggFunction (funcName string , overloads []AggregateFunctionInterface ) {
178+ var newBuffer func () (sql.AggregationBuffer , error )
179+ overloadTree := NewOverloads ()
180+ for _ , functionOverload := range overloads {
181+ newBuffer = functionOverload .NewBuffer
182+ if err := overloadTree .Add (functionOverload ); err != nil {
183+ panic (err )
184+ }
185+ }
186+
187+ // Store the compiled function into the engine's built-in functions
188+ // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog
189+ createFunc := func (params ... sql.Expression ) (sql.Expression , error ) {
190+ return NewCompiledAggregateFunction (funcName , params , overloadTree , newBuffer ), nil
191+ }
192+ function .BuiltIns = append (function .BuiltIns , sql.FunctionN {
193+ Name : funcName ,
194+ Fn : createFunc ,
195+ })
196+ compiledCatalog [funcName ] = createFunc
155197}
156198
157199// compileFunctions creates a CompiledFunction for each overload of each function in the catalog.
@@ -165,32 +207,38 @@ func compileFunctions() {
165207 // special rules, so it's far more efficient to reuse it for operators. Operators are also a special case since they
166208 // all have different names, while standard overload deducers work on a function-name basis.
167209 for signature , functionOverload := range unaryFunctions {
168- overloads , ok := unaryAggregateOverloads [signature .Operator ]
210+ overloads , ok := unaryOperatorOverloads [signature .Operator ]
169211 if ! ok {
170212 overloads = NewOverloads ()
171- unaryAggregateOverloads [signature .Operator ] = overloads
213+ unaryOperatorOverloads [signature .Operator ] = overloads
172214 }
173215 if err := overloads .Add (functionOverload ); err != nil {
174216 panic (err )
175217 }
176218 }
177219
178220 for signature , functionOverload := range binaryFunctions {
179- overloads , ok := binaryAggregateOverloads [signature .Operator ]
221+ overloads , ok := binaryOperatorOverloads [signature .Operator ]
180222 if ! ok {
181223 overloads = NewOverloads ()
182- binaryAggregateOverloads [signature .Operator ] = overloads
224+ binaryOperatorOverloads [signature .Operator ] = overloads
183225 }
184226 if err := overloads .Add (functionOverload ); err != nil {
185227 panic (err )
186228 }
187229 }
188230
189231 // Add all permutations for the unary and binary operators
190- for operator , overload := range unaryAggregateOverloads {
191- unaryAggregatePermutations [operator ] = overload .overloadsForParams (1 )
232+ for operator , overload := range unaryOperatorOverloads {
233+ unaryOperatorPermutations [operator ] = overload .overloadsForParams (1 )
234+ }
235+ for operator , overload := range binaryOperatorOverloads {
236+ binaryOperatorPermutations [operator ] = overload .overloadsForParams (2 )
192237 }
193- for operator , overload := range binaryAggregateOverloads {
194- binaryAggregatePermutations [operator ] = overload .overloadsForParams (2 )
238+ }
239+
240+ func compileAggs () {
241+ for funcName , overloads := range AggregateCatalog {
242+ compileAggFunction (funcName , overloads )
195243 }
196244}
0 commit comments