@@ -3,6 +3,7 @@ package channel
33import (
44 "container/list"
55 "sync"
6+ "sync/atomic"
67)
78
89// Sizable is the interface implemented by types that support to compute memory size.
@@ -31,7 +32,8 @@ type MemoryBoundedChannel[T Sizable] struct {
3132 sendCh chan T // channel for sending data
3233 recvCh chan T // channel for receiving data
3334
34- closed bool // indicates whether the channel is closed
35+ sendClosed atomic.Bool // indicates whether the send channel is closed
36+ recvClosed atomic.Bool // indicates whether the recv channel is closed
3537}
3638
3739// NewMemoryBoundedChannel creates a new MemoryBoundedChannel with the specified capacity and maximum memory size.
@@ -84,27 +86,27 @@ func (ch *MemoryBoundedChannel[T]) Close() {
8486 defer ch .mu .Unlock ()
8587
8688 // already closed
87- if ch .closed {
89+ if ! ch .sendClosed . CompareAndSwap ( false , true ) {
8890 return
8991 }
9092
91- ch .closed = true
92-
9393 // Close sendCh to stop loopEnqueue goroutine.
9494 // Note, recvCh cannot be closed here, otherwise loopDequeue
9595 // goroutine may panic to write data into recvCh.
9696 close (ch .sendCh )
9797
98- // wake up all waiting goroutines
98+ // wake up loopEnqueue goroutine
9999 ch .notFullCond .Broadcast ()
100- ch .notEmptyCond .Broadcast ()
101100}
102101
103102func (ch * MemoryBoundedChannel [T ]) loopEnqueue (stopCh chan <- struct {}) {
104103 for item := range ch .sendCh {
105104 ch .enqueue (item )
106105 }
107106
107+ ch .recvClosed .Store (true )
108+ ch .notEmptyCond .Broadcast ()
109+
108110 if stopCh != nil {
109111 stopCh <- struct {}{}
110112 }
@@ -127,7 +129,7 @@ func (ch *MemoryBoundedChannel[T]) enqueue(item T) {
127129 ch .notEmptyCond .Broadcast ()
128130
129131 // blocking sending channel if buffer is full and this channel not closed yet
130- for ! ch .closed && (ch .buffer .Len () >= ch .capacity || ch .curBytes >= ch .maxBytes ) {
132+ for ! ch .sendClosed . Load () && (ch .buffer .Len () >= ch .capacity || ch .curBytes >= ch .maxBytes ) {
131133 ch .notFullCond .Wait ()
132134 }
133135}
@@ -157,7 +159,7 @@ func (ch *MemoryBoundedChannel[T]) peek() (val Sized[T], ok bool) {
157159 defer ch .mu .Unlock ()
158160
159161 // wait until there is data
160- for ! ch .closed && ch .buffer .Len () == 0 {
162+ for ! ch .recvClosed . Load () && ch .buffer .Len () == 0 {
161163 ch .notEmptyCond .Wait ()
162164 }
163165
0 commit comments