Skip to content

Commit efbcac5

Browse files
[release-21.0] Fix scalar aggregation with literals in empty result sets (vitessio#18477) (vitessio#18490)
Signed-off-by: Dirkjan Bussink <[email protected]> Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com>
1 parent 53524ee commit efbcac5

24 files changed

+332
-182
lines changed

go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ func TestAggrWithLimit(t *testing.T) {
8080
mcmp.Exec(fmt.Sprintf("insert into aggr_test(id, val1, val2) values(%d, 'a', %d)", i, r))
8181
}
8282
mcmp.Exec("select val2, count(*) from aggr_test group by val2 order by count(*), val2 limit 10")
83+
if utils.BinaryIsAtLeastAtVersion(21, "vtgate") {
84+
mcmp.Exec("SELECT 1 AS `id`, COUNT(*) FROM (SELECT `id` FROM aggr_test WHERE val1 = 1 LIMIT 100) `t`")
85+
}
8386
}
8487

8588
func TestAggregateTypes(t *testing.T) {

go/vt/sqlparser/analyzer.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package sqlparser
1919
// analyzer.go contains utility analysis functions.
2020

2121
import (
22+
"errors"
2223
"fmt"
2324
"strings"
2425
"unicode"
@@ -383,6 +384,20 @@ func IsColName(node Expr) bool {
383384
return ok
384385
}
385386

387+
var errNotStatic = errors.New("not static")
388+
389+
// IsConstant returns true if the Expr can be evaluated without input or access to tables.
390+
func IsConstant(node Expr) bool {
391+
err := Walk(func(node SQLNode) (kontinue bool, err error) {
392+
switch node.(type) {
393+
case *ColName, *Subquery:
394+
return false, errNotStatic
395+
}
396+
return true, nil
397+
}, node)
398+
return err == nil
399+
}
400+
386401
// IsValue returns true if the Expr is a string, integral or value arg.
387402
// NULL is not considered to be a value.
388403
func IsValue(node Expr) bool {

go/vt/vtgate/engine/aggregations.go

Lines changed: 108 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ import (
3535
// It contains the opcode and input column number.
3636
type AggregateParams struct {
3737
Opcode AggregateOpcode
38-
Col int
38+
// Input source specification - exactly one of these should be set:
39+
// Col: Column index for simple column references (e.g., SUM(column_name))
40+
// EExpr: Evaluated expression for literals, parameters
41+
Col int
42+
EExpr evalengine.Expr
3943

4044
// These are used only for distinct opcodes.
4145
KeyCol int
@@ -53,15 +57,26 @@ type AggregateParams struct {
5357
CollationEnv *collations.Environment
5458
}
5559

56-
func NewAggregateParam(opcode AggregateOpcode, col int, alias string, collationEnv *collations.Environment) *AggregateParams {
60+
// NewAggregateParam creates a new aggregate param
61+
func NewAggregateParam(
62+
oc AggregateOpcode,
63+
col int,
64+
expr evalengine.Expr,
65+
alias string,
66+
collationEnv *collations.Environment,
67+
) *AggregateParams {
68+
if expr != nil && oc != AggregateConstant {
69+
panic(vterrors.VT13001("expr should be nil"))
70+
}
5771
out := &AggregateParams{
58-
Opcode: opcode,
72+
Opcode: oc,
5973
Col: col,
74+
EExpr: expr,
6075
Alias: alias,
6176
WCol: -1,
6277
CollationEnv: collationEnv,
6378
}
64-
if opcode.NeedsComparableValues() {
79+
if oc.NeedsComparableValues() {
6580
out.KeyCol = col
6681
}
6782
return out
@@ -73,6 +88,9 @@ func (ap *AggregateParams) WAssigned() bool {
7388

7489
func (ap *AggregateParams) String() string {
7590
keyCol := strconv.Itoa(ap.Col)
91+
if ap.EExpr != nil {
92+
keyCol = sqlparser.String(ap.EExpr)
93+
}
7694
if ap.WAssigned() {
7795
keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol)
7896
}
@@ -89,7 +107,14 @@ func (ap *AggregateParams) String() string {
89107
return fmt.Sprintf("%s%s(%s)", ap.Opcode.String(), dispOrigOp, keyCol)
90108
}
91109

92-
func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
110+
func (ap *AggregateParams) typ(inputType querypb.Type, env *evalengine.ExpressionEnv, collID collations.ID) querypb.Type {
111+
if ap.EExpr != nil {
112+
value, err := eval(env, ap.EExpr, collID)
113+
if err != nil {
114+
return sqltypes.Unknown
115+
}
116+
return value.Type()
117+
}
93118
if ap.OrigOpcode != AggregateUnassigned {
94119
return ap.OrigOpcode.SQLType(inputType)
95120
}
@@ -98,7 +123,7 @@ func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
98123

99124
type aggregator interface {
100125
add(row []sqltypes.Value) error
101-
finish() sqltypes.Value
126+
finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error)
102127
reset()
103128
}
104129

@@ -151,8 +176,8 @@ func (a *aggregatorCount) add(row []sqltypes.Value) error {
151176
return nil
152177
}
153178

154-
func (a *aggregatorCount) finish() sqltypes.Value {
155-
return sqltypes.NewInt64(a.n)
179+
func (a *aggregatorCount) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
180+
return sqltypes.NewInt64(a.n), nil
156181
}
157182

158183
func (a *aggregatorCount) reset() {
@@ -164,13 +189,13 @@ type aggregatorCountStar struct {
164189
n int64
165190
}
166191

167-
func (a *aggregatorCountStar) add(_ []sqltypes.Value) error {
192+
func (a *aggregatorCountStar) add([]sqltypes.Value) error {
168193
a.n++
169194
return nil
170195
}
171196

172-
func (a *aggregatorCountStar) finish() sqltypes.Value {
173-
return sqltypes.NewInt64(a.n)
197+
func (a *aggregatorCountStar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
198+
return sqltypes.NewInt64(a.n), nil
174199
}
175200

176201
func (a *aggregatorCountStar) reset() {
@@ -198,8 +223,8 @@ func (a *aggregatorMax) add(row []sqltypes.Value) (err error) {
198223
return a.minmax.Max(row[a.from])
199224
}
200225

201-
func (a *aggregatorMinMax) finish() sqltypes.Value {
202-
return a.minmax.Result()
226+
func (a *aggregatorMinMax) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
227+
return a.minmax.Result(), nil
203228
}
204229

205230
func (a *aggregatorMinMax) reset() {
@@ -222,8 +247,8 @@ func (a *aggregatorSum) add(row []sqltypes.Value) error {
222247
return a.sum.Add(row[a.from])
223248
}
224249

225-
func (a *aggregatorSum) finish() sqltypes.Value {
226-
return a.sum.Result()
250+
func (a *aggregatorSum) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
251+
return a.sum.Result(), nil
227252
}
228253

229254
func (a *aggregatorSum) reset() {
@@ -232,28 +257,51 @@ func (a *aggregatorSum) reset() {
232257
}
233258

234259
type aggregatorScalar struct {
235-
from int
236-
current sqltypes.Value
237-
init bool
260+
from int
261+
current sqltypes.Value
262+
hasValue bool
238263
}
239264

240265
func (a *aggregatorScalar) add(row []sqltypes.Value) error {
241-
if !a.init {
266+
if !a.hasValue {
242267
a.current = row[a.from]
243-
a.init = true
268+
a.hasValue = true
244269
}
245270
return nil
246271
}
247272

248-
func (a *aggregatorScalar) finish() sqltypes.Value {
249-
return a.current
273+
func (a *aggregatorScalar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
274+
return a.current, nil
250275
}
251276

252277
func (a *aggregatorScalar) reset() {
253278
a.current = sqltypes.NULL
254-
a.init = false
279+
a.hasValue = false
255280
}
256281

282+
type aggregatorConstant struct {
283+
expr evalengine.Expr
284+
}
285+
286+
func (*aggregatorConstant) add([]sqltypes.Value) error {
287+
return nil
288+
}
289+
290+
func (a *aggregatorConstant) finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error) {
291+
return eval(env, a.expr, coll)
292+
}
293+
294+
func eval(env *evalengine.ExpressionEnv, eexpr evalengine.Expr, coll collations.ID) (sqltypes.Value, error) {
295+
v, err := env.Evaluate(eexpr)
296+
if err != nil {
297+
return sqltypes.Value{}, err
298+
}
299+
300+
return v.Value(coll), nil
301+
}
302+
303+
func (*aggregatorConstant) reset() {}
304+
257305
type aggregatorGroupConcat struct {
258306
from int
259307
type_ sqltypes.Type
@@ -275,11 +323,11 @@ func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error {
275323
return nil
276324
}
277325

278-
func (a *aggregatorGroupConcat) finish() sqltypes.Value {
326+
func (a *aggregatorGroupConcat) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
279327
if a.n == 0 {
280-
return sqltypes.NULL
328+
return sqltypes.NULL, nil
281329
}
282-
return sqltypes.MakeTrusted(a.type_, a.concat)
330+
return sqltypes.MakeTrusted(a.type_, a.concat), nil
283331
}
284332

285333
func (a *aggregatorGroupConcat) reset() {
@@ -301,36 +349,44 @@ func (a *aggregatorGtid) add(row []sqltypes.Value) error {
301349
return nil
302350
}
303351

304-
func (a *aggregatorGtid) finish() sqltypes.Value {
352+
func (a *aggregatorGtid) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) {
305353
gtid := binlogdatapb.VGtid{ShardGtids: a.shards}
306-
return sqltypes.NewVarChar(gtid.String())
354+
return sqltypes.NewVarChar(gtid.String()), nil
307355
}
308356

309357
func (a *aggregatorGtid) reset() {
310358
a.shards = a.shards[:0] // safe to reuse because only the serialized form of a.shards is returned
311359
}
312360

313-
type aggregationState []aggregator
361+
type aggregationState struct {
362+
env *evalengine.ExpressionEnv
363+
aggregators []aggregator
364+
coll collations.ID
365+
}
314366

315-
func (a aggregationState) add(row []sqltypes.Value) error {
316-
for _, st := range a {
367+
func (a *aggregationState) add(row []sqltypes.Value) error {
368+
for _, st := range a.aggregators {
317369
if err := st.add(row); err != nil {
318370
return err
319371
}
320372
}
321373
return nil
322374
}
323375

324-
func (a aggregationState) finish() (row []sqltypes.Value) {
325-
row = make([]sqltypes.Value, 0, len(a))
326-
for _, st := range a {
327-
row = append(row, st.finish())
376+
func (a *aggregationState) finish() ([]sqltypes.Value, error) {
377+
row := make([]sqltypes.Value, 0, len(a.aggregators))
378+
for _, st := range a.aggregators {
379+
v, err := st.finish(a.env, a.coll)
380+
if err != nil {
381+
return nil, err
382+
}
383+
row = append(row, v)
328384
}
329-
return
385+
return row, nil
330386
}
331387

332-
func (a aggregationState) reset() {
333-
for _, st := range a {
388+
func (a *aggregationState) reset() {
389+
for _, st := range a.aggregators {
334390
st.reset()
335391
}
336392
}
@@ -354,13 +410,16 @@ func isComparable(typ sqltypes.Type) bool {
354410
return false
355411
}
356412

357-
func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (aggregationState, []*querypb.Field, error) {
413+
func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env *evalengine.ExpressionEnv, collation collations.ID) (*aggregationState, []*querypb.Field, error) {
358414
fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { return from.CloneVT() })
359415

360-
agstate := make([]aggregator, len(fields))
416+
aggregators := make([]aggregator, len(fields))
361417
for _, aggr := range aggregates {
362-
sourceType := fields[aggr.Col].Type
363-
targetType := aggr.typ(sourceType)
418+
var sourceType querypb.Type
419+
if aggr.Col < len(fields) {
420+
sourceType = fields[aggr.Col].Type
421+
}
422+
targetType := aggr.typ(sourceType, env, collation)
364423

365424
var ag aggregator
366425
var distinct = -1
@@ -444,22 +503,25 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
444503
separator: separator,
445504
}
446505

506+
case AggregateConstant:
507+
ag = &aggregatorConstant{expr: aggr.EExpr}
508+
447509
default:
448510
panic("BUG: unexpected Aggregation opcode")
449511
}
450512

451-
agstate[aggr.Col] = ag
513+
aggregators[aggr.Col] = ag
452514
fields[aggr.Col].Type = targetType
453515
if aggr.Alias != "" {
454516
fields[aggr.Col].Name = aggr.Alias
455517
}
456518
}
457519

458-
for i, a := range agstate {
520+
for i, a := range aggregators {
459521
if a == nil {
460-
agstate[i] = &aggregatorScalar{from: i}
522+
aggregators[i] = &aggregatorScalar{from: i}
461523
}
462524
}
463525

464-
return agstate, fields, nil
526+
return &aggregationState{aggregators: aggregators, env: env, coll: collation}, fields, nil
465527
}

go/vt/vtgate/engine/cached_size.go

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)