Skip to content

Commit a039cc0

Browse files
committed
Add iterator method for accounts and update tests
1 parent 65583d2 commit a039cc0

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

table.go

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package pgkit
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"iter"
68
"slices"
79
"time"
810

@@ -156,20 +158,25 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error {
156158
return nil
157159
}
158160

159-
// Get returns the first record matching the condition.
160-
func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) {
161+
// getListQuery builds a base select query for listing records.
162+
func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder {
161163
if len(orderBy) == 0 {
162164
orderBy = []string{t.IDColumn}
163165
}
164166

165-
record := new(T)
166-
167167
q := t.SQL.
168168
Select("*").
169169
From(t.Name).
170170
Where(where).
171-
Limit(1).
172171
OrderBy(orderBy...)
172+
return q
173+
}
174+
175+
// Get returns the first record matching the condition.
176+
func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) {
177+
record := new(T)
178+
179+
q := t.getListQuery(where, orderBy).Limit(1)
173180

174181
if err := t.Query.GetOne(ctx, q, record); err != nil {
175182
return nil, fmt.Errorf("get record: %w", err)
@@ -180,16 +187,7 @@ func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy [
180187

181188
// List returns all records matching the condition.
182189
func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]PT, error) {
183-
if len(orderBy) == 0 {
184-
orderBy = []string{t.IDColumn}
185-
}
186-
187-
q := t.SQL.
188-
Select("*").
189-
From(t.Name).
190-
Where(where).
191-
OrderBy(orderBy...)
192-
190+
q := t.getListQuery(where, orderBy)
193191
var records []PT
194192
if err := t.Query.GetAll(ctx, q, &records); err != nil {
195193
return nil, err
@@ -198,6 +196,31 @@ func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy
198196
return records, nil
199197
}
200198

199+
// Iter returns an iterator for records matching the condition.
200+
func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[PT, error], error) {
201+
q := t.getListQuery(where, orderBy)
202+
rows, err := t.Query.QueryRows(ctx, q)
203+
if err != nil {
204+
return nil, fmt.Errorf("query rows: %w", err)
205+
}
206+
207+
return func(yield func(PT, error) bool) {
208+
defer rows.Close()
209+
for rows.Next() {
210+
var record T
211+
if err := t.Query.Scan.ScanOne(&record, rows); err != nil {
212+
if !errors.Is(err, pgx.ErrNoRows) {
213+
yield(nil, err)
214+
}
215+
return
216+
}
217+
if !yield(&record, nil) {
218+
return
219+
}
220+
}
221+
}, nil
222+
}
223+
201224
// GetByID returns a record by its ID.
202225
func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) {
203226
return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn})

tests/table_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ func TestTable(t *testing.T) {
5454
count, err = db.Accounts.Count(ctx, nil)
5555
require.NoError(t, err, "FindAll failed")
5656
require.Equal(t, uint64(1), count, "Expected 1 account")
57+
58+
// Iterate all accounts.
59+
iter, err := db.Accounts.Iter(ctx, nil, nil)
60+
require.NoError(t, err, "Iter failed")
61+
var accounts []Account
62+
for account, err := range iter {
63+
require.NoError(t, err, "Iter error")
64+
accounts = append(accounts, *account)
65+
}
5766
})
5867

5968
t.Run("Save multiple", func(t *testing.T) {

0 commit comments

Comments
 (0)