Skip to content

Commit 8780152

Browse files
committed
Refactor tests
1 parent d308bc4 commit 8780152

File tree

1 file changed

+49
-40
lines changed

1 file changed

+49
-40
lines changed

tests/test_filters_jinja2.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
from typing import cast
5+
from unittest.mock import MagicMock
56

67
import cql2
78
import pytest
@@ -233,28 +234,17 @@ def test_search_post(
233234
input_query,
234235
token_builder,
235236
):
236-
"""Test filter is applied to search with full-featured filtering."""
237-
# Setup app
238-
app = app_factory(
239-
upstream_url=source_api_server,
240-
items_filter={
241-
"cls": "stac_auth_proxy.filters.Template",
242-
"args": [filter_template_expr.strip()],
243-
},
244-
default_public=True,
245-
)
246-
247-
# Query API
248-
headers = (
249-
{"Authorization": f"Bearer {token_builder({})}"} if is_authenticated else {}
250-
)
251-
response = TestClient(app, headers=headers).post("/search", json=input_query)
237+
"""Append body with generated CQL2 query."""
238+
response = _build_client(
239+
src_api_server=source_api_server,
240+
template_expr=filter_template_expr,
241+
is_authenticated=is_authenticated,
242+
token_builder=token_builder,
243+
).post("/search", json=input_query)
252244
response.raise_for_status()
253245

254246
# Retrieve query from upstream
255-
assert mock_upstream.call_count == 1
256-
[r] = cast(list[Request], mock_upstream.call_args[0])
257-
output_query = json.loads(r.read().decode())
247+
upstream_body = json.loads(_get_upstream_request(mock_upstream)[0])
258248

259249
# Parse query from upstream
260250
input_filter = input_query.get("filter")
@@ -271,7 +261,7 @@ def test_search_post(
271261
}
272262

273263
assert (
274-
output_query == expected_output_query
264+
upstream_body == expected_output_query
275265
), "Query should be combined with the filter expression."
276266

277267

@@ -348,29 +338,18 @@ def test_search_get(
348338
input_query,
349339
token_builder,
350340
):
351-
"""Test filter is applied to search with fimple filtering."""
352-
# Setup app
353-
app = app_factory(
354-
upstream_url=source_api_server,
355-
items_filter={
356-
"cls": "stac_auth_proxy.filters.Template",
357-
"args": [filter_template_expr.strip()],
358-
},
359-
default_public=True,
360-
)
361-
362-
# Query API
363-
headers = (
364-
{"Authorization": f"Bearer {token_builder({})}"} if is_authenticated else {}
365-
)
366-
response = TestClient(app, headers=headers).get("/search", params=input_query)
341+
"""Append query params with generated CQL2 query."""
342+
response = _build_client(
343+
src_api_server=source_api_server,
344+
template_expr=filter_template_expr,
345+
is_authenticated=is_authenticated,
346+
token_builder=token_builder,
347+
).get("/search", params=input_query)
367348
response.raise_for_status()
368349

369350
# Retrieve query from upstream
370-
assert mock_upstream.call_count == 1
371-
[r] = cast(list[Request], mock_upstream.call_args[0])
372-
assert r.read().decode() == ""
373-
upstream_querystring = dict(r.url.params)
351+
upstream_body, upstream_querystring = _get_upstream_request(mock_upstream)
352+
assert upstream_body == ""
374353

375354
# Parse query from upstream
376355
input_filter = input_query.get("filter")
@@ -390,3 +369,33 @@ def test_search_get(
390369
assert (
391370
upstream_querystring == expected_output_query
392371
), "Query should be combined with the filter expression."
372+
373+
374+
def _build_client(
375+
*,
376+
src_api_server: str,
377+
template_expr: str,
378+
is_authenticated: bool,
379+
token_builder,
380+
):
381+
# Setup app
382+
app = app_factory(
383+
upstream_url=src_api_server,
384+
items_filter={
385+
"cls": "stac_auth_proxy.filters.Template",
386+
"args": [template_expr.strip()],
387+
},
388+
default_public=True,
389+
)
390+
391+
# Query API
392+
headers = (
393+
{"Authorization": f"Bearer {token_builder({})}"} if is_authenticated else {}
394+
)
395+
return TestClient(app, headers=headers)
396+
397+
398+
def _get_upstream_request(mock_upstream: MagicMock):
399+
assert mock_upstream.call_count == 1
400+
[r] = cast(list[Request], mock_upstream.call_args[0])
401+
return (r.read().decode(), dict(r.url.params))

0 commit comments

Comments
 (0)