@@ -139,9 +139,9 @@ func (v *pktDescsManager) buffersForWritingToConn(packetCount int) (net.Buffers,
139139 return nil , fmt .Errorf ("vm_pkt_size %d exceeds maxPacketSize %d" , vmPktDesc .vm_pkt_size , v .maxPacketSize )
140140 }
141141 // Write packet size to the 4-byte header
142- binary .BigEndian .PutUint32 (v .backingBuffers [ i ][: headerSize ] , uint32 (vmPktDesc .vm_pkt_size ))
142+ binary .BigEndian .PutUint32 (v .headerBufferAt ( i ) , uint32 (vmPktDesc .vm_pkt_size ))
143143 // Resize buffer to include header and packet size
144- v .writingBuffers [i ] = v .backingBuffers [i ][:headerSize + uintptr ( vmPktDesc .vm_pkt_size )]
144+ v .writingBuffers [i ] = v .backingBuffers [i ][:headerSize + vmPktDesc .GetPacketSize ( )]
145145 }
146146 return v .writingBuffers [:packetCount ], nil
147147}
@@ -165,6 +165,28 @@ func (v *pktDescsManager) writePacketsToConn(conn net.Conn, packetCount int) err
165165 return nil
166166}
167167
168+ // buffersForReadingFromConn returns [net.Buffers] to read from the [net.Conn]
169+ // for the given index and offset.
170+ // It prepares buffer for the next header read as well if possible.
171+ func (v * pktDescsManager ) buffersForReadingFromConn (index , offset int ) net.Buffers {
172+ if offset < v .at (index ).GetPacketSize () {
173+ if index + 1 < v .maxPacketCount {
174+ // prepare next header read as well
175+ return net.Buffers {
176+ v .packetBufferAt (index , offset ),
177+ v .headerBufferAt (index + 1 ),
178+ }
179+ }
180+ return net.Buffers {
181+ v .packetBufferAt (index , offset ),
182+ }
183+ }
184+ headerOffset := offset - v .at (index ).GetPacketSize ()
185+ return net.Buffers {
186+ v .headerBufferAt (index )[headerOffset :],
187+ }
188+ }
189+
168190// readPacketsFromConn reads packets from the [net.Conn] into [VMPktDesc]s.
169191// - It returns the number of packets read.
170192// - The packets are expected to come one by one with 4-byte big-endian header indicating the packet size.
@@ -173,63 +195,51 @@ func (v *pktDescsManager) writePacketsToConn(conn net.Conn, packetCount int) err
173195func (v * pktDescsManager ) readPacketsFromConn (conn net.Conn ) (int , error ) {
174196 var packetCount int
175197 // Wait until 4-byte header is read
176- if _ , err := conn .Read (v .backingBuffers [ packetCount ][: headerSize ] ); err != nil {
198+ if _ , err := conn .Read (v .headerBufferAt ( packetCount ) ); err != nil {
177199 return 0 , fmt .Errorf ("conn.Read failed: %w" , err )
178200 }
179201 // Get rawConn for Readv
180202 rawConn , _ := conn .(syscall.Conn ).SyscallConn ()
181203 // Read available packets
182- var packetLen uint32
183- var bufs net.Buffers
184204 for {
185- packetLen = binary .BigEndian .Uint32 (v .backingBuffers [ packetCount ][: headerSize ] )
205+ packetLen := int ( binary .BigEndian .Uint32 (v .headerBufferAt ( packetCount )) )
186206 if packetLen == 0 || uint64 (packetLen ) > v .maxPacketSize {
187207 return 0 , fmt .Errorf ("invalid packetLen: %d (max %d)" , packetLen , v .maxPacketSize )
188208 }
209+ v .at (packetCount ).SetPacketSize (packetLen )
189210
190- // prepare buffers for reading packet and next header if any
191- if packetCount + 1 < v .maxPacketCount {
192- // prepare next header read as well
193- bufs = net.Buffers {
194- v .backingBuffers [packetCount ][headerSize : headerSize + uintptr (packetLen )],
195- v .backingBuffers [packetCount + 1 ][:headerSize ],
196- }
197- } else {
198- // prepare only packet read to avoid exceeding maxPacketCount
199- bufs = net.Buffers {
200- v .backingBuffers [packetCount ][headerSize : headerSize + uintptr (packetLen )],
201- }
202- }
203-
204- // Read packet from the connection
211+ // Read packet from the connection until full packet is read, including next header if possible.
205212 var bytesHasBeenRead int
206- var err error
207- rawConnReadErr := rawConn .Read (func (fd uintptr ) (done bool ) {
213+ var readErr error
214+ if rawConnReadErr := rawConn .Read (func (fd uintptr ) (done bool ) {
215+ bufs := v .buffersForReadingFromConn (packetCount , bytesHasBeenRead )
208216 // read packet into buffers
209- bytesHasBeenRead , err = unix .Readv (int (fd ), bufs )
210- if bytesHasBeenRead <= 0 {
217+ n , err : = unix .Readv (int (fd ), bufs )
218+ if n <= 0 {
211219 if errors .Is (err , syscall .EAGAIN ) {
212220 return false // try again later
213221 }
214- err = fmt .Errorf ("unix.Readv failed: %w" , err )
222+ readErr = fmt .Errorf ("unix.Readv failed: %w" , err )
223+ return true
224+ }
225+ bytesHasBeenRead += n
226+ if bytesHasBeenRead == packetLen + headerSize || bytesHasBeenRead == packetLen {
215227 return true
216228 }
217- // assumes partial read of a packet does not happen since packet len is already known
218- return true
219- })
220- if rawConnReadErr != nil {
229+ // Partial read, read again
230+ return false
231+ }); rawConnReadErr != nil {
221232 return 0 , fmt .Errorf ("rawConn.Read failed: %w" , rawConnReadErr )
222233 }
223- if err != nil {
224- return 0 , fmt .Errorf ("closure in rawConn.Read failed: %w" , err )
234+ if readErr != nil {
235+ return 0 , fmt .Errorf ("closure in rawConn.Read failed: %w" , readErr )
225236 }
226- v .at (packetCount ).SetPacketSize (int (packetLen ))
227237 packetCount ++
228- if bytesHasBeenRead == int (packetLen ) {
238+ if bytesHasBeenRead == packetLen + headerSize {
239+ // next packet header is also read, continue to read next packet
240+ } else if bytesHasBeenRead == packetLen {
229241 // next packet seems not available now, or reached maxPacketCount
230242 break
231- } else if bytesHasBeenRead != int (packetLen )+ int (headerSize ) {
232- return 0 , fmt .Errorf ("unexpected bytesHasBeenRead: %d (expected %d or %d)" , bytesHasBeenRead , packetLen , packetLen + uint32 (headerSize ))
233243 }
234244 }
235245 return packetCount , nil
0 commit comments