Skip to content

Commit 9d99755

Browse files
committed
Pearson correlation.
1 parent 9d75c39 commit 9d99755

File tree

4 files changed

+53
-27
lines changed

4 files changed

+53
-27
lines changed

ext/stats/stats.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// - var_samp: sample variance
88
// - covar_pop: population covariance
99
// - covar_samp: sample covariance
10+
// - corr: correlation coefficient
1011
//
1112
// See: [ANSI SQL Aggregate Functions]
1213
//
@@ -24,13 +25,15 @@ func Register(db *sqlite3.Conn) {
2425
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
2526
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
2627
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
28+
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
2729
}
2830

2931
const (
3032
var_pop = iota
3133
var_samp
3234
stddev_pop
3335
stddev_samp
36+
corr
3437
)
3538

3639
func newVariance(kind int) func() sqlite3.AggregateFunction {
@@ -85,6 +88,8 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
8588
r = fn.covar_pop()
8689
case var_samp:
8790
r = fn.covar_samp()
91+
case corr:
92+
r = fn.correlation()
8893
}
8994
ctx.ResultFloat(r)
9095
}

ext/stats/stats_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,20 @@ func TestRegister_covariance(t *testing.T) {
102102
}
103103

104104
stmt, _, err := db.Prepare(`SELECT
105-
covar_samp(x, y), covar_pop(x, y) FROM data`)
105+
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
106106
if err != nil {
107107
t.Fatal(err)
108108
}
109109
defer stmt.Close()
110110

111111
if stmt.Step() {
112-
if got := stmt.ColumnFloat(0); got != 21.25 {
112+
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
113+
t.Errorf("got %v, want 0.9881049293224639", got)
114+
}
115+
if got := stmt.ColumnFloat(1); got != 21.25 {
113116
t.Errorf("got %v, want 21.25", got)
114117
}
115-
if got := stmt.ColumnFloat(1); got != 17 {
118+
if got := stmt.ColumnFloat(2); got != 17 {
116119
t.Errorf("got %v, want 17", got)
117120
}
118121
}

ext/stats/welford.go

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,36 +48,48 @@ func (w *welford) dequeue(x float64) {
4848
}
4949

5050
type welford2 struct {
51-
x, y, c kahan
52-
n uint64
51+
m1x, m2x kahan
52+
m1y, m2y kahan
53+
cov kahan
54+
n uint64
5355
}
5456

5557
func (w welford2) covar_pop() float64 {
56-
return w.c.hi / float64(w.n)
58+
return w.cov.hi / float64(w.n)
5759
}
5860

5961
func (w welford2) covar_samp() float64 {
60-
return w.c.hi / float64(w.n-1) // Bessel's correction
62+
return w.cov.hi / float64(w.n-1) // Bessel's correction
63+
}
64+
65+
func (w welford2) correlation() float64 {
66+
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
6167
}
6268

6369
func (w *welford2) enqueue(x, y float64) {
6470
w.n++
65-
dx := x - w.x.hi - w.x.lo
66-
dy := y - w.y.hi - w.y.lo
67-
w.x.add(dx / float64(w.n))
68-
w.y.add(dy / float64(w.n))
69-
d2 := y - w.y.hi - w.y.lo
70-
w.c.add(dx * d2)
71+
d1x := x - w.m1x.hi - w.m1x.lo
72+
d1y := y - w.m1y.hi - w.m1y.lo
73+
w.m1x.add(d1x / float64(w.n))
74+
w.m1y.add(d1y / float64(w.n))
75+
d2x := x - w.m1x.hi - w.m1x.lo
76+
d2y := y - w.m1y.hi - w.m1y.lo
77+
w.m2x.add(d1x * d2x)
78+
w.m2y.add(d1y * d2y)
79+
w.cov.add(d1x * d2y)
7180
}
7281

7382
func (w *welford2) dequeue(x, y float64) {
7483
w.n--
75-
dx := x - w.x.hi - w.x.lo
76-
dy := y - w.y.hi - w.y.lo
77-
w.x.sub(dx / float64(w.n))
78-
w.y.sub(dy / float64(w.n))
79-
d2 := y - w.y.hi - w.y.lo
80-
w.c.sub(dx * d2)
84+
d1x := x - w.m1x.hi - w.m1x.lo
85+
d1y := y - w.m1y.hi - w.m1y.lo
86+
w.m1x.sub(d1x / float64(w.n))
87+
w.m1y.sub(d1y / float64(w.n))
88+
d2x := x - w.m1x.hi - w.m1x.lo
89+
d2y := y - w.m1y.hi - w.m1y.lo
90+
w.m2x.sub(d1x * d2x)
91+
w.m2y.sub(d1y * d2y)
92+
w.cov.sub(d1x * d2y)
8193
}
8294

8395
type kahan struct{ hi, lo float64 }

ext/stats/welford_test.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ func Test_welford(t *testing.T) {
3232
s2.enqueue(7)
3333
s2.enqueue(13)
3434
s2.enqueue(16)
35-
s1.m1.lo, s2.m1.lo = 0, 0
36-
s1.m2.lo, s2.m2.lo = 0, 0
37-
if s1 != s2 {
35+
if s1.var_pop() != s2.var_pop() {
3836
t.Errorf("got %v, want %v", s1, s2)
3937
}
4038
}
@@ -60,10 +58,18 @@ func Test_covar(t *testing.T) {
6058
c2.enqueue(2, 60)
6159
c2.enqueue(7, 90)
6260
c2.enqueue(4, 75)
63-
c1.x.lo, c2.x.lo = 0, 0
64-
c1.y.lo, c2.y.lo = 0, 0
65-
c1.c.lo, c2.c.lo = 0, 0
66-
if c1 != c2 {
67-
t.Errorf("got %v, want %v", c1, c2)
61+
if c1.covar_pop() != c2.covar_pop() {
62+
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
63+
}
64+
}
65+
66+
func Test_correlation(t *testing.T) {
67+
var c welford2
68+
c.enqueue(1, 3)
69+
c.enqueue(2, 2)
70+
c.enqueue(3, 1)
71+
72+
if got := c.correlation(); got != -1 {
73+
t.Errorf("got %v, want -1", got)
6874
}
6975
}

0 commit comments

Comments
 (0)