1515""" This module contains base REST classes for constructing REST servlets. """
1616
1717import logging
18- from typing import Iterable , List , Optional , Union , overload
18+ from typing import Dict , Iterable , List , Optional , overload
1919
2020from typing_extensions import Literal
2121
22+ from twisted .web .server import Request
23+
2224from synapse .api .errors import Codes , SynapseError
2325from synapse .util import json_decoder
2426
@@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False):
108110 return default
109111
110112
113+ @overload
114+ def parse_bytes_from_args (
115+ args : Dict [bytes , List [bytes ]],
116+ name : str ,
117+ default : Literal [None ] = None ,
118+ required : Literal [True ] = True ,
119+ ) -> bytes :
120+ ...
121+
122+
123+ @overload
124+ def parse_bytes_from_args (
125+ args : Dict [bytes , List [bytes ]],
126+ name : str ,
127+ default : Optional [bytes ] = None ,
128+ required : bool = False ,
129+ ) -> Optional [bytes ]:
130+ ...
131+
132+
133+ def parse_bytes_from_args (
134+ args : Dict [bytes , List [bytes ]],
135+ name : str ,
136+ default : Optional [bytes ] = None ,
137+ required : bool = False ,
138+ ) -> Optional [bytes ]:
139+ """
140+ Parse a string parameter as bytes from the request query string.
141+
142+ Args:
143+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
144+ name: the name of the query parameter.
145+ default: value to use if the parameter is absent,
146+ defaults to None. Must be bytes if encoding is None.
147+ required: whether to raise a 400 SynapseError if the
148+ parameter is absent, defaults to False.
149+ Returns:
150+ Bytes or the default value.
151+
152+ Raises:
153+ SynapseError if the parameter is absent and required.
154+ """
155+ name_bytes = name .encode ("ascii" )
156+
157+ if name_bytes in args :
158+ return args [name_bytes ][0 ]
159+ elif required :
160+ message = "Missing string query parameter %s" % (name ,)
161+ raise SynapseError (400 , message , errcode = Codes .MISSING_PARAM )
162+
163+ return default
164+
165+
111166def parse_string (
112- request ,
113- name : Union [ bytes , str ] ,
167+ request : Request ,
168+ name : str ,
114169 default : Optional [str ] = None ,
115170 required : bool = False ,
116171 allowed_values : Optional [Iterable [str ]] = None ,
117- encoding : Optional [ str ] = "ascii" ,
172+ encoding : str = "ascii" ,
118173):
119174 """
120175 Parse a string parameter from the request query string.
@@ -125,66 +180,65 @@ def parse_string(
125180 Args:
126181 request: the twisted HTTP request.
127182 name: the name of the query parameter.
128- default: value to use if the parameter is absent,
129- defaults to None. Must be bytes if encoding is None.
183+ default: value to use if the parameter is absent, defaults to None.
130184 required: whether to raise a 400 SynapseError if the
131185 parameter is absent, defaults to False.
132186 allowed_values: List of allowed values for the
133187 string, or None if any value is allowed, defaults to None. Must be
134188 the same type as name, if given.
135- encoding : The encoding to decode the string content with.
189+ encoding: The encoding to decode the string content with.
190+
136191 Returns:
137- A string value or the default. Unicode if encoding
138- was given, bytes otherwise.
192+ A string value or the default.
139193
140194 Raises:
141195 SynapseError if the parameter is absent and required, or if the
142196 parameter is present, must be one of a list of allowed values and
143197 is not one of those allowed values.
144198 """
199+ args = request .args # type: Dict[bytes, List[bytes]] # type: ignore
145200 return parse_string_from_args (
146- request . args , name , default , required , allowed_values , encoding
201+ args , name , default , required , allowed_values , encoding
147202 )
148203
149204
150205def _parse_string_value (
151- value : Union [ str , bytes ] ,
206+ value : bytes ,
152207 allowed_values : Optional [Iterable [str ]],
153208 name : str ,
154- encoding : Optional [str ],
155- ) -> Union [str , bytes ]:
156- if encoding :
157- try :
158- value = value .decode (encoding )
159- except ValueError :
160- raise SynapseError (400 , "Query parameter %r must be %s" % (name , encoding ))
209+ encoding : str ,
210+ ) -> str :
211+ try :
212+ value_str = value .decode (encoding )
213+ except ValueError :
214+ raise SynapseError (400 , "Query parameter %r must be %s" % (name , encoding ))
161215
162- if allowed_values is not None and value not in allowed_values :
216+ if allowed_values is not None and value_str not in allowed_values :
163217 message = "Query parameter %r must be one of [%s]" % (
164218 name ,
165219 ", " .join (repr (v ) for v in allowed_values ),
166220 )
167221 raise SynapseError (400 , message )
168222 else :
169- return value
223+ return value_str
170224
171225
172226@overload
173227def parse_strings_from_args (
174- args : List [str ],
175- name : Union [ bytes , str ] ,
228+ args : Dict [ bytes , List [bytes ] ],
229+ name : str ,
176230 default : Optional [List [str ]] = None ,
177- required : bool = False ,
231+ required : Literal [ True ] = True ,
178232 allowed_values : Optional [Iterable [str ]] = None ,
179- encoding : Literal [ None ] = None ,
180- ) -> Optional [ List [bytes ] ]:
233+ encoding : str = "ascii" ,
234+ ) -> List [str ]:
181235 ...
182236
183237
184238@overload
185239def parse_strings_from_args (
186- args : List [str ],
187- name : Union [ bytes , str ] ,
240+ args : Dict [ bytes , List [bytes ] ],
241+ name : str ,
188242 default : Optional [List [str ]] = None ,
189243 required : bool = False ,
190244 allowed_values : Optional [Iterable [str ]] = None ,
@@ -194,83 +248,71 @@ def parse_strings_from_args(
194248
195249
196250def parse_strings_from_args (
197- args : List [str ],
198- name : Union [ bytes , str ] ,
251+ args : Dict [ bytes , List [bytes ] ],
252+ name : str ,
199253 default : Optional [List [str ]] = None ,
200254 required : bool = False ,
201255 allowed_values : Optional [Iterable [str ]] = None ,
202- encoding : Optional [ str ] = "ascii" ,
203- ) -> Optional [List [Union [ bytes , str ] ]]:
256+ encoding : str = "ascii" ,
257+ ) -> Optional [List [str ]]:
204258 """
205259 Parse a string parameter from the request query string list.
206260
207- If encoding is not None, the content of the query param will be
208- decoded to Unicode using the encoding, otherwise it will be encoded
261+ The content of the query param will be decoded to Unicode using the encoding.
209262
210263 Args:
211- args: the twisted HTTP request. args list.
264+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args) .
212265 name: the name of the query parameter.
213- default: value to use if the parameter is absent,
214- defaults to None. Must be bytes if encoding is None.
215- required : whether to raise a 400 SynapseError if the
266+ default: value to use if the parameter is absent, defaults to None.
267+ required: whether to raise a 400 SynapseError if the
216268 parameter is absent, defaults to False.
217- allowed_values (list[bytes|unicode]): List of allowed values for the
218- string, or None if any value is allowed, defaults to None. Must be
219- the same type as name, if given.
269+ allowed_values: List of allowed values for the
270+ string, or None if any value is allowed, defaults to None.
220271 encoding: The encoding to decode the string content with.
221272
222273 Returns:
223- A string value or the default. Unicode if encoding
224- was given, bytes otherwise.
274+ A string value or the default.
225275
226276 Raises:
227277 SynapseError if the parameter is absent and required, or if the
228278 parameter is present, must be one of a list of allowed values and
229279 is not one of those allowed values.
230280 """
281+ name_bytes = name .encode ("ascii" )
231282
232- if not isinstance (name , bytes ):
233- name = name .encode ("ascii" )
234-
235- if name in args :
236- values = args [name ]
283+ if name_bytes in args :
284+ values = args [name_bytes ]
237285
238286 return [
239287 _parse_string_value (value , allowed_values , name = name , encoding = encoding )
240288 for value in values
241289 ]
242290 else :
243291 if required :
244- message = "Missing string query parameter %r" % (name )
292+ message = "Missing string query parameter %r" % (name , )
245293 raise SynapseError (400 , message , errcode = Codes .MISSING_PARAM )
246- else :
247-
248- if encoding and isinstance (default , bytes ):
249- return default .decode (encoding )
250294
251- return default
295+ return default
252296
253297
254298def parse_string_from_args (
255- args : List [str ],
256- name : Union [ bytes , str ] ,
299+ args : Dict [ bytes , List [bytes ] ],
300+ name : str ,
257301 default : Optional [str ] = None ,
258302 required : bool = False ,
259303 allowed_values : Optional [Iterable [str ]] = None ,
260- encoding : Optional [ str ] = "ascii" ,
261- ) -> Optional [Union [ bytes , str ] ]:
304+ encoding : str = "ascii" ,
305+ ) -> Optional [str ]:
262306 """
263307 Parse the string parameter from the request query string list
264308 and return the first result.
265309
266- If encoding is not None, the content of the query param will be
267- decoded to Unicode using the encoding, otherwise it will be encoded
310+ The content of the query param will be decoded to Unicode using the encoding.
268311
269312 Args:
270- args: the twisted HTTP request. args list.
313+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args) .
271314 name: the name of the query parameter.
272- default: value to use if the parameter is absent,
273- defaults to None. Must be bytes if encoding is None.
315+ default: value to use if the parameter is absent, defaults to None.
274316 required: whether to raise a 400 SynapseError if the
275317 parameter is absent, defaults to False.
276318 allowed_values: List of allowed values for the
@@ -279,8 +321,7 @@ def parse_string_from_args(
279321 encoding: The encoding to decode the string content with.
280322
281323 Returns:
282- A string value or the default. Unicode if encoding
283- was given, bytes otherwise.
324+ A string value or the default.
284325
285326 Raises:
286327 SynapseError if the parameter is absent and required, or if the
@@ -291,12 +332,15 @@ def parse_string_from_args(
291332 strings = parse_strings_from_args (
292333 args ,
293334 name ,
294- default = [default ],
335+ default = [default ] if default is not None else None ,
295336 required = required ,
296337 allowed_values = allowed_values ,
297338 encoding = encoding ,
298339 )
299340
341+ if strings is None :
342+ return None
343+
300344 return strings [0 ]
301345
302346
@@ -388,9 +432,8 @@ class attribute containing a pre-compiled regular expression. The automatic
388432
389433 def register (self , http_server ):
390434 """ Register this servlet with the given HTTP server. """
391- if hasattr (self , "PATTERNS" ):
392- patterns = self .PATTERNS
393-
435+ patterns = getattr (self , "PATTERNS" , None )
436+ if patterns :
394437 for method in ("GET" , "PUT" , "POST" , "DELETE" ):
395438 if hasattr (self , "on_%s" % (method ,)):
396439 servlet_classname = self .__class__ .__name__
0 commit comments