Skip to content

Commit 3ccccfb

Browse files
authored
Support returning any from callbacks (#1046)
Support returning any from callbacks
1 parent b819467 commit 3ccccfb

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

callback.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,32 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
353353
return nil
354354
}
355355

356+
func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error {
357+
if v.IsNil() {
358+
C.sqlite3_result_null(ctx)
359+
return nil
360+
}
361+
362+
cb, err := callbackRet(v.Elem().Type())
363+
if err != nil {
364+
return err
365+
}
366+
367+
return cb(ctx, v.Elem())
368+
}
369+
356370
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
357371
switch typ.Kind() {
358372
case reflect.Interface:
359373
errorInterface := reflect.TypeOf((*error)(nil)).Elem()
360374
if typ.Implements(errorInterface) {
361375
return callbackRetNil, nil
362376
}
377+
378+
if typ.NumMethod() == 0 {
379+
return callbackRetGeneric, nil
380+
}
381+
363382
fallthrough
364383
case reflect.Slice:
365384
if typ.Elem().Kind() != reflect.Uint8 {

callback_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,15 @@ func TestCallbackConverters(t *testing.T) {
102102
}
103103
}
104104
}
105+
106+
func TestCallbackReturnAny(t *testing.T) {
107+
udf := func() interface{} {
108+
return 1
109+
}
110+
111+
typ := reflect.TypeOf(udf)
112+
_, err := callbackRet(typ.Out(0))
113+
if err != nil {
114+
t.Errorf("Expected valid callback for any return type, got: %s", err)
115+
}
116+
}

sqlite3_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) {
14491449
}
14501450
}
14511451

1452+
type mode struct {
1453+
counts map[interface{}]int
1454+
top interface{}
1455+
topCount int
1456+
}
1457+
1458+
func newMode() *mode {
1459+
return &mode{
1460+
counts: map[interface{}]int{},
1461+
}
1462+
}
1463+
1464+
func (m *mode) Step(x interface{}) {
1465+
m.counts[x]++
1466+
c := m.counts[x]
1467+
if c > m.topCount {
1468+
m.top = x
1469+
m.topCount = c
1470+
}
1471+
}
1472+
1473+
func (m *mode) Done() interface{} {
1474+
return m.top
1475+
}
1476+
1477+
func TestAggregatorRegistration_GenericReturn(t *testing.T) {
1478+
sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{
1479+
ConnectHook: func(conn *SQLiteConn) error {
1480+
return conn.RegisterAggregator("mode", newMode, true)
1481+
},
1482+
})
1483+
db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:")
1484+
if err != nil {
1485+
t.Fatal("Failed to open database:", err)
1486+
}
1487+
defer db.Close()
1488+
1489+
_, err = db.Exec("create table foo (department integer, profits integer)")
1490+
if err != nil {
1491+
t.Fatal("Failed to create table:", err)
1492+
}
1493+
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)")
1494+
if err != nil {
1495+
t.Fatal("Failed to insert records:", err)
1496+
}
1497+
1498+
var mode int
1499+
err = db.QueryRow("select mode(profits) from foo").Scan(&mode)
1500+
if err != nil {
1501+
t.Fatal("MODE query error:", err)
1502+
}
1503+
1504+
if mode != 20 {
1505+
t.Fatal("Got incorrect mode. Wanted 20, got: ", mode)
1506+
}
1507+
}
1508+
14521509
func rot13(r rune) rune {
14531510
switch {
14541511
case r >= 'A' && r <= 'Z':

0 commit comments

Comments
 (0)