Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,16 @@ func (es *EventSource) processEvent(scanner *bufio.Scanner) error {
}

func (es *EventSource) handleCallback(e *Event) {
es.lock.RLock()
defer es.lock.RUnlock()

eventName := e.Name
if len(eventName) == 0 {
eventName = defaultEventName
}
if cb, found := es.onEvent[eventName]; found {

es.lock.RLock()
cb, found := es.onEvent[eventName]
es.lock.RUnlock()

if found {
if cb.Result == nil {
cb.Func(e)
return
Expand Down
30 changes: 23 additions & 7 deletions sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@ import (
)

func TestEventSourceSimpleFlow(t *testing.T) {
es := createEventSource(t, "", nil, nil)

messageCounter := 0
messageFunc := func(e any) {
event := e.(*Event)
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
assertEqual(t, true, strings.HasPrefix(event.Data, "The time is"))
messageCounter++
if messageCounter == 100 {
es.Close()
}
}
es.OnMessage(messageFunc, nil)

counter := 0
es := createEventSource(t, "", messageFunc, nil)
ts := createSSETestServer(
t,
10*time.Millisecond,
func(w io.Writer) error {
if counter == 100 {
es.Close()
return fmt.Errorf("stop sending events")
}
_, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate))
Expand Down Expand Up @@ -129,22 +133,25 @@ func TestEventSourceOverwriteFuncs(t *testing.T) {
messageFunc1 := func(e any) {
assertNotNil(t, e)
}
es := createEventSource(t, "", messageFunc1, nil)

message2Counter := 0
messageFunc2 := func(e any) {
event := e.(*Event)
assertEqual(t, strconv.Itoa(message2Counter), event.ID)
assertEqual(t, true, strings.HasPrefix(event.Data, "The time is"))
message2Counter++
if message2Counter == 50 {
es.Close()
}
}

counter := 0
es := createEventSource(t, "", messageFunc1, nil)
ts := createSSETestServer(
t,
10*time.Millisecond,
func(w io.Writer) error {
if counter == 50 {
es.Close()
return fmt.Errorf("stop sending events")
}
_, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate))
Expand Down Expand Up @@ -177,16 +184,21 @@ func TestEventSourceOverwriteFuncs(t *testing.T) {
}

func TestEventSourceRetry(t *testing.T) {
es := createEventSource(t, "", nil, nil)

messageCounter := 2 // 0 & 1 connection failure
messageFunc := func(e any) {
event := e.(*Event)
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
assertEqual(t, true, strings.HasPrefix(event.Data, "The time is"))
messageCounter++
if messageCounter == 15 {
es.Close()
}
}
es.OnMessage(messageFunc, nil)

counter := 0
es := createEventSource(t, "", messageFunc, nil)
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
if counter == 1 && r.URL.Query().Get("reconnect") == "1" {
w.WriteHeader(http.StatusTooManyRequests)
Expand Down Expand Up @@ -445,19 +457,24 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
es := createEventSource(t, "", nil, nil)

messageCounter := 0
messageFunc := func(e any) {
event := e.(*Event)
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method)))
messageCounter++
if messageCounter == 20 {
es.Close()
}
}
es.OnMessage(messageFunc, nil)

counter := 0
methodVerified := false
bodyVerified := false

es := createEventSource(t, "", messageFunc, nil)
ts := createMethodVerifyingSSETestServer(
t,
10*time.Millisecond,
Expand All @@ -467,7 +484,6 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
&bodyVerified,
func(w io.Writer) error {
if counter == 20 {
es.Close()
return fmt.Errorf("stop sending events")
}
_, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))
Expand Down