Skip to content

Commit ad4c48e

Browse files
authored
http: Initialize http.Request with ctx so cancelation interrupts the request (#321)
* Add test to verify canceled context.Context aborts http request * http: Initialize http.Request with ctx so cancelation interrupts the request
1 parent 9e42df5 commit ad4c48e

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

get_http.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error {
8989
u.RawQuery = q.Encode()
9090

9191
// Get the URL
92-
req, err := http.NewRequest("GET", u.String(), nil)
92+
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
9393
if err != nil {
9494
return err
9595
}
@@ -176,7 +176,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error {
176176
// We first make a HEAD request so we can check
177177
// if the server supports range queries. If the server/URL doesn't
178178
// support HEAD requests, we just fall back to GET.
179-
req, err := http.NewRequest("HEAD", src.String(), nil)
179+
req, err := http.NewRequestWithContext(ctx, "HEAD", src.String(), nil)
180180
if err != nil {
181181
return err
182182
}
@@ -203,7 +203,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error {
203203
}
204204
}
205205

206-
req, err = http.NewRequest("GET", src.String(), nil)
206+
req, err = http.NewRequestWithContext(ctx, "GET", src.String(), nil)
207207
if err != nil {
208208
return err
209209
}

get_http_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package getter
22

33
import (
4+
"context"
45
"crypto/sha256"
56
"encoding/hex"
67
"errors"
@@ -404,6 +405,42 @@ func TestHttpGetter_cleanhttp(t *testing.T) {
404405
}
405406
}
406407

408+
func TestHttpGetter__RespectsContextCanceled(t *testing.T) {
409+
ctx, cancel := context.WithCancel(context.Background())
410+
cancel() // cancel immediately
411+
412+
ln := testHttpServer(t)
413+
414+
var u url.URL
415+
u.Scheme = "http"
416+
u.Host = ln.Addr().String()
417+
u.Path = "/file"
418+
dst := tempDir(t)
419+
420+
rt := hookableHTTPRoundTripper{
421+
before: func(req *http.Request) {
422+
err := req.Context().Err()
423+
if !errors.Is(err, context.Canceled) {
424+
t.Fatalf("Expected http.Request with canceled.Context, got: %v", err)
425+
}
426+
},
427+
RoundTripper: http.DefaultTransport,
428+
}
429+
430+
g := new(HttpGetter)
431+
g.client = &Client{
432+
Ctx: ctx,
433+
}
434+
g.Client = &http.Client{
435+
Transport: &rt,
436+
}
437+
438+
err := g.Get(dst, &u)
439+
if !errors.Is(err, context.Canceled) {
440+
t.Fatalf("expected context.Canceled, got: %v", err)
441+
}
442+
}
443+
407444
func testHttpServer(t *testing.T) net.Listener {
408445
ln, err := net.Listen("tcp", "127.0.0.1:0")
409446
if err != nil {
@@ -531,3 +568,15 @@ machine %s
531568
login foo
532569
password bar
533570
`
571+
572+
type hookableHTTPRoundTripper struct {
573+
before func(req *http.Request)
574+
http.RoundTripper
575+
}
576+
577+
func (m *hookableHTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
578+
if m.before != nil {
579+
m.before(req)
580+
}
581+
return m.RoundTripper.RoundTrip(req)
582+
}

0 commit comments

Comments
 (0)