@@ -9,41 +9,11 @@ import (
99 "time"
1010)
1111
12- // safeStream wraps a BidirectionalStream with lifecycle protection.
13- // It ensures Cancel and Destroy don't race, preventing use-after-free.
14- type safeStream struct {
15- stream BidirectionalStream
16- mutex sync.Mutex
17- destroyed bool
18- }
19-
20- // Cancel cancels the stream if it hasn't been destroyed.
21- // Safe to call concurrently with Destroy.
22- func (s * safeStream ) Cancel () {
23- s .mutex .Lock ()
24- defer s .mutex .Unlock ()
25- if ! s .destroyed {
26- s .stream .Cancel ()
27- }
28- }
29-
30- // Destroy marks the stream as destroyed and destroys it.
31- // Safe to call concurrently with Cancel. Returns false if already destroyed.
32- func (s * safeStream ) Destroy () bool {
33- s .mutex .Lock ()
34- defer s .mutex .Unlock ()
35- if s .destroyed {
36- return false
37- }
38- s .destroyed = true
39- s .stream .Destroy ()
40- return true
41- }
42-
4312// BidirectionalConn is a wrapper from BidirectionalStream to net.Conn
4413type BidirectionalConn struct {
45- stream * safeStream
46- rawStream BidirectionalStream // for non-lifecycle operations
14+ stream BidirectionalStream
15+ cancelOnce sync.Once // Ensures Cancel is called at most once
16+ destroyOnce sync.Once // Ensures Destroy is called at most once
4717 readWaitHeaders bool
4818 writeWaitHeaders bool
4919 access sync.Mutex
@@ -94,8 +64,7 @@ func (e StreamEngine) CreateConn(readWaitHeaders bool, writeWaitHeaders bool) *B
9464 }
9565 conn .readSemaphore <- struct {}{}
9666 conn .writeSemaphore <- struct {}{}
97- conn .rawStream = e .CreateStream (& bidirectionalHandler {BidirectionalConn : conn })
98- conn .stream = & safeStream {stream : conn .rawStream }
67+ conn .stream = e .CreateStream (& bidirectionalHandler {BidirectionalConn : conn })
9968 return conn
10069}
10170
@@ -109,7 +78,7 @@ func (c *BidirectionalConn) Start(method string, url string, headers map[string]
10978 return net .ErrClosed
11079 default :
11180 }
112- if ! c .rawStream .Start (method , url , headers , priority , endOfStream ) {
81+ if ! c .stream .Start (method , url , headers , priority , endOfStream ) {
11382 return os .ErrInvalid
11483 }
11584 return nil
@@ -169,7 +138,7 @@ func (c *BidirectionalConn) Read(p []byte) (n int, err error) {
169138 c .readBuffer = make ([]byte , len (p ))
170139 }
171140 readBuffer := c .readBuffer [:len (p )]
172- c .rawStream .Read (readBuffer )
141+ c .stream .Read (readBuffer )
173142 c .access .Unlock ()
174143
175144 select {
@@ -248,7 +217,7 @@ func (c *BidirectionalConn) Write(p []byte) (n int, err error) {
248217 }
249218 writeBuffer := c .writeBuffer [:len (p )]
250219 copy (writeBuffer , p )
251- c .rawStream .Write (writeBuffer , false )
220+ c .stream .Write (writeBuffer , false )
252221 c .access .Unlock ()
253222
254223 select {
@@ -286,14 +255,20 @@ func (c *BidirectionalConn) Close() error {
286255 return net .ErrClosed
287256 case <- c .done :
288257 c .access .Unlock ()
289- return net . ErrClosed
258+ return nil // Stream already terminated normally
290259 default :
291260 }
292261
293262 close (c .close )
294263 c .access .Unlock ()
295264
296- c .stream .Cancel ()
265+ // Use cancelOnce to ensure Cancel is only called once and doesn't race
266+ // with terminal callbacks. If a terminal callback has already consumed
267+ // cancelOnce, this will be a no-op, preventing use-after-free crashes
268+ // when Cancel races with the async Destroy.
269+ c .cancelOnce .Do (func () {
270+ c .stream .Cancel ()
271+ })
297272 return nil
298273}
299274
@@ -438,18 +413,25 @@ func (c *bidirectionalHandler) OnResponseTrailersReceived(stream BidirectionalSt
438413func (c * bidirectionalHandler ) OnSucceeded (stream BidirectionalStream ) {
439414 c .signalReadDone ()
440415 c .signalWriteDone ()
416+ // Consume cancelOnce to prevent Close() from calling Cancel after this
417+ // terminal callback. This prevents Cancel from racing with Destroy.
418+ c .cancelOnce .Do (func () {})
441419 c .Close (io .EOF )
442420}
443421
444422func (c * bidirectionalHandler ) OnFailed (stream BidirectionalStream , netError int ) {
445423 c .signalReadDone ()
446424 c .signalWriteDone ()
425+ c .cancelOnce .Do (func () {})
447426 c .Close (NetError (netError ))
448427}
449428
450429func (c * bidirectionalHandler ) OnCanceled (stream BidirectionalStream ) {
451430 c .signalReadDone ()
452431 c .signalWriteDone ()
432+ // OnCanceled is triggered by Cancel(), so cancelOnce is already consumed.
433+ // But call it anyway for consistency and safety.
434+ c .cancelOnce .Do (func () {})
453435 c .Close (context .Canceled )
454436}
455437
@@ -460,6 +442,8 @@ func (c *bidirectionalHandler) Close(err error) {
460442 close (c .done )
461443 c .access .Unlock ()
462444
463- c .stream .Destroy ()
445+ c .destroyOnce .Do (func () {
446+ c .stream .Destroy ()
447+ })
464448 })
465449}
0 commit comments