|
| 1 | +package common |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "io" |
| 6 | + "os" |
| 7 | + "time" |
| 8 | +) |
| 9 | + |
| 10 | +// ProgressOption configures the behaviour of ProgressWriter. |
| 11 | +type ProgressOption func(*ProgressWriter) |
| 12 | + |
| 13 | +// WithMessage sets the status prefix shown before percentage/KiB. |
| 14 | +func WithMessage(msg string) ProgressOption { |
| 15 | + return func(p *ProgressWriter) { p.msg = msg } |
| 16 | +} |
| 17 | + |
| 18 | +// WithFrequency sets how often updates are printed. |
| 19 | +func WithFrequency(freq time.Duration) ProgressOption { |
| 20 | + return func(p *ProgressWriter) { p.freq = freq } |
| 21 | +} |
| 22 | + |
| 23 | +// WithOutput sets the destination writer (defaults to os.Stderr). |
| 24 | +func WithOutput(w io.Writer) ProgressOption { |
| 25 | + return func(p *ProgressWriter) { p.out = w } |
| 26 | +} |
| 27 | + |
| 28 | +// ProgressWriter is an io.Writer that prints a simple progress bar. |
| 29 | +// It is safe to use with io.TeeReader. |
| 30 | +// |
| 31 | +// pw := common.NewProgressWriter(resp.ContentLength, common.WithMessage("Downloading...")) |
| 32 | +// io.Copy(dst, io.TeeReader(src, pw)) |
| 33 | +type ProgressWriter struct { |
| 34 | + written int64 |
| 35 | + lastPrint time.Time |
| 36 | + contentLen int64 |
| 37 | + |
| 38 | + msg string |
| 39 | + freq time.Duration |
| 40 | + out io.Writer |
| 41 | +} |
| 42 | + |
| 43 | +const defaultFreq = 500 * time.Millisecond |
| 44 | + |
| 45 | +// NewProgressWriter returns a configured ProgressWriter. |
| 46 | +func NewProgressWriter(contentLen int64, opts ...ProgressOption) *ProgressWriter { |
| 47 | + pw := &ProgressWriter{ |
| 48 | + contentLen: contentLen, |
| 49 | + msg: "Progress...", |
| 50 | + freq: defaultFreq, |
| 51 | + out: os.Stderr, |
| 52 | + } |
| 53 | + for _, o := range opts { |
| 54 | + o(pw) |
| 55 | + } |
| 56 | + return pw |
| 57 | +} |
| 58 | + |
| 59 | +// Write implements io.Writer. |
| 60 | +func (p *ProgressWriter) Write(b []byte) (int, error) { |
| 61 | + n := len(b) |
| 62 | + p.written += int64(n) |
| 63 | + now := time.Now() |
| 64 | + |
| 65 | + if now.Sub(p.lastPrint) >= p.freq || p.written == p.contentLen { |
| 66 | + if p.contentLen > 0 { |
| 67 | + percent := float64(p.written) / float64(p.contentLen) * 100 |
| 68 | + fmt.Fprintf(p.out, "\r%s %5.1f%%", p.msg, percent) |
| 69 | + if p.written == p.contentLen { |
| 70 | + fmt.Fprintln(p.out) |
| 71 | + } |
| 72 | + } else { |
| 73 | + fmt.Fprintf(p.out, "\r%s %d KiB", p.msg, p.written/1024) |
| 74 | + } |
| 75 | + p.lastPrint = now |
| 76 | + } |
| 77 | + return n, nil |
| 78 | +} |
0 commit comments