@@ -30,9 +30,10 @@ type testStreamingBatchClient struct {
3030 committed map [messageType ]int
3131 open map [messageType ][]string
3232
33- writeErr error
34- writeErrAfter int64
35- writeCounter map [string ]int64 // table name to write counter
33+ writeErr error
34+ writeErrAfter int64
35+ writeCounter map [string ]int64 // table name to write counter
36+ writeCommitErr error
3637}
3738
3839func newClient () * testStreamingBatchClient {
@@ -84,6 +85,11 @@ func (c *testStreamingBatchClient) WriteTable(ctx context.Context, msgs <-chan *
8485 return c .writeErr // leave msgs open
8586 }
8687 }
88+
89+ if c .writeCommitErr != nil {
90+ return c .writeCommitErr
91+ }
92+
8793 return c .handleTypeCommit (ctx , messageTypeInsert , key )
8894}
8995
@@ -528,6 +534,41 @@ func TestErrorCleanUpSecondMessage(t *testing.T) {
528534 waitForLength (t , testClient .MessageLen , messageTypeInsert , 1 ) // batch size 1
529535}
530536
537+ func TestErrorCleanUpAfterClose (t * testing.T ) {
538+ t .Parallel ()
539+ ctx := context .Background ()
540+ ch := make (chan message.WriteMessage )
541+
542+ testClient := newClient ()
543+ testClient .writeCommitErr = errors .New ("test error" )
544+
545+ wr , err := New (testClient , WithBatchTimeout (0 ), WithBatchSizeRows (100 ))
546+ if err != nil {
547+ t .Fatal (err )
548+ }
549+
550+ errCh := make (chan error )
551+ go func () {
552+ errCh <- wr .Write (ctx , ch )
553+ }()
554+
555+ table := schema.Table {Name : "table1" , Columns : []schema.Column {{Name : "id" , Type : arrow .PrimitiveTypes .Int64 }}}
556+ record := getRecord (table .ToArrowSchema (), 1 )
557+
558+ for i := 0 ; i < 10 ; i ++ {
559+ ch <- & message.WriteInsert {
560+ Record : record ,
561+ }
562+ }
563+
564+ waitForLength (t , testClient .InflightLen , messageTypeInsert , 10 )
565+ close (ch )
566+
567+ requireErrorCount (t , 1 , errCh )
568+
569+ waitForLength (t , testClient .MessageLen , messageTypeInsert , 0 ) // batch size 1
570+ }
571+
531572func waitForLength (t * testing.T , checkLen func (messageType ) int , msgType messageType , want int ) {
532573 t .Helper ()
533574 lastValue := - 1
0 commit comments