Skip to content

Commit 8b02c95

Browse files
committed
mutex to avoid concurrent reads and writes Recorder maps
1 parent 8df13e2 commit 8b02c95

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

cmd/sshproxy/recorder.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"os"
1919
"path"
2020
"strings"
21+
"sync"
2122
"time"
2223

2324
"github.com/cea-hpc/sshproxy/pkg/record"
@@ -103,6 +104,7 @@ type Recorder struct {
103104
dumpLimitSize uint64 // number of bytes beyond which records are no longer dumped
104105
dumpLimitWindow time.Duration // time window in which dump size is accounted
105106
leaseID clientv3.LeaseID // etcd lease ID used for updating stats
107+
lock sync.RWMutex // mutex to avoid concurrent reads and writes in bandwidth and totals maps
106108
writer *record.Writer // *record.Writer where the raw records are dumped
107109
}
108110

@@ -129,6 +131,7 @@ func NewRecorder(conninfo *ConnInfo, dumpfile, command string, etcdStatsInterval
129131
dumpLimitSize: dumpLimitSize,
130132
dumpLimitWindow: dumpLimitWindow,
131133
leaseID: leaseID,
134+
lock: sync.RWMutex{},
132135
writer: nil,
133136
}
134137
}
@@ -137,7 +140,10 @@ func NewRecorder(conninfo *ConnInfo, dumpfile, command string, etcdStatsInterval
137140
func (r *Recorder) updateStats(cli *utils.Client, etcdPath string) {
138141
if cli != nil {
139142
if cli.IsAlive() {
140-
err := cli.UpdateStats(etcdPath, r.bandwidth, r.leaseID)
143+
r.lock.RLock()
144+
stats := r.bandwidth
145+
r.lock.RUnlock()
146+
err := cli.UpdateStats(etcdPath, stats, r.leaseID)
141147
if err != nil {
142148
log.Errorf("updating stats: %v", err)
143149
}
@@ -149,9 +155,11 @@ func (r *Recorder) updateStats(cli *utils.Client, etcdPath string) {
149155
func (r *Recorder) log(ctx context.Context, step string) {
150156
fds := []string{"stdin", "stdout", "stderr"}
151157
t := []string{}
158+
r.lock.RLock()
152159
for fd, name := range fds {
153160
t = append(t, fmt.Sprintf("%s=%d", name, r.totals[fd]))
154161
}
162+
r.lock.RUnlock()
155163
// round to second
156164
elapsed := time.Duration((time.Since(r.start) / time.Second) * time.Second)
157165
log.Infof("bytes transferred (%s): duration=%s %s", step, elapsed, strings.Join(t, " "))
@@ -237,6 +245,7 @@ func (r *Recorder) Run(ctx context.Context, cli *utils.Client, etcdPath string)
237245
}
238246
if r.etcdStatsInterval != 0 {
239247
go func() {
248+
time.Sleep(time.Second)
240249
for {
241250
select {
242251
case <-time.After(r.etcdStatsInterval):
@@ -252,17 +261,21 @@ func (r *Recorder) Run(ctx context.Context, cli *utils.Client, etcdPath string)
252261
select {
253262
case <-timeout:
254263
timeout = time.After(r.etcdStatsInterval)
264+
r.lock.Lock()
255265
for i := 0; i <= 2; i++ {
256266
r.bandwidth[i] = buf[i] / uint64(r.etcdStatsInterval.Seconds())
257267
buf[i] = 0
258268
}
269+
r.lock.Unlock()
259270
case <-bwTimeout:
260271
bwTimeout = time.After(r.dumpLimitWindow)
261272
bw = bwBuf
262273
bwBuf = 0
263274
case rec := <-r.ch:
264275
buf[rec.Fd] += uint64(rec.Size)
276+
r.lock.Lock()
265277
r.totals[rec.Fd] += uint64(rec.Size)
278+
r.lock.Unlock()
266279
if r.writer != nil {
267280
if r.dumpLimitSize == 0 || (bw < r.dumpLimitSize && bwBuf < r.dumpLimitSize) {
268281
r.dump(rec)
@@ -285,7 +298,9 @@ func (r *Recorder) Run(ctx context.Context, cli *utils.Client, etcdPath string)
285298
bw = bwBuf
286299
bwBuf = 0
287300
case rec := <-r.ch:
301+
r.lock.Lock()
288302
r.totals[rec.Fd] += uint64(rec.Size)
303+
r.lock.Unlock()
289304
if r.writer != nil {
290305
if r.dumpLimitSize == 0 || (bw < r.dumpLimitSize && bwBuf < r.dumpLimitSize) {
291306
r.dump(rec)

0 commit comments

Comments
 (0)