@@ -8,10 +8,11 @@ import (
88 "sync"
99 "time"
1010
11+ "github.com/ethereum/go-ethereum/common/hexutil"
1112 "github.com/golang/snappy"
1213 "github.com/libp2p/go-libp2p/core/network"
1314 dynssz "github.com/pk910/dynamic-ssz"
14- errors "github.com/pkg/errors"
15+ "github.com/pkg/errors"
1516 log "github.com/sirupsen/logrus"
1617)
1718
@@ -127,12 +128,19 @@ func readVarint(r io.Reader) (uint64, error) {
127128
128129// writeRequest writes a request to the stream with SSZ+Snappy encoding
129130func (r * ReqResp ) writeRequest (stream network.Stream , req any ) error {
131+ defer stream .CloseWrite ()
132+
130133 // Set write deadline
131134 if err := stream .SetWriteDeadline (time .Now ().Add (r .cfg .WriteTimeout )); err != nil {
132135 return fmt .Errorf ("failed to set write deadline: %w" , err )
133136 }
134137
135- // Marshal to SSZ
138+ if req == nil {
139+ // we close the write side of the stream immediately, communicating we have no data to send
140+ return nil
141+ }
142+
143+ // Marshal to SSZ if the request is not nil
136144 data , err := sszCodec .MarshalSSZ (req )
137145 if err != nil {
138146 return fmt .Errorf ("failed to marshal SSZ: %w" , err )
@@ -157,6 +165,14 @@ func (r *ReqResp) writeRequest(stream network.Stream, req any) error {
157165 return fmt .Errorf ("failed to compress data: %w" , err )
158166 }
159167
168+ log .WithFields (log.Fields {
169+ "protocol" : stream .Protocol (),
170+ "data" : hexutil .Encode (data ),
171+ "data_len" : len (data ),
172+ "wire_data" : hexutil .Encode (buf .Bytes ()),
173+ "wire_len" : buf .Len (),
174+ }).Debug ("writing request" )
175+
160176 // Write buffer to the stream
161177 if _ , err := io .Copy (stream , & buf ); err != nil {
162178 return fmt .Errorf ("failed to write final payload to stream: %w" , err )
@@ -271,7 +287,8 @@ func (r *ReqResp) readResponse(stream network.Stream, resp any) error {
271287 }).Debug ("Raw response code received" )
272288 }
273289
274- if code [0 ] != ResponseCodeSuccess {
290+ success := code [0 ] == ResponseCodeSuccess
291+ if ! success {
275292 if log .GetLevel () >= log .DebugLevel {
276293 errorType := getResponseCodeName (code [0 ])
277294 log .WithFields (log.Fields {
@@ -280,7 +297,6 @@ func (r *ReqResp) readResponse(stream network.Stream, resp any) error {
280297 "error_type" : errorType ,
281298 }).Debug ("Non-success response code received" )
282299 }
283- return fmt .Errorf ("RPC error code: %d" , code [0 ])
284300 }
285301
286302 // Read uncompressed length prefix
@@ -330,6 +346,21 @@ func (r *ReqResp) readResponse(stream network.Stream, resp any) error {
330346 }).Debug ("Raw response data received" )
331347 }
332348
349+ if ! success {
350+ var errorMessage ErrorMessage
351+ l := log .WithFields (log.Fields {
352+ "response_type" : responseType ,
353+ "raw_data_hex" : fmt .Sprintf ("0x%x" , data ),
354+ })
355+ if err := sszCodec .UnmarshalSSZ (& errorMessage , data ); err != nil {
356+ l .WithError (err ).Error ("failed to unmarshal SSZ error message" )
357+ return fmt .Errorf ("failed to unmarshal SSZ error message: %w" , err )
358+ }
359+ msg := string (errorMessage )
360+ l .Warnf ("RPC failed; error message: %s" , msg )
361+ return fmt .Errorf ("RPC failed: %s" , msg )
362+ }
363+
333364 // Unmarshal from SSZ
334365 if err := sszCodec .UnmarshalSSZ (resp , data ); err != nil {
335366 if log .GetLevel () >= log .DebugLevel {
0 commit comments