@@ -2,8 +2,12 @@ package plugin
22
33import (
44 "context"
5+ "errors"
56 "io"
7+ "strings"
8+ "sync/atomic"
69 "testing"
10+ "time"
711
812 "github.com/apache/arrow-go/v18/arrow"
913 "github.com/apache/arrow-go/v18/arrow/array"
@@ -248,3 +252,143 @@ func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old
248252 return old .AddField (1 , arrow.Field {Name : "source" , Type : arrow .BinaryTypes .String })
249253}
250254func (* mockSourceColumnAdderPluginClient ) Close (context.Context ) error { return nil }
255+
256+ type testTransformPluginClient struct {
257+ plugin.UnimplementedDestination
258+ plugin.UnimplementedSource
259+ recordsSent int32
260+ }
261+
262+ func (c * testTransformPluginClient ) Transform (ctx context.Context , recvRecords <- chan arrow.Record , sendRecords chan <- arrow.Record ) error {
263+ for record := range recvRecords {
264+ select {
265+ default :
266+ time .Sleep (1 * time .Second )
267+ sendRecords <- record
268+ atomic .AddInt32 (& c .recordsSent , 1 )
269+ case <- ctx .Done ():
270+ return ctx .Err ()
271+ }
272+ }
273+ return nil
274+ }
275+
276+ func (* testTransformPluginClient ) TransformSchema (_ context.Context , old * arrow.Schema ) (* arrow.Schema , error ) {
277+ return old , nil
278+ }
279+
280+ func (* testTransformPluginClient ) Close (context.Context ) error {
281+ return nil
282+ }
283+
284+ func TestTransformNoDeadlockOnSendError (t * testing.T ) {
285+ client := & testTransformPluginClient {}
286+ p := plugin .NewPlugin ("test" , "development" , func (context.Context , zerolog.Logger , []byte , plugin.NewClientOptions ) (plugin.Client , error ) {
287+ return client , nil
288+ })
289+ s := Server {
290+ Plugin : p ,
291+ }
292+ _ , err := s .Init (context .Background (), & pb.Init_Request {})
293+ require .NoError (t , err )
294+
295+ // Create a channel to signal when Send was called
296+ sendCalled := make (chan struct {})
297+ // Create a channel to signal when we should return from the test
298+ done := make (chan struct {})
299+ defer close (done )
300+
301+ stream := & mockTransformServerWithBlockingSend {
302+ incomingMessages : makeRequests (3 ), // Multiple messages to ensure Transform tries to keep sending
303+ sendCalled : sendCalled ,
304+ done : done ,
305+ }
306+
307+ // Run Transform in a goroutine with a timeout
308+ errCh := make (chan error )
309+ go func () {
310+ errCh <- s .Transform (stream )
311+ }()
312+
313+ // Wait for the first Send to be called
314+ select {
315+ case <- sendCalled :
316+ // Send was called, good
317+ case <- time .After (5 * time .Second ):
318+ t .Fatal ("timeout waiting for Send to be called" )
319+ }
320+
321+ // Now wait for Transform to complete or timeout
322+ select {
323+ case err := <- errCh :
324+ require .Error (t , err )
325+ // Check for either the simulated error or context cancellation
326+ if ! strings .Contains (err .Error (), "simulated stream send error" ) &&
327+ ! strings .Contains (err .Error (), "context canceled" ) {
328+ t .Fatalf ("unexpected error: %v" , err )
329+ }
330+ case <- time .After (5 * time .Second ):
331+ t .Fatal ("Transform got deadlocked" )
332+ }
333+ }
334+
335+ type mockTransformServerWithBlockingSend struct {
336+ grpc.ServerStream
337+ incomingMessages []* pb.Transform_Request
338+ sendCalled chan struct {}
339+ done chan struct {}
340+ sendCount int32
341+ }
342+
343+ func (s * mockTransformServerWithBlockingSend ) Recv () (* pb.Transform_Request , error ) {
344+ if len (s .incomingMessages ) > 0 {
345+ msg := s .incomingMessages [0 ]
346+ s .incomingMessages = s .incomingMessages [1 :]
347+ return msg , nil
348+ }
349+ return nil , io .EOF
350+ }
351+
352+ func (s * mockTransformServerWithBlockingSend ) Send (* pb.Transform_Response ) error {
353+ // Signal that Send was called
354+ select {
355+ case s .sendCalled <- struct {}{}:
356+ default :
357+ }
358+
359+ // Return error on first send
360+ if atomic .AddInt32 (& s .sendCount , 1 ) == 1 {
361+ return errors .New ("simulated stream send error" )
362+ }
363+
364+ // Block until test is done
365+ <- s .done
366+ return nil
367+ }
368+
369+ func (* mockTransformServerWithBlockingSend ) Context () context.Context {
370+ return context .Background ()
371+ }
372+
373+ func makeRequests (i int ) []* pb.Transform_Request {
374+ requests := make ([]* pb.Transform_Request , i )
375+ for i := range i {
376+ requests [i ] = makeRequestFromString ("test" )
377+ }
378+ return requests
379+ }
380+
381+ func makeRequestFromString (s string ) * pb.Transform_Request {
382+ record := makeRecordFromString (s )
383+ bs , _ := pb .RecordToBytes (record )
384+ return & pb.Transform_Request {Record : bs }
385+ }
386+
387+ func makeRecordFromString (s string ) arrow.Record {
388+ str := array .NewStringBuilder (memory .DefaultAllocator )
389+ str .AppendString (s )
390+ arr := str .NewStringArray ()
391+ sch := arrow .NewSchema ([]arrow.Field {{Name : "col1" , Type : arrow .BinaryTypes .String }}, nil )
392+
393+ return array .NewRecord (sch , []arrow.Array {arr }, 1 )
394+ }
0 commit comments