22
33import json
44from typing import cast
5+ from unittest .mock import MagicMock
56
67import cql2
78import pytest
@@ -233,28 +234,17 @@ def test_search_post(
233234 input_query ,
234235 token_builder ,
235236):
236- """Test filter is applied to search with full-featured filtering."""
237- # Setup app
238- app = app_factory (
239- upstream_url = source_api_server ,
240- items_filter = {
241- "cls" : "stac_auth_proxy.filters.Template" ,
242- "args" : [filter_template_expr .strip ()],
243- },
244- default_public = True ,
245- )
246-
247- # Query API
248- headers = (
249- {"Authorization" : f"Bearer { token_builder ({})} " } if is_authenticated else {}
250- )
251- response = TestClient (app , headers = headers ).post ("/search" , json = input_query )
237+ """Append body with generated CQL2 query."""
238+ response = _build_client (
239+ src_api_server = source_api_server ,
240+ template_expr = filter_template_expr ,
241+ is_authenticated = is_authenticated ,
242+ token_builder = token_builder ,
243+ ).post ("/search" , json = input_query )
252244 response .raise_for_status ()
253245
254246 # Retrieve query from upstream
255- assert mock_upstream .call_count == 1
256- [r ] = cast (list [Request ], mock_upstream .call_args [0 ])
257- output_query = json .loads (r .read ().decode ())
247+ upstream_body = json .loads (_get_upstream_request (mock_upstream )[0 ])
258248
259249 # Parse query from upstream
260250 input_filter = input_query .get ("filter" )
@@ -271,7 +261,7 @@ def test_search_post(
271261 }
272262
273263 assert (
274- output_query == expected_output_query
264+ upstream_body == expected_output_query
275265 ), "Query should be combined with the filter expression."
276266
277267
@@ -348,29 +338,18 @@ def test_search_get(
348338 input_query ,
349339 token_builder ,
350340):
351- """Test filter is applied to search with fimple filtering."""
352- # Setup app
353- app = app_factory (
354- upstream_url = source_api_server ,
355- items_filter = {
356- "cls" : "stac_auth_proxy.filters.Template" ,
357- "args" : [filter_template_expr .strip ()],
358- },
359- default_public = True ,
360- )
361-
362- # Query API
363- headers = (
364- {"Authorization" : f"Bearer { token_builder ({})} " } if is_authenticated else {}
365- )
366- response = TestClient (app , headers = headers ).get ("/search" , params = input_query )
341+ """Append query params with generated CQL2 query."""
342+ response = _build_client (
343+ src_api_server = source_api_server ,
344+ template_expr = filter_template_expr ,
345+ is_authenticated = is_authenticated ,
346+ token_builder = token_builder ,
347+ ).get ("/search" , params = input_query )
367348 response .raise_for_status ()
368349
369350 # Retrieve query from upstream
370- assert mock_upstream .call_count == 1
371- [r ] = cast (list [Request ], mock_upstream .call_args [0 ])
372- assert r .read ().decode () == ""
373- upstream_querystring = dict (r .url .params )
351+ upstream_body , upstream_querystring = _get_upstream_request (mock_upstream )
352+ assert upstream_body == ""
374353
375354 # Parse query from upstream
376355 input_filter = input_query .get ("filter" )
@@ -390,3 +369,33 @@ def test_search_get(
390369 assert (
391370 upstream_querystring == expected_output_query
392371 ), "Query should be combined with the filter expression."
372+
373+
374+ def _build_client (
375+ * ,
376+ src_api_server : str ,
377+ template_expr : str ,
378+ is_authenticated : bool ,
379+ token_builder ,
380+ ):
381+ # Setup app
382+ app = app_factory (
383+ upstream_url = src_api_server ,
384+ items_filter = {
385+ "cls" : "stac_auth_proxy.filters.Template" ,
386+ "args" : [template_expr .strip ()],
387+ },
388+ default_public = True ,
389+ )
390+
391+ # Query API
392+ headers = (
393+ {"Authorization" : f"Bearer { token_builder ({})} " } if is_authenticated else {}
394+ )
395+ return TestClient (app , headers = headers )
396+
397+
398+ def _get_upstream_request (mock_upstream : MagicMock ):
399+ assert mock_upstream .call_count == 1
400+ [r ] = cast (list [Request ], mock_upstream .call_args [0 ])
401+ return (r .read ().decode (), dict (r .url .params ))
0 commit comments