Skip to content

Commit 01e8989

Browse files
committed
reexamine socket and start/stop handling
1 parent 44e849d commit 01e8989

File tree

4 files changed

+117
-62
lines changed

4 files changed

+117
-62
lines changed

pkg/agent/agent.go

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ var errAgentStopped = errors.New("agent has been stopped")
2222
type Agent struct {
2323
keyring agent.Agent
2424
caClient *caclient.Client
25-
ctx context.Context
2625

2726
agentSocketPath string
2827
agentListener net.Listener
@@ -35,40 +34,39 @@ type Agent struct {
3534
closeOnce sync.Once
3635
}
3736

38-
// Start creates and starts an SSH agent that listens on the specified socket path.
39-
// If agentSocketPath is empty, a temporary socket will be created.
40-
// The agent will automatically close when the context is cancelled.
41-
func Start(ctx context.Context, caClient *caclient.Client, agentSocketPath string) (*Agent, error) {
42-
keyring := agent.NewKeyring()
43-
a := &Agent{
44-
ctx: ctx,
45-
agentSocketPath: agentSocketPath,
46-
keyring: keyring,
47-
caClient: caClient,
48-
done: make(chan struct{}),
49-
}
50-
37+
// New creates a new SSH agent. This does not start listening - call Serve() to begin accepting connections.
38+
// If agentSocketPath is empty, a temporary socket will be created when Serve() is called.
39+
func New(caClient *caclient.Client, agentSocketPath string) (*Agent, error) {
5140
pub, priv, err := sshcert.GenerateKeys()
5241
if err != nil {
5342
return nil, err
5443
}
55-
a.privateKey = priv
56-
a.publicKey = pub
5744

58-
if a.ctx == nil {
59-
a.ctx = context.Background()
60-
}
45+
return &Agent{
46+
agentSocketPath: agentSocketPath,
47+
keyring: agent.NewKeyring(),
48+
caClient: caClient,
49+
publicKey: pub,
50+
privateKey: priv,
51+
done: make(chan struct{}),
52+
}, nil
53+
}
6154

62-
err = a.startAgentListener()
63-
if err != nil {
64-
return nil, err
55+
// Serve starts the agent listening on the configured socket and blocks until the context is cancelled.
56+
// Returns an error if the listener cannot be started, otherwise returns ctx.Err() when shutdown completes.
57+
func (a *Agent) Serve(ctx context.Context) error {
58+
if err := a.startAgentListener(); err != nil {
59+
return err
6560
}
6661

67-
context.AfterFunc(ctx, func() {
68-
a.Close()
69-
})
62+
// Serve connections in background
63+
go a.serve(ctx)
64+
65+
// Block until context cancelled
66+
<-ctx.Done()
67+
a.Close()
7068

71-
return a, nil
69+
return ctx.Err()
7270
}
7371

7472
// Credential contains the private key and certificate in PEM format
@@ -143,23 +141,35 @@ func (a *Agent) startAgentListener() error {
143141
return fmt.Errorf("unable to set permissions on agent socket: %w", err)
144142
}
145143
a.agentListener = agentListener
146-
go a.listenAndServeAgent()
147144
return nil
148145
}
149146

150-
func (a *Agent) listenAndServeAgent() {
151-
for a.Running() {
147+
func (a *Agent) serve(ctx context.Context) {
148+
for {
149+
// Check if context is done
150+
select {
151+
case <-ctx.Done():
152+
return
153+
default:
154+
}
155+
152156
conn, err := a.agentListener.Accept()
153157
if err != nil {
154158
if conn != nil {
155159
conn.Close()
156160
}
157-
if !a.Running() {
158-
// Agent is shutting down
161+
// Check if error is from listener being closed
162+
if errors.Is(err, net.ErrClosed) {
163+
return
164+
}
165+
// Check context again before logging
166+
select {
167+
case <-ctx.Done():
159168
return
169+
default:
170+
log.Warnf("error on accept from SSH_AUTH_SOCK listener: %v", err)
171+
continue
160172
}
161-
log.Warnf("error on accept from SSH_AUTH_SOCK listener: %v", err)
162-
continue
163173
}
164174
go a.serveAgent(conn)
165175
}

pkg/agent/agent_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,22 @@ func TestBasics(t *testing.T) {
3030
require.NoError(t, err)
3131

3232
ctx, cancel := context.WithCancel(t.Context())
33+
defer cancel()
3334

34-
a, err := agent.Start(ctx, nil, "")
35+
a, err := agent.New(nil, "")
3536
require.NoError(t, err)
3637

38+
// Serve in background
39+
go func() {
40+
err := a.Serve(ctx)
41+
if err != nil && !errors.Is(err, context.Canceled) {
42+
t.Errorf("agent.Serve error: %v", err)
43+
}
44+
}()
45+
46+
// Give agent time to start listening
47+
time.Sleep(10 * time.Millisecond)
48+
3749
server, err := sshd.Start(caPub)
3850
require.NoError(t, err)
3951
defer server.Close()

pkg/broker/broker.go

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,31 @@ type MatchResponse struct {
4040
Error string
4141
}
4242

43-
func New(ctx context.Context, log slog.Logger, socketPath string, authCommand string) (*Broker, error) {
44-
b := Broker{
43+
// New creates a new Broker instance. This does not start listening - call Serve() to begin accepting connections.
44+
func New(log slog.Logger, socketPath string, authCommand string) *Broker {
45+
return &Broker{
4546
auth: NewAuth(authCommand),
4647
brokerSocketPath: socketPath,
4748
done: make(chan struct{}),
4849
log: log,
4950
}
51+
}
5052

51-
err := b.startBrokerListener()
52-
if err != nil {
53-
return nil, fmt.Errorf("Unable to start broker socket: %w", err)
53+
// Serve starts the broker listening on the configured socket and blocks until the context is cancelled.
54+
// Returns an error if the listener cannot be started, otherwise returns ctx.Err() when shutdown completes.
55+
func (b *Broker) Serve(ctx context.Context) error {
56+
if err := b.startBrokerListener(); err != nil {
57+
return fmt.Errorf("unable to start broker socket: %w", err)
5458
}
5559

56-
context.AfterFunc(ctx, func() {
57-
b.Close()
58-
})
59-
return &b, nil
60+
// Serve connections in background
61+
go b.serve(ctx)
62+
63+
// Block until context cancelled
64+
<-ctx.Done()
65+
b.Close()
66+
67+
return ctx.Err()
6068
}
6169

6270
func (b *Broker) startBrokerListener() error {
@@ -67,7 +75,6 @@ func (b *Broker) startBrokerListener() error {
6775
}
6876

6977
b.brokerListener = brokerListener
70-
go b.listenAndServe()
7178
return nil
7279
}
7380

@@ -77,26 +84,37 @@ func (b *Broker) Match(input MatchRequest, output *MatchResponse) error {
7784
return nil
7885
}
7986

80-
func (b *Broker) listenAndServe() {
87+
func (b *Broker) serve(ctx context.Context) {
8188
server := rpc.NewServer()
8289
server.Register(b)
8390
for {
91+
// Check if context is done
92+
select {
93+
case <-ctx.Done():
94+
return
95+
default:
96+
}
97+
8498
conn, err := b.brokerListener.Accept()
8599
if err != nil {
86100
// Check if error is from listener being closed
87101
if errors.Is(err, net.ErrClosed) {
88102
// Listener closed, exit gracefully
89103
return
90104
}
91-
// Log other errors only if still running
92-
if b.Running() {
105+
// Check context again before logging
106+
select {
107+
case <-ctx.Done():
108+
return
109+
default:
93110
b.log.Warn("Unable to accept connection", "error", err)
94111
continue
95112
}
96-
return
97113
}
98-
defer conn.Close()
99-
go server.ServeConn(conn)
114+
go func() {
115+
defer conn.Close()
116+
server.ServeConn(conn)
117+
}()
100118
}
101119
}
102120

pkg/broker/broker_test.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"log/slog"
55
"net/rpc"
66
"testing"
7+
"time"
78

89
"github.com/lmittmann/tint"
910
"github.com/stretchr/testify/require"
@@ -13,14 +14,21 @@ func Test_RpcBasics(t *testing.T) {
1314
ctx := t.Context()
1415
authCommand := "echo '6:thello,'"
1516
socketPath := t.TempDir() + "/broker.sock"
16-
b, err := New(
17-
ctx,
18-
*testLogger(t),
19-
socketPath,
20-
authCommand)
21-
require.NoError(t, err)
17+
18+
b := New(*testLogger(t), socketPath, authCommand)
19+
20+
// Serve in background
21+
go func() {
22+
err := b.Serve(ctx)
23+
if err != nil && err != ctx.Err() {
24+
t.Errorf("broker.Serve error: %v", err)
25+
}
26+
}()
2227
defer b.Close()
2328

29+
// Give broker time to start listening
30+
time.Sleep(10 * time.Millisecond)
31+
2432
client, err := rpc.Dial("unix", socketPath)
2533
require.NoError(t, err)
2634

@@ -35,14 +43,21 @@ func Test_MatchRequestFields(t *testing.T) {
3543
ctx := t.Context()
3644
authCommand := "echo '6:thello,'"
3745
socketPath := t.TempDir() + "/broker.sock"
38-
b, err := New(
39-
ctx,
40-
*testLogger(t),
41-
socketPath,
42-
authCommand)
43-
require.NoError(t, err)
46+
47+
b := New(*testLogger(t), socketPath, authCommand)
48+
49+
// Serve in background
50+
go func() {
51+
err := b.Serve(ctx)
52+
if err != nil && err != ctx.Err() {
53+
t.Errorf("broker.Serve error: %v", err)
54+
}
55+
}()
4456
defer b.Close()
4557

58+
// Give broker time to start listening
59+
time.Sleep(10 * time.Millisecond)
60+
4661
client, err := rpc.Dial("unix", socketPath)
4762
require.NoError(t, err)
4863

0 commit comments

Comments
 (0)