@@ -66,11 +66,17 @@ type sessionsDict struct {
6666 // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
6767 // the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
6868 deleteSessionToDeviceID map [string ]string
69+ // allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
70+ // cross-signing keys without UIA.
71+ allowedForCrossSigningKeysReplacement map [string ]* time.Timer
6972}
7073
7174// defaultTimeout is the timeout used to clean up sessions
7275const defaultTimeOut = time .Minute * 5
7376
77+ // allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
78+ const allowedForCrossSigningKeysReplacementDuration = time .Minute * 10
79+
7480// getCompletedStages returns the completed stages for a session.
7581func (d * sessionsDict ) getCompletedStages (sessionID string ) []authtypes.LoginType {
7682 d .RLock ()
@@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) {
119125 }
120126}
121127
128+ func (d * sessionsDict ) allowCrossSigningKeysReplacement (userID string ) int64 {
129+ d .Lock ()
130+ defer d .Unlock ()
131+ ts := time .Now ().Add (allowedForCrossSigningKeysReplacementDuration ).UnixMilli ()
132+ t , ok := d .allowedForCrossSigningKeysReplacement [userID ]
133+ if ok {
134+ t .Reset (allowedForCrossSigningKeysReplacementDuration )
135+ return ts
136+ }
137+ d .allowedForCrossSigningKeysReplacement [userID ] = time .AfterFunc (
138+ allowedForCrossSigningKeysReplacementDuration ,
139+ func () {
140+ d .restrictCrossSigningKeysReplacement (userID )
141+ },
142+ )
143+ return ts
144+ }
145+
146+ func (d * sessionsDict ) isCrossSigningKeysReplacementAllowed (userID string ) bool {
147+ d .RLock ()
148+ defer d .RUnlock ()
149+ _ , ok := d .allowedForCrossSigningKeysReplacement [userID ]
150+ return ok
151+ }
152+
153+ func (d * sessionsDict ) restrictCrossSigningKeysReplacement (userID string ) {
154+ d .Lock ()
155+ defer d .Unlock ()
156+ t , ok := d .allowedForCrossSigningKeysReplacement [userID ]
157+ if ok {
158+ if ! t .Stop () {
159+ select {
160+ case <- t .C :
161+ default :
162+ }
163+ }
164+ delete (d .allowedForCrossSigningKeysReplacement , userID )
165+ }
166+ }
167+
122168func newSessionsDict () * sessionsDict {
123169 return & sessionsDict {
124- sessions : make (map [string ][]authtypes.LoginType ),
125- sessionCompletedResult : make (map [string ]registerResponse ),
126- params : make (map [string ]registerRequest ),
127- timer : make (map [string ]* time.Timer ),
128- deleteSessionToDeviceID : make (map [string ]string ),
170+ sessions : make (map [string ][]authtypes.LoginType ),
171+ sessionCompletedResult : make (map [string ]registerResponse ),
172+ params : make (map [string ]registerRequest ),
173+ timer : make (map [string ]* time.Timer ),
174+ deleteSessionToDeviceID : make (map [string ]string ),
175+ allowedForCrossSigningKeysReplacement : make (map [string ]* time.Timer ),
129176 }
130177}
131178
0 commit comments