@@ -137,6 +137,70 @@ func (s *QUICSocket) WaitForIncomingConn() (packets.UDPConn, error) {
137137 }
138138}
139139
140+ func (s * QUICSocket ) WaitForIncomingConnWithContext (ctx context.Context ) (packets.UDPConn , error ) {
141+ if s .options == nil || ! s .options .MultiportMode {
142+ log .Debugf ("Waiting for new connection" )
143+ stream , err := s .listenConns [0 ].AcceptStreamWithContext (ctx )
144+ if err != nil {
145+ log .Fatalf ("QUIC Accept err %s" , err .Error ())
146+ }
147+
148+ log .Debugf ("Accepted new Stream on listen socket" )
149+
150+ bts := make ([]byte , packets .PACKET_SIZE )
151+ _ , err = stream .Read (bts )
152+
153+ if s .listenConns [0 ].GetInternalConn () == nil {
154+ s .listenConns [0 ].SetStream (stream )
155+ select {
156+ case s .listenConns [0 ].Ready <- true :
157+ default :
158+ }
159+
160+ return s .listenConns [0 ], nil
161+ } else {
162+ newConn := & packets.QUICReliableConn {}
163+ id := RandStringBytes (32 )
164+ newConn .SetId (id )
165+ newConn .SetLocal (* s .localAddr )
166+ newConn .SetRemote (s .listenConns [0 ].GetRemote ())
167+ newConn .SetStream (stream )
168+ s .listenConns = append (s .listenConns , newConn )
169+
170+ _ , err = stream .Read (bts )
171+ if err != nil {
172+ return nil , err
173+ }
174+ return newConn , nil
175+ }
176+ } else {
177+ addr := s .localAddr .Copy ()
178+ addr .Host .Port = s .localAddr .Host .Port + len (s .listenConns )
179+ conn := & packets.QUICReliableConn {}
180+ err := conn .Listen (* addr )
181+ if err != nil {
182+ return nil , err
183+ }
184+
185+ stream , err := conn .AcceptStreamWithContext (ctx )
186+ if err != nil {
187+ return nil , err
188+ }
189+
190+ id := RandStringBytes (32 )
191+ conn .SetId (id )
192+
193+ conn .SetStream (stream )
194+ s .listenConns = append (s .listenConns , conn )
195+ bts := make ([]byte , packets .PACKET_SIZE )
196+ _ , err = stream .Read (bts )
197+ if err != nil {
198+ return nil , err
199+ }
200+ return conn , nil
201+ }
202+ }
203+
140204func (s * QUICSocket ) WaitForDialIn () (* snet.UDPAddr , error ) {
141205 bts := make ([]byte , packets .PACKET_SIZE )
142206 log .Debugf ("Wait for Dial In" )
@@ -210,7 +274,7 @@ func (s *QUICSocket) WaitForDialInWithContext(ctx context.Context) (*snet.UDPAdd
210274 log .Debugf ("Waiting for %d more connections" , p .NumPaths - 1 )
211275
212276 for i := 1 ; i < p .NumPaths ; i ++ {
213- _ , err := s .WaitForIncomingConn ( )
277+ _ , err := s .WaitForIncomingConnWithContext ( ctx )
214278 if err != nil {
215279 return nil , err
216280 }
0 commit comments