11#!/usr/bin/env python
22"""NEAR Query"""
33import typing
4+ from typing import cast
5+ from typing import List
6+ from typing import Union
47
8+ from search_query .constants import Operators
59from search_query .query import Query
610from search_query .query import SearchField
11+ from search_query .query_term import Term
712
813
914class NEARQuery (Query ):
@@ -19,7 +24,7 @@ def __init__(
1924 * ,
2025 search_field : typing .Optional [typing .Union [SearchField , str ]] = None ,
2126 position : typing .Optional [typing .Tuple [int , int ]] = None ,
22- distance : typing . Optional [ int ] = None ,
27+ distance : int ,
2328 platform : str = "generic" ,
2429 ) -> None :
2530 """init method
@@ -29,18 +34,41 @@ def __init__(
2934 search field: search field to which the query should be applied
3035 """
3136
37+ query_children = [
38+ c if isinstance (c , Query ) else Term (value = c ) for c in children
39+ ]
40+
3241 super ().__init__ (
3342 value = value ,
34- children = children ,
43+ children = cast ( List [ Union [ str , Query ]], query_children ) ,
3544 search_field = search_field
3645 if isinstance (search_field , SearchField )
3746 else SearchField (search_field )
3847 if search_field is not None
3948 else None ,
4049 position = position ,
41- distance = distance ,
4250 platform = platform ,
4351 )
52+ self .children = query_children
53+ self .distance : int = distance
54+
55+ @property
56+ def distance (self ) -> typing .Optional [int ]:
57+ """Distance property."""
58+ return self ._distance
59+
60+ @distance .setter
61+ def distance (self , dist : typing .Optional [int ]) -> None :
62+ """Set distance property."""
63+
64+ if self .operator and self .value in {Operators .NEAR , Operators .WITHIN }:
65+ if dist is None :
66+ raise ValueError (f"{ self .value } operator requires a distance" )
67+ else :
68+ if dist is not None :
69+ raise ValueError (f"{ self .value } operator cannot have a distance" )
70+
71+ self ._distance = dist
4472
4573 @property
4674 def children (self ) -> typing .List [Query ]:
@@ -51,13 +79,55 @@ def children(self) -> typing.List[Query]:
5179 def children (self , children : typing .List [Query ]) -> None :
5280 """Set the children of NEAR query, updating parent pointers."""
5381 # Clear existing children and reset parent links (if necessary)
82+
5483 self ._children .clear ()
84+
5585 if not isinstance (children , list ):
5686 raise TypeError ("children must be a list of Query instances or strings" )
5787
58- if len (children ) != 2 :
59- raise ValueError ("A NEAR query must have two children" )
88+ if self .platform != "deactivated" : # Note: temporary for EBSCO parser
89+ if len (children ) != 2 :
90+ raise ValueError ("A NEAR query must have two children" )
6091
6192 # Add each new child using add_child (ensures parent is set)
6293 for child in children or []:
6394 self .add_child (child )
95+
96+ def selects_record (self , record_dict : dict ) -> bool :
97+ """Check if the record matches the NEAR query."""
98+ assert len (self .children ) == 2 , "NEAR query must have two children"
99+ assert self .children [0 ].search_field , "First child must have a search field"
100+ assert self .children [1 ].search_field , "Second child must have a search field"
101+ assert self .distance is not None , "NEAR query must have a distance"
102+ assert (
103+ self .children [0 ].search_field .value == self .children [1 ].search_field .value
104+ ), "Both children of NEAR query must have the same search field"
105+
106+ # the self.children[0].value
107+ # must be in self.distance words of self.children[1].value
108+ field = self .children [0 ].search_field .value
109+ text = record_dict .get (field , "" )
110+ if not isinstance (text , str ):
111+ return False
112+
113+ term1 = self .children [0 ].value .lower ()
114+ term2 = self .children [1 ].value .lower ()
115+
116+ tokens = (
117+ text .split ()
118+ ) # Simple whitespace tokenizer; can be replaced with a smarter one
119+ # Get all positions of term1 and term2
120+ positions_term1 = [
121+ i for i , token in enumerate (tokens ) if token .lower () == term1
122+ ]
123+ positions_term2 = [
124+ i for i , token in enumerate (tokens ) if token .lower () == term2
125+ ]
126+
127+ # Check if any pair is within the allowed distance
128+ for p1 in positions_term1 :
129+ for p2 in positions_term2 :
130+ if abs (p1 - p2 ) <= self .distance :
131+ return True
132+
133+ return False
0 commit comments