Skip to content

Commit 6463ca2

Browse files
committed
Use struct for cancellation callbacks
1 parent 7be202d commit 6463ca2

File tree

2 files changed

+84
-9
lines changed

2 files changed

+84
-9
lines changed

internal/sync/channel.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@ type Channel[T any] interface {
1212
Close()
1313
}
1414

15+
type Receiver[T any] struct {
16+
Receive func(v T, ok bool)
17+
}
18+
1519
type ChannelInternal[T any] interface {
1620
Closed() bool
1721

1822
ReceiveNonBlocking() (v T, ok bool)
1923

2024
// AddReceiveCallback adds a callback that is called once when a value is sent to the channel. This is similar
2125
// to the blocking `Receive` method, but is not blocking a coroutine.
22-
AddReceiveCallback(cb func(v T, ok bool))
26+
AddReceiveCallback(rcb *Receiver[T])
27+
28+
RemoveReceiveCallback(rcb *Receiver[T])
2329
}
2430

2531
// Ensure channel implementation support internal interface
@@ -40,7 +46,7 @@ func NewBufferedChannel[T any](size int) Channel[T] {
4046

4147
type channel[T any] struct {
4248
c []T
43-
receivers []func(value T, ok bool)
49+
receivers []*Receiver[T]
4450
senders []func() T
4551
closed bool
4652
size int
@@ -62,7 +68,7 @@ func (c *channel[T]) Close() {
6268

6369
// Send zero value to pending receiver
6470
var v T
65-
r(v, false)
71+
r.Receive(v, false)
6672
}
6773
}
6874

@@ -119,10 +125,12 @@ func (c *channel[T]) Receive(ctx Context) (v T, ok bool) {
119125

120126
// Register handler to receive value once
121127
if !addedListener {
122-
cb := func(rv T, rok bool) {
123-
receivedValue = true
124-
v = rv
125-
ok = rok
128+
cb := &Receiver[T]{
129+
Receive: func(rv T, rok bool) {
130+
receivedValue = true
131+
v = rv
132+
ok = rok
133+
},
126134
}
127135

128136
c.receivers = append(c.receivers, cb)
@@ -176,7 +184,7 @@ func (c *channel[T]) trySend(v T) bool {
176184
c.receivers[0] = nil
177185
c.receivers = c.receivers[1:]
178186

179-
r(v, true)
187+
r.Receive(v, true)
180188

181189
return true
182190
}
@@ -223,10 +231,20 @@ func (c *channel[T]) hasCapacity() bool {
223231
return len(c.c) < c.size
224232
}
225233

226-
func (c *channel[T]) AddReceiveCallback(cb func(v T, ok bool)) {
234+
func (c *channel[T]) AddReceiveCallback(cb *Receiver[T]) {
227235
c.receivers = append(c.receivers, cb)
228236
}
229237

238+
func (c *channel[T]) RemoveReceiveCallback(cb *Receiver[T]) {
239+
for i, r := range c.receivers {
240+
if r == cb {
241+
c.receivers[i] = nil
242+
c.receivers = append(c.receivers[:i], c.receivers[i+1:]...)
243+
return
244+
}
245+
}
246+
}
247+
230248
func (c *channel[T]) Closed() bool {
231249
return c.closed
232250
}

internal/sync/channel_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,60 @@ func Test_Channel_Buffered(t *testing.T) {
413413
})
414414
}
415415
}
416+
417+
func Test_CancellationHandler_Add(t *testing.T) {
418+
ctx, cancel := WithCancel(Background())
419+
420+
f := 1
421+
422+
c := NewChannel[int]()
423+
ic := c.(*channel[int])
424+
425+
cr := NewCoroutine(ctx, func(ctx Context) error {
426+
r := &Receiver[int]{
427+
Receive: func(_ int, ok bool) {
428+
f++
429+
},
430+
}
431+
432+
ic.AddReceiveCallback(r)
433+
434+
c.Send(ctx, 42)
435+
return nil
436+
})
437+
438+
cr.Execute()
439+
440+
cancel()
441+
442+
require.Equal(t, 2, f)
443+
}
444+
445+
func Test_CancellationHandler_Remove(t *testing.T) {
446+
ctx, cancel := WithCancel(Background())
447+
448+
f := 1
449+
450+
c := NewChannel[int]()
451+
ic := c.(*channel[int])
452+
453+
cr := NewCoroutine(ctx, func(ctx Context) error {
454+
r := &Receiver[int]{
455+
Receive: func(_ int, ok bool) {
456+
f++
457+
},
458+
}
459+
460+
ic.AddReceiveCallback(r)
461+
ic.RemoveReceiveCallback(r)
462+
463+
c.Send(ctx, 42)
464+
return nil
465+
})
466+
467+
cr.Execute()
468+
469+
cancel()
470+
471+
require.Equal(t, 1, f)
472+
}

0 commit comments

Comments
 (0)