diff --git a/kadai3-2/shuheiktgw/gget/.gitignore b/kadai3-2/shuheiktgw/gget/.gitignore new file mode 100644 index 0000000..e6a58c1 --- /dev/null +++ b/kadai3-2/shuheiktgw/gget/.gitignore @@ -0,0 +1 @@ +gget \ No newline at end of file diff --git a/kadai3-2/shuheiktgw/gget/README.md b/kadai3-2/shuheiktgw/gget/README.md new file mode 100644 index 0000000..039fd3a --- /dev/null +++ b/kadai3-2/shuheiktgw/gget/README.md @@ -0,0 +1,21 @@ +gget +==== + +gget is a wget like command to download file, but downloads a file in parallel. + +## Usage +``` +gget [options...] URL + +OPTIONS: + --parallel value, -p value specifies the amount of parallelism (default: the number of CPU) + --help, -h prints help + +``` + +## Install + +``` +go build +./gget [options...] URL +``` \ No newline at end of file diff --git a/kadai3-2/shuheiktgw/gget/cli.go b/kadai3-2/shuheiktgw/gget/cli.go new file mode 100644 index 0000000..4436196 --- /dev/null +++ b/kadai3-2/shuheiktgw/gget/cli.go @@ -0,0 +1,75 @@ +package main + +import ( + "flag" + "fmt" + "io" + "runtime" +) + +const ( + ExitCodeOK = iota + ExitCodeError + ExitCodeBadArgsError + ExitCodeParseFlagsError + ExitCodeInvalidFlagError +) + +const name = "gget" + +// CLI represents CLI interface for gget +type CLI struct { + outStream, errStream io.Writer +} + +// Run runs gget command +func (cli *CLI) Run(args []string) int { + var parallel int + + flags := flag.NewFlagSet(name, flag.ContinueOnError) + flags.Usage = func() { + fmt.Fprint(cli.outStream, usage) + } + + numCPU := runtime.NumCPU() + flags.IntVar(¶llel, "parallel", numCPU, "") + flags.IntVar(¶llel, "p", numCPU, "") + + if err := flags.Parse(args[1:]); err != nil { + return ExitCodeParseFlagsError + } + + if parallel < 1 { + fmt.Fprintf(cli.errStream, "Failed to set up gget: The number of parallels cannot be less than one\n") + return ExitCodeInvalidFlagError + } + + parsedArgs := flags.Args() + if len(parsedArgs) != 1 { + fmt.Fprintf(cli.errStream, "Invalid arguments: you need to set exactly one URL\n") + return ExitCodeBadArgsError + } + + request, err := NewRequest(parsedArgs[0], parallel) + if err != nil { + fmt.Fprintf(cli.errStream, "Error occurred while initializing a request: %s\n", err) + return ExitCodeError + } + + if err := request.Do(); err != nil { + fmt.Fprintf(cli.errStream, "Error occurred while downloading the file: %s\n", err) + return ExitCodeError + } + + return ExitCodeOK +} + +var usage = `Usage: gget [options...] URL + +gget is a wget like command to download file, but downloads a file in parallel + +OPTIONS: + --parallel value, -p value specifies the amount of parallelism (default: the number of CPU) + --help, -h prints help + +` diff --git a/kadai3-2/shuheiktgw/gget/main.go b/kadai3-2/shuheiktgw/gget/main.go new file mode 100644 index 0000000..9c033c1 --- /dev/null +++ b/kadai3-2/shuheiktgw/gget/main.go @@ -0,0 +1,8 @@ +package main + +import "os" + +func main() { + cli := &CLI{outStream: os.Stdout, errStream: os.Stderr} + os.Exit(cli.Run(os.Args)) +} diff --git a/kadai3-2/shuheiktgw/gget/request.go b/kadai3-2/shuheiktgw/gget/request.go new file mode 100644 index 0000000..07ed441 --- /dev/null +++ b/kadai3-2/shuheiktgw/gget/request.go @@ -0,0 +1,176 @@ +package main + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + + "golang.org/x/sync/errgroup" +) + +// RangeRequest represents a request with a range access +type RangeRequest struct { + URL string + FName string + Ranges []*Range +} + +// NonRangeRequest represents a request without a range access +type NonRangeRequest struct { + URL string + FName string +} + +// Request represents a request +type Request interface { + Do() error +} + +// Range tells tha range of file to download +type Range struct { + start int64 + end int64 +} + +// NewRequest initializes Request object +func NewRequest(rawURL string, parallel int) (Request, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + ss := strings.Split(u.Path, "/") + fname := ss[len(ss)-1] + + res, err := http.Head(rawURL) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.Header.Get("Accept-Ranges") != "bytes" { + return &NonRangeRequest{URL: rawURL, FName: fname}, nil + } + + total := res.ContentLength + unit := total / int64(parallel) + ranges := make([]*Range, parallel) + + for i := 0; i < parallel; i++ { + var start int64 + if i == 0 { + start = 0 + } else { + start = int64(i)*unit + 1 + } + + var end int64 + if i == parallel-1 { + end = total + } else { + end = int64(i+1) * unit + } + + ranges[i] = &Range{start: start, end: end} + } + + return &RangeRequest{URL: rawURL, FName: fname, Ranges: ranges}, nil +} + +// Do sends a real HTTP requests in parallel +func (r *NonRangeRequest) Do() error { + req, err := http.NewRequest(http.MethodGet, r.URL, nil) + + client := http.DefaultClient + res, err := client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + return saveResponseBody(r.FName, res) +} + +// Do sends a real HTTP requests in parallel +func (r *RangeRequest) Do() error { + eg, ctx := errgroup.WithContext(context.TODO()) + + for idx := range r.Ranges { + // DO NOT refer to idx directly since function below + // is a closure and idx changes for each iterations + i := idx + eg.Go(func() error { + return r.do(i, ctx) + }) + } + + if err := eg.Wait(); err != nil { + return err + } + + return r.mergeFiles() +} + +func (r *RangeRequest) do(idx int, ctx context.Context) error { + req, err := http.NewRequest(http.MethodGet, r.URL, nil) + if err != nil { + return err + } + req = req.WithContext(ctx) + + ran := r.Ranges[idx] + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", ran.start, ran.end)) + + client := http.DefaultClient + + res, err := client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + tmpFName := fmt.Sprintf("%s.%d", r.FName, idx) + return saveResponseBody(tmpFName, res) +} + +func (r *RangeRequest) mergeFiles() error { + f, err := os.Create(r.FName) + if err != nil { + return err + } + defer f.Close() + + for idx := range r.Ranges { + tmpFName := fmt.Sprintf("%s.%d", r.FName, idx) + tmpFile, err := os.Open(tmpFName) + if err != nil { + return err + } + + io.Copy(f, tmpFile) + tmpFile.Close() + if err := os.Remove(tmpFName); err != nil { + return err + } + } + + return nil +} + +func saveResponseBody(fname string, response *http.Response) error { + file, err := os.Create(fname) + if err != nil { + return err + } + defer file.Close() + + if _, err := io.Copy(file, response.Body); err != nil { + return err + } + + return nil +}