Skip to content

Commit 1c8f32a

Browse files
authored
wait to peek from buf until loop enqueue completed (#90)
1 parent 4aced6f commit 1c8f32a

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

channel/channel.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package channel
33
import (
44
"container/list"
55
"sync"
6+
"sync/atomic"
67
)
78

89
// Sizable is the interface implemented by types that support to compute memory size.
@@ -31,7 +32,8 @@ type MemoryBoundedChannel[T Sizable] struct {
3132
sendCh chan T // channel for sending data
3233
recvCh chan T // channel for receiving data
3334

34-
closed bool // indicates whether the channel is closed
35+
sendClosed atomic.Bool // indicates whether the send channel is closed
36+
recvClosed atomic.Bool // indicates whether the recv channel is closed
3537
}
3638

3739
// NewMemoryBoundedChannel creates a new MemoryBoundedChannel with the specified capacity and maximum memory size.
@@ -84,27 +86,27 @@ func (ch *MemoryBoundedChannel[T]) Close() {
8486
defer ch.mu.Unlock()
8587

8688
// already closed
87-
if ch.closed {
89+
if !ch.sendClosed.CompareAndSwap(false, true) {
8890
return
8991
}
9092

91-
ch.closed = true
92-
9393
// Close sendCh to stop loopEnqueue goroutine.
9494
// Note, recvCh cannot be closed here, otherwise loopDequeue
9595
// goroutine may panic to write data into recvCh.
9696
close(ch.sendCh)
9797

98-
// wake up all waiting goroutines
98+
// wake up loopEnqueue goroutine
9999
ch.notFullCond.Broadcast()
100-
ch.notEmptyCond.Broadcast()
101100
}
102101

103102
func (ch *MemoryBoundedChannel[T]) loopEnqueue(stopCh chan<- struct{}) {
104103
for item := range ch.sendCh {
105104
ch.enqueue(item)
106105
}
107106

107+
ch.recvClosed.Store(true)
108+
ch.notEmptyCond.Broadcast()
109+
108110
if stopCh != nil {
109111
stopCh <- struct{}{}
110112
}
@@ -127,7 +129,7 @@ func (ch *MemoryBoundedChannel[T]) enqueue(item T) {
127129
ch.notEmptyCond.Broadcast()
128130

129131
// blocking sending channel if buffer is full and this channel not closed yet
130-
for !ch.closed && (ch.buffer.Len() >= ch.capacity || ch.curBytes >= ch.maxBytes) {
132+
for !ch.sendClosed.Load() && (ch.buffer.Len() >= ch.capacity || ch.curBytes >= ch.maxBytes) {
131133
ch.notFullCond.Wait()
132134
}
133135
}
@@ -157,7 +159,7 @@ func (ch *MemoryBoundedChannel[T]) peek() (val Sized[T], ok bool) {
157159
defer ch.mu.Unlock()
158160

159161
// wait until there is data
160-
for !ch.closed && ch.buffer.Len() == 0 {
162+
for !ch.recvClosed.Load() && ch.buffer.Len() == 0 {
161163
ch.notEmptyCond.Wait()
162164
}
163165

0 commit comments

Comments
 (0)