Skip to content

Commit 32c90e1

Browse files
authored
fix: context data race (#4)
1 parent d9d1f1d commit 32c90e1

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

context.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ type Context interface {
9393

9494
type sshContext struct {
9595
context.Context
96-
*sync.Mutex
96+
*sync.RWMutex
9797
}
9898

9999
func newContext(srv *Server) (*sshContext, context.CancelFunc) {
100100
innerCtx, cancel := context.WithCancel(context.Background())
101-
ctx := &sshContext{innerCtx, &sync.Mutex{}}
101+
ctx := &sshContext{innerCtx, &sync.RWMutex{}}
102102
ctx.SetValue(ContextKeyServer, srv)
103103
perms := &Permissions{&gossh.Permissions{}}
104104
ctx.SetValue(ContextKeyPermissions, perms)
@@ -120,9 +120,23 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
120120
}
121121

122122
func (ctx *sshContext) SetValue(key, value interface{}) {
123+
ctx.RWMutex.Lock()
124+
defer ctx.RWMutex.Unlock()
123125
ctx.Context = context.WithValue(ctx.Context, key, value)
124126
}
125127

128+
func (ctx *sshContext) Value(key interface{}) interface{} {
129+
ctx.RWMutex.RLock()
130+
defer ctx.RWMutex.RUnlock()
131+
return ctx.Context.Value(key)
132+
}
133+
134+
func (ctx *sshContext) Done() <-chan struct{} {
135+
ctx.RWMutex.RLock()
136+
defer ctx.RWMutex.RUnlock()
137+
return ctx.Context.Done()
138+
}
139+
126140
func (ctx *sshContext) User() string {
127141
return ctx.Value(ContextKeyUser).(string)
128142
}

context_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,27 @@ func TestSetValue(t *testing.T) {
4545
t.Fatal(err)
4646
}
4747
}
48+
49+
func TestRaceRWIssue160(t *testing.T) {
50+
value := "foo"
51+
key := "bar"
52+
session, _, cleanup := newTestSessionWithOptions(t, &Server{
53+
Handler: func(s Session) {
54+
t.Run("test done", func(t *testing.T) {
55+
t.Parallel()
56+
go func() {
57+
s.Context().SetValue(key, value)
58+
}()
59+
go func() {
60+
select {
61+
case <-s.Context().Done():
62+
}
63+
}()
64+
})
65+
},
66+
}, nil)
67+
defer cleanup()
68+
if err := session.Run(""); err != nil {
69+
t.Fatal(err)
70+
}
71+
}

0 commit comments

Comments
 (0)