Skip to content

Commit 6440450

Browse files
Use reflect.Select to process dynamic list of channels correctly
1 parent 4860086 commit 6440450

File tree

2 files changed

+57
-32
lines changed

2 files changed

+57
-32
lines changed

internal/libvirt/libvirt.go

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"errors"
2323
"fmt"
2424
"os"
25+
"reflect"
2526
"sync"
2627
"time"
2728

@@ -143,30 +144,49 @@ func (l *LibVirt) Close() error {
143144
func (l *LibVirt) runEventLoop(ctx context.Context) {
144145
log := logger.FromContext(ctx, "libvirt", "event-loop")
145146
for {
147+
// The reflect.Select function works the same way as a
148+
// regular select statement, but allows selecting over
149+
// a dynamic set of channels.
150+
var cases []reflect.SelectCase
151+
var eventIds []libvirt.DomainEventID
146152
for eventId, ch := range l.domEventChs {
147-
select {
148-
case <-ctx.Done():
149-
return
150-
case <-l.virt.Disconnected():
151-
log.Error(errors.New("libvirt disconnected"), "waiting for reconnection")
152-
time.Sleep(5 * time.Second)
153-
case eventPayload, ok := <-ch:
154-
if !ok {
155-
err := errors.New("libvirt event channel closed")
156-
log.Error(err, "eventId", eventId)
157-
continue
158-
}
159-
handlers, exists := l.domEventChangeHandlers[eventId]
160-
if !exists {
161-
continue
162-
}
163-
for _, handler := range handlers {
164-
// Process each handler sequentially.
165-
handler(ctx, eventPayload)
166-
}
167-
default:
168-
// No event available, continue
169-
}
153+
cases = append(cases, reflect.SelectCase{
154+
Dir: reflect.SelectRecv,
155+
Chan: reflect.ValueOf(ch),
156+
})
157+
eventIds = append(eventIds, eventId)
158+
}
159+
160+
cases = append(cases, reflect.SelectCase{
161+
Dir: reflect.SelectRecv,
162+
Chan: reflect.ValueOf(ctx.Done()),
163+
})
164+
caseCtxDone := len(cases) - 1
165+
166+
chosen, value, ok := reflect.Select(cases)
167+
if !ok {
168+
// This should never happen. If it does, give the
169+
// service a chance to restart and reconnect.
170+
panic("libvirt connection closed")
171+
}
172+
if chosen == caseCtxDone {
173+
log.Info("shutting down libvirt event loop")
174+
return
175+
}
176+
if chosen >= len(eventIds) {
177+
msg := "no handler for selected channel"
178+
log.Error(errors.New("invalid event channel selected"), msg)
179+
continue
180+
}
181+
182+
// Distribute the event to all registered handlers.
183+
eventId := eventIds[chosen] // safe as chosen < len(eventIds)
184+
handlers, exists := l.domEventChangeHandlers[eventId]
185+
if !exists {
186+
continue
187+
}
188+
for _, handler := range handlers {
189+
handler(ctx, value.Interface())
170190
}
171191
}
172192
}

internal/libvirt/libvirt_test.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,13 @@ func TestWatchDomainChanges_OverwriteHandler(t *testing.T) {
10251025
}
10261026

10271027
func TestRunEventLoop_ProcessesEvents(t *testing.T) {
1028-
// Create a channel for events
1029-
eventCh := make(chan any, 1)
1028+
// Create a buffered channel for events that won't be closed during the test
1029+
eventChInternal := make(chan any, 10)
1030+
1031+
// Wrap it in a read-only channel to prevent accidental closure
1032+
var eventCh <-chan any = eventChInternal
10301033

1031-
mockConn := &mockLibvirtConnection{}
1034+
mockConn := newMockLibvirtConnection()
10321035

10331036
l := &LibVirt{
10341037
virt: &mockConn.Libvirt,
@@ -1048,16 +1051,12 @@ func TestRunEventLoop_ProcessesEvents(t *testing.T) {
10481051

10491052
l.WatchDomainChanges(libvirt.DomainEventIDLifecycle, "test-handler", handler)
10501053

1051-
// Create a context with timeout
1052-
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
1053-
defer cancel()
1054-
10551054
// Start the event loop in a goroutine
1056-
go l.runEventLoop(ctx)
1055+
go l.runEventLoop(t.Context())
10571056

10581057
// Send an event
10591058
testPayload := "test-event-payload"
1060-
eventCh <- testPayload
1059+
eventChInternal <- testPayload
10611060

10621061
// Give some time for the event to be processed
10631062
time.Sleep(50 * time.Millisecond)
@@ -1078,6 +1077,12 @@ type mockLibvirtConnection struct {
10781077
disconnectedCh chan struct{}
10791078
}
10801079

1080+
func newMockLibvirtConnection() *mockLibvirtConnection {
1081+
return &mockLibvirtConnection{
1082+
disconnectedCh: make(chan struct{}),
1083+
}
1084+
}
1085+
10811086
func (m *mockLibvirtConnection) Disconnected() <-chan struct{} {
10821087
return m.disconnectedCh
10831088
}

0 commit comments

Comments
 (0)