Skip to content

Commit f197e45

Browse files
author
Gerit Wagner
committed
query: move selects logic to subclasses
1 parent ecd7ebc commit f197e45

File tree

12 files changed

+245
-75
lines changed

12 files changed

+245
-75
lines changed

search_query/ebsco/parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from search_query.parser_base import QueryStringParser
1414
from search_query.query import Query
1515
from search_query.query import SearchField
16+
from search_query.query_near import NEARQuery
1617
from search_query.query_term import Term
1718

1819

@@ -244,11 +245,12 @@ def parse_query_tree(
244245

245246
elif token.type == TokenTypes.PROXIMITY_OPERATOR:
246247
distance = self._extract_proximity_distance(token)
247-
proximity_node = Query(
248+
proximity_node = NEARQuery(
248249
value=token.value,
250+
distance=distance,
251+
children=[],
249252
position=token.position,
250253
search_field=search_field or field_context,
251-
distance=distance,
252254
platform="deactivated",
253255
)
254256
parent, current_operator = self.append_operator(parent, proximity_node)

search_query/ebsco/serializer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
def to_string_ebsco(query: Query) -> str:
1515
"""Convert the query to a string representation for EBSCO."""
16+
# pylint: disable=import-outside-toplevel
17+
from search_query.query_near import NEARQuery
1618

1719
query = query.copy()
1820

@@ -27,7 +29,7 @@ def to_string_ebsco(query: Query) -> str:
2729
for i, child in enumerate(query.children):
2830
child_str = to_string_ebsco(child)
2931

30-
if query.value in {"NEAR", "WITHIN"}:
32+
if isinstance(query, NEARQuery):
3133
# Convert proximity operator to EBSCO format
3234
proximity_operator = (
3335
f"{'N' if query.value == 'NEAR' else 'W'}{query.distance}"

search_query/query.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
from __future__ import annotations
44

55
import copy
6-
import re
76
import typing
87

9-
from search_query.constants import Fields
108
from search_query.constants import Operators
119
from search_query.constants import PLATFORM
1210
from search_query.constants import SearchField
@@ -31,18 +29,15 @@ def __init__(
3129
search_field: typing.Optional[SearchField] = None,
3230
children: typing.Optional[typing.List[typing.Union[str, Query]]] = None,
3331
position: typing.Optional[typing.Tuple[int, int]] = None,
34-
distance: typing.Optional[int] = None,
3532
platform: str = "generic",
3633
) -> None:
3734
self._value: str = ""
3835
self._operator = False
39-
self._distance = None
4036
self._children: typing.List[Query] = []
4137
self._search_field = None
4238

4339
self.operator = operator
4440
self.value = value
45-
self.distance = distance
4641
if isinstance(search_field, str):
4742
self.search_field = SearchField(search_field)
4843
else:
@@ -198,24 +193,6 @@ def operator(self, is_op: bool) -> None:
198193
raise TypeError("operator must be a boolean")
199194
self._operator = is_op
200195

201-
@property
202-
def distance(self) -> typing.Optional[int]:
203-
"""Distance property."""
204-
return self._distance
205-
206-
@distance.setter
207-
def distance(self, dist: typing.Optional[int]) -> None:
208-
"""Set distance property."""
209-
210-
if self.operator and self.value in {Operators.NEAR, Operators.WITHIN}:
211-
if dist is None:
212-
raise ValueError(f"{self.value} operator requires a distance")
213-
else:
214-
if dist is not None:
215-
raise ValueError(f"{self.value} operator cannot have a distance")
216-
217-
self._distance = dist
218-
219196
@property
220197
def children(self) -> typing.List[Query]:
221198
"""Children property."""
@@ -284,38 +261,13 @@ def selects(self, *, record_dict: dict) -> bool:
284261
QueryTranslator.move_fields_to_terms(query_with_term_fields)
285262

286263
# pylint: disable=protected-access
287-
return query_with_term_fields._selects(record_dict=record_dict)
288-
289-
def _selects(self, record_dict: dict) -> bool:
290-
if self.value == Operators.NOT:
291-
return not self.children[0].selects(record_dict=record_dict)
292-
293-
if self.value == Operators.AND:
294-
return all(x.selects(record_dict=record_dict) for x in self.children)
264+
return query_with_term_fields.selects_record(record_dict=record_dict)
295265

296-
if self.value == Operators.OR:
297-
return any(x.selects(record_dict=record_dict) for x in self.children)
298-
299-
assert not self.operator
300-
301-
assert self.search_field is not None, "Search field must be set for terms"
302-
if self.search_field.value == Fields.TITLE:
303-
field_value = record_dict.get("title", "").lower()
304-
elif self.search_field.value == Fields.ABSTRACT:
305-
field_value = record_dict.get("abstract", "").lower()
306-
else:
307-
raise ValueError(f"Unsupported search field: {self.search_field}")
308-
309-
value = self.value.lower().lstrip('"').rstrip('"')
310-
311-
# Handle wildcards
312-
if "*" in value:
313-
pattern = re.compile(value.replace("*", ".*").lower())
314-
match = pattern.search(field_value)
315-
return match is not None
316-
317-
# Match exact word
318-
return value.lower() in field_value
266+
def selects_record(self, record_dict: dict) -> bool:
267+
"""Abstract method to be implemented by subclasses to select records."""
268+
raise NotImplementedError(
269+
"This method should be implemented by subclasses of Query"
270+
)
319271

320272
def _get_confusion_matrix(self, records_dict: dict) -> dict:
321273
relevant_ids = set()

search_query/query_and.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@ def children(self, children: typing.List[Query]) -> None:
5757
# Add each new child using add_child (ensures parent is set)
5858
for child in children or []:
5959
self.add_child(child)
60+
61+
def selects_record(self, record_dict: dict) -> bool:
62+
return all(x.selects(record_dict=record_dict) for x in self.children)

search_query/query_near.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
#!/usr/bin/env python
22
"""NEAR Query"""
33
import typing
4+
from typing import cast
5+
from typing import List
6+
from typing import Union
47

8+
from search_query.constants import Operators
59
from search_query.query import Query
610
from search_query.query import SearchField
11+
from search_query.query_term import Term
712

813

914
class NEARQuery(Query):
@@ -19,7 +24,7 @@ def __init__(
1924
*,
2025
search_field: typing.Optional[typing.Union[SearchField, str]] = None,
2126
position: typing.Optional[typing.Tuple[int, int]] = None,
22-
distance: typing.Optional[int] = None,
27+
distance: int,
2328
platform: str = "generic",
2429
) -> None:
2530
"""init method
@@ -29,18 +34,41 @@ def __init__(
2934
search field: search field to which the query should be applied
3035
"""
3136

37+
query_children = [
38+
c if isinstance(c, Query) else Term(value=c) for c in children
39+
]
40+
3241
super().__init__(
3342
value=value,
34-
children=children,
43+
children=cast(List[Union[str, Query]], query_children),
3544
search_field=search_field
3645
if isinstance(search_field, SearchField)
3746
else SearchField(search_field)
3847
if search_field is not None
3948
else None,
4049
position=position,
41-
distance=distance,
4250
platform=platform,
4351
)
52+
self.children = query_children
53+
self.distance: int = distance
54+
55+
@property
56+
def distance(self) -> typing.Optional[int]:
57+
"""Distance property."""
58+
return self._distance
59+
60+
@distance.setter
61+
def distance(self, dist: typing.Optional[int]) -> None:
62+
"""Set distance property."""
63+
64+
if self.operator and self.value in {Operators.NEAR, Operators.WITHIN}:
65+
if dist is None:
66+
raise ValueError(f"{self.value} operator requires a distance")
67+
else:
68+
if dist is not None:
69+
raise ValueError(f"{self.value} operator cannot have a distance")
70+
71+
self._distance = dist
4472

4573
@property
4674
def children(self) -> typing.List[Query]:
@@ -51,13 +79,55 @@ def children(self) -> typing.List[Query]:
5179
def children(self, children: typing.List[Query]) -> None:
5280
"""Set the children of NEAR query, updating parent pointers."""
5381
# Clear existing children and reset parent links (if necessary)
82+
5483
self._children.clear()
84+
5585
if not isinstance(children, list):
5686
raise TypeError("children must be a list of Query instances or strings")
5787

58-
if len(children) != 2:
59-
raise ValueError("A NEAR query must have two children")
88+
if self.platform != "deactivated": # Note: temporary for EBSCO parser
89+
if len(children) != 2:
90+
raise ValueError("A NEAR query must have two children")
6091

6192
# Add each new child using add_child (ensures parent is set)
6293
for child in children or []:
6394
self.add_child(child)
95+
96+
def selects_record(self, record_dict: dict) -> bool:
97+
"""Check if the record matches the NEAR query."""
98+
assert len(self.children) == 2, "NEAR query must have two children"
99+
assert self.children[0].search_field, "First child must have a search field"
100+
assert self.children[1].search_field, "Second child must have a search field"
101+
assert self.distance is not None, "NEAR query must have a distance"
102+
assert (
103+
self.children[0].search_field.value == self.children[1].search_field.value
104+
), "Both children of NEAR query must have the same search field"
105+
106+
# the self.children[0].value
107+
# must be in self.distance words of self.children[1].value
108+
field = self.children[0].search_field.value
109+
text = record_dict.get(field, "")
110+
if not isinstance(text, str):
111+
return False
112+
113+
term1 = self.children[0].value.lower()
114+
term2 = self.children[1].value.lower()
115+
116+
tokens = (
117+
text.split()
118+
) # Simple whitespace tokenizer; can be replaced with a smarter one
119+
# Get all positions of term1 and term2
120+
positions_term1 = [
121+
i for i, token in enumerate(tokens) if token.lower() == term1
122+
]
123+
positions_term2 = [
124+
i for i, token in enumerate(tokens) if token.lower() == term2
125+
]
126+
127+
# Check if any pair is within the allowed distance
128+
for p1 in positions_term1:
129+
for p2 in positions_term2:
130+
if abs(p1 - p2) <= self.distance:
131+
return True
132+
133+
return False

search_query/query_not.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,8 @@ def children(self, children: typing.List[Query]) -> None:
5757
# Add each new child using add_child (ensures parent is set)
5858
for child in children or []:
5959
self.add_child(child)
60+
61+
def selects_record(self, record_dict: dict) -> bool:
62+
return self.children[0].selects(record_dict=record_dict) and not self.children[
63+
1
64+
].selects(record_dict=record_dict)

search_query/query_or.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ def children(self, children: typing.List[Query]) -> None:
5858
# Add each new child using add_child (ensures parent is set)
5959
for child in children or []:
6060
self.add_child(child)
61+
62+
def selects_record(self, record_dict: dict) -> bool:
63+
return any(x.selects(record_dict=record_dict) for x in self.children)

search_query/query_range.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,28 @@ def children(self, children: typing.List[Query]) -> None:
5757
# Add each new child using add_child (ensures parent is set)
5858
for child in children or []:
5959
self.add_child(child)
60+
61+
def selects_record(self, record_dict: dict) -> bool:
62+
"""Check if the record matches the range query."""
63+
assert len(self.children) == 2, "RANGE query must have two children"
64+
assert self.children[0].search_field, "First child must have a search field"
65+
assert self.children[1].search_field, "Second child must have a search field"
66+
assert (
67+
self.children[0].search_field.value == self.children[1].search_field.value
68+
), "Both children of RANGE query must have the same search field"
69+
70+
term1 = self.children[0].value.lower()
71+
term2 = self.children[1].value.lower()
72+
record_field = record_dict.get(
73+
self.children[0].search_field.value, record_dict.get("year", "")
74+
)
75+
76+
if term1.isdigit() and term2.isdigit() and record_field.isdigit():
77+
value1 = int(term1)
78+
value2 = int(term2)
79+
record_value = int(record_field)
80+
return value1 <= record_value <= value2
81+
82+
# Match other cases here (e.g., dates)
83+
84+
raise ValueError("Both children of RANGE query must be numeric values")

search_query/query_term.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
"""Query class."""
33
from __future__ import annotations
44

5+
import re
56
import typing
67

8+
from search_query.constants import Fields
79
from search_query.constants import SearchField
810
from search_query.query import Query
911

@@ -27,3 +29,23 @@ def __init__(
2729
position=position,
2830
platform=platform,
2931
)
32+
33+
def selects_record(self, record_dict: dict) -> bool:
34+
assert self.search_field is not None, "Search field must be set for terms"
35+
if self.search_field.value == Fields.TITLE:
36+
field_value = record_dict.get("title", "").lower()
37+
elif self.search_field.value == Fields.ABSTRACT:
38+
field_value = record_dict.get("abstract", "").lower()
39+
else:
40+
raise ValueError(f"Unsupported search field: {self.search_field}")
41+
42+
value = self.value.lower().lstrip('"').rstrip('"')
43+
44+
# Handle wildcards
45+
if "*" in value:
46+
pattern = re.compile(value.replace("*", ".*").lower())
47+
match = pattern.search(field_value)
48+
return match is not None
49+
50+
# Match exact word
51+
return value.lower() in field_value

search_query/wos/parser.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from search_query.parser_base import QueryStringParser
1717
from search_query.query import Query
1818
from search_query.query import SearchField
19+
from search_query.query_near import NEARQuery
1920
from search_query.query_term import Term
2021
from search_query.wos.constants import search_field_general_to_syntax
2122
from search_query.wos.linter import WOSQueryListLinter
@@ -241,12 +242,26 @@ def _handle_closing_parenthesis(
241242

242243
# Return the operator and children if there is an operator
243244
if current_operator:
245+
if distance:
246+
# If there is a distance, it must be a proximity operator
247+
if current_operator not in {"NEAR", "WITHIN"}:
248+
raise ValueError(
249+
f"Distance {distance} "
250+
"is only allowed for NEAR or WITHIN operators, "
251+
f"not {current_operator}"
252+
)
253+
return NEARQuery(
254+
value=current_operator,
255+
children=children,
256+
search_field=search_field,
257+
platform="deactivated",
258+
distance=distance,
259+
)
244260
return Query(
245261
value=current_operator,
246262
children=children,
247263
search_field=search_field,
248264
platform="deactivated",
249-
distance=distance,
250265
)
251266

252267
# Multiple children without operator are not allowed

0 commit comments

Comments
 (0)