diff --git a/plugins/transport/socket/main.go b/plugins/transport/socket/main.go index bfac8fa7..1290cf9d 100644 --- a/plugins/transport/socket/main.go +++ b/plugins/transport/socket/main.go @@ -20,11 +20,13 @@ import ( ) const ( - maxBufferSize = 65535 - udp = "udp" - unix = "unix" - tcp = "tcp" - msgLengthSize = 8 + maxBufferSize = 65535 // 64KB - initial buffer size for all socket types and max for UDP (OS datagram limit) + maxBufferSizeUnix = 10485760 // 10MB - max buffer size for Unix domain sockets + maxBufferSizeTCP = 104857600 // 100MB - max buffer size for TCP (stream-based, can handle very large messages) + udp = "udp" + unix = "unix" + tcp = "tcp" + msgLengthSize = 8 ) var ( @@ -138,6 +140,17 @@ func (s *Socket) initTCPSocket() *net.TCPListener { return pc } +func (s *Socket) getMaxBufferSize() int64 { + switch s.conf.Type { + case udp: + return maxBufferSize + case tcp: + return maxBufferSizeTCP + default: + return maxBufferSizeUnix + } +} + func (s *Socket) WriteTCPMsg(w transport.WriteFn, msgBuffer []byte, n int) (int64, error) { var pos int64 var length int64 @@ -165,10 +178,13 @@ func (s *Socket) WriteTCPMsg(w transport.WriteFn, msgBuffer []byte, n int) (int6 return pos, nil } -func (s *Socket) ReceiveData(maxBuffSize int64, done chan bool, pc net.Conn, w transport.WriteFn) { +func (s *Socket) ReceiveData(initialBuffSize int64, done chan bool, pc net.Conn, w transport.WriteFn) { defer pc.Close() - msgBuffer := make([]byte, maxBuffSize) + currentBuffSize := initialBuffSize + maxBuffSize := s.getMaxBufferSize() + msgBuffer := make([]byte, currentBuffSize) var remainingMsg []byte + for { n, err := pc.Read(msgBuffer) if err != nil || n < 1 { @@ -180,17 +196,40 @@ func (s *Socket) ReceiveData(maxBuffSize int64, done chan bool, pc net.Conn, w t } return } - msgBuffer = append(remainingMsg, msgBuffer...) - // whole buffer was used, so we are potentially handling larger message - if n == len(msgBuffer) { - s.logger.Warnf("full read buffer used") + // Combine remaining data from previous iteration with newly read data + var data []byte + if len(remainingMsg) > 0 { + data = make([]byte, len(remainingMsg)+n) + copy(data, remainingMsg) + copy(data[len(remainingMsg):], msgBuffer[:n]) + } else { + data = msgBuffer[:n] + } + totalSize := len(data) + + // Check if buffer was completely filled - message may have been truncated + if n == int(currentBuffSize) { + if s.conf.Type == tcp { + s.logger.Debugf("full read buffer used (%d bytes), TCP will handle continuation if needed", n) + } else { + // For UDP/Unix sockets, buffer being full means message was likely truncated + if currentBuffSize < maxBuffSize { + newSize := currentBuffSize * 2 + if newSize > maxBuffSize { + newSize = maxBuffSize + } + s.logger.Warnf("message may have been truncated (buffer filled with %d bytes), growing buffer from %d to %d bytes for next message", currentBuffSize, currentBuffSize, newSize) + currentBuffSize = newSize + msgBuffer = make([]byte, currentBuffSize) + } else { + s.logger.Errorf(nil, "message truncated: buffer size (%d bytes) exceeded for %s socket and already at maximum buffer size (%d bytes)", currentBuffSize, s.conf.Type, maxBuffSize) + } + } } - - n += len(remainingMsg) if s.conf.DumpMessages.Enabled { - _, err := s.dumpBuf.Write(msgBuffer[:n]) + _, err := s.dumpBuf.Write(data) if err != nil { s.logger.Errorf(err, "writing to dump buffer") } @@ -202,16 +241,17 @@ func (s *Socket) ReceiveData(maxBuffSize int64, done chan bool, pc net.Conn, w t } if s.conf.Type == tcp { - parsed, err := s.WriteTCPMsg(w, msgBuffer, n) + parsed, err := s.WriteTCPMsg(w, data, totalSize) if err != nil { s.logger.Errorf(err, "error, while parsing messages") return } - remainingMsg = make([]byte, int64(n)-parsed) - copy(remainingMsg, msgBuffer[parsed:n]) + remainingMsg = make([]byte, int64(totalSize)-parsed) + copy(remainingMsg, data[parsed:totalSize]) } else { - w(msgBuffer[:n]) + w(data) msgCount++ + remainingMsg = nil } } } diff --git a/plugins/transport/socket/main_test.go b/plugins/transport/socket/main_test.go index d734248c..ac996892 100644 --- a/plugins/transport/socket/main_test.go +++ b/plugins/transport/socket/main_test.go @@ -16,7 +16,7 @@ import ( "gopkg.in/go-playground/assert.v1" ) -const regularBuffSize = 16384 +const regularBuffSize = 65535 // default buffer size const addition = "wubba lubba dub dub" func TestUnixSocketTransport(t *testing.T) { @@ -28,40 +28,39 @@ func TestUnixSocketTransport(t *testing.T) { logger, err := logging.NewLogger(logging.DEBUG, logpath) require.NoError(t, err) - sktpath := path.Join(tmpdir, "socket") - skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) - require.NoError(t, err) - defer skt.Close() - - trans := Socket{ - conf: configT{ - Path: sktpath, - }, - logger: &logWrapper{ - l: logger, - }, - } + t.Run("test normal message", func(t *testing.T) { + // Create a normal-sized message (5KB) + msg := make([]byte, 5000) + for i := 0; i < len(msg); i++ { + msg[i] = byte('A') + } + marker := []byte("--END--") + copy(msg[len(msg)-len(marker):], marker) - t.Run("test large message transport", func(t *testing.T) { - msg := make([]byte, regularBuffSize) - for i := 0; i < regularBuffSize; i++ { - msg[i] = byte('X') + sktpath := path.Join(tmpdir, "socket1") + skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) + require.NoError(t, err) + defer skt.Close() + + trans := Socket{ + conf: configT{ + Path: sktpath, + }, + logger: &logWrapper{ + l: logger, + }, } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) - // verify transport ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} + wg.Add(1) + var receivedMsg []byte go trans.Run(ctx, func(mess []byte) { - wg.Add(1) - strmsg := string(mess) - assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message - assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct + receivedMsg = mess wg.Done() }, make(chan bool)) - // wait for socket file to be created + // Wait for socket file to be created for { stat, err := os.Stat(sktpath) require.NoError(t, err) @@ -71,71 +70,151 @@ func TestUnixSocketTransport(t *testing.T) { time.Sleep(250 * time.Millisecond) } - // write to socket wskt, err := net.DialUnix("unixgram", nil, &net.UnixAddr{Name: sktpath, Net: "unixgram"}) require.NoError(t, err) _, err = wskt.Write(msg) require.NoError(t, err) - cancel() wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) wskt.Close() - }) -} - -func TestUdpSocketTransport(t *testing.T) { - tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) - - logpath := path.Join(tmpdir, "test.log") - logger, err := logging.NewLogger(logging.DEBUG, logpath) - require.NoError(t, err) - trans := Socket{ - conf: configT{ - Socketaddr: "127.0.0.1:8642", - Type: "udp", - }, - logger: &logWrapper{ - l: logger, - }, - } + // Verify we received the complete message + assert.Equal(t, len(msg), len(receivedMsg)) + // Verify the end marker is present + endMarkerPos := len(receivedMsg) - len(marker) + assert.Equal(t, string(marker), string(receivedMsg[endMarkerPos:])) + }) t.Run("test large message transport", func(t *testing.T) { - msg := make([]byte, regularBuffSize) - for i := 0; i < regularBuffSize; i++ { + // Create a message larger than initial buffer to test dynamic buffer growth + largeBuffSize := regularBuffSize * 2 // 131070 bytes + msg := make([]byte, largeBuffSize) + for i := 0; i < largeBuffSize; i++ { msg[i] = byte('X') } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) + msg[largeBuffSize-1] = byte('$') + msg = append(msg, []byte(addition)...) // Total: 131089 bytes + + // Setup socket using same pattern as sendUnixSocketMessage + sktpath := path.Join(tmpdir, "socket2") + skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) + require.NoError(t, err) + defer skt.Close() + + trans := Socket{ + conf: configT{ + Path: sktpath, + }, + logger: &logWrapper{ + l: logger, + }, + } - // verify transport ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var receivedMsgs [][]byte + var mutex sync.Mutex wg := sync.WaitGroup{} + wg.Add(3) // Expecting 3 messages + go trans.Run(ctx, func(mess []byte) { - wg.Add(1) - strmsg := string(mess) - assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message - assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct + mutex.Lock() + receivedMsgs = append(receivedMsgs, mess) + mutex.Unlock() wg.Done() }, make(chan bool)) - // write to socket - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8642") + // Wait for socket file to be created + for { + stat, err := os.Stat(sktpath) + require.NoError(t, err) + if stat.Mode()&os.ModeType == os.ModeSocket { + break + } + time.Sleep(250 * time.Millisecond) + } + + wskt, err := net.DialUnix("unixgram", nil, &net.UnixAddr{Name: sktpath, Net: "unixgram"}) + require.NoError(t, err) + defer wskt.Close() + + // Send the same message 3 times + _, err = wskt.Write(msg) require.NoError(t, err) - wskt, err := net.DialUDP("udp", nil, addr) + time.Sleep(100 * time.Millisecond) + + _, err = wskt.Write(msg) require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + _, err = wskt.Write(msg) require.NoError(t, err) - cancel() wg.Wait() - wskt.Close() + + // Verify we received 3 messages + require.Equal(t, 3, len(receivedMsgs)) + + // First message: the message is truncated to the maximum 64KB (65535 bytes) + require.Equal(t, len(receivedMsgs[0]), regularBuffSize) + + // Second message: check for 128KB (131070 bytes) with '$' at position 131069 + require.Equal(t, len(receivedMsgs[1]), largeBuffSize) + assert.Equal(t, byte('$'), receivedMsgs[1][131069]) + + // Third message: check for > 128KB (131070 bytes) with "wubba lubba dub dub" at the end + require.GreaterOrEqual(t, len(receivedMsgs[2]), largeBuffSize+len(addition)) + endStr := string(receivedMsgs[2][len(receivedMsgs[2])-len(addition):]) + assert.Equal(t, addition, endStr) }) } -func TestTcpSocketTransport(t *testing.T) { +// Helper function to send and receive UDP socket message +func sendUDPSocketMessage(t *testing.T, logger *logging.Logger, addr string, msg []byte) ([]byte, error) { + trans := Socket{ + conf: configT{ + Socketaddr: addr, + Type: "udp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + var receivedMsg []byte + messageReceived := false + go trans.Run(ctx, func(mess []byte) { + receivedMsg = mess + messageReceived = true + wg.Done() + }, make(chan bool)) + + // Wait for socket to be ready + time.Sleep(100 * time.Millisecond) + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + require.NoError(t, err) + wskt, err := net.DialUDP("udp", nil, udpAddr) + require.NoError(t, err) + _, writeErr := wskt.Write(msg) + + if writeErr == nil && messageReceived { + wg.Wait() + } + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() + + return receivedMsg, writeErr +} + +func TestUdpSocketTransport(t *testing.T) { tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") require.NoError(t, err) defer os.RemoveAll(tmpdir) @@ -144,9 +223,69 @@ func TestTcpSocketTransport(t *testing.T) { logger, err := logging.NewLogger(logging.DEBUG, logpath) require.NoError(t, err) + t.Run("test normal message", func(t *testing.T) { + // Create a normal message (5KB) + msg := make([]byte, 5000) + for i := 0; i < len(msg); i++ { + msg[i] = byte('U') + } + marker := []byte("--UDP-END--") + copy(msg[len(msg)-len(marker):], marker) + + receivedMsg, err := sendUDPSocketMessage(t, logger, "127.0.0.1:8650", msg) + require.NoError(t, err) + + // Verify we received the complete message + assert.Equal(t, len(msg), len(receivedMsg)) + // Verify the end marker is present + endMarkerPos := len(receivedMsg) - len(marker) + assert.Equal(t, string(marker), string(receivedMsg[endMarkerPos:])) + }) + + t.Run("test large message transport", func(t *testing.T) { + // Create message that exceeds UDP datagram limits + // UDP max payload is ~65507 bytes, we're trying to send 65535 + 19 = 65554 bytes + largeBuffSize := regularBuffSize - len(addition) + msg := make([]byte, largeBuffSize) + for i := 0; i < largeBuffSize; i++ { + msg[i] = byte('X') + } + msg[largeBuffSize-1] = byte('$') + msg = append(msg, []byte(addition)...) + + _, err := sendUDPSocketMessage(t, logger, "127.0.0.1:8652", msg) + + // Verify that sending a message that's too large for UDP fails + require.Error(t, err) + }) +} + +// Helper function to connect to TCP with retries +func connectTCPWithRetry(t *testing.T, addr string) net.Conn { + wskt, err := net.Dial("tcp", addr) + if err != nil { + for retries := 0; err != nil && retries < 3; retries++ { + time.Sleep(500 * time.Millisecond) + wskt, err = net.Dial("tcp", addr) + } + } + require.NoError(t, err) + return wskt +} + +// Helper function to create a TCP message with length header +func createTCPMessage(t *testing.T, content []byte) []byte { + msgLength := new(bytes.Buffer) + err := binary.Write(msgLength, binary.LittleEndian, uint64(len(content))) + require.NoError(t, err) + return append(msgLength.Bytes(), content...) +} + +// Helper function to send and verify TCP socket message with marker +func sendTCPSocketMessage(t *testing.T, logger *logging.Logger, addr string, msgSize int, fillByte byte, marker []byte) { trans := Socket{ conf: configT{ - Socketaddr: "127.0.0.1:8642", + Socketaddr: addr, Type: "tcp", }, logger: &logWrapper{ @@ -154,83 +293,171 @@ func TestTcpSocketTransport(t *testing.T) { }, } - t.Run("test large message transport single connection", func(t *testing.T) { - msg := make([]byte, regularBuffSize) - for i := 0; i < regularBuffSize; i++ { - msg[i] = byte('X') + msgContent := make([]byte, msgSize) + for i := 0; i < msgSize; i++ { + msgContent[i] = fillByte + } + copy(msgContent[len(msgContent)-len(marker):], marker) + + fullMsg := createTCPMessage(t, msgContent) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + go trans.Run(ctx, func(mess []byte) { + assert.Equal(t, msgSize, len(mess)) + endMarkerPos := len(mess) - len(marker) + assert.Equal(t, string(marker), string(mess[endMarkerPos:])) + wg.Done() + }, make(chan bool)) + + time.Sleep(100 * time.Millisecond) + + wskt := connectTCPWithRetry(t, addr) + _, err := wskt.Write(fullMsg) + require.NoError(t, err) + + wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() +} + +func TestTcpSocketTransport(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("test normal message", func(t *testing.T) { + // Create a normal message (5KB) + sendTCPSocketMessage(t, logger, "127.0.0.1:8660", 5000, 'T', []byte("--TCP-END--")) + }) + + t.Run("test message exceeding initial buffer", func(t *testing.T) { + // Create a message larger than initial buffer (100KB) + sendTCPSocketMessage(t, logger, "127.0.0.1:8661", 100000, 'B', []byte("--LARGE-TCP--")) + }) + + t.Run("test multiple large messages", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8663", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) - msgLength := new(bytes.Buffer) - err := binary.Write(msgLength, binary.LittleEndian, uint64(len(msg))) - require.NoError(t, err) - msg = append(msgLength.Bytes(), msg...) - // verify transport + numMessages := 3 + messageSizes := []int{80000, 120000, 90000} + var combinedMsg bytes.Buffer + + // Create multiple large messages + for i := 0; i < numMessages; i++ { + msgContent := make([]byte, messageSizes[i]) + fillByte := byte('0' + i) + for j := 0; j < messageSizes[i]; j++ { + msgContent[j] = fillByte + } + combinedMsg.Write(createTCPMessage(t, msgContent)) + } + + // Setup message verification ctx, cancel := context.WithCancel(context.Background()) + receivedCount := 0 + var mutex sync.Mutex wg := sync.WaitGroup{} + wg.Add(numMessages) + go trans.Run(ctx, func(mess []byte) { - wg.Add(1) - strmsg := string(mess) - assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message - assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct - wg.Done() + mutex.Lock() + defer mutex.Unlock() + + // Verify message size matches one of our expected sizes + found := false + for i, expectedSize := range messageSizes { + if len(mess) == expectedSize { + expectedByte := byte('0' + i) + allMatch := true + for _, b := range mess { + if b != expectedByte { + allMatch = false + break + } + } + if allMatch { + found = true + receivedCount++ + wg.Done() + break + } + } + } + assert.Equal(t, true, found) }, make(chan bool)) - // write to socket - wskt, err := net.Dial("tcp", "127.0.0.1:8642") - if err != nil { - // The socket might not be listening yet, wait a little bit and try to connect again - for retries := 0; err != nil && retries < 3; retries++ { - time.Sleep(2 * time.Second) - wskt, err = net.Dial("tcp", "127.0.0.1:8642") - } - } - require.NoError(t, err) - _, err = wskt.Write(msg) + // Wait for socket to be ready + time.Sleep(100 * time.Millisecond) + + // Connect and send all messages + wskt := connectTCPWithRetry(t, "127.0.0.1:8663") + _, err = wskt.Write(combinedMsg.Bytes()) require.NoError(t, err) - cancel() wg.Wait() + + mutex.Lock() + assert.Equal(t, numMessages, receivedCount) + mutex.Unlock() + + cancel() + time.Sleep(100 * time.Millisecond) wskt.Close() }) t.Run("test large message transport multiple connections", func(t *testing.T) { - msg := make([]byte, regularBuffSize) + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8665", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + msgContent := make([]byte, regularBuffSize) for i := 0; i < regularBuffSize; i++ { - msg[i] = byte('X') + msgContent[i] = byte('X') } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) - msgLength := new(bytes.Buffer) - err := binary.Write(msgLength, binary.LittleEndian, uint64(len(msg))) - require.NoError(t, err) - msg = append(msgLength.Bytes(), msg...) + msgContent[regularBuffSize-1] = byte('$') + msgContent = append(msgContent, []byte(addition)...) + msg := createTCPMessage(t, msgContent) // verify transport ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} + wg.Add(2) go trans.Run(ctx, func(mess []byte) { - wg.Add(1) strmsg := string(mess) assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct wg.Done() }, make(chan bool)) + // Wait for socket to be ready + time.Sleep(100 * time.Millisecond) + // write to socket - wskt1, err := net.Dial("tcp", "127.0.0.1:8642") - if err != nil { - // The socket might not be listening yet, wait a little bit and try to connect again - for retries := 0; err != nil && retries < 3; retries++ { - time.Sleep(2 * time.Second) - wskt1, err = net.Dial("tcp", "127.0.0.1:8642") - } - } - require.NoError(t, err) + wskt1 := connectTCPWithRetry(t, "127.0.0.1:8665") // We shouldn't need to retry the second connection, if this fails, then something is wrong - wskt2, err := net.Dial("tcp", "127.0.0.1:8642") + wskt2, err := net.Dial("tcp", "127.0.0.1:8665") require.NoError(t, err) _, err = wskt1.Write(msg) @@ -238,8 +465,9 @@ func TestTcpSocketTransport(t *testing.T) { _, err = wskt2.Write(msg) require.NoError(t, err) - cancel() wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) wskt1.Close() wskt2.Close() })