Skip to content

Commit 4bc7fb6

Browse files
author
James Cor
committed
implment std and tests
1 parent e7b7ae1 commit 4bc7fb6

File tree

8 files changed

+352
-4
lines changed

8 files changed

+352
-4
lines changed

enginetest/queries/script_queries.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7887,6 +7887,57 @@ where
78877887
},
78887888
},
78897889
},
7890+
{
7891+
Name: "std, stdev, stddev_pop tests",
7892+
Dialect: "mysql",
7893+
SetUpScript: []string{
7894+
"create table t (i int);",
7895+
},
7896+
Assertions: []ScriptTestAssertion{
7897+
{
7898+
Query: "select std(i), stddev(i), stddev_pop(i) from t;",
7899+
Expected: []sql.Row{
7900+
{nil, nil, nil},
7901+
},
7902+
},
7903+
{
7904+
Query: "insert into t values (1);",
7905+
Expected: []sql.Row{
7906+
{types.NewOkResult(1)},
7907+
},
7908+
},
7909+
{
7910+
Query: "select std(i), stddev(i), stddev_pop(i) from t;",
7911+
Expected: []sql.Row{
7912+
{0.0, 0.0, 0.0},
7913+
},
7914+
},
7915+
{
7916+
Query: "insert into t values (2);",
7917+
Expected: []sql.Row{
7918+
{types.NewOkResult(1)},
7919+
},
7920+
},
7921+
{
7922+
Query: "select std(i), stddev(i), stddev_pop(i) from t;",
7923+
Expected: []sql.Row{
7924+
{0.5, 0.5, 0.5},
7925+
},
7926+
},
7927+
{
7928+
Query: "insert into t values (3);",
7929+
Expected: []sql.Row{
7930+
{types.NewOkResult(1)},
7931+
},
7932+
},
7933+
{
7934+
Query: "select std(i), stddev(i), stddev_pop(i) from t;",
7935+
Expected: []sql.Row{
7936+
{0.816496580927726, 0.816496580927726, 0.816496580927726},
7937+
},
7938+
},
7939+
},
7940+
},
78907941
}
78917942

78927943
var SpatialScriptTests = []ScriptTest{

optgen/cmd/source/unary_aggs.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@ unaryAggs:
3131
desc: "returns the minimum value of expr in all rows."
3232
- name: "Sum"
3333
desc: "returns the sum of expr in all rows"
34-
nullable: false
34+
nullable: false
35+
- name: "StdDevPop"
36+
desc: "returns the population standard deviation of expr"

sql/expression/function/aggregation/common.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ func (a *unaryAggBase) WithChildren(children ...sql.Expression) (sql.Expression,
125125
return &na, nil
126126
}
127127

128-
func (a unaryAggBase) FunctionName() string {
128+
func (a *unaryAggBase) FunctionName() string {
129129
return a.functionName
130130
}
131131

132-
func (a unaryAggBase) Description() string {
132+
func (a *unaryAggBase) Description() string {
133133
return a.description
134134
}
135135

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package aggregation
16+
17+
import (
18+
"testing"
19+
20+
"github.com/stretchr/testify/require"
21+
22+
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/expression"
24+
)
25+
26+
func TestStd(t *testing.T) {
27+
sum := NewStdDevPop(expression.NewGetField(0, nil, "", false))
28+
29+
testCases := []struct {
30+
name string
31+
rows []sql.Row
32+
expected interface{}
33+
}{
34+
{
35+
"string int values",
36+
[]sql.Row{{"1"}, {"2"}, {"3"}, {"4"}},
37+
1.118033988749895,
38+
},
39+
{
40+
"string float values",
41+
[]sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}},
42+
0.9601432184835761,
43+
},
44+
{
45+
"string non-int values",
46+
[]sql.Row{{"a"}, {"b"}, {"c"}, {"d"}},
47+
float64(0),
48+
},
49+
{
50+
"float values",
51+
[]sql.Row{{1.}, {2.5}, {3.}, {4.}},
52+
1.0825317547305484,
53+
},
54+
{
55+
"no rows",
56+
[]sql.Row{},
57+
nil,
58+
},
59+
{
60+
"nil values",
61+
[]sql.Row{{nil}, {nil}},
62+
nil,
63+
},
64+
{
65+
"int64 values",
66+
[]sql.Row{{int64(1)}, {int64(3)}},
67+
1.0,
68+
},
69+
{
70+
"int32 values",
71+
[]sql.Row{{int32(1)}, {int32(3)}},
72+
1.0,
73+
},
74+
{
75+
"int32 and nil values",
76+
[]sql.Row{{int32(1)}, {int32(3)}, {nil}},
77+
1.0,
78+
},
79+
}
80+
81+
for _, tt := range testCases {
82+
t.Run(tt.name, func(t *testing.T) {
83+
require := require.New(t)
84+
85+
ctx := sql.NewEmptyContext()
86+
buf, _ := sum.NewBuffer()
87+
for _, row := range tt.rows {
88+
require.NoError(buf.Update(ctx, row))
89+
}
90+
91+
result, err := buf.Eval(sql.NewEmptyContext())
92+
require.NoError(err)
93+
require.Equal(tt.expected, result)
94+
})
95+
}
96+
}
97+

sql/expression/function/aggregation/unary_agg_buffers.go

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ package aggregation
22

33
import (
44
"fmt"
5-
"reflect"
5+
"math"
6+
"reflect"
67

78
"github.com/cespare/xxhash/v2"
89
"github.com/shopspring/decimal"
@@ -666,3 +667,66 @@ func (j *jsonArrayBuffer) Eval(ctx *sql.Context) (interface{}, error) {
666667
// Dispose implements the Disposable interface.
667668
func (j *jsonArrayBuffer) Dispose() {
668669
}
670+
671+
type stdDevPopBuffer struct {
672+
vals []interface{}
673+
expr sql.Expression
674+
675+
count int64
676+
oldMean float64
677+
newMean float64
678+
oldVar float64
679+
newVar float64
680+
}
681+
682+
func NewStdDevPopBuffer(child sql.Expression) *stdDevPopBuffer {
683+
return &stdDevPopBuffer{
684+
vals: nil,
685+
expr: child,
686+
}
687+
}
688+
689+
// Update implements the AggregationBuffer interface.
690+
func (s *stdDevPopBuffer) Update(ctx *sql.Context, row sql.Row) error {
691+
v, err := s.expr.Eval(ctx, row)
692+
if err != nil {
693+
return err
694+
}
695+
696+
// TODO: convert val to appropriate type
697+
v, _, err = types.Float64.Convert(ctx, v)
698+
if err != nil {
699+
v = 0.0
700+
ctx.Warn(1292, "Truncated incorrect DOUBLE value: %s", v)
701+
}
702+
if v == nil {
703+
return nil
704+
}
705+
val := v.(float64)
706+
707+
s.count += 1
708+
if s.count == 1 {
709+
s.oldMean = val
710+
s.newMean = val
711+
return nil
712+
}
713+
714+
s.newMean = s.oldMean + (val - s.oldMean) / float64(s.count)
715+
s.newVar = s.oldVar + (val - s.oldMean) * (val - s.newMean)
716+
s.oldVar = s.newVar
717+
s.oldMean = s.newMean
718+
719+
return nil
720+
}
721+
722+
// Eval implements the AggregationBuffer interface.
723+
func (s *stdDevPopBuffer) Eval(ctx *sql.Context) (interface{}, error) {
724+
if s.count == 0 {
725+
return nil, nil
726+
}
727+
return math.Sqrt(s.newVar / float64(s.count)), nil // TODO: sqrt?
728+
}
729+
730+
// Dispose implements the Disposable interface.
731+
func (s *stdDevPopBuffer) Dispose() {
732+
}

sql/expression/function/aggregation/unary_aggs.og.go

Lines changed: 79 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sql/expression/function/aggregation/window_functions.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,3 +1408,55 @@ func (a *leadLagBase) Compute(ctx *sql.Context, interval sql.WindowInterval, buf
14081408
a.pos++
14091409
return res
14101410
}
1411+
1412+
type StdDevPopAgg struct {
1413+
expr sql.Expression
1414+
framer sql.WindowFramer
1415+
}
1416+
1417+
func NewStdDevPopAgg(e sql.Expression) *StdDevPopAgg {
1418+
return &StdDevPopAgg{
1419+
expr: e,
1420+
}
1421+
}
1422+
1423+
func (s *StdDevPopAgg) WithWindow(w *sql.WindowDefinition) (sql.WindowFunction, error) {
1424+
ns := *s
1425+
if w.Frame != nil {
1426+
framer, err := w.Frame.NewFramer(w)
1427+
if err != nil {
1428+
return nil, err
1429+
}
1430+
ns.framer = framer
1431+
}
1432+
return &ns, nil
1433+
}
1434+
1435+
func (s *StdDevPopAgg) Dispose() {
1436+
expression.Dispose(s.expr)
1437+
}
1438+
1439+
// DefaultFramer returns a NewUnboundedPrecedingToCurrentRowFramer
1440+
func (s *StdDevPopAgg) DefaultFramer() sql.WindowFramer {
1441+
if s.framer != nil {
1442+
return s.framer
1443+
}
1444+
return NewUnboundedPrecedingToCurrentRowFramer()
1445+
}
1446+
1447+
func (s *StdDevPopAgg) StartPartition(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) error {
1448+
s.Dispose()
1449+
return nil
1450+
}
1451+
1452+
func (s *StdDevPopAgg) Compute(ctx *sql.Context, interval sql.WindowInterval, buf sql.WindowBuffer) interface{} {
1453+
for i := interval.Start; i < interval.End; i++ {
1454+
row := buf[i]
1455+
v, err := s.expr.Eval(ctx, row)
1456+
if err != nil {
1457+
return err
1458+
}
1459+
return v
1460+
}
1461+
return nil
1462+
}

0 commit comments

Comments
 (0)