11"""Tests for Jinja2 CQL2 filter."""
22
3- from dataclasses import dataclass
4- from typing import Generator
5- from unittest .mock import AsyncMock , MagicMock , patch
63from urllib .parse import parse_qs
74
85import httpx
96import pytest
107from fastapi .testclient import TestClient
118from utils import AppFactory
129
10+ from tests .utils import single_chunk_async_stream_response
11+
1312app_factory = AppFactory (
1413 oidc_discovery_url = "https://example-stac-api.com/.well-known/openid-configuration" ,
1514 default_public = False ,
1615)
1716
1817
19- @pytest .fixture
20- def mock_send () -> Generator [MagicMock , None , None ]:
21- """Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
22- with patch (
23- "stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send" ,
24- new_callable = AsyncMock ,
25- ) as mock_send_method :
26- yield mock_send_method
27-
28-
29- @dataclass
30- class SingleChunkAsyncStream (httpx .AsyncByteStream ):
31- """Mock async stream that returns a single chunk of data."""
32-
33- body : bytes
34-
35- async def __aiter__ (self ):
36- """Return a single chunk of data."""
37- yield self .body
38-
39-
4018def test_collections_filter_contained_by_token (
41- mock_send , source_api_server , token_builder
19+ mock_upstream , source_api_server , token_builder
4220):
4321 """Test that the collections filter is applied correctly."""
4422 # Mock response from upstream API
45- mock_send .return_value = httpx .Response (
46- 200 ,
47- stream = SingleChunkAsyncStream (b"{}" ),
48- headers = {"content-type" : "application/json" },
49- )
23+ mock_upstream .return_value = single_chunk_async_stream_response (b"{}" )
5024
5125 app = app_factory (
5226 upstream_url = source_api_server ,
@@ -59,15 +33,49 @@ def test_collections_filter_contained_by_token(
5933 )
6034
6135 auth_token = token_builder ({"collections" : ["foo" , "bar" ]})
62- client = TestClient (
63- app ,
64- headers = {"Authorization" : f"Bearer { auth_token } " },
65- )
66-
36+ client = TestClient (app , headers = {"Authorization" : f"Bearer { auth_token } " })
6737 response = client .get ("/collections" )
38+
6839 assert response .status_code == 200
69- assert mock_send .call_count == 1
70- [r ] = mock_send .call_args [0 ]
40+ assert mock_upstream .call_count == 1
41+ [r ] = mock_upstream .call_args [0 ]
7142 assert parse_qs (r .url .query .decode ()) == {
7243 "filter" : ["a_containedby(id, ('foo', 'bar'))" ]
7344 }
45+
46+
47+ @pytest .mark .parametrize (
48+ "authenticated, expected_filter" ,
49+ [
50+ (True , "true" ),
51+ (False , "(private = false)" ),
52+ ],
53+ )
54+ def test_collections_filter_private_and_public (
55+ mock_upstream , source_api_server , token_builder , authenticated , expected_filter
56+ ):
57+ """Test that filter can be used for private/public collections."""
58+ # Mock response from upstream API
59+ mock_upstream .return_value = single_chunk_async_stream_response (b"{}" )
60+
61+ app = app_factory (
62+ upstream_url = source_api_server ,
63+ collections_filter = {
64+ "cls" : "stac_auth_proxy.filters.Template" ,
65+ "args" : ["{{ '(private = false)' if token is none else true }}" ],
66+ },
67+ default_public = True ,
68+ )
69+
70+ client = TestClient (
71+ app ,
72+ headers = (
73+ {"Authorization" : f"Bearer { token_builder ({})} " } if authenticated else {}
74+ ),
75+ )
76+ response = client .get ("/collections" )
77+
78+ assert response .status_code == 200
79+ assert mock_upstream .call_count == 1
80+ [r ] = mock_upstream .call_args [0 ]
81+ assert parse_qs (r .url .query .decode ()) == {"filter" : [expected_filter ]}
0 commit comments