From 43dfdd13efcf8f55b3d72a7ee6e8e723bb4a0163 Mon Sep 17 00:00:00 2001 From: Tianzhi Jin Date: Fri, 15 Aug 2025 13:56:54 +0800 Subject: [PATCH] fix: be able to invoke Close in SSE callback --- sse.go | 10 ++++++---- sse_test.go | 30 +++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/sse.go b/sse.go index 3b9b8e5a..86b380c3 100644 --- a/sse.go +++ b/sse.go @@ -585,14 +585,16 @@ func (es *EventSource) processEvent(scanner *bufio.Scanner) error { } func (es *EventSource) handleCallback(e *Event) { - es.lock.RLock() - defer es.lock.RUnlock() - eventName := e.Name if len(eventName) == 0 { eventName = defaultEventName } - if cb, found := es.onEvent[eventName]; found { + + es.lock.RLock() + cb, found := es.onEvent[eventName] + es.lock.RUnlock() + + if found { if cb.Result == nil { cb.Func(e) return diff --git a/sse_test.go b/sse_test.go index 323b0304..c1a1e512 100644 --- a/sse_test.go +++ b/sse_test.go @@ -20,22 +20,26 @@ import ( ) func TestEventSourceSimpleFlow(t *testing.T) { + es := createEventSource(t, "", nil, nil) + messageCounter := 0 messageFunc := func(e any) { event := e.(*Event) assertEqual(t, strconv.Itoa(messageCounter), event.ID) assertEqual(t, true, strings.HasPrefix(event.Data, "The time is")) messageCounter++ + if messageCounter == 100 { + es.Close() + } } + es.OnMessage(messageFunc, nil) counter := 0 - es := createEventSource(t, "", messageFunc, nil) ts := createSSETestServer( t, 10*time.Millisecond, func(w io.Writer) error { if counter == 100 { - es.Close() return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) @@ -129,22 +133,25 @@ func TestEventSourceOverwriteFuncs(t *testing.T) { messageFunc1 := func(e any) { assertNotNil(t, e) } + es := createEventSource(t, "", messageFunc1, nil) + message2Counter := 0 messageFunc2 := func(e any) { event := e.(*Event) assertEqual(t, strconv.Itoa(message2Counter), event.ID) assertEqual(t, true, strings.HasPrefix(event.Data, "The time is")) message2Counter++ + if message2Counter == 50 { + es.Close() + } } counter := 0 - es := createEventSource(t, "", messageFunc1, nil) ts := createSSETestServer( t, 10*time.Millisecond, func(w io.Writer) error { if counter == 50 { - es.Close() return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) @@ -177,16 +184,21 @@ func TestEventSourceOverwriteFuncs(t *testing.T) { } func TestEventSourceRetry(t *testing.T) { + es := createEventSource(t, "", nil, nil) + messageCounter := 2 // 0 & 1 connection failure messageFunc := func(e any) { event := e.(*Event) assertEqual(t, strconv.Itoa(messageCounter), event.ID) assertEqual(t, true, strings.HasPrefix(event.Data, "The time is")) messageCounter++ + if messageCounter == 15 { + es.Close() + } } + es.OnMessage(messageFunc, nil) counter := 0 - es := createEventSource(t, "", messageFunc, nil) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { if counter == 1 && r.URL.Query().Get("reconnect") == "1" { w.WriteHeader(http.StatusTooManyRequests) @@ -445,19 +457,24 @@ func TestEventSourceWithDifferentMethods(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + es := createEventSource(t, "", nil, nil) + messageCounter := 0 messageFunc := func(e any) { event := e.(*Event) assertEqual(t, strconv.Itoa(messageCounter), event.ID) assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method))) messageCounter++ + if messageCounter == 20 { + es.Close() + } } + es.OnMessage(messageFunc, nil) counter := 0 methodVerified := false bodyVerified := false - es := createEventSource(t, "", messageFunc, nil) ts := createMethodVerifyingSSETestServer( t, 10*time.Millisecond, @@ -467,7 +484,6 @@ func TestEventSourceWithDifferentMethods(t *testing.T) { &bodyVerified, func(w io.Writer) error { if counter == 20 { - es.Close() return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))