Skip to content

Commit 65532e7

Browse files
committed
fix: Pass down context.Context for proper context management
1 parent 31e0b36 commit 65532e7

25 files changed

+200
-292
lines changed

checksum.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package getter
66
import (
77
"bufio"
88
"bytes"
9+
"context"
910
"crypto/md5"
1011
"crypto/sha1"
1112
"crypto/sha256"
@@ -80,23 +81,27 @@ func (c *FileChecksum) checksum(source string) error {
8081
// extractChecksum will return a FileChecksum based on the 'checksum'
8182
// parameter of u.
8283
// ex:
83-
// http://hashicorp.com/terraform?checksum=<checksumValue>
84-
// http://hashicorp.com/terraform?checksum=<checksumType>:<checksumValue>
85-
// http://hashicorp.com/terraform?checksum=file:<checksum_url>
84+
//
85+
// http://hashicorp.com/terraform?checksum=<checksumValue>
86+
// http://hashicorp.com/terraform?checksum=<checksumType>:<checksumValue>
87+
// http://hashicorp.com/terraform?checksum=file:<checksum_url>
88+
//
8689
// when checksumming from a file, extractChecksum will go get checksum_url
8790
// in a temporary directory, parse the content of the file then delete it.
8891
// Content of files are expected to be BSD style or GNU style.
8992
//
9093
// BSD-style checksum:
91-
// MD5 (file1) = <checksum>
92-
// MD5 (file2) = <checksum>
94+
//
95+
// MD5 (file1) = <checksum>
96+
// MD5 (file2) = <checksum>
9397
//
9498
// GNU-style:
95-
// <checksum> file1
96-
// <checksum> *file2
99+
//
100+
// <checksum> file1
101+
// <checksum> *file2
97102
//
98103
// see parseChecksumLine for more detail on checksum file parsing
99-
func (c *Client) extractChecksum(u *url.URL) (*FileChecksum, error) {
104+
func (c *Client) extractChecksum(ctx context.Context, u *url.URL) (*FileChecksum, error) {
100105
q := u.Query()
101106
v := q.Get("checksum")
102107

@@ -118,7 +123,7 @@ func (c *Client) extractChecksum(u *url.URL) (*FileChecksum, error) {
118123

119124
switch checksumType {
120125
case "file":
121-
return c.ChecksumFromFile(checksumValue, u)
126+
return c.ChecksumFromFile(ctx, checksumValue, u)
122127
default:
123128
return newChecksumFromType(checksumType, checksumValue, filepath.Base(u.EscapedPath()))
124129
}
@@ -193,7 +198,7 @@ func newChecksumFromValue(checksumValue, filename string) (*FileChecksum, error)
193198
//
194199
// ChecksumFromFile will only return checksums for files that match file
195200
// behind src
196-
func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileChecksum, error) {
201+
func (c *Client) ChecksumFromFile(ctx context.Context, checksumFile string, src *url.URL) (*FileChecksum, error) {
197202
checksumFileURL, err := urlhelper.Parse(checksumFile)
198203
if err != nil {
199204
return nil, err
@@ -206,7 +211,6 @@ func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileCheck
206211
defer os.Remove(tempfile)
207212

208213
c2 := &Client{
209-
Ctx: c.Ctx,
210214
Getters: c.Getters,
211215
Decompressors: c.Decompressors,
212216
Detectors: c.Detectors,
@@ -216,7 +220,7 @@ func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileCheck
216220
Dst: tempfile,
217221
ProgressListener: c.ProgressListener,
218222
}
219-
if err = c2.Get(); err != nil {
223+
if err = c2.Get(ctx); err != nil {
220224
return nil, fmt.Errorf(
221225
"Error downloading checksum file: %s", err)
222226
}

client.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled")
2626
// Using a client directly allows more fine-grained control over how downloading
2727
// is done, as well as customizing the protocols supported.
2828
type Client struct {
29-
// Ctx for cancellation
30-
Ctx context.Context
3129

3230
// Src is the source URL to get.
3331
//
@@ -104,7 +102,7 @@ func (c *Client) mode(mode os.FileMode) os.FileMode {
104102
}
105103

106104
// Get downloads the configured source to the destination.
107-
func (c *Client) Get() error {
105+
func (c *Client) Get(ctx context.Context) error {
108106
if err := c.Configure(c.Options...); err != nil {
109107
return err
110108
}
@@ -221,7 +219,7 @@ func (c *Client) Get() error {
221219
}
222220

223221
// Determine checksum if we have one
224-
checksum, err := c.extractChecksum(u)
222+
checksum, err := c.extractChecksum(ctx, u)
225223
if err != nil {
226224
return fmt.Errorf("invalid checksum: %s", err)
227225
}
@@ -232,7 +230,7 @@ func (c *Client) Get() error {
232230

233231
if mode == ClientModeAny {
234232
// Ask the getter which client mode to use
235-
mode, err = g.ClientMode(u)
233+
mode, err = g.ClientMode(ctx, u)
236234
if err != nil {
237235
return err
238236
}
@@ -270,7 +268,7 @@ func (c *Client) Get() error {
270268
}
271269
}
272270
if getFile {
273-
err := g.GetFile(dst, u)
271+
err := g.GetFile(ctx, dst, u)
274272
if err != nil {
275273
return err
276274
}
@@ -321,7 +319,7 @@ func (c *Client) Get() error {
321319

322320
// We're downloading a directory, which might require a bit more work
323321
// if we're specifying a subdir.
324-
err := g.Get(dst, u)
322+
err := g.Get(ctx, dst, u)
325323
if err != nil {
326324
err = fmt.Errorf("error downloading '%s': %s", RedactURL(u), err)
327325
return err
@@ -343,7 +341,7 @@ func (c *Client) Get() error {
343341
return err
344342
}
345343

346-
return copyDir(c.Ctx, realDst, subDir, false, c.DisableSymlinks, c.umask())
344+
return copyDir(ctx, realDst, subDir, false, c.DisableSymlinks, c.umask())
347345
}
348346

349347
return nil

client_option.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package getter
55

66
import (
7-
"context"
87
"os"
98
)
109

@@ -15,10 +14,6 @@ type ClientOption func(*Client) error
1514
// behavior including context, decompressors, detectors, and getters used by
1615
// the client.
1716
func (c *Client) Configure(opts ...ClientOption) error {
18-
// If the context has not been configured use the background context.
19-
if c.Ctx == nil {
20-
c.Ctx = context.Background()
21-
}
2217

2318
// Store the options used to configure this client.
2419
c.Options = opts
@@ -52,15 +47,6 @@ func (c *Client) Configure(opts ...ClientOption) error {
5247
return nil
5348
}
5449

55-
// WithContext allows to pass a context to operation
56-
// in order to be able to cancel a download in progress.
57-
func WithContext(ctx context.Context) ClientOption {
58-
return func(c *Client) error {
59-
c.Ctx = ctx
60-
return nil
61-
}
62-
}
63-
6450
// WithDecompressors specifies which Decompressor are available.
6551
func WithDecompressors(decompressors map[string]Decompressor) ClientOption {
6652
return func(c *Client) error {

client_option_progress_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package getter
55

66
import (
7+
"context"
78
"io"
89
"net/http"
910
"net/http/httptest"
@@ -42,7 +43,7 @@ func TestGet_progress(t *testing.T) {
4243
{ // dl without tracking
4344
dst := tempTestFile(t)
4445
defer os.RemoveAll(filepath.Dir(dst))
45-
if err := GetFile(dst, s.URL+"/file?thig=this&that"); err != nil {
46+
if err := GetFile(context.Background(), dst, s.URL+"/file?thig=this&that"); err != nil {
4647
t.Fatalf("download failed: %v", err)
4748
}
4849
}
@@ -51,10 +52,10 @@ func TestGet_progress(t *testing.T) {
5152
p := &MockProgressTracking{}
5253
dst := tempTestFile(t)
5354
defer os.RemoveAll(filepath.Dir(dst))
54-
if err := GetFile(dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil {
55+
if err := GetFile(context.Background(), dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil {
5556
t.Fatalf("download failed: %v", err)
5657
}
57-
if err := GetFile(dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil {
58+
if err := GetFile(context.Background(), dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil {
5859
t.Fatalf("download failed: %v", err)
5960
}
6061

cmd/go-getter/main.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ func main() {
5858
ctx, cancel := context.WithCancel(context.Background())
5959
// Build the client
6060
client := &getter.Client{
61-
Ctx: ctx,
6261
Src: args[0],
6362
Dst: args[1],
6463
Pwd: pwd,
@@ -72,7 +71,7 @@ func main() {
7271
go func() {
7372
defer wg.Done()
7473
defer cancel()
75-
if err := client.Get(); err != nil {
74+
if err := client.Get(ctx); err != nil {
7675
errChan <- err
7776
}
7877
}()

folder_storage.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package getter
55

66
import (
7+
"context"
78
"crypto/md5"
89
"encoding/hex"
910
"fmt"
@@ -42,7 +43,7 @@ func (s *FolderStorage) Dir(key string) (d string, e bool, err error) {
4243
}
4344

4445
// Get implements Storage.Get
45-
func (s *FolderStorage) Get(key string, source string, update bool) error {
46+
func (s *FolderStorage) Get(ctx context.Context, key string, source string, update bool) error {
4647
dir := s.dir(key)
4748
if !update {
4849
if _, err := os.Stat(dir); err == nil {
@@ -57,7 +58,7 @@ func (s *FolderStorage) Get(key string, source string, update bool) error {
5758
}
5859

5960
// Get the source. This always forces an update.
60-
return Get(dir, source)
61+
return Get(ctx, dir, source)
6162
}
6263

6364
// dir returns the directory name internally that we'll use to map to

folder_storage_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package getter
55

66
import (
7+
"context"
78
"os"
89
"path/filepath"
910
"testing"
@@ -30,7 +31,7 @@ func TestFolderStorage(t *testing.T) {
3031
key := "foo"
3132

3233
// We can get it
33-
err = s.Get(key, module, false)
34+
err = s.Get(context.Background(), key, module, false)
3435
if err != nil {
3536
t.Fatalf("err: %s", err)
3637
}

get.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package getter
1616

1717
import (
1818
"bytes"
19+
"context"
1920
"fmt"
2021
"net/url"
2122
"os/exec"
@@ -33,16 +34,16 @@ type Getter interface {
3334
// The directory may already exist (if we're updating). If it is in a
3435
// format that isn't understood, an error should be returned. Get shouldn't
3536
// simply nuke the directory.
36-
Get(string, *url.URL) error
37+
Get(context.Context, string, *url.URL) error
3738

3839
// GetFile downloads the give URL into the given path. The URL must
3940
// reference a single file. If possible, the Getter should check if
4041
// the remote end contains the same file and no-op this operation.
41-
GetFile(string, *url.URL) error
42+
GetFile(context.Context, string, *url.URL) error
4243

4344
// ClientMode returns the mode based on the given URL. This is used to
4445
// allow clients to let the getters decide which mode to use.
45-
ClientMode(*url.URL) (ClientMode, error)
46+
ClientMode(context.Context, *url.URL) (ClientMode, error)
4647

4748
// SetClient allows a getter to know it's client
4849
// in order to access client's Get functions or
@@ -82,13 +83,13 @@ func init() {
8283
//
8384
// src is a URL, whereas dst is always just a file path to a folder. This
8485
// folder doesn't need to exist. It will be created if it doesn't exist.
85-
func Get(dst, src string, opts ...ClientOption) error {
86+
func Get(ctx context.Context, dst, src string, opts ...ClientOption) error {
8687
return (&Client{
8788
Src: src,
8889
Dst: dst,
8990
Dir: true,
9091
Options: opts,
91-
}).Get()
92+
}).Get(ctx)
9293
}
9394

9495
// GetAny downloads a URL into the given destination. Unlike Get or
@@ -97,24 +98,24 @@ func Get(dst, src string, opts ...ClientOption) error {
9798
// dst must be a directory. If src is a file, it will be downloaded
9899
// into dst with the basename of the URL. If src is a directory or
99100
// archive, it will be unpacked directly into dst.
100-
func GetAny(dst, src string, opts ...ClientOption) error {
101+
func GetAny(ctx context.Context, dst, src string, opts ...ClientOption) error {
101102
return (&Client{
102103
Src: src,
103104
Dst: dst,
104105
Mode: ClientModeAny,
105106
Options: opts,
106-
}).Get()
107+
}).Get(ctx)
107108
}
108109

109110
// GetFile downloads the file specified by src into the path specified by
110111
// dst.
111-
func GetFile(dst, src string, opts ...ClientOption) error {
112+
func GetFile(ctx context.Context, dst, src string, opts ...ClientOption) error {
112113
return (&Client{
113114
Src: src,
114115
Dst: dst,
115116
Dir: false,
116117
Options: opts,
117-
}).Get()
118+
}).Get(ctx)
118119
}
119120

120121
// getRunCommand is a helper that will run a command and capture the output

get_base.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,10 @@
33

44
package getter
55

6-
import "context"
7-
86
// getter is our base getter; it regroups
97
// fields all getters have in common.
108
type getter struct {
119
client *Client
1210
}
1311

1412
func (g *getter) SetClient(c *Client) { g.client = c }
15-
16-
// Context tries to returns the Contex from the getter's
17-
// client. otherwise context.Background() is returned.
18-
func (g *getter) Context() context.Context {
19-
if g == nil || g.client == nil {
20-
return context.Background()
21-
}
22-
return g.client.Ctx
23-
}

get_file.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package getter
55

66
import (
7+
"context"
78
"net/url"
89
"os"
910
)
@@ -19,7 +20,7 @@ type FileGetter struct {
1920
Copy bool
2021
}
2122

22-
func (g *FileGetter) ClientMode(u *url.URL) (ClientMode, error) {
23+
func (g *FileGetter) ClientMode(_ context.Context, u *url.URL) (ClientMode, error) {
2324
path := u.Path
2425
if u.RawPath != "" {
2526
path = u.RawPath

0 commit comments

Comments
 (0)