@@ -18,6 +18,7 @@ import (
1818 "os"
1919 "path"
2020 "strings"
21+ "sync"
2122 "time"
2223
2324 "github.com/cea-hpc/sshproxy/pkg/record"
@@ -103,6 +104,7 @@ type Recorder struct {
103104 dumpLimitSize uint64 // number of bytes beyond which records are no longer dumped
104105 dumpLimitWindow time.Duration // time window in which dump size is accounted
105106 leaseID clientv3.LeaseID // etcd lease ID used for updating stats
107+ lock sync.RWMutex // mutex to avoid concurrent reads and writes in bandwidth and totals maps
106108 writer * record.Writer // *record.Writer where the raw records are dumped
107109}
108110
@@ -129,6 +131,7 @@ func NewRecorder(conninfo *ConnInfo, dumpfile, command string, etcdStatsInterval
129131 dumpLimitSize : dumpLimitSize ,
130132 dumpLimitWindow : dumpLimitWindow ,
131133 leaseID : leaseID ,
134+ lock : sync.RWMutex {},
132135 writer : nil ,
133136 }
134137}
@@ -137,7 +140,10 @@ func NewRecorder(conninfo *ConnInfo, dumpfile, command string, etcdStatsInterval
137140func (r * Recorder ) updateStats (cli * utils.Client , etcdPath string ) {
138141 if cli != nil {
139142 if cli .IsAlive () {
140- err := cli .UpdateStats (etcdPath , r .bandwidth , r .leaseID )
143+ r .lock .RLock ()
144+ stats := r .bandwidth
145+ r .lock .RUnlock ()
146+ err := cli .UpdateStats (etcdPath , stats , r .leaseID )
141147 if err != nil {
142148 log .Errorf ("updating stats: %v" , err )
143149 }
@@ -149,9 +155,11 @@ func (r *Recorder) updateStats(cli *utils.Client, etcdPath string) {
149155func (r * Recorder ) log (ctx context.Context , step string ) {
150156 fds := []string {"stdin" , "stdout" , "stderr" }
151157 t := []string {}
158+ r .lock .RLock ()
152159 for fd , name := range fds {
153160 t = append (t , fmt .Sprintf ("%s=%d" , name , r .totals [fd ]))
154161 }
162+ r .lock .RUnlock ()
155163 // round to second
156164 elapsed := time .Duration ((time .Since (r .start ) / time .Second ) * time .Second )
157165 log .Infof ("bytes transferred (%s): duration=%s %s" , step , elapsed , strings .Join (t , " " ))
@@ -237,6 +245,7 @@ func (r *Recorder) Run(ctx context.Context, cli *utils.Client, etcdPath string)
237245 }
238246 if r .etcdStatsInterval != 0 {
239247 go func () {
248+ time .Sleep (time .Second )
240249 for {
241250 select {
242251 case <- time .After (r .etcdStatsInterval ):
@@ -252,17 +261,21 @@ func (r *Recorder) Run(ctx context.Context, cli *utils.Client, etcdPath string)
252261 select {
253262 case <- timeout :
254263 timeout = time .After (r .etcdStatsInterval )
264+ r .lock .Lock ()
255265 for i := 0 ; i <= 2 ; i ++ {
256266 r .bandwidth [i ] = buf [i ] / uint64 (r .etcdStatsInterval .Seconds ())
257267 buf [i ] = 0
258268 }
269+ r .lock .Unlock ()
259270 case <- bwTimeout :
260271 bwTimeout = time .After (r .dumpLimitWindow )
261272 bw = bwBuf
262273 bwBuf = 0
263274 case rec := <- r .ch :
264275 buf [rec .Fd ] += uint64 (rec .Size )
276+ r .lock .Lock ()
265277 r .totals [rec .Fd ] += uint64 (rec .Size )
278+ r .lock .Unlock ()
266279 if r .writer != nil {
267280 if r .dumpLimitSize == 0 || (bw < r .dumpLimitSize && bwBuf < r .dumpLimitSize ) {
268281 r .dump (rec )
@@ -285,7 +298,9 @@ func (r *Recorder) Run(ctx context.Context, cli *utils.Client, etcdPath string)
285298 bw = bwBuf
286299 bwBuf = 0
287300 case rec := <- r .ch :
301+ r .lock .Lock ()
288302 r .totals [rec .Fd ] += uint64 (rec .Size )
303+ r .lock .Unlock ()
289304 if r .writer != nil {
290305 if r .dumpLimitSize == 0 || (bw < r .dumpLimitSize && bwBuf < r .dumpLimitSize ) {
291306 r .dump (rec )
0 commit comments