@@ -31,10 +31,18 @@ var (
3131 attrDeprecated = trace .StringAttribute ("ocsql.warning" , "database driver uses deprecated features" )
3232
3333 // Compile time assertions
34- _ driver.Driver = & ocDriver {}
35- _ conn = & ocConn {}
36- _ driver.Result = & ocResult {}
37- _ driver.Rows = & ocRows {}
34+ _ driver.Driver = & ocDriver {}
35+ _ conn = & ocConn {}
36+ _ driver.Result = & ocResult {}
37+ _ driver.Stmt = & ocStmt {}
38+ _ driver.StmtExecContext = & ocStmt {}
39+ _ driver.StmtQueryContext = & ocStmt {}
40+ _ driver.Rows = & ocRows {}
41+ _ driver.RowsNextResultSet = & ocRows {}
42+ _ driver.RowsColumnTypeDatabaseTypeName = & ocRows {}
43+ _ driver.RowsColumnTypeLength = & ocRows {}
44+ _ driver.RowsColumnTypeNullable = & ocRows {}
45+ _ driver.RowsColumnTypePrecisionScale = & ocRows {}
3846)
3947
4048// Register initializes and registers our ocsql wrapped database driver
@@ -250,7 +258,7 @@ func (c ocConn) Query(query string, args []driver.Value) (rows driver.Rows, err
250258 return nil , err
251259 }
252260
253- return wrapRows (rows , ctx , c .options ), nil
261+ return wrapRows (ctx , rows , c .options ), nil
254262 }
255263
256264 return nil , driver .ErrSkip
@@ -288,7 +296,7 @@ func (c ocConn) QueryContext(ctx context.Context, query string, args []driver.Na
288296 return nil , err
289297 }
290298
291- return wrapRows (rows , ctx , c .options ), nil
299+ return wrapRows (ctx , rows , c .options ), nil
292300 }
293301
294302 return nil , driver .ErrSkip
@@ -530,7 +538,7 @@ func (s ocStmt) Query(args []driver.Value) (rows driver.Rows, err error) {
530538 if err != nil {
531539 return nil , err
532540 }
533- rows , err = wrapRows (rows , ctx , s .options ), nil
541+ rows , err = wrapRows (ctx , rows , s .options ), nil
534542 return
535543}
536544
@@ -604,20 +612,21 @@ func (s ocStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (row
604612 if err != nil {
605613 return nil , err
606614 }
607- rows , err = wrapRows (rows , ctx , s .options ), nil
615+ rows , err = wrapRows (ctx , rows , s .options ), nil
608616 return
609617}
610618
611-
612- // RowsColumnTypeScanType is a duplicate interface for driver.RowsColumnTypeScanType but
613- // with the driver.Rows composition removed.
614- //
615- // This is used to embed a anonymous struct without running into ambiguous method errors .
616- type RowsColumnTypeScanType interface {
619+ // withRowsColumnTypeScanType is the same as the driver.RowsColumnTypeScanType
620+ // interface except it omits the driver.Rows embedded interface.
621+ // If the original driver.Rows implementation wrapped by ocsql supports
622+ // RowsColumnTypeScanType we enable the original method implementation in the
623+ // returned driver.Rows from wrapRows by doing a composition with ocRows .
624+ type withRowsColumnTypeScanType interface {
617625 ColumnTypeScanType (index int ) reflect.Type
618626}
619627
620- // ocRows implements driver.Rows.
628+ // ocRows implements driver.Rows and all enhancement interfaces except
629+ // driver.RowsColumnTypeScanType.
621630type ocRows struct {
622631 parent driver.Rows
623632 ctx context.Context
@@ -732,9 +741,12 @@ func (r ocRows) Next(dest []driver.Value) (err error) {
732741}
733742
734743// wrapRows returns a struct which conforms to the driver.Rows interface.
735- // It checks if the parent adheres to any additional driver interfaces and returns a matching
736- // implementation accordingly.
737- func wrapRows (parent driver.Rows , ctx context.Context , options TraceOptions ) driver.Rows {
744+ // ocRows implements all enhancement interfaces that have no effect on
745+ // sql/database logic in case the underlying parent implementation lacks them.
746+ // Currently the one exception is RowsColumnTypeScanType which does not have a
747+ // valid zero value. This interface is tested for and only enabled in case the
748+ // parent implementation supports it.
749+ func wrapRows (ctx context.Context , parent driver.Rows , options TraceOptions ) driver.Rows {
738750 var (
739751 ts , hasColumnTypeScan = parent .(driver.RowsColumnTypeScanType )
740752 )
@@ -747,8 +759,8 @@ func wrapRows(parent driver.Rows, ctx context.Context, options TraceOptions) dri
747759
748760 if hasColumnTypeScan {
749761 return struct {
750- driver. Rows
751- RowsColumnTypeScanType
762+ ocRows
763+ withRowsColumnTypeScanType
752764 }{r , ts }
753765 }
754766
0 commit comments