Skip to content

Commit 312d3b5

Browse files
committed
Statistics functions.
1 parent b71cd29 commit 312d3b5

File tree

6 files changed

+258
-8
lines changed

6 files changed

+258
-8
lines changed

ext/stats/stats.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Package stats provides aggregate functions for statistics.
2+
//
3+
// Functions:
4+
// - var_samp: sample variance
5+
// - var_pop: population variance
6+
// - stddev_samp: sample standard deviation
7+
// - stddev_pop: population standard deviation
8+
//
9+
// See: [ANSI SQL Aggregate Functions]
10+
//
11+
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
12+
package stats
13+
14+
import "github.com/ncruces/go-sqlite3"
15+
16+
// Register registers statistics functions.
17+
func Register(db *sqlite3.Conn) {
18+
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
19+
db.CreateWindowFunction("var_pop", 1, flags, create(var_pop))
20+
db.CreateWindowFunction("var_samp", 1, flags, create(var_samp))
21+
db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop))
22+
db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp))
23+
}
24+
25+
const (
26+
var_pop = iota
27+
var_samp
28+
stddev_pop
29+
stddev_samp
30+
)
31+
32+
func create(kind int) func() sqlite3.AggregateFunction {
33+
return func() sqlite3.AggregateFunction { return &state{kind: kind} }
34+
}
35+
36+
type state struct {
37+
kind int
38+
welford
39+
}
40+
41+
func (f *state) Value(ctx sqlite3.Context) {
42+
var r float64
43+
switch f.kind {
44+
case var_pop:
45+
r = f.var_pop()
46+
case var_samp:
47+
r = f.var_samp()
48+
case stddev_pop:
49+
r = f.stddev_pop()
50+
case stddev_samp:
51+
r = f.stddev_samp()
52+
}
53+
ctx.ResultFloat(r)
54+
}
55+
56+
func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
57+
if a := arg[0]; a.Type() != sqlite3.NULL {
58+
f.enqueue(a.Float())
59+
}
60+
}
61+
62+
func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
63+
if a := arg[0]; a.Type() != sqlite3.NULL {
64+
f.dequeue(a.Float())
65+
}
66+
}

ext/stats/stats_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package stats
2+
3+
import (
4+
"math"
5+
"testing"
6+
7+
"github.com/ncruces/go-sqlite3"
8+
_ "github.com/ncruces/go-sqlite3/embed"
9+
)
10+
11+
func TestRegister(t *testing.T) {
12+
t.Parallel()
13+
14+
db, err := sqlite3.Open(":memory:")
15+
if err != nil {
16+
t.Fatal(err)
17+
}
18+
defer db.Close()
19+
20+
Register(db)
21+
22+
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`)
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
27+
err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`)
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
32+
stmt, _, err := db.Prepare(`SELECT
33+
sum(col), avg(col),
34+
var_samp(col), var_pop(col),
35+
stddev_samp(col), stddev_pop(col) FROM data`)
36+
if err != nil {
37+
t.Fatal(err)
38+
}
39+
defer stmt.Close()
40+
41+
if stmt.Step() {
42+
if got := stmt.ColumnFloat(0); got != 40 {
43+
t.Errorf("got %v, want 40", got)
44+
}
45+
if got := stmt.ColumnFloat(1); got != 10 {
46+
t.Errorf("got %v, want 10", got)
47+
}
48+
if got := stmt.ColumnFloat(2); got != 30 {
49+
t.Errorf("got %v, want 30", got)
50+
}
51+
if got := stmt.ColumnFloat(3); got != 22.5 {
52+
t.Errorf("got %v, want 22.5", got)
53+
}
54+
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
55+
t.Errorf("got %v, want √30", got)
56+
}
57+
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
58+
t.Errorf("got %v, want √22.5", got)
59+
}
60+
}
61+
62+
{
63+
stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`)
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
defer stmt.Close()
68+
69+
want := [...]float64{0, 4.5, 18, 0, 0}
70+
for i := 0; stmt.Step(); i++ {
71+
if got := stmt.ColumnFloat(0); got != want[i] {
72+
t.Errorf("got %v, want %v", got, want[i])
73+
}
74+
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
75+
t.Errorf("got %v, want %v", got, want[i])
76+
}
77+
}
78+
}
79+
}

ext/stats/welford.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package stats
2+
3+
import "math"
4+
5+
// Welford's algorithm with Kahan summation:
6+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
7+
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
8+
9+
type welford struct {
10+
m1, m2 kahan
11+
n uint64
12+
}
13+
14+
func (w welford) average() float64 {
15+
return w.m1.hi
16+
}
17+
18+
func (w welford) var_pop() float64 {
19+
return w.m2.hi / float64(w.n)
20+
}
21+
22+
func (w welford) var_samp() float64 {
23+
return w.m2.hi / float64(w.n-1)
24+
}
25+
26+
func (w welford) stddev_pop() float64 {
27+
return math.Sqrt(w.var_pop())
28+
}
29+
30+
func (w welford) stddev_samp() float64 {
31+
return math.Sqrt(w.var_samp())
32+
}
33+
34+
func (w *welford) enqueue(x float64) {
35+
w.n++
36+
d1 := x - w.m1.hi
37+
w.m1.add(d1 / float64(w.n))
38+
d2 := x - w.m1.hi
39+
w.m2.add(d1 * d2)
40+
}
41+
42+
func (w *welford) dequeue(x float64) {
43+
w.n--
44+
d1 := x - w.m1.hi
45+
w.m1.sub(d1 / float64(w.n))
46+
d2 := x - w.m1.hi
47+
w.m2.sub(d1 * d2)
48+
}
49+
50+
type kahan struct{ hi, lo float64 }
51+
52+
func (k *kahan) add(x float64) {
53+
y := k.lo + x
54+
t := k.hi + y
55+
k.lo = y - (t - k.hi)
56+
k.hi = t
57+
}
58+
59+
func (k *kahan) sub(x float64) {
60+
y := k.lo - x
61+
t := k.hi + y
62+
k.lo = y - (t - k.hi)
63+
k.hi = t
64+
}

ext/stats/welford_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package stats
2+
3+
import (
4+
"math"
5+
"testing"
6+
)
7+
8+
func Test_welford(t *testing.T) {
9+
var s1, s2 welford
10+
11+
s1.enqueue(4)
12+
s1.enqueue(7)
13+
s1.enqueue(13)
14+
s1.enqueue(16)
15+
if got := s1.average(); got != 10 {
16+
t.Errorf("got %v, want 10", got)
17+
}
18+
if got := s1.var_samp(); got != 30 {
19+
t.Errorf("got %v, want 30", got)
20+
}
21+
if got := s1.var_pop(); got != 22.5 {
22+
t.Errorf("got %v, want 22.5", got)
23+
}
24+
if got := s1.stddev_samp(); got != math.Sqrt(30) {
25+
t.Errorf("got %v, want √30", got)
26+
}
27+
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
28+
t.Errorf("got %v, want √22.5", got)
29+
}
30+
31+
s1.dequeue(4)
32+
s2.enqueue(7)
33+
s2.enqueue(13)
34+
s2.enqueue(16)
35+
if s1 != s2 {
36+
t.Errorf("got %v, want %v", s1, s2)
37+
}
38+
}

ext/unicode/unicode.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
// Package unicode provides an alternative to the SQLite ICU extension.
22
//
3-
// Provides Unicode aware:
4-
// - upper and lower functions,
3+
// Like the [ICU extension], it provides Unicode aware:
4+
// - upper() and lower() functions,
55
// - LIKE and REGEXP operators,
66
// - collation sequences.
77
//
8-
// This package is not 100% compatible with the ICU extension:
9-
// - upper and lower use [strings.ToUpper], [strings.ToLower] and [cases];
8+
// The implementation is not 100% compatible with the [ICU extension]:
9+
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
1010
// - the LIKE operator follows [strings.EqualFold] rules;
1111
// - the REGEXP operator uses Go [regex/syntax];
1212
// - collation sequences use [collate].
1313
//
1414
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
15+
//
16+
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
1517
package unicode
1618

1719
import (
@@ -45,16 +47,17 @@ func Register(db *sqlite3.Conn) {
4547
return
4648
}
4749

48-
err := RegisterCollation(db, name, arg[0].Text())
50+
err := RegisterCollation(db, arg[0].Text(), name)
4951
if err != nil {
5052
ctx.ResultError(err)
5153
return
5254
}
5355
})
5456
}
5557

56-
func RegisterCollation(db *sqlite3.Conn, name, lang string) error {
57-
tag, err := language.Parse(lang)
58+
// RegisterCollation registers a Unicode collation sequence for a database connection.
59+
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
60+
tag, err := language.Parse(locale)
5861
if err != nil {
5962
return err
6063
}

func_win_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func ExampleConn_CreateWindowFunction() {
2626
log.Fatal(err)
2727
}
2828

29-
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter)
29+
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
3030
if err != nil {
3131
log.Fatal(err)
3232
}

0 commit comments

Comments
 (0)