Skip to content

Commit 9c859a3

Browse files
authored
ensure boolean search queries correctly generate (#210)
* tests: add cases for bool search query func * try to protect from bad keyterm synonyms data * tests: add more complicated bool query case
1 parent 17d2dd3 commit 9c859a3

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

colandr/api/v1/schemas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ class ReviewPlanPICO(af.Schema):
8686
comparison = af.fields.String(validate=af.validators.Length(max=300))
8787
outcome = af.fields.String(validate=af.validators.Length(max=300))
8888

89+
@ma.pre_load
90+
def coerce_synonyms_to_list(self, data: dict, **kwargs) -> dict:
91+
synonyms = data.get("synonyms")
92+
if isinstance(synonyms, str):
93+
data["synonyms"] = [s.strip() for s in synonyms.split(",") if s.strip()]
94+
return data
95+
8996

9097
class ReviewPlanKeyterm(af.Schema):
9198
group = af.fields.String(required=True, validate=af.validators.Length(max=100))

colandr/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def get_boolean_search_query(keyterms: Iterable[dict]) -> str:
4646

4747

4848
def _boolify_term_set(term_set: dict) -> str:
49-
if term_set.get("synonyms"):
49+
synonyms = term_set.get("synonyms")
50+
if isinstance(synonyms, str): # corrupted data guard
51+
synonyms = [s.strip() for s in synonyms.split(",") if s.strip()]
52+
if synonyms:
5053
return (
5154
"("
5255
+ " OR ".join(

tests/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from colandr import utils
4+
5+
6+
@pytest.mark.parametrize(
7+
["keyterms", "exp_result"],
8+
[
9+
(
10+
[{"term": "foo", "group": "test", "synonyms": ["bar", "bat"]}],
11+
'("foo" OR "bar" OR "bat")',
12+
),
13+
(
14+
[
15+
{"term": "foo", "group": "test1"},
16+
{"term": "spam", "group": "test2", "synonyms": ["eggs"]},
17+
],
18+
'"foo"\nAND\n("spam" OR "eggs")',
19+
),
20+
(
21+
[
22+
{"term": "foo", "group": "test1"},
23+
{"term": "bar", "group": "test1", "synonyms": ["bat", "baz"]},
24+
{"term": "spam", "group": "test2", "synonyms": ["eggs"]},
25+
],
26+
'("foo" OR ("bar" OR "bat" OR "baz"))\nAND\n("spam" OR "eggs")',
27+
),
28+
],
29+
)
30+
def test_get_boolean_search_query(keyterms, exp_result):
31+
result = utils.get_boolean_search_query(keyterms)
32+
assert result == exp_result

0 commit comments

Comments
 (0)