Skip to content

Commit 4bf719a

Browse files
hugo.rutbasvanbeek
authored andcommitted
implement additional driver.Rows interfaces
1 parent 6684c3f commit 4bf719a

File tree

2 files changed

+223
-4
lines changed

2 files changed

+223
-4
lines changed

driver.go

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"reflect"
1011
"strconv"
1112
"sync"
1213

@@ -249,7 +250,7 @@ func (c ocConn) Query(query string, args []driver.Value) (rows driver.Rows, err
249250
return nil, err
250251
}
251252

252-
return ocRows{parent: rows, ctx: ctx, options: c.options}, nil
253+
return wrapRows(rows, ctx, c.options), nil
253254
}
254255

255256
return nil, driver.ErrSkip
@@ -287,7 +288,7 @@ func (c ocConn) QueryContext(ctx context.Context, query string, args []driver.Na
287288
return nil, err
288289
}
289290

290-
return ocRows{parent: rows, ctx: ctx, options: c.options}, nil
291+
return wrapRows(rows, ctx, c.options), nil
291292
}
292293

293294
return nil, driver.ErrSkip
@@ -529,7 +530,7 @@ func (s ocStmt) Query(args []driver.Value) (rows driver.Rows, err error) {
529530
if err != nil {
530531
return nil, err
531532
}
532-
rows, err = ocRows{parent: rows, ctx: ctx, options: s.options}, nil
533+
rows, err = wrapRows(rows, ctx, s.options), nil
533534
return
534535
}
535536

@@ -603,17 +604,92 @@ func (s ocStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (row
603604
if err != nil {
604605
return nil, err
605606
}
606-
rows, err = ocRows{parent: rows, ctx: ctx, options: s.options}, nil
607+
rows, err = wrapRows(rows, ctx, s.options), nil
607608
return
608609
}
609610

611+
612+
// RowsColumnTypeScanType is a duplicate interface for driver.RowsColumnTypeScanType but
613+
// with the driver.Rows composition removed.
614+
//
615+
// This is used to embed a anonymous struct without running into ambiguous method errors.
616+
type RowsColumnTypeScanType interface {
617+
ColumnTypeScanType(index int) reflect.Type
618+
}
619+
610620
// ocRows implements driver.Rows.
611621
type ocRows struct {
612622
parent driver.Rows
613623
ctx context.Context
614624
options TraceOptions
615625
}
616626

627+
// HasNextResultSet calls the implements the driver.RowsNextResultSet for ocRows.
628+
// It returns the the underlying result of HasNextResultSet from the ocRows.parent
629+
// if the parent implements driver.RowsNextResultSet.
630+
func (r ocRows) HasNextResultSet() bool {
631+
if v, ok := r.parent.(driver.RowsNextResultSet); ok {
632+
return v.HasNextResultSet()
633+
}
634+
635+
return false
636+
}
637+
638+
// NextResultsSet calls the implements the driver.RowsNextResultSet for ocRows.
639+
// It returns the the underlying result of NextResultSet from the ocRows.parent
640+
// if the parent implements driver.RowsNextResultSet.
641+
func (r ocRows) NextResultSet() error {
642+
if v, ok := r.parent.(driver.RowsNextResultSet); ok {
643+
return v.NextResultSet()
644+
}
645+
646+
return io.EOF
647+
}
648+
649+
// ColumnTypeDatabaseTypeName calls the implements the driver.RowsColumnTypeDatabaseTypeName for ocRows.
650+
// It returns the the underlying result of ColumnTypeDatabaseTypeName from the ocRows.parent
651+
// if the parent implements driver.RowsColumnTypeDatabaseTypeName.
652+
func (r ocRows) ColumnTypeDatabaseTypeName(index int) string {
653+
if v, ok := r.parent.(driver.RowsColumnTypeDatabaseTypeName); ok {
654+
return v.ColumnTypeDatabaseTypeName(index)
655+
}
656+
657+
return ""
658+
}
659+
660+
// ColumnTypeLength calls the implements the driver.RowsColumnTypeLength for ocRows.
661+
// It returns the the underlying result of ColumnTypeLength from the ocRows.parent
662+
// if the parent implements driver.RowsColumnTypeLength.
663+
func (r ocRows) ColumnTypeLength(index int) (length int64, ok bool) {
664+
if v, ok := r.parent.(driver.RowsColumnTypeLength); ok {
665+
return v.ColumnTypeLength(index)
666+
}
667+
668+
return 0, false
669+
}
670+
671+
// ColumnTypeNullable calls the implements the driver.RowsColumnTypeNullable for ocRows.
672+
// It returns the the underlying result of ColumnTypeNullable from the ocRows.parent
673+
// if the parent implements driver.RowsColumnTypeNullable.
674+
func (r ocRows) ColumnTypeNullable(index int) (nullable, ok bool) {
675+
if v, ok := r.parent.(driver.RowsColumnTypeNullable); ok {
676+
return v.ColumnTypeNullable(index)
677+
}
678+
679+
return false, false
680+
}
681+
682+
// ColumnTypePrecisionScale calls the implements the driver.RowsColumnTypePrecisionScale for ocRows.
683+
// It returns the the underlying result of ColumnTypePrecisionScale from the ocRows.parent
684+
// if the parent implements driver.RowsColumnTypePrecisionScale.
685+
func (r ocRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
686+
if v, ok := r.parent.(driver.RowsColumnTypePrecisionScale); ok {
687+
return v.ColumnTypePrecisionScale(index)
688+
}
689+
690+
return 0, 0, false
691+
}
692+
617693
func (r ocRows) Columns() []string {
618694
return r.parent.Columns()
619695
}
@@ -655,6 +731,30 @@ func (r ocRows) Next(dest []driver.Value) (err error) {
655731
return
656732
}
657733

734+
// wrapRows returns a struct which conforms to the driver.Rows interface.
735+
// It checks if the parent adheres to any additional driver interfaces and returns a matching
736+
// implementation accordingly.
737+
func wrapRows(parent driver.Rows, ctx context.Context, options TraceOptions) driver.Rows {
738+
var (
739+
ts, hasColumnTypeScan = parent.(driver.RowsColumnTypeScanType)
740+
)
741+
742+
r := ocRows{
743+
parent: parent,
744+
ctx: ctx,
745+
options: options,
746+
}
747+
748+
if hasColumnTypeScan {
749+
return struct {
750+
driver.Rows
751+
RowsColumnTypeScanType
752+
}{r, ts}
753+
}
754+
755+
return r
756+
}
757+
658758
// ocTx implemens driver.Tx
659759
type ocTx struct {
660760
parent driver.Tx

driver_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package ocsql_test
2+
3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
"reflect"
7+
"testing"
8+
9+
"github.com/opencensus-integrations/ocsql"
10+
)
11+
12+
type stubRows struct{}
13+
14+
func (stubRows) Columns() []string { return []string{} }
15+
func (stubRows) Close() error { return nil }
16+
func (stubRows) Next(dest []driver.Value) error { return nil }
17+
18+
type stubScanType struct {
19+
toReturn reflect.Type
20+
}
21+
22+
func (s stubScanType) ColumnTypeScanType(index int) reflect.Type { return s.toReturn }
23+
24+
type stubDriver struct {
25+
rows driver.Rows
26+
}
27+
28+
func (d stubDriver) Open(name string) (driver.Conn, error) {
29+
return stubConnection{rows: d.rows}, nil
30+
}
31+
32+
type stubConnection struct {
33+
rows driver.Rows
34+
}
35+
36+
func (c stubConnection) Prepare(query string) (driver.Stmt, error) {
37+
return stubStmt{rows: c.rows}, nil
38+
}
39+
40+
func (stubConnection) Close() error { return nil }
41+
func (stubConnection) Begin() (driver.Tx, error) { return &sql.Tx{}, nil }
42+
43+
type stubStmt struct {
44+
rows driver.Rows
45+
}
46+
47+
func (stubStmt) Close() error { return nil }
48+
func (stubStmt) NumInput() int { return 0 }
49+
func (stubStmt) Exec(args []driver.Value) (driver.Result, error) { return stubResult{}, nil }
50+
51+
func (s stubStmt) Query(args []driver.Value) (driver.Rows, error) {
52+
return s.rows, nil
53+
}
54+
55+
type stubResult struct{}
56+
57+
func (stubResult) LastInsertId() (int64, error) { return 0, nil }
58+
func (stubResult) RowsAffected() (int64, error) { return 0, nil }
59+
60+
type testFunc func(t *testing.T, rows driver.Rows)
61+
62+
var testNotAssignableToScanTypeInterface testFunc = func(t *testing.T, rows driver.Rows) {
63+
if _, ok := rows.(driver.RowsColumnTypeScanType); ok {
64+
t.Error("expected output to not be assignable to type: RowsColumnTypeLength")
65+
}
66+
}
67+
68+
var testAssignableToScanTypeInterface testFunc = func(t *testing.T, rows driver.Rows) {
69+
if _, ok := rows.(driver.RowsColumnTypeScanType); !ok {
70+
t.Error("expected output to be assignable to type: RowsColumnTypeLength")
71+
}
72+
}
73+
74+
func TestRowsAreWrappedWithCorrectInterfaceType(t *testing.T) {
75+
type test struct {
76+
name string
77+
input driver.Rows
78+
testFunc testFunc
79+
}
80+
81+
tests := []test{
82+
{
83+
input: stubRows{},
84+
name: "test non scan type parent is not wrapped with scan type interface",
85+
testFunc: testNotAssignableToScanTypeInterface,
86+
},
87+
{
88+
input: struct {
89+
driver.Rows
90+
ocsql.RowsColumnTypeScanType
91+
}{ stubRows{}, stubScanType{}},
92+
name: "test wraps rows with scan type interface",
93+
testFunc: testAssignableToScanTypeInterface,
94+
},
95+
}
96+
97+
for _, tt := range tests {
98+
t.Run(tt.name, func(t *testing.T) {
99+
var d = ocsql.Wrap(stubDriver{
100+
rows: tt.input,
101+
}, ocsql.WithAllTraceOptions())
102+
var c, _ = d.Open("fake-connection")
103+
104+
s, err := c.Prepare("SELECT * FROM test;")
105+
if err != nil {
106+
t.Errorf("connection.Prepare returned unexpected err: %v", err)
107+
return
108+
}
109+
110+
rows, err := s.Query([]driver.Value{})
111+
if err != nil {
112+
t.Errorf("stmt.Query returned unexpected err: %v", err)
113+
return
114+
}
115+
116+
tt.testFunc(t, rows)
117+
})
118+
}
119+
}

0 commit comments

Comments
 (0)