Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions pkg/protocol/http1/resp/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,6 @@ type clientRespStream struct {
closeCallback func(shouldClose bool) error
}

// ForceClose closes underlying conn. It enables `Read` call to return instead of blocking.
//
// This method is ONLY used by hertz internally.
// Normally, users call `Close` when the body is no longer used.
func (c *clientRespStream) ForceClose() (err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closeCallback != nil {
err = c.closeCallback(true)
c.closeCallback = nil
}
// NOTE: DO NOT put back to pool here,
// user may still use clientRespStream and call Close() like `defer body.Close()`
return
}

// Close closes response stream gracefully.
//
// NOTE:
Expand Down
117 changes: 83 additions & 34 deletions pkg/protocol/sse/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"strconv"
"strings"
"sync/atomic"
"time"

"github.com/cloudwego/hertz/internal/bytestr"
Expand All @@ -43,6 +44,10 @@ type Reader struct {
events int32

lastEventID string

// lower 32 bit: 0: not reading, n: reading count
// higher 32 bit: 0: not closing, 1: closed
state uint64
}

// NewReader creates a new SSE reader from the given response.
Expand Down Expand Up @@ -75,10 +80,6 @@ func (r *Reader) SetMaxBufferSize(max int) {
r.s.Buffer(nil, max)
}

type forceCloseIf interface {
ForceClose() error // implemented by *clientRespStream
}

// ForEach iterates over all SSE events in the response body,
// invoking the provided handler function for each event.
//
Expand All @@ -91,40 +92,34 @@ type forceCloseIf interface {
// - Context is cancelled (if ctx.Done() != nil)
// - All events are processed (returns nil)
func (r *Reader) ForEach(ctx context.Context, f func(e *Event) error) error {
if ctx.Done() != nil {
ch := make(chan struct{})
defer close(ch)
go func() {
select {
case <-ctx.Done():
// force close the underlying connection to release resource
// or r.Read may block until remote server ends
if s, ok := r.r.(forceCloseIf); ok {
s.ForceClose()
ch := make(chan error, 1)
go func() {
e := NewEvent()
defer e.Release()
for {
if err := r.Read(e); err != nil {
if err == io.EOF {
err = nil
}
case <-ch:
ch <- err
return
}
}()
}
e := NewEvent()
defer e.Release()
for {
if err := ctx.Err(); err != nil {
return err
}
if err := r.Read(e); err != nil {
if err == io.EOF {
return nil
select {
case <-ctx.Done():
return
default:
}
if er := ctx.Err(); er != nil {
err = er
if err := f(e); err != nil {
ch <- err
return
}
return err
}
if err := f(e); err != nil {
return err
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-ch:
return err
}
}

Expand All @@ -147,6 +142,14 @@ func (r *Reader) onEventRead(e *Event) {
// (e.g., bufio.ErrTooLong if an event line exceeds the buffer size).
// Use SetMaxBufferSize to handle larger events.
func (r *Reader) Read(e *Event) error {
if !r.incref() {
return errors.New("use of closed file")
}
defer func() {
if r.decref() {
_ = r.resp.CloseBodyStream()
}
}()
e.Reset()
for i := 0; r.s.Scan(); i++ {
line := r.s.Bytes()
Expand Down Expand Up @@ -224,7 +227,53 @@ func (r *Reader) Read(e *Event) error {
// Close closes the underlying response body.
//
// NOTE:
// * MUST NOT call Close() and Read() / ForEach() concurrently to avoid race issue.
// * It's OK to call Close() and Read() / ForEach() concurrently.
func (r *Reader) Close() error {
return r.resp.CloseBodyStream()
if !r.increfAndClose() {
return errors.New("closing a closed file")
}
if r.decref() {
return r.resp.CloseBodyStream()
}
return nil
}

// returns true if it's available for reading or writing.
func (r *Reader) incref() bool {
for {
old := atomic.LoadUint64(&r.state)
if (old >> 32) != 0 {
// closed
return false
}
if atomic.CompareAndSwapUint64(&r.state, old, old+1) {
return true
}
}
}

// returns false if the file was already closed.
func (r *Reader) increfAndClose() bool {
for {
old := atomic.LoadUint64(&r.state)
if (old >> 32) != 0 {
return false
}
// Mark as closed and acquire a reference.
new_ := (old | (1 << 32)) + 1
if atomic.CompareAndSwapUint64(&r.state, old, new_) {
return true
}
}
}

// returns true if there's no remaining reference and closed.
func (r *Reader) decref() bool {
for {
old := atomic.LoadUint64(&r.state)
new_ := old - 1
if atomic.CompareAndSwapUint64(&r.state, old, new_) {
return new_ == 1<<32
}
}
}
175 changes: 165 additions & 10 deletions pkg/protocol/sse/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"errors"
"io"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -281,18 +282,13 @@ func TestReader_ReadEvent_WithBodyStream(t *testing.T) {
}

type mockReadForceClose struct {
readFunc func(b []byte) (int, error)
closeFunc func() error
readFunc func(b []byte) (int, error)
}

func (m *mockReadForceClose) Read(b []byte) (int, error) {
return m.readFunc(b)
}

func (m *mockReadForceClose) ForceClose() error {
return m.closeFunc()
}

func TestReader_ReadEvent_Error(t *testing.T) {
// Create a reader that will return an error
errReader := &bytes.Reader{}
Expand Down Expand Up @@ -321,10 +317,6 @@ func TestReader_ForEach(t *testing.T) {
mr.readFunc = func(b []byte) (int, error) {
return 0, <-ch
}
mr.closeFunc = func() error {
ch <- errors.New("closed")
return nil
}

// create protocol.Response
resp := &protocol.Response{}
Expand Down Expand Up @@ -400,3 +392,166 @@ func TestReader_SetMaxBufferSize(t *testing.T) {
r.SetMaxBufferSize(80 * 1024)
})
}

// blockingReadStream simulates a stream that blocks on Read until signaled.
type blockingReadStream struct {
ch chan struct{}
raceAddr int32 // race detected addr
}

func (b *blockingReadStream) Read(p []byte) (n int, err error) {
b.raceAddr += 1
<-b.ch // Block until signaled
return 0, io.EOF
}

func (b *blockingReadStream) Close() error {
b.raceAddr += 1
return nil
}

// TestReader_ConcurrentReadAndClose tests that Read and Close can be called
// concurrently without race conditions.
func TestReader_ConcurrentReadAndClose(t *testing.T) {
bs := &blockingReadStream{ch: make(chan struct{})}

resp := &protocol.Response{}
resp.Header.SetContentType(string(bytestr.MIMETextEventStream))
resp.SetBodyStream(bs, -1)

r, err := NewReader(resp)
assert.Assert(t, err == nil)

// Use a WaitGroup to wait for both goroutines to complete
var wg sync.WaitGroup
wg.Add(2)

// Goroutine 1: Try to read (will block)
go func() {
defer wg.Done()
e := NewEvent()
defer e.Release()
_ = r.Read(e) // Will block until bs.ch is closed
}()

// Goroutine 2: Close the reader while read is in progress
go func() {
defer wg.Done()
time.Sleep(200 * time.Millisecond) // Give read time to start
err := r.Close()
assert.Assert(t, err == nil)
}()

// Wait a bit, then unblock the read
time.Sleep(500 * time.Millisecond)
close(bs.ch)

// Wait for both goroutines to complete
wg.Wait()

// Close should be idempotent
err = r.Close()
assert.Assert(t, err != nil, "closing already closed reader should return error")
}

// TestReader_MultipleClose tests that calling Close multiple times is safe.
func TestReader_MultipleClose(t *testing.T) {
input := "id: 123\nevent: update\ndata: test data\n\n"

resp := &protocol.Response{}
resp.Header.SetContentType(string(bytestr.MIMETextEventStream))
resp.SetBody([]byte(input))

r, err := NewReader(resp)
assert.Assert(t, err == nil)

// First close should succeed
err = r.Close()
assert.Assert(t, err == nil)

// Second close should fail with "closing a closed file"
err = r.Close()
assert.Assert(t, err != nil)
assert.DeepEqual(t, "closing a closed file", err.Error())
}

// TestReader_ReadAfterClose tests that reading after Close returns appropriate error.
func TestReader_ReadAfterClose(t *testing.T) {
input := "id: 123\nevent: update\ndata: test data\n\n"

resp := &protocol.Response{}
resp.Header.SetContentType(string(bytestr.MIMETextEventStream))
resp.SetBody([]byte(input))

r, err := NewReader(resp)
assert.Assert(t, err == nil)

// Close the reader first
err = r.Close()
assert.Assert(t, err == nil)

// Try to read after close - should get "use of closed file" error
e := NewEvent()
defer e.Release()
err = r.Read(e)
assert.Assert(t, err != nil)
assert.DeepEqual(t, "use of closed file", err.Error())
}

// TestReader_ConcurrentForEachAndClose tests that ForEach and Close can be called
// concurrently without race conditions.
func TestReader_ConcurrentForEachAndClose(t *testing.T) {
bs := &blockingReadStream{ch: make(chan struct{})}

resp := &protocol.Response{}
resp.Header.SetContentType(string(bytestr.MIMETextEventStream))
resp.SetBodyStream(bs, -1)

r, err := NewReader(resp)
assert.Assert(t, err == nil)

ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()

// Use a WaitGroup to wait for both goroutines to complete
var wg sync.WaitGroup
wg.Add(2)

// Goroutine 1: ForEach (will block)
forEachErr := make(chan error, 1)
go func() {
defer wg.Done()
err := r.ForEach(ctx, func(e *Event) error {
return nil
})
forEachErr <- err
}()

// Goroutine 2: Close the reader while ForEach is in progress
go func() {
defer wg.Done()
time.Sleep(400 * time.Millisecond) // Give forEach time to start
err := r.Close()
assert.Assert(t, err == nil)
}()

// Unblock the blocking read
time.Sleep(600 * time.Millisecond)
close(bs.ch)

// Wait for both goroutines to complete
wg.Wait()

// Check that ForEach exited (either with context error or EOF)
select {
case err := <-forEachErr:
// Either context canceled or EOF, both are acceptable
assert.Assert(t, err == nil || err == ctx.Err() || err == context.DeadlineExceeded)
default:
t.Error("ForEach should have completed")
}

// Close should be idempotent
err = r.Close()
assert.Assert(t, err != nil, "closing already closed reader should return error")
}