diff --git a/pkg/portfwd/client.go b/pkg/portfwd/client.go index 81c5c21e17b..944032945a4 100644 --- a/pkg/portfwd/client.go +++ b/pkg/portfwd/client.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net" + "sync/atomic" "time" "github.com/containers/gvisor-tap-vsock/pkg/services/forwarder" @@ -40,33 +41,35 @@ func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgen } func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.PacketConn, guestAddr string) { - id := fmt.Sprintf("udp-%s", conn.LocalAddr().String()) - - stream, err := client.Tunnel(ctx) - if err != nil { - logrus.Errorf("could not open udp tunnel for id: %s error:%v", id, err) - return - } - - // Handshake message to start tunnel - if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil { - logrus.Errorf("could not start udp tunnel for id: %s error:%v", id, err) - return - } + var udpConnectionCounter atomic.Uint32 + initialID := fmt.Sprintf("udp-%s", conn.LocalAddr().String()) + // gvisor-tap-vsock's UDPProxy demultiplexes client connections internally based on their source address. + // It calls this dialer function only when it receives a datagram from a new, unrecognized client. + // For each new client, we must return a new net.Conn, which in our case is a new gRPC stream. + // The atomic counter ensures that each stream has a unique ID to distinguish them on the server side. proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) { + id := fmt.Sprintf("%s-%d", initialID, udpConnectionCounter.Add(1)) + stream, err := client.Tunnel(ctx) + if err != nil { + return nil, fmt.Errorf("could not open udp tunnel for id: %s error:%w", id, err) + } + // Handshake message to start tunnel + if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil { + return nil, fmt.Errorf("could not start udp tunnel for id: %s error:%w", id, err) + } rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "udp"} return rw, nil }) if err != nil { - logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", id, err) + logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", initialID, err) return } defer func() { err := proxy.Close() if err != nil { - logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", id, err) + logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", initialID, err) } }() proxy.Run()