Skip to content

Commit fa7516c

Browse files
committed
Quantiles.
1 parent dbf93b2 commit fa7516c

File tree

6 files changed

+130
-0
lines changed

6 files changed

+130
-0
lines changed

ext/stats/quantile.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package stats
2+
3+
import (
4+
"math"
5+
"slices"
6+
7+
"github.com/ncruces/go-sqlite3"
8+
"github.com/ncruces/go-sqlite3/internal/util"
9+
"github.com/ncruces/sort/quick"
10+
)
11+
12+
const (
13+
median = iota
14+
quant_cont
15+
quant_disc
16+
)
17+
18+
func newQuantile(kind int) func() sqlite3.AggregateFunction {
19+
return func() sqlite3.AggregateFunction { return &quantile{kind: kind} }
20+
}
21+
22+
type quantile struct {
23+
kind int
24+
pos float64
25+
list []float64
26+
}
27+
28+
func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
29+
if a := arg[0]; a.NumericType() != sqlite3.NULL {
30+
q.list = append(q.list, a.Float())
31+
}
32+
if q.kind != median {
33+
q.pos = arg[1].Float()
34+
}
35+
}
36+
37+
func (q *quantile) Value(ctx sqlite3.Context) {
38+
if q.list == nil {
39+
return
40+
}
41+
if q.kind == median {
42+
q.pos = 0.5
43+
}
44+
if q.pos < 0 || q.pos > 1 {
45+
ctx.ResultError(util.ErrorString("quantile: invalid pos"))
46+
return
47+
}
48+
49+
i, f := math.Modf(q.pos * float64(len(q.list)-1))
50+
m0 := quick.Select(q.list, int(i))
51+
52+
if q.kind == quant_disc {
53+
ctx.ResultFloat(m0)
54+
return
55+
}
56+
57+
m1 := slices.Min(q.list[int(i)+1:])
58+
ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0)))
59+
}

ext/stats/quantile_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package stats_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/ncruces/go-sqlite3"
7+
_ "github.com/ncruces/go-sqlite3/embed"
8+
"github.com/ncruces/go-sqlite3/ext/stats"
9+
_ "github.com/ncruces/go-sqlite3/tests/testcfg"
10+
)
11+
12+
func TestRegister_quantile(t *testing.T) {
13+
t.Parallel()
14+
15+
db, err := sqlite3.Open(":memory:")
16+
if err != nil {
17+
t.Fatal(err)
18+
}
19+
defer db.Close()
20+
21+
stats.Register(db)
22+
23+
err = db.Exec(`CREATE TABLE data (x)`)
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
28+
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
29+
if err != nil {
30+
t.Fatal(err)
31+
}
32+
33+
stmt, _, err := db.Prepare(`
34+
SELECT
35+
median(x),
36+
quantile_disc(x, 0.5),
37+
quantile_cont(x, 0.3)
38+
FROM data`)
39+
if err != nil {
40+
t.Fatal(err)
41+
}
42+
defer stmt.Close()
43+
44+
if stmt.Step() {
45+
if got := stmt.ColumnFloat(0); got != 10 {
46+
t.Errorf("got %v, want 10", got)
47+
}
48+
if got := stmt.ColumnFloat(1); got != 7 {
49+
t.Errorf("got %v, want 7", got)
50+
}
51+
if got := stmt.ColumnFloat(2); got != 6.699999999999999 {
52+
t.Errorf("got %v, want 6.7", got)
53+
}
54+
}
55+
}

ext/stats/stats.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
// - regr_slope: slope of the least-squares-fit linear equation
1919
// - regr_intercept: y-intercept of the least-squares-fit linear equation
2020
// - regr_json: all regr stats in a JSON object
21+
// - median: median value
22+
// - quantile_cont: continuous quantile
23+
// - quantile_disc: discrete quantile
2124
//
2225
// These join the [Built-in Aggregate Functions]:
2326
// - count: count rows/values
@@ -27,9 +30,11 @@
2730
// - max: maximum value
2831
//
2932
// See: [ANSI SQL Aggregate Functions]
33+
// See: [DuckDB Aggregate Functions]
3034
//
3135
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
3236
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
37+
// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html
3338
package stats
3439

3540
import "github.com/ncruces/go-sqlite3"
@@ -54,6 +59,9 @@ func Register(db *sqlite3.Conn) {
5459
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
5560
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
5661
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
62+
db.CreateWindowFunction("median", 1, flags, newQuantile(median))
63+
db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont))
64+
db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc))
5765
}
5866

5967
const (

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ require (
1313
lukechampine.com/adiantum v1.1.1
1414
)
1515

16+
require github.com/ncruces/sort v0.1.2
17+
1618
retract v0.4.0 // tagged from the wrong branch

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
22
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
3+
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=
4+
github.com/ncruces/sort v0.1.2/go.mod h1:vEJUTBJtebIuCMmXD18GKo5GJGhsay+xZFOoBEIXFmE=
35
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
46
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
57
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=

go.work.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
2+
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
3+
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
4+
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=

0 commit comments

Comments
 (0)