Skip to content

Commit 0608b22

Browse files
committed
Normalize multipart form data, fix imports
1 parent f580492 commit 0608b22

File tree

6 files changed

+295
-166
lines changed

6 files changed

+295
-166
lines changed

pygeoapi/api/__init__.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import re
5151
import sys
5252
from typing import Any, Tuple, Union, Self
53+
from io import BytesIO
5354

5455
from babel import Locale
5556
from dateutil.parser import parse as dateparse
@@ -64,9 +65,8 @@
6465
from pygeoapi.provider.base import (
6566
ProviderConnectionError, ProviderGenericError, ProviderTypeError)
6667

67-
from pygeoapi.join_util import initialize_joins
6868
from pygeoapi.util import (
69-
TEMPLATESDIR, UrlPrefetcher, dategetter,
69+
TEMPLATESDIR, UrlPrefetcher, dategetter, FileObject,
7070
filter_dict_by_key_value, filter_providers_by_type, get_api_rules,
7171
get_base_url, get_provider_by_type, get_provider_default, get_typed_value,
7272
render_j2_template, to_json, get_choice_from_headers, get_from_headers
@@ -214,12 +214,12 @@ def __init__(self, request, supported_locales):
214214
# Set default request data
215215
self._data = b''
216216

217-
# Form data
218-
self._form = {}
219-
220217
# Copy request query parameters
221218
self._args = self._get_params(request)
222219

220+
# Form data: populate in from_* factory methods
221+
self._form = {}
222+
223223
# Get path info
224224
if hasattr(request, 'scope'):
225225
self._path_info = request.scope['path'].strip('/')
@@ -243,13 +243,12 @@ def from_flask(cls, request, supported_locales) -> 'APIRequest':
243243
"""Factory class similar to with_data, but only for flask requests"""
244244
api_req = cls(request, supported_locales)
245245
api_req._data = request.data
246-
# TODO: quick hack to retrieve multipart form data
247-
if hasattr(request, 'form'):
248-
for key, value in request.form.items():
249-
api_req._form[key] = value
250-
if hasattr(request, 'files'):
251-
for key, value in request.files.items():
252-
api_req._form[key] = value
246+
for key, value in cls._formdata_flask(request):
247+
LOGGER.debug(f"Setting form field '{key}'")
248+
if key in api_req._form:
249+
LOGGER.debug(f"Skipping duplicate form field '{key}'")
250+
continue
251+
api_req._form[key] = value
253252
return api_req
254253

255254
@classmethod
@@ -258,15 +257,66 @@ async def from_starlette(cls, request, supported_locales) -> 'APIRequest':
258257
"""
259258
api_req = cls(request, supported_locales)
260259
api_req._data = await request.body()
260+
async for key, value in cls._formdata_starlette(request):
261+
LOGGER.debug(f"Setting form field '{key}'")
262+
if key in api_req._form:
263+
LOGGER.debug(f"Skipping duplicate form field '{key}'")
264+
continue
265+
api_req._form[key] = value
261266
return api_req
262267

263268
@classmethod
264269
def from_django(cls, request, supported_locales) -> 'APIRequest':
265270
"""Factory class similar to with_data, but only for django requests"""
266271
api_req = cls(request, supported_locales)
267272
api_req._data = request.body
273+
for key, value in cls._formdata_django(request):
274+
LOGGER.debug(f"Setting form field '{key}'")
275+
if key in api_req._form:
276+
LOGGER.debug(f"Skipping duplicate form field '{key}'")
277+
continue
278+
api_req._form[key] = value
268279
return api_req
269280

281+
@staticmethod
282+
def _formdata_flask(request):
283+
""" Normalize Flask/Werkzeug form data. """
284+
285+
for key, value in getattr(request, 'form', {}).items():
286+
yield key, value
287+
288+
for key, file_obj in getattr(request, 'files', {}).items():
289+
yield key, FileObject(file_obj.filename, file_obj.content_type,
290+
BytesIO(file_obj.read()))
291+
292+
@staticmethod
293+
def _formdata_django(request):
294+
""" Normalize Django form data. """
295+
296+
for key, value in getattr(request, 'POST', {}).items():
297+
yield key, value
298+
299+
for key, file_obj in getattr(request, 'FILES', {}).items():
300+
yield key, FileObject(file_obj.name, file_obj.content_type,
301+
BytesIO(file_obj.read()))
302+
303+
@staticmethod
304+
async def _formdata_starlette(request):
305+
""" Normalize Starlette/FastAPI form data (async). """
306+
307+
form = await request.form()
308+
309+
for key, value in form.items():
310+
if hasattr(value, 'filename'):
311+
# It's a file: for Starlette, we need to read async
312+
content = await value.read()
313+
file_obj = FileObject(value.filename, value.content_type,
314+
BytesIO(content))
315+
yield key, file_obj
316+
else:
317+
# Regular form field
318+
yield key, value
319+
270320
@staticmethod
271321
def _get_params(request):
272322
"""
@@ -364,7 +414,7 @@ def data(self) -> bytes:
364414

365415
@property
366416
def form(self) -> dict:
367-
"""Returns the Request form data dict"""
417+
"""Returns the Request form data dict (multipart/form-data)"""
368418
return self._form
369419

370420
@property
@@ -556,8 +606,13 @@ def __init__(self, config: dict, openapi: dict) -> Self | None:
556606
self.base_url = get_base_url(self.config)
557607
self.prefetcher = UrlPrefetcher()
558608

559-
# Build reference cache of join tables already/still on the server
560-
initialize_joins(config)
609+
setup_logger(self.config['logging'])
610+
611+
joins_api = all_apis().get('joins')
612+
if joins_api:
613+
# Initialize OGC API - Joins:
614+
# build reference cache of join tables already/still on the server
615+
joins_api.init(config)
561616

562617
CHARSET[0] = config['server'].get('encoding', 'utf-8')
563618
if config['server'].get('gzip'):
@@ -576,8 +631,6 @@ def __init__(self, config: dict, openapi: dict) -> Self | None:
576631

577632
self.pretty_print = self.config['server']['pretty_print']
578633

579-
setup_logger(self.config['logging'])
580-
581634
# Create config clone for HTML templating with modified base URL
582635
self.tpl_config = deepcopy(self.config)
583636
self.tpl_config['server']['url'] = self.base_url

pygeoapi/api/itemtypes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@
4747
from pygeofilter.parsers.cql2_json import parse as parse_cql2_json
4848
from pyproj.exceptions import CRSError
4949

50-
import join_util
51-
from pygeoapi import l10n
50+
from pygeoapi import l10n, join_util
5251
from pygeoapi.api import evaluate_limit
5352
from pygeoapi.crs import (DEFAULT_CRS, DEFAULT_STORAGE_CRS,
5453
create_crs_transform_spec, get_supported_crs_list,
@@ -178,7 +177,7 @@ def get_collection_queryables(api: API, request: Union[APIRequest, Any],
178177
try:
179178
domains, _ = p.get_domains(properties)
180179
except NotImplementedError:
181-
LOGGER.debug('Domains are not suported by this provider')
180+
LOGGER.debug('Domains are not supported by this provider')
182181
domains = {}
183182

184183
for k, v in p.fields.items():

0 commit comments

Comments
 (0)