Skip to content

Commit 51fc977

Browse files
authored
Fix scalar aggregation with literals in empty result sets (vitessio#18477)
Signed-off-by: Andres Taylor <[email protected]>
1 parent f9c10b9 commit 51fc977

24 files changed

+331
-183
lines changed

go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ func TestAggrWithLimit(t *testing.T) {
8080
mcmp.Exec(fmt.Sprintf("insert into aggr_test(id, val1, val2) values(%d, 'a', %d)", i, r))
8181
}
8282
mcmp.Exec("select val2, count(*) from aggr_test group by val2 order by count(*), val2 limit 10")
83+
if utils.BinaryIsAtLeastAtVersion(23, "vtgate") {
84+
mcmp.Exec("SELECT 1 AS `id`, COUNT(*) FROM (SELECT `id` FROM aggr_test WHERE val1 = 1 LIMIT 100) `t`")
85+
}
8386
}
8487

8588
func TestAggregateTypes(t *testing.T) {

go/vt/sqlparser/analyzer.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package sqlparser
1919
// analyzer.go contains utility analysis functions.
2020

2121
import (
22+
"errors"
2223
"fmt"
2324
"strings"
2425
"unicode"
@@ -379,6 +380,20 @@ func IsColName(node Expr) bool {
379380
return ok
380381
}
381382

383+
var errNotStatic = errors.New("not static")
384+
385+
// IsConstant returns true if the Expr can be evaluated without input or access to tables.
386+
func IsConstant(node Expr) bool {
387+
err := Walk(func(node SQLNode) (kontinue bool, err error) {
388+
switch node.(type) {
389+
case *ColName, *Subquery:
390+
return false, errNotStatic
391+
}
392+
return true, nil
393+
}, node)
394+
return err == nil
395+
}
396+
382397
// IsValue returns true if the Expr is a string, integral or value arg.
383398
// NULL is not considered to be a value.
384399
func IsValue(node Expr) bool {

go/vt/vtgate/engine/aggregations.go

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ import (
3535
// It contains the opcode and input column number.
3636
type AggregateParams struct {
3737
Opcode opcode.AggregateOpcode
38-
Col int
38+
39+
// Input source specification - exactly one of these should be set:
40+
// Col: Column index for simple column references (e.g., SUM(column_name))
41+
// EExpr: Evaluated expression for literals, parameters
42+
Col int
43+
EExpr evalengine.Expr
3944

4045
// These are used only for distinct opcodes.
4146
KeyCol int
@@ -53,15 +58,26 @@ type AggregateParams struct {
5358
CollationEnv *collations.Environment
5459
}
5560

56-
func NewAggregateParam(opcode opcode.AggregateOpcode, col int, alias string, collationEnv *collations.Environment) *AggregateParams {
61+
// NewAggregateParam creates a new aggregate param
62+
func NewAggregateParam(
63+
oc opcode.AggregateOpcode,
64+
col int,
65+
expr evalengine.Expr,
66+
alias string,
67+
collationEnv *collations.Environment,
68+
) *AggregateParams {
69+
if expr != nil && oc != opcode.AggregateConstant {
70+
panic(vterrors.VT13001("expr should be nil"))
71+
}
5772
out := &AggregateParams{
58-
Opcode: opcode,
73+
Opcode: oc,
5974
Col: col,
75+
EExpr: expr,
6076
Alias: alias,
6177
WCol: -1,
6278
CollationEnv: collationEnv,
6379
}
64-
if opcode.NeedsComparableValues() {
80+
if oc.NeedsComparableValues() {
6581
out.KeyCol = col
6682
}
6783
return out
@@ -73,6 +89,9 @@ func (ap *AggregateParams) WAssigned() bool {
7389

7490
func (ap *AggregateParams) String() string {
7591
keyCol := strconv.Itoa(ap.Col)
92+
if ap.EExpr != nil {
93+
keyCol = sqlparser.String(ap.EExpr)
94+
}
7695
if ap.WAssigned() {
7796
keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol)
7897
}
@@ -89,7 +108,14 @@ func (ap *AggregateParams) String() string {
89108
return fmt.Sprintf("%s%s(%s)", ap.Opcode.String(), dispOrigOp, keyCol)
90109
}
91110

92-
func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
111+
func (ap *AggregateParams) typ(inputType querypb.Type, env *evalengine.ExpressionEnv, collID collations.ID) querypb.Type {
112+
if ap.EExpr != nil {
113+
value, err := eval(env, ap.EExpr, collID)
114+
if err != nil {
115+
return sqltypes.Unknown
116+
}
117+
return value.Type()
118+
}
93119
if ap.OrigOpcode != opcode.AggregateUnassigned {
94120
return ap.OrigOpcode.SQLType(inputType)
95121
}
@@ -98,7 +124,7 @@ func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
98124

99125
type aggregator interface {
100126
add(row []sqltypes.Value) error
101-
finish() sqltypes.Value
127+
finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error)
102128
reset()
103129
}
104130

@@ -151,8 +177,8 @@ func (a *aggregatorCount) add(row []sqltypes.Value) error {
151177
return nil
152178
}
153179

154-
func (a *aggregatorCount) finish() sqltypes.Value {
155-
return sqltypes.NewInt64(a.n)
180+
func (a *aggregatorCount) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
181+
return sqltypes.NewInt64(a.n), nil
156182
}
157183

158184
func (a *aggregatorCount) reset() {
@@ -164,13 +190,13 @@ type aggregatorCountStar struct {
164190
n int64
165191
}
166192

167-
func (a *aggregatorCountStar) add(_ []sqltypes.Value) error {
193+
func (a *aggregatorCountStar) add([]sqltypes.Value) error {
168194
a.n++
169195
return nil
170196
}
171197

172-
func (a *aggregatorCountStar) finish() sqltypes.Value {
173-
return sqltypes.NewInt64(a.n)
198+
func (a *aggregatorCountStar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
199+
return sqltypes.NewInt64(a.n), nil
174200
}
175201

176202
func (a *aggregatorCountStar) reset() {
@@ -198,8 +224,8 @@ func (a *aggregatorMax) add(row []sqltypes.Value) (err error) {
198224
return a.minmax.Max(row[a.from])
199225
}
200226

201-
func (a *aggregatorMinMax) finish() sqltypes.Value {
202-
return a.minmax.Result()
227+
func (a *aggregatorMinMax) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
228+
return a.minmax.Result(), nil
203229
}
204230

205231
func (a *aggregatorMinMax) reset() {
@@ -222,8 +248,8 @@ func (a *aggregatorSum) add(row []sqltypes.Value) error {
222248
return a.sum.Add(row[a.from])
223249
}
224250

225-
func (a *aggregatorSum) finish() sqltypes.Value {
226-
return a.sum.Result()
251+
func (a *aggregatorSum) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
252+
return a.sum.Result(), nil
227253
}
228254

229255
func (a *aggregatorSum) reset() {
@@ -232,28 +258,51 @@ func (a *aggregatorSum) reset() {
232258
}
233259

234260
type aggregatorScalar struct {
235-
from int
236-
current sqltypes.Value
237-
init bool
261+
from int
262+
current sqltypes.Value
263+
hasValue bool
238264
}
239265

240266
func (a *aggregatorScalar) add(row []sqltypes.Value) error {
241-
if !a.init {
267+
if !a.hasValue {
242268
a.current = row[a.from]
243-
a.init = true
269+
a.hasValue = true
244270
}
245271
return nil
246272
}
247273

248-
func (a *aggregatorScalar) finish() sqltypes.Value {
249-
return a.current
274+
func (a *aggregatorScalar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
275+
return a.current, nil
250276
}
251277

252278
func (a *aggregatorScalar) reset() {
253279
a.current = sqltypes.NULL
254-
a.init = false
280+
a.hasValue = false
281+
}
282+
283+
type aggregatorConstant struct {
284+
expr evalengine.Expr
285+
}
286+
287+
func (*aggregatorConstant) add([]sqltypes.Value) error {
288+
return nil
255289
}
256290

291+
func (a *aggregatorConstant) finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error) {
292+
return eval(env, a.expr, coll)
293+
}
294+
295+
func eval(env *evalengine.ExpressionEnv, eexpr evalengine.Expr, coll collations.ID) (sqltypes.Value, error) {
296+
v, err := env.Evaluate(eexpr)
297+
if err != nil {
298+
return sqltypes.Value{}, err
299+
}
300+
301+
return v.Value(coll), nil
302+
}
303+
304+
func (*aggregatorConstant) reset() {}
305+
257306
type aggregatorGroupConcat struct {
258307
from int
259308
type_ sqltypes.Type
@@ -275,11 +324,11 @@ func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error {
275324
return nil
276325
}
277326

278-
func (a *aggregatorGroupConcat) finish() sqltypes.Value {
327+
func (a *aggregatorGroupConcat) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
279328
if a.n == 0 {
280-
return sqltypes.NULL
329+
return sqltypes.NULL, nil
281330
}
282-
return sqltypes.MakeTrusted(a.type_, a.concat)
331+
return sqltypes.MakeTrusted(a.type_, a.concat), nil
283332
}
284333

285334
func (a *aggregatorGroupConcat) reset() {
@@ -301,36 +350,44 @@ func (a *aggregatorGtid) add(row []sqltypes.Value) error {
301350
return nil
302351
}
303352

304-
func (a *aggregatorGtid) finish() sqltypes.Value {
353+
func (a *aggregatorGtid) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
305354
gtid := binlogdatapb.VGtid{ShardGtids: a.shards}
306-
return sqltypes.NewVarChar(gtid.String())
355+
return sqltypes.NewVarChar(gtid.String()), nil
307356
}
308357

309358
func (a *aggregatorGtid) reset() {
310359
a.shards = a.shards[:0] // safe to reuse because only the serialized form of a.shards is returned
311360
}
312361

313-
type aggregationState []aggregator
362+
type aggregationState struct {
363+
env *evalengine.ExpressionEnv
364+
aggregators []aggregator
365+
coll collations.ID
366+
}
314367

315-
func (a aggregationState) add(row []sqltypes.Value) error {
316-
for _, st := range a {
368+
func (a *aggregationState) add(row []sqltypes.Value) error {
369+
for _, st := range a.aggregators {
317370
if err := st.add(row); err != nil {
318371
return err
319372
}
320373
}
321374
return nil
322375
}
323376

324-
func (a aggregationState) finish() (row []sqltypes.Value) {
325-
row = make([]sqltypes.Value, 0, len(a))
326-
for _, st := range a {
327-
row = append(row, st.finish())
377+
func (a *aggregationState) finish() ([]sqltypes.Value, error) {
378+
row := make([]sqltypes.Value, 0, len(a.aggregators))
379+
for _, st := range a.aggregators {
380+
v, err := st.finish(a.env, a.coll)
381+
if err != nil {
382+
return nil, err
383+
}
384+
row = append(row, v)
328385
}
329-
return
386+
return row, nil
330387
}
331388

332-
func (a aggregationState) reset() {
333-
for _, st := range a {
389+
func (a *aggregationState) reset() {
390+
for _, st := range a.aggregators {
334391
st.reset()
335392
}
336393
}
@@ -354,13 +411,16 @@ func isComparable(typ sqltypes.Type) bool {
354411
return false
355412
}
356413

357-
func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (aggregationState, []*querypb.Field, error) {
414+
func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env *evalengine.ExpressionEnv, collation collations.ID) (*aggregationState, []*querypb.Field, error) {
358415
fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { return from.CloneVT() })
359416

360-
agstate := make([]aggregator, len(fields))
417+
aggregators := make([]aggregator, len(fields))
361418
for _, aggr := range aggregates {
362-
sourceType := fields[aggr.Col].Type
363-
targetType := aggr.typ(sourceType)
419+
var sourceType querypb.Type
420+
if aggr.Col < len(fields) {
421+
sourceType = fields[aggr.Col].Type
422+
}
423+
targetType := aggr.typ(sourceType, env, collation)
364424

365425
var ag aggregator
366426
var distinct = -1
@@ -444,22 +504,25 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
444504
separator: separator,
445505
}
446506

507+
case opcode.AggregateConstant:
508+
ag = &aggregatorConstant{expr: aggr.EExpr}
509+
447510
default:
448511
panic("BUG: unexpected Aggregation opcode")
449512
}
450513

451-
agstate[aggr.Col] = ag
514+
aggregators[aggr.Col] = ag
452515
fields[aggr.Col].Type = targetType
453516
if aggr.Alias != "" {
454517
fields[aggr.Col].Name = aggr.Alias
455518
}
456519
}
457520

458-
for i, a := range agstate {
521+
for i, a := range aggregators {
459522
if a == nil {
460-
agstate[i] = &aggregatorScalar{from: i}
523+
aggregators[i] = &aggregatorScalar{from: i}
461524
}
462525
}
463526

464-
return agstate, fields, nil
527+
return &aggregationState{aggregators: aggregators, env: env, coll: collation}, fields, nil
465528
}

go/vt/vtgate/engine/cached_size.go

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)