@@ -168,22 +168,23 @@ func (c *impl) exchangeOne(ctx context.Context, hostname string, qtype uint16) (
168168 // present.
169169 req .SetEdns0 (4096 , false )
170170
171- var resp * dns.Msg
172-
173171 servers , err := c .servers .Addrs ()
174172 if err != nil {
175173 return nil , "" , fmt .Errorf ("failed to list DNS servers: %w" , err )
176174 }
177175
178176 // Prepare to increment a latency metric no matter whether we succeed or fail.
179- // The deferred function closes over result , chosenServerIP, and tries, which
177+ // The deferred function closes over resp , chosenServerIP, and tries, which
180178 // are all modified in the loop below.
181179 start := c .clk .Now ()
182180 qtypeStr := dns .TypeToString [qtype ]
183- result := "failed"
184- chosenServerIP := ""
185- tries := 0
181+ var (
182+ resp * dns.Msg
183+ chosenServerIP string
184+ tries int
185+ )
186186 defer func () {
187+ result := "failed"
187188 if resp != nil {
188189 result = dns .RcodeToString [resp .Rcode ]
189190 }
@@ -195,12 +196,6 @@ func (c *impl) exchangeOne(ctx context.Context, hostname string, qtype uint16) (
195196 }).Observe (c .clk .Since (start ).Seconds ())
196197 }()
197198
198- type dnsRes struct {
199- resp * dns.Msg
200- err error
201- }
202- ch := make (chan dnsRes , 1 )
203-
204199 for i := range c .maxTries {
205200 tries = i + 1
206201 chosenServer := servers [i % len (servers )]
@@ -212,64 +207,61 @@ func (c *impl) exchangeOne(ctx context.Context, hostname string, qtype uint16) (
212207 // and ensures that chosenServer can't be a bare port, e.g. ":1337"
213208 chosenServerIP , _ , err = net .SplitHostPort (chosenServer )
214209 if err != nil {
215- return nil , "" , err
210+ return nil , chosenServer , err
216211 }
217212
218- go func () {
219- resp , rtt , err := c .exchanger .Exchange (req , chosenServer )
220- result := "failed"
221- if resp != nil {
222- result = dns .RcodeToString [resp .Rcode ]
223- }
224- if err != nil {
225- c .log .Infof ("logDNSError chosenServer=[%s] hostname=[%s] queryType=[%s] err=[%s]" , chosenServer , hostname , qtypeStr , err )
226- }
227- c .queryTime .With (prometheus.Labels {
228- "qtype" : qtypeStr ,
229- "result" : result ,
230- "resolver" : chosenServerIP ,
231- }).Observe (rtt .Seconds ())
232- ch <- dnsRes {resp : resp , err : err }
233- }()
234- select {
235- case <- ctx .Done ():
236- switch ctx .Err () {
237- case context .DeadlineExceeded :
238- result = "deadline exceeded"
239- case context .Canceled :
240- result = "canceled"
241- default :
242- result = "unknown"
243- }
244- c .timeoutCounter .With (prometheus.Labels {
245- "qtype" : qtypeStr ,
246- "result" : result ,
247- "resolver" : chosenServerIP ,
248- "isTLD" : fmt .Sprintf ("%t" , ! strings .Contains (hostname , "." )),
249- }).Inc ()
250- return nil , "" , ctx .Err ()
251- case r := <- ch :
252- if r .err != nil {
253- // Check if the error is a timeout error, which we want to retry.
254- // Network errors that can timeout implement the net.Error interface.
255- var netErr net.Error
256- isRetryable := errors .As (r .err , & netErr ) && netErr .Timeout ()
257- hasRetriesLeft := tries < c .maxTries
258- if isRetryable && hasRetriesLeft {
259- continue
260- } else if isRetryable && ! hasRetriesLeft {
261- c .timeoutCounter .With (prometheus.Labels {
262- "qtype" : qtypeStr ,
263- "result" : "out of retries" ,
264- "resolver" : chosenServerIP ,
265- "isTLD" : fmt .Sprintf ("%t" , ! strings .Contains (hostname , "." )),
266- }).Inc ()
267- }
213+ // Do a bare assignment (not :=) to populate the `resp` used by the defer above.
214+ var rtt time.Duration
215+ resp , rtt , err = c .exchanger .ExchangeContext (ctx , req , chosenServer )
216+
217+ // Do some metrics handling before we do error handling.
218+ result := "failed"
219+ if resp != nil {
220+ result = dns .RcodeToString [resp .Rcode ]
221+ }
222+ c .queryTime .With (prometheus.Labels {
223+ "qtype" : qtypeStr ,
224+ "result" : result ,
225+ "resolver" : chosenServerIP ,
226+ }).Observe (rtt .Seconds ())
227+
228+ if err != nil {
229+ c .log .Infof ("logDNSError chosenServer=[%s] hostname=[%s] queryType=[%s] err=[%s]" , chosenServer , hostname , qtypeStr , err )
230+
231+ // Check if the error is a timeout error, which we want to retry.
232+ // Network errors that can timeout implement the net.Error interface.
233+ var netErr net.Error
234+ isRetryable := errors .As (err , & netErr ) && netErr .Timeout () && ! errors .Is (err , context .DeadlineExceeded )
235+ hasRetriesLeft := tries < c .maxTries
236+ if isRetryable && hasRetriesLeft {
237+ continue
238+ } else if isRetryable && ! hasRetriesLeft {
239+ c .timeoutCounter .With (prometheus.Labels {
240+ "qtype" : qtypeStr ,
241+ "result" : "out of retries" ,
242+ "resolver" : chosenServerIP ,
243+ "isTLD" : fmt .Sprintf ("%t" , ! strings .Contains (hostname , "." )),
244+ }).Inc ()
245+ } else if errors .Is (err , context .DeadlineExceeded ) {
246+ c .timeoutCounter .With (prometheus.Labels {
247+ "qtype" : qtypeStr ,
248+ "result" : "deadline exceeded" ,
249+ "resolver" : chosenServerIP ,
250+ "isTLD" : fmt .Sprintf ("%t" , ! strings .Contains (hostname , "." )),
251+ }).Inc ()
252+ } else if errors .Is (err , context .Canceled ) {
253+ c .timeoutCounter .With (prometheus.Labels {
254+ "qtype" : qtypeStr ,
255+ "result" : "canceled" ,
256+ "resolver" : chosenServerIP ,
257+ "isTLD" : fmt .Sprintf ("%t" , ! strings .Contains (hostname , "." )),
258+ }).Inc ()
268259 }
269260
270- // This is either a success or a non-retryable error; return either way.
271- return r .resp , chosenServer , r .err
261+ return nil , chosenServer , err
272262 }
263+
264+ return resp , chosenServer , nil
273265 }
274266
275267 // It's impossible to get past the bottom of the loop: on the last attempt
@@ -286,7 +278,7 @@ func (c *impl) LookupA(ctx context.Context, hostname string) (*Result[*dns.A], s
286278 return nil , resolver , err
287279 }
288280
289- return resultFromMsg [* dns.A ](resp ), resolver , wrapErr ( dns . TypeA , hostname , resp , err )
281+ return resultFromMsg [* dns.A ](resp ), resolver , nil
290282}
291283
292284// LookupAAAA sends a DNS query to find all AAAA records associated with the
@@ -339,7 +331,7 @@ func (c *impl) LookupTXT(ctx context.Context, hostname string) (*Result[*dns.TXT
339331// exchanger represents an underlying DNS client. This interface exists solely
340332// so that its implementation can be swapped out in unit tests.
341333type exchanger interface {
342- Exchange ( m * dns.Msg , a string ) (* dns.Msg , time.Duration , error )
334+ ExchangeContext ( ctx context. Context , m * dns.Msg , a string ) (* dns.Msg , time.Duration , error )
343335}
344336
345337// dohExchanger implements the exchanger interface. It routes all of its DNS
@@ -351,16 +343,16 @@ type dohExchanger struct {
351343 userAgent string
352344}
353345
354- // Exchange sends a DoH query to the provided DoH server and returns the response.
355- func (d * dohExchanger ) Exchange ( query * dns.Msg , server string ) (* dns.Msg , time.Duration , error ) {
346+ // ExchangeContext sends a DoH query to the provided DoH server and returns the response.
347+ func (d * dohExchanger ) ExchangeContext ( ctx context. Context , query * dns.Msg , server string ) (* dns.Msg , time.Duration , error ) {
356348 q , err := query .Pack ()
357349 if err != nil {
358350 return nil , 0 , err
359351 }
360352
361353 // The default Unbound URL template
362354 url := fmt .Sprintf ("https://%s/dns-query" , server )
363- req , err := http .NewRequest ( "POST" , url , strings .NewReader (string (q )))
355+ req , err := http .NewRequestWithContext ( ctx , "POST" , url , strings .NewReader (string (q )))
364356 if err != nil {
365357 return nil , 0 , err
366358 }
0 commit comments