Skip to content

Commit f641b0a

Browse files
authored
fix(remote): race during flush (#154)
1 parent 0391768 commit f641b0a

File tree

2 files changed

+89
-16
lines changed

2 files changed

+89
-16
lines changed

upstream/remote/remote.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ const (
3030
)
3131

3232
type Remote struct {
33+
mu sync.Mutex
3334
cfg Config
34-
jobs chan *upstream.UploadJob
35+
jobs chan job
3536
client HTTPClient
3637
logger Logger
3738

3839
done chan struct{}
3940
wg sync.WaitGroup
4041

41-
flushWG sync.WaitGroup
42+
flushWG *sync.WaitGroup
4243
}
4344

4445
type HTTPClient interface {
@@ -69,7 +70,7 @@ type Logger interface {
6970
func NewRemote(cfg Config) (*Remote, error) {
7071
r := &Remote{
7172
cfg: cfg,
72-
jobs: make(chan *upstream.UploadJob, 20),
73+
jobs: make(chan job, 20),
7374
client: &http.Client{
7475
Transport: &http.Transport{
7576
MaxConnsPerHost: cfg.Threads,
@@ -84,8 +85,9 @@ func NewRemote(cfg Config) (*Remote, error) {
8485
},
8586
Timeout: cfg.Timeout,
8687
},
87-
logger: cfg.Logger,
88-
done: make(chan struct{}),
88+
logger: cfg.Logger,
89+
done: make(chan struct{}),
90+
flushWG: new(sync.WaitGroup),
8991
}
9092
if cfg.HTTPClient != nil {
9193
r.client = cfg.HTTPClient
@@ -121,21 +123,28 @@ func (r *Remote) Stop() {
121123
r.wg.Wait()
122124
}
123125

124-
func (r *Remote) Upload(j *upstream.UploadJob) {
126+
func (r *Remote) Upload(uj *upstream.UploadJob) {
127+
r.mu.Lock()
128+
defer r.mu.Unlock()
125129
r.flushWG.Add(1)
130+
j := job{
131+
upload: uj,
132+
flush: r.flushWG,
133+
}
126134
select {
127135
case r.jobs <- j:
128136
default:
129-
r.flushWG.Done()
137+
j.flush.Done()
130138
r.logger.Errorf("remote upload queue is full, dropping a profile job")
131139
}
132140
}
133141

134142
func (r *Remote) Flush() {
135-
if r.done == nil {
136-
return
137-
}
138-
r.flushWG.Wait()
143+
r.mu.Lock()
144+
flush := r.flushWG
145+
r.flushWG = new(sync.WaitGroup)
146+
r.mu.Unlock()
147+
flush.Wait()
139148
}
140149

141150
func (r *Remote) uploadProfile(j *upstream.UploadJob) error {
@@ -174,7 +183,6 @@ func (r *Remote) uploadProfile(j *upstream.UploadJob) error {
174183

175184
q := u.Query()
176185
q.Set("name", j.Name)
177-
// TODO: I think these should be renamed to startTime / endTime
178186
q.Set("from", strconv.FormatInt(j.StartTime.UnixNano(), 10))
179187
q.Set("until", strconv.FormatInt(j.EndTime.UnixNano(), 10))
180188
q.Set("spyName", j.SpyName)
@@ -238,9 +246,9 @@ func (r *Remote) handleJobs() {
238246
case <-r.done:
239247
r.wg.Done()
240248
return
241-
case job := <-r.jobs:
242-
r.safeUpload(job)
243-
r.flushWG.Done()
249+
case j := <-r.jobs:
250+
r.safeUpload(j.upload)
251+
j.flush.Done()
244252
}
245253
}
246254
}
@@ -262,3 +270,8 @@ func (r *Remote) safeUpload(job *upstream.UploadJob) {
262270
r.logger.Errorf("upload profile: %v", err)
263271
}
264272
}
273+
274+
type job struct {
275+
upload *upstream.UploadJob
276+
flush *sync.WaitGroup
277+
}

upstream/remote/remote_test.go

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package remote
22

33
import (
44
"bytes"
5+
"fmt"
6+
"github.com/stretchr/testify/require"
57
"io"
68
"net/http"
9+
"sync"
710
"testing"
811
"time"
912

@@ -98,5 +101,62 @@ type MockHTTPClient struct {
98101

99102
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
100103
args := m.Called(req)
101-
return args.Get(0).(*http.Response), args.Error(1)
104+
err := args.Error(1)
105+
a0 := args.Get(0)
106+
switch typed := a0.(type) {
107+
case *http.Response:
108+
return typed, err
109+
case func() *http.Response:
110+
return typed(), err
111+
default:
112+
return nil, fmt.Errorf("unknown mock arg type arg %+v %w", a0, err)
113+
}
114+
}
115+
116+
func TestConcurrentUploadFlushRace(t *testing.T) {
117+
mockClient := new(MockHTTPClient)
118+
mockClient.On("Do", mock.Anything).Return(func() *http.Response {
119+
return &http.Response{
120+
StatusCode: 200,
121+
Body: io.NopCloser(bytes.NewBufferString("OK")),
122+
}
123+
}, nil)
124+
r, err := NewRemote(Config{
125+
Threads: 2,
126+
Logger: testutil.NewTestLogger(),
127+
HTTPClient: mockClient,
128+
})
129+
require.NoError(t, err)
130+
r.Start()
131+
defer r.Stop()
132+
133+
var wg sync.WaitGroup
134+
wg.Add(2)
135+
loop := func(f func()) {
136+
timeout := time.After(10 * time.Millisecond)
137+
go func() {
138+
defer wg.Done()
139+
for {
140+
select {
141+
case <-timeout:
142+
return
143+
default:
144+
f()
145+
}
146+
}
147+
}()
148+
}
149+
loop(func() {
150+
r.Upload(newJob("job1"))
151+
})
152+
loop(func() {
153+
r.Flush()
154+
})
155+
wg.Wait()
156+
}
157+
158+
func newJob(name string) *upstream.UploadJob {
159+
return &upstream.UploadJob{
160+
Name: name,
161+
}
102162
}

0 commit comments

Comments
 (0)