Skip to content

Commit 4534456

Browse files
committed
mas: store crossSigngingKeysReplacement period in sessionsDict struct instead of db
1 parent b5f34df commit 4534456

File tree

3 files changed

+61
-39
lines changed

3 files changed

+61
-39
lines changed

clientapi/routing/admin.go

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ import (
3535
"github.com/element-hq/dendrite/userapi/storage/shared"
3636
)
3737

38-
const (
39-
replacementPeriod time.Duration = 10 * time.Minute
40-
)
41-
4238
var (
4339
validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
4440
deviceDisplayName = "OIDC-native client"
@@ -807,27 +803,10 @@ func AdminAllowCrossSigningReplacementWithoutUIA(
807803

808804
switch req.Method {
809805
case http.MethodPost:
810-
rq := userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest{
811-
UserID: userID.String(),
812-
Duration: replacementPeriod,
813-
}
814-
var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse
815-
err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs)
816-
if err != nil && !errors.Is(err, sql.ErrNoRows) {
817-
util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA")
818-
return util.JSONResponse{
819-
Code: http.StatusInternalServerError,
820-
JSON: spec.Unknown(err.Error()),
821-
}
822-
} else if errors.Is(err, sql.ErrNoRows) {
823-
return util.JSONResponse{
824-
Code: http.StatusNotFound,
825-
JSON: spec.NotFound("User not found."),
826-
}
827-
}
806+
ts := sessions.allowCrossSigningKeysReplacement(userID.String())
828807
return util.JSONResponse{
829808
Code: http.StatusOK,
830-
JSON: map[string]int64{"updatable_without_uia_before_ms": rs.Timestamp},
809+
JSON: map[string]int64{"updatable_without_uia_before_ms": ts},
831810
}
832811
default:
833812
return util.JSONResponse{

clientapi/routing/key_crosssigning.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ func UploadCrossSigningDeviceKeys(
4949
return *resErr
5050
}
5151

52-
sessionID := uploadReq.Auth.Session
53-
if sessionID == "" {
54-
sessionID = util.RandomString(sessionIDLength)
55-
}
56-
5752
// Query existing keys to determine if UIA is required
5853
keyResp := api.QueryKeysResponse{}
5954
keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
@@ -68,7 +63,6 @@ func UploadCrossSigningDeviceKeys(
6863
}
6964

7065
existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID]
71-
requireUIA := true
7266

7367
if hasMasterKey {
7468
if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) {
@@ -89,10 +83,8 @@ func UploadCrossSigningDeviceKeys(
8983
logger.WithError(masterKeyResp.Error).Error("Failed to query master key")
9084
return convertKeyError(masterKeyResp.Error)
9185
}
92-
if k := masterKeyResp.Key; k != nil && k.UpdatableWithoutUIABeforeMs != nil {
93-
requireUIA = !(time.Now().UnixMilli() < *k.UpdatableWithoutUIABeforeMs)
94-
}
9586

87+
requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) && masterKeyResp.Key != nil
9688
if requireUIA {
9789
url := ""
9890
if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" {
@@ -122,9 +114,13 @@ func UploadCrossSigningDeviceKeys(
122114
),
123115
}
124116
}
125-
// XXX: is it necessary?
126-
sessions.addCompletedSessionStage(sessionID, CrossSigningResetStage)
117+
sessions.restrictCrossSigningKeysReplacement(device.UserID)
127118
} else {
119+
sessionID := uploadReq.Auth.Session
120+
if sessionID == "" {
121+
sessionID = util.RandomString(sessionIDLength)
122+
}
123+
128124
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
129125
return util.JSONResponse{
130126
Code: http.StatusUnauthorized,

clientapi/routing/register.go

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,17 @@ type sessionsDict struct {
6666
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
6767
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
6868
deleteSessionToDeviceID map[string]string
69+
// allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
70+
// cross-signing keys without UIA.
71+
allowedForCrossSigningKeysReplacement map[string]*time.Timer
6972
}
7073

7174
// defaultTimeout is the timeout used to clean up sessions
7275
const defaultTimeOut = time.Minute * 5
7376

77+
// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
78+
const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10
79+
7480
// getCompletedStages returns the completed stages for a session.
7581
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
7682
d.RLock()
@@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) {
119125
}
120126
}
121127

128+
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
129+
d.Lock()
130+
defer d.Unlock()
131+
ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli()
132+
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
133+
if ok {
134+
t.Reset(allowedForCrossSigningKeysReplacementDuration)
135+
return ts
136+
}
137+
d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc(
138+
allowedForCrossSigningKeysReplacementDuration,
139+
func() {
140+
d.restrictCrossSigningKeysReplacement(userID)
141+
},
142+
)
143+
return ts
144+
}
145+
146+
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
147+
d.RLock()
148+
defer d.RUnlock()
149+
_, ok := d.allowedForCrossSigningKeysReplacement[userID]
150+
return ok
151+
}
152+
153+
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
154+
d.Lock()
155+
defer d.Unlock()
156+
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
157+
if ok {
158+
if !t.Stop() {
159+
select {
160+
case <-t.C:
161+
default:
162+
}
163+
}
164+
delete(d.allowedForCrossSigningKeysReplacement, userID)
165+
}
166+
}
167+
122168
func newSessionsDict() *sessionsDict {
123169
return &sessionsDict{
124-
sessions: make(map[string][]authtypes.LoginType),
125-
sessionCompletedResult: make(map[string]registerResponse),
126-
params: make(map[string]registerRequest),
127-
timer: make(map[string]*time.Timer),
128-
deleteSessionToDeviceID: make(map[string]string),
170+
sessions: make(map[string][]authtypes.LoginType),
171+
sessionCompletedResult: make(map[string]registerResponse),
172+
params: make(map[string]registerRequest),
173+
timer: make(map[string]*time.Timer),
174+
deleteSessionToDeviceID: make(map[string]string),
175+
allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer),
129176
}
130177
}
131178

0 commit comments

Comments
 (0)