Skip to content

Commit dd053e9

Browse files
authored
Merge pull request #643 from atlanhq/APP-6766
APP-6766: Added default `typeName=Referenceable` for must clauses
2 parents 4e9e898 + 64d2bc6 commit dd053e9

File tree

2 files changed

+75
-14
lines changed

2 files changed

+75
-14
lines changed

pyatlan/client/asset.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,30 +172,39 @@ def _get_bulk_search_log_message(self, bulk):
172172
@staticmethod
173173
def _ensure_type_filter_present(criteria: IndexSearchRequest) -> None:
174174
"""
175-
Ensures that at least one 'typeName' filter is present in the given search criteria.
176-
If no such filter exists, appends a default filter for 'Referenceable'.
175+
Ensures that at least one 'typeName' filter is present in both 'must' and 'filter' clauses.
176+
If missing in either, appends a default filter for 'Referenceable' to that clause.
177177
"""
178178
if not (
179179
criteria
180180
and criteria.dsl
181181
and criteria.dsl.query
182182
and isinstance(criteria.dsl.query, Bool)
183-
and criteria.dsl.query.filter
184-
and isinstance(criteria.dsl.query.filter, list)
185183
):
186184
return
187185

188-
has_type_filter = any(
189-
isinstance(f, (Term, Terms))
190-
and f.field == Referenceable.TYPE_NAME.keyword_field_name
191-
for f in criteria.dsl.query.filter
192-
)
186+
query = criteria.dsl.query
187+
default_filter = Term.with_super_type_names(Referenceable.__name__)
188+
type_field = Referenceable.TYPE_NAME.keyword_field_name
193189

194-
if not has_type_filter:
195-
criteria.dsl.query.filter.append(
196-
Term.with_super_type_names(Referenceable.__name__)
190+
def needs_type_filter(clause: Optional[List]) -> bool:
191+
return not any(
192+
isinstance(f, (Term, Terms)) and f.field == type_field
193+
for f in clause or []
197194
)
198195

196+
# Update 'filter' clause if needed
197+
if needs_type_filter(query.filter):
198+
if query.filter is None:
199+
query.filter = []
200+
query.filter.append(default_filter)
201+
202+
# Update 'must' clause if needed
203+
if needs_type_filter(query.must):
204+
if query.must is None:
205+
query.must = []
206+
query.must.append(default_filter)
207+
199208
# TODO: Try adding @validate_arguments to this method once
200209
# the issue below is fixed or when we switch to pydantic v2
201210
# https://github.com/atlanhq/atlan-python/pull/88#discussion_r1260892704

tests/unit/test_client.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from pyatlan.model.group import GroupRequest
5656
from pyatlan.model.lineage import LineageListRequest
5757
from pyatlan.model.response import AssetMutationResponse
58-
from pyatlan.model.search import Bool, Term, TermAttributes
58+
from pyatlan.model.search import DSL, Bool, IndexSearchRequest, Term, TermAttributes
5959
from pyatlan.model.search_log import SearchLogRequest
6060
from pyatlan.model.typedef import EnumDef
6161
from pyatlan.model.user import AtlanUser, UserRequest
@@ -1606,7 +1606,7 @@ def test_index_search_with_no_aggregation_results(
16061606
mock_api_caller.reset_mock()
16071607

16081608

1609-
def test_type_name_in_asset_search(mock_api_caller):
1609+
def test_type_name_in_asset_search_bool_filter(mock_api_caller):
16101610
# When the type name is not present in the request
16111611
request = (FluentSearch().where(CompoundQuery.active_assets())).to_request()
16121612
client = AssetClient(mock_api_caller)
@@ -1656,6 +1656,58 @@ def test_type_name_in_asset_search(mock_api_caller):
16561656
assert has_type_filter is False
16571657

16581658

1659+
def test_type_name_in_asset_search_bool_must(mock_api_caller):
1660+
# When the type name is not present in the request
1661+
query = Bool(must=[Term.with_state("ACTIVE")])
1662+
request = IndexSearchRequest(dsl=DSL(query=query))
1663+
1664+
client = AssetClient(mock_api_caller)
1665+
client._ensure_type_filter_present(request)
1666+
1667+
assert request.dsl.query and request.dsl.query.must
1668+
assert isinstance(request.dsl.query.must, list)
1669+
1670+
has_type_filter = any(
1671+
isinstance(f, Term) and f.field == TermAttributes.SUPER_TYPE_NAMES.value
1672+
for f in request.dsl.query.must
1673+
)
1674+
assert has_type_filter is True
1675+
1676+
# When the type name is present in the request (no need to add super type filter)
1677+
query = Bool(must=[Term.with_state("ACTIVE"), Term.with_type_name("AtlasGlossary")])
1678+
request = IndexSearchRequest(dsl=DSL(query=query))
1679+
client._ensure_type_filter_present(request)
1680+
1681+
assert request.dsl.query and request.dsl.query.must
1682+
assert isinstance(request.dsl.query.must, list)
1683+
1684+
has_type_filter = any(
1685+
isinstance(f, Term) and f.field == TermAttributes.SUPER_TYPE_NAMES.value
1686+
for f in request.dsl.query.must
1687+
)
1688+
assert has_type_filter is False
1689+
1690+
# When multiple type name(s) is present in the request (no need to add super type filter)
1691+
query = Bool(
1692+
must=[
1693+
Term.with_state("ACTIVE"),
1694+
Term.with_type_name("AtlasGlossary"),
1695+
Term.with_type_name("AtlasGlossaryTerm"),
1696+
]
1697+
)
1698+
request = IndexSearchRequest(dsl=DSL(query=query))
1699+
client._ensure_type_filter_present(request)
1700+
1701+
assert request.dsl.query and request.dsl.query.must
1702+
assert isinstance(request.dsl.query.must, list)
1703+
1704+
has_type_filter = any(
1705+
isinstance(f, Term) and f.field == TermAttributes.SUPER_TYPE_NAMES.value
1706+
for f in request.dsl.query.must
1707+
)
1708+
assert has_type_filter is False
1709+
1710+
16591711
def _assert_search_results(results, response_json, sorts, bulk=False):
16601712
for i, result in enumerate(results):
16611713
assert result and response_json["entities"][i]

0 commit comments

Comments
 (0)