Skip to content

Commit dbf764a

Browse files
committed
Boolean aggregates.
1 parent 8fd878a commit dbf764a

File tree

7 files changed

+163
-40
lines changed

7 files changed

+163
-40
lines changed

ext/stats/TODO.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ https://sqlite.org/windowfunctions.html#builtins
4848

4949
## Boolean aggregates
5050

51-
- [ ] `ALL(boolean)`
52-
- [ ] `ANY(boolean)`
53-
- [ ] `EVERY(boolean)`
54-
- [ ] `SOME(boolean)`
51+
- [X] `EVERY(boolean)`
52+
- [X] `SOME(boolean)`
5553

5654
## Additional aggregates
5755

ext/stats/boolean.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package stats
2+
3+
import "github.com/ncruces/go-sqlite3"
4+
5+
const (
6+
every = iota
7+
some
8+
)
9+
10+
func newBoolean(kind int) func() sqlite3.AggregateFunction {
11+
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
12+
}
13+
14+
type boolean struct {
15+
count int
16+
total int
17+
kind int
18+
}
19+
20+
func (b *boolean) Value(ctx sqlite3.Context) {
21+
if b.kind == every {
22+
ctx.ResultBool(b.count == b.total)
23+
} else {
24+
ctx.ResultBool(b.count > 0)
25+
}
26+
}
27+
28+
func (b *boolean) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
29+
if arg[0].Type() == sqlite3.NULL {
30+
return
31+
}
32+
if arg[0].Bool() {
33+
b.count++
34+
}
35+
b.total++
36+
}
37+
38+
func (b *boolean) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
39+
if arg[0].Type() == sqlite3.NULL {
40+
return
41+
}
42+
if arg[0].Bool() {
43+
b.count--
44+
}
45+
b.total--
46+
}

ext/stats/boolean_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package stats_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/ncruces/go-sqlite3"
7+
_ "github.com/ncruces/go-sqlite3/embed"
8+
"github.com/ncruces/go-sqlite3/ext/stats"
9+
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
10+
)
11+
12+
func TestRegister_boolean(t *testing.T) {
13+
t.Parallel()
14+
15+
db, err := sqlite3.Open(":memory:")
16+
if err != nil {
17+
t.Fatal(err)
18+
}
19+
defer db.Close()
20+
21+
stats.Register(db)
22+
23+
err = db.Exec(`CREATE TABLE data (x)`)
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
28+
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), (13), (NULL), (16), (3.14)`)
29+
if err != nil {
30+
t.Fatal(err)
31+
}
32+
33+
stmt, _, err := db.Prepare(`
34+
SELECT
35+
every(x > 0),
36+
every(x > 10),
37+
some(x > 10),
38+
some(x > 20)
39+
FROM data`)
40+
if err != nil {
41+
t.Fatal(err)
42+
}
43+
if stmt.Step() {
44+
if got := stmt.ColumnBool(0); got != true {
45+
t.Errorf("got %v, want true", got)
46+
}
47+
if got := stmt.ColumnBool(1); got != false {
48+
t.Errorf("got %v, want false", got)
49+
}
50+
if got := stmt.ColumnBool(2); got != true {
51+
t.Errorf("got %v, want true", got)
52+
}
53+
if got := stmt.ColumnBool(3); got != false {
54+
t.Errorf("got %v, want false", got)
55+
}
56+
}
57+
stmt.Close()
58+
59+
stmt, _, err = db.Prepare(`SELECT every(x > 10) OVER (ROWS 1 PRECEDING) FROM data`)
60+
if err != nil {
61+
t.Fatal(err)
62+
}
63+
64+
want := [...]bool{false, false, false, true, true, false}
65+
for i := 0; stmt.Step(); i++ {
66+
if got := stmt.ColumnBool(0); got != want[i] {
67+
t.Errorf("got %v, want %v", got, want[i])
68+
}
69+
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
70+
t.Errorf("got %v, want INTEGER", got)
71+
}
72+
}
73+
stmt.Close()
74+
}

ext/stats/stats.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
// - quantile_disc: discrete quantile
2222
// - quantile_cont: continuous quantile
2323
// - median: median value
24+
// - every: boolean and
25+
// - some: boolean or
2426
//
2527
// These join the [Built-in Aggregate Functions]:
2628
// - count: count rows/values
@@ -29,9 +31,16 @@
2931
// - min: minimum value
3032
// - max: maximum value
3133
//
34+
// And the [Built-in Window Functions]:
35+
// - rank: rank of the current row with gaps
36+
// - dense_rank: rank of the current row without gaps
37+
// - percent_rank: relative rank of the row
38+
// - cume_dist: cumulative distribution
39+
//
3240
// See: [ANSI SQL Aggregate Functions], [DuckDB Aggregate Functions]
3341
//
3442
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
43+
// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins
3544
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
3645
// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html
3746
package stats
@@ -61,6 +70,8 @@ func Register(db *sqlite3.Conn) {
6170
db.CreateWindowFunction("median", 1, flags, newQuantile(median))
6271
db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont))
6372
db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc))
73+
db.CreateWindowFunction("every", 1, flags, newBoolean(every))
74+
db.CreateWindowFunction("some", 1, flags, newBoolean(some))
6475
}
6576

6677
const (

ext/stats/stats_test.go

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ func TestRegister_variance(t *testing.T) {
4040
if err != nil {
4141
t.Fatal(err)
4242
}
43-
defer stmt.Close()
44-
4543
if stmt.Step() {
4644
if got := stmt.ColumnFloat(0); got != 40 {
4745
t.Errorf("got %v, want 40", got)
@@ -62,24 +60,23 @@ func TestRegister_variance(t *testing.T) {
6260
t.Errorf("got %v, want √22.5", got)
6361
}
6462
}
63+
stmt.Close()
6564

66-
{
67-
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
68-
if err != nil {
69-
t.Fatal(err)
70-
}
71-
defer stmt.Close()
65+
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
66+
if err != nil {
67+
t.Fatal(err)
68+
}
7269

73-
want := [...]float64{0, 4.5, 18, 0, 0}
74-
for i := 0; stmt.Step(); i++ {
75-
if got := stmt.ColumnFloat(0); got != want[i] {
76-
t.Errorf("got %v, want %v", got, want[i])
77-
}
78-
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
79-
t.Errorf("got %v, want %v", got, want[i])
80-
}
70+
want := [...]float64{0, 4.5, 18, 0, 0}
71+
for i := 0; stmt.Step(); i++ {
72+
if got := stmt.ColumnFloat(0); got != want[i] {
73+
t.Errorf("got %v, want %v", got, want[i])
74+
}
75+
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
76+
t.Errorf("got %v, want %v", got, want[i])
8177
}
8278
}
79+
stmt.Close()
8380
}
8481

8582
func TestRegister_covariance(t *testing.T) {
@@ -113,8 +110,6 @@ func TestRegister_covariance(t *testing.T) {
113110
if err != nil {
114111
t.Fatal(err)
115112
}
116-
defer stmt.Close()
117-
118113
if stmt.Step() {
119114
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
120115
t.Errorf("got %v, want 0.9881049293224639", got)
@@ -159,24 +154,23 @@ func TestRegister_covariance(t *testing.T) {
159154
t.Errorf("got %v, want 5", got)
160155
}
161156
}
157+
stmt.Close()
162158

163-
{
164-
stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
165-
if err != nil {
166-
t.Fatal(err)
167-
}
168-
defer stmt.Close()
159+
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
160+
if err != nil {
161+
t.Fatal(err)
162+
}
169163

170-
want := [...]float64{0, 10, 30, 75, 22.5}
171-
for i := 0; stmt.Step(); i++ {
172-
if got := stmt.ColumnFloat(0); got != want[i] {
173-
t.Errorf("got %v, want %v", got, want[i])
174-
}
175-
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
176-
t.Errorf("got %v, want %v", got, want[i])
177-
}
164+
want := [...]float64{0, 10, 30, 75, 22.5}
165+
for i := 0; stmt.Step(); i++ {
166+
if got := stmt.ColumnFloat(0); got != want[i] {
167+
t.Errorf("got %v, want %v", got, want[i])
168+
}
169+
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
170+
t.Errorf("got %v, want %v", got, want[i])
178171
}
179172
}
173+
stmt.Close()
180174
}
181175

182176
func Benchmark_average(b *testing.B) {

stmt.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,12 +441,12 @@ func (s *Stmt) ColumnOriginName(col int) string {
441441
// ColumnBool returns the value of the result column as a bool.
442442
// The leftmost column of the result set has the index 0.
443443
// SQLite does not have a separate boolean storage class.
444-
// Instead, boolean values are retrieved as integers,
444+
// Instead, boolean values are retrieved as numbers,
445445
// with 0 converted to false and any other value to true.
446446
//
447447
// https://sqlite.org/c3ref/column_blob.html
448448
func (s *Stmt) ColumnBool(col int) bool {
449-
return s.ColumnInt64(col) != 0
449+
return s.ColumnFloat(col) != 0
450450
}
451451

452452
// ColumnInt returns the value of the result column as an int.

value.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ func (v Value) NumericType() Datatype {
6868

6969
// Bool returns the value as a bool.
7070
// SQLite does not have a separate boolean storage class.
71-
// Instead, boolean values are retrieved as integers,
71+
// Instead, boolean values are retrieved as numbers,
7272
// with 0 converted to false and any other value to true.
7373
//
7474
// https://sqlite.org/c3ref/value_blob.html
7575
func (v Value) Bool() bool {
76-
return v.Int64() != 0
76+
return v.Float() != 0
7777
}
7878

7979
// Int returns the value as an int.

0 commit comments

Comments
 (0)