Skip to content

Commit e7b7f80

Browse files
ncodeclaude
andcommitted
test: add comprehensive tests for recent bug fixes
Add tests covering: - New() validation for missing key and id fields - runCommand() empty command error handling - handleServiceCriticalState() filtering by ServiceID and ServiceChecks - session() renewal cancellation when creating new sessions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent f8275de commit e7b7f80

File tree

1 file changed

+219
-24
lines changed

1 file changed

+219
-24
lines changed

internal/ballot/ballot_test.go

Lines changed: 219 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,36 @@ func TestNew(t *testing.T) {
4949
assert.Error(t, err)
5050
assert.Nil(t, b)
5151
})
52+
53+
t.Run("failure due to missing key", func(t *testing.T) {
54+
// Set up configuration without the key field
55+
viper.Set("election.services.nokey.id", "test_service_id")
56+
viper.Set("election.services.nokey.primaryTag", "primary")
57+
58+
defer func() {
59+
viper.Reset()
60+
}()
61+
62+
b, err := New(context.Background(), "nokey")
63+
assert.Error(t, err)
64+
assert.Nil(t, b)
65+
assert.Contains(t, err.Error(), "key is required")
66+
})
67+
68+
t.Run("failure due to missing id", func(t *testing.T) {
69+
// Set up configuration without the id field
70+
viper.Set("election.services.noid.key", "election/test/leader")
71+
viper.Set("election.services.noid.primaryTag", "primary")
72+
73+
defer func() {
74+
viper.Reset()
75+
}()
76+
77+
b, err := New(context.Background(), "noid")
78+
assert.Error(t, err)
79+
assert.Nil(t, b)
80+
assert.Contains(t, err.Error(), "service ID is required")
81+
})
5282
}
5383

5484
func TestCopyServiceToRegistration(t *testing.T) {
@@ -126,36 +156,64 @@ func (m *MockCommandExecutor) CommandContext(ctx context.Context, name string, a
126156
}
127157

128158
func TestRunCommand(t *testing.T) {
129-
// Create a mock CommandExecutor
130-
mockExecutor := new(MockCommandExecutor)
159+
t.Run("successful command execution", func(t *testing.T) {
160+
// Create a mock CommandExecutor
161+
mockExecutor := new(MockCommandExecutor)
131162

132-
// Create a Ballot instance with the mock executor
133-
b := &Ballot{
134-
executor: mockExecutor,
135-
ctx: context.Background(),
136-
}
163+
// Create a Ballot instance with the mock executor
164+
b := &Ballot{
165+
executor: mockExecutor,
166+
ctx: context.Background(),
167+
}
137168

138-
// Define the command to run
139-
command := "echo hello"
140-
payload := &ElectionPayload{
141-
Address: "127.0.0.1",
142-
Port: 8080,
143-
SessionID: "session",
144-
}
169+
// Define the command to run
170+
command := "echo hello"
171+
payload := &ElectionPayload{
172+
Address: "127.0.0.1",
173+
Port: 8080,
174+
SessionID: "session",
175+
}
145176

146-
// Set up the expectation
147-
// Here, we're using a command that just outputs "mocked" when run
148-
mockCmd := exec.Command("echo", "mocked")
149-
mockExecutor.On("CommandContext", b.ctx, "echo", []string{"hello"}).Return(mockCmd)
177+
// Set up the expectation
178+
// Here, we're using a command that just outputs "mocked" when run
179+
mockCmd := exec.Command("echo", "mocked")
180+
mockExecutor.On("CommandContext", b.ctx, "echo", []string{"hello"}).Return(mockCmd)
150181

151-
// Call the method under test
152-
_, err := b.runCommand(command, payload)
182+
// Call the method under test
183+
_, err := b.runCommand(command, payload)
153184

154-
// Assert that the expectations were met
155-
mockExecutor.AssertExpectations(t)
185+
// Assert that the expectations were met
186+
mockExecutor.AssertExpectations(t)
156187

157-
// Assert that the method did not return an error
158-
assert.NoError(t, err)
188+
// Assert that the method did not return an error
189+
assert.NoError(t, err)
190+
})
191+
192+
t.Run("empty command returns error", func(t *testing.T) {
193+
// Create a mock CommandExecutor
194+
mockExecutor := new(MockCommandExecutor)
195+
196+
// Create a Ballot instance with the mock executor
197+
b := &Ballot{
198+
executor: mockExecutor,
199+
ctx: context.Background(),
200+
}
201+
202+
// Define an empty command
203+
command := ""
204+
payload := &ElectionPayload{
205+
Address: "127.0.0.1",
206+
Port: 8080,
207+
SessionID: "session",
208+
}
209+
210+
// Call the method under test with empty command
211+
_, err := b.runCommand(command, payload)
212+
213+
// Assert that an error is returned for empty command
214+
assert.Error(t, err)
215+
assert.Contains(t, err.Error(), "empty command")
216+
})
159217
}
160218

161219
func TestIsLeader(t *testing.T) {
@@ -372,6 +430,47 @@ func TestSession(t *testing.T) {
372430
assert.Error(t, err)
373431
assert.Equal(t, expectedErr, err)
374432
})
433+
434+
t.Run("session cancels previous renewal when creating new session", func(t *testing.T) {
435+
firstSessionID := "session1"
436+
secondSessionID := "session2"
437+
438+
mockSession := new(MockSession)
439+
// First session creation
440+
mockSession.On("Create", mock.Anything, (*api.WriteOptions)(nil)).Return(firstSessionID, nil, nil).Once()
441+
mockSession.On("RenewPeriodic", mock.Anything, firstSessionID, (*api.WriteOptions)(nil), mock.Anything).Return(nil)
442+
// Info check for first session (returns nil to indicate session expired, forcing new session creation)
443+
mockSession.On("Info", firstSessionID, (*api.QueryOptions)(nil)).Return((*api.SessionEntry)(nil), &api.QueryMeta{}, nil)
444+
// Second session creation
445+
mockSession.On("Create", mock.Anything, (*api.WriteOptions)(nil)).Return(secondSessionID, nil, nil).Once()
446+
mockSession.On("RenewPeriodic", mock.Anything, secondSessionID, (*api.WriteOptions)(nil), mock.Anything).Return(nil)
447+
448+
mockClient := &MockConsulClient{}
449+
mockClient.On("Session").Return(mockSession)
450+
451+
b := &Ballot{
452+
client: mockClient,
453+
TTL: 10 * time.Second,
454+
ctx: context.Background(),
455+
}
456+
457+
// Create first session
458+
err := b.session()
459+
assert.NoError(t, err)
460+
storedSessionID, ok := b.getSessionID()
461+
assert.True(t, ok)
462+
assert.Equal(t, firstSessionID, *storedSessionID)
463+
464+
// Verify sessionRenewalCancel is set
465+
assert.NotNil(t, b.sessionRenewalCancel)
466+
467+
// Create second session - should cancel the first renewal
468+
err = b.session()
469+
assert.NoError(t, err)
470+
storedSessionID, ok = b.getSessionID()
471+
assert.True(t, ok)
472+
assert.Equal(t, secondSessionID, *storedSessionID)
473+
})
375474
}
376475

377476
func TestHandleServiceCriticalState(t *testing.T) {
@@ -455,6 +554,102 @@ func TestHandleServiceCriticalState(t *testing.T) {
455554
assert.Error(t, err)
456555
assert.ErrorContains(t, err, expectedErr.Error())
457556
})
557+
558+
t.Run("filters checks by service ID", func(t *testing.T) {
559+
serviceID := "test_service_id"
560+
mockHealth := new(MockHealth)
561+
// Return multiple health checks, some for this instance and some for others
562+
mockHealth.On("Checks", "test_service", (*api.QueryOptions)(nil)).Return([]*api.HealthCheck{
563+
{ServiceID: serviceID, CheckID: "check1", Status: "passing"},
564+
{ServiceID: "other_service_id", CheckID: "check2", Status: "critical"}, // Different instance - should be ignored
565+
}, nil, nil)
566+
567+
mockClient := &MockConsulClient{}
568+
mockClient.On("Health").Return(mockHealth)
569+
570+
b := &Ballot{
571+
client: mockClient,
572+
ID: serviceID,
573+
Name: "test_service",
574+
}
575+
576+
// Should pass because only our instance's checks are considered
577+
err := b.handleServiceCriticalState()
578+
assert.NoError(t, err)
579+
})
580+
581+
t.Run("filters checks by service ID with critical state", func(t *testing.T) {
582+
serviceID := "test_service_id"
583+
serviceName := "test_service"
584+
primaryTag := "primary"
585+
586+
mockHealth := new(MockHealth)
587+
mockSession := new(MockSession)
588+
mockAgent := new(MockAgent)
589+
mockCatalog := new(MockCatalog)
590+
591+
// Return health checks where this instance is critical
592+
mockHealth.On("Checks", serviceName, (*api.QueryOptions)(nil)).Return([]*api.HealthCheck{
593+
{ServiceID: serviceID, CheckID: "check1", Status: "critical"},
594+
{ServiceID: "other_service_id", CheckID: "check2", Status: "passing"},
595+
}, nil, nil)
596+
597+
sessionID := "session_id"
598+
mockSession.On("Destroy", sessionID, (*api.WriteOptions)(nil)).Return(nil, nil)
599+
600+
// Mock Agent and Catalog for updateServiceTags
601+
mockAgent.On("Service", serviceID, mock.Anything).Return(&api.AgentService{
602+
ID: serviceID,
603+
Service: serviceName,
604+
Tags: []string{},
605+
}, nil, nil)
606+
mockCatalog.On("Service", serviceName, primaryTag, mock.Anything).Return([]*api.CatalogService{}, nil, nil)
607+
608+
mockClient := &MockConsulClient{}
609+
mockClient.On("Health").Return(mockHealth)
610+
mockClient.On("Session").Return(mockSession)
611+
mockClient.On("Agent").Return(mockAgent)
612+
mockClient.On("Catalog").Return(mockCatalog)
613+
614+
b := &Ballot{
615+
client: mockClient,
616+
ID: serviceID,
617+
Name: serviceName,
618+
PrimaryTag: primaryTag,
619+
}
620+
b.sessionID.Store(&sessionID)
621+
622+
err := b.handleServiceCriticalState()
623+
// Should return error because service is in critical state
624+
assert.Error(t, err)
625+
assert.Contains(t, err.Error(), "service is in critical state")
626+
// Verify session was destroyed due to critical state
627+
mockSession.AssertCalled(t, "Destroy", sessionID, (*api.WriteOptions)(nil))
628+
})
629+
630+
t.Run("filters checks by ServiceChecks list", func(t *testing.T) {
631+
serviceID := "test_service_id"
632+
mockHealth := new(MockHealth)
633+
// Return multiple health checks
634+
mockHealth.On("Checks", "test_service", (*api.QueryOptions)(nil)).Return([]*api.HealthCheck{
635+
{ServiceID: serviceID, CheckID: "check1", Status: "passing"},
636+
{ServiceID: serviceID, CheckID: "check2", Status: "critical"}, // Not in ServiceChecks - should be ignored
637+
}, nil, nil)
638+
639+
mockClient := &MockConsulClient{}
640+
mockClient.On("Health").Return(mockHealth)
641+
642+
b := &Ballot{
643+
client: mockClient,
644+
ID: serviceID,
645+
Name: "test_service",
646+
ServiceChecks: []string{"check1"}, // Only consider check1
647+
}
648+
649+
// Should pass because check2 (which is critical) is not in ServiceChecks
650+
err := b.handleServiceCriticalState()
651+
assert.NoError(t, err)
652+
})
458653
}
459654

460655
func TestUpdateServiceTags(t *testing.T) {

0 commit comments

Comments
 (0)