@@ -3,10 +3,11 @@ package work
33import (
44 "errors"
55 "sync"
6+ "sync/atomic"
67)
78
89type Batch struct {
9- batchPosition int
10+ batchPosition atomic. Value
1011 batchSize int
1112 itemsToSave []interface {}
1213 pushHandler BatchHandler
@@ -46,7 +47,7 @@ type BytesSource interface {
4647type BatchHandler func ([]interface {}) error
4748
4849func (b * Batch ) Init (batchSize int , pushHandler BatchHandler , flushHandler ... BatchHandler ) {
49- b .batchPosition = 0
50+ b .batchPosition . Store ( 0 )
5051
5152 // grab the batch size - default to 100
5253 b .batchSize = batchSize
@@ -82,19 +83,20 @@ func (b *Batch) Push(record interface{}) error {
8283 b .itemsToSave = make ([]interface {}, b .batchSize , b .batchSize )
8384 }
8485
86+ batchPosition , _ := b .batchPosition .Load ().(int )
87+
8588 // if our batch is full
86- if b . batchPosition >= b .batchSize {
89+ if batchPosition >= b .batchSize {
8790 batch := b .itemsToSave
8891
8992 // allocate a new buffer, put the inbound record as the first item
9093 b .itemsToSave = make ([]interface {}, b .batchSize , b .batchSize )
9194 b .itemsToSave [0 ] = record
92- b .batchPosition = 1
95+ b .batchPosition . Store ( 1 )
9396
9497 // release the lock
9598 b .mutex .Unlock ()
9699
97- // TODO: review impact of making this call from a goroutine - definitely faster, but would bugs arise from timing changes?
98100 if err := b .pushHandler (batch ); err != nil {
99101 return err
100102 }
@@ -103,8 +105,9 @@ func (b *Batch) Push(record interface{}) error {
103105 } else {
104106
105107 // our batch is not full - if the batch size
106- b .itemsToSave [b .batchPosition ] = record
107- b .batchPosition ++
108+ b .itemsToSave [batchPosition ] = record
109+ batchPosition ++
110+ b .batchPosition .Store (batchPosition )
108111 b .mutex .Unlock ()
109112 }
110113
@@ -113,7 +116,7 @@ func (b *Batch) Push(record interface{}) error {
113116
114117func (b * Batch ) GetPosition () int {
115118 b .mutex .Lock ()
116- pos := b .batchPosition
119+ pos , _ := b .batchPosition . Load ().( int )
117120 b .mutex .Unlock ()
118121 return pos
119122}
@@ -125,12 +128,13 @@ func (b *Batch) Flush() error {
125128
126129 // lock around batch processing
127130 b .mutex .Lock ()
128- if b .batchPosition > 0 {
131+ batchPosition , _ := b .batchPosition .Load ().(int )
132+ if batchPosition > 0 {
129133
130134 // snag the rest of the buffer as a slice, reset buffer
131- subSlice := (b .itemsToSave )[0 :b . batchPosition ]
135+ subSlice := (b .itemsToSave )[0 :batchPosition ]
132136 b .itemsToSave = make ([]interface {}, b .batchSize , b .batchSize )
133- b .batchPosition = 0
137+ b .batchPosition . Store ( 0 )
134138
135139 // we've finished batch processing, unlock
136140 b .mutex .Unlock ()
0 commit comments