Skip to content

Commit 021431c

Browse files
committed
mas: fix key_crosssigning_test.go
1 parent b44f899 commit 021431c

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

clientapi/routing/key_crosssigning_test.go

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"strings"
1111
"testing"
1212

13+
"github.com/element-hq/dendrite/userapi/types"
14+
1315
"github.com/element-hq/dendrite/setup/config"
1416
"github.com/element-hq/dendrite/test"
1517
"github.com/element-hq/dendrite/test/testrig"
@@ -20,19 +22,28 @@ import (
2022
)
2123

2224
type mockKeyAPI struct {
23-
t *testing.T
24-
userResponses map[string]api.QueryKeysResponse
25+
t *testing.T
26+
queryKeysData map[string]api.QueryKeysResponse
27+
queryMasterKeysData map[string]api.QueryMasterKeysResponse
2528
}
2629

2730
func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
28-
res.MasterKeys = m.userResponses[req.UserID].MasterKeys
29-
res.SelfSigningKeys = m.userResponses[req.UserID].SelfSigningKeys
30-
res.UserSigningKeys = m.userResponses[req.UserID].UserSigningKeys
31+
res.MasterKeys = m.queryKeysData[req.UserID].MasterKeys
32+
res.SelfSigningKeys = m.queryKeysData[req.UserID].SelfSigningKeys
33+
res.UserSigningKeys = m.queryKeysData[req.UserID].UserSigningKeys
3134
if m.t != nil {
3235
m.t.Logf("QueryKeys: %+v => %+v", req, res)
3336
}
3437
}
3538

39+
func (m mockKeyAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) {
40+
res.Key = m.queryMasterKeysData[req.UserID].Key
41+
res.Error = m.queryMasterKeysData[req.UserID].Error
42+
if m.t != nil {
43+
m.t.Logf("QueryMasterKeys: %+v => %+v", req, res)
44+
}
45+
}
46+
3647
func (m mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
3748
// Just a dummy upload which always succeeds
3849
}
@@ -53,13 +64,19 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) {
5364
req.Header.Set("Content-Type", "application/json")
5465

5566
keyserverAPI := &mockKeyAPI{
56-
userResponses: map[string]api.QueryKeysResponse{
67+
queryKeysData: map[string]api.QueryKeysResponse{
68+
"@user:example.com": {},
69+
},
70+
queryMasterKeysData: map[string]api.QueryMasterKeysResponse{
5771
"@user:example.com": {},
5872
},
5973
}
6074
device := &api.Device{UserID: "@user:example.com", ID: "device"}
61-
cfg := &config.ClientAPI{}
62-
75+
cfg := &config.ClientAPI{
76+
MSCs: &config.MSCs{
77+
MSCs: []string{},
78+
},
79+
}
6380
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
6481
if res.Code != http.StatusOK {
6582
t.Fatalf("expected status %d, got %d", http.StatusOK, res.Code)
@@ -101,18 +118,32 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) {
101118

102119
keyserverAPI := &mockKeyAPI{
103120
t: t,
104-
userResponses: map[string]api.QueryKeysResponse{
121+
queryKeysData: map[string]api.QueryKeysResponse{
105122
"@user:example.com": {
106123
MasterKeys: map[string]fclient.CrossSigningKey{
107-
"@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}},
124+
"@user:example.com": {
125+
UserID: "@user:example.com",
126+
Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster},
127+
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}},
108128
},
109129
SelfSigningKeys: nil,
110130
UserSigningKeys: nil,
111131
},
112132
},
133+
queryMasterKeysData: map[string]api.QueryMasterKeysResponse{
134+
"@user:example.com": {
135+
Key: &types.CrossSigningKey{
136+
KeyData: spec.Base64Bytes("key1"),
137+
},
138+
},
139+
},
113140
}
114141
device := &api.Device{UserID: "@user:example.com", ID: "device"}
115-
cfg := &config.ClientAPI{}
142+
cfg := &config.ClientAPI{
143+
MSCs: &config.MSCs{
144+
MSCs: []string{},
145+
},
146+
}
116147

117148
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
118149
if res.Code != http.StatusUnauthorized {
@@ -132,8 +163,11 @@ func Test_UploadCrossSigningDeviceKeys_InvalidJSON(t *testing.T) {
132163

133164
keyserverAPI := &mockKeyAPI{}
134165
device := &api.Device{UserID: "@user:example.com", ID: "device"}
135-
cfg := &config.ClientAPI{}
136-
166+
cfg := &config.ClientAPI{
167+
MSCs: &config.MSCs{
168+
MSCs: []string{},
169+
},
170+
}
137171
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
138172
if res.Code != http.StatusBadRequest {
139173
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, res.Code)
@@ -151,10 +185,21 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) {
151185
req.Header.Set("Content-Type", "application/json")
152186

153187
keyserverAPI := &mockKeyAPI{
154-
userResponses: map[string]api.QueryKeysResponse{
188+
queryKeysData: map[string]api.QueryKeysResponse{
155189
"@user:example.com": {
156190
MasterKeys: map[string]fclient.CrossSigningKey{
157-
"@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}},
191+
"@user:example.com": {
192+
UserID: "@user:example.com",
193+
Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster},
194+
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")},
195+
},
196+
},
197+
},
198+
},
199+
queryMasterKeysData: map[string]api.QueryMasterKeysResponse{
200+
"@user:example.com": {
201+
Key: &types.CrossSigningKey{
202+
KeyData: spec.Base64Bytes("different_key"),
158203
},
159204
},
160205
},

setup/config/config_mscs.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package config
22

3-
import "slices"
4-
53
type MSCs struct {
64
Matrix *Global `yaml:"-"`
75

@@ -46,7 +44,7 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) {
4644
}
4745

4846
func (c *MSCs) MSC3861Enabled() bool {
49-
return slices.Contains(c.MSCs, "msc3861") && c.MSC3861 != nil
47+
return c.Enabled("msc3861") && c.MSC3861 != nil
5048
}
5149

5250
type MSC3861 struct {

0 commit comments

Comments
 (0)