11"""Helper to deal with querystring parameters according to jsonapi specification."""
2+ from collections import defaultdict
23from functools import cached_property
34from typing import (
45 TYPE_CHECKING ,
78 List ,
89 Optional ,
910 Type ,
10- Union ,
1111)
1212from urllib .parse import unquote
1313
2222)
2323from starlette .datastructures import QueryParams
2424
25+ from fastapi_jsonapi .api import RoutersJSONAPI
2526from fastapi_jsonapi .exceptions import (
2627 BadRequest ,
2728 InvalidField ,
2829 InvalidFilters ,
2930 InvalidInclude ,
3031 InvalidSort ,
32+ InvalidType ,
3133)
3234from fastapi_jsonapi .schema import (
3335 get_model_field ,
3436 get_relationships ,
35- get_schema_from_type ,
3637)
3738from fastapi_jsonapi .splitter import SPLIT_REL
3839
@@ -89,33 +90,45 @@ def __init__(self, request: Request) -> None:
8990 self .MAX_INCLUDE_DEPTH : int = self .config .get ("MAX_INCLUDE_DEPTH" , 3 )
9091 self .headers : HeadersQueryStringManager = HeadersQueryStringManager (** dict (self .request .headers ))
9192
92- def _get_key_values (self , name : str ) -> Dict [str , Union [List [str ], str ]]:
93+ def _extract_item_key (self , key : str ) -> str :
94+ try :
95+ key_start = key .index ("[" ) + 1
96+ key_end = key .index ("]" )
97+ return key [key_start :key_end ]
98+ except Exception :
99+ msg = "Parse error"
100+ raise BadRequest (msg , parameter = key )
101+
102+ def _get_unique_key_values (self , name : str ) -> Dict [str , str ]:
93103 """
94104 Return a dict containing key / values items for a given key, used for items like filters, page, etc.
95105
96106 :param name: name of the querystring parameter
97107 :return: a dict of key / values items
98108 :raises BadRequest: if an error occurred while parsing the querystring.
99109 """
100- results : Dict [ str , Union [ List [ str ], str ]] = {}
110+ results = {}
101111
102112 for raw_key , value in self .qs .multi_items ():
103113 key = unquote (raw_key )
104- try :
105- if not key .startswith (name ):
106- continue
114+ if not key .startswith (name ):
115+ continue
107116
108- key_start = key .index ("[" ) + 1
109- key_end = key .index ("]" )
110- item_key = key [key_start :key_end ]
117+ item_key = self ._extract_item_key (key )
118+ results [item_key ] = value
111119
112- if "," in value :
113- results .update ({item_key : value .split ("," )})
114- else :
115- results .update ({item_key : value })
116- except Exception :
117- msg = "Parse error"
118- raise BadRequest (msg , parameter = key )
120+ return results
121+
122+ def _get_multiple_key_values (self , name : str ) -> Dict [str , List ]:
123+ results = defaultdict (list )
124+
125+ for raw_key , value in self .qs .multi_items ():
126+ key = unquote (raw_key )
127+ if not key .startswith (name ):
128+ continue
129+
130+ item_key = self ._extract_item_key (key )
131+ results [item_key ].extend (value .split ("," ))
119132
120133 return results
121134
@@ -134,7 +147,7 @@ def querystring(self) -> Dict[str, str]:
134147 return {
135148 key : value
136149 for (key , value ) in self .qs .multi_items ()
137- if key .startswith (self .managed_keys ) or self ._get_key_values ("filter[" )
150+ if key .startswith (self .managed_keys ) or self ._get_unique_key_values ("filter[" )
138151 }
139152
140153 @property
@@ -159,8 +172,8 @@ def filters(self) -> List[dict]:
159172 raise InvalidFilters (msg )
160173
161174 results .extend (loaded_filters )
162- if self ._get_key_values ("filter[" ):
163- results .extend (self ._simple_filters (self . _get_key_values ( "filter[" ) ))
175+ if filter_key_values := self ._get_unique_key_values ("filter[" ):
176+ results .extend (self ._simple_filters (filter_key_values ))
164177 return results
165178
166179 @cached_property
@@ -186,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager:
186199 :raises BadRequest: if the client is not allowed to disable pagination.
187200 """
188201 # check values type
189- pagination_data : Dict [str , Union [ List [ str ], str ]] = self ._get_key_values ("page" )
202+ pagination_data : Dict [str , str ] = self ._get_unique_key_values ("page" )
190203 pagination = PaginationQueryStringManager (** pagination_data )
191204 if pagination_data .get ("size" ) is None :
192205 pagination .size = None
@@ -199,8 +212,6 @@ def pagination(self) -> PaginationQueryStringManager:
199212
200213 return pagination
201214
202- # TODO: finally use this! upgrade Sqlachemy Data Layer
203- # and add to all views (get list/detail, create, patch)
204215 @property
205216 def fields (self ) -> Dict [str , List [str ]]:
206217 """
@@ -216,26 +227,32 @@ def fields(self) -> Dict[str, List[str]]:
216227
217228 :raises InvalidField: if result field not in schema.
218229 """
219- if self .request .method != "GET" :
220- msg = "attribute 'fields' allowed only for GET-method"
221- raise InvalidField (msg )
222- fields = self ._get_key_values ("fields" )
223- for key , value in fields .items ():
224- if not isinstance (value , list ):
225- value = [value ] # noqa: PLW2901
226- fields [key ] = value
230+ fields = self ._get_multiple_key_values ("fields" )
231+ for resource_type , field_names in fields .items ():
227232 # TODO: we have registry for models (BaseModel)
228233 # TODO: create `type to schemas` registry
229- schema : Type [BaseModel ] = get_schema_from_type (key , self .app )
230- for field in value :
231- if field not in schema .__fields__ :
234+
235+ if resource_type not in RoutersJSONAPI .all_jsonapi_routers :
236+ msg = f"Application has no resource with type { resource_type !r} "
237+ raise InvalidType (msg )
238+
239+ schema : Type [BaseModel ] = self ._get_schema (resource_type )
240+
241+ for field_name in field_names :
242+ if field_name == "" :
243+ continue
244+
245+ if field_name not in schema .__fields__ :
232246 msg = "{schema} has no attribute {field}" .format (
233247 schema = schema .__name__ ,
234- field = field ,
248+ field = field_name ,
235249 )
236250 raise InvalidField (msg )
237251
238- return fields
252+ return {resource_type : set (field_names ) for resource_type , field_names in fields .items ()}
253+
254+ def _get_schema (self , resource_type : str ) -> Type [BaseModel ]:
255+ return RoutersJSONAPI .all_jsonapi_routers [resource_type ]._schema
239256
240257 def get_sorts (self , schema : Type ["TypeSchema" ]) -> List [Dict [str , str ]]:
241258 """
0 commit comments