Skip to content

Commit b8ea41b

Browse files
committed
tests for sessionsDict.crossSigningKeysReplacement
1 parent c1ad175 commit b8ea41b

File tree

2 files changed

+72
-18
lines changed

2 files changed

+72
-18
lines changed

clientapi/routing/register.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,16 @@ type sessionsDict struct {
6666
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
6767
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
6868
deleteSessionToDeviceID map[string]string
69-
// allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
69+
// crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
7070
// cross-signing keys without UIA.
71-
allowedForCrossSigningKeysReplacement map[string]*time.Timer
71+
crossSigningKeysReplacement map[string]*time.Timer
7272
}
7373

7474
// defaultTimeout is the timeout used to clean up sessions
7575
const defaultTimeOut = time.Minute * 5
7676

77-
// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
78-
const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10
77+
// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
78+
const crossSigningKeysReplacementDuration = time.Minute * 10
7979

8080
// getCompletedStages returns the completed stages for a session.
8181
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
@@ -128,14 +128,14 @@ func (d *sessionsDict) deleteSession(sessionID string) {
128128
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
129129
d.Lock()
130130
defer d.Unlock()
131-
ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli()
132-
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
131+
ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli()
132+
t, ok := d.crossSigningKeysReplacement[userID]
133133
if ok {
134-
t.Reset(allowedForCrossSigningKeysReplacementDuration)
134+
t.Reset(crossSigningKeysReplacementDuration)
135135
return ts
136136
}
137-
d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc(
138-
allowedForCrossSigningKeysReplacementDuration,
137+
d.crossSigningKeysReplacement[userID] = time.AfterFunc(
138+
crossSigningKeysReplacementDuration,
139139
func() {
140140
d.restrictCrossSigningKeysReplacement(userID)
141141
},
@@ -146,33 +146,33 @@ func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
146146
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
147147
d.RLock()
148148
defer d.RUnlock()
149-
_, ok := d.allowedForCrossSigningKeysReplacement[userID]
149+
_, ok := d.crossSigningKeysReplacement[userID]
150150
return ok
151151
}
152152

153153
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
154154
d.Lock()
155155
defer d.Unlock()
156-
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
156+
t, ok := d.crossSigningKeysReplacement[userID]
157157
if ok {
158158
if !t.Stop() {
159159
select {
160160
case <-t.C:
161161
default:
162162
}
163163
}
164-
delete(d.allowedForCrossSigningKeysReplacement, userID)
164+
delete(d.crossSigningKeysReplacement, userID)
165165
}
166166
}
167167

168168
func newSessionsDict() *sessionsDict {
169169
return &sessionsDict{
170-
sessions: make(map[string][]authtypes.LoginType),
171-
sessionCompletedResult: make(map[string]registerResponse),
172-
params: make(map[string]registerRequest),
173-
timer: make(map[string]*time.Timer),
174-
deleteSessionToDeviceID: make(map[string]string),
175-
allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer),
170+
sessions: make(map[string][]authtypes.LoginType),
171+
sessionCompletedResult: make(map[string]registerResponse),
172+
params: make(map[string]registerRequest),
173+
timer: make(map[string]*time.Timer),
174+
deleteSessionToDeviceID: make(map[string]string),
175+
crossSigningKeysReplacement: make(map[string]*time.Timer),
176176
}
177177
}
178178

clientapi/routing/register_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
669669
assert.Equal(t, expectedDisplayName, profile.DisplayName)
670670
})
671671
}
672+
673+
func TestCrossSigningKeysReplacement(t *testing.T) {
674+
userID := "@user:example.com"
675+
676+
t.Run("Can add new session", func(t *testing.T) {
677+
s := newSessionsDict()
678+
assert.Empty(t, s.crossSigningKeysReplacement)
679+
s.allowCrossSigningKeysReplacement(userID)
680+
assert.Len(t, s.crossSigningKeysReplacement, 1)
681+
assert.Contains(t, s.crossSigningKeysReplacement, userID)
682+
})
683+
684+
t.Run("Can check if session exists or not", func(t *testing.T) {
685+
s := newSessionsDict()
686+
t.Run("exists", func(t *testing.T) {
687+
s.allowCrossSigningKeysReplacement(userID)
688+
assert.Len(t, s.crossSigningKeysReplacement, 1)
689+
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
690+
})
691+
692+
t.Run("not exists", func(t *testing.T) {
693+
assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com"))
694+
})
695+
})
696+
697+
t.Run("Can deactivate session", func(t *testing.T) {
698+
s := newSessionsDict()
699+
assert.Empty(t, s.crossSigningKeysReplacement)
700+
t.Run("not exists", func(t *testing.T) {
701+
s.restrictCrossSigningKeysReplacement("@random:test.com")
702+
assert.Empty(t, s.crossSigningKeysReplacement)
703+
})
704+
705+
t.Run("exists", func(t *testing.T) {
706+
s.allowCrossSigningKeysReplacement(userID)
707+
s.restrictCrossSigningKeysReplacement(userID)
708+
assert.Empty(t, s.crossSigningKeysReplacement)
709+
})
710+
})
711+
712+
t.Run("Can erase expired sessions", func(t *testing.T) {
713+
s := newSessionsDict()
714+
s.allowCrossSigningKeysReplacement(userID)
715+
assert.Len(t, s.crossSigningKeysReplacement, 1)
716+
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
717+
timer := s.crossSigningKeysReplacement[userID]
718+
719+
// pretending the timer is expired
720+
timer.Reset(time.Millisecond)
721+
time.Sleep(time.Millisecond * 500)
722+
723+
assert.Empty(t, s.crossSigningKeysReplacement)
724+
})
725+
}

0 commit comments

Comments
 (0)