Skip to content

Commit d78a53a

Browse files
committed
Multiple quantiles.
1 parent 19bc6e3 commit d78a53a

File tree

2 files changed

+66
-25
lines changed

2 files changed

+66
-25
lines changed

ext/stats/quantile.go

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package stats
22

33
import (
4+
"encoding/json"
5+
"fmt"
46
"math"
57
"slices"
68

@@ -21,39 +23,67 @@ func newQuantile(kind int) func() sqlite3.AggregateFunction {
2123

2224
type quantile struct {
2325
kind int
24-
pos float64
25-
list []float64
26+
nums []float64
27+
arg1 []byte
2628
}
2729

2830
func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
2931
if a := arg[0]; a.NumericType() != sqlite3.NULL {
30-
q.list = append(q.list, a.Float())
32+
q.nums = append(q.nums, a.Float())
3133
}
3234
if q.kind != median {
33-
q.pos = arg[1].Float()
35+
q.arg1 = arg[1].Blob(q.arg1[:0])
3436
}
3537
}
3638

3739
func (q *quantile) Value(ctx sqlite3.Context) {
38-
if len(q.list) == 0 {
40+
if len(q.nums) == 0 {
3941
return
4042
}
43+
44+
var (
45+
err error
46+
float float64
47+
floats []float64
48+
)
4149
if q.kind == median {
42-
q.pos = 0.5
50+
float, err = getQuantile(q.nums, 0.5, false)
51+
ctx.ResultFloat(float)
52+
} else if err = json.Unmarshal(q.arg1, &float); err == nil {
53+
float, err = getQuantile(q.nums, float, q.kind == quant_disc)
54+
ctx.ResultFloat(float)
55+
} else if err = json.Unmarshal(q.arg1, &floats); err == nil {
56+
err = getQuantiles(q.nums, floats, q.kind == quant_disc)
57+
ctx.ResultJSON(floats)
4358
}
44-
if q.pos < 0 || q.pos > 1 {
45-
ctx.ResultError(util.ErrorString("quantile: invalid pos"))
46-
return
59+
if err != nil {
60+
ctx.ResultError(fmt.Errorf("quantile: %w", err))
4761
}
62+
}
4863

49-
i, f := math.Modf(q.pos * float64(len(q.list)-1))
50-
m0 := quick.Select(q.list, int(i))
64+
func getQuantile(nums []float64, pos float64, disc bool) (float64, error) {
65+
if pos < 0 || pos > 1 {
66+
return 0, util.ErrorString("invalid pos")
67+
}
5168

52-
if f == 0 || q.kind == quant_disc {
53-
ctx.ResultFloat(m0)
54-
return
69+
i, f := math.Modf(pos * float64(len(nums)-1))
70+
m0 := quick.Select(nums, int(i))
71+
72+
if f == 0 || disc {
73+
return m0, nil
5574
}
5675

57-
m1 := slices.Min(q.list[int(i)+1:])
58-
ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0)))
76+
m1 := slices.Min(nums[int(i)+1:])
77+
return math.FMA(f, m1, -math.FMA(f, m0, -m0)), nil
78+
}
79+
80+
func getQuantiles(nums []float64, pos []float64, disc bool) error {
81+
for i := range pos {
82+
v, err := getQuantile(nums, pos[i], disc)
83+
if err != nil {
84+
return err
85+
}
86+
pos[i] = v
87+
}
88+
return nil
5989
}

ext/stats/quantile_test.go

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package stats_test
22

33
import (
4+
"slices"
45
"testing"
56

67
"github.com/ncruces/go-sqlite3"
@@ -34,7 +35,7 @@ func TestRegister_quantile(t *testing.T) {
3435
SELECT
3536
median(x),
3637
quantile_disc(x, 0.5),
37-
quantile_cont(x, 0.25)
38+
quantile_cont(x, '[0.25, 0.5, 0.75]')
3839
FROM data`)
3940
if err != nil {
4041
t.Fatal(err)
@@ -46,8 +47,12 @@ func TestRegister_quantile(t *testing.T) {
4647
if got := stmt.ColumnFloat(1); got != 7 {
4748
t.Errorf("got %v, want 7", got)
4849
}
49-
if got := stmt.ColumnFloat(2); got != 6.25 {
50-
t.Errorf("got %v, want 6.25", got)
50+
var got []float64
51+
if err := stmt.ColumnJSON(2, &got); err != nil {
52+
t.Error(err)
53+
}
54+
if !slices.Equal(got, []float64{6.25, 10, 13.75}) {
55+
t.Errorf("got %v, want [6.25 10 13.75]", got)
5156
}
5257
}
5358
stmt.Close()
@@ -56,7 +61,7 @@ func TestRegister_quantile(t *testing.T) {
5661
SELECT
5762
median(x),
5863
quantile_disc(x, 0.5),
59-
quantile_cont(x, 0.25)
64+
quantile_cont(x, '[0.25, 0.5, 0.75]')
6065
FROM data
6166
WHERE x < 5`)
6267
if err != nil {
@@ -69,8 +74,12 @@ func TestRegister_quantile(t *testing.T) {
6974
if got := stmt.ColumnFloat(1); got != 4 {
7075
t.Errorf("got %v, want 4", got)
7176
}
72-
if got := stmt.ColumnFloat(2); got != 4 {
73-
t.Errorf("got %v, want 4", got)
77+
var got []float64
78+
if err := stmt.ColumnJSON(2, &got); err != nil {
79+
t.Error(err)
80+
}
81+
if !slices.Equal(got, []float64{4, 4, 4}) {
82+
t.Errorf("got %v, want [4 4 4]", got)
7483
}
7584
}
7685
stmt.Close()
@@ -79,7 +88,7 @@ func TestRegister_quantile(t *testing.T) {
7988
SELECT
8089
median(x),
8190
quantile_disc(x, 0.5),
82-
quantile_cont(x, 0.25)
91+
quantile_cont(x, '[0.25, 0.5, 0.75]')
8392
FROM data
8493
WHERE x < 0`)
8594
if err != nil {
@@ -101,13 +110,15 @@ func TestRegister_quantile(t *testing.T) {
101110
stmt, _, err = db.Prepare(`
102111
SELECT
103112
quantile_disc(x, -2),
104-
quantile_cont(x, +2)
113+
quantile_cont(x, +2),
114+
quantile_cont(x, ''),
115+
quantile_cont(x, '[100]')
105116
FROM data`)
106117
if err != nil {
107118
t.Fatal(err)
108119
}
109120
if stmt.Step() {
110-
t.Fatal("want error")
121+
t.Error("want error")
111122
}
112123
stmt.Close()
113124
}

0 commit comments

Comments
 (0)