Skip to content

Commit 20ed755

Browse files
committed
TUN-6679: Allow client side of quic request to close body
In a previous commit, we fixed a bug where the client roundtrip code could close the request body, which in fact would be the quic.Stream, thus closing the write-side. The way that was fixed, prevented the client roundtrip code from closing also read-side (the body). This fixes that, by allowing close to only close the read side, which will guarantee that any subsquent will fail with an error or EOF it occurred before the close.
1 parent 8e9e1d9 commit 20ed755

File tree

2 files changed

+86
-4
lines changed

2 files changed

+86
-4
lines changed

connection/quic.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"strconv"
1111
"strings"
12+
"sync/atomic"
1213
"time"
1314

1415
"github.com/google/uuid"
@@ -156,9 +157,10 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) {
156157
defer stream.Close()
157158

158159
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
159-
// code executed in the code path of handleStream don't trigger an earlier close to the downstream stream.
160-
// So, we wrap the stream with a no-op closer and only this method can actually close the stream.
161-
noCloseStream := &nopCloserReadWriter{stream}
160+
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
161+
// So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
162+
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
163+
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
162164
if err := q.handleStream(ctx, noCloseStream); err != nil {
163165
q.logger.Err(err).Msg("Failed to handle QUIC stream")
164166
}
@@ -408,10 +410,39 @@ func isTransferEncodingChunked(req *http.Request) bool {
408410
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
409411
}
410412

413+
// A helper struct that guarantees a call to close only affects read side, but not write side.
411414
type nopCloserReadWriter struct {
412415
io.ReadWriteCloser
416+
417+
// for use by Read only
418+
// we don't need a memory barrier here because there is an implicit assumption that
419+
// Read calls can't happen concurrently by different go-routines.
420+
sawEOF bool
421+
// should be updated and read using atomic primitives.
422+
// value is read in Read method and written in Close method, which could be done by different
423+
// go-routines.
424+
closed uint32
425+
}
426+
427+
func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
428+
if np.sawEOF {
429+
return 0, io.EOF
430+
}
431+
432+
if atomic.LoadUint32(&np.closed) > 0 {
433+
return 0, fmt.Errorf("closed by handler")
434+
}
435+
436+
n, err = np.ReadWriteCloser.Read(p)
437+
if err == io.EOF {
438+
np.sawEOF = true
439+
}
440+
441+
return
413442
}
414443

415-
func (n *nopCloserReadWriter) Close() error {
444+
func (np *nopCloserReadWriter) Close() error {
445+
atomic.StoreUint32(&np.closed, 1)
446+
416447
return nil
417448
}

connection/quic_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"net/url"
1212
"os"
13+
"strings"
1314
"sync"
1415
"testing"
1516
"time"
@@ -527,6 +528,44 @@ func TestServeUDPSession(t *testing.T) {
527528
cancel()
528529
}
529530

531+
func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
532+
readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}}
533+
buffer := make([]byte, 5)
534+
535+
n, err := readerWriter.Read(buffer)
536+
require.NoError(t, err)
537+
require.Equal(t, n, 5)
538+
539+
// close
540+
require.NoError(t, readerWriter.Close())
541+
542+
// read should get error
543+
n, err = readerWriter.Read(buffer)
544+
require.Equal(t, n, 0)
545+
require.Equal(t, err, fmt.Errorf("closed by handler"))
546+
}
547+
548+
func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
549+
readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}}
550+
buffer := make([]byte, 20)
551+
552+
n, err := readerWriter.Read(buffer)
553+
require.NoError(t, err)
554+
require.Equal(t, n, 9)
555+
556+
// force another read to read eof
557+
n, err = readerWriter.Read(buffer)
558+
require.Equal(t, err, io.EOF)
559+
560+
// close
561+
require.NoError(t, readerWriter.Close())
562+
563+
// read should get EOF still
564+
n, err = readerWriter.Read(buffer)
565+
require.Equal(t, n, 0)
566+
require.Equal(t, err, io.EOF)
567+
}
568+
530569
func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) {
531570
var (
532571
payload = []byte(t.Name())
@@ -647,3 +686,15 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
647686
require.NoError(t, err)
648687
return qc
649688
}
689+
690+
type mockReaderNoopWriter struct {
691+
io.Reader
692+
}
693+
694+
func (m *mockReaderNoopWriter) Write(p []byte) (n int, err error) {
695+
return len(p), nil
696+
}
697+
698+
func (m *mockReaderNoopWriter) Close() error {
699+
return nil
700+
}

0 commit comments

Comments
 (0)