Skip to content

Commit 335a3b1

Browse files
revise base Query class (#41)
1 parent 315423f commit 335a3b1

File tree

13 files changed

+166
-118
lines changed

13 files changed

+166
-118
lines changed

search_query/ebsco/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def parse_query_tree(
256256
parent, current_operator = self.append_operator(parent, proximity_node)
257257

258258
elif token.type == TokenTypes.LOGIC_OPERATOR:
259-
new_operator_node = Query(
259+
new_operator_node = Query.create(
260260
value=token.value.upper(),
261261
position=token.position,
262262
search_field=search_field or field_context,

search_query/pubmed/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _parse_compound_query(self, tokens: list) -> Query:
186186
query_start_pos = tokens[0].position[0]
187187
query_end_pos = tokens[-1].position[1]
188188

189-
return Query(
189+
return Query.create(
190190
value=operator_type,
191191
search_field=None,
192192
children=list(children),
@@ -313,7 +313,7 @@ def _parse_operator_node(self, token_nr: int) -> Query:
313313
if operator.upper() in {"|", "OR"}:
314314
operator = Operators.OR
315315

316-
return Query(
316+
return Query.create(
317317
value=operator,
318318
search_field=None,
319319
children=children,

search_query/pubmed/translator.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
"""Pubmed query translator."""
33
from search_query.constants import Fields
44
from search_query.constants import Operators
5+
from search_query.constants import PLATFORM
56
from search_query.pubmed.constants import generic_search_field_to_syntax_field
67
from search_query.pubmed.constants import syntax_str_to_generic_search_field_set
78
from search_query.query import Query
89
from search_query.query import SearchField
10+
from search_query.query_or import OrQuery
11+
from search_query.query_term import Term
912
from search_query.translator_base import QueryTranslator
1013

1114

@@ -35,11 +38,9 @@ def _expand_flat_or_chains(cls, query: "Query") -> bool:
3538
if not child.search_field: # pragma: no cover
3639
continue
3740
child.search_field.value = Fields.TITLE
38-
new_child = Query(
41+
new_child = Term(
3942
value=child.value,
40-
operator=False,
4143
search_field=SearchField(value=Fields.ABSTRACT),
42-
children=None,
4344
)
4445
query.add_child(new_child)
4546

@@ -146,18 +147,16 @@ def _expand_combined_fields(cls, query: Query, search_fields: set) -> None:
146147
# Note: sorted list for deterministic order of fields
147148
for search_field in sorted(list(search_fields)):
148149
query_children.append(
149-
Query(
150+
Term(
150151
value=query.value,
151-
operator=False,
152152
search_field=SearchField(value=search_field),
153-
children=None,
154153
)
155154
)
156-
157-
query.value = Operators.OR
158-
query.operator = True
159-
query.search_field = None
160-
query.children = query_children # type: ignore
155+
query.replace(
156+
OrQuery(
157+
children=query_children,
158+
)
159+
)
161160

162161
@classmethod
163162
def to_generic_syntax(cls, query: "Query") -> "Query":

search_query/query.py

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ def __init__(
3131
position: typing.Optional[typing.Tuple[int, int]] = None,
3232
platform: str = "generic",
3333
) -> None:
34+
if type(self) is Query:
35+
raise TypeError(
36+
"The base Query type cannot be instantiated directly. "
37+
"Use Query.create() or the appropriate query subclass."
38+
)
3439
self._value: str = ""
35-
self._operator = False
40+
self._operator = operator
3641
self._children: typing.List[Query] = []
3742
self._search_field = None
3843

39-
self.operator = operator
4044
self.value = value
4145
if isinstance(search_field, str):
4246
self.search_field = SearchField(search_field)
@@ -62,6 +66,58 @@ def __init__(
6266
# when queries are created programmatically
6367
self._validate_platform_constraints()
6468

69+
@classmethod
70+
def create(
71+
cls,
72+
value: str,
73+
*,
74+
operator: bool = True,
75+
search_field: typing.Optional[SearchField] = None,
76+
children: typing.Optional[typing.List[typing.Union[str, Query]]] = None,
77+
position: typing.Optional[typing.Tuple[int, int]] = None,
78+
platform: str = "generic",
79+
distance: int = 0
80+
) -> Query:
81+
"""Factory method for query creation."""
82+
if not operator:
83+
from search_query.query_term import Term
84+
return Term(
85+
value=value,
86+
search_field=search_field,
87+
position=position,
88+
platform=platform
89+
)
90+
91+
args = {
92+
"search_field": search_field,
93+
"children": children,
94+
"position": position,
95+
"platform": platform
96+
}
97+
98+
if value == Operators.AND:
99+
from search_query.query_and import AndQuery
100+
return AndQuery(**args)
101+
102+
elif value == Operators.OR:
103+
from search_query.query_or import OrQuery
104+
return OrQuery(**args)
105+
106+
elif value == Operators.NOT:
107+
from search_query.query_not import NotQuery
108+
return NotQuery(**args)
109+
110+
elif value in {Operators.NEAR, Operators.WITHIN}:
111+
from search_query.query_near import NEARQuery
112+
return NEARQuery(value=value, distance=distance, **args)
113+
114+
elif value == Operators.RANGE:
115+
from search_query.query_range import RangeQuery
116+
return RangeQuery(**args)
117+
118+
else:
119+
raise ValueError(f"Invalid operator value: {value}")
120+
65121
def _validate_platform_constraints(self) -> None:
66122
if self.platform == "deactivated":
67123
return
@@ -170,15 +226,18 @@ def value(self, v: str) -> None:
170226
"""Set value property."""
171227
if not isinstance(v, str):
172228
raise TypeError("value must be a string")
173-
if self.operator and v not in [
174-
Operators.AND,
175-
Operators.OR,
176-
Operators.NOT,
177-
Operators.NEAR,
178-
Operators.WITHIN,
179-
Operators.RANGE,
180-
]:
181-
raise ValueError(f"Invalid operator value: {v}")
229+
if self.operator:
230+
if self._value:
231+
raise AttributeError("operator value can only be set once")
232+
if v not in [
233+
Operators.AND,
234+
Operators.OR,
235+
Operators.NOT,
236+
Operators.NEAR,
237+
Operators.WITHIN,
238+
Operators.RANGE,
239+
]:
240+
raise ValueError(f"Invalid operator value: {v}")
182241
self._value = v
183242

184243
@property
@@ -191,6 +250,8 @@ def operator(self, is_op: bool) -> None:
191250
"""Set operator property."""
192251
if not isinstance(is_op, bool):
193252
raise TypeError("operator must be a boolean")
253+
if is_op != self._operator:
254+
raise AttributeError("operator property can only be set once")
194255
self._operator = is_op
195256

196257
@property
@@ -252,6 +313,15 @@ def search_field(self, sf: typing.Optional[SearchField]) -> None:
252313
"""Set search field property."""
253314
self._search_field = copy.deepcopy(sf) if sf else None
254315

316+
def replace(self, new_query) -> None:
317+
if self.get_parent():
318+
for index, child in enumerate(self.get_parent().children):
319+
if child is self:
320+
self.get_parent().children[index] = new_query
321+
return
322+
else:
323+
raise RuntimeError("Root node of a query cannot be replaced")
324+
255325
def selects(self, *, record_dict: dict) -> bool:
256326
"""Indicates whether the query selects a given record."""
257327
# pylint: disable=import-outside-toplevel

search_query/query_not.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Union
77

88
from search_query.constants import Operators
9+
from search_query.constants import PLATFORM
910
from search_query.query import Query
1011
from search_query.query import SearchField
1112
from search_query.query_term import Term
@@ -32,7 +33,7 @@ def __init__(
3233

3334
query_children = [
3435
c if isinstance(c, Query) else Term(value=c) for c in children
35-
]
36+
] if children else None
3637

3738
super().__init__(
3839
value=Operators.NOT,
@@ -58,10 +59,10 @@ def children(self, children: typing.List[Query]) -> None:
5859
"""Set the children of NOT query, updating parent pointers."""
5960
# Clear existing children and reset parent links (if necessary)
6061
self._children.clear()
61-
if not isinstance(children, list):
62+
if self.platform != "deactivated" and not isinstance(children, list):
6263
raise TypeError("children must be a list of Query instances or strings")
6364

64-
if len(children) != 2:
65+
if self.platform not in {"deactivated", PLATFORM.WOS} and len(children) != 2:
6566
raise ValueError("A NOT query must have two children")
6667

6768
# Add each new child using add_child (ensures parent is set)

search_query/wos/parser.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from search_query.query import SearchField
1919
from search_query.query_near import NEARQuery
2020
from search_query.query_term import Term
21+
from search_query.query_not import NotQuery
2122
from search_query.wos.constants import search_field_general_to_syntax
2223
from search_query.wos.linter import WOSQueryListLinter
2324
from search_query.wos.linter import WOSQueryStringLinter
@@ -133,9 +134,7 @@ def parse_query_tree(
133134

134135
if current_negation:
135136
# If the current operator is NOT, wrap the sub_query in a NOT
136-
not_part = Query(
137-
value="NOT",
138-
operator=True,
137+
not_part = NotQuery(
139138
children=[sub_query],
140139
search_field=search_field,
141140
platform="deactivated",
@@ -181,9 +180,7 @@ def parse_query_tree(
181180
# Handle terms
182181
elif token.type == TokenTypes.SEARCH_TERM:
183182
if current_negation:
184-
not_part = Query(
185-
value="NOT",
186-
operator=True,
183+
not_part = NotQuery(
187184
children=[
188185
Term(
189186
value=token.value,
@@ -219,7 +216,7 @@ def parse_query_tree(
219216

220217
# Return the operator and children if there is an operator
221218
return (
222-
Query(
219+
Query.create(
223220
value=current_operator,
224221
children=list(children),
225222
search_field=search_field,
@@ -257,7 +254,7 @@ def _handle_closing_parenthesis(
257254
platform="deactivated",
258255
distance=distance,
259256
)
260-
return Query(
257+
return Query.create(
261258
value=current_operator,
262259
children=children,
263260
search_field=search_field,
@@ -398,7 +395,7 @@ def _build_query_from_operator_node(self, tokens: list) -> Query:
398395

399396
assert operator, "[ERROR] No operator found in combining query."
400397

401-
operator_query = Query(
398+
operator_query = Query.create(
402399
value=operator,
403400
children=children,
404401
platform="deactivated",

search_query/wos/translator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from search_query.constants import Operators
77
from search_query.query import Query
88
from search_query.query import SearchField
9+
from search_query.query_or import OrQuery
10+
from search_query.query_term import Term
911
from search_query.translator_base import QueryTranslator
1012
from search_query.wos.constants import generic_search_field_to_syntax_field
1113
from search_query.wos.constants import syntax_str_to_generic_search_field_set
@@ -44,18 +46,17 @@ def _expand_combined_fields(cls, query: Query, search_fields: set) -> None:
4446
# Note: sorted list for deterministic order of fields
4547
for search_field in sorted(list(search_fields)):
4648
query_children.append(
47-
Query(
49+
Term(
4850
value=query.value,
49-
operator=False,
5051
search_field=SearchField(value=search_field),
51-
children=None,
5252
)
5353
)
5454

55-
query.value = Operators.OR
56-
query.operator = True
57-
query.search_field = None
58-
query.children = query_children # type: ignore
55+
query.replace(
56+
OrQuery(
57+
children=query_children,
58+
)
59+
)
5960

6061
@classmethod
6162
def combine_equal_search_fields(cls, query: Query) -> None:

0 commit comments

Comments
 (0)