|
16 | 16 | import re |
17 | 17 | from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple |
18 | 18 |
|
| 19 | +from pydantic import Extra, StrictInt, StrictStr |
19 | 20 | from signedjson.sign import sign_json |
20 | 21 |
|
21 | 22 | from twisted.web.server import Request |
|
24 | 25 | from synapse.http.server import HttpServer |
25 | 26 | from synapse.http.servlet import ( |
26 | 27 | RestServlet, |
| 28 | + parse_and_validate_json_object_from_request, |
27 | 29 | parse_integer, |
28 | | - parse_json_object_from_request, |
29 | 30 | ) |
| 31 | +from synapse.rest.models import RequestBodyModel |
30 | 32 | from synapse.storage.keys import FetchKeyResultForRemote |
31 | 33 | from synapse.types import JsonDict |
32 | 34 | from synapse.util import json_decoder |
|
38 | 40 | logger = logging.getLogger(__name__) |
39 | 41 |
|
40 | 42 |
|
| 43 | +class _KeyQueryCriteriaDataModel(RequestBodyModel): |
| 44 | + class Config: |
| 45 | + extra = Extra.allow |
| 46 | + |
| 47 | + minimum_valid_until_ts: Optional[StrictInt] |
| 48 | + |
| 49 | + |
41 | 50 | class RemoteKey(RestServlet): |
42 | 51 | """HTTP resource for retrieving the TLS certificate and NACL signature |
43 | 52 | verification keys for a collection of servers. Checks that the reported |
@@ -96,6 +105,9 @@ class RemoteKey(RestServlet): |
96 | 105 |
|
97 | 106 | CATEGORY = "Federation requests" |
98 | 107 |
|
| 108 | + class PostBody(RequestBodyModel): |
| 109 | + server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]] |
| 110 | + |
99 | 111 | def __init__(self, hs: "HomeServer"): |
100 | 112 | self.fetcher = ServerKeyFetcher(hs) |
101 | 113 | self.store = hs.get_datastores().main |
@@ -137,24 +149,29 @@ async def on_GET( |
137 | 149 | ) |
138 | 150 |
|
139 | 151 | minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") |
140 | | - arguments = {} |
141 | | - if minimum_valid_until_ts is not None: |
142 | | - arguments["minimum_valid_until_ts"] = minimum_valid_until_ts |
143 | | - query = {server: {key_id: arguments}} |
| 152 | + query = { |
| 153 | + server: { |
| 154 | + key_id: _KeyQueryCriteriaDataModel( |
| 155 | + minimum_valid_until_ts=minimum_valid_until_ts |
| 156 | + ) |
| 157 | + } |
| 158 | + } |
144 | 159 | else: |
145 | 160 | query = {server: {}} |
146 | 161 |
|
147 | 162 | return 200, await self.query_keys(query, query_remote_on_cache_miss=True) |
148 | 163 |
|
149 | 164 | async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: |
150 | | - content = parse_json_object_from_request(request) |
| 165 | + content = parse_and_validate_json_object_from_request(request, self.PostBody) |
151 | 166 |
|
152 | | - query = content["server_keys"] |
| 167 | + query = content.server_keys |
153 | 168 |
|
154 | 169 | return 200, await self.query_keys(query, query_remote_on_cache_miss=True) |
155 | 170 |
|
156 | 171 | async def query_keys( |
157 | | - self, query: JsonDict, query_remote_on_cache_miss: bool = False |
| 172 | + self, |
| 173 | + query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]], |
| 174 | + query_remote_on_cache_miss: bool = False, |
158 | 175 | ) -> JsonDict: |
159 | 176 | logger.info("Handling query for keys %r", query) |
160 | 177 |
|
@@ -196,8 +213,10 @@ async def query_keys( |
196 | 213 | else: |
197 | 214 | ts_added_ms = key_result.added_ts |
198 | 215 | ts_valid_until_ms = key_result.valid_until_ts |
199 | | - req_key = query.get(server_name, {}).get(key_id, {}) |
200 | | - req_valid_until = req_key.get("minimum_valid_until_ts") |
| 216 | + req_key = query.get(server_name, {}).get( |
| 217 | + key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None) |
| 218 | + ) |
| 219 | + req_valid_until = req_key.minimum_valid_until_ts |
201 | 220 | if req_valid_until is not None: |
202 | 221 | if ts_valid_until_ms < req_valid_until: |
203 | 222 | logger.debug( |
|
0 commit comments