diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index e8debaa83b..e169d0901c 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7887,6 +7887,226 @@ where }, }, }, + { + Name: "std, stdev, stddev_pop, variance, var_pop, var_samp tests", + Dialect: "mysql", + SetUpScript: []string{ + "create table t (i int);", + "create table tt (i int, j int);", + "insert into tt values (0, 1), (0, 2), (0, 3);", + "insert into tt values (1, 123), (1, 456), (1, 789);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", + Expected: []sql.Row{ + {nil, nil, nil, nil}, + }, + }, + { + Query: "select variance(i), var_pop(i), var_samp(i) from t;", + Expected: []sql.Row{ + {nil, nil, nil}, + }, + }, + { + Query: "insert into t values (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", + Expected: []sql.Row{ + {0.0, 0.0, 0.0, nil}, + }, + }, + { + Query: "select variance(i), var_pop(i), var_samp(i) from t;", + Expected: []sql.Row{ + {0.0, 0.0, nil}, + }, + }, + { + Query: "insert into t values (2);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", + Expected: []sql.Row{ + {0.5, 0.5, 0.5, 0.7071067811865476}, + }, + }, + { + Query: "select variance(i), var_pop(i), var_samp(i) from t;", + Expected: []sql.Row{ + {0.25, 0.25, 0.5}, + }, + }, + { + Query: "insert into t values (3);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", + Expected: []sql.Row{ + {0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0}, + }, + }, + { + Query: "select variance(i), var_pop(i), var_samp(i) from t;", + Expected: []sql.Row{ + {0.6666666666666666, 0.6666666666666666, 1.0}, + }, + }, + { + Query: "insert into t values (null), (null);", + Expected: []sql.Row{ + {types.NewOkResult(2)}, + }, + }, + { + Query: "select std(i), stddev(i), stddev_pop(i), stddev_samp(i) from t;", + Expected: []sql.Row{ + {0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0}, + }, + }, + { + Query: "select variance(i), var_pop(i), var_samp(i) from t;", + Expected: []sql.Row{ + {0.6666666666666666, 0.6666666666666666, 1.0}, + }, + }, + { + Query: "select i, std(j), stddev_samp(j) from tt group by i;", + Expected: []sql.Row{ + {0, 0.816496580927726, 1.0}, + {1, 271.89336144893275, 333.0}, + }, + }, + { + Query: "select i, variance(i), var_samp(i) from tt group by i;", + Expected: []sql.Row{ + {0, 0.0, 0.0}, + {1, 0.0, 0.0}, + }, + }, + { + Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", + Expected: []sql.Row{ + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + {0.5, 297.47660972475353, 325.86929895281634}, + }, + }, + { + Query: "select i, std(j) over(partition by i), stddev_samp(j) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, + }, + }, + { + Query: "select i, variance(i) over(), var_samp(i) over() from tt order by i;", + Expected: []sql.Row{ + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, + }, + }, + { + Query: "select i, variance(j) over(partition by i), var_samp(i) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, + }, + }, + { + Query: "insert into tt values (null, null);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select i, std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;", + Expected: []sql.Row{ + {nil, 0.5, 297.47660972475353, 325.86929895281634}, + {0, 0.5, 297.47660972475353, 325.86929895281634}, + {0, 0.5, 297.47660972475353, 325.86929895281634}, + {0, 0.5, 297.47660972475353, 325.86929895281634}, + {1, 0.5, 297.47660972475353, 325.86929895281634}, + {1, 0.5, 297.47660972475353, 325.86929895281634}, + {1, 0.5, 297.47660972475353, 325.86929895281634}, + }, + }, + { + Query: "select i, std(j) over(partition by i), stddev_samp(j) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {nil, nil, nil}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {0, 0.816496580927726, 1.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, + {1, 271.89336144893275, 333.0}, + }, + }, + { + Query: "select i, variance(i) over(), var_samp(i) over() from tt order by i;", + Expected: []sql.Row{ + {nil, 0.25, 0.3}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {0, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, + {1, 0.25, 0.3}, + }, + }, + { + Query: "select i, variance(j) over(partition by i), var_samp(i) over(partition by i) from tt order by i;", + Expected: []sql.Row{ + {nil, nil, nil}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {0, 0.6666666666666666, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, + {1, 73926.0, 0.0}, + }, + }, + { + Query: "select i, stddev_pop(j) over w, stddev_samp(j) over w, variance(j) over w, var_samp(i) over w from tt window w as (partition by i) order by i;", + Expected: []sql.Row{ + {nil, nil, nil, nil, nil}, + {0, 0.816496580927726, 1.0, 0.6666666666666666, 0.0}, + {0, 0.816496580927726, 1.0, 0.6666666666666666, 0.0}, + {0, 0.816496580927726, 1.0, 0.6666666666666666, 0.0}, + {1, 271.89336144893275, 333.0, 73926.0, 0.0}, + {1, 271.89336144893275, 333.0, 73926.0, 0.0}, + {1, 271.89336144893275, 333.0, 73926.0, 0.0}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/go.mod b/go.mod index a68cf4f068..aed4846e2d 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250409183615-d8335325e91c + github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 0ae08658e6..391e1a6454 100644 --- a/go.sum +++ b/go.sum @@ -58,10 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3 h1:euU+adNAYw46Zcp1HnoaSDWhqjfaL8s/1SPU+i16gYM= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250409183615-d8335325e91c h1:hl+yPanHdJML9aMB0MgrTCpzsd3jIf/o3r8pC6Tqx6E= -github.com/dolthub/vitess v0.0.0-20250409183615-d8335325e91c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4 h1:LGTt2LtYX8vaai32d+c9L0sMcP+Dg9w1kO6+lbsxxYg= +github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/optgen/cmd/source/unary_aggs.yaml b/optgen/cmd/source/unary_aggs.yaml index 6f58490e12..7e40fe762f 100644 --- a/optgen/cmd/source/unary_aggs.yaml +++ b/optgen/cmd/source/unary_aggs.yaml @@ -31,4 +31,12 @@ unaryAggs: desc: "returns the minimum value of expr in all rows." - name: "Sum" desc: "returns the sum of expr in all rows" - nullable: false \ No newline at end of file + nullable: false +- name: "StdDevPop" + desc: "returns the population standard deviation of expr" +- name: "StdDevSamp" + desc: "returns the sample standard deviation of expr" +- name: "VarPop" + desc: "returns the population variance of expr" +- name: "VarSamp" + desc: "returns the sample variance of expr" \ No newline at end of file diff --git a/sql/expression/function/aggregation/common.go b/sql/expression/function/aggregation/common.go index 6637685dfe..ee738fd261 100644 --- a/sql/expression/function/aggregation/common.go +++ b/sql/expression/function/aggregation/common.go @@ -125,11 +125,11 @@ func (a *unaryAggBase) WithChildren(children ...sql.Expression) (sql.Expression, return &na, nil } -func (a unaryAggBase) FunctionName() string { +func (a *unaryAggBase) FunctionName() string { return a.functionName } -func (a unaryAggBase) Description() string { +func (a *unaryAggBase) Description() string { return a.description } diff --git a/sql/expression/function/aggregation/std_test.go b/sql/expression/function/aggregation/std_test.go new file mode 100644 index 0000000000..9f694bfb4e --- /dev/null +++ b/sql/expression/function/aggregation/std_test.go @@ -0,0 +1,333 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregation + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" +) + +func isFloatEqual(a, b float64) bool { + return math.Abs(a-b) < 1e-9 +} + +func TestStd(t *testing.T) { + sum := NewStdDevPop(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.118033988749895, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 0.9601432184835761, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.0825317547305484, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 1.0, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 1.0, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 1.0, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) + }) + } +} + +func TestStdSamp(t *testing.T) { + sum := NewStdDevSamp(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.2909944487358056, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 1.1086778913041726, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.25, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 1.4142135623730951, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 1.4142135623730951, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 1.4142135623730951, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) + }) + } +} + +func TestVariance(t *testing.T) { + sum := NewVarPop(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.25, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 0.9218750000000001, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.171875, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 1.0, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 1.0, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 1.0, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) + }) + } +} + +func TestVarSamp(t *testing.T) { + sum := NewVarSamp(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + 1.6666666666666667, + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + 1.2291666666666667, + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + 1.5625, + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + 2.0, + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + 2.0, + }, + { + "int32 and nil values", + []sql.Row{{int32(1)}, {int32(3)}, {nil}}, + 2.0, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + buf, _ := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(buf.Update(ctx, row)) + } + + result, err := buf.Eval(sql.NewEmptyContext()) + require.NoError(err) + if tt.expected == nil { + require.Equal(tt.expected, nil) + return + } + require.True(isFloatEqual(tt.expected.(float64), result.(float64))) + }) + } +} diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index c06037530e..6ca26669cc 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -2,6 +2,7 @@ package aggregation import ( "fmt" + "math" "reflect" "github.com/cespare/xxhash/v2" @@ -666,3 +667,124 @@ func (j *jsonArrayBuffer) Eval(ctx *sql.Context) (interface{}, error) { // Dispose implements the Disposable interface. func (j *jsonArrayBuffer) Dispose() { } + +type varBaseBuffer struct { + vals []interface{} + expr sql.Expression + + count uint64 + mean float64 + std2 float64 +} + +// Update implements the AggregationBuffer interface. +func (vb *varBaseBuffer) Update(ctx *sql.Context, row sql.Row) error { + v, err := vb.expr.Eval(ctx, row) + if err != nil { + return err + } + v, _, err = types.Float64.Convert(ctx, v) + if err != nil { + v = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v) + } + if v == nil { + return nil + } + val := v.(float64) + + vb.count += 1 + if vb.count == 1 { + vb.mean = val + return nil + } + + newMean := vb.mean + (val-vb.mean)/float64(vb.count) + vb.std2 = vb.std2 + (val-vb.mean)*(val-newMean) + vb.mean = newMean + + return nil +} + +// Dispose implements the Disposable interface. +func (vb *varBaseBuffer) Dispose() {} + +type stdDevPopBuffer struct { + varBaseBuffer +} + +func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer { + return &stdDevPopBuffer{ + varBaseBuffer: varBaseBuffer{ + expr: child, + }, + } +} + +// Eval implements the AggregationBuffer interface. +func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if s.count == 0 { + return nil, nil + } + return math.Sqrt(s.std2 / float64(s.count)), nil +} + +type stdDevSampBuffer struct { + varBaseBuffer +} + +func NewStdDevSampBuffer(child sql.Expression) *stdDevSampBuffer { + return &stdDevSampBuffer{ + varBaseBuffer: varBaseBuffer{ + expr: child, + }, + } +} + +// Eval implements the AggregationBuffer interface. +func (s *stdDevSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if s.count <= 1 { + return nil, nil + } + return math.Sqrt(s.std2 / float64(s.count-1)), nil +} + +type varPopBuffer struct { + varBaseBuffer +} + +func NewVarPopBuffer(child sql.Expression) *varPopBuffer { + return &varPopBuffer{ + varBaseBuffer: varBaseBuffer{ + expr: child, + }, + } +} + +// Eval implements the AggregationBuffer interface. +func (vp *varPopBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if vp.count == 0 { + return nil, nil + } + return vp.std2 / float64(vp.count), nil +} + +type varSampBuffer struct { + varBaseBuffer +} + +func NewVarSampBuffer(child sql.Expression) *varSampBuffer { + return &varSampBuffer{ + varBaseBuffer: varBaseBuffer{ + expr: child, + }, + } +} + +// Eval implements the AggregationBuffer interface. +func (vp *varSampBuffer) Eval(ctx *sql.Context) (interface{}, error) { + if vp.count <= 1 { + return nil, nil + } + return vp.std2 / float64(vp.count-1), nil +} diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 7e3736a7e7..a5094cc975 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -958,3 +958,319 @@ func (a *Sum) NewWindowFunction() (sql.WindowFunction, error) { } return NewSumAgg(child).WithWindow(a.Window()) } + +type StdDevPop struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*StdDevPop)(nil) +var _ sql.Aggregation = (*StdDevPop)(nil) +var _ sql.WindowAdaptableExpression = (*StdDevPop)(nil) + +func NewStdDevPop(e sql.Expression) *StdDevPop { + return &StdDevPop{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "StdDevPop", + description: "returns the population standard deviation of expr", + }, + } +} + +func (a *StdDevPop) Type() sql.Type { + return a.Child.Type() +} + +func (a *StdDevPop) IsNullable() bool { + return false +} + +func (a *StdDevPop) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEVPOP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEVPOP(%s)", a.Child) +} + +func (a *StdDevPop) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEVPOP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEVPOP(%s)", sql.DebugString(a.Child)) +} + +func (a *StdDevPop) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &StdDevPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevPop) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &StdDevPop{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *StdDevPop) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &StdDevPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevPop) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevPopBuffer(child), nil +} + +func (a *StdDevPop) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevPopAgg(child).WithWindow(a.Window()) +} + +type StdDevSamp struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*StdDevSamp)(nil) +var _ sql.Aggregation = (*StdDevSamp)(nil) +var _ sql.WindowAdaptableExpression = (*StdDevSamp)(nil) + +func NewStdDevSamp(e sql.Expression) *StdDevSamp { + return &StdDevSamp{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "StdDevSamp", + description: "returns the sample standard deviation of expr", + }, + } +} + +func (a *StdDevSamp) Type() sql.Type { + return a.Child.Type() +} + +func (a *StdDevSamp) IsNullable() bool { + return false +} + +func (a *StdDevSamp) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEVSAMP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEVSAMP(%s)", a.Child) +} + +func (a *StdDevSamp) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("STDDEVSAMP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("STDDEVSAMP(%s)", sql.DebugString(a.Child)) +} + +func (a *StdDevSamp) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &StdDevSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevSamp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &StdDevSamp{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *StdDevSamp) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &StdDevSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *StdDevSamp) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevSampBuffer(child), nil +} + +func (a *StdDevSamp) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewStdDevSampAgg(child).WithWindow(a.Window()) +} + +type VarPop struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*VarPop)(nil) +var _ sql.Aggregation = (*VarPop)(nil) +var _ sql.WindowAdaptableExpression = (*VarPop)(nil) + +func NewVarPop(e sql.Expression) *VarPop { + return &VarPop{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "VarPop", + description: "returns the population variance of expr", + }, + } +} + +func (a *VarPop) Type() sql.Type { + return a.Child.Type() +} + +func (a *VarPop) IsNullable() bool { + return false +} + +func (a *VarPop) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARPOP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARPOP(%s)", a.Child) +} + +func (a *VarPop) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARPOP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARPOP(%s)", sql.DebugString(a.Child)) +} + +func (a *VarPop) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &VarPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarPop) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &VarPop{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *VarPop) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &VarPop{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarPop) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarPopBuffer(child), nil +} + +func (a *VarPop) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarPopAgg(child).WithWindow(a.Window()) +} + +type VarSamp struct { + unaryAggBase +} + +var _ sql.FunctionExpression = (*VarSamp)(nil) +var _ sql.Aggregation = (*VarSamp)(nil) +var _ sql.WindowAdaptableExpression = (*VarSamp)(nil) + +func NewVarSamp(e sql.Expression) *VarSamp { + return &VarSamp{ + unaryAggBase{ + UnaryExpression: expression.UnaryExpression{Child: e}, + functionName: "VarSamp", + description: "returns the sample variance of expr", + }, + } +} + +func (a *VarSamp) Type() sql.Type { + return a.Child.Type() +} + +func (a *VarSamp) IsNullable() bool { + return false +} + +func (a *VarSamp) String() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARSAMP") + children := []string{a.window.String(), a.Child.String()} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARSAMP(%s)", a.Child) +} + +func (a *VarSamp) DebugString() string { + if a.window != nil { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("VARSAMP") + children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} + pr.WriteChildren(children...) + return pr.String() + } + return fmt.Sprintf("VARSAMP(%s)", sql.DebugString(a.Child)) +} + +func (a *VarSamp) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { + res := a.unaryAggBase.WithWindow(window) + return &VarSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarSamp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + res, err := a.unaryAggBase.WithChildren(children...) + return &VarSamp{unaryAggBase: *res.(*unaryAggBase)}, err +} + +func (a *VarSamp) WithId(id sql.ColumnId) sql.IdExpression { + res := a.unaryAggBase.WithId(id) + return &VarSamp{unaryAggBase: *res.(*unaryAggBase)} +} + +func (a *VarSamp) NewBuffer() (sql.AggregationBuffer, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarSampBuffer(child), nil +} + +func (a *VarSamp) NewWindowFunction() (sql.WindowFunction, error) { + child, err := transform.Clone(a.Child) + if err != nil { + return nil, err + } + return NewVarSampAgg(child).WithWindow(a.Window()) +} diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 7fd0c8e530..3265781d75 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -15,6 +15,7 @@ package aggregation import ( + "math" "sort" "strings" @@ -1408,3 +1409,329 @@ func (a *leadLagBase) Compute(ctx *sql.Context, interval sql.WindowInterval, buf a.pos++ return res } + +type StdDevPopAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewStdDevPopAgg(e sql.Expression) *StdDevPopAgg { + return &StdDevPopAgg{ + expr: e, + } +} + +func (s *StdDevPopAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *s + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (s *StdDevPopAgg) Dispose() { + expression.Dispose(s.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (s *StdDevPopAgg) DefaultFramer() sql.WindowFramer { + if s.framer != nil { + return s.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + s.Dispose() + s.partitionStart = interval.Start + s.partitionEnd = interval.End + var err error + s.prefixSum, s.nullCnt, err = floatPrefixSum(ctx, interval, buf, s.expr) + return err +} + +func computeStd2(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer, expr sql.Expression, m float64) (float64, error) { + var v float64 + for i := interval.Start; i < interval.End; i++ { + row := buf[i] + val, err := expr.Eval(ctx, row) + if err != nil { + return 0, err + } + val, _, err = types.Float64.Convert(ctx, val) + if err != nil { + val = 0.0 + ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", val) + } + if val == nil { + continue + } + dv := val.(float64) - m + v += dv * dv + } + return v, nil +} + +func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - s.partitionStart - 1 + endIdx := interval.End - s.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= s.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += s.nullCnt[startIdx] + } + if nonNullCnt == 0 { + return nil + } + + m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) + s2, err := computeStd2(ctx, interval, buf, s.expr, m) + if err != nil { + return err + } + + return math.Sqrt(s2 / float64(nonNullCnt)) +} + +type StdDevSampAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewStdDevSampAgg(e sql.Expression) *StdDevSampAgg { + return &StdDevSampAgg{ + expr: e, + } +} + +func (s *StdDevSampAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *s + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (s *StdDevSampAgg) Dispose() { + expression.Dispose(s.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (s *StdDevSampAgg) DefaultFramer() sql.WindowFramer { + if s.framer != nil { + return s.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (s *StdDevSampAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + s.Dispose() + s.partitionStart = interval.Start + s.partitionEnd = interval.End + var err error + s.prefixSum, s.nullCnt, err = floatPrefixSum(ctx, interval, buf, s.expr) + return err +} + +func (s *StdDevSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - s.partitionStart - 1 + endIdx := interval.End - s.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= s.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += s.nullCnt[startIdx] + } + if nonNullCnt <= 1 { + return nil + } + + m := computePrefixSum(interval, s.partitionStart, s.prefixSum) / float64(nonNullCnt) + s2, err := computeStd2(ctx, interval, buf, s.expr, m) + if err != nil { + return err + } + + return math.Sqrt(s2 / float64(nonNullCnt-1)) +} + +type VarPopAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewVarPopAgg(e sql.Expression) *VarPopAgg { + return &VarPopAgg{ + expr: e, + } +} + +func (v *VarPopAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *v + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (v *VarPopAgg) Dispose() { + expression.Dispose(v.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (v *VarPopAgg) DefaultFramer() sql.WindowFramer { + if v.framer != nil { + return v.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (v *VarPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + v.Dispose() + v.partitionStart = interval.Start + v.partitionEnd = interval.End + var err error + v.prefixSum, v.nullCnt, err = floatPrefixSum(ctx, interval, buf, v.expr) + return err +} + +func (v *VarPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - v.partitionStart - 1 + endIdx := interval.End - v.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= v.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += v.nullCnt[startIdx] + } + if nonNullCnt <= 0 { + return nil + } + + m := computePrefixSum(interval, v.partitionStart, v.prefixSum) / float64(nonNullCnt) + s2, err := computeStd2(ctx, interval, buf, v.expr, m) + if err != nil { + return err + } + + return s2 / float64(nonNullCnt) +} + +type VarSampAgg struct { + expr sql.Expression + framer sql.WindowFramer + + partitionStart int + partitionEnd int + + prefixSum []float64 + nullCnt []int +} + +func NewVarSampAgg(e sql.Expression) *VarSampAgg { + return &VarSampAgg{ + expr: e, + } +} + +func (v *VarSampAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) { + ns := *v + if w.Frame != nil { + framer, err := w.Frame.NewFramer(w) + if err != nil { + return nil, err + } + ns.framer = framer + } + return &ns, nil +} + +func (v *VarSampAgg) Dispose() { + expression.Dispose(v.expr) +} + +// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer +func (v *VarSampAgg) DefaultFramer() sql.WindowFramer { + if v.framer != nil { + return v.framer + } + return NewUnboundedPrecedingToCurrentRowFramer() +} + +func (v *VarSampAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error { + v.Dispose() + v.partitionStart = interval.Start + v.partitionEnd = interval.End + var err error + v.prefixSum, v.nullCnt, err = floatPrefixSum(ctx, interval, buf, v.expr) + return err +} + +func (v *VarSampAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} { + startIdx := interval.Start - v.partitionStart - 1 + endIdx := interval.End - v.partitionStart - 1 + + var nonNullCnt int + if endIdx >= 0 { + nonNullCnt += endIdx + 1 + nonNullCnt -= v.nullCnt[endIdx] + } + if startIdx >= 0 { + nonNullCnt -= startIdx + 1 + nonNullCnt += v.nullCnt[startIdx] + } + if nonNullCnt <= 1 { + return nil + } + + m := computePrefixSum(interval, v.partitionStart, v.prefixSum) / float64(nonNullCnt) + s2, err := computeStd2(ctx, interval, buf, v.expr, m) + if err != nil { + return err + } + + return s2 / float64(nonNullCnt-1) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 53050e846f..478d47e5ff 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -289,6 +289,10 @@ var BuiltIns = []sql.Function{ sql.FunctionN{Name: "substring", Fn: NewSubstring}, sql.Function3{Name: "substring_index", Fn: NewSubstringIndex}, sql.Function1{Name: "sum", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) }}, + sql.Function1{Name: "std", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, + sql.Function1{Name: "stddev", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, + sql.Function1{Name: "stddev_pop", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevPop(e) }}, + sql.Function1{Name: "stddev_samp", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewStdDevSamp(e) }}, sql.FunctionN{Name: "sysdate", Fn: NewSysdate}, sql.Function1{Name: "tan", Fn: NewTan}, sql.Function1{Name: "time", Fn: NewTime}, @@ -312,6 +316,9 @@ var BuiltIns = []sql.Function{ sql.FunctionN{Name: "week", Fn: NewWeek}, sql.Function1{Name: "values", Fn: NewValues}, sql.Function1{Name: "validate_password_strength", Fn: NewValidatePasswordStrength}, + sql.Function1{Name: "variance", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }}, + sql.Function1{Name: "var_pop", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }}, + sql.Function1{Name: "var_samp", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarSamp(e) }}, sql.Function2{Name: "vec_distance", Fn: vector.NewL2SquaredDistance}, sql.Function2{Name: "vec_distance_l2_squared", Fn: vector.NewL2SquaredDistance}, sql.Function1{Name: "weekday", Fn: NewWeekday}, diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 1459c6e32c..c5f0b13d9f 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -504,7 +504,9 @@ func isWindowFunc(name string) bool { "avg", "max", "min", "count_distinct", "json_arrayagg", "row_number", "percent_rank", "lead", "lag", "first_value", "last_value", - "rank", "dense_rank": + "rank", "dense_rank", + "std", "stddev", "stddev_pop", "stddev_samp", + "variance", "var_pop", "var_samp": return true default: return false