Skip to content

Commit 52c771d

Browse files
committed
Make writer close channel. Simplify concurrency model.
1 parent 286269f commit 52c771d

File tree

2 files changed

+15
-25
lines changed

2 files changed

+15
-25
lines changed

internal/servers/plugin/v3/plugin.go

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
404404
var (
405405
recvRecords = make(chan arrow.Record)
406406
sendRecords = make(chan arrow.Record)
407-
pluginStops = make(chan error)
408407
ctx = stream.Context()
409408
eg, gctx = errgroup.WithContext(ctx)
410409
)
@@ -414,11 +413,8 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
414413
// When the plugin is done, it must return with either an error or nil.
415414
// The plugin must not close either channel.
416415
eg.Go(func() error {
417-
err := s.Plugin.Transform(gctx, recvRecords, sendRecords)
418-
if err != nil {
419-
err = status.Error(codes.Internal, err.Error())
420-
pluginStops <- err
421-
return err
416+
if err := s.Plugin.Transform(gctx, recvRecords, sendRecords); err != nil {
417+
return status.Error(codes.Internal, err.Error())
422418
}
423419
return nil
424420
})
@@ -431,21 +427,16 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
431427
// The reading never closes the writer, because it's up to the Plugin to decide when to finish
432428
// writing, regardless of if the reading finished.
433429
eg.Go(func() error {
434-
for {
435-
select {
436-
case record := <-sendRecords:
437-
recordBytes, err := pb.RecordToBytes(record)
438-
if err != nil {
439-
return status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err)
440-
}
441-
442-
if err := stream.Send(&pb.Transform_Response{Record: recordBytes}); err != nil {
443-
return status.Errorf(codes.Internal, "error sending response: %v", err)
444-
}
445-
case err := <-pluginStops:
446-
return err
430+
for record := range sendRecords {
431+
recordBytes, err := pb.RecordToBytes(record)
432+
if err != nil {
433+
return status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err)
434+
}
435+
if err := stream.Send(&pb.Transform_Response{Record: recordBytes}); err != nil {
436+
return status.Errorf(codes.Internal, "error sending response: %v", err)
447437
}
448438
}
439+
return nil
449440
})
450441

451442
// Read records from source to transformer
@@ -475,15 +466,12 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
475466

476467
select {
477468
case recvRecords <- record:
478-
case err := <-pluginStops:
479-
close(recvRecords)
480-
return err
481469
case <-gctx.Done():
482470
close(recvRecords)
483-
return gctx.Err()
471+
return status.Errorf(codes.Canceled, "context done: %v", gctx.Err())
484472
case <-ctx.Done():
485473
close(recvRecords)
486-
return ctx.Err()
474+
return status.Errorf(codes.Canceled, "context done: %v", ctx.Err())
487475
}
488476
}
489477
})

plugin/plugin_transformer.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ type TransformerClient interface {
1212
}
1313

1414
func (p *Plugin) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
15-
return p.client.Transform(ctx, recvRecords, sendRecords)
15+
err := p.client.Transform(ctx, recvRecords, sendRecords)
16+
close(sendRecords)
17+
return err
1618
}
1719
func (p *Plugin) TransformSchema(ctx context.Context, old *arrow.Schema) (*arrow.Schema, error) {
1820
return p.client.TransformSchema(ctx, old)

0 commit comments

Comments
 (0)