Skip to content

Commit 2ab2b56

Browse files
committed
Refactor database test functions for improved readability and reuse
1 parent a311739 commit 2ab2b56

File tree

2 files changed

+89
-133
lines changed

2 files changed

+89
-133
lines changed

itests/db_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"testing"
88

99
"github.com/jackc/pgx/v5/pgxpool"
10+
sq "github.com/n-r-w/squirrel"
1011
"github.com/n-r-w/testdock/v2"
1112
"github.com/stretchr/testify/require"
1213
)
@@ -16,7 +17,6 @@ func newTestPool(t *testing.T) (*pgxpool.Pool, context.Context) {
1617

1718
ctx := context.Background()
1819
pool, _ := testdock.GetPgxPool(t, testdock.DefaultPostgresDSN)
19-
t.Cleanup(pool.Close)
2020

2121
return pool, ctx
2222
}
@@ -27,3 +27,51 @@ func execSetup(t *testing.T, pool *pgxpool.Pool, ctx context.Context, setupSQL s
2727
_, err := pool.Exec(ctx, setupSQL)
2828
require.NoError(t, err)
2929
}
30+
31+
func queryInt64s(t *testing.T, pool *pgxpool.Pool, ctx context.Context, q sq.Sqlizer) []int64 {
32+
t.Helper()
33+
34+
sql, args, err := q.ToSql()
35+
require.NoError(t, err)
36+
37+
rows, err := pool.Query(ctx, sql, args...)
38+
require.NoError(t, err)
39+
t.Cleanup(rows.Close)
40+
41+
var ids []int64
42+
for rows.Next() {
43+
var id int64
44+
err := rows.Scan(&id)
45+
require.NoError(t, err)
46+
ids = append(ids, id)
47+
}
48+
require.NoError(t, rows.Err())
49+
50+
return ids
51+
}
52+
53+
func queryInt64StringPairs(t *testing.T, pool *pgxpool.Pool, ctx context.Context, q sq.Sqlizer) ([]int64, []string) {
54+
t.Helper()
55+
56+
sql, args, err := q.ToSql()
57+
require.NoError(t, err)
58+
59+
rows, err := pool.Query(ctx, sql, args...)
60+
require.NoError(t, err)
61+
t.Cleanup(rows.Close)
62+
63+
var ids []int64
64+
var names []string
65+
for rows.Next() {
66+
var id int64
67+
var name string
68+
err := rows.Scan(&id, &name)
69+
require.NoError(t, err)
70+
71+
ids = append(ids, id)
72+
names = append(names, name)
73+
}
74+
require.NoError(t, rows.Err())
75+
76+
return ids, names
77+
}

itests/select_test.go

Lines changed: 40 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -381,45 +381,41 @@ INSERT INTO user_groups_all (user_id, group_id) VALUES
381381
avgAmount float64
382382
}
383383

384-
results := make([]selectResult, 0)
385-
for rows.Next() {
386-
var row selectResult
387-
err := rows.Scan(
388-
&row.prefID,
389-
&row.prefName,
390-
&row.displayName,
391-
&row.emailLabel,
392-
&row.statusRank,
393-
&row.hasRefunds,
394-
&row.noChargebacks,
395-
&row.sumAmount,
396-
&row.ordersCount,
397-
&row.minAmount,
398-
&row.maxAmount,
399-
&row.avgAmount,
400-
)
401-
require.NoError(t, err)
402-
results = append(results, row)
403-
}
384+
require.True(t, rows.Next())
385+
386+
var got selectResult
387+
err = rows.Scan(
388+
&got.prefID,
389+
&got.prefName,
390+
&got.displayName,
391+
&got.emailLabel,
392+
&got.statusRank,
393+
&got.hasRefunds,
394+
&got.noChargebacks,
395+
&got.sumAmount,
396+
&got.ordersCount,
397+
&got.minAmount,
398+
&got.maxAmount,
399+
&got.avgAmount,
400+
)
401+
require.NoError(t, err)
402+
require.False(t, rows.Next())
404403
require.NoError(t, rows.Err())
405404

406-
require.Len(t, results, 1)
407-
408-
assert.Equal(t, int64(1), results[0].prefID)
409-
assert.Equal(t, "Alice", results[0].prefName)
410-
assert.Equal(t, "Alice <alice@example.com>", results[0].displayName)
411-
assert.Equal(t, "alice@work.com", results[0].emailLabel)
412-
assert.Equal(t, 1, results[0].statusRank)
413-
assert.True(t, results[0].hasRefunds)
414-
assert.True(t, results[0].noChargebacks)
415-
assert.InEpsilon(t, 120.0, results[0].sumAmount, 0.0001)
416-
assert.Equal(t, int64(2), results[0].ordersCount)
417-
assert.InEpsilon(t, 20.0, results[0].minAmount, 0.0001)
418-
assert.InEpsilon(t, 100.0, results[0].maxAmount, 0.0001)
419-
assert.InEpsilon(t, 60.0, results[0].avgAmount, 0.0001)
420-
421-
422-
cleanupQuery := sq.Select("id").
405+
assert.Equal(t, int64(1), got.prefID)
406+
assert.Equal(t, "Alice", got.prefName)
407+
assert.Equal(t, "Alice <alice@example.com>", got.displayName)
408+
assert.Equal(t, "alice@work.com", got.emailLabel)
409+
assert.Equal(t, 1, got.statusRank)
410+
assert.True(t, got.hasRefunds)
411+
assert.True(t, got.noChargebacks)
412+
assert.InEpsilon(t, 120.0, got.sumAmount, 0.0001)
413+
assert.Equal(t, int64(2), got.ordersCount)
414+
assert.InEpsilon(t, 20.0, got.minAmount, 0.0001)
415+
assert.InEpsilon(t, 100.0, got.maxAmount, 0.0001)
416+
assert.InEpsilon(t, 60.0, got.avgAmount, 0.0001)
417+
418+
builderResetQuery := sq.Select("id").
423419
Distinct().
424420
FromSelect(activeUsers, "au").
425421
RemoveColumns().
@@ -432,45 +428,16 @@ INSERT INTO user_groups_all (user_id, group_id) VALUES
432428
Prefix("/* cleanup */").
433429
PlaceholderFormat(sq.Dollar)
434430

435-
sql, args, err = cleanupQuery.ToSql()
436-
require.NoError(t, err)
437-
438-
rows, err = pool.Query(ctx, sql, args...)
439-
require.NoError(t, err)
440-
t.Cleanup(rows.Close)
441-
442-
cleaned := make([]int64, 0)
443-
for rows.Next() {
444-
var id int64
445-
var name string
446-
err := rows.Scan(&id, &name)
447-
require.NoError(t, err)
448-
cleaned = append(cleaned, id)
449-
}
450-
require.NoError(t, rows.Err())
451-
assert.Len(t, cleaned, 4)
431+
cleanedIDs, _ := queryInt64StringPairs(t, pool, ctx, builderResetQuery)
432+
assert.Len(t, cleanedIDs, 4)
452433

453434
paginateByID := sq.Select("id").
454435
From("users_all").
455436
OrderBy("id").
456437
PaginateByID(2, 1, "id").
457438
PlaceholderFormat(sq.Dollar)
458439

459-
sql, args, err = paginateByID.ToSql()
460-
require.NoError(t, err)
461-
462-
rows, err = pool.Query(ctx, sql, args...)
463-
require.NoError(t, err)
464-
t.Cleanup(rows.Close)
465-
466-
pageByID := make([]int64, 0)
467-
for rows.Next() {
468-
var id int64
469-
err := rows.Scan(&id)
470-
require.NoError(t, err)
471-
pageByID = append(pageByID, id)
472-
}
473-
require.NoError(t, rows.Err())
440+
pageByID := queryInt64s(t, pool, ctx, paginateByID)
474441
assert.Equal(t, []int64{2, 3}, pageByID)
475442

476443
paginateByPage := sq.Select("id").
@@ -479,21 +446,7 @@ INSERT INTO user_groups_all (user_id, group_id) VALUES
479446
PaginateByPage(2, 2).
480447
PlaceholderFormat(sq.Dollar)
481448

482-
sql, args, err = paginateByPage.ToSql()
483-
require.NoError(t, err)
484-
485-
rows, err = pool.Query(ctx, sql, args...)
486-
require.NoError(t, err)
487-
t.Cleanup(rows.Close)
488-
489-
pageByPage := make([]int64, 0)
490-
for rows.Next() {
491-
var id int64
492-
err := rows.Scan(&id)
493-
require.NoError(t, err)
494-
pageByPage = append(pageByPage, id)
495-
}
496-
require.NoError(t, rows.Err())
449+
pageByPage := queryInt64s(t, pool, ctx, paginateByPage)
497450
assert.Equal(t, []int64{3, 4}, pageByPage)
498451

499452
paginateByPaginator := sq.Select("id").
@@ -503,21 +456,7 @@ INSERT INTO user_groups_all (user_id, group_id) VALUES
503456
SetIDColumn("id").
504457
PlaceholderFormat(sq.Dollar)
505458

506-
sql, args, err = paginateByPaginator.ToSql()
507-
require.NoError(t, err)
508-
509-
rows, err = pool.Query(ctx, sql, args...)
510-
require.NoError(t, err)
511-
t.Cleanup(rows.Close)
512-
513-
pageByPaginator := make([]int64, 0)
514-
for rows.Next() {
515-
var id int64
516-
err := rows.Scan(&id)
517-
require.NoError(t, err)
518-
pageByPaginator = append(pageByPaginator, id)
519-
}
520-
require.NoError(t, rows.Err())
459+
pageByPaginator := queryInt64s(t, pool, ctx, paginateByPaginator)
521460
assert.Equal(t, []int64{2, 3}, pageByPaginator)
522461

523462
paginateByPagePaginator := sq.Select("id").
@@ -526,21 +465,7 @@ INSERT INTO user_groups_all (user_id, group_id) VALUES
526465
Paginate(sq.PaginatorByPage(2, 2)).
527466
PlaceholderFormat(sq.Dollar)
528467

529-
sql, args, err = paginateByPagePaginator.ToSql()
530-
require.NoError(t, err)
531-
532-
rows, err = pool.Query(ctx, sql, args...)
533-
require.NoError(t, err)
534-
t.Cleanup(rows.Close)
535-
536-
pageByPagePaginator := make([]int64, 0)
537-
for rows.Next() {
538-
var id int64
539-
err := rows.Scan(&id)
540-
require.NoError(t, err)
541-
pageByPagePaginator = append(pageByPagePaginator, id)
542-
}
543-
require.NoError(t, rows.Err())
468+
pageByPagePaginator := queryInt64s(t, pool, ctx, paginateByPagePaginator)
544469
assert.Equal(t, []int64{3, 4}, pageByPagePaginator)
545470

546471
recursiveQuery := sq.WithRecursive("category_tree").As(
@@ -552,24 +477,7 @@ INSERT INTO user_groups_all (user_id, group_id) VALUES
552477
From("category_tree"),
553478
).PlaceholderFormat(sq.Dollar)
554479

555-
sql, args, err = recursiveQuery.ToSql()
556-
require.NoError(t, err)
557-
558-
rows, err = pool.Query(ctx, sql, args...)
559-
require.NoError(t, err)
560-
t.Cleanup(rows.Close)
561-
562-
recursiveIDs := make([]int64, 0)
563-
recursiveNames := make([]string, 0)
564-
for rows.Next() {
565-
var id int64
566-
var name string
567-
err := rows.Scan(&id, &name)
568-
require.NoError(t, err)
569-
recursiveIDs = append(recursiveIDs, id)
570-
recursiveNames = append(recursiveNames, name)
571-
}
572-
require.NoError(t, rows.Err())
480+
recursiveIDs, recursiveNames := queryInt64StringPairs(t, pool, ctx, recursiveQuery)
573481
assert.Equal(t, []int64{1}, recursiveIDs)
574482
assert.Equal(t, []string{"Engineering"}, recursiveNames)
575483
}

0 commit comments

Comments
 (0)