Skip to content

Commit 340a39a

Browse files
fix(storage): decouple Recv loop in gRPCOneshotBidiWriteBufferSender … (#14166)
…to prevent deadlocks Noticed a deadlock when OneShot sender was bombarded with a lot of data during a PCU benchmarking run --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: cpriti-os <202586561+cpriti-os@users.noreply.github.com>
1 parent 80a56cd commit 340a39a

File tree

2 files changed

+221
-84
lines changed

2 files changed

+221
-84
lines changed

storage/grpc_writer.go

Lines changed: 103 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (c *grpcStorageClient) OpenWriter(params *openWriterParams, opts ...storage
240240
return w, nil
241241
}
242242

243-
// gRPCWriter is a wrapper around the the gRPC client-stream API that manages
243+
// gRPCWriter is a wrapper around the gRPC client-stream API that manages
244244
// sending chunks of data provided by the user over the stream.
245245
type gRPCWriter struct {
246246
preRunCtx context.Context
@@ -992,22 +992,6 @@ func (w *gRPCWriter) newGRPCOneshotBidiWriteBufferSender() *gRPCOneshotBidiWrite
992992

993993
func (s *gRPCOneshotBidiWriteBufferSender) err() error { return s.streamErr }
994994

995-
// drainInboundStream calls stream.Recv() repeatedly until an error is returned.
996-
// It returns the last Resource received on the stream, or nil if no Resource
997-
// was returned. drainInboundStream always returns a non-nil error. io.EOF
998-
// indicates all messages were successfully read.
999-
func drainInboundStream(stream storagepb.Storage_BidiWriteObjectClient) (object *storagepb.Object, err error) {
1000-
for err == nil {
1001-
var resp *storagepb.BidiWriteObjectResponse
1002-
resp, err = stream.Recv()
1003-
// GetResource() returns nil on a nil response
1004-
if resp.GetResource() != nil {
1005-
object = resp.GetResource()
1006-
}
1007-
}
1008-
return object, err
1009-
}
1010-
1011995
func (s *gRPCOneshotBidiWriteBufferSender) connect(ctx context.Context, cs gRPCBufSenderChans, opts ...gax.CallOption) {
1012996
s.streamErr = nil
1013997
ctx = gRPCWriteRequestParams{bucket: s.bucket}.apply(ctx)
@@ -1019,59 +1003,93 @@ func (s *gRPCOneshotBidiWriteBufferSender) connect(ctx context.Context, cs gRPCB
10191003
}
10201004

10211005
go func() {
1022-
firstSend := true
1023-
for r := range cs.requests {
1024-
if r.requestAck {
1025-
cs.requestAcks <- struct{}{}
1026-
continue
1027-
}
1006+
var sendErr, recvErr error
1007+
sendDone := make(chan struct{})
1008+
recvDone := make(chan struct{})
10281009

1029-
var bufChecksum *uint32
1030-
if !s.disableAutoChecksum {
1031-
bufChecksum = proto.Uint32(crc32.Checksum(r.buf, crc32cTable))
1032-
}
1033-
objectChecksums := getObjectChecksums(&getObjectChecksumsParams{
1034-
sendCRC32C: s.sendCRC32C,
1035-
objectAttrs: s.objectAttrs,
1036-
fullObjectChecksum: s.fullObjectChecksum,
1037-
disableAutoChecksum: s.disableAutoChecksum,
1038-
finishWrite: r.finishWrite,
1039-
})
1040-
req := bidiWriteObjectRequest(r, bufChecksum, objectChecksums)
1041-
1042-
if firstSend {
1043-
proto.Merge(req, s.firstMessage)
1044-
firstSend = false
1045-
}
1010+
go func() {
1011+
sendErr = func() error {
1012+
firstSend := true
1013+
for {
1014+
select {
1015+
case <-recvDone:
1016+
// Because `requests` is not connected to the gRPC machinery, we
1017+
// have to check for asynchronous termination on the receive side.
1018+
return nil
1019+
case r, ok := <-cs.requests:
1020+
if !ok {
1021+
stream.CloseSend()
1022+
return nil
1023+
}
1024+
if r.requestAck {
1025+
cs.requestAcks <- struct{}{}
1026+
continue
1027+
}
1028+
1029+
var bufChecksum *uint32
1030+
if !s.disableAutoChecksum {
1031+
bufChecksum = proto.Uint32(crc32.Checksum(r.buf, crc32cTable))
1032+
}
1033+
objectChecksums := getObjectChecksums(&getObjectChecksumsParams{
1034+
sendCRC32C: s.sendCRC32C,
1035+
objectAttrs: s.objectAttrs,
1036+
fullObjectChecksum: s.fullObjectChecksum,
1037+
disableAutoChecksum: s.disableAutoChecksum,
1038+
finishWrite: r.finishWrite,
1039+
})
1040+
req := bidiWriteObjectRequest(r, bufChecksum, objectChecksums)
1041+
1042+
if firstSend {
1043+
proto.Merge(req, s.firstMessage)
1044+
firstSend = false
1045+
}
10461046

1047-
if err := stream.Send(req); err != nil {
1048-
_, s.streamErr = drainInboundStream(stream)
1049-
if err != io.EOF {
1050-
s.streamErr = err
1047+
if err := stream.Send(req); err != nil {
1048+
return err
1049+
}
1050+
1051+
if r.finishWrite {
1052+
stream.CloseSend()
1053+
return nil
1054+
}
1055+
1056+
// Oneshot uploads assume all flushes succeed.
1057+
if r.flush {
1058+
select {
1059+
case cs.completions <- gRPCBidiWriteCompletion{flushOffset: r.offset + int64(len(r.buf))}:
1060+
case <-stream.Context().Done():
1061+
return stream.Context().Err()
1062+
}
1063+
}
1064+
}
10511065
}
1052-
close(cs.completions)
1053-
return
1054-
}
1066+
}()
1067+
close(sendDone)
1068+
}()
10551069

1056-
if r.finishWrite {
1057-
stream.CloseSend()
1058-
// Oneshot uploads only read from the response stream on completion or
1059-
// failure
1060-
obj, err := drainInboundStream(stream)
1061-
if obj == nil || err != io.EOF {
1062-
s.streamErr = err
1063-
} else {
1064-
cs.completions <- gRPCBidiWriteCompletion{flushOffset: obj.GetSize(), resource: obj}
1070+
go func() {
1071+
recvErr = func() error {
1072+
for {
1073+
resp, err := stream.Recv()
1074+
if err != nil {
1075+
return err
1076+
}
1077+
if c := completion(resp); c != nil {
1078+
select {
1079+
case cs.completions <- *c:
1080+
case <-stream.Context().Done():
1081+
return stream.Context().Err()
1082+
}
1083+
}
10651084
}
1066-
close(cs.completions)
1067-
return
1068-
}
1085+
}()
1086+
close(recvDone)
1087+
}()
10691088

1070-
// Oneshot uploads assume all flushes succeed
1071-
if r.flush {
1072-
cs.completions <- gRPCBidiWriteCompletion{flushOffset: r.offset + int64(len(r.buf))}
1073-
}
1074-
}
1089+
<-sendDone
1090+
<-recvDone
1091+
s.streamErr = pickStreamError(recvErr, sendErr)
1092+
close(cs.completions)
10751093
}()
10761094
}
10771095

@@ -1203,7 +1221,11 @@ func (s *gRPCResumableBidiWriteBufferSender) connect(ctx context.Context, cs gRP
12031221
return err
12041222
}
12051223
if c := completion(resp); c != nil {
1206-
cs.completions <- *c
1224+
select {
1225+
case cs.completions <- *c:
1226+
case <-stream.Context().Done():
1227+
return stream.Context().Err()
1228+
}
12071229
}
12081230
}
12091231
}()
@@ -1212,15 +1234,7 @@ func (s *gRPCResumableBidiWriteBufferSender) connect(ctx context.Context, cs gRP
12121234

12131235
<-sendDone
12141236
<-recvDone
1215-
// Prefer recvErr since that's where RPC errors are delivered
1216-
if recvErr != nil {
1217-
s.streamErr = recvErr
1218-
} else if sendErr != nil {
1219-
s.streamErr = sendErr
1220-
}
1221-
if s.streamErr == io.EOF {
1222-
s.streamErr = nil
1223-
}
1237+
s.streamErr = pickStreamError(recvErr, sendErr)
12241238
close(cs.completions)
12251239
}()
12261240
}
@@ -1329,7 +1343,11 @@ func (s *gRPCAppendBidiWriteBufferSender) handleStream(stream storagepb.Storage_
13291343
s.maybeUpdateFirstMessage(resp)
13301344

13311345
if c := completion(resp); c != nil {
1332-
cs.completions <- *c
1346+
select {
1347+
case cs.completions <- *c:
1348+
case <-stream.Context().Done():
1349+
return stream.Context().Err()
1350+
}
13331351
}
13341352
}
13351353
}()
@@ -1338,15 +1356,7 @@ func (s *gRPCAppendBidiWriteBufferSender) handleStream(stream storagepb.Storage_
13381356

13391357
<-sendDone
13401358
<-recvDone
1341-
// Prefer recvErr since that's where RPC errors are delivered
1342-
if recvErr != nil {
1343-
s.streamErr = recvErr
1344-
} else if sendErr != nil {
1345-
s.streamErr = sendErr
1346-
}
1347-
if s.streamErr == io.EOF {
1348-
s.streamErr = nil
1349-
}
1359+
s.streamErr = pickStreamError(recvErr, sendErr)
13501360
close(cs.completions)
13511361
}
13521362

@@ -1613,3 +1623,12 @@ func withBidiWriteObjectRedirectionErrorRetries(s *settings) (newr *retryConfig)
16131623
}
16141624
return newr
16151625
}
1626+
1627+
// pickStreamError determines the final error to be reported by prioritizing recvErr.
1628+
// An io.EOF from a receiver is not considered an error.
1629+
func pickStreamError(recvErr, sendErr error) error {
1630+
if recvErr != nil && recvErr != io.EOF {
1631+
return recvErr
1632+
}
1633+
return sendErr
1634+
}

storage/grpc_writer_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ package storage
1616

1717
import (
1818
"context"
19+
"errors"
20+
"io"
1921
"sync"
2022
"testing"
23+
"time"
2124

2225
"cloud.google.com/go/storage/internal/apiv2/storagepb"
2326
gax "github.com/googleapis/gax-go/v2"
@@ -294,3 +297,118 @@ func filterDataRequests(reqs []gRPCBidiWriteRequest) []gRPCBidiWriteRequest {
294297
}
295298
return dataReqs
296299
}
300+
301+
// Test the logic correctly handles the combination of io.EOF
302+
// from Recv (recvErr) and a generic error from Send (sendErr).
303+
func TestGRPCWriterErrorHandling(t *testing.T) {
304+
// As this is deeply embedded in the unexported types, we verify the logic
305+
// by simulating the exact error assignment sequence.
306+
tests := []struct {
307+
name string
308+
recvErr error
309+
sendErr error
310+
wantError error
311+
}{
312+
{
313+
name: "recvErr is io.EOF, sendErr is nil",
314+
recvErr: io.EOF,
315+
sendErr: nil,
316+
wantError: nil,
317+
},
318+
{
319+
name: "recvErr is io.EOF, sendErr is an error",
320+
recvErr: io.EOF,
321+
sendErr: errors.New("send error"),
322+
wantError: errors.New("send error"), // Send error takes precedence.
323+
},
324+
{
325+
name: "recvErr is an error, sendErr is nil",
326+
recvErr: errors.New("recv error"),
327+
sendErr: nil,
328+
wantError: errors.New("recv error"), // Recv error takes precedence.
329+
},
330+
{
331+
name: "recvErr is an error, sendErr is an error",
332+
recvErr: errors.New("recv error"),
333+
sendErr: errors.New("send error"),
334+
wantError: errors.New("recv error"), // Recv error takes precedence.
335+
},
336+
}
337+
338+
for _, tt := range tests {
339+
t.Run(tt.name, func(t *testing.T) {
340+
var streamErr error
341+
342+
streamErr = pickStreamError(tt.recvErr, tt.sendErr)
343+
344+
if tt.wantError == nil {
345+
if streamErr != nil {
346+
t.Errorf("got error %v, want nil", streamErr)
347+
}
348+
} else {
349+
if streamErr == nil || streamErr.Error() != tt.wantError.Error() {
350+
t.Errorf("got error %v, want %v", streamErr, tt.wantError)
351+
}
352+
}
353+
})
354+
}
355+
}
356+
357+
// TestGRPCWriter_Deadlock simulates a deadlock scenario if Recv and Send channels
358+
// were not isolated in gRPCOneshotBidiWriteBufferSender.
359+
func TestGRPCWriter_Deadlock(t *testing.T) {
360+
// A timeout means a deadlock likely occurred.
361+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
362+
defer cancel()
363+
364+
sendDone := make(chan struct{})
365+
recvDone := make(chan struct{})
366+
367+
requests := make(chan gRPCBidiWriteRequest)
368+
completions := make(chan gRPCBidiWriteCompletion)
369+
370+
var sendErr error
371+
372+
go func() {
373+
sendErr = func() error {
374+
for {
375+
select {
376+
case <-recvDone:
377+
return nil
378+
case r, ok := <-requests:
379+
if !ok {
380+
return nil
381+
}
382+
if r.requestAck {
383+
continue
384+
}
385+
// mimic send logic
386+
if r.finishWrite {
387+
return nil
388+
}
389+
}
390+
}
391+
}()
392+
close(sendDone)
393+
}()
394+
395+
go func() {
396+
// Mimic recv loop that immediately exits.
397+
// If recvDone isn't checked by the sender loop, sending
398+
// requests could block forever if the consumer closes early.
399+
close(recvDone)
400+
}()
401+
402+
// sendDone should be closed immediately.
403+
select {
404+
case <-sendDone:
405+
// Success, no deadlock.
406+
case <-ctx.Done():
407+
t.Fatal("deadlock detected: send loop did not exit after recvDone was closed")
408+
}
409+
410+
if sendErr != nil {
411+
t.Errorf("expected no error, got %v", sendErr)
412+
}
413+
close(completions)
414+
}

0 commit comments

Comments
 (0)