Skip to content

Commit 3213ba9

Browse files
support calling query sequences multiples times
Signed-off-by: Achille Roussel <[email protected]>
1 parent 6698ca7 commit 3213ba9

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

sqlrange.go

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,7 @@ func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error
192192
// QueryContext returns the results of the query as a sequence of rows.
193193
//
194194
// The returned function automatically closes the unerlying sql.Rows value when
195-
// it completes its iteration. The function can only be iterated once, it will
196-
// not retain the values that it has seen.
195+
// it completes its iteration.
197196
//
198197
// A typical use of QueryContext is:
199198
//
@@ -210,21 +209,20 @@ func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error
210209
// See Scan for more information about how the rows are mapped to the row type
211210
// parameter Row.
212211
func QueryContext[Row any](ctx context.Context, q Queryable, query string, args ...any) iter.Seq2[Row, error] {
213-
rows, err := q.QueryContext(ctx, query, args...)
214-
if err != nil {
215-
return func(yield func(Row, error) bool) {
212+
return func(yield func(Row, error) bool) {
213+
if rows, err := q.QueryContext(ctx, query, args...); err != nil {
216214
var zero Row
217215
yield(zero, err)
216+
} else {
217+
scan[Row](yield, rows)
218218
}
219219
}
220-
return Scan[Row](rows)
221220
}
222221

223222
// Scan returns a sequence of rows from a sql.Rows value.
224223
//
225224
// The returned function automatically closes the rows passed as argument when
226-
// it completes its iteration. The function can only be iterated once, it will
227-
// not retain the values that it has seen.
225+
// it completes its iteration.
228226
//
229227
// A typical use of Scan is:
230228
//
@@ -254,40 +252,42 @@ func QueryContext[Row any](ctx context.Context, q Queryable, query string, args
254252
// Ranging over the returned function will panic if the type parameter is not a
255253
// struct.
256254
func Scan[Row any](rows *sql.Rows) iter.Seq2[Row, error] {
257-
return func(yield func(Row, error) bool) {
258-
defer rows.Close()
259-
var zero Row
255+
return func(yield func(Row, error) bool) { scan(yield, rows) }
256+
}
260257

261-
columns, err := rows.Columns()
262-
if err != nil {
263-
yield(zero, err)
264-
return
265-
}
258+
func scan[Row any](yield func(Row, error) bool, rows *sql.Rows) {
259+
defer rows.Close()
260+
var zero Row
266261

267-
scanArgs := make([]any, len(columns))
268-
row := new(Row)
269-
val := reflect.ValueOf(row).Elem()
262+
columns, err := rows.Columns()
263+
if err != nil {
264+
yield(zero, err)
265+
return
266+
}
270267

271-
for columnName, structField := range Fields(val.Type()) {
272-
if columnIndex := slices.Index(columns, columnName); columnIndex >= 0 {
273-
scanArgs[columnIndex] = val.FieldByIndex(structField.Index).Addr().Interface()
274-
}
275-
}
268+
scanArgs := make([]any, len(columns))
269+
row := new(Row)
270+
val := reflect.ValueOf(row).Elem()
276271

277-
for rows.Next() {
278-
if err := rows.Scan(scanArgs...); err != nil {
279-
yield(zero, err)
280-
return
281-
}
282-
if !yield(*row, nil) {
283-
return
284-
}
285-
*row = zero
272+
for columnName, structField := range Fields(val.Type()) {
273+
if columnIndex := slices.Index(columns, columnName); columnIndex >= 0 {
274+
scanArgs[columnIndex] = val.FieldByIndex(structField.Index).Addr().Interface()
286275
}
276+
}
287277

288-
if err := rows.Err(); err != nil {
278+
for rows.Next() {
279+
if err := rows.Scan(scanArgs...); err != nil {
289280
yield(zero, err)
281+
return
282+
}
283+
if !yield(*row, nil) {
284+
return
290285
}
286+
*row = zero
287+
}
288+
289+
if err := rows.Err(); err != nil {
290+
yield(zero, err)
291291
}
292292
}
293293

0 commit comments

Comments
 (0)