11/*
22* SPDX-License-Identifier: AGPL-3.0-only
33* Copyright (c) 2022-2025, daeuniverse Organization <dae@v2raya.org>
4- */
4+ */
55
66package control
77
@@ -80,11 +80,11 @@ func (d *DoH) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, err
8080 if d .client == nil {
8181 d .client = d .getClient ()
8282 }
83- msg , err := sendHttpDNS (d .client , d .dialArgument .bestTarget .String (), & d .Upstream , data )
83+ msg , err := sendHttpDNS (ctx , d .client , d .dialArgument .bestTarget .String (), & d .Upstream , data )
8484 if err != nil {
8585 // If failed to send DNS request, we should try to create a new client.
8686 d .client = d .getClient ()
87- msg , err = sendHttpDNS (d .client , d .dialArgument .bestTarget .String (), & d .Upstream , data )
87+ msg , err = sendHttpDNS (ctx , d .client , d .dialArgument .bestTarget .String (), & d .Upstream , data )
8888 if err != nil {
8989 return nil , err
9090 }
@@ -196,7 +196,7 @@ func (d *DoQ) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, err
196196 // thanks https://github.com/natesales/q/blob/1cb2639caf69bd0a9b46494a3c689130df8fb24a/transport/quic.go#L97
197197 binary .BigEndian .PutUint16 (data [0 :2 ], 0 )
198198
199- msg , err := sendStreamDNS (stream , data )
199+ msg , err := sendStreamDNS (ctx , stream , data )
200200 if err != nil {
201201 return nil , err
202202 }
@@ -259,7 +259,7 @@ func (d *DoTLS) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, e
259259 }
260260 d .conn = tlsConn
261261
262- return sendStreamDNS (tlsConn , data )
262+ return sendStreamDNS (ctx , tlsConn , data )
263263}
264264
265265func (d * DoTLS ) Close () error {
@@ -287,7 +287,7 @@ func (d *DoTCP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, e
287287 }
288288
289289 d .conn = conn
290- return sendStreamDNS (conn , data )
290+ return sendStreamDNS (ctx , conn , data )
291291}
292292
293293func (d * DoTCP ) Close () error {
@@ -313,6 +313,7 @@ func (d *DoUDP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, e
313313 if err != nil {
314314 return nil , err
315315 }
316+ d .conn = conn
316317
317318 timeout := 5 * time .Second
318319 _ = conn .SetDeadline (time .Now ().Add (timeout ))
@@ -362,12 +363,14 @@ func (d *DoUDP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, e
362363
363364func (d * DoUDP ) Close () error {
364365 if d .conn != nil {
365- return d .conn .Close ()
366+ err := d .conn .Close ()
367+ d .conn = nil
368+ return err
366369 }
367370 return nil
368371}
369372
370- func sendHttpDNS (client * http.Client , target string , upstream * dns.Upstream , data []byte ) (respMsg * dnsmessage.Msg , err error ) {
373+ func sendHttpDNS (ctx context. Context , client * http.Client , target string , upstream * dns.Upstream , data []byte ) (respMsg * dnsmessage.Msg , err error ) {
371374 // disable redirect https://github.com/daeuniverse/dae/pull/649#issuecomment-2379577896
372375 client .CheckRedirect = func (req * http.Request , via []* http.Request ) error {
373376 return fmt .Errorf ("do not use a server that will redirect, upstream: %v" , upstream .String ())
@@ -384,7 +387,7 @@ func sendHttpDNS(client *http.Client, target string, upstream *dns.Upstream, dat
384387 q .Set ("dns" , base64 .RawURLEncoding .EncodeToString (data ))
385388 serverURL .RawQuery = q .Encode ()
386389
387- req , err := http .NewRequest ( http .MethodGet , serverURL .String (), nil )
390+ req , err := http .NewRequestWithContext ( ctx , http .MethodGet , serverURL .String (), nil )
388391 if err != nil {
389392 return nil , err
390393 }
@@ -406,7 +409,19 @@ func sendHttpDNS(client *http.Client, target string, upstream *dns.Upstream, dat
406409 return & msg , nil
407410}
408411
409- func sendStreamDNS (stream io.ReadWriter , data []byte ) (respMsg * dnsmessage.Msg , err error ) {
412+ func sendStreamDNS (ctx context.Context , stream io.ReadWriter , data []byte ) (respMsg * dnsmessage.Msg , err error ) {
413+ type streamDeadliner interface {
414+ SetDeadline (t time.Time ) error
415+ }
416+ if deadliner , ok := stream .(streamDeadliner ); ok {
417+ if deadline , ok := ctx .Deadline (); ok {
418+ _ = deadliner .SetDeadline (deadline )
419+ }
420+ }
421+ if err = ctx .Err (); err != nil {
422+ return nil , err
423+ }
424+
410425 // We should write two byte length in the front of stream DNS request.
411426 bReq := pool .Get (2 + len (data ))
412427 defer pool .Put (bReq )
@@ -416,11 +431,17 @@ func sendStreamDNS(stream io.ReadWriter, data []byte) (respMsg *dnsmessage.Msg,
416431 if err != nil {
417432 return nil , fmt .Errorf ("failed to write DNS req: %w" , err )
418433 }
434+ if err = ctx .Err (); err != nil {
435+ return nil , err
436+ }
419437
420438 // Read two byte length.
421439 if _ , err = io .ReadFull (stream , bReq [:2 ]); err != nil {
422440 return nil , fmt .Errorf ("failed to read DNS resp payload length: %w" , err )
423441 }
442+ if err = ctx .Err (); err != nil {
443+ return nil , err
444+ }
424445 respLen := int (binary .BigEndian .Uint16 (bReq ))
425446 // Try to reuse the buf.
426447 var buf []byte
0 commit comments