Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions namegraph/xcollections/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def search_by_collection(
fields = [
'data.collection_name', 'template.collection_rank', 'metadata.owner',
'metadata.members_count', 'template.top10_names.normalized_name', 'template.top10_names.namehash',
'template.collection_types', 'metadata.modified', 'data.avatar_emoji', 'data.avatar_image'
'template.collection_types', 'metadata.modified', 'data.avatar_emoji', 'data.avatar_image',
'data.archived'
]

# find collection with specified collection_id
Expand All @@ -171,6 +172,13 @@ def search_by_collection(
logger.error(f'could not find collection with id {collection_id}', exc_info=True)
raise HTTPException(status_code=404, detail=f"Collection with id={collection_id} not found.")

if found_collection.archived != False:
if found_collection.archived is None:
logger.error(f'collection with id {collection_id} has no archived field')
else:
logger.error(f'collection with id {collection_id} is archived')
raise HTTPException(status_code=410, detail=f"Collection with id={collection_id} is archived.")

es_time_first = es_response_metadata['took']
es_comm_time_first = es_response_metadata['elasticsearch_communication_time']
if es_response_metadata['n_total_hits'] > 1:
Expand Down Expand Up @@ -290,7 +298,8 @@ def get_collections_by_id_list(self, id_list: list[str]) -> list[Collection]:
fields = [
'data.collection_name', 'template.collection_rank', 'metadata.owner',
'metadata.members_count', 'template.top10_names.normalized_name', 'template.top10_names.namehash',
'template.collection_types', 'metadata.modified', 'data.avatar_emoji', 'data.avatar_image'
'template.collection_types', 'metadata.modified', 'data.avatar_emoji', 'data.avatar_image',
'data.archived'
]

try:
Expand All @@ -307,4 +316,9 @@ def get_collections_by_id_list(self, id_list: list[str]) -> list[Collection]:
logger.error(f'Elasticsearch search failed [by-id_list]', exc_info=True)
raise HTTPException(status_code=503, detail=str(ex)) from ex

return collections
filtered_collections = [c for c in collections if c.archived == False]

if len(filtered_collections) < len(collections):
logger.warning(f'{len(collections) - len(filtered_collections)} collections were archived')

return filtered_collections
4 changes: 4 additions & 0 deletions namegraph/xcollections/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
avatar_emoji: Optional[str],
avatar_image: Optional[str],
related_collections: Optional[list[dict[str, Any]]] = None,
archived: Optional[bool] = None,
# TODO do we need those above? and do we need anything else?
):
self.score = score
Expand All @@ -38,6 +39,7 @@ def __init__(
self.avatar_emoji = avatar_emoji
self.avatar_image = avatar_image
self.related_collections = related_collections
self.archived = archived

# FIXME make more universal or split into multiple methods
# FIXME should we move limit_names somewhere else?
Expand Down Expand Up @@ -76,6 +78,7 @@ def from_elasticsearch_hit(cls, hit: dict[str, Any], limit_names: int = 10) -> C
avatar_emoji=fields['data.avatar_emoji'][0] if 'data.avatar_emoji' in fields else None,
avatar_image=fields['data.avatar_image'][0] if 'data.avatar_image' in fields else None,
related_collections=_source.get('name_generator', {}).get('related_collections', None),
archived=fields.get('data.archived', [None])[0],
)

@classmethod
Expand Down Expand Up @@ -110,4 +113,5 @@ def from_elasticsearch_hit_script_names(cls, hit: dict[str, Any], limit_names: i
avatar_emoji=fields['data.avatar_emoji'][0] if 'data.avatar_emoji' in fields else None,
avatar_image=fields['data.avatar_image'][0] if 'data.avatar_image' in fields else None,
related_collections=_source.get('name_generator', {}).get('related_collections', None),
archived=fields.get('data.archived', [None])[0],
)
15 changes: 10 additions & 5 deletions namegraph/xcollections/generator_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def sample_members_from_collection(
max_sample_size: int = 10,
) -> tuple[dict, dict]:

fields = ['data.collection_name']
fields = ['data.collection_name', 'data.archived']

sampling_script = """
def number_of_names = params._source.data.names.size();
Expand Down Expand Up @@ -275,6 +275,9 @@ def index = random.nextInt(number_of_names);
except IndexError as ex:
raise HTTPException(status_code=404, detail=f'Collection with id={collection_id} not found') from ex

if hit['fields']['data.archived'][0]:
raise HTTPException(status_code=410, detail=f'Collection with id={collection_id} is archived')

result = {
'collection_id': hit['_id'],
'collection_title': hit['fields']['data.collection_name'][0],
Expand Down Expand Up @@ -302,9 +305,8 @@ def fetch_top10_members_from_collection(
logger.error(f'Elasticsearch search failed [fetch top10 collection members]', exc_info=True)
raise HTTPException(status_code=503, detail=str(ex)) from ex

# TODO: as quick fix for filtering collections (e.g. nazi) with archived=True; needs to be activated when those collection will be marked in separated field
# if response['_source']['data']['archived']:
# raise HTTPException(status_code=410, detail=f'Collection with id={collection_id} is archived')
if response['_source']['data']['archived']:
raise HTTPException(status_code=410, detail=f'Collection with id={collection_id} is archived')

es_response_metadata = {
'n_total_hits': 1,
Expand Down Expand Up @@ -340,7 +342,7 @@ def scramble_tokens_from_collection(
seed: int
) -> tuple[dict, dict]:

fields = ['data.collection_name']
fields = ['data.collection_name', 'data.archived']

query_params = ElasticsearchQueryBuilder() \
.set_term('_id', collection_id) \
Expand All @@ -367,6 +369,9 @@ def scramble_tokens_from_collection(
except IndexError as ex:
raise HTTPException(status_code=404, detail=f'Collection with id={collection_id} not found') from ex

if hit['fields']['data.archived'][0]:
raise HTTPException(status_code=410, detail=f'Collection with id={collection_id} is archived')

name_tokens_tuples = [(r['normalized_name'], r['tokenized_name']) for r in hit['fields']['names_with_tokens']]
token_scramble_tokenized_suggestions = self._get_suggestions_by_scrambling_tokens(
name_tokens_tuples, method, seed, n_suggestions=max_suggestions
Expand Down
94 changes: 94 additions & 0 deletions tests/test_collections_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pytest import mark
from fastapi.testclient import TestClient
from hydra import compose, initialize

from namegraph.domains import Domains
from namegraph.generation.categories_generator import Categories
Expand Down Expand Up @@ -540,6 +541,99 @@ def test_elasticsearch_template_collections_search_tokenization_with_spaces(self
assert 'Yu-Gi-Oh! video games' not in titles
assert 'Oh Yeon-seo filmography' not in titles

@pytest.mark.integration_test
def test_find_collections_by_string_archived(self, test_test_client):
client = test_test_client
string = 'Waffen-SS' # qKRMjGsAicbq - Waffen-SS personnel

response = client.post("/find_collections_by_string", json={
"query": string,
"mode": "instant",
"max_related_collections": 15,
"max_total_collections": 15
})

assert response.status_code == 200

response_json = response.json()

for collection in response_json['related_collections']:
assert collection['collection_id'] != 'qKRMjGsAicbq'

@pytest.mark.integration_test
def test_get_collections_by_id_list_archived(self):
collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel
with initialize(version_base=None, config_path="../conf/"):
config = compose(config_name="test_config_new")
collection_matcher = CollectionMatcherForAPI(config)
collections = collection_matcher.get_collections_by_id_list([collection_id])
assert len(collections) == 0

@pytest.mark.integration_test
def test_find_collections_by_collection_archived_main(self, test_test_client):
client = test_test_client
collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel

response = client.post("/find_collections_by_collection", json={
"collection_id": collection_id,
"max_related_collections": 10,
"min_other_collections": 0,
"max_other_collections": 0,
"max_total_collections": 10,
"name_diversity_ratio": 0.5,
"max_per_type": 3,
"limit_names": 10,
"sort_order": 'Relevance'
})

assert response.status_code == 410

@pytest.mark.integration_test
def test_find_collections_by_collection_archived_other(self, test_test_client):
client = test_test_client
collection_id = 'bobVn_JSavxm' # Military units and formations of the Waffen-SS
archived_collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel

response = client.post("/find_collections_by_collection", json={
"collection_id": collection_id,
"max_related_collections": 10,
"min_other_collections": 0,
"max_other_collections": 10,
"max_total_collections": 10,
"name_diversity_ratio": 0.5,
"max_per_type": 3,
"limit_names": 10,
"sort_order": 'Relevance'
})

assert response.status_code == 200
response_json = response.json()

for collection in response_json['related_collections']:
assert collection['collection_id'] != archived_collection_id

@pytest.mark.integration_test
@pytest.mark.parametrize("member", ["rudolfhoss", "wilhelmmohnke", "karlwolff",
"kurtmeyer", "karlbrandt", "guntergrass"])
def test_find_collections_by_member_archived(self, test_test_client, member):
client = test_test_client
archived_collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel

response = client.post("/find_collections_by_member", json={
"label": member,
"sort_order": "AI",
"mode": 'domain_detail',
"offset": 0,
'max_results': 10
})

assert response.status_code == 200

response_json = response.json()

for collection in response_json['collections']:
assert collection['collection_id'] != archived_collection_id

@mark.integration_test
def test_fetch_collection_members_pagination(self, test_test_client):
# Test fetching first page
Expand Down
101 changes: 101 additions & 0 deletions tests/test_web_api_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,107 @@ def test_fetching_top_collection_members(prod_test_client):
assert name['metadata']['pipeline_name'] == 'fetch_top_collection_members'
assert name['metadata']['collection_id'] == collection_id

@pytest.mark.integration_test
def test_suggestions_by_category_archived(prod_test_client):
client = prod_test_client
response = client.post("/suggestions_by_category", json={
"label": "Waffen-SS personnel",
"params": {
"user_info": {
"user_wallet_addr": "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa",
"user_ip_addr": "192.168.0.1",
"session_id": "d6374908-94c3-420f-b2aa-6dd41989baef",
"user_ip_country": "us"
},
"mode": "full",
"metadata": True
},
"categories": {
"related": {
"enable_learning_to_rank": True,
"max_names_per_related_collection": 10,
"max_per_type": 2,
"max_recursive_related_collections": 3,
"max_related_collections": 10,
"name_diversity_ratio": 0.5
},
"wordplay": {
"max_suggestions": 10,
"min_suggestions": 2
},
"alternates": {
"max_suggestions": 10,
"min_suggestions": 2
},
"emojify": {
"max_suggestions": 10,
"min_suggestions": 2
},
"community": {
"max_suggestions": 10,
"min_suggestions": 2
},
"expand": {
"max_suggestions": 10,
"min_suggestions": 2
},
"gowild": {
"max_suggestions": 10,
"min_suggestions": 2
},
"other": {
"max_suggestions": 10,
"min_suggestions": 6,
"min_total_suggestions": 50
}
}
})

assert response.status_code == 200
response_json = response.json()

assert 'categories' in response_json
categories = response_json['categories']
assert len(categories) > 0
for cat in categories:
if cat['type'] == 'related':
for s in cat['suggestions']:
assert s['metadata']['collection_id'] not in [
'qKRMjGsAicbq', # Waffen-SS personnel (archived)
'7Y4V_MMzHJw8' # SS personnel (archived)
]

@pytest.mark.integration_test
def test_sample_members_from_collection_archived(prod_test_client):
client = prod_test_client
collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel

response = client.post("/sample_collection_members",
json={"collection_id": collection_id, "max_sample_size": 10, "seed": 42})

assert response.status_code == 410 # should be archived

@pytest.mark.integration_test
def test_fetch_top_collection_members_archived(prod_test_client):
client = prod_test_client
collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel

response = client.post("/fetch_top_collection_members",
json={"collection_id": collection_id})

assert response.status_code == 410 # should be archived

@pytest.mark.integration_test
def test_scramble_collection_tokens_archived(prod_test_client):
client = prod_test_client
collection_id = 'qKRMjGsAicbq' # Waffen-SS personnel

response = client.post("/scramble_collection_tokens",
json={"collection_id": collection_id, "method": 'left-right-shuffle',
"n_top_members": 25, "max_suggestions": 1000})

assert response.status_code == 410 # should be archived


class TestTokenScramble:
@pytest.mark.integration_test
Expand Down