Skip to content

Commit a4e8d45

Browse files
author
James Cor
committed
implment variance and tests
1 parent 0d79ae9 commit a4e8d45

File tree

8 files changed

+499
-61
lines changed

8 files changed

+499
-61
lines changed

enginetest/queries/script_queries.go

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7888,7 +7888,7 @@ where
78887888
},
78897889
},
78907890
{
7891-
Name: "std, stdev, stddev_pop tests",
7891+
Name: "std, stdev, stddev_pop, variance, var_pop, var_samp tests",
78927892
Dialect: "mysql",
78937893
SetUpScript: []string{
78947894
"create table t (i int);",
@@ -7903,6 +7903,12 @@ where
79037903
{nil, nil, nil, nil},
79047904
},
79057905
},
7906+
{
7907+
Query: "select variance(i), var_pop(i) from t;",
7908+
Expected: []sql.Row{
7909+
{nil, nil},
7910+
},
7911+
},
79067912
{
79077913
Query: "insert into t values (1);",
79087914
Expected: []sql.Row{
@@ -7915,6 +7921,12 @@ where
79157921
{0.0, 0.0, 0.0, nil},
79167922
},
79177923
},
7924+
{
7925+
Query: "select variance(i), var_pop(i) from t;",
7926+
Expected: []sql.Row{
7927+
{0.0, 0.0},
7928+
},
7929+
},
79187930
{
79197931
Query: "insert into t values (2);",
79207932
Expected: []sql.Row{
@@ -7927,6 +7939,12 @@ where
79277939
{0.5, 0.5, 0.5, 0.7071067811865476},
79287940
},
79297941
},
7942+
{
7943+
Query: "select variance(i), var_pop(i) from t;",
7944+
Expected: []sql.Row{
7945+
{0.25, 0.25},
7946+
},
7947+
},
79307948
{
79317949
Query: "insert into t values (3);",
79327950
Expected: []sql.Row{
@@ -7939,6 +7957,12 @@ where
79397957
{0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0},
79407958
},
79417959
},
7960+
{
7961+
Query: "select variance(i), var_pop(i) from t;",
7962+
Expected: []sql.Row{
7963+
{0.6666666666666666, 0.6666666666666666},
7964+
},
7965+
},
79427966
{
79437967
Query: "insert into t values (null), (null);",
79447968
Expected: []sql.Row{
@@ -7951,13 +7975,26 @@ where
79517975
{0.816496580927726, 0.816496580927726, 0.816496580927726, 1.0},
79527976
},
79537977
},
7978+
{
7979+
Query: "select variance(i), var_pop(i) from t;",
7980+
Expected: []sql.Row{
7981+
{0.6666666666666666, 0.6666666666666666},
7982+
},
7983+
},
79547984
{
79557985
Query: "select i, std(j), stddev_samp(j) from tt group by i;",
79567986
Expected: []sql.Row{
79577987
{0, 0.816496580927726, 1.0},
79587988
{1, 271.89336144893275, 333.0},
79597989
},
79607990
},
7991+
{
7992+
Query: "select i, variance(i) from tt group by i;",
7993+
Expected: []sql.Row{
7994+
{0, 0.0},
7995+
{1, 0.0},
7996+
},
7997+
},
79617998
{
79627999
Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;",
79638000
Expected: []sql.Row{
@@ -7980,22 +8017,44 @@ where
79808017
{1, 271.89336144893275, 333.0},
79818018
},
79828019
},
8020+
{
8021+
Query: "select i, variance(i) over() from tt order by i;",
8022+
Expected: []sql.Row{
8023+
{0, 0.25},
8024+
{0, 0.25},
8025+
{0, 0.25},
8026+
{1, 0.25},
8027+
{1, 0.25},
8028+
{1, 0.25},
8029+
},
8030+
},
8031+
{
8032+
Query: "select i, variance(j) over(partition by i) from tt order by i;",
8033+
Expected: []sql.Row{
8034+
{0, 0.6666666666666666},
8035+
{0, 0.6666666666666666},
8036+
{0, 0.6666666666666666},
8037+
{1, 73926.0},
8038+
{1, 73926.0},
8039+
{1, 73926.0},
8040+
},
8041+
},
79838042
{
79848043
Query: "insert into tt values (null, null);",
79858044
Expected: []sql.Row{
79868045
{types.NewOkResult(1)},
79878046
},
79888047
},
79898048
{
7990-
Query: "select std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;",
8049+
Query: "select i, std(i) over(), std(j) over(), stddev_samp(j) over() from tt order by i;",
79918050
Expected: []sql.Row{
7992-
{0.5, 297.47660972475353, 325.86929895281634},
7993-
{0.5, 297.47660972475353, 325.86929895281634},
7994-
{0.5, 297.47660972475353, 325.86929895281634},
7995-
{0.5, 297.47660972475353, 325.86929895281634},
7996-
{0.5, 297.47660972475353, 325.86929895281634},
7997-
{0.5, 297.47660972475353, 325.86929895281634},
7998-
{0.5, 297.47660972475353, 325.86929895281634},
8051+
{nil, 0.5, 297.47660972475353, 325.86929895281634},
8052+
{0, 0.5, 297.47660972475353, 325.86929895281634},
8053+
{0, 0.5, 297.47660972475353, 325.86929895281634},
8054+
{0, 0.5, 297.47660972475353, 325.86929895281634},
8055+
{1, 0.5, 297.47660972475353, 325.86929895281634},
8056+
{1, 0.5, 297.47660972475353, 325.86929895281634},
8057+
{1, 0.5, 297.47660972475353, 325.86929895281634},
79998058
},
80008059
},
80018060
{
@@ -8010,6 +8069,30 @@ where
80108069
{1, 271.89336144893275, 333.0},
80118070
},
80128071
},
8072+
{
8073+
Query: "select i, variance(i) over() from tt order by i;",
8074+
Expected: []sql.Row{
8075+
{nil, 0.25},
8076+
{0, 0.25},
8077+
{0, 0.25},
8078+
{0, 0.25},
8079+
{1, 0.25},
8080+
{1, 0.25},
8081+
{1, 0.25},
8082+
},
8083+
},
8084+
{
8085+
Query: "select i, variance(j) over(partition by i) from tt order by i;",
8086+
Expected: []sql.Row{
8087+
{nil, nil},
8088+
{0, 0.6666666666666666},
8089+
{0, 0.6666666666666666},
8090+
{0, 0.6666666666666666},
8091+
{1, 73926.0},
8092+
{1, 73926.0},
8093+
{1, 73926.0},
8094+
},
8095+
},
80138096
},
80148097
},
80158098
}

optgen/cmd/source/unary_aggs.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ unaryAggs:
3535
- name: "StdDevPop"
3636
desc: "returns the population standard deviation of expr"
3737
- name: "StdDevSamp"
38-
desc: "returns the sample standard deviation of expr"
38+
desc: "returns the sample standard deviation of expr"
39+
- name: "VarPop"
40+
desc: "returns the population variance of expr"

sql/expression/function/aggregation/std_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,147 @@ func TestStd(t *testing.T) {
9494
})
9595
}
9696
}
97+
98+
func TestStdSamp(t *testing.T) {
99+
sum := NewStdDevSamp(expression.NewGetField(0, nil, "", false))
100+
101+
testCases := []struct {
102+
name string
103+
rows []sql.Row
104+
expected interface{}
105+
}{
106+
{
107+
"string int values",
108+
[]sql.Row{{"1"}, {"2"}, {"3"}, {"4"}},
109+
1.2909944487358056,
110+
},
111+
{
112+
"string float values",
113+
[]sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}},
114+
1.1086778913041726,
115+
},
116+
{
117+
"string non-int values",
118+
[]sql.Row{{"a"}, {"b"}, {"c"}, {"d"}},
119+
float64(0),
120+
},
121+
{
122+
"float values",
123+
[]sql.Row{{1.}, {2.5}, {3.}, {4.}},
124+
1.25,
125+
},
126+
{
127+
"no rows",
128+
[]sql.Row{},
129+
nil,
130+
},
131+
{
132+
"nil values",
133+
[]sql.Row{{nil}, {nil}},
134+
nil,
135+
},
136+
{
137+
"int64 values",
138+
[]sql.Row{{int64(1)}, {int64(3)}},
139+
1.4142135623730951,
140+
},
141+
{
142+
"int32 values",
143+
[]sql.Row{{int32(1)}, {int32(3)}},
144+
1.4142135623730951,
145+
},
146+
{
147+
"int32 and nil values",
148+
[]sql.Row{{int32(1)}, {int32(3)}, {nil}},
149+
1.4142135623730951,
150+
},
151+
}
152+
153+
for _, tt := range testCases {
154+
t.Run(tt.name, func(t *testing.T) {
155+
require := require.New(t)
156+
157+
ctx := sql.NewEmptyContext()
158+
buf, _ := sum.NewBuffer()
159+
for _, row := range tt.rows {
160+
require.NoError(buf.Update(ctx, row))
161+
}
162+
163+
result, err := buf.Eval(sql.NewEmptyContext())
164+
require.NoError(err)
165+
require.Equal(tt.expected, result)
166+
})
167+
}
168+
}
169+
170+
func TestVariance(t *testing.T) {
171+
sum := NewVarPop(expression.NewGetField(0, nil, "", false))
172+
173+
testCases := []struct {
174+
name string
175+
rows []sql.Row
176+
expected interface{}
177+
}{
178+
{
179+
"string int values",
180+
[]sql.Row{{"1"}, {"2"}, {"3"}, {"4"}},
181+
1.25,
182+
},
183+
{
184+
"string float values",
185+
[]sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}},
186+
0.9218750000000001,
187+
},
188+
{
189+
"string non-int values",
190+
[]sql.Row{{"a"}, {"b"}, {"c"}, {"d"}},
191+
float64(0),
192+
},
193+
{
194+
"float values",
195+
[]sql.Row{{1.}, {2.5}, {3.}, {4.}},
196+
1.171875,
197+
},
198+
{
199+
"no rows",
200+
[]sql.Row{},
201+
nil,
202+
},
203+
{
204+
"nil values",
205+
[]sql.Row{{nil}, {nil}},
206+
nil,
207+
},
208+
{
209+
"int64 values",
210+
[]sql.Row{{int64(1)}, {int64(3)}},
211+
1.0,
212+
},
213+
{
214+
"int32 values",
215+
[]sql.Row{{int32(1)}, {int32(3)}},
216+
1.0,
217+
},
218+
{
219+
"int32 and nil values",
220+
[]sql.Row{{int32(1)}, {int32(3)}, {nil}},
221+
1.0,
222+
},
223+
}
224+
225+
for _, tt := range testCases {
226+
t.Run(tt.name, func(t *testing.T) {
227+
require := require.New(t)
228+
229+
ctx := sql.NewEmptyContext()
230+
buf, _ := sum.NewBuffer()
231+
for _, row := range tt.rows {
232+
require.NoError(buf.Update(ctx, row))
233+
}
234+
235+
result, err := buf.Eval(sql.NewEmptyContext())
236+
require.NoError(err)
237+
require.Equal(tt.expected, result)
238+
})
239+
}
240+
}

0 commit comments

Comments
 (0)