Skip to content

Commit 1da4817

Browse files
authored
feat(search): Implement GEO index and support for RADIUS search (#5854)
Implementation of GEO index use boost::geometry::index::rtree as structure to store/search/remove points. Search is done in a way that we construct polygon of 4 points with 90 degrees angle difference around search point with fixed distance and bounding box around that polygon to cover whole area of possible points that fit in radius. For each point we calculate haversine distance between found point and search point. Fixes #4536 Signed-off-by: mkaruza <[email protected]>
1 parent cba733e commit 1da4817

File tree

11 files changed

+275
-19
lines changed

11 files changed

+275
-19
lines changed

src/core/search/ast_expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ AstRangeNode::AstRangeNode(double lo, bool lo_excl, double hi, bool hi_excl)
2020
: lo{lo_excl ? nextafter(lo, hi) : lo}, hi{hi_excl ? nextafter(hi, lo) : hi} {
2121
}
2222

23+
AstGeoNode::AstGeoNode(double lon, double lat, double radius, std::string unit)
24+
: lon(lon), lat(lat), radius(radius), unit(std::move(unit)) {
25+
}
26+
2327
AstNegateNode::AstNegateNode(AstNode&& node) : node{make_unique<AstNode>(std::move(node))} {
2428
}
2529

src/core/search/ast_expr.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ struct AstRangeNode {
4545
double lo, hi;
4646
};
4747

48+
struct AstGeoNode {
49+
AstGeoNode(double lon, double lat, double radius, std::string unit);
50+
double lon, lat;
51+
double radius;
52+
std::string unit;
53+
};
54+
4855
// Negates subtree
4956
struct AstNegateNode {
5057
AstNegateNode(AstNode&& node);
@@ -112,7 +119,7 @@ struct AstKnnNode {
112119
using NodeVariants =
113120
std::variant<std::monostate, AstStarNode, AstStarFieldNode, AstTermNode, AstPrefixNode,
114121
AstSuffixNode, AstInfixNode, AstRangeNode, AstNegateNode, AstLogicalNode,
115-
AstFieldNode, AstTagsNode, AstKnnNode>;
122+
AstFieldNode, AstTagsNode, AstKnnNode, AstGeoNode>;
116123

117124
struct AstNode : public NodeVariants {
118125
using variant::variant;

src/core/search/indices.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <absl/strings/str_join.h>
1111
#include <absl/strings/str_split.h>
1212

13+
#include <boost/iterator/function_output_iterator.hpp>
14+
1315
#define UNI_ALGO_DISABLE_NFKC_NFKD
1416

1517
#include <hnswlib/hnswalg.h>
@@ -90,6 +92,42 @@ void IterateAllSuffixes(const absl::flat_hash_set<string>& words,
9092
}
9193
}
9294

95+
// Haversine with earth radius in meters. Used to calculate distance.
96+
boost::geometry::strategy::distance::haversine haversine_(6372797.560856);
97+
98+
double ConvertToRadiusInMeters(size_t radius, std::string_view arg) {
99+
const std::string unit = absl::AsciiStrToUpper(arg);
100+
if (unit == "M") {
101+
return radius * 1;
102+
} else if (unit == "KM") {
103+
return radius * 1000;
104+
} else if (unit == "FT") {
105+
return radius * 0.3048;
106+
} else if (unit == "MI") {
107+
return radius * 1609.34;
108+
} else {
109+
return -1;
110+
}
111+
}
112+
113+
std::optional<GeoIndex::point> GetGeoPoint(const DocumentAccessor& doc, string_view field) {
114+
auto element = doc.GetStrings(field);
115+
116+
if (!element)
117+
return std::nullopt;
118+
119+
absl::InlinedVector<string_view, 2> coordinates = absl::StrSplit(element.value()[0], ",");
120+
121+
if (coordinates.size() != 2)
122+
return std::nullopt;
123+
124+
double lon, lat;
125+
if (!absl::SimpleAtod(coordinates[0], &lon) || !absl::SimpleAtod(coordinates[1], &lat))
126+
return nullopt;
127+
128+
return GeoIndex::point{lon, lat};
129+
}
130+
93131
}; // namespace
94132

95133
class RangeTreeAdapter : public NumericIndex::RangeTreeBase {
@@ -614,4 +652,73 @@ void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view
614652
adapter_->Remove(id);
615653
}
616654

655+
GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
656+
}
657+
658+
GeoIndex::~GeoIndex() {
659+
}
660+
661+
bool GeoIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
662+
auto doc_point = GetGeoPoint(doc, field);
663+
if (!doc_point) {
664+
return false;
665+
}
666+
rtree_->insert({doc_point.value(), id});
667+
return true;
668+
}
669+
670+
void GeoIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
671+
auto doc_point = GetGeoPoint(doc, field);
672+
rtree_->remove({doc_point.value(), id});
673+
}
674+
675+
std::vector<DocId> GeoIndex::RadiusSearch(double lon, double lat, double radius,
676+
std::string_view unit) {
677+
std::vector<DocId> results;
678+
679+
// Get radius in meters
680+
double converted_radius = ConvertToRadiusInMeters(radius, unit);
681+
682+
// Declare the geographic_point_circle strategy with 4 points
683+
boost::geometry::strategy::buffer::geographic_point_circle<> point_strategy(4);
684+
685+
// Declare the distance strategy in meters around the point
686+
boost::geometry::strategy::buffer::distance_symmetric<double> distance_strategy(converted_radius);
687+
688+
// Declare other necessary strategies, unused for point
689+
boost::geometry::strategy::buffer::join_round join_strategy;
690+
boost::geometry::strategy::buffer::end_round end_strategy;
691+
boost::geometry::strategy::buffer::side_straight side_strategy;
692+
693+
point p{lon, lat};
694+
695+
// Create polygon with 4 point around point
696+
boost::geometry::model::multi_polygon<boost::geometry::model::polygon<point>> buffer_polygon;
697+
698+
boost::geometry::buffer(p, buffer_polygon, distance_strategy, side_strategy, join_strategy,
699+
end_strategy, point_strategy);
700+
701+
// Create bouding box around polygon to include all possible points
702+
boost::geometry::model::box<point> box;
703+
boost::geometry::envelope(buffer_polygon, box);
704+
705+
rtree_->query(
706+
boost::geometry::index::within(box),
707+
boost::make_function_output_iterator([&results, &p, &converted_radius](auto const& val) {
708+
if (haversine_.apply(val.first, p) <= converted_radius) {
709+
results.push_back(val.second);
710+
}
711+
}));
712+
713+
// TODO: we should return sorted results by radius distance
714+
return results;
715+
}
716+
717+
std::vector<DocId> GeoIndex::GetAllDocsWithNonNullValues() const {
718+
std::vector<DocId> results;
719+
std::for_each(boost::geometry::index::begin(*rtree_), boost::geometry::index::end(*rtree_),
720+
[&results](auto const& val) { results.push_back(val.second); });
721+
return results;
722+
}
723+
617724
} // namespace dfly::search

src/core/search/indices.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <absl/container/flat_hash_map.h>
99
#include <absl/container/flat_hash_set.h>
1010

11+
#include <boost/geometry.hpp>
1112
#include <map>
1213
#include <memory>
1314
#include <optional>
@@ -205,4 +206,23 @@ struct HnswVectorIndex : public BaseVectorIndex {
205206
std::unique_ptr<HnswlibAdapter> adapter_;
206207
};
207208

209+
struct GeoIndex : public BaseIndex {
210+
using point =
211+
boost::geometry::model::point<double, 2,
212+
boost::geometry::cs::geographic<boost::geometry::degree>>;
213+
using index_entry = std::pair<point, DocId>;
214+
215+
explicit GeoIndex(PMR_NS::memory_resource* mr);
216+
~GeoIndex();
217+
218+
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
219+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
220+
std::vector<DocId> RadiusSearch(double lon, double lat, double radius, std::string_view arg);
221+
std::vector<DocId> GetAllDocsWithNonNullValues() const override;
222+
223+
private:
224+
using rtree = boost::geometry::index::rtree<index_entry, boost::geometry::index::linear<16>>;
225+
std::unique_ptr<rtree> rtree_;
226+
};
227+
208228
} // namespace dfly::search

src/core/search/lexer.lex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ astrsk_ch \*
6868
"KNN" return Parser::make_KNN (loc());
6969
"AS" return Parser::make_AS (loc());
7070
"EF_RUNTIME" return Parser::make_EF_RUNTIME (loc());
71+
"M" return Parser::make_GEOUNIT_M (loc());
72+
"KM" return Parser::make_GEOUNIT_KM (loc());
73+
"MI" return Parser::make_GEOUNIT_MI (loc());
74+
"FT" return Parser::make_GEOUNIT_FT (loc());
7175

7276
[0-9]{1,9} return Parser::make_UINT32(str(), loc());
7377
[+-]?(([0-9]*[.])?[0-9]+|inf) return Parser::make_DOUBLE(str(), loc());

src/core/search/parser.y

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ double toDouble(string_view src);
6161
KNN "KNN"
6262
AS "AS"
6363
EF_RUNTIME "EF_RUNTIME"
64+
GEOUNIT_M "GEOUNIT_M"
65+
GEOUNIT_KM "GEOUNIT_KM"
66+
GEOUNIT_MI "GEOUNIT_MI"
67+
GEOUNIT_FT "GEOUNIT_FT"
6468
;
6569

6670
%token AND_OP
@@ -77,14 +81,13 @@ double toDouble(string_view src);
7781

7882
%token <std::string> DOUBLE "double"
7983
%token <std::string> UINT32 "uint32"
80-
%nterm <double> generic_number
81-
%nterm <bool> opt_lparen
82-
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr numeric_filter_expr
84+
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr bracket_filter_expr
8385
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
8486
%nterm <AstTagsNode::TagValueProxy> tag_list_element
8587

8688
%nterm <AstKnnNode> knn_query
8789
%nterm <std::string> opt_knn_alias
90+
%nterm <std::string> geounit
8891
%nterm <std::optional<size_t>> opt_ef_runtime
8992

9093
%printer { yyo << $$; } <*>;
@@ -149,23 +152,55 @@ field_cond:
149152
| STAR { $$ = AstStarFieldNode(); }
150153
| NOT_OP field_cond { $$ = AstNegateNode(std::move($2)); }
151154
| LPAREN field_cond_expr RPAREN { $$ = std::move($2); }
152-
| LBRACKET numeric_filter_expr RBRACKET { $$ = std::move($2); }
155+
| LBRACKET bracket_filter_expr RBRACKET { $$ = std::move($2); }
153156
| LCURLBR tag_list RCURLBR { $$ = std::move($2); }
154157
| PREFIX { $$ = AstPrefixNode(std::move($1)); }
155158
| SUFFIX { $$ = AstSuffixNode(std::move($1)); }
156159
| INFIX { $$ = AstInfixNode(std::move($1)); }
157160

158-
numeric_filter_expr:
159-
opt_lparen generic_number opt_lparen generic_number { $$ = AstRangeNode($2, $1, $4, $3); }
160-
| opt_lparen generic_number COMMA opt_lparen generic_number { $$ = AstRangeNode($2, $1, $5, $4); }
161-
162-
generic_number:
163-
DOUBLE { $$ = toDouble($1); }
164-
| UINT32 { $$ = toUint32($1); }
165-
166-
opt_lparen:
167-
/* empty */ { $$ = false; }
168-
| LPAREN { $$ = true; }
161+
bracket_filter_expr:
162+
/* Numeric filter has form [(] UINT32|DOUBLE [COMMA] [(] UINT32|DOUBLE */
163+
DOUBLE DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($2), false); }
164+
| LPAREN DOUBLE DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($3), false); }
165+
| DOUBLE LPAREN DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($3), true); }
166+
| LPAREN DOUBLE LPAREN DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($4), true); }
167+
| DOUBLE UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($2), false); }
168+
| LPAREN DOUBLE UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($3), false); }
169+
| DOUBLE LPAREN UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($3), true); }
170+
| LPAREN DOUBLE LPAREN UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($4), true); }
171+
| UINT32 DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($2), false); }
172+
| LPAREN UINT32 DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($3), false); }
173+
| UINT32 LPAREN DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($3), true); }
174+
| LPAREN UINT32 LPAREN DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($4), true); }
175+
| UINT32 UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($2), false); }
176+
| LPAREN UINT32 UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($3), false); }
177+
| UINT32 LPAREN UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($3), true); }
178+
| LPAREN UINT32 LPAREN UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($4), true); }
179+
| DOUBLE COMMA DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($3), false); }
180+
| DOUBLE COMMA UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($3), false); }
181+
| UINT32 COMMA DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($3), false); }
182+
| UINT32 COMMA UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($3), false); }
183+
| LPAREN DOUBLE COMMA DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($4), false); }
184+
| DOUBLE COMMA LPAREN DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($4), true); }
185+
| LPAREN DOUBLE COMMA LPAREN DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($5), true); }
186+
| LPAREN DOUBLE COMMA UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($4), false); }
187+
| DOUBLE COMMA LPAREN UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($4), true); }
188+
| LPAREN DOUBLE COMMA LPAREN UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($5), true); }
189+
| LPAREN UINT32 COMMA DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($4), false); }
190+
| UINT32 COMMA LPAREN DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($4), true); }
191+
| LPAREN UINT32 COMMA LPAREN DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($5), true); }
192+
| LPAREN UINT32 COMMA UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($4), false); }
193+
| UINT32 COMMA LPAREN UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($4), true); }
194+
| LPAREN UINT32 COMMA LPAREN UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($5), true); }
195+
/* GEO filter */
196+
| DOUBLE DOUBLE UINT32 geounit { $$ = AstGeoNode(toDouble($1), toDouble($2), toUint32($3), std::move($4)); }
197+
| DOUBLE DOUBLE DOUBLE geounit { $$ = AstGeoNode(toDouble($1), toDouble($2), toDouble($3), std::move($4)); }
198+
199+
geounit:
200+
GEOUNIT_M { $$ = "M"; }
201+
| GEOUNIT_KM { $$ = "KM"; }
202+
| GEOUNIT_MI { $$ = "MI"; }
203+
| GEOUNIT_FT { $$ = "FT"; }
169204

170205
field_cond_expr:
171206
field_unary_expr { $$ = std::move($1); }

src/core/search/search.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ struct ProfileBuilder {
7979
[](const AstNegateNode& n) { return absl::StrCat("Negate{}"); },
8080
[](const AstStarNode& n) { return absl::StrCat("Star{}"); },
8181
[](const AstStarFieldNode& n) { return absl::StrCat("StarField{}"); },
82+
[](const AstGeoNode& n) {
83+
return absl::StrCat("Geo{", n.lat, " ", n.lon, " ", n.radius, " ", n.unit, "}");
84+
},
8285
};
8386
return visit(node_info, node.Variant());
8487
}
@@ -276,6 +279,14 @@ struct BasicSearch {
276279
return IndexResult{};
277280
}
278281

282+
IndexResult Search(const AstGeoNode& node, string_view active_field) {
283+
DCHECK(!active_field.empty());
284+
if (auto* index = GetIndex<GeoIndex>(active_field); index) {
285+
return IndexResult{index->RadiusSearch(node.lon, node.lat, node.radius, node.unit)};
286+
}
287+
return IndexResult{};
288+
}
289+
279290
// negate -(*subquery*): explicitly compute result complement. Needs further optimizations
280291
IndexResult Search(const AstNegateNode& node, string_view active_field) {
281292
vector<DocId> matched = SearchGeneric(*node.node, active_field).Take();
@@ -396,6 +407,7 @@ struct BasicSearch {
396407
// used by knn
397408

398409
DCHECK(top_level || holds_alternative<AstKnnNode>(node.Variant()) ||
410+
holds_alternative<AstGeoNode>(node.Variant()) ||
399411
visit([](auto* set) { return is_sorted(set->begin(), set->end()); }, result.Borrowed()));
400412

401413
if (profile_builder_)
@@ -500,6 +512,10 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) {
500512
indices_[field_ident] = std::move(vector_index);
501513
break;
502514
}
515+
case SchemaField::GEO: {
516+
indices_[field_ident] = make_unique<GeoIndex>(mr);
517+
break;
518+
}
503519
}
504520
}
505521
}
@@ -518,6 +534,7 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) {
518534
sort_indices_[field_ident] = make_unique<NumericSortIndex>(mr);
519535
break;
520536
case SchemaField::VECTOR:
537+
case SchemaField::GEO:
521538
break;
522539
}
523540
}

src/core/search/search.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct OptionalNumericFilter : public OptionalFilterBase {
5555

5656
// Describes a specific index field
5757
struct SchemaField {
58-
enum FieldType { TAG, TEXT, NUMERIC, VECTOR };
58+
enum FieldType { TAG, TEXT, NUMERIC, VECTOR, GEO };
5959
enum FieldFlags : uint8_t { NOINDEX = 1 << 0, SORTABLE = 1 << 1 };
6060

6161
struct VectorParams {

0 commit comments

Comments
 (0)