Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions credentials/alts/internal/conn/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ import (
"net"
"reflect"
"strings"
"syscall"
"testing"
"time"

"golang.org/x/sys/unix"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/internal/grpctest"
)
Expand Down Expand Up @@ -105,6 +108,94 @@ func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (cli
return clientConn, serverConn
}

// newTCPConnPair returns a pair of conns backed by TCP over loopback.
func newTCPConnPair(rp string, clientProtected []byte, serverProtected []byte) (*conn, *conn, error) {
const address = "localhost:50935"

// Start the server.
serverChan := make(chan net.Conn)
listenChan := make(chan struct{})
go func() {
listener, err := net.Listen("tcp4", address)
if err != nil {
panic(fmt.Sprintf("failed to listen: %v", err))
}
defer listener.Close()
listenChan <- struct{}{}
conn, err := listener.Accept()
if err != nil {
panic(fmt.Sprintf("failed to aceept: %v", err))
}
serverChan <- conn
}()

// Ensure the server is listening before trying to connect.
<-listenChan
clientTCP, err := net.DialTimeout("tcp4", address, 5*time.Second)
if err != nil {
return nil, nil, fmt.Errorf("failed to Dial: %w", err)
}

// Get the server-side connection returned by Accept().
var serverTCP net.Conn
select {
case serverTCP = <-serverChan:
case <-time.After(5 * time.Second):
return nil, nil, fmt.Errorf("timed out waiting for server conn")
}

// Make the connection behave a little bit like a real one by imposing
// an MTU.
clientTCP = &mtuConn{clientTCP, 1500}

// 16 arbitrary bytes.
key := []byte{
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88,
0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
}

client, err := NewConn(clientTCP, core.ClientSide, rp, key, clientProtected)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}
server, err := NewConn(serverTCP, core.ServerSide, rp, key, serverProtected)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}

return client.(*conn), server.(*conn), nil
}

// mtuConn imposes an MTU on writes. It simulates an important quality of real
// network traffic that is lost when using loopback devices. On loopback, even
// large messages (e.g. 512 KiB) when written often arrive at the receiver
// instantaneously as a single payload. By explicitly splitting such writes into
// smaller, MTU-sized paylaods we give the receiver a chance to respond to
// smaller message sizes.
type mtuConn struct {
net.Conn
mtu int
}

// Write implements net.Conn.
func (rc *mtuConn) Write(buf []byte) (int, error) {
var written int
for len(buf) > 0 {
n, err := rc.Conn.Write(buf[:min(rc.mtu, len(buf))])
written += n
if err != nil {
return written, err
}
buf = buf[n:]
}
return written, nil
}

// SyscallConn implements syscall.Conn.
func (rc *mtuConn) SycallConn() (syscall.RawConn, error) {
return rc.Conn.(syscall.Conn).SyscallConn()
}

func testPingPong(t *testing.T, rp string) {
clientConn, serverConn := newConnPair(rp, nil, nil)
clientMsg := []byte("Client Message")
Expand Down Expand Up @@ -231,6 +322,115 @@ func BenchmarkLargeMessage(b *testing.B) {
}
}

// BenchmarkTCP is a simple throughput test that sends payloads over a local TCP
// connection.
func BenchmarkTCP(b *testing.B) {
tcs := []struct {
name string
size int
}{
{"1 KiB", 1024},
{"4 KiB", 4 * 1024},
{"64 KiB", 64 * 1024},
{"512 KiB", 512 * 1024},
{"1 MiB", 1024 * 1024},
{"4 MiB", 4 * 1024 * 1024},
}
for _, tc := range tcs {
b.Run("size="+tc.name, func(b *testing.B) {
benchmarkTCP(b, tc.size)
})
}
}

// sum makes unwanted compiler optimizations in benchmarkTCP's loop less likely.
var sum int

func benchmarkTCP(b *testing.B, size int) {
// Initialize the connection.
client, server, err := newTCPConnPair(rekeyRecordProtocol, nil, nil)
if err != nil {
b.Fatalf("failed to create TCP conn pair: %v", err)
}
defer client.Close()
defer server.Close()

rcvBuf := make([]byte, size)
sndBuf := make([]byte, size)
done := make(chan struct{})
errChan := make(chan error)

// Launch a writer goroutine.
go func() {
for {
select {
case <-done:
return
default:
}
n, err := client.Write(sndBuf)
if n != size || err != nil {
errChan <- fmt.Errorf("Write() = %v, %v; want %v, <nil>", n, err, size)
return
}
// Act a bit like a real workload that can't just fill
// every buffer immediately.
time.Sleep(10 * time.Millisecond)
}
}()

// Get the initial rusage so we can measure CPU time.
var startUsage unix.Rusage
if err := unix.Getrusage(unix.RUSAGE_SELF, &startUsage); err != nil {
b.Fatalf("failed to get initial rusage: %v", err)
}

// Read as much as possible.
var rcvd uint64
for b.Loop() {
n, err := io.ReadFull(server, rcvBuf)
rcvd += uint64(n)
if n != size || err != nil {
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, size)
}
// Act a bit like a real workload and utilize received bytes.
for _, b := range rcvBuf[:n] {
sum += int(b)
}
}

// Turn off the writer.
done <- struct{}{}

// Get the ending rusage.
var endUsage unix.Rusage
if err := unix.Getrusage(unix.RUSAGE_SELF, &endUsage); err != nil {
b.Fatalf("failed to get final rusage: %v", err)
}

// Error check the writer goroutine.
select {
case err := <-errChan:
b.Fatal(err)
default:
}

// Emit extra metrics.
utime := timevalDiffUsec(&startUsage.Utime, &endUsage.Utime)
stime := timevalDiffUsec(&startUsage.Stime, &endUsage.Stime)
b.ReportMetric(float64(utime)/float64(b.N), "usr-usec/op")
b.ReportMetric(float64(stime)/float64(b.N), "sys-usec/op")
b.ReportMetric(float64(stime+utime)/float64(b.N), "cpu-usec/op")
b.ReportMetric(float64(rcvd*8/(1024*1024))/float64(b.Elapsed().Seconds()), "Mbps")
}

// timevalDiffUsec returns the difference in microseconds between start and end.
func timevalDiffUsec(start, end *unix.Timeval) int64 {
// Note: the int64 type conversion is needed because unix.Timeval uses
// 32 bit values on some architectures.
return int64(1_000_000*(end.Sec-start.Sec) + end.Usec - start.Usec)
}

func testIncorrectMsgType(t *testing.T, rp string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
Expand Down
Loading