Skip to content

Commit 2d90985

Browse files
committed
Make sure Transform finishes.
1 parent 106a0bb commit 2d90985

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

internal/servers/plugin/v3/plugin.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,11 @@ func (s *Server) Write(stream pb.Plugin_WriteServer) error {
402402

403403
func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
404404
var (
405-
recvRecords = make(chan arrow.Record)
406-
sendRecords = make(chan arrow.Record)
407-
pluginStopsWriter = make(chan struct{})
408-
doneReading = false
409-
ctx = stream.Context()
410-
eg, gctx = errgroup.WithContext(ctx)
405+
recvRecords = make(chan arrow.Record)
406+
sendRecords = make(chan arrow.Record)
407+
pluginStops = make(chan error)
408+
ctx = stream.Context()
409+
eg, gctx = errgroup.WithContext(ctx)
411410
)
412411

413412
// Run the plugin's transform with both channels.
@@ -416,10 +415,10 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
416415
// The plugin must not close either channel.
417416
eg.Go(func() error {
418417
err := s.Plugin.Transform(gctx, recvRecords, sendRecords)
419-
close(pluginStopsWriter)
420-
doneReading = true
421418
if err != nil {
422-
return status.Error(codes.Internal, err.Error())
419+
err = status.Error(codes.Internal, err.Error())
420+
pluginStops <- err
421+
return err
423422
}
424423
return nil
425424
})
@@ -443,8 +442,8 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
443442
if err := stream.Send(&pb.Transform_Response{Record: recordBytes}); err != nil {
444443
return status.Errorf(codes.Internal, "error sending response: %v", err)
445444
}
446-
case <-pluginStopsWriter:
447-
return nil
445+
case err := <-pluginStops:
446+
return err
448447
}
449448
}
450449
})
@@ -468,16 +467,30 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
468467
close(recvRecords)
469468
return status.Errorf(codes.Internal, "Error receiving request: %v", err)
470469
}
471-
if doneReading {
472-
return nil
473-
}
474470
record, err := pb.NewRecordFromBytes(req.Record)
475471
if err != nil {
476472
close(recvRecords)
477473
return status.Errorf(codes.InvalidArgument, "failed to create record: %v", err)
478474
}
479475

480-
recvRecords <- record
476+
select {
477+
case recvRecords <- record:
478+
case err := <-pluginStops:
479+
close(recvRecords)
480+
return err
481+
case <-gctx.Done():
482+
close(recvRecords)
483+
if err := eg.Wait(); err != nil {
484+
return status.Errorf(codes.Canceled, "plugin returned error: %v", err)
485+
}
486+
return status.Errorf(codes.Internal, "transform failed for unknown reason")
487+
case <-ctx.Done():
488+
close(recvRecords)
489+
if err := eg.Wait(); err != nil {
490+
return status.Errorf(codes.Internal, "context done: %v and failed to wait for plugin: %v", ctx.Err(), err)
491+
}
492+
return status.Errorf(codes.Canceled, "context done: %v", ctx.Err())
493+
}
481494
}
482495
})
483496

0 commit comments

Comments
 (0)