diff --git a/credentials/alts/internal/conn/record_test.go b/credentials/alts/internal/conn/record_test.go index e4992489a189..29475df67ac5 100644 --- a/credentials/alts/internal/conn/record_test.go +++ b/credentials/alts/internal/conn/record_test.go @@ -27,8 +27,11 @@ import ( "net" "reflect" "strings" + "syscall" "testing" + "time" + "golang.org/x/sys/unix" core "google.golang.org/grpc/credentials/alts/internal" "google.golang.org/grpc/internal/grpctest" ) @@ -105,6 +108,94 @@ func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (cli return clientConn, serverConn } +// newTCPConnPair returns a pair of conns backed by TCP over loopback. +func newTCPConnPair(rp string, clientProtected []byte, serverProtected []byte) (*conn, *conn, error) { + const address = "localhost:50935" + + // Start the server. + serverChan := make(chan net.Conn) + listenChan := make(chan struct{}) + go func() { + listener, err := net.Listen("tcp4", address) + if err != nil { + panic(fmt.Sprintf("failed to listen: %v", err)) + } + defer listener.Close() + listenChan <- struct{}{} + conn, err := listener.Accept() + if err != nil { + panic(fmt.Sprintf("failed to aceept: %v", err)) + } + serverChan <- conn + }() + + // Ensure the server is listening before trying to connect. + <-listenChan + clientTCP, err := net.DialTimeout("tcp4", address, 5*time.Second) + if err != nil { + return nil, nil, fmt.Errorf("failed to Dial: %w", err) + } + + // Get the server-side connection returned by Accept(). + var serverTCP net.Conn + select { + case serverTCP = <-serverChan: + case <-time.After(5 * time.Second): + return nil, nil, fmt.Errorf("timed out waiting for server conn") + } + + // Make the connection behave a little bit like a real one by imposing + // an MTU. + clientTCP = &mtuConn{clientTCP, 1500} + + // 16 arbitrary bytes. + key := []byte{ + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, + 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, + } + + client, err := NewConn(clientTCP, core.ClientSide, rp, key, clientProtected) + if err != nil { + panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) + } + server, err := NewConn(serverTCP, core.ServerSide, rp, key, serverProtected) + if err != nil { + panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) + } + + return client.(*conn), server.(*conn), nil +} + +// mtuConn imposes an MTU on writes. It simulates an important quality of real +// network traffic that is lost when using loopback devices. On loopback, even +// large messages (e.g. 512 KiB) when written often arrive at the receiver +// instantaneously as a single payload. By explicitly splitting such writes into +// smaller, MTU-sized paylaods we give the receiver a chance to respond to +// smaller message sizes. +type mtuConn struct { + net.Conn + mtu int +} + +// Write implements net.Conn. +func (rc *mtuConn) Write(buf []byte) (int, error) { + var written int + for len(buf) > 0 { + n, err := rc.Conn.Write(buf[:min(rc.mtu, len(buf))]) + written += n + if err != nil { + return written, err + } + buf = buf[n:] + } + return written, nil +} + +// SyscallConn implements syscall.Conn. +func (rc *mtuConn) SycallConn() (syscall.RawConn, error) { + return rc.Conn.(syscall.Conn).SyscallConn() +} + func testPingPong(t *testing.T, rp string) { clientConn, serverConn := newConnPair(rp, nil, nil) clientMsg := []byte("Client Message") @@ -231,6 +322,115 @@ func BenchmarkLargeMessage(b *testing.B) { } } +// BenchmarkTCP is a simple throughput test that sends payloads over a local TCP +// connection. +func BenchmarkTCP(b *testing.B) { + tcs := []struct { + name string + size int + }{ + {"1 KiB", 1024}, + {"4 KiB", 4 * 1024}, + {"64 KiB", 64 * 1024}, + {"512 KiB", 512 * 1024}, + {"1 MiB", 1024 * 1024}, + {"4 MiB", 4 * 1024 * 1024}, + } + for _, tc := range tcs { + b.Run("size="+tc.name, func(b *testing.B) { + benchmarkTCP(b, tc.size) + }) + } +} + +// sum makes unwanted compiler optimizations in benchmarkTCP's loop less likely. +var sum int + +func benchmarkTCP(b *testing.B, size int) { + // Initialize the connection. + client, server, err := newTCPConnPair(rekeyRecordProtocol, nil, nil) + if err != nil { + b.Fatalf("failed to create TCP conn pair: %v", err) + } + defer client.Close() + defer server.Close() + + rcvBuf := make([]byte, size) + sndBuf := make([]byte, size) + done := make(chan struct{}) + errChan := make(chan error) + + // Launch a writer goroutine. + go func() { + for { + select { + case <-done: + return + default: + } + n, err := client.Write(sndBuf) + if n != size || err != nil { + errChan <- fmt.Errorf("Write() = %v, %v; want %v, ", n, err, size) + return + } + // Act a bit like a real workload that can't just fill + // every buffer immediately. + time.Sleep(10 * time.Millisecond) + } + }() + + // Get the initial rusage so we can measure CPU time. + var startUsage unix.Rusage + if err := unix.Getrusage(unix.RUSAGE_SELF, &startUsage); err != nil { + b.Fatalf("failed to get initial rusage: %v", err) + } + + // Read as much as possible. + var rcvd uint64 + for b.Loop() { + n, err := io.ReadFull(server, rcvBuf) + rcvd += uint64(n) + if n != size || err != nil { + b.Fatalf("Read() = %v, %v; want %v, ", n, err, size) + } + // Act a bit like a real workload and utilize received bytes. + for _, b := range rcvBuf[:n] { + sum += int(b) + } + } + + // Turn off the writer. + done <- struct{}{} + + // Get the ending rusage. + var endUsage unix.Rusage + if err := unix.Getrusage(unix.RUSAGE_SELF, &endUsage); err != nil { + b.Fatalf("failed to get final rusage: %v", err) + } + + // Error check the writer goroutine. + select { + case err := <-errChan: + b.Fatal(err) + default: + } + + // Emit extra metrics. + utime := timevalDiffUsec(&startUsage.Utime, &endUsage.Utime) + stime := timevalDiffUsec(&startUsage.Stime, &endUsage.Stime) + b.ReportMetric(float64(utime)/float64(b.N), "usr-usec/op") + b.ReportMetric(float64(stime)/float64(b.N), "sys-usec/op") + b.ReportMetric(float64(stime+utime)/float64(b.N), "cpu-usec/op") + b.ReportMetric(float64(rcvd*8/(1024*1024))/float64(b.Elapsed().Seconds()), "Mbps") +} + +// timevalDiffUsec returns the difference in microseconds between start and end. +func timevalDiffUsec(start, end *unix.Timeval) int64 { + // Note: the int64 type conversion is needed because unix.Timeval uses + // 32 bit values on some architectures. + return int64(1_000_000*(end.Sec-start.Sec) + end.Usec - start.Usec) +} + func testIncorrectMsgType(t *testing.T, rp string) { // framedMsg is an empty ciphertext with correct framing but wrong // message type.