Skip to content

Commit 2db8578

Browse files
authored
Merge pull request #2 from hidetatz/hidetatz/bug/incr_decr_lostupdate
2 parents 7d1c297 + 8330298 commit 2db8578

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

cache.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ func (c *Cache[K, _]) Keys() []K {
122122
// NumberCache is a in-memory cache which is able to store only Number constraint.
123123
type NumberCache[K comparable, V Number] struct {
124124
*Cache[K, V]
125+
// nmu is used to do lock in Increment/Decrement process.
126+
// Note that this must be here as a separate mutex because mu in Cache struct is Locked in GetItem,
127+
// and if we call mu.Lock in Increment/Decrement, it will cause deadlock.
128+
nmu sync.Mutex
125129
}
126130

127131
// NewNumber creates a new cache for Number constraint.
@@ -135,6 +139,9 @@ func NewNumber[K comparable, V Number]() *NumberCache[K, V] {
135139
// Returns an error if the item was not found or expired. If there is no error, the
136140
// incremented value is returned.
137141
func (nc *NumberCache[K, V]) Increment(k K, n V) (val V, err error) {
142+
// In order to avoid lost update, we must lock whole Increment/Decrement process.
143+
nc.nmu.Lock()
144+
defer nc.nmu.Unlock()
138145
got, err := nc.Cache.GetItem(k)
139146
if err != nil {
140147
return val, err
@@ -151,6 +158,8 @@ func (nc *NumberCache[K, V]) Increment(k K, n V) (val V, err error) {
151158
// Returns an error if the item was not found or expired. If there is no error, the
152159
// decremented value is returned.
153160
func (nc *NumberCache[K, V]) Decrement(k K, n V) (val V, err error) {
161+
nc.nmu.Lock()
162+
defer nc.nmu.Unlock()
154163
got, err := nc.Cache.GetItem(k)
155164
if err != nil {
156165
return val, err

cache_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cache
22

33
import (
44
"errors"
5+
"sync"
56
"testing"
67
"time"
78
)
@@ -82,3 +83,51 @@ func TestGetItemExpired(t *testing.T) {
8283
}
8384

8485
}
86+
87+
func TestMultiThreadIncr(t *testing.T) {
88+
nc := NewNumber[string, int]()
89+
nc.Set("counter", 0)
90+
91+
var wg sync.WaitGroup
92+
93+
for i := 0; i < 100; i++ {
94+
wg.Add(1)
95+
go func() {
96+
_, err := nc.Increment("counter", 1)
97+
if err != nil {
98+
t.Logf("err: %v", err)
99+
}
100+
wg.Done()
101+
}()
102+
}
103+
104+
wg.Wait()
105+
106+
if counter, _ := nc.Get("counter"); counter != 100 {
107+
t.Errorf("want %v but got %v", 100, counter)
108+
}
109+
}
110+
111+
func TestMultiThreadDecr(t *testing.T) {
112+
nc := NewNumber[string, int]()
113+
nc.Set("counter", 100)
114+
115+
var wg sync.WaitGroup
116+
117+
for i := 0; i < 100; i++ {
118+
wg.Add(1)
119+
go func() {
120+
_, err := nc.Decrement("counter", 1)
121+
if err != nil {
122+
t.Logf("err: %v", err)
123+
}
124+
wg.Done()
125+
}()
126+
}
127+
128+
wg.Wait()
129+
130+
if counter, _ := nc.Get("counter"); counter != 0 {
131+
t.Errorf("want %v but got %v", 0, counter)
132+
}
133+
}

0 commit comments

Comments
 (0)