Skip to content

Commit d3e0dc7

Browse files
committed
Normalize multipart form data, fix imports
1 parent f580492 commit d3e0dc7

File tree

5 files changed

+188
-94
lines changed

5 files changed

+188
-94
lines changed

pygeoapi/api/__init__.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import re
5151
import sys
5252
from typing import Any, Tuple, Union, Self
53+
from io import BytesIO
5354

5455
from babel import Locale
5556
from dateutil.parser import parse as dateparse
@@ -64,9 +65,8 @@
6465
from pygeoapi.provider.base import (
6566
ProviderConnectionError, ProviderGenericError, ProviderTypeError)
6667

67-
from pygeoapi.join_util import initialize_joins
6868
from pygeoapi.util import (
69-
TEMPLATESDIR, UrlPrefetcher, dategetter,
69+
TEMPLATESDIR, UrlPrefetcher, dategetter, FileObject,
7070
filter_dict_by_key_value, filter_providers_by_type, get_api_rules,
7171
get_base_url, get_provider_by_type, get_provider_default, get_typed_value,
7272
render_j2_template, to_json, get_choice_from_headers, get_from_headers
@@ -214,12 +214,12 @@ def __init__(self, request, supported_locales):
214214
# Set default request data
215215
self._data = b''
216216

217-
# Form data
218-
self._form = {}
219-
220217
# Copy request query parameters
221218
self._args = self._get_params(request)
222219

220+
# Form data: populate in from_* factory methods
221+
self._form = {}
222+
223223
# Get path info
224224
if hasattr(request, 'scope'):
225225
self._path_info = request.scope['path'].strip('/')
@@ -243,13 +243,12 @@ def from_flask(cls, request, supported_locales) -> 'APIRequest':
243243
"""Factory class similar to with_data, but only for flask requests"""
244244
api_req = cls(request, supported_locales)
245245
api_req._data = request.data
246-
# TODO: quick hack to retrieve multipart form data
247-
if hasattr(request, 'form'):
248-
for key, value in request.form.items():
249-
api_req._form[key] = value
250-
if hasattr(request, 'files'):
251-
for key, value in request.files.items():
252-
api_req._form[key] = value
246+
for key, value in cls._formdata_flask(request):
247+
LOGGER.debug(f"Setting form field '{key}'")
248+
if key in api_req._form:
249+
LOGGER.debug(f"Skipping duplicate form field '{key}'")
250+
continue
251+
api_req._form[key] = value
253252
return api_req
254253

255254
@classmethod
@@ -258,15 +257,66 @@ async def from_starlette(cls, request, supported_locales) -> 'APIRequest':
258257
"""
259258
api_req = cls(request, supported_locales)
260259
api_req._data = await request.body()
260+
async for key, value in cls._formdata_starlette(request):
261+
LOGGER.debug(f"Setting form field '{key}'")
262+
if key in api_req._form:
263+
LOGGER.debug(f"Skipping duplicate form field '{key}'")
264+
continue
265+
api_req._form[key] = value
261266
return api_req
262267

263268
@classmethod
264269
def from_django(cls, request, supported_locales) -> 'APIRequest':
265270
"""Factory class similar to with_data, but only for django requests"""
266271
api_req = cls(request, supported_locales)
267272
api_req._data = request.body
273+
for key, value in cls._formdata_django(request):
274+
LOGGER.debug(f"Setting form field '{key}'")
275+
if key in api_req._form:
276+
LOGGER.debug(f"Skipping duplicate form field '{key}'")
277+
continue
278+
api_req._form[key] = value
268279
return api_req
269280

281+
@staticmethod
282+
def _formdata_flask(request):
283+
""" Normalize Flask/Werkzeug form data. """
284+
285+
for key, value in getattr(request, 'form', {}).items():
286+
yield key, value
287+
288+
for key, file_obj in getattr(request, 'files', {}).items():
289+
yield key, FileObject(file_obj.filename, file_obj.content_type,
290+
BytesIO(file_obj.read()))
291+
292+
@staticmethod
293+
def _formdata_django(request):
294+
""" Normalize Django form data. """
295+
296+
for key, value in getattr(request, 'POST', {}).items():
297+
yield key, value
298+
299+
for key, file_obj in getattr(request, 'FILES', {}).items():
300+
yield key, FileObject(file_obj.name, file_obj.content_type,
301+
BytesIO(file_obj.read()))
302+
303+
@staticmethod
304+
async def _formdata_starlette(request):
305+
""" Normalize Starlette/FastAPI form data (async). """
306+
307+
form = await request.form()
308+
309+
for key, value in form.items():
310+
if hasattr(value, 'filename'):
311+
# It's a file: for Starlette, we need to read async
312+
content = await value.read()
313+
file_obj = FileObject(value.filename, value.content_type,
314+
BytesIO(content))
315+
yield key, file_obj
316+
else:
317+
# Regular form field
318+
yield key, value
319+
270320
@staticmethod
271321
def _get_params(request):
272322
"""
@@ -364,7 +414,7 @@ def data(self) -> bytes:
364414

365415
@property
366416
def form(self) -> dict:
367-
"""Returns the Request form data dict"""
417+
"""Returns the Request form data dict (multipart/form-data)"""
368418
return self._form
369419

370420
@property
@@ -556,8 +606,13 @@ def __init__(self, config: dict, openapi: dict) -> Self | None:
556606
self.base_url = get_base_url(self.config)
557607
self.prefetcher = UrlPrefetcher()
558608

559-
# Build reference cache of join tables already/still on the server
560-
initialize_joins(config)
609+
setup_logger(self.config['logging'])
610+
611+
joins_api = all_apis().get('joins')
612+
if joins_api:
613+
# Initialize OGC API - Joins:
614+
# build reference cache of join tables already/still on the server
615+
joins_api.init(config)
561616

562617
CHARSET[0] = config['server'].get('encoding', 'utf-8')
563618
if config['server'].get('gzip'):
@@ -576,8 +631,6 @@ def __init__(self, config: dict, openapi: dict) -> Self | None:
576631

577632
self.pretty_print = self.config['server']['pretty_print']
578633

579-
setup_logger(self.config['logging'])
580-
581634
# Create config clone for HTML templating with modified base URL
582635
self.tpl_config = deepcopy(self.config)
583636
self.tpl_config['server']['url'] = self.base_url

pygeoapi/api/itemtypes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@
4747
from pygeofilter.parsers.cql2_json import parse as parse_cql2_json
4848
from pyproj.exceptions import CRSError
4949

50-
import join_util
51-
from pygeoapi import l10n
50+
from pygeoapi import l10n, join_util
5251
from pygeoapi.api import evaluate_limit
5352
from pygeoapi.crs import (DEFAULT_CRS, DEFAULT_STORAGE_CRS,
5453
create_crs_transform_spec, get_supported_crs_list,
@@ -178,7 +177,7 @@ def get_collection_queryables(api: API, request: Union[APIRequest, Any],
178177
try:
179178
domains, _ = p.get_domains(properties)
180179
except NotImplementedError:
181-
LOGGER.debug('Domains are not suported by this provider')
180+
LOGGER.debug('Domains are not supported by this provider')
182181
domains = {}
183182

184183
for k, v in p.fields.items():

pygeoapi/api/joins.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,17 @@
2929
import logging
3030
from typing import Any
3131

32-
import util
3332
from pygeoapi import l10n, join_util
33+
from pygeoapi.api import (
34+
APIRequest, API, SYSTEM_LOCALE, FORMAT_TYPES,
35+
F_JSON, F_JSONLD, F_HTML, HTTPStatus
36+
)
37+
from pygeoapi.openapi import get_visible_collections
3438
from pygeoapi.plugin import load_plugin, PLUGINS
3539
from pygeoapi.provider.base import ProviderTypeError
3640
from pygeoapi.util import (
3741
get_provider_by_type, to_json, filter_providers_by_type,
38-
filter_dict_by_key_value
39-
)
40-
from pygeoapi.openapi import get_visible_collections
41-
from pygeoapi.api import (
42-
APIRequest, API, SYSTEM_LOCALE, FORMAT_TYPES,
43-
F_JSON, F_JSONLD, F_HTML, HTTPStatus
42+
filter_dict_by_key_value, get_current_datetime
4443
)
4544

4645
LOGGER = logging.getLogger(__name__)
@@ -69,6 +68,17 @@
6968
]
7069

7170

71+
def init(cfg: dict) -> bool:
72+
"""
73+
Shortcut to initialize join utility with config.
74+
75+
:param cfg: pygeoapi configuration dict
76+
77+
:returns: True if the join utility was initialized successfully
78+
"""
79+
return join_util.init(cfg)
80+
81+
7282
def get_oas_30(cfg: dict, locale: str) -> tuple[list[dict[str, str]], dict[str, dict]]: # noqa
7383
"""
7484
Get OpenAPI fragments
@@ -81,7 +91,7 @@ def get_oas_30(cfg: dict, locale: str) -> tuple[list[dict[str, str]], dict[str,
8191

8292
paths = {}
8393

84-
if not join_util.CONFIG.enabled:
94+
if not join_util.init(cfg):
8595
LOGGER.info('OpenAPI: skipping OGC API - Joins endpoints setup')
8696
return [], {'paths': paths}
8797

@@ -692,18 +702,19 @@ def list_joins(api: API, request: APIRequest, dataset: str) -> tuple[dict, int,
692702

693703
try:
694704
sources = join_util.list_sources(dataset)
695-
except KeyError:
696-
return api.get_exception(
697-
HTTPStatus.NOT_FOUND, headers, request.format,
698-
'NotFound', 'Collection not found'
699-
)
700705
except Exception as e:
701706
LOGGER.error(str(e), exc_info=True)
702707
return api.get_exception(
703708
HTTPStatus.INTERNAL_SERVER_ERROR, headers, request.format,
704709
'NoApplicableCode', str(e)
705710
)
706711

712+
if not sources:
713+
return api.get_exception(
714+
HTTPStatus.NOT_FOUND, headers, request.format, 'NotFound',
715+
f'No joins found for collection: {dataset}'
716+
)
717+
707718
# Build the joins list with proper structure
708719
joins_list = []
709720
for source_id, source_obj in sources:
@@ -741,7 +752,7 @@ def list_joins(api: API, request: APIRequest, dataset: str) -> tuple[dict, int,
741752
'joins': joins_list,
742753
'numberMatched': len(joins_list),
743754
'numberReturned': len(joins_list),
744-
'timeStamp': util.get_current_datetime()
755+
'timeStamp': get_current_datetime()
745756
}
746757

747758
return headers, HTTPStatus.OK, to_json(response, api.pretty_print)
@@ -907,18 +918,16 @@ def delete_join(api: API, request: APIRequest,
907918
)
908919

909920
try:
910-
join_util.remove_source(dataset, join_id)
921+
if not join_util.remove_source(dataset, join_id):
922+
msg = f'Join source {join_id} not found for collection {dataset}'
923+
return api.get_exception(
924+
HTTPStatus.NOT_FOUND, headers, request.format,
925+
'NotFound', msg)
911926
except ValueError as e:
912927
LOGGER.error(f'Invalid request parameter: {e}', exc_info=True)
913928
return api.get_exception(
914929
HTTPStatus.BAD_REQUEST, headers, request.format,
915930
'InvalidParameterValue', str(e))
916-
except KeyError as e:
917-
msg = 'Collection or join source not found'
918-
LOGGER.error(f'Invalid parameter value: {e}', exc_info=True)
919-
return api.get_exception(
920-
HTTPStatus.NOT_FOUND, headers, request.format,
921-
'NotFound', msg)
922931
except Exception as e:
923932
LOGGER.error(f'Failed to delete join source: {e}',
924933
exc_info=True)

0 commit comments

Comments
 (0)