From 4bc7fb656afc2b0f40c7dcd187b2f225cefdcd45 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 8 Apr 2025 16:42:51 -0700 Subject: [PATCH 01/19] implment std and tests --- enginetest/queries/script_queries.go | 51 ++++++++++ optgen/cmd/source/unary_aggs.yaml | 4 +- sql/expression/function/aggregation/common.go | 4 +- .../function/aggregation/std_test.go | 97 +++++++++++++++++++ .../function/aggregation/unary_agg_buffers.go | 66 ++++++++++++- .../function/aggregation/unary_aggs.og.go | 79 +++++++++++++++ .../function/aggregation/window_functions.go | 52 ++++++++++ sql/expression/function/registry.go | 3 + 8 files changed, 352 insertions(+), 4 deletions(-) create mode 100644 sql/expression/function/aggregation/std_test.go diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index e8debaa83b..1fbf454ba2 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7887,6 +7887,57 @@ where }, }, }, + { + Name: "std, stdev, stddev_pop tests", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Expected: []sql.Row{ + {nil, nil, nil}, + }, + }, + { + Query: "insert into t values (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Expected: []sql.Row{ + {0.0, 0.0, 0.0}, + }, + }, + { + Query: "insert into t values (2);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Expected: []sql.Row{ + {0.5, 0.5, 0.5}, + }, + }, + { + Query: "insert into t values (3);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Expected: []sql.Row{ + {0.816496580927726, 0.816496580927726, 0.816496580927726}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/optgen/cmd/source/unary_aggs.yaml b/optgen/cmd/source/unary_aggs.yaml index 6f58490e12..a59929bd5e 100644 --- a/optgen/cmd/source/unary_aggs.yaml +++ b/optgen/cmd/source/unary_aggs.yaml @@ -31,4 +31,6 @@ unaryAggs: desc: "returns the minimum value of expr in all rows." - name: "Sum" desc: "returns the sum of expr in all rows" - nullable: false \ No newline at end of file + nullable: false +- name: "StdDevPop" + desc: "returns the population standard deviation of expr" \ No newline at end of file diff --git a/sql/expression/function/aggregation/common.go b/sql/expression/function/aggregation/common.go index 6637685dfe..ee738fd261 100644 --- a/sql/expression/function/aggregation/common.go +++ b/sql/expression/function/aggregation/common.go @@ -125,11 +125,11 @@ func (a *unaryAggBase) WithChildren(children ...sql.Expression) (sql.Expression, return &na, nil } -func (a unaryAggBase) FunctionName() string { +func (a *unaryAggBase) FunctionName() string { return a.functionName } -func (a unaryAggBase) Description() string { +func (a *unaryAggBase) Description() string { return a.description } diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go new file mode 100644 index 0000000000..5e6de0e6e3 --- /dev/null +++ b/sql/expression/function/aggregation/std_test.go @@ -0,0 +1,97 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregation + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" +) + +func TestStd(t *testing.T) { + sum := NewStdDevPop(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.118033988749895, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 0.9601432184835761, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.0825317547305484, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 1.0, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 1.0, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 1.0, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} + diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index c06037530e..080726d24c 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -2,7 +2,8 @@ package aggregation import ( "fmt" - "reflect" + "math" +"reflect" "github.com/cespare/xxhash/v2" "github.com/shopspring/decimal" @@ -666,3 +667,66 @@ func (j *jsonArrayBuffer) Eval(ctx *sql.Context) (interface{}, error) { // Dispose implements the Disposable interface. func (j *jsonArrayBuffer) Dispose() { } + +type stdDevPopBuffer struct { + vals []interface{} + expr sql.Expression + + count int64 + oldMean float64 + newMean float64 + oldVar float64 + newVar float64 +} + +func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { + return &stdDevPopBuffer{ + vals: nil, + expr: child, + } +} + +// Update implements the AggregationBuffer interface. +func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error { + v, err := s.expr.Eval(ctx, row) + if err != nil { + return err + } + + // TODO: convert val to appropriate type + v, _, err = types.Float64.Convert(ctx, v) + if err != nil { + v = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + } + if v == nil { + return nil + } + val := v.(float64) + + s.count += 1 + if s.count == 1 { + s.oldMean = val + s.newMean = val + return nil + } + + s.newMean = s.oldMean + (val - s.oldMean) / float64(s.count) + s.newVar = s.oldVar + (val - s.oldMean) * (val - s.newMean) + s.oldVar = s.newVar + s.oldMean = s.newMean + + return nil +} + +// Eval implements the AggregationBuffer interface. +func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if s.count == 0 { + return nil, nil + } + return math.Sqrt(s.newVar / float64(s.count)), nil // TODO: sqrt? +} + +// Dispose implements the Disposable interface. +func (s *stdDevPopBuffer) Dispose() { +} \ No newline at end of file diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 7e3736a7e7..be64d85ebb 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -958,3 +958,82 @@ func (a *Sum) NewWindowFunction() (sql.WindowFunction, error) { } return NewSumAgg(child).WithWindow(a.Window()) } + +type StdDevPop struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*StdDevPop)(nil) +var _ sql.Aggregation = (*StdDevPop)(nil) +var _ sql.WindowAdaptableExpression = (*StdDevPop)(nil) + +func NewStdDevPop(e sql.Expression) *StdDevPop { + return &StdDevPop{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "StdDevPop", + description: "returns the population standard deviation of expr", + }, + } +} + +func (a *StdDevPop) Type() sql.Type { + return a.Child.Type() +} + +func (a *StdDevPop) IsNullable() bool { + return false +} + +func (a *StdDevPop) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEV_POP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEV_POP(%s)", a.Child) +} + +func (a *StdDevPop) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEV_POP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEV_POP(%s)", sql.DebugString(a.Child)) +} + +func (a *StdDevPop) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &StdDevPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevPop) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &StdDevPop{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *StdDevPop) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &StdDevPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevPop) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevPopBuffer(child), nil +} + +func (a *StdDevPop) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevPopAgg(child).WithWindow(a.Window()) +} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 7fd0c8e530..21d859f59a 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1408,3 +1408,55 @@ func (a *leadLagBase) Compute(ctx *sql.Context, interval sql.WindowInterval, buf a.pos++ return res } + +type StdDevPopAgg struct { + expr sql.Expression + framer sql.WindowFramer +} + +func NewStdDevPopAgg(e sql.Expression) *StdDevPopAgg { + return &StdDevPopAgg{ + expr: e, + } +} + +func (s *StdDevPopAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *s + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (s *StdDevPopAgg) Dispose() { + expression.Dispose(s.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (s *StdDevPopAgg) DefaultFramer() sql.WindowFramer { + if s.framer != nil { + return s.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + s.Dispose() + return nil +} + +func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + for i := interval.Start; i < interval.End; i++ { + row := buf[i] + v, err := s.expr.Eval(ctx, row) + if err != nil { + return err + } + return v + } + return nil +} \ No newline at end of file diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 53050e846f..f5129d04a5 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -289,6 +289,9 @@ var BuiltIns = []sql.Function{ sql.FunctionN{Name: "substring", Fn: NewSubstring}, sql.Function3{Name: "substring_index", Fn: NewSubstringIndex}, sql.Function1{Name: "sum", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) }}, + sql.Function1{Name: "std", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, + sql.Function1{Name: "stddev", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, + sql.Function1{Name: "stddev_pop", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, sql.FunctionN{Name: "sysdate", Fn: NewSysdate}, sql.Function1{Name: "tan", Fn: NewTan}, sql.Function1{Name: "time", Fn: NewTime}, From e29872079371255b367b910522edd6413c07a806 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 8 Apr 2025 16:53:23 -0700 Subject: [PATCH 02/19] another test --- enginetest/queries/script_queries.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 1fbf454ba2..654a950053 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7892,6 +7892,9 @@ where Dialect: "mysql", SetUpScript: []string{ "create table t (i int);", + "create table tt (i int, j int);", + "insert into tt values (0, 1), (0, 2), (0, 3);", + "insert into tt values (1, 123), (1, 456), (1, 789);", }, Assertions: []ScriptTestAssertion{ { @@ -7936,6 +7939,13 @@ where {0.816496580927726, 0.816496580927726, 0.816496580927726}, }, }, + { + Query: "select i, std(j) from tt group by i;", + Expected: []sql.Row{ + {0, 0.816496580927726}, + {1, 271.89336144893275}, + }, + }, }, }, } From 366388ba35f7ab1b53065067728ac5f41b697773 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 9 Apr 2025 00:09:44 +0000 Subject: [PATCH 03/19] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/aggregation/std_test.go | 1 - .../function/aggregation/unary_agg_buffers.go | 10 +++++----- .../function/aggregation/window_functions.go | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go index 5e6de0e6e3..2065b18314 100644 --- a/sql/expression/function/aggregation/std_test.go +++ b/sql/expression/function/aggregation/std_test.go @@ -94,4 +94,3 @@ func TestStd(t *testing.T) { }) } } - diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 080726d24c..064bda1ecd 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -3,7 +3,7 @@ package aggregation import ( "fmt" "math" -"reflect" + "reflect" "github.com/cespare/xxhash/v2" "github.com/shopspring/decimal" @@ -711,9 +711,9 @@ func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } - s.newMean = s.oldMean + (val - s.oldMean) / float64(s.count) - s.newVar = s.oldVar + (val - s.oldMean) * (val - s.newMean) - s.oldVar = s.newVar + s.newMean = s.oldMean + (val-s.oldMean)/float64(s.count) + s.newVar = s.oldVar + (val-s.oldMean)*(val-s.newMean) + s.oldVar = s.newVar s.oldMean = s.newMean return nil @@ -729,4 +729,4 @@ func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { // Dispose implements the Disposable interface. func (s *stdDevPopBuffer) Dispose() { -} \ No newline at end of file +} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 21d859f59a..ad5c588255 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1459,4 +1459,4 @@ func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, bu return v } return nil -} \ No newline at end of file +} From 392284750c8f345761b9f1f4f4625e8e78f4ed8a Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 13:41:55 -0700 Subject: [PATCH 04/19] implement std and tests --- enginetest/queries/script_queries.go | 64 +++++++++++++++++++ .../function/aggregation/unary_agg_buffers.go | 3 +- .../function/aggregation/window_functions.go | 48 ++++++++++++-- sql/planbuilder/aggregates.go | 3 +- 4 files changed, 110 insertions(+), 8 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 654a950053..265fda982c 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7939,6 +7939,18 @@ where {0.816496580927726, 0.816496580927726, 0.816496580927726}, }, }, + { + Query: "insert into t values (null), (null);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Expected: []sql.Row{ + {0.816496580927726, 0.816496580927726, 0.816496580927726}, + }, + }, { Query: "select i, std(j) from tt group by i;", Expected: []sql.Row{ @@ -7946,6 +7958,58 @@ where {1, 271.89336144893275}, }, }, + { + Query: "select std(i) over(), std(j) over() from tt order by i;", + Expected: []sql.Row{ + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + }, + }, + { + Query: "select i, std(j) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {0, 0.816496580927726}, + {0, 0.816496580927726}, + {0, 0.816496580927726}, + {1, 271.89336144893275}, + {1, 271.89336144893275}, + {1, 271.89336144893275}, + }, + }, + { + Query: "insert into tt values (null, null);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i) over(), std(j) over() from tt order by i;", + Expected: []sql.Row{ + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + {0.5, 297.47660972475353}, + }, + }, + { + Query: "select i, std(j) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {nil, nil}, + {0, 0.816496580927726}, + {0, 0.816496580927726}, + {0, 0.816496580927726}, + {1, 271.89336144893275}, + {1, 271.89336144893275}, + {1, 271.89336144893275}, + }, + }, }, }, } diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 080726d24c..e1ce1691bc 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -693,7 +693,6 @@ func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error { return err } - // TODO: convert val to appropriate type v, _, err = types.Float64.Convert(ctx, v) if err != nil { v = 0.0 @@ -724,7 +723,7 @@ func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { if s.count == 0 { return nil, nil } - return math.Sqrt(s.newVar / float64(s.count)), nil // TODO: sqrt? + return math.Sqrt(s.newVar / float64(s.count)), nil } // Dispose implements the Disposable interface. diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 21d859f59a..6040851384 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -15,7 +15,8 @@ package aggregation import ( - "sort" + "math" +"sort" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -1412,6 +1413,12 @@ func (a *leadLagBase) Compute(ctx *sql.Context, interval sql.WindowInterval, buf type StdDevPopAgg struct { expr sql.Expression framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int } func NewStdDevPopAgg(e sql.Expression) *StdDevPopAgg { @@ -1446,17 +1453,48 @@ func (s *StdDevPopAgg) DefaultFramer() sql.WindowFramer { func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { s.Dispose() - return nil + s.partitionStart = interval.Start + s.partitionEnd = interval.End + var err error + s.prefixSum, s.nullCnt, err = floatPrefixSum(ctx, interval, buf, s.expr) + return err } func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - s.partitionStart - 1 + endIdx := interval.End - s.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= s.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += s.nullCnt[startIdx] + } + if nonNullCnt == 0 { + return nil + } + + m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) + v := 0.0 for i := interval.Start; i < interval.End; i++ { row := buf[i] - v, err := s.expr.Eval(ctx, row) + val, err := s.expr.Eval(ctx, row) if err != nil { return err } - return v + val, _, err = types.Float64.Convert(ctx, val) + if err != nil { + return nil + } + if val == nil { + continue + } + dv := val.(float64) - m + v += dv * dv } - return nil + + return math.Sqrt(v / float64(nonNullCnt)) } \ No newline at end of file diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 1459c6e32c..c8ddc91858 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -504,7 +504,8 @@ func isWindowFunc(name string) bool { "avg", "max", "min", "count_distinct", "json_arrayagg", "row_number", "percent_rank", "lead", "lag", "first_value", "last_value", - "rank", "dense_rank": + "rank", "dense_rank", + "std", "stddev", "stddev_pop": return true default: return false From 48b1214bb9816c00c20511279f55de949cb11490 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 13:53:18 -0700 Subject: [PATCH 05/19] refactor --- .../function/aggregation/window_functions.go | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 6040851384..5d0b9f265f 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1460,6 +1460,29 @@ func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInter return err } +func computeVariance(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer, expr sql.Expression, m float64) (float64, error) { + var v float64 + for i := interval.Start; i < interval.End; i++ { + row := buf[i] + val, err := expr.Eval(ctx, row) + if err != nil { + return 0, err + } + // TODO: consider saving conversions to avoid double Converts + val, _, err = types.Float64.Convert(ctx, val) + if err != nil { + val = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + } + if val == nil { + continue + } + dv := val.(float64) - m + v += dv * dv + } + return v, nil +} + func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { startIdx := interval.Start - s.partitionStart - 1 endIdx := interval.End - s.partitionStart - 1 @@ -1478,22 +1501,9 @@ func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, bu } m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) - v := 0.0 - for i := interval.Start; i < interval.End; i++ { - row := buf[i] - val, err := s.expr.Eval(ctx, row) - if err != nil { - return err - } - val, _, err = types.Float64.Convert(ctx, val) - if err != nil { - return nil - } - if val == nil { - continue - } - dv := val.(float64) - m - v += dv * dv + v, err := computeVariance(ctx, interval, buf, s.expr, m) + if err != nil { + return err } return math.Sqrt(v / float64(nonNullCnt)) From d1c4874bc6febb0dddaa8dd139ad240b4af0dc1b Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 14:30:42 -0700 Subject: [PATCH 06/19] implment std samp --- enginetest/queries/script_queries.go | 86 +++++++++--------- optgen/cmd/source/unary_aggs.yaml | 4 +- .../function/aggregation/unary_agg_buffers.go | 65 +++++++++++++- .../function/aggregation/unary_aggs.og.go | 87 ++++++++++++++++++- .../function/aggregation/window_functions.go | 76 ++++++++++++++++ sql/expression/function/registry.go | 1 + sql/planbuilder/aggregates.go | 2 +- 7 files changed, 270 insertions(+), 51 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 265fda982c..52e06b72b7 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7898,9 +7898,9 @@ where }, Assertions: []ScriptTestAssertion{ { - Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", Expected: []sql.Row{ - {nil, nil, nil}, + {nil, nil, nil, nil}, }, }, { @@ -7910,9 +7910,9 @@ where }, }, { - Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", Expected: []sql.Row{ - {0.0, 0.0, 0.0}, + {0.0, 0.0, 0.0, nil}, }, }, { @@ -7922,9 +7922,9 @@ where }, }, { - Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", Expected: []sql.Row{ - {0.5, 0.5, 0.5}, + {0.5, 0.5, 0.5, 0.7071067811865476}, }, }, { @@ -7934,9 +7934,9 @@ where }, }, { - Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", Expected: []sql.Row{ - {0.816496580927726, 0.816496580927726, 0.816496580927726}, + {0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0}, }, }, { @@ -7946,38 +7946,38 @@ where }, }, { - Query: "select std(i), stddev(i), stddev_pop(i) from t;", + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", Expected: []sql.Row{ - {0.816496580927726, 0.816496580927726, 0.816496580927726}, + {0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0}, }, }, { - Query: "select i, std(j) from tt group by i;", + Query: "select i, std(j), stddev_samp(j) from tt group by i;", Expected: []sql.Row{ - {0, 0.816496580927726}, - {1, 271.89336144893275}, + {0, 0.816496580927726, 1.0}, + {1, 271.89336144893275, 333.0}, }, }, { - Query: "select std(i) over(), std(j) over() from tt order by i;", + Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", Expected: []sql.Row{ - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, }, }, { - Query: "select i, std(j) over(partition by i) from tt order by i;", + Query: "select i, std(j) over(partition by i), stddev_samp(j) over(partition by i) from tt order by i;", Expected: []sql.Row{ - {0, 0.816496580927726}, - {0, 0.816496580927726}, - {0, 0.816496580927726}, - {1, 271.89336144893275}, - {1, 271.89336144893275}, - {1, 271.89336144893275}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, }, }, { @@ -7987,27 +7987,27 @@ where }, }, { - Query: "select std(i) over(), std(j) over() from tt order by i;", + Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", Expected: []sql.Row{ - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, - {0.5, 297.47660972475353}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, }, }, { - Query: "select i, std(j) over(partition by i) from tt order by i;", + Query: "select i, std(j) over(partition by i), stddev_samp(j) over(partition by i) from tt order by i;", Expected: []sql.Row{ - {nil, nil}, - {0, 0.816496580927726}, - {0, 0.816496580927726}, - {0, 0.816496580927726}, - {1, 271.89336144893275}, - {1, 271.89336144893275}, - {1, 271.89336144893275}, + {nil, nil, nil}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, }, }, }, diff --git a/optgen/cmd/source/unary_aggs.yaml b/optgen/cmd/source/unary_aggs.yaml index a59929bd5e..72af5361fc 100644 --- a/optgen/cmd/source/unary_aggs.yaml +++ b/optgen/cmd/source/unary_aggs.yaml @@ -33,4 +33,6 @@ unaryAggs: desc: "returns the sum of expr in all rows" nullable: false - name: "StdDevPop" - desc: "returns the population standard deviation of expr" \ No newline at end of file + desc: "returns the population standard deviation of expr" +- name: "StdDevSamp" + desc: "returns the sample standard deviation of expr" \ No newline at end of file diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index e1ce1691bc..d5219fa688 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -672,7 +672,7 @@ type stdDevPopBuffer struct { vals []interface{} expr sql.Expression - count int64 + count uint64 oldMean float64 newMean float64 oldVar float64 @@ -728,4 +728,65 @@ func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { // Dispose implements the Disposable interface. func (s *stdDevPopBuffer) Dispose() { -} \ No newline at end of file +} + +type stdDevSampBuffer struct { + vals []interface{} + expr sql.Expression + + count uint64 + oldMean float64 + newMean float64 + oldVar float64 + newVar float64 +} + +func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { + return &stdDevSampBuffer{ + vals: nil, + expr: child, + } +} + +// Update implements the AggregationBuffer interface. +func (s *stdDevSampBuffer) Update(ctx *sql.Context, row sql.Row) error { + v, err := s.expr.Eval(ctx, row) + if err != nil { + return err + } + + v, _, err = types.Float64.Convert(ctx, v) + if err != nil { + v = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + } + if v == nil { + return nil + } + val := v.(float64) + + s.count += 1 + if s.count == 1 { + s.oldMean = val + s.newMean = val + return nil + } + + s.newMean = s.oldMean + (val - s.oldMean) / float64(s.count) + s.newVar = s.oldVar + (val - s.oldMean) * (val - s.newMean) + s.oldVar = s.newVar + s.oldMean = s.newMean + + return nil +} + +// Eval implements the AggregationBuffer interface. +func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if s.count <= 1 { + return nil, nil + } + return math.Sqrt(s.newVar / float64(s.count - 1)), nil +} + +// Dispose implements the Disposable interface. +func (s *stdDevSampBuffer) Dispose() {} \ No newline at end of file diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index be64d85ebb..4207f348c7 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -988,23 +988,23 @@ func (a *StdDevPop) IsNullable() bool { func (a *StdDevPop) String() string { if a.window != nil { pr := sql.NewTreePrinter() - _ = pr.WriteNode("STDDEV_POP") + _ = pr.WriteNode("STDDEVPOP") children := []string{a.window.String(), a.Child.String()} pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("STDDEV_POP(%s)", a.Child) + return fmt.Sprintf("STDDEVPOP(%s)", a.Child) } func (a *StdDevPop) DebugString() string { if a.window != nil { pr := sql.NewTreePrinter() - _ = pr.WriteNode("STDDEV_POP") + _ = pr.WriteNode("STDDEVPOP") children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("STDDEV_POP(%s)", sql.DebugString(a.Child)) + return fmt.Sprintf("STDDEVPOP(%s)", sql.DebugString(a.Child)) } func (a *StdDevPop) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { @@ -1037,3 +1037,82 @@ func (a *StdDevPop) NewWindowFunction() (sql.WindowFunction, error) { } return NewStdDevPopAgg(child).WithWindow(a.Window()) } + +type StdDevSamp struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*StdDevSamp)(nil) +var _ sql.Aggregation = (*StdDevSamp)(nil) +var _ sql.WindowAdaptableExpression = (*StdDevSamp)(nil) + +func NewStdDevSamp(e sql.Expression) *StdDevSamp { + return &StdDevSamp{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "StdDevSamp", + description: "returns the sample standard deviation of expr", + }, + } +} + +func (a *StdDevSamp) Type() sql.Type { + return a.Child.Type() +} + +func (a *StdDevSamp) IsNullable() bool { + return false +} + +func (a *StdDevSamp) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEV_SAMP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEV_SAMP(%s)", a.Child) +} + +func (a *StdDevSamp) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEV_SAMP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEV_SAMP(%s)", sql.DebugString(a.Child)) +} + +func (a *StdDevSamp) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &StdDevSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevSamp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &StdDevSamp{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *StdDevSamp) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &StdDevSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevSamp) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevSampBuffer(child), nil +} + +func (a *StdDevSamp) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevSampAgg(child).WithWindow(a.Window()) +} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 5d0b9f265f..1e2c45bbed 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1507,4 +1507,80 @@ func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, bu } return math.Sqrt(v / float64(nonNullCnt)) +} + +type StdDevSampAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewStdDevSampAgg(e sql.Expression) *StdDevSampAgg { + return &StdDevSampAgg{ + expr: e, + } +} + +func (s *StdDevSampAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *s + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (s *StdDevSampAgg) Dispose() { + expression.Dispose(s.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (s *StdDevSampAgg) DefaultFramer() sql.WindowFramer { + if s.framer != nil { + return s.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (s *StdDevSampAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + s.Dispose() + s.partitionStart = interval.Start + s.partitionEnd = interval.End + var err error + s.prefixSum, s.nullCnt, err = floatPrefixSum(ctx, interval, buf, s.expr) + return err +} + +func (s *StdDevSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - s.partitionStart - 1 + endIdx := interval.End - s.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= s.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += s.nullCnt[startIdx] + } + if nonNullCnt <= 1 { + return nil + } + + m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) + v, err := computeVariance(ctx, interval, buf, s.expr, m) + if err != nil { + return err + } + + return math.Sqrt(v / float64(nonNullCnt - 1)) } \ No newline at end of file diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index f5129d04a5..f527684ab5 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -292,6 +292,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "std", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, sql.Function1{Name: "stddev", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, sql.Function1{Name: "stddev_pop", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, + sql.Function1{Name: "stddev_samp", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevSamp(e) }}, sql.FunctionN{Name: "sysdate", Fn: NewSysdate}, sql.Function1{Name: "tan", Fn: NewTan}, sql.Function1{Name: "time", Fn: NewTime}, diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index c8ddc91858..fe18ee4e63 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -505,7 +505,7 @@ func isWindowFunc(name string) bool { "row_number", "percent_rank", "lead", "lag", "first_value", "last_value", "rank", "dense_rank", - "std", "stddev", "stddev_pop": + "std", "stddev", "stddev_pop", "stddev_samp": return true default: return false From 0d79ae9c26c4474548bfa2b9da565a6e369c4298 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 9 Apr 2025 21:32:57 +0000 Subject: [PATCH 07/19] [ga-format-pr] Run ./format_repo.sh to fix formatting --- .../function/aggregation/unary_agg_buffers.go | 10 +++++----- .../function/aggregation/window_functions.go | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 92c9325883..8b029013e8 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -772,9 +772,9 @@ func (s *stdDevSampBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } - s.newMean = s.oldMean + (val - s.oldMean) / float64(s.count) - s.newVar = s.oldVar + (val - s.oldMean) * (val - s.newMean) - s.oldVar = s.newVar + s.newMean = s.oldMean + (val-s.oldMean)/float64(s.count) + s.newVar = s.oldVar + (val-s.oldMean)*(val-s.newMean) + s.oldVar = s.newVar s.oldMean = s.newMean return nil @@ -785,8 +785,8 @@ func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { if s.count <= 1 { return nil, nil } - return math.Sqrt(s.newVar / float64(s.count - 1)), nil + return math.Sqrt(s.newVar / float64(s.count-1)), nil } // Dispose implements the Disposable interface. -func (s *stdDevSampBuffer) Dispose() {} \ No newline at end of file +func (s *stdDevSampBuffer) Dispose() {} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 19ac1bb392..132be101de 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1417,8 +1417,8 @@ type StdDevPopAgg struct { partitionStart int partitionEnd int - prefixSum []float64 - nullCnt []int + prefixSum []float64 + nullCnt []int } func NewStdDevPopAgg(e sql.Expression) *StdDevPopAgg { @@ -1454,7 +1454,7 @@ func (s *StdDevPopAgg) DefaultFramer() sql.WindowFramer { func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { s.Dispose() s.partitionStart = interval.Start - s.partitionEnd = interval.End + s.partitionEnd = interval.End var err error s.prefixSum, s.nullCnt, err = floatPrefixSum(ctx, interval, buf, s.expr) return err @@ -1516,8 +1516,8 @@ type StdDevSampAgg struct { partitionStart int partitionEnd int - prefixSum []float64 - nullCnt []int + prefixSum []float64 + nullCnt []int } func NewStdDevSampAgg(e sql.Expression) *StdDevSampAgg { @@ -1553,7 +1553,7 @@ func (s *StdDevSampAgg) DefaultFramer() sql.WindowFramer { func (s *StdDevSampAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { s.Dispose() s.partitionStart = interval.Start - s.partitionEnd = interval.End + s.partitionEnd = interval.End var err error s.prefixSum, s.nullCnt, err = floatPrefixSum(ctx, interval, buf, s.expr) return err @@ -1582,5 +1582,5 @@ func (s *StdDevSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, b return err } - return math.Sqrt(v / float64(nonNullCnt - 1)) + return math.Sqrt(v / float64(nonNullCnt-1)) } From a4e8d455256b652c73a439e219dec72dba04e6e8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 16:44:53 -0700 Subject: [PATCH 08/19] implment variance and tests --- enginetest/queries/script_queries.go | 101 ++++++++++-- optgen/cmd/source/unary_aggs.yaml | 4 +- .../function/aggregation/std_test.go | 144 ++++++++++++++++++ .../function/aggregation/unary_agg_buffers.go | 131 +++++++++++----- .../function/aggregation/unary_aggs.og.go | 87 ++++++++++- .../function/aggregation/window_functions.go | 88 ++++++++++- sql/expression/function/registry.go | 2 + sql/planbuilder/aggregates.go | 3 +- 8 files changed, 499 insertions(+), 61 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 52e06b72b7..d9c3cd0e6b 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7888,7 +7888,7 @@ where }, }, { - Name: "std, stdev, stddev_pop tests", + Name: "std, stdev, stddev_pop, variance, var_pop, var_samp tests", Dialect: "mysql", SetUpScript: []string{ "create table t (i int);", @@ -7903,6 +7903,12 @@ where {nil, nil, nil, nil}, }, }, + { + Query: "select variance(i), var_pop(i) from t;", + Expected: []sql.Row{ + {nil, nil}, + }, + }, { Query: "insert into t values (1);", Expected: []sql.Row{ @@ -7915,6 +7921,12 @@ where {0.0, 0.0, 0.0, nil}, }, }, + { + Query: "select variance(i), var_pop(i) from t;", + Expected: []sql.Row{ + {0.0, 0.0}, + }, + }, { Query: "insert into t values (2);", Expected: []sql.Row{ @@ -7927,6 +7939,12 @@ where {0.5, 0.5, 0.5, 0.7071067811865476}, }, }, + { + Query: "select variance(i), var_pop(i) from t;", + Expected: []sql.Row{ + {0.25, 0.25}, + }, + }, { Query: "insert into t values (3);", Expected: []sql.Row{ @@ -7939,6 +7957,12 @@ where {0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0}, }, }, + { + Query: "select variance(i), var_pop(i) from t;", + Expected: []sql.Row{ + {0.6666666666666666, 0.6666666666666666}, + }, + }, { Query: "insert into t values (null), (null);", Expected: []sql.Row{ @@ -7951,6 +7975,12 @@ where {0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0}, }, }, + { + Query: "select variance(i), var_pop(i) from t;", + Expected: []sql.Row{ + {0.6666666666666666, 0.6666666666666666}, + }, + }, { Query: "select i, std(j), stddev_samp(j) from tt group by i;", Expected: []sql.Row{ @@ -7958,6 +7988,13 @@ where {1, 271.89336144893275, 333.0}, }, }, + { + Query: "select i, variance(i) from tt group by i;", + Expected: []sql.Row{ + {0, 0.0}, + {1, 0.0}, + }, + }, { Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", Expected: []sql.Row{ @@ -7980,6 +8017,28 @@ where {1, 271.89336144893275, 333.0}, }, }, + { + Query: "select i, variance(i) over() from tt order by i;", + Expected: []sql.Row{ + {0, 0.25}, + {0, 0.25}, + {0, 0.25}, + {1, 0.25}, + {1, 0.25}, + {1, 0.25}, + }, + }, + { + Query: "select i, variance(j) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {0, 0.6666666666666666}, + {0, 0.6666666666666666}, + {0, 0.6666666666666666}, + {1, 73926.0}, + {1, 73926.0}, + {1, 73926.0}, + }, + }, { Query: "insert into tt values (null, null);", Expected: []sql.Row{ @@ -7987,15 +8046,15 @@ where }, }, { - Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", + Query: "select i, std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", Expected: []sql.Row{ - {0.5, 297.47660972475353, 325.86929895281634}, - {0.5, 297.47660972475353, 325.86929895281634}, - {0.5, 297.47660972475353, 325.86929895281634}, - {0.5, 297.47660972475353, 325.86929895281634}, - {0.5, 297.47660972475353, 325.86929895281634}, - {0.5, 297.47660972475353, 325.86929895281634}, - {0.5, 297.47660972475353, 325.86929895281634}, + {nil, 0.5, 297.47660972475353, 325.86929895281634}, + {0, 0.5, 297.47660972475353, 325.86929895281634}, + {0, 0.5, 297.47660972475353, 325.86929895281634}, + {0, 0.5, 297.47660972475353, 325.86929895281634}, + {1, 0.5, 297.47660972475353, 325.86929895281634}, + {1, 0.5, 297.47660972475353, 325.86929895281634}, + {1, 0.5, 297.47660972475353, 325.86929895281634}, }, }, { @@ -8010,6 +8069,30 @@ where {1, 271.89336144893275, 333.0}, }, }, + { + Query: "select i, variance(i) over() from tt order by i;", + Expected: []sql.Row{ + {nil, 0.25}, + {0, 0.25}, + {0, 0.25}, + {0, 0.25}, + {1, 0.25}, + {1, 0.25}, + {1, 0.25}, + }, + }, + { + Query: "select i, variance(j) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {nil, nil}, + {0, 0.6666666666666666}, + {0, 0.6666666666666666}, + {0, 0.6666666666666666}, + {1, 73926.0}, + {1, 73926.0}, + {1, 73926.0}, + }, + }, }, }, } diff --git a/optgen/cmd/source/unary_aggs.yaml b/optgen/cmd/source/unary_aggs.yaml index 72af5361fc..ef3d317482 100644 --- a/optgen/cmd/source/unary_aggs.yaml +++ b/optgen/cmd/source/unary_aggs.yaml @@ -35,4 +35,6 @@ unaryAggs: - name: "StdDevPop" desc: "returns the population standard deviation of expr" - name: "StdDevSamp" - desc: "returns the sample standard deviation of expr" \ No newline at end of file + desc: "returns the sample standard deviation of expr" +- name: "VarPop" + desc: "returns the population variance of expr" \ No newline at end of file diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go index 2065b18314..9e20c79363 100644 --- a/sql/expression/function/aggregation/std_test.go +++ b/sql/expression/function/aggregation/std_test.go @@ -94,3 +94,147 @@ func TestStd(t *testing.T) { }) } } + +func TestStdSamp(t *testing.T) { + sum := NewStdDevSamp(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.2909944487358056, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 1.1086778913041726, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.25, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 1.4142135623730951, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 1.4142135623730951, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 1.4142135623730951, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} + +func TestVariance(t *testing.T) { + sum := NewVarPop(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.25, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 0.9218750000000001, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.171875, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 1.0, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 1.0, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 1.0, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 8b029013e8..bcbdd79ff0 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -668,15 +668,34 @@ func (j *jsonArrayBuffer) Eval(ctx *sql.Context) (interface{}, error) { func (j *jsonArrayBuffer) Dispose() { } +func evalFloat64(ctx *sql.Context, row sql.Row, expr sql.Expression) (any, error) { + v, err := expr.Eval(ctx, row) + if err != nil { + return nil, err + } + v, _, err = types.Float64.Convert(ctx, v) + if err != nil { + v = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + } + return v, nil +} + +func calcOnlineMean(oldMean float64, val float64, count uint64) float64 { + return oldMean + (val - oldMean) / float64(count) +} + +func calcOnlineVar2(oldMean, newMean, oldVar2, val float64) float64 { + return oldVar2 + (val - oldMean) * (val - newMean) +} + type stdDevPopBuffer struct { vals []interface{} expr sql.Expression - count uint64 - oldMean float64 - newMean float64 - oldVar float64 - newVar float64 + count uint64 + mean float64 + std2 float64 } func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { @@ -688,16 +707,10 @@ func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { // Update implements the AggregationBuffer interface. func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := s.expr.Eval(ctx, row) + v, err := evalFloat64(ctx, row, s.expr) if err != nil { return err } - - v, _, err = types.Float64.Convert(ctx, v) - if err != nil { - v = 0.0 - ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) - } if v == nil { return nil } @@ -705,15 +718,13 @@ func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error { s.count += 1 if s.count == 1 { - s.oldMean = val - s.newMean = val + s.mean = val return nil } - s.newMean = s.oldMean + (val-s.oldMean)/float64(s.count) - s.newVar = s.oldVar + (val-s.oldMean)*(val-s.newMean) - s.oldVar = s.newVar - s.oldMean = s.newMean + newMean := calcOnlineMean(s.mean, val, s.count) + s.std2 = calcOnlineVar2(s.mean, newMean, s.std2, val) + s.mean = newMean return nil } @@ -723,22 +734,19 @@ func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { if s.count == 0 { return nil, nil } - return math.Sqrt(s.newVar / float64(s.count)), nil + return math.Sqrt(s.std2 / float64(s.count)), nil } // Dispose implements the Disposable interface. -func (s *stdDevPopBuffer) Dispose() { -} +func (s *stdDevPopBuffer) Dispose() {} type stdDevSampBuffer struct { vals []interface{} expr sql.Expression - count uint64 - oldMean float64 - newMean float64 - oldVar float64 - newVar float64 + count uint64 + mean float64 + std2 float64 } func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { @@ -750,16 +758,10 @@ func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { // Update implements the AggregationBuffer interface. func (s *stdDevSampBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := s.expr.Eval(ctx, row) + v, err := evalFloat64(ctx, row, s.expr) if err != nil { return err } - - v, _, err = types.Float64.Convert(ctx, v) - if err != nil { - v = 0.0 - ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) - } if v == nil { return nil } @@ -767,15 +769,13 @@ func (s *stdDevSampBuffer) Update(ctx *sql.Context, row sql.Row) error { s.count += 1 if s.count == 1 { - s.oldMean = val - s.newMean = val + s.mean = val return nil } - s.newMean = s.oldMean + (val-s.oldMean)/float64(s.count) - s.newVar = s.oldVar + (val-s.oldMean)*(val-s.newMean) - s.oldVar = s.newVar - s.oldMean = s.newMean + newMean := calcOnlineMean(s.mean, val, s.count) + s.std2 = calcOnlineVar2(s.mean, newMean, s.std2, val) + s.mean = newMean return nil } @@ -785,8 +785,59 @@ func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { if s.count <= 1 { return nil, nil } - return math.Sqrt(s.newVar / float64(s.count-1)), nil + return math.Sqrt(s.std2 / float64(s.count-1)), nil } // Dispose implements the Disposable interface. func (s *stdDevSampBuffer) Dispose() {} + +type varPopBuffer struct { + vals []interface{} + expr sql.Expression + + count uint64 + mean float64 + std2 float64 +} + +func NewVarPopBuffer(child sql.Expression) *varPopBuffer { + return &varPopBuffer{ + vals: nil, + expr: child, + } +} + +// Update implements the AggregationBuffer interface. +func (s *varPopBuffer) Update(ctx *sql.Context, row sql.Row) error { + v, err := evalFloat64(ctx, row, s.expr) + if err != nil { + return err + } + if v == nil { + return nil + } + val := v.(float64) + + s.count += 1 + if s.count == 1 { + s.mean = val + return nil + } + + newMean := calcOnlineMean(s.mean, val, s.count) + s.std2 = calcOnlineVar2(s.mean, newMean, s.std2, val) + s.mean = newMean + + return nil +} + +// Eval implements the AggregationBuffer interface. +func (s *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if s.count == 0 { + return nil, nil + } + return s.std2 / float64(s.count), nil +} + +// Dispose implements the Disposable interface. +func (s *varPopBuffer) Dispose() {} diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 4207f348c7..1320639034 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -1067,23 +1067,23 @@ func (a *StdDevSamp) IsNullable() bool { func (a *StdDevSamp) String() string { if a.window != nil { pr := sql.NewTreePrinter() - _ = pr.WriteNode("STDDEV_SAMP") + _ = pr.WriteNode("STDDEVSAMP") children := []string{a.window.String(), a.Child.String()} pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("STDDEV_SAMP(%s)", a.Child) + return fmt.Sprintf("STDDEVSAMP(%s)", a.Child) } func (a *StdDevSamp) DebugString() string { if a.window != nil { pr := sql.NewTreePrinter() - _ = pr.WriteNode("STDDEV_SAMP") + _ = pr.WriteNode("STDDEVSAMP") children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("STDDEV_SAMP(%s)", sql.DebugString(a.Child)) + return fmt.Sprintf("STDDEVSAMP(%s)", sql.DebugString(a.Child)) } func (a *StdDevSamp) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { @@ -1116,3 +1116,82 @@ func (a *StdDevSamp) NewWindowFunction() (sql.WindowFunction, error) { } return NewStdDevSampAgg(child).WithWindow(a.Window()) } + +type VarPop struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*VarPop)(nil) +var _ sql.Aggregation = (*VarPop)(nil) +var _ sql.WindowAdaptableExpression = (*VarPop)(nil) + +func NewVarPop(e sql.Expression) *VarPop { + return &VarPop{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "VarPop", + description: "returns the population variance of expr", + }, + } +} + +func (a *VarPop) Type() sql.Type { + return a.Child.Type() +} + +func (a *VarPop) IsNullable() bool { + return false +} + +func (a *VarPop) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARPOP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARPOP(%s)", a.Child) +} + +func (a *VarPop) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARPOP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARPOP(%s)", sql.DebugString(a.Child)) +} + +func (a *VarPop) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &VarPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarPop) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &VarPop{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *VarPop) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &VarPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarPop) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarPopBuffer(child), nil +} + +func (a *VarPop) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarPopAgg(child).WithWindow(a.Window()) +} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 132be101de..859e5b18d6 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1460,7 +1460,7 @@ func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInter return err } -func computeVariance(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer, expr sql.Expression, m float64) (float64, error) { +func computeStd2(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer, expr sql.Expression, m float64) (float64, error) { var v float64 for i := interval.Start; i < interval.End; i++ { row := buf[i] @@ -1472,7 +1472,7 @@ func computeVariance(ctx *sql.Context, interval sql.WindowInterval, buf sql.Wind val, _, err = types.Float64.Convert(ctx, val) if err != nil { val = 0.0 - ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", val) } if val == nil { continue @@ -1501,12 +1501,12 @@ func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, bu } m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) - v, err := computeVariance(ctx, interval, buf, s.expr, m) + s2, err := computeStd2(ctx, interval, buf, s.expr, m) if err != nil { return err } - return math.Sqrt(v / float64(nonNullCnt)) + return math.Sqrt(s2 / float64(nonNullCnt)) } type StdDevSampAgg struct { @@ -1577,10 +1577,86 @@ func (s *StdDevSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, b } m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) - v, err := computeVariance(ctx, interval, buf, s.expr, m) + s2, err := computeStd2(ctx, interval, buf, s.expr, m) if err != nil { return err } - return math.Sqrt(v / float64(nonNullCnt-1)) + return math.Sqrt(s2 / float64(nonNullCnt - 1)) } + +type VarPopAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewVarPopAgg(e sql.Expression) *VarPopAgg { + return &VarPopAgg{ + expr: e, + } +} + +func (v *VarPopAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *v + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (v *VarPopAgg) Dispose() { + expression.Dispose(v.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (v *VarPopAgg) DefaultFramer() sql.WindowFramer { + if v.framer != nil { + return v.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (v *VarPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + v.Dispose() + v.partitionStart = interval.Start + v.partitionEnd = interval.End + var err error + v.prefixSum, v.nullCnt, err = floatPrefixSum(ctx, interval, buf, v.expr) + return err +} + +func (v *VarPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - v.partitionStart - 1 + endIdx := interval.End - v.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= v.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += v.nullCnt[startIdx] + } + if nonNullCnt <= 0 { + return nil + } + + m := computePrefixSum(interval, v.partitionStart, v.prefixSum) / float64(nonNullCnt) + s2, err := computeStd2(ctx, interval, buf, v.expr, m) + if err != nil { + return err + } + + return s2 / float64(nonNullCnt) +} \ No newline at end of file diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index f527684ab5..41e7b18296 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -316,6 +316,8 @@ var BuiltIns = []sql.Function{ sql.FunctionN{Name: "week", Fn: NewWeek}, sql.Function1{Name: "values", Fn: NewValues}, sql.Function1{Name: "validate_password_strength", Fn: NewValidatePasswordStrength}, + sql.Function1{Name: "variance", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }}, + sql.Function1{Name: "var_pop", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }}, sql.Function2{Name: "vec_distance", Fn: vector.NewL2SquaredDistance}, sql.Function2{Name: "vec_distance_l2_squared", Fn: vector.NewL2SquaredDistance}, sql.Function1{Name: "weekday", Fn: NewWeekday}, diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index fe18ee4e63..c5f0b13d9f 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -505,7 +505,8 @@ func isWindowFunc(name string) bool { "row_number", "percent_rank", "lead", "lag", "first_value", "last_value", "rank", "dense_rank", - "std", "stddev", "stddev_pop", "stddev_samp": + "std", "stddev", "stddev_pop", "stddev_samp", + "variance", "var_pop", "var_samp": return true default: return false From c3f2fab3559e67f04c263b4c2e46d8f5780418fa Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 9 Apr 2025 23:46:14 +0000 Subject: [PATCH 09/19] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/aggregation/unary_agg_buffers.go | 8 ++++---- sql/expression/function/aggregation/window_functions.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index bcbdd79ff0..27a6e4da88 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -682,11 +682,11 @@ func evalFloat64(ctx *sql.Context, row sql.Row, expr sql.Expression) (any, error } func calcOnlineMean(oldMean float64, val float64, count uint64) float64 { - return oldMean + (val - oldMean) / float64(count) + return oldMean + (val-oldMean)/float64(count) } func calcOnlineVar2(oldMean, newMean, oldVar2, val float64) float64 { - return oldVar2 + (val - oldMean) * (val - newMean) + return oldVar2 + (val-oldMean)*(val-newMean) } type stdDevPopBuffer struct { @@ -796,8 +796,8 @@ type varPopBuffer struct { expr sql.Expression count uint64 - mean float64 - std2 float64 + mean float64 + std2 float64 } func NewVarPopBuffer(child sql.Expression) *varPopBuffer { diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 859e5b18d6..cf21f38dad 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1582,7 +1582,7 @@ func (s *StdDevSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, b return err } - return math.Sqrt(s2 / float64(nonNullCnt - 1)) + return math.Sqrt(s2 / float64(nonNullCnt-1)) } type VarPopAgg struct { @@ -1659,4 +1659,4 @@ func (v *VarPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf s } return s2 / float64(nonNullCnt) -} \ No newline at end of file +} From e83585af2ee2b843194171ef518d2b29b110b4d0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 17:22:08 -0700 Subject: [PATCH 10/19] implement var samp --- enginetest/queries/script_queries.go | 86 +++++++++---------- optgen/cmd/source/unary_aggs.yaml | 4 +- .../function/aggregation/std_test.go | 72 ++++++++++++++++ .../function/aggregation/unary_agg_buffers.go | 75 +++++++++++++--- .../function/aggregation/unary_aggs.og.go | 79 +++++++++++++++++ .../function/aggregation/window_functions.go | 76 ++++++++++++++++ sql/expression/function/registry.go | 1 + 7 files changed, 337 insertions(+), 56 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index d9c3cd0e6b..34f91df0eb 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7904,9 +7904,9 @@ where }, }, { - Query: "select variance(i), var_pop(i) from t;", + Query: "select variance(i), var_pop(i), var_samp(i) from t;", Expected: []sql.Row{ - {nil, nil}, + {nil, nil, nil}, }, }, { @@ -7922,9 +7922,9 @@ where }, }, { - Query: "select variance(i), var_pop(i) from t;", + Query: "select variance(i), var_pop(i), var_samp(i) from t;", Expected: []sql.Row{ - {0.0, 0.0}, + {0.0, 0.0, nil}, }, }, { @@ -7940,9 +7940,9 @@ where }, }, { - Query: "select variance(i), var_pop(i) from t;", + Query: "select variance(i), var_pop(i), var_samp(i) from t;", Expected: []sql.Row{ - {0.25, 0.25}, + {0.25, 0.25, 0.5}, }, }, { @@ -7958,9 +7958,9 @@ where }, }, { - Query: "select variance(i), var_pop(i) from t;", + Query: "select variance(i), var_pop(i), var_samp(i) from t;", Expected: []sql.Row{ - {0.6666666666666666, 0.6666666666666666}, + {0.6666666666666666, 0.6666666666666666, 1.0}, }, }, { @@ -7976,9 +7976,9 @@ where }, }, { - Query: "select variance(i), var_pop(i) from t;", + Query: "select variance(i), var_pop(i), var_samp(i) from t;", Expected: []sql.Row{ - {0.6666666666666666, 0.6666666666666666}, + {0.6666666666666666, 0.6666666666666666, 1.0}, }, }, { @@ -7989,10 +7989,10 @@ where }, }, { - Query: "select i, variance(i) from tt group by i;", + Query: "select i, variance(i), var_samp(i) from tt group by i;", Expected: []sql.Row{ - {0, 0.0}, - {1, 0.0}, + {0, 0.0, 0.0}, + {1, 0.0, 0.0}, }, }, { @@ -8018,25 +8018,25 @@ where }, }, { - Query: "select i, variance(i) over() from tt order by i;", + Query: "select i, variance(i) over(), var_samp(i) over() from tt order by i;", Expected: []sql.Row{ - {0, 0.25}, - {0, 0.25}, - {0, 0.25}, - {1, 0.25}, - {1, 0.25}, - {1, 0.25}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, }, }, { - Query: "select i, variance(j) over(partition by i) from tt order by i;", + Query: "select i, variance(j) over(partition by i), var_samp(i) over(partition by i) from tt order by i;", Expected: []sql.Row{ - {0, 0.6666666666666666}, - {0, 0.6666666666666666}, - {0, 0.6666666666666666}, - {1, 73926.0}, - {1, 73926.0}, - {1, 73926.0}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, }, }, { @@ -8070,27 +8070,27 @@ where }, }, { - Query: "select i, variance(i) over() from tt order by i;", + Query: "select i, variance(i) over(), var_samp(i) over() from tt order by i;", Expected: []sql.Row{ - {nil, 0.25}, - {0, 0.25}, - {0, 0.25}, - {0, 0.25}, - {1, 0.25}, - {1, 0.25}, - {1, 0.25}, + {nil, 0.25, 0.3}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, }, }, { - Query: "select i, variance(j) over(partition by i) from tt order by i;", + Query: "select i, variance(j) over(partition by i), var_samp(i) over(partition by i) from tt order by i;", Expected: []sql.Row{ - {nil, nil}, - {0, 0.6666666666666666}, - {0, 0.6666666666666666}, - {0, 0.6666666666666666}, - {1, 73926.0}, - {1, 73926.0}, - {1, 73926.0}, + {nil, nil, nil}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, }, }, }, diff --git a/optgen/cmd/source/unary_aggs.yaml b/optgen/cmd/source/unary_aggs.yaml index ef3d317482..7e40fe762f 100644 --- a/optgen/cmd/source/unary_aggs.yaml +++ b/optgen/cmd/source/unary_aggs.yaml @@ -37,4 +37,6 @@ unaryAggs: - name: "StdDevSamp" desc: "returns the sample standard deviation of expr" - name: "VarPop" - desc: "returns the population variance of expr" \ No newline at end of file + desc: "returns the population variance of expr" +- name: "VarSamp" + desc: "returns the sample variance of expr" \ No newline at end of file diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go index 9e20c79363..32892cf619 100644 --- a/sql/expression/function/aggregation/std_test.go +++ b/sql/expression/function/aggregation/std_test.go @@ -238,3 +238,75 @@ func TestVariance(t *testing.T) { }) } } + +func TestVarSamp(t *testing.T) { + sum := NewVarSamp(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.6666666666666667, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 1.2291666666666667, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.5625, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 2.0, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 2.0, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 2.0, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index bcbdd79ff0..b990540405 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -808,8 +808,8 @@ func NewVarPopBuffer(child sql.Expression) *varPopBuffer { } // Update implements the AggregationBuffer interface. -func (s *varPopBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := evalFloat64(ctx, row, s.expr) +func (vp *varPopBuffer) Update(ctx *sql.Context, row sql.Row) error { + v, err := evalFloat64(ctx, row, vp.expr) if err != nil { return err } @@ -818,26 +818,77 @@ func (s *varPopBuffer) Update(ctx *sql.Context, row sql.Row) error { } val := v.(float64) - s.count += 1 - if s.count == 1 { - s.mean = val + vp.count += 1 + if vp.count == 1 { + vp.mean = val return nil } - newMean := calcOnlineMean(s.mean, val, s.count) - s.std2 = calcOnlineVar2(s.mean, newMean, s.std2, val) - s.mean = newMean + newMean := calcOnlineMean(vp.mean, val, vp.count) + vp.std2 = calcOnlineVar2(vp.mean, newMean, vp.std2, val) + vp.mean = newMean return nil } // Eval implements the AggregationBuffer interface. -func (s *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { - if s.count == 0 { +func (vp *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if vp.count == 0 { + return nil, nil + } + return vp.std2 / float64(vp.count), nil +} + +// Dispose implements the Disposable interface. +func (vp *varPopBuffer) Dispose() {} + +type varSampBuffer struct { + vals []interface{} + expr sql.Expression + + count uint64 + mean float64 + std2 float64 +} + +func NewVarSampBuffer(child sql.Expression) *varSampBuffer { + return &varSampBuffer{ + vals: nil, + expr: child, + } +} + +// Update implements the AggregationBuffer interface. +func (vp *varSampBuffer) Update(ctx *sql.Context, row sql.Row) error { + v, err := evalFloat64(ctx, row, vp.expr) + if err != nil { + return err + } + if v == nil { + return nil + } + val := v.(float64) + + vp.count += 1 + if vp.count == 1 { + vp.mean = val + return nil + } + + newMean := calcOnlineMean(vp.mean, val, vp.count) + vp.std2 = calcOnlineVar2(vp.mean, newMean, vp.std2, val) + vp.mean = newMean + + return nil +} + +// Eval implements the AggregationBuffer interface. +func (vp *varSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if vp.count <= 1 { return nil, nil } - return s.std2 / float64(s.count), nil + return vp.std2 / float64(vp.count - 1), nil } // Dispose implements the Disposable interface. -func (s *varPopBuffer) Dispose() {} +func (vp *varSampBuffer) Dispose() {} diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 1320639034..a5094cc975 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -1195,3 +1195,82 @@ func (a *VarPop) NewWindowFunction() (sql.WindowFunction, error) { } return NewVarPopAgg(child).WithWindow(a.Window()) } + +type VarSamp struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*VarSamp)(nil) +var _ sql.Aggregation = (*VarSamp)(nil) +var _ sql.WindowAdaptableExpression = (*VarSamp)(nil) + +func NewVarSamp(e sql.Expression) *VarSamp { + return &VarSamp{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "VarSamp", + description: "returns the sample variance of expr", + }, + } +} + +func (a *VarSamp) Type() sql.Type { + return a.Child.Type() +} + +func (a *VarSamp) IsNullable() bool { + return false +} + +func (a *VarSamp) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARSAMP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARSAMP(%s)", a.Child) +} + +func (a *VarSamp) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARSAMP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARSAMP(%s)", sql.DebugString(a.Child)) +} + +func (a *VarSamp) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &VarSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarSamp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &VarSamp{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *VarSamp) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &VarSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarSamp) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarSampBuffer(child), nil +} + +func (a *VarSamp) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarSampAgg(child).WithWindow(a.Window()) +} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 859e5b18d6..5b577e922d 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1659,4 +1659,80 @@ func (v *VarPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf s } return s2 / float64(nonNullCnt) +} + +type VarSampAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewVarSampAgg(e sql.Expression) *VarSampAgg { + return &VarSampAgg{ + expr: e, + } +} + +func (v *VarSampAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *v + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (v *VarSampAgg) Dispose() { + expression.Dispose(v.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (v *VarSampAgg) DefaultFramer() sql.WindowFramer { + if v.framer != nil { + return v.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (v *VarSampAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + v.Dispose() + v.partitionStart = interval.Start + v.partitionEnd = interval.End + var err error + v.prefixSum, v.nullCnt, err = floatPrefixSum(ctx, interval, buf, v.expr) + return err +} + +func (v *VarSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - v.partitionStart - 1 + endIdx := interval.End - v.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= v.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += v.nullCnt[startIdx] + } + if nonNullCnt <= 1 { + return nil + } + + m := computePrefixSum(interval, v.partitionStart, v.prefixSum) / float64(nonNullCnt) + s2, err := computeStd2(ctx, interval, buf, v.expr, m) + if err != nil { + return err + } + + return s2 / float64(nonNullCnt - 1) } \ No newline at end of file diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 41e7b18296..478d47e5ff 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -318,6 +318,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "validate_password_strength", Fn: NewValidatePasswordStrength}, sql.Function1{Name: "variance", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }}, sql.Function1{Name: "var_pop", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }}, + sql.Function1{Name: "var_samp", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarSamp(e) }}, sql.Function2{Name: "vec_distance", Fn: vector.NewL2SquaredDistance}, sql.Function2{Name: "vec_distance_l2_squared", Fn: vector.NewL2SquaredDistance}, sql.Function1{Name: "weekday", Fn: NewWeekday}, From 5fe38672cdfb345fbe2dd9eddda39201685eda8d Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 17:25:23 -0700 Subject: [PATCH 11/19] bump --- go.mod | 2 +- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index a68cf4f068..2cc7a7c49b 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250409183615-d8335325e91c + github.com/dolthub/vitess v0.0.0-20250410002136-c7dbb492484f github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 0ae08658e6..274d0e96d4 100644 --- a/go.sum +++ b/go.sum @@ -58,10 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3 h1:euU+adNAYw46Zcp1HnoaSDWhqjfaL8s/1SPU+i16gYM= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250409183615-d8335325e91c h1:hl+yPanHdJML9aMB0MgrTCpzsd3jIf/o3r8pC6Tqx6E= -github.com/dolthub/vitess v0.0.0-20250409183615-d8335325e91c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250410002136-c7dbb492484f h1:0Qx6wljg/gF5NieSKOE4uJH1Ff0e87xYZDvHTpQFlg8= +github.com/dolthub/vitess v0.0.0-20250410002136-c7dbb492484f/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From ce10c83086a9f65bd33ff1e98b0daaa22d3e12ff Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 10 Apr 2025 00:26:54 +0000 Subject: [PATCH 12/19] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/aggregation/unary_agg_buffers.go | 6 +++--- sql/expression/function/aggregation/window_functions.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 0c1b8f9836..fa34d58e59 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -847,8 +847,8 @@ type varSampBuffer struct { expr sql.Expression count uint64 - mean float64 - std2 float64 + mean float64 + std2 float64 } func NewVarSampBuffer(child sql.Expression) *varSampBuffer { @@ -887,7 +887,7 @@ func (vp *varSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { if vp.count <= 1 { return nil, nil } - return vp.std2 / float64(vp.count - 1), nil + return vp.std2 / float64(vp.count-1), nil } // Dispose implements the Disposable interface. diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 0a19224b39..f591070f6b 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1734,5 +1734,5 @@ func (v *VarSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf return err } - return s2 / float64(nonNullCnt - 1) -} \ No newline at end of file + return s2 / float64(nonNullCnt-1) +} From 2a6f320f7953d4318a2e074b956a353de9d662f2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Apr 2025 23:08:08 -0700 Subject: [PATCH 13/19] test with epsilon --- .../function/aggregation/std_test.go | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go index 32892cf619..7c8974b255 100644 --- a/sql/expression/function/aggregation/std_test.go +++ b/sql/expression/function/aggregation/std_test.go @@ -15,6 +15,7 @@ package aggregation import ( + "math" "testing" "github.com/stretchr/testify/require" @@ -23,6 +24,10 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" ) +func isFloatEqual(a, b float64) bool { + return math.Abs(a - b) < 1e-9 +} + func TestStd(t *testing.T) { sum := NewStdDevPop(expression.NewGetField(0, nil, "", false)) @@ -90,7 +95,11 @@ func TestStd(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - require.Equal(tt.expected, result) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) }) } } @@ -162,7 +171,11 @@ func TestStdSamp(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - require.Equal(tt.expected, result) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) }) } } @@ -234,7 +247,11 @@ func TestVariance(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - require.Equal(tt.expected, result) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) }) } } @@ -306,7 +323,11 @@ func TestVarSamp(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - require.Equal(tt.expected, result) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) }) } } From 8d49dca25d165e6e0efe93f9e4db74b769bbe81a Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 10 Apr 2025 06:09:29 +0000 Subject: [PATCH 14/19] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/aggregation/std_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go index 7c8974b255..9f694bfb4e 100644 --- a/sql/expression/function/aggregation/std_test.go +++ b/sql/expression/function/aggregation/std_test.go @@ -25,7 +25,7 @@ import ( ) func isFloatEqual(a, b float64) bool { - return math.Abs(a - b) < 1e-9 + return math.Abs(a-b) < 1e-9 } func TestStd(t *testing.T) { From 2c58e08ea45147dc839ee30914bc1de5115bde47 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 10 Apr 2025 00:56:55 -0700 Subject: [PATCH 15/19] adding window tests --- enginetest/queries/script_queries.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 34f91df0eb..e169d0901c 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -8093,6 +8093,18 @@ where {1, 73926.0, 0.0}, }, }, + { + Query: "select i, stddev_pop(j) over w, stddev_samp(j) over w, variance(j) over w, var_samp(i) over w from tt window w as (partition by i) order by i;", + Expected: []sql.Row{ + {nil, nil, nil, nil, nil}, + {0, 0.816496580927726, 1.0, 0.6666666666666666, 0.0}, + {0, 0.816496580927726, 1.0, 0.6666666666666666, 0.0}, + {0, 0.816496580927726, 1.0, 0.6666666666666666, 0.0}, + {1, 271.89336144893275, 333.0, 73926.0, 0.0}, + {1, 271.89336144893275, 333.0, 73926.0, 0.0}, + {1, 271.89336144893275, 333.0, 73926.0, 0.0}, + }, + }, }, }, } From fd9a1c14a411f4a53a5967d04783f3b2faf214e7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 10 Apr 2025 02:04:09 -0700 Subject: [PATCH 16/19] bump --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2cc7a7c49b..aed4846e2d 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250410002136-c7dbb492484f + github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 274d0e96d4..391e1a6454 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250410002136-c7dbb492484f h1:0Qx6wljg/gF5NieSKOE4uJH1Ff0e87xYZDvHTpQFlg8= -github.com/dolthub/vitess v0.0.0-20250410002136-c7dbb492484f/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4 h1:LGTt2LtYX8vaai32d+c9L0sMcP+Dg9w1kO6+lbsxxYg= +github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From 95436da472624965d2a27c94b119ce33cb766646 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 10 Apr 2025 10:46:09 -0700 Subject: [PATCH 17/19] consolidate code --- .../function/aggregation/unary_agg_buffers.go | 186 ++++-------------- 1 file changed, 41 insertions(+), 145 deletions(-) diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index fa34d58e59..2d43025e5c 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -668,28 +668,7 @@ func (j *jsonArrayBuffer) Eval(ctx *sql.Context) (interface{}, error) { func (j *jsonArrayBuffer) Dispose() { } -func evalFloat64(ctx *sql.Context, row sql.Row, expr sql.Expression) (any, error) { - v, err := expr.Eval(ctx, row) - if err != nil { - return nil, err - } - v, _, err = types.Float64.Convert(ctx, v) - if err != nil { - v = 0.0 - ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) - } - return v, nil -} - -func calcOnlineMean(oldMean float64, val float64, count uint64) float64 { - return oldMean + (val-oldMean)/float64(count) -} - -func calcOnlineVar2(oldMean, newMean, oldVar2, val float64) float64 { - return oldVar2 + (val-oldMean)*(val-newMean) -} - -type stdDevPopBuffer struct { +type varAggBase struct { vals []interface{} expr sql.Expression @@ -698,37 +677,50 @@ type stdDevPopBuffer struct { std2 float64 } -func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { - return &stdDevPopBuffer{ - vals: nil, - expr: child, - } -} - // Update implements the AggregationBuffer interface. -func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := evalFloat64(ctx, row, s.expr) +func (vb *varAggBase) Update(ctx *sql.Context, row sql.Row) error { + v, err := vb.expr.Eval(ctx, row) if err != nil { return err } + v, _, err = types.Float64.Convert(ctx, v) + if err != nil { + v = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + } if v == nil { return nil } val := v.(float64) - s.count += 1 - if s.count == 1 { - s.mean = val + vb.count += 1 + if vb.count == 1 { + vb.mean = val return nil } - newMean := calcOnlineMean(s.mean, val, s.count) - s.std2 = calcOnlineVar2(s.mean, newMean, s.std2, val) - s.mean = newMean + newMean := vb.mean + (val - vb.mean) / float64(vb.count) + vb.std2 = vb.std2 + (val - vb.mean) * (val - newMean) + vb.mean = newMean return nil } +// Dispose implements the Disposable interface. +func (vb *varAggBase) Dispose() {} + +type stdDevPopBuffer struct { + varAggBase +} + +func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { + return &stdDevPopBuffer{ + varAggBase: varAggBase { + expr: child, + }, + } +} + // Eval implements the AggregationBuffer interface. func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { if s.count == 0 { @@ -737,49 +729,18 @@ func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { return math.Sqrt(s.std2 / float64(s.count)), nil } -// Dispose implements the Disposable interface. -func (s *stdDevPopBuffer) Dispose() {} - type stdDevSampBuffer struct { - vals []interface{} - expr sql.Expression - - count uint64 - mean float64 - std2 float64 + varAggBase } func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { return &stdDevSampBuffer{ - vals: nil, - expr: child, + varAggBase: varAggBase { + expr: child, + }, } } -// Update implements the AggregationBuffer interface. -func (s *stdDevSampBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := evalFloat64(ctx, row, s.expr) - if err != nil { - return err - } - if v == nil { - return nil - } - val := v.(float64) - - s.count += 1 - if s.count == 1 { - s.mean = val - return nil - } - - newMean := calcOnlineMean(s.mean, val, s.count) - s.std2 = calcOnlineVar2(s.mean, newMean, s.std2, val) - s.mean = newMean - - return nil -} - // Eval implements the AggregationBuffer interface. func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { if s.count <= 1 { @@ -788,49 +749,18 @@ func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { return math.Sqrt(s.std2 / float64(s.count-1)), nil } -// Dispose implements the Disposable interface. -func (s *stdDevSampBuffer) Dispose() {} - type varPopBuffer struct { - vals []interface{} - expr sql.Expression - - count uint64 - mean float64 - std2 float64 + varAggBase } func NewVarPopBuffer(child sql.Expression) *varPopBuffer { return &varPopBuffer{ - vals: nil, - expr: child, + varAggBase: varAggBase { + expr: child, + }, } } -// Update implements the AggregationBuffer interface. -func (vp *varPopBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := evalFloat64(ctx, row, vp.expr) - if err != nil { - return err - } - if v == nil { - return nil - } - val := v.(float64) - - vp.count += 1 - if vp.count == 1 { - vp.mean = val - return nil - } - - newMean := calcOnlineMean(vp.mean, val, vp.count) - vp.std2 = calcOnlineVar2(vp.mean, newMean, vp.std2, val) - vp.mean = newMean - - return nil -} - // Eval implements the AggregationBuffer interface. func (vp *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { if vp.count == 0 { @@ -839,49 +769,18 @@ func (vp *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { return vp.std2 / float64(vp.count), nil } -// Dispose implements the Disposable interface. -func (vp *varPopBuffer) Dispose() {} - type varSampBuffer struct { - vals []interface{} - expr sql.Expression - - count uint64 - mean float64 - std2 float64 + varAggBase } func NewVarSampBuffer(child sql.Expression) *varSampBuffer { return &varSampBuffer{ - vals: nil, - expr: child, + varAggBase: varAggBase{ + expr: child, + }, } } -// Update implements the AggregationBuffer interface. -func (vp *varSampBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := evalFloat64(ctx, row, vp.expr) - if err != nil { - return err - } - if v == nil { - return nil - } - val := v.(float64) - - vp.count += 1 - if vp.count == 1 { - vp.mean = val - return nil - } - - newMean := calcOnlineMean(vp.mean, val, vp.count) - vp.std2 = calcOnlineVar2(vp.mean, newMean, vp.std2, val) - vp.mean = newMean - - return nil -} - // Eval implements the AggregationBuffer interface. func (vp *varSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { if vp.count <= 1 { @@ -889,6 +788,3 @@ func (vp *varSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { } return vp.std2 / float64(vp.count-1), nil } - -// Dispose implements the Disposable interface. -func (vp *varSampBuffer) Dispose() {} From 925cc388aea21851f3da56682fbf3c5dd618d310 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 10 Apr 2025 10:54:21 -0700 Subject: [PATCH 18/19] clean up --- .../function/aggregation/unary_agg_buffers.go | 22 +++++++++---------- .../function/aggregation/window_functions.go | 1 - 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 2d43025e5c..451c916bf0 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -668,7 +668,7 @@ func (j *jsonArrayBuffer) Eval(ctx *sql.Context) (interface{}, error) { func (j *jsonArrayBuffer) Dispose() { } -type varAggBase struct { +type varBaseBuffer struct { vals []interface{} expr sql.Expression @@ -678,7 +678,7 @@ type varAggBase struct { } // Update implements the AggregationBuffer interface. -func (vb *varAggBase) Update(ctx *sql.Context, row sql.Row) error { +func (vb *varBaseBuffer) Update(ctx *sql.Context, row sql.Row) error { v, err := vb.expr.Eval(ctx, row) if err != nil { return err @@ -707,15 +707,15 @@ func (vb *varAggBase) Update(ctx *sql.Context, row sql.Row) error { } // Dispose implements the Disposable interface. -func (vb *varAggBase) Dispose() {} +func (vb *varBaseBuffer) Dispose() {} type stdDevPopBuffer struct { - varAggBase + varBaseBuffer } func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { return &stdDevPopBuffer{ - varAggBase: varAggBase { + varBaseBuffer: varBaseBuffer { expr: child, }, } @@ -730,12 +730,12 @@ func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { } type stdDevSampBuffer struct { - varAggBase + varBaseBuffer } func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { return &stdDevSampBuffer{ - varAggBase: varAggBase { + varBaseBuffer: varBaseBuffer { expr: child, }, } @@ -750,12 +750,12 @@ func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { } type varPopBuffer struct { - varAggBase + varBaseBuffer } func NewVarPopBuffer(child sql.Expression) *varPopBuffer { return &varPopBuffer{ - varAggBase: varAggBase { + varBaseBuffer: varBaseBuffer { expr: child, }, } @@ -770,12 +770,12 @@ func (vp *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { } type varSampBuffer struct { - varAggBase + varBaseBuffer } func NewVarSampBuffer(child sql.Expression) *varSampBuffer { return &varSampBuffer{ - varAggBase: varAggBase{ + varBaseBuffer: varBaseBuffer{ expr: child, }, } diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index f591070f6b..3265781d75 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1468,7 +1468,6 @@ func computeStd2(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBu if err != nil { return 0, err } - // TODO: consider saving conversions to avoid double Converts val, _, err = types.Float64.Convert(ctx, val) if err != nil { val = 0.0 From 6614a2a9ef5e96ad098f28625701bf8768a7537d Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 10 Apr 2025 17:55:44 +0000 Subject: [PATCH 19/19] [ga-format-pr] Run ./format_repo.sh to fix formatting --- .../function/aggregation/unary_agg_buffers.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 451c916bf0..6ca26669cc 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -699,8 +699,8 @@ func (vb *varBaseBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } - newMean := vb.mean + (val - vb.mean) / float64(vb.count) - vb.std2 = vb.std2 + (val - vb.mean) * (val - newMean) + newMean := vb.mean + (val-vb.mean)/float64(vb.count) + vb.std2 = vb.std2 + (val-vb.mean)*(val-newMean) vb.mean = newMean return nil @@ -715,7 +715,7 @@ type stdDevPopBuffer struct { func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { return &stdDevPopBuffer{ - varBaseBuffer: varBaseBuffer { + varBaseBuffer: varBaseBuffer{ expr: child, }, } @@ -735,7 +735,7 @@ type stdDevSampBuffer struct { func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { return &stdDevSampBuffer{ - varBaseBuffer: varBaseBuffer { + varBaseBuffer: varBaseBuffer{ expr: child, }, } @@ -755,7 +755,7 @@ type varPopBuffer struct { func NewVarPopBuffer(child sql.Expression) *varPopBuffer { return &varPopBuffer{ - varBaseBuffer: varBaseBuffer { + varBaseBuffer: varBaseBuffer{ expr: child, }, }