@@ -192,8 +192,7 @@ func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error
192192// QueryContext returns the results of the query as a sequence of rows.
193193//
194194// The returned function automatically closes the unerlying sql.Rows value when
195- // it completes its iteration. The function can only be iterated once, it will
196- // not retain the values that it has seen.
195+ // it completes its iteration.
197196//
198197// A typical use of QueryContext is:
199198//
@@ -210,21 +209,20 @@ func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error
210209// See Scan for more information about how the rows are mapped to the row type
211210// parameter Row.
212211func QueryContext [Row any ](ctx context.Context , q Queryable , query string , args ... any ) iter.Seq2 [Row , error ] {
213- rows , err := q .QueryContext (ctx , query , args ... )
214- if err != nil {
215- return func (yield func (Row , error ) bool ) {
212+ return func (yield func (Row , error ) bool ) {
213+ if rows , err := q .QueryContext (ctx , query , args ... ); err != nil {
216214 var zero Row
217215 yield (zero , err )
216+ } else {
217+ scan [Row ](yield , rows )
218218 }
219219 }
220- return Scan [Row ](rows )
221220}
222221
223222// Scan returns a sequence of rows from a sql.Rows value.
224223//
225224// The returned function automatically closes the rows passed as argument when
226- // it completes its iteration. The function can only be iterated once, it will
227- // not retain the values that it has seen.
225+ // it completes its iteration.
228226//
229227// A typical use of Scan is:
230228//
@@ -254,40 +252,42 @@ func QueryContext[Row any](ctx context.Context, q Queryable, query string, args
254252// Ranging over the returned function will panic if the type parameter is not a
255253// struct.
256254func Scan [Row any ](rows * sql.Rows ) iter.Seq2 [Row , error ] {
257- return func (yield func (Row , error ) bool ) {
258- defer rows .Close ()
259- var zero Row
255+ return func (yield func (Row , error ) bool ) { scan (yield , rows ) }
256+ }
260257
261- columns , err := rows .Columns ()
262- if err != nil {
263- yield (zero , err )
264- return
265- }
258+ func scan [Row any ](yield func (Row , error ) bool , rows * sql.Rows ) {
259+ defer rows .Close ()
260+ var zero Row
266261
267- scanArgs := make ([]any , len (columns ))
268- row := new (Row )
269- val := reflect .ValueOf (row ).Elem ()
262+ columns , err := rows .Columns ()
263+ if err != nil {
264+ yield (zero , err )
265+ return
266+ }
270267
271- for columnName , structField := range Fields (val .Type ()) {
272- if columnIndex := slices .Index (columns , columnName ); columnIndex >= 0 {
273- scanArgs [columnIndex ] = val .FieldByIndex (structField .Index ).Addr ().Interface ()
274- }
275- }
268+ scanArgs := make ([]any , len (columns ))
269+ row := new (Row )
270+ val := reflect .ValueOf (row ).Elem ()
276271
277- for rows .Next () {
278- if err := rows .Scan (scanArgs ... ); err != nil {
279- yield (zero , err )
280- return
281- }
282- if ! yield (* row , nil ) {
283- return
284- }
285- * row = zero
272+ for columnName , structField := range Fields (val .Type ()) {
273+ if columnIndex := slices .Index (columns , columnName ); columnIndex >= 0 {
274+ scanArgs [columnIndex ] = val .FieldByIndex (structField .Index ).Addr ().Interface ()
286275 }
276+ }
287277
288- if err := rows .Err (); err != nil {
278+ for rows .Next () {
279+ if err := rows .Scan (scanArgs ... ); err != nil {
289280 yield (zero , err )
281+ return
282+ }
283+ if ! yield (* row , nil ) {
284+ return
290285 }
286+ * row = zero
287+ }
288+
289+ if err := rows .Err (); err != nil {
290+ yield (zero , err )
291291 }
292292}
293293
0 commit comments