@@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
233233 sessionPolicy = s .policyManager .ForLevel (user .Level )
234234
235235 if destination .Network == net .Network_UDP { // handle udp request
236- return s .handleUDPPayload (ctx , & PacketReader {Reader : clientReader }, & PacketWriter {Writer : conn }, dispatcher )
236+ return s .handleUDPPayload (ctx , sessionPolicy , & PacketReader {Reader : clientReader }, & PacketWriter {Writer : conn }, dispatcher )
237237 }
238238
239239 ctx = log .ContextWithAccessMessage (ctx , & log.AccessMessage {
@@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
248248 return s .handleConnection (ctx , sessionPolicy , destination , clientReader , buf .NewWriter (conn ), dispatcher )
249249}
250250
251- func (s * Server ) handleUDPPayload (ctx context.Context , clientReader * PacketReader , clientWriter * PacketWriter , dispatcher routing.Dispatcher ) error {
251+ func (s * Server ) handleUDPPayload (ctx context.Context , sessionPolicy policy.Session , clientReader * PacketReader , clientWriter * PacketWriter , dispatcher routing.Dispatcher ) error {
252+ ctx , cancel := context .WithCancel (ctx )
253+ defer cancel ()
254+ timer := signal .CancelAfterInactivity (ctx , cancel , sessionPolicy .Timeouts .ConnectionIdle )
255+ defer timer .SetTimeout (0 )
252256 udpServer := udp .NewDispatcher (dispatcher , func (ctx context.Context , packet * udp_proto.Packet ) {
253257 udpPayload := packet .Payload
254258 if udpPayload .UDP == nil {
@@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
257261
258262 if err := clientWriter .WriteMultiBuffer (buf.MultiBuffer {udpPayload }); err != nil {
259263 errors .LogWarningInner (ctx , err , "failed to write response" )
264+ cancel ()
265+ } else {
266+ timer .Update ()
260267 }
261268 })
262269 defer udpServer .RemoveRay ()
@@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
266273
267274 var dest * net.Destination
268275
269- for {
270- select {
271- case <- ctx .Done ():
272- return nil
273- default :
274- mb , err := clientReader .ReadMultiBuffer ()
275- if err != nil {
276- if errors .Cause (err ) != io .EOF {
277- return errors .New ("unexpected EOF" ).Base (err )
278- }
276+ requestDone := func () error {
277+ for {
278+ select {
279+ case <- ctx .Done ():
279280 return nil
280- }
281+ default :
282+ mb , err := clientReader .ReadMultiBuffer ()
283+ if err != nil {
284+ if errors .Cause (err ) != io .EOF {
285+ return errors .New ("unexpected EOF" ).Base (err )
286+ }
287+ return nil
288+ }
281289
282- mb2 , b := buf .SplitFirst (mb )
283- if b == nil {
284- continue
285- }
286- destination := * b .UDP
287-
288- currentPacketCtx := ctx
289- if inbound .Source .IsValid () {
290- currentPacketCtx = log .ContextWithAccessMessage (ctx , & log.AccessMessage {
291- From : inbound .Source ,
292- To : destination ,
293- Status : log .AccessAccepted ,
294- Reason : "" ,
295- Email : user .Email ,
296- })
297- }
298- errors .LogInfo (ctx , "tunnelling request to " , destination )
290+ mb2 , b := buf .SplitFirst (mb )
291+ if b == nil {
292+ continue
293+ }
294+ timer .Update ()
295+ destination := * b .UDP
296+
297+ currentPacketCtx := ctx
298+ if inbound .Source .IsValid () {
299+ currentPacketCtx = log .ContextWithAccessMessage (ctx , & log.AccessMessage {
300+ From : inbound .Source ,
301+ To : destination ,
302+ Status : log .AccessAccepted ,
303+ Reason : "" ,
304+ Email : user .Email ,
305+ })
306+ }
307+ errors .LogInfo (ctx , "tunnelling request to " , destination )
299308
300- if ! s .cone || dest == nil {
301- dest = & destination
302- }
309+ if ! s .cone || dest == nil {
310+ dest = & destination
311+ }
303312
304- udpServer .Dispatch (currentPacketCtx , * dest , b ) // first packet
305- for _ , payload := range mb2 {
306- udpServer .Dispatch (currentPacketCtx , * dest , payload )
313+ udpServer .Dispatch (currentPacketCtx , * dest , b ) // first packet
314+ for _ , payload := range mb2 {
315+ udpServer .Dispatch (currentPacketCtx , * dest , payload )
316+ }
307317 }
308318 }
319+
320+ }
321+
322+ if err := task .Run (ctx , requestDone ); err != nil {
323+ return err
309324 }
325+ return nil
310326}
311327
312328func (s * Server ) handleConnection (ctx context.Context , sessionPolicy policy.Session ,
0 commit comments