Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion fclient/federationclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,13 @@ func (ac *federationClient) ClaimKeys(ctx context.Context, origin, s spec.Server
func (ac *federationClient) QueryKeys(ctx context.Context, origin, s spec.ServerName, keys map[string][]string) (res RespQueryKeys, err error) {
path := federationPathPrefixV1 + "/user/keys/query"
req := NewFederationRequest("POST", origin, s, path)
if err = req.SetContent(map[string]interface{}{
// Ensure that the keys map has empty slices for any nil values.
for k, v := range keys {
if v == nil {
keys[k] = []string{}
}
}
if err = req.SetContent(map[string]map[string][]string{
"device_keys": keys,
}); err != nil {
return
Expand Down
122 changes: 122 additions & 0 deletions fclient/federationclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,125 @@ func jsonify(x interface{}) string {
b, _ := json.Marshal(x)
return string(b)
}

func TestQueryKeysReturnsDeviceKeys(t *testing.T) {
serverName := spec.ServerName("local.server.name")
targetServerName := spec.ServerName("target.server.name")
keyID := gomatrixserverlib.KeyID("ed25519:auto")
_, privateKey, _ := ed25519.GenerateKey(nil)

respQueryKeysJSON := []byte(`{
"device_keys": {
"@user:target.server.name": {
"device1": {
"algorithms": ["m.olm.curve25519-aes-sha2"],
"keys": {
"curve25519:device1": "key1",
"ed25519:device1": "key2"
}
}
}
}
}`)

fc := fclient.NewFederationClient(
[]*fclient.SigningIdentity{
{
ServerName: serverName,
KeyID: keyID,
PrivateKey: privateKey,
},
},
fclient.WithSkipVerify(true),
fclient.WithTransport(
&roundTripper{
fn: func(req *http.Request) (*http.Response, error) {
if strings.HasPrefix(req.URL.Path, "/_matrix/federation/v1/user/keys/query") {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
if bytes.Contains(body, []byte("null")) {
t.Fatalf("QueryKeys request body should not contain 'null': %s", string(body))
}
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(respQueryKeysJSON)),
}, nil
}
return &http.Response{
StatusCode: 404,
Body: io.NopCloser(strings.NewReader("404 not found")),
}, nil
},
},
),
)

keys := map[string][]string{
"@user:target.server.name": {"device1"},
}
res, err := fc.QueryKeys(context.Background(), serverName, targetServerName, keys)
if err != nil {
t.Fatalf("QueryKeys returned an error: %s", err)
}
if len(res.DeviceKeys["@user:target.server.name"]["device1"].Keys) == 0 {
t.Fatalf("QueryKeys response missing device keys")
}
}

func TestQueryKeysHandlesNilDeviceKeys(t *testing.T) {
serverName := spec.ServerName("local.server.name")
targetServerName := spec.ServerName("target.server.name")
keyID := gomatrixserverlib.KeyID("ed25519:auto")
_, privateKey, _ := ed25519.GenerateKey(nil)

respQueryKeysJSON := []byte(`{
"device_keys": {}
}`)

fc := fclient.NewFederationClient(
[]*fclient.SigningIdentity{
{
ServerName: serverName,
KeyID: keyID,
PrivateKey: privateKey,
},
},
fclient.WithSkipVerify(true),
fclient.WithTransport(
&roundTripper{
fn: func(req *http.Request) (*http.Response, error) {
if strings.HasPrefix(req.URL.Path, "/_matrix/federation/v1/user/keys/query") {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
if bytes.Contains(body, []byte("null")) {
t.Fatalf("QueryKeys request body should not contain 'null': %s", string(body))
}
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(respQueryKeysJSON)),
}, nil
}
return &http.Response{
StatusCode: 404,
Body: io.NopCloser(strings.NewReader("404 not found")),
}, nil
},
},
),
)

keys := map[string][]string{
"@user:target.server.name": nil,
}
res, err := fc.QueryKeys(context.Background(), serverName, targetServerName, keys)
if err != nil {
t.Fatalf("QueryKeys returned an error: %s", err)
}
if len(res.DeviceKeys) != 0 {
t.Fatalf("QueryKeys response should be empty for nil device keys")
}
}
Loading