diff --git a/adapter/grpc_test.go b/adapter/grpc_test.go index 850506e..1d16f68 100644 --- a/adapter/grpc_test.go +++ b/adapter/grpc_test.go @@ -3,7 +3,6 @@ package adapter import ( "context" "strconv" - "strings" "sync" "testing" @@ -164,8 +163,7 @@ func Test_grpc_transaction(t *testing.T) { } func rawKVClient(t *testing.T, hosts []string) pb.RawKVClient { - dials := "multi:///" + strings.Join(hosts, ",") - conn, err := grpc.NewClient(dials, + conn, err := grpc.NewClient(hosts[0], grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), ) @@ -175,8 +173,7 @@ func rawKVClient(t *testing.T, hosts []string) pb.RawKVClient { } func transactionalKVClient(t *testing.T, hosts []string) pb.TransactionalKVClient { - dials := "multi:///" + strings.Join(hosts, ",") - conn, err := grpc.NewClient(dials, + conn, err := grpc.NewClient(hosts[0], grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), ) diff --git a/adapter/test_util.go b/adapter/test_util.go index cec1a29..8c2abb9 100644 --- a/adapter/test_util.go +++ b/adapter/test_util.go @@ -1,6 +1,8 @@ package adapter import ( + "context" + "log" "net" "strconv" "sync" @@ -26,6 +28,14 @@ func shutdown(nodes []Node) { for _, n := range nodes { n.grpcServer.Stop() n.redisServer.Stop() + if n.raft != nil { + n.raft.Shutdown() + } + if n.tm != nil { + if err := n.tm.Close(); err != nil { + log.Printf("transport close: %v", err) + } + } } } @@ -80,15 +90,19 @@ type Node struct { redisAddress string grpcServer *grpc.Server redisServer *RedisServer + raft *raft.Raft + tm *transport.Manager } -func newNode(grpcAddress, raftAddress, redisAddress string, grpcs *grpc.Server, rd *RedisServer) Node { +func newNode(grpcAddress, raftAddress, redisAddress string, r *raft.Raft, tm *transport.Manager, grpcs *grpc.Server, rd *RedisServer) Node { return Node{ grpcAddress: grpcAddress, raftAddress: raftAddress, redisAddress: redisAddress, grpcServer: grpcs, redisServer: rd, + raft: r, + tm: tm, } } @@ -98,9 +112,17 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) { var redisAdders []string var nodes []Node + const ( + waitTimeout = 5 * time.Second + waitInterval = 100 * time.Millisecond + ) + cfg := raft.Configuration{} ports := make([]portsAdress, n) + ctx := context.Background() + var lc net.ListenConfig + // port assign for i := 0; i < n; i++ { ports[i] = portAssigner() @@ -145,7 +167,7 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) { leaderhealth.Setup(r, s, []string{"Example"}) raftadmin.Register(s, r) - grpcSock, err := net.Listen("tcp", port.grpcAddress) + grpcSock, err := lc.Listen(ctx, "tcp", port.grpcAddress) assert.NoError(t, err) grpcAdders = append(grpcAdders, port.grpcAddress) @@ -154,7 +176,7 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) { assert.NoError(t, s.Serve(grpcSock)) }() - l, err := net.Listen("tcp", port.redisAddress) + l, err := lc.Listen(ctx, "tcp", port.redisAddress) assert.NoError(t, err) rd := NewRedisServer(l, st, coordinator) go func() { @@ -165,14 +187,34 @@ func createNode(t *testing.T, n int) ([]Node, []string, []string) { port.grpcAddress, port.raftAddress, port.redisAddress, + r, + tm, s, - rd), - ) + rd, + )) } - //nolint:mnd - time.Sleep(10 * time.Second) + d := &net.Dialer{Timeout: time.Second} + for _, n := range nodes { + assert.Eventually(t, func() bool { + conn, err := d.DialContext(ctx, "tcp", n.grpcAddress) + if err != nil { + return false + } + _ = conn.Close() + conn, err = d.DialContext(ctx, "tcp", n.redisAddress) + if err != nil { + return false + } + _ = conn.Close() + return true + }, waitTimeout, waitInterval) + } + + assert.Eventually(t, func() bool { + return nodes[0].raft.State() == raft.Leader + }, waitTimeout, waitInterval) return nodes, grpcAdders, redisAdders } diff --git a/cmd/server/demo.go b/cmd/server/demo.go index fdfd4db..d4d4c3b 100644 --- a/cmd/server/demo.go +++ b/cmd/server/demo.go @@ -1,6 +1,7 @@ package main import ( + "context" "log/slog" "net" "os" @@ -55,6 +56,8 @@ func main() { func run(eg *errgroup.Group) error { cfg := raft.Configuration{} + ctx := context.Background() + var lc net.ListenConfig for i := 0; i < 3; i++ { var suffrage raft.ServerSuffrage @@ -93,7 +96,7 @@ func run(eg *errgroup.Group) error { leaderhealth.Setup(r, s, []string{"RawKV"}) raftadmin.Register(s, r) - grpcSock, err := net.Listen("tcp", grpcAdders[i]) + grpcSock, err := lc.Listen(ctx, "tcp", grpcAdders[i]) if err != nil { return errors.WithStack(err) } @@ -102,7 +105,7 @@ func run(eg *errgroup.Group) error { return errors.WithStack(s.Serve(grpcSock)) }) - l, err := net.Listen("tcp", redisAdders[i]) + l, err := lc.Listen(ctx, "tcp", redisAdders[i]) if err != nil { return errors.WithStack(err) } diff --git a/main.go b/main.go index 4ffc1b7..f3ff849 100644 --- a/main.go +++ b/main.go @@ -41,12 +41,14 @@ func main() { } ctx := context.Background() + var lc net.ListenConfig + _, port, err := net.SplitHostPort(*myAddr) if err != nil { log.Fatalf("failed to parse local address (%q): %v", *myAddr, err) } - grpcSock, err := net.Listen("tcp", fmt.Sprintf(":%s", port)) + grpcSock, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%s", port)) if err != nil { log.Fatalf("failed to listen: %v", err) } @@ -72,7 +74,7 @@ func main() { raftadmin.Register(gs, r) reflection.Register(gs) - redisL, err := net.Listen("tcp", *redisAddr) + redisL, err := lc.Listen(ctx, "tcp", *redisAddr) if err != nil { log.Fatalf("failed to listen: %v", err) }