Skip to content

Commit 9366b1c

Browse files
committed
Injected order by
1 parent fa6d02d commit 9366b1c

File tree

2 files changed

+78
-18
lines changed

2 files changed

+78
-18
lines changed

sql/planbuilder/aggregates.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.E
492492
sortFields = append(sortFields, sf)
493493
}
494494

495-
//TODO: this should be acquired at runtime, not at parse time, so fix this
495+
// TODO: this should be acquired at runtime, not at parse time, so fix this
496496
gcml, err := b.ctx.GetSessionVariable(b.ctx, "group_concat_max_len")
497497
if err != nil {
498498
b.handleErr(err)
@@ -515,6 +515,55 @@ func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.E
515515
return col.scalarGf()
516516
}
517517

518+
// buildOrderedInjectedExpr builds an InjectedExpr with an ORDER BY dependency
519+
func (b *Builder) buildOrderedInjectedExpr(inScope *scope, e *ast.OrderedInjectedExpr) sql.Expression {
520+
inScope.initGroupBy()
521+
gb := inScope.groupBy
522+
523+
resolvedChildren := make([]any, len(e.Children))
524+
for i, child := range e.Children {
525+
resolvedChildren[i] = b.buildScalar(inScope, child)
526+
}
527+
528+
orderByScope := b.analyzeOrderBy(inScope, inScope, e.OrderBy)
529+
var sortFields sql.SortFields
530+
for _, c := range orderByScope.cols {
531+
so := sql.Ascending
532+
if c.descending {
533+
so = sql.Descending
534+
}
535+
scalar := c.scalar
536+
if scalar == nil {
537+
scalar = c.scalarGf()
538+
}
539+
sf := sql.SortField{
540+
Column: scalar,
541+
Order: so,
542+
}
543+
sortFields = append(sortFields, sf)
544+
}
545+
546+
resolvedChildren = append(resolvedChildren, sortFields)
547+
548+
expr := b.buildInjectedExpressionFromResolvedChildren(e.InjectedExpr, resolvedChildren)
549+
agg, ok := expr.(sql.Aggregation)
550+
if !ok {
551+
b.handleErr(fmt.Errorf("expected sql.Aggregation, got %T", expr))
552+
}
553+
554+
aggName := strings.ToLower(plan.AliasSubqueryString(agg))
555+
col := scopeColumn{col: aggName, scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
556+
id := gb.outScope.newColumn(col)
557+
558+
agg = agg.WithId(sql.ColumnId(id)).(*aggregation.GroupConcat)
559+
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
560+
col.scalar = agg
561+
562+
gb.addAggStr(col)
563+
col.id = id
564+
return col.scalarGf()
565+
}
566+
518567
func isWindowFunc(name string) bool {
519568
switch name {
520569
case "first", "last", "count", "sum", "any_value",

sql/planbuilder/scalar.go

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
206206
case *ast.GroupConcatExpr:
207207
// TODO this is an aggregation
208208
return b.buildGroupConcat(inScope, v)
209+
case *ast.OrderedInjectedExpr:
210+
// TODO this is an aggregation
211+
return b.buildOrderedInjectedExpr(inScope, v)
209212
case *ast.ParenExpr:
210213
return b.buildScalar(inScope, v.Expr)
211214
case *ast.AndExpr:
@@ -272,23 +275,7 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
272275
}
273276
return ret
274277
case ast.InjectedExpr:
275-
if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, v.Auth); err != nil && b.authEnabled {
276-
b.handleErr(err)
277-
}
278-
resolvedChildren := make([]any, len(v.Children))
279-
for i, child := range v.Children {
280-
resolvedChildren[i] = b.buildScalar(inScope, child)
281-
}
282-
expr, err := v.Expression.WithResolvedChildren(resolvedChildren)
283-
if err != nil {
284-
b.handleErr(err)
285-
return nil
286-
}
287-
if sqlExpr, ok := expr.(sql.Expression); ok {
288-
return sqlExpr
289-
}
290-
b.handleErr(fmt.Errorf("Injected expression does not resolve to a valid expression"))
291-
return nil
278+
return b.buildInjectedExpr(inScope, v)
292279
case *ast.RangeCond:
293280
val := b.buildScalar(inScope, v.Left)
294281
lower := b.buildScalar(inScope, v.From)
@@ -422,6 +409,30 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
422409
return nil
423410
}
424411

412+
func (b *Builder) buildInjectedExpr(inScope *scope, v ast.InjectedExpr) sql.Expression {
413+
if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, v.Auth); err != nil && b.authEnabled {
414+
b.handleErr(err)
415+
}
416+
resolvedChildren := make([]any, len(v.Children))
417+
for i, child := range v.Children {
418+
resolvedChildren[i] = b.buildScalar(inScope, child)
419+
}
420+
return b.buildInjectedExpressionFromResolvedChildren(v, resolvedChildren)
421+
}
422+
423+
func (b *Builder) buildInjectedExpressionFromResolvedChildren(v ast.InjectedExpr, resolvedChildren []any) sql.Expression {
424+
expr, err := v.Expression.WithResolvedChildren(resolvedChildren)
425+
if err != nil {
426+
b.handleErr(err)
427+
return nil
428+
}
429+
if sqlExpr, ok := expr.(sql.Expression); ok {
430+
return sqlExpr
431+
}
432+
b.handleErr(fmt.Errorf("injected expression should resolve to sql.Expression, got %T", expr))
433+
return nil
434+
}
435+
425436
func (b *Builder) getOrigTblName(node sql.Node, alias string) string {
426437
if node == nil {
427438
return ""

0 commit comments

Comments
 (0)