Skip to content

Commit cdb861a

Browse files
authored
Merge pull request #27 from opencensus-integrations/implement-extra-row-interfaces
Implement additional driver.Rows interfaces
2 parents 6684c3f + 49b922c commit cdb861a

File tree

3 files changed

+260
-10
lines changed

3 files changed

+260
-10
lines changed

driver.go

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"reflect"
1011
"strconv"
1112
"sync"
1213

@@ -30,10 +31,18 @@ var (
3031
attrDeprecated = trace.StringAttribute("ocsql.warning", "database driver uses deprecated features")
3132

3233
// Compile time assertions
33-
_ driver.Driver = &ocDriver{}
34-
_ conn = &ocConn{}
35-
_ driver.Result = &ocResult{}
36-
_ driver.Rows = &ocRows{}
34+
_ driver.Driver = &ocDriver{}
35+
_ conn = &ocConn{}
36+
_ driver.Result = &ocResult{}
37+
_ driver.Stmt = &ocStmt{}
38+
_ driver.StmtExecContext = &ocStmt{}
39+
_ driver.StmtQueryContext = &ocStmt{}
40+
_ driver.Rows = &ocRows{}
41+
_ driver.RowsNextResultSet = &ocRows{}
42+
_ driver.RowsColumnTypeDatabaseTypeName = &ocRows{}
43+
_ driver.RowsColumnTypeLength = &ocRows{}
44+
_ driver.RowsColumnTypeNullable = &ocRows{}
45+
_ driver.RowsColumnTypePrecisionScale = &ocRows{}
3746
)
3847

3948
// Register initializes and registers our ocsql wrapped database driver
@@ -249,7 +258,7 @@ func (c ocConn) Query(query string, args []driver.Value) (rows driver.Rows, err
249258
return nil, err
250259
}
251260

252-
return ocRows{parent: rows, ctx: ctx, options: c.options}, nil
261+
return wrapRows(ctx, rows, c.options), nil
253262
}
254263

255264
return nil, driver.ErrSkip
@@ -287,7 +296,7 @@ func (c ocConn) QueryContext(ctx context.Context, query string, args []driver.Na
287296
return nil, err
288297
}
289298

290-
return ocRows{parent: rows, ctx: ctx, options: c.options}, nil
299+
return wrapRows(ctx, rows, c.options), nil
291300
}
292301

293302
return nil, driver.ErrSkip
@@ -529,7 +538,7 @@ func (s ocStmt) Query(args []driver.Value) (rows driver.Rows, err error) {
529538
if err != nil {
530539
return nil, err
531540
}
532-
rows, err = ocRows{parent: rows, ctx: ctx, options: s.options}, nil
541+
rows, err = wrapRows(ctx, rows, s.options), nil
533542
return
534543
}
535544

@@ -603,17 +612,93 @@ func (s ocStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (row
603612
if err != nil {
604613
return nil, err
605614
}
606-
rows, err = ocRows{parent: rows, ctx: ctx, options: s.options}, nil
615+
rows, err = wrapRows(ctx, rows, s.options), nil
607616
return
608617
}
609618

610-
// ocRows implements driver.Rows.
619+
// withRowsColumnTypeScanType is the same as the driver.RowsColumnTypeScanType
620+
// interface except it omits the driver.Rows embedded interface.
621+
// If the original driver.Rows implementation wrapped by ocsql supports
622+
// RowsColumnTypeScanType we enable the original method implementation in the
623+
// returned driver.Rows from wrapRows by doing a composition with ocRows.
624+
type withRowsColumnTypeScanType interface {
625+
ColumnTypeScanType(index int) reflect.Type
626+
}
627+
628+
// ocRows implements driver.Rows and all enhancement interfaces except
629+
// driver.RowsColumnTypeScanType.
611630
type ocRows struct {
612631
parent driver.Rows
613632
ctx context.Context
614633
options TraceOptions
615634
}
616635

636+
// HasNextResultSet calls the implements the driver.RowsNextResultSet for ocRows.
637+
// It returns the the underlying result of HasNextResultSet from the ocRows.parent
638+
// if the parent implements driver.RowsNextResultSet.
639+
func (r ocRows) HasNextResultSet() bool {
640+
if v, ok := r.parent.(driver.RowsNextResultSet); ok {
641+
return v.HasNextResultSet()
642+
}
643+
644+
return false
645+
}
646+
647+
// NextResultsSet calls the implements the driver.RowsNextResultSet for ocRows.
648+
// It returns the the underlying result of NextResultSet from the ocRows.parent
649+
// if the parent implements driver.RowsNextResultSet.
650+
func (r ocRows) NextResultSet() error {
651+
if v, ok := r.parent.(driver.RowsNextResultSet); ok {
652+
return v.NextResultSet()
653+
}
654+
655+
return io.EOF
656+
}
657+
658+
// ColumnTypeDatabaseTypeName calls the implements the driver.RowsColumnTypeDatabaseTypeName for ocRows.
659+
// It returns the the underlying result of ColumnTypeDatabaseTypeName from the ocRows.parent
660+
// if the parent implements driver.RowsColumnTypeDatabaseTypeName.
661+
func (r ocRows) ColumnTypeDatabaseTypeName(index int) string {
662+
if v, ok := r.parent.(driver.RowsColumnTypeDatabaseTypeName); ok {
663+
return v.ColumnTypeDatabaseTypeName(index)
664+
}
665+
666+
return ""
667+
}
668+
669+
// ColumnTypeLength calls the implements the driver.RowsColumnTypeLength for ocRows.
670+
// It returns the the underlying result of ColumnTypeLength from the ocRows.parent
671+
// if the parent implements driver.RowsColumnTypeLength.
672+
func (r ocRows) ColumnTypeLength(index int) (length int64, ok bool) {
673+
if v, ok := r.parent.(driver.RowsColumnTypeLength); ok {
674+
return v.ColumnTypeLength(index)
675+
}
676+
677+
return 0, false
678+
}
679+
680+
// ColumnTypeNullable calls the implements the driver.RowsColumnTypeNullable for ocRows.
681+
// It returns the the underlying result of ColumnTypeNullable from the ocRows.parent
682+
// if the parent implements driver.RowsColumnTypeNullable.
683+
func (r ocRows) ColumnTypeNullable(index int) (nullable, ok bool) {
684+
if v, ok := r.parent.(driver.RowsColumnTypeNullable); ok {
685+
return v.ColumnTypeNullable(index)
686+
}
687+
688+
return false, false
689+
}
690+
691+
// ColumnTypePrecisionScale calls the implements the driver.RowsColumnTypePrecisionScale for ocRows.
692+
// It returns the the underlying result of ColumnTypePrecisionScale from the ocRows.parent
693+
// if the parent implements driver.RowsColumnTypePrecisionScale.
694+
func (r ocRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
695+
if v, ok := r.parent.(driver.RowsColumnTypePrecisionScale); ok {
696+
return v.ColumnTypePrecisionScale(index)
697+
}
698+
699+
return 0, 0, false
700+
}
701+
617702
func (r ocRows) Columns() []string {
618703
return r.parent.Columns()
619704
}
@@ -655,6 +740,33 @@ func (r ocRows) Next(dest []driver.Value) (err error) {
655740
return
656741
}
657742

743+
// wrapRows returns a struct which conforms to the driver.Rows interface.
744+
// ocRows implements all enhancement interfaces that have no effect on
745+
// sql/database logic in case the underlying parent implementation lacks them.
746+
// Currently the one exception is RowsColumnTypeScanType which does not have a
747+
// valid zero value. This interface is tested for and only enabled in case the
748+
// parent implementation supports it.
749+
func wrapRows(ctx context.Context, parent driver.Rows, options TraceOptions) driver.Rows {
750+
var (
751+
ts, hasColumnTypeScan = parent.(driver.RowsColumnTypeScanType)
752+
)
753+
754+
r := ocRows{
755+
parent: parent,
756+
ctx: ctx,
757+
options: options,
758+
}
759+
760+
if hasColumnTypeScan {
761+
return struct {
762+
ocRows
763+
withRowsColumnTypeScanType
764+
}{r, ts}
765+
}
766+
767+
return r
768+
}
769+
658770
// ocTx implemens driver.Tx
659771
type ocTx struct {
660772
parent driver.Tx

driver_go1.10.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ import (
1111
var errConnDone = sql.ErrConnDone
1212

1313
// Compile time assertion
14-
var _ driver.DriverContext = &ocDriver{}
14+
var (
15+
_ driver.DriverContext = &ocDriver{}
16+
_ driver.Connector = &ocDriver{}
17+
)
1518

1619
// ocDriver implements driver.Driver
1720
type ocDriver struct {

driver_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package ocsql
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
"errors"
7+
"io"
8+
"reflect"
9+
"testing"
10+
)
11+
12+
var errDummy = errors.New("dummy")
13+
14+
type stubRows struct{}
15+
16+
func (stubRows) Columns() []string { return []string{"dummy"} }
17+
func (stubRows) Close() error { return errDummy }
18+
func (stubRows) Next([]driver.Value) error { return errDummy }
19+
func (stubRows) HasNextResultSet() bool { return true }
20+
func (stubRows) NextResultSet() error { return errDummy }
21+
func (stubRows) ColumnTypeScanType(int) reflect.Type { return reflect.TypeOf(stubRows{}) }
22+
func (stubRows) ColumnTypeDatabaseTypeName(index int) string { return "dummy" }
23+
func (stubRows) ColumnTypeLength(index int) (length int64, ok bool) { return 1, true }
24+
func (stubRows) ColumnTypeNullable(index int) (nullable, ok bool) { return true, true }
25+
func (stubRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
26+
return 1, 1, true
27+
}
28+
29+
func TestWrappingTransparency(t *testing.T) {
30+
var (
31+
ctx = context.Background()
32+
oRows = &stubRows{}
33+
wRows = wrapRows(ctx, oRows, AllTraceOptions)
34+
)
35+
36+
if want, have := oRows.Columns(), wRows.Columns(); len(want) != len(have) {
37+
t.Errorf("rows.Column want: %v, have: %v", want, have)
38+
}
39+
40+
if want, have := oRows.Close(), wRows.Close(); want != have {
41+
t.Errorf("rows.Close want: %v, have: %v", want, have)
42+
}
43+
44+
if want, have := oRows.Next(nil), wRows.Next(nil); want != have {
45+
t.Errorf("rows.Next want: %v, have: %v", want, have)
46+
}
47+
48+
if want, have := oRows.HasNextResultSet(), wRows.(driver.RowsNextResultSet).HasNextResultSet(); want != have {
49+
t.Errorf("rows.HasNextResultSet want: %t, have: %t", want, have)
50+
}
51+
52+
if want, have := oRows.NextResultSet(), wRows.(driver.RowsNextResultSet).NextResultSet(); want != have {
53+
t.Errorf("rows.NextResultSet want: %v, have: %v", want, have)
54+
}
55+
56+
if want, have := oRows.ColumnTypeScanType(1), wRows.(driver.RowsColumnTypeScanType).ColumnTypeScanType(1); want != have {
57+
t.Errorf("rows.ColumnTypeScanType want: %v, have: %v", want, have)
58+
}
59+
60+
if want, have := oRows.ColumnTypeDatabaseTypeName(1), wRows.(driver.RowsColumnTypeDatabaseTypeName).ColumnTypeDatabaseTypeName(1); want != have {
61+
t.Errorf("rows.ColumnTypeDatabaseTypeName want: %s, have: %s", want, have)
62+
}
63+
64+
oLength, oOk := oRows.ColumnTypeLength(1)
65+
wLength, wOk := wRows.(driver.RowsColumnTypeLength).ColumnTypeLength(1)
66+
if oLength != wLength || oOk != wOk {
67+
t.Errorf("rows.ColumnTypeLength want: %d:%t, have %d:%t", oLength, oOk, wLength, wOk)
68+
}
69+
70+
oNullable, oOk := oRows.ColumnTypeNullable(1)
71+
wNullable, wOk := wRows.(driver.RowsColumnTypeNullable).ColumnTypeNullable(1)
72+
if oNullable != wNullable || oOk != wOk {
73+
t.Errorf("rows.ColumnTypeNullable want: %t:%t, have %t:%t", oNullable, oOk, wNullable, wOk)
74+
}
75+
76+
oPrecision, oScale, oOk := oRows.ColumnTypePrecisionScale(1)
77+
wPrecision, wScale, wOk := wRows.(driver.RowsColumnTypePrecisionScale).ColumnTypePrecisionScale(1)
78+
if oPrecision != wPrecision || oScale != wScale || oOk != wOk {
79+
t.Errorf("rows.ColumnTypePrecisionScale want: %d:%d:%t, have %d:%d:%t", oPrecision, oScale, oOk, wPrecision, wScale, wOk)
80+
}
81+
}
82+
83+
func TestWrappingFallback(t *testing.T) {
84+
var (
85+
ctx = context.Background()
86+
oRows = struct{ driver.Rows }{&stubRows{}}
87+
wRows = wrapRows(ctx, oRows, AllTraceOptions)
88+
)
89+
90+
if want, have := oRows.Columns(), wRows.Columns(); len(want) != len(have) {
91+
t.Errorf("rows.Column want: %v, have: %v", want, have)
92+
}
93+
94+
if want, have := oRows.Close(), wRows.Close(); want != have {
95+
t.Errorf("rows.Close want: %v, have: %v", want, have)
96+
}
97+
98+
if want, have := oRows.Next(nil), wRows.Next(nil); want != have {
99+
t.Errorf("rows.Next want: %v, have: %v", want, have)
100+
}
101+
102+
if want, have := false, wRows.(driver.RowsNextResultSet).HasNextResultSet(); want != have {
103+
t.Errorf("rows.HasNextResultSet want: %t, have: %t", want, have)
104+
}
105+
106+
if want, have := io.EOF, wRows.(driver.RowsNextResultSet).NextResultSet(); want != have {
107+
t.Errorf("rows.NextResultSet want: %v, have: %v", want, have)
108+
}
109+
110+
if _, ok := wRows.(driver.RowsColumnTypeScanType); ok {
111+
t.Error("rows.ColumnTypeScanType unexpected interface implementation found")
112+
}
113+
114+
if want, have := "", wRows.(driver.RowsColumnTypeDatabaseTypeName).ColumnTypeDatabaseTypeName(1); want != have {
115+
t.Errorf("rows.ColumnTypeDatabaseTypeName want: %s, have: %s", want, have)
116+
}
117+
118+
oLength, oOk := int64(0), false
119+
wLength, wOk := wRows.(driver.RowsColumnTypeLength).ColumnTypeLength(1)
120+
if oLength != wLength || oOk != wOk {
121+
t.Errorf("rows.ColumnTypeLength want: %d:%t, have %d:%t", oLength, oOk, wLength, wOk)
122+
}
123+
124+
oNullable, oOk := false, false
125+
wNullable, wOk := wRows.(driver.RowsColumnTypeNullable).ColumnTypeNullable(1)
126+
if oNullable != wNullable || oOk != wOk {
127+
t.Errorf("rows.ColumnTypeNullable want: %t:%t, have %t:%t", oNullable, oOk, wNullable, wOk)
128+
}
129+
130+
oPrecision, oScale, oOk := int64(0), int64(0), false
131+
wPrecision, wScale, wOk := wRows.(driver.RowsColumnTypePrecisionScale).ColumnTypePrecisionScale(1)
132+
if oPrecision != wPrecision || oScale != wScale || oOk != wOk {
133+
t.Errorf("rows.ColumnTypePrecisionScale want: %d:%d:%t, have %d:%d:%t", oPrecision, oScale, oOk, wPrecision, wScale, wOk)
134+
}
135+
}

0 commit comments

Comments
 (0)