Skip to content

Commit 746a849

Browse files
committed
Covariance.
1 parent 312d3b5 commit 746a849

File tree

5 files changed

+194
-34
lines changed

5 files changed

+194
-34
lines changed

ext/stats/stats.go

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// Package stats provides aggregate functions for statistics.
22
//
33
// Functions:
4-
// - var_samp: sample variance
5-
// - var_pop: population variance
6-
// - stddev_samp: sample standard deviation
74
// - stddev_pop: population standard deviation
5+
// - stddev_samp: sample standard deviation
6+
// - var_pop: population variance
7+
// - var_samp: sample variance
8+
// - covar_pop: population covariance
9+
// - covar_samp: sample covariance
810
//
911
// See: [ANSI SQL Aggregate Functions]
1012
//
@@ -16,10 +18,12 @@ import "github.com/ncruces/go-sqlite3"
1618
// Register registers statistics functions.
1719
func Register(db *sqlite3.Conn) {
1820
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
19-
db.CreateWindowFunction("var_pop", 1, flags, create(var_pop))
20-
db.CreateWindowFunction("var_samp", 1, flags, create(var_samp))
21-
db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop))
22-
db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp))
21+
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
22+
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
23+
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
24+
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
25+
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
26+
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
2327
}
2428

2529
const (
@@ -29,38 +33,72 @@ const (
2933
stddev_samp
3034
)
3135

32-
func create(kind int) func() sqlite3.AggregateFunction {
33-
return func() sqlite3.AggregateFunction { return &state{kind: kind} }
36+
func newVariance(kind int) func() sqlite3.AggregateFunction {
37+
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
3438
}
3539

36-
type state struct {
40+
type variance struct {
3741
kind int
3842
welford
3943
}
4044

41-
func (f *state) Value(ctx sqlite3.Context) {
45+
func (fn *variance) Value(ctx sqlite3.Context) {
4246
var r float64
43-
switch f.kind {
47+
switch fn.kind {
4448
case var_pop:
45-
r = f.var_pop()
49+
r = fn.var_pop()
4650
case var_samp:
47-
r = f.var_samp()
51+
r = fn.var_samp()
4852
case stddev_pop:
49-
r = f.stddev_pop()
53+
r = fn.stddev_pop()
5054
case stddev_samp:
51-
r = f.stddev_samp()
55+
r = fn.stddev_samp()
5256
}
5357
ctx.ResultFloat(r)
5458
}
5559

56-
func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
60+
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
5761
if a := arg[0]; a.Type() != sqlite3.NULL {
58-
f.enqueue(a.Float())
62+
fn.enqueue(a.Float())
5963
}
6064
}
6165

62-
func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
66+
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
6367
if a := arg[0]; a.Type() != sqlite3.NULL {
64-
f.dequeue(a.Float())
68+
fn.dequeue(a.Float())
69+
}
70+
}
71+
72+
func newCovariance(kind int) func() sqlite3.AggregateFunction {
73+
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
74+
}
75+
76+
type covariance struct {
77+
kind int
78+
welford2
79+
}
80+
81+
func (fn *covariance) Value(ctx sqlite3.Context) {
82+
var r float64
83+
switch fn.kind {
84+
case var_pop:
85+
r = fn.covar_pop()
86+
case var_samp:
87+
r = fn.covar_samp()
88+
}
89+
ctx.ResultFloat(r)
90+
}
91+
92+
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
93+
a, b := arg[0], arg[1]
94+
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
95+
fn.enqueue(a.Float(), b.Float())
96+
}
97+
}
98+
99+
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
100+
a, b := arg[0], arg[1]
101+
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
102+
fn.dequeue(a.Float(), b.Float())
65103
}
66104
}

ext/stats/stats_test.go

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
_ "github.com/ncruces/go-sqlite3/embed"
99
)
1010

11-
func TestRegister(t *testing.T) {
11+
func TestRegister_variance(t *testing.T) {
1212
t.Parallel()
1313

1414
db, err := sqlite3.Open(":memory:")
@@ -19,20 +19,22 @@ func TestRegister(t *testing.T) {
1919

2020
Register(db)
2121

22-
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`)
22+
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
2323
if err != nil {
2424
t.Fatal(err)
2525
}
2626

27-
err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`)
27+
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
2828
if err != nil {
2929
t.Fatal(err)
3030
}
3131

32-
stmt, _, err := db.Prepare(`SELECT
33-
sum(col), avg(col),
34-
var_samp(col), var_pop(col),
35-
stddev_samp(col), stddev_pop(col) FROM data`)
32+
stmt, _, err := db.Prepare(`
33+
SELECT
34+
sum(x), avg(x),
35+
var_samp(x), var_pop(x),
36+
stddev_samp(x), stddev_pop(x)
37+
FROM data`)
3638
if err != nil {
3739
t.Fatal(err)
3840
}
@@ -60,7 +62,7 @@ func TestRegister(t *testing.T) {
6062
}
6163

6264
{
63-
stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`)
65+
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
6466
if err != nil {
6567
t.Fatal(err)
6668
}
@@ -77,3 +79,59 @@ func TestRegister(t *testing.T) {
7779
}
7880
}
7981
}
82+
83+
func TestRegister_covariance(t *testing.T) {
84+
t.Parallel()
85+
86+
db, err := sqlite3.Open(":memory:")
87+
if err != nil {
88+
t.Fatal(err)
89+
}
90+
defer db.Close()
91+
92+
Register(db)
93+
94+
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
95+
if err != nil {
96+
t.Fatal(err)
97+
}
98+
99+
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
100+
if err != nil {
101+
t.Fatal(err)
102+
}
103+
104+
stmt, _, err := db.Prepare(`SELECT
105+
covar_samp(x, y), covar_pop(x, y) FROM data`)
106+
if err != nil {
107+
t.Fatal(err)
108+
}
109+
defer stmt.Close()
110+
111+
if stmt.Step() {
112+
if got := stmt.ColumnFloat(0); got != 21.25 {
113+
t.Errorf("got %v, want 21.25", got)
114+
}
115+
if got := stmt.ColumnFloat(1); got != 17 {
116+
t.Errorf("got %v, want 17", got)
117+
}
118+
}
119+
120+
{
121+
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
122+
if err != nil {
123+
t.Fatal(err)
124+
}
125+
defer stmt.Close()
126+
127+
want := [...]float64{0, 10, 30, 75, 22.5}
128+
for i := 0; stmt.Step(); i++ {
129+
if got := stmt.ColumnFloat(0); got != want[i] {
130+
t.Errorf("got %v, want %v", got, want[i])
131+
}
132+
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
133+
t.Errorf("got %v, want %v", got, want[i])
134+
}
135+
}
136+
}
137+
}

ext/stats/welford.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func (w welford) var_pop() float64 {
2020
}
2121

2222
func (w welford) var_samp() float64 {
23-
return w.m2.hi / float64(w.n-1)
23+
return w.m2.hi / float64(w.n-1) // Bessel's correction
2424
}
2525

2626
func (w welford) stddev_pop() float64 {
@@ -33,20 +33,53 @@ func (w welford) stddev_samp() float64 {
3333

3434
func (w *welford) enqueue(x float64) {
3535
w.n++
36-
d1 := x - w.m1.hi
36+
d1 := x - w.m1.hi - w.m1.lo
3737
w.m1.add(d1 / float64(w.n))
38-
d2 := x - w.m1.hi
38+
d2 := x - w.m1.hi - w.m1.lo
3939
w.m2.add(d1 * d2)
4040
}
4141

4242
func (w *welford) dequeue(x float64) {
4343
w.n--
44-
d1 := x - w.m1.hi
44+
d1 := x - w.m1.hi - w.m1.lo
4545
w.m1.sub(d1 / float64(w.n))
46-
d2 := x - w.m1.hi
46+
d2 := x - w.m1.hi - w.m1.lo
4747
w.m2.sub(d1 * d2)
4848
}
4949

50+
type welford2 struct {
51+
x, y, c kahan
52+
n uint64
53+
}
54+
55+
func (w welford2) covar_pop() float64 {
56+
return w.c.hi / float64(w.n)
57+
}
58+
59+
func (w welford2) covar_samp() float64 {
60+
return w.c.hi / float64(w.n-1) // Bessel's correction
61+
}
62+
63+
func (w *welford2) enqueue(x, y float64) {
64+
w.n++
65+
dx := x - w.x.hi - w.x.lo
66+
dy := y - w.y.hi - w.y.lo
67+
w.x.add(dx / float64(w.n))
68+
w.y.add(dy / float64(w.n))
69+
d2 := y - w.y.hi - w.y.lo
70+
w.c.add(dx * d2)
71+
}
72+
73+
func (w *welford2) dequeue(x, y float64) {
74+
w.n--
75+
dx := x - w.x.hi - w.x.lo
76+
dy := y - w.y.hi - w.y.lo
77+
w.x.sub(dx / float64(w.n))
78+
w.y.sub(dy / float64(w.n))
79+
d2 := y - w.y.hi - w.y.lo
80+
w.c.sub(dx * d2)
81+
}
82+
5083
type kahan struct{ hi, lo float64 }
5184

5285
func (k *kahan) add(x float64) {

ext/stats/welford_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,38 @@ func Test_welford(t *testing.T) {
3232
s2.enqueue(7)
3333
s2.enqueue(13)
3434
s2.enqueue(16)
35+
s1.m1.lo, s2.m1.lo = 0, 0
36+
s1.m2.lo, s2.m2.lo = 0, 0
3537
if s1 != s2 {
3638
t.Errorf("got %v, want %v", s1, s2)
3739
}
3840
}
41+
42+
func Test_covar(t *testing.T) {
43+
var c1, c2 welford2
44+
45+
c1.enqueue(3, 70)
46+
c1.enqueue(5, 80)
47+
c1.enqueue(2, 60)
48+
c1.enqueue(7, 90)
49+
c1.enqueue(4, 75)
50+
51+
if got := c1.covar_samp(); got != 21.25 {
52+
t.Errorf("got %v, want 21.25", got)
53+
}
54+
if got := c1.covar_pop(); got != 17 {
55+
t.Errorf("got %v, want 17", got)
56+
}
57+
58+
c1.dequeue(3, 70)
59+
c2.enqueue(5, 80)
60+
c2.enqueue(2, 60)
61+
c2.enqueue(7, 90)
62+
c2.enqueue(4, 75)
63+
c1.x.lo, c2.x.lo = 0, 0
64+
c1.y.lo, c2.y.lo = 0, 0
65+
c1.c.lo, c2.c.lo = 0, 0
66+
if c1 != c2 {
67+
t.Errorf("got %v, want %v", c1, c2)
68+
}
69+
}

func.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
// for any unknown collating sequence.
1313
// The fake collating function works like BINARY.
1414
//
15-
// This extension can be used to load schemas that contain
15+
// This can be used to load schemas that contain
1616
// one or more unknown collating sequences.
1717
func (c *Conn) AnyCollationNeeded() {
1818
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)

0 commit comments

Comments
 (0)