Skip to content

Commit 8e0a8cc

Browse files
committed
Add tests for simple CRUD, complex transactions and LockForUpdate()
1 parent 2d01907 commit 8e0a8cc

File tree

4 files changed

+367
-30
lines changed

4 files changed

+367
-30
lines changed

tests/database_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package pgkit_test
2+
3+
import (
4+
"context"
5+
6+
"github.com/goware/pgkit/v2"
7+
"github.com/jackc/pgx/v5"
8+
)
9+
10+
type Database struct {
11+
*pgkit.DB
12+
13+
Accounts *accountsTable
14+
Articles *articlesTable
15+
Reviews *reviewsTable
16+
}
17+
18+
func initDB(db *pgkit.DB) *Database {
19+
return &Database{
20+
DB: db,
21+
Accounts: &accountsTable{Table: &pgkit.Table[Account, *Account, int64]{DB: db, Name: "accounts", IDColumn: "id"}},
22+
Articles: &articlesTable{Table: &pgkit.Table[Article, *Article, uint64]{DB: db, Name: "articles", IDColumn: "id"}},
23+
Reviews: &reviewsTable{Table: &pgkit.Table[Review, *Review, uint64]{DB: db, Name: "reviews", IDColumn: "id"}},
24+
}
25+
}
26+
27+
func (db *Database) BeginTx(ctx context.Context, fn func(tx *Database) error) error {
28+
return pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error {
29+
tx := db.WithTxQuery(pgTx)
30+
return fn(tx)
31+
})
32+
}
33+
34+
func (db *Database) WithTxQuery(tx pgx.Tx) *Database {
35+
pgkitDB := &pgkit.DB{
36+
Conn: db.Conn,
37+
SQL: db.SQL,
38+
Query: db.TxQuery(tx),
39+
}
40+
41+
return initDB(pgkitDB)
42+
}
43+
44+
type accountsTable struct {
45+
*pgkit.Table[Account, *Account, int64]
46+
}
47+
48+
type articlesTable struct {
49+
*pgkit.Table[Article, *Article, uint64]
50+
}
51+
52+
type reviewsTable struct {
53+
*pgkit.Table[Review, *Review, uint64]
54+
}

tests/schema_test.go

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pgkit_test
22

33
import (
4+
"fmt"
45
"time"
56

67
"github.com/goware/pgkit/v2/dbtype"
@@ -11,19 +12,83 @@ type Account struct {
1112
Name string `db:"name"`
1213
Disabled bool `db:"disabled"`
1314
CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT
15+
UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT
1416
}
1517

16-
func (a *Account) DBTableName() string {
17-
return "accounts"
18+
func (a *Account) DBTableName() string { return "accounts" }
19+
func (a *Account) GetID() int64 { return a.ID }
20+
func (a *Account) SetUpdatedAt(t time.Time) { a.UpdatedAt = t }
21+
22+
func (a *Account) Validate() error {
23+
if a.Name == "" {
24+
return fmt.Errorf("name is required")
25+
}
26+
27+
return nil
28+
}
29+
30+
type Article struct {
31+
ID uint64 `db:"id,omitempty"`
32+
Author string `db:"author"`
33+
Alias *string `db:"alias"`
34+
Content Content `db:"content"` // using JSONB postgres datatype
35+
AccountID int64 `db:"account_id"`
36+
CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT
37+
UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT
38+
DeletedAt *time.Time `db:"deleted_at"`
39+
}
40+
41+
func (a *Article) GetID() uint64 { return a.ID }
42+
func (a *Article) SetUpdatedAt(t time.Time) { a.UpdatedAt = t }
43+
func (a *Article) SetDeletedAt(t time.Time) { a.DeletedAt = &t }
44+
45+
func (a *Article) Validate() error {
46+
if a.Author == "" {
47+
return fmt.Errorf("author is required")
48+
}
49+
50+
return nil
51+
}
52+
53+
type Content struct {
54+
Title string `json:"title"`
55+
Body string `json:"body"`
56+
Views int64 `json:"views"`
1857
}
1958

2059
type Review struct {
21-
ID int64 `db:"id,omitempty"`
22-
Name string `db:"name"`
23-
Comments string `db:"comments"`
24-
CreatedAt time.Time `db:"created_at"` // if unset, will store Go zero-value
60+
ID uint64 `db:"id,omitempty"`
61+
Comment string `db:"comment"`
62+
Status ReviewStatus `db:"status"`
63+
Sentiment int64 `db:"sentiment"`
64+
AccountID int64 `db:"account_id"`
65+
ArticleID uint64 `db:"article_id"`
66+
CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT
67+
UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT
68+
DeletedAt *time.Time `db:"deleted_at"`
2569
}
2670

71+
func (r *Review) GetID() uint64 { return r.ID }
72+
func (r *Review) SetUpdatedAt(t time.Time) { r.UpdatedAt = t }
73+
func (r *Review) SetDeletedAt(t time.Time) { r.DeletedAt = &t }
74+
75+
func (r *Review) Validate() error {
76+
if len(r.Comment) < 3 {
77+
return fmt.Errorf("comment too short")
78+
}
79+
80+
return nil
81+
}
82+
83+
type ReviewStatus int64
84+
85+
const (
86+
ReviewStatusPending ReviewStatus = iota
87+
ReviewStatusProcessing
88+
ReviewStatusApproved
89+
ReviewStatusRejected
90+
)
91+
2792
type Log struct {
2893
ID int64 `db:"id,omitempty"`
2994
Message string `db:"message"`
@@ -38,16 +103,3 @@ type Stat struct {
38103
Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype
39104
Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype
40105
}
41-
42-
type Article struct {
43-
ID int64 `db:"id,omitempty"`
44-
Author string `db:"author"`
45-
Alias *string `db:"alias"`
46-
Content Content `db:"content"` // using JSONB postgres datatype
47-
}
48-
49-
type Content struct {
50-
Title string `json:"title"`
51-
Body string `json:"body"`
52-
Views int64 `json:"views"`
53-
}

tests/table_test.go

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
package pgkit_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"slices"
7+
"sync"
8+
"testing"
9+
"time"
10+
11+
sq "github.com/Masterminds/squirrel"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestTable(t *testing.T) {
16+
truncateAllTables(t)
17+
18+
ctx := t.Context()
19+
db := initDB(DB)
20+
21+
t.Run("Simple CRUD", func(t *testing.T) {
22+
account := &Account{
23+
Name: "Save Account",
24+
}
25+
26+
// Create.
27+
err := db.Accounts.Save(ctx, account)
28+
require.NoError(t, err, "Create failed")
29+
require.NotZero(t, account.ID, "ID should be set")
30+
require.NotZero(t, account.CreatedAt, "CreatedAt should be set")
31+
require.NotZero(t, account.UpdatedAt, "UpdatedAt should be set")
32+
33+
// Check count.
34+
count, err := db.Accounts.Count(ctx, nil)
35+
require.NoError(t, err, "FindAll failed")
36+
require.Equal(t, uint64(1), count, "Expected 1 account")
37+
38+
// Read from DB & check for equality.
39+
accountCheck, err := db.Accounts.GetByID(ctx, account.ID)
40+
require.NoError(t, err, "FindByID failed")
41+
require.Equal(t, account.ID, accountCheck.ID, "account ID should match")
42+
require.Equal(t, account.Name, accountCheck.Name, "account name should match")
43+
44+
// Update.
45+
account.Name = "Updated account"
46+
err = db.Accounts.Save(ctx, account)
47+
require.NoError(t, err, "Save failed")
48+
49+
// Read from DB & check for equality again.
50+
accountCheck, err = db.Accounts.GetByID(ctx, account.ID)
51+
require.NoError(t, err, "FindByID failed")
52+
require.Equal(t, account.ID, accountCheck.ID, "account ID should match")
53+
require.Equal(t, account.Name, accountCheck.Name, "account name should match")
54+
55+
// Check count again.
56+
count, err = db.Accounts.Count(ctx, nil)
57+
require.NoError(t, err, "FindAll failed")
58+
require.Equal(t, uint64(1), count, "Expected 1 account")
59+
})
60+
61+
t.Run("Complex Transaction", func(t *testing.T) {
62+
t.Parallel()
63+
ctx := t.Context()
64+
65+
err := db.BeginTx(ctx, func(tx *Database) error {
66+
// Create account.
67+
account := &Account{Name: "Complex Transaction Account"}
68+
err := tx.Accounts.Save(ctx, account)
69+
require.NoError(t, err, "Create account failed")
70+
71+
articles := []*Article{
72+
{Author: "First", AccountID: account.ID},
73+
{Author: "Second", AccountID: account.ID},
74+
{Author: "Third", AccountID: account.ID},
75+
}
76+
77+
// Save articles (3x insert).
78+
err = tx.Articles.SaveAll(ctx, articles)
79+
require.NoError(t, err, "SaveAll failed")
80+
81+
for _, article := range articles {
82+
require.NotZero(t, article.ID, "ID should be set")
83+
require.NotZero(t, article.CreatedAt, "CreatedAt should be set")
84+
require.NotZero(t, article.UpdatedAt, "UpdatedAt should be set")
85+
}
86+
87+
firstArticle := articles[0]
88+
89+
// Save articles (3x update, 1x insert).
90+
articles = append(articles, &Article{Author: "Fourth", AccountID: account.ID})
91+
err = tx.Articles.SaveAll(ctx, articles)
92+
require.NoError(t, err, "SaveAll failed")
93+
94+
for _, article := range articles {
95+
require.NotZero(t, article.ID, "ID should be set")
96+
require.NotZero(t, article.CreatedAt, "CreatedAt should be set")
97+
require.NotZero(t, article.UpdatedAt, "UpdatedAt should be set")
98+
}
99+
require.Equal(t, firstArticle.ID, articles[0].ID, "First article ID should be the same")
100+
101+
// Verify we can load all articles with .GetById()
102+
for _, article := range articles {
103+
articleCheck, err := tx.Articles.GetByID(ctx, article.ID)
104+
require.NoError(t, err, "GetByID failed")
105+
require.Equal(t, article.ID, articleCheck.ID, "Article ID should match")
106+
require.Equal(t, article.Author, articleCheck.Author, "Article Author should match")
107+
require.Equal(t, article.AccountID, articleCheck.AccountID, "Article AccountID should match")
108+
require.Equal(t, article.CreatedAt, articleCheck.CreatedAt, "Article CreatedAt should match")
109+
//require.Equal(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt should match")
110+
//require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .SaveAll() aboe updates the timestamp.
111+
require.Equal(t, article.DeletedAt, articleCheck.DeletedAt, "Article DeletedAt should match")
112+
}
113+
114+
// Verify we can load all articles with .GetByIDs()
115+
articleIDs := make([]uint64, len(articles))
116+
for _, article := range articles {
117+
articleIDs = append(articleIDs, article.ID)
118+
}
119+
articlesCheck, err := tx.Articles.GetByIDs(ctx, articleIDs)
120+
require.NoError(t, err, "GetByIDs failed")
121+
require.Equal(t, len(articles), len(articlesCheck), "Number of articles should match")
122+
for i, _ := range articlesCheck {
123+
require.Equal(t, articles[i].ID, articlesCheck[i].ID, "Article ID should match")
124+
require.Equal(t, articles[i].Author, articlesCheck[i].Author, "Article Author should match")
125+
require.Equal(t, articles[i].AccountID, articlesCheck[i].AccountID, "Article AccountID should match")
126+
require.Equal(t, articles[i].CreatedAt, articlesCheck[i].CreatedAt, "Article CreatedAt should match")
127+
//require.Equal(t, articles[i].UpdatedAt, articlesCheck[i].UpdatedAt, "Article UpdatedAt should match")
128+
require.Equal(t, articles[i].DeletedAt, articlesCheck[i].DeletedAt, "Article DeletedAt should match")
129+
}
130+
131+
// Soft-delete first article.
132+
err = tx.Articles.DeleteByID(ctx, firstArticle.ID)
133+
require.NoError(t, err, "DeleteByID failed")
134+
135+
// Check if article is soft-deleted.
136+
article, err := tx.Articles.GetByID(ctx, firstArticle.ID)
137+
require.NoError(t, err, "GetByID failed")
138+
require.Equal(t, firstArticle.ID, article.ID, "DeletedAt should be set")
139+
require.NotNil(t, article.DeletedAt, "DeletedAt should be set")
140+
141+
// Hard-delete first article.
142+
err = tx.Articles.HardDeleteByID(ctx, firstArticle.ID)
143+
require.NoError(t, err, "HardDeleteByID failed")
144+
145+
// Check if article is hard-deleted.
146+
article, err = tx.Articles.GetByID(ctx, firstArticle.ID)
147+
require.Error(t, err, "article was not hard-deleted")
148+
require.Nil(t, article, "article is not nil")
149+
150+
return nil
151+
})
152+
require.NoError(t, err, "SaveTx transaction failed")
153+
})
154+
}
155+
156+
func TestLockForUpdate(t *testing.T) {
157+
truncateAllTables(t)
158+
159+
ctx := t.Context()
160+
db := initDB(DB)
161+
162+
t.Run("TestLockForUpdate", func(t *testing.T) {
163+
// Create account.
164+
account := &Account{Name: "LockForUpdate Account"}
165+
err := db.Accounts.Save(ctx, account)
166+
require.NoError(t, err, "Create account failed")
167+
168+
// Create article.
169+
article := &Article{AccountID: account.ID, Author: "Author", Content: Content{Title: "Title", Body: "Body"}}
170+
err = db.Articles.Save(ctx, article)
171+
require.NoError(t, err, "Create article failed")
172+
173+
// Create 1000 reviews.
174+
reviews := make([]*Review, 100)
175+
for i := range 100 {
176+
reviews[i] = &Review{
177+
Comment: fmt.Sprintf("Test comment %d", i),
178+
AccountID: account.ID,
179+
ArticleID: article.ID,
180+
Status: ReviewStatusPending,
181+
}
182+
}
183+
err = db.Reviews.SaveAll(ctx, reviews)
184+
require.NoError(t, err, "create review")
185+
186+
cond := sq.Eq{
187+
"status": ReviewStatusPending,
188+
"deleted_at": nil,
189+
}
190+
orderBy := []string{"created_at ASC"}
191+
192+
var uniqueIDs [][]uint64 = make([][]uint64, 10)
193+
var wg sync.WaitGroup
194+
195+
for range 10 {
196+
wg.Go(func() {
197+
198+
err := db.Reviews.LockForUpdate(ctx, cond, orderBy, 10, func(reviews []*Review) {
199+
for _, review := range reviews {
200+
review.Status = ReviewStatusProcessing
201+
go processReviewAsynchronously(ctx, db, review)
202+
}
203+
})
204+
require.NoError(t, err, "lock for update")
205+
206+
})
207+
}
208+
wg.Wait()
209+
210+
ids := slices.Concat(uniqueIDs...)
211+
slices.Sort(ids)
212+
ids = slices.Compact(ids)
213+
214+
require.Equal(t, 100, len(ids), "number of processed unique reviews should be 100")
215+
})
216+
}
217+
218+
// TODO: defer() save status (success/failure) or put back to queue for processing.
219+
func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) {
220+
time.Sleep(1 * time.Second)
221+
review.Status = ReviewStatusApproved
222+
db.Reviews.Save(ctx, review)
223+
}

0 commit comments

Comments
 (0)