@@ -402,12 +402,11 @@ func (s *Server) Write(stream pb.Plugin_WriteServer) error {
402402
403403func (s * Server ) Transform (stream pb.Plugin_TransformServer ) error {
404404 var (
405- recvRecords = make (chan arrow.Record )
406- sendRecords = make (chan arrow.Record )
407- pluginStopsWriter = make (chan struct {})
408- doneReading = false
409- ctx = stream .Context ()
410- eg , gctx = errgroup .WithContext (ctx )
405+ recvRecords = make (chan arrow.Record )
406+ sendRecords = make (chan arrow.Record )
407+ pluginStops = make (chan error )
408+ ctx = stream .Context ()
409+ eg , gctx = errgroup .WithContext (ctx )
411410 )
412411
413412 // Run the plugin's transform with both channels.
@@ -416,10 +415,10 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
416415 // The plugin must not close either channel.
417416 eg .Go (func () error {
418417 err := s .Plugin .Transform (gctx , recvRecords , sendRecords )
419- close (pluginStopsWriter )
420- doneReading = true
421418 if err != nil {
422- return status .Error (codes .Internal , err .Error ())
419+ err = status .Error (codes .Internal , err .Error ())
420+ pluginStops <- err
421+ return err
423422 }
424423 return nil
425424 })
@@ -443,8 +442,8 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
443442 if err := stream .Send (& pb.Transform_Response {Record : recordBytes }); err != nil {
444443 return status .Errorf (codes .Internal , "error sending response: %v" , err )
445444 }
446- case <- pluginStopsWriter :
447- return nil
445+ case err := <- pluginStops :
446+ return err
448447 }
449448 }
450449 })
@@ -468,16 +467,30 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
468467 close (recvRecords )
469468 return status .Errorf (codes .Internal , "Error receiving request: %v" , err )
470469 }
471- if doneReading {
472- return nil
473- }
474470 record , err := pb .NewRecordFromBytes (req .Record )
475471 if err != nil {
476472 close (recvRecords )
477473 return status .Errorf (codes .InvalidArgument , "failed to create record: %v" , err )
478474 }
479475
480- recvRecords <- record
476+ select {
477+ case recvRecords <- record :
478+ case err := <- pluginStops :
479+ close (recvRecords )
480+ return err
481+ case <- gctx .Done ():
482+ close (recvRecords )
483+ if err := eg .Wait (); err != nil {
484+ return status .Errorf (codes .Canceled , "plugin returned error: %v" , err )
485+ }
486+ return status .Errorf (codes .Internal , "transform failed for unknown reason" )
487+ case <- ctx .Done ():
488+ close (recvRecords )
489+ if err := eg .Wait (); err != nil {
490+ return status .Errorf (codes .Internal , "context done: %v and failed to wait for plugin: %v" , ctx .Err (), err )
491+ }
492+ return status .Errorf (codes .Canceled , "context done: %v" , ctx .Err ())
493+ }
481494 }
482495 })
483496
0 commit comments