Skip to content

Commit 6ee63c9

Browse files
authored
FTS5: Add bm25 function (#147)
* Add type-safe bm25 function * wip * wip * wip * Fix docs
1 parent ddaa35e commit 6ee63c9

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

Sources/StructuredQueriesCore/SQLite/FTS5.swift

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,50 @@ import IssueReporting
88
public protocol FTS5: Table {}
99

1010
extension TableDefinition where QueryValue: FTS5 {
11+
/// A BM25 ranking function for the given column-accuracy mapping.
12+
///
13+
/// - Parameter rankings: A dictionary mapping columns to accuracy of a match.
14+
/// - Returns: A BM25 ranking function.
15+
public func bm25(
16+
_ rankings: KeyValuePairs<PartialKeyPath<Self>, Double> = [:]
17+
) -> some QueryExpression<Double> {
18+
var queryFragments: [QueryFragment] = ["\(QueryValue.self)"]
19+
if !rankings.isEmpty {
20+
var columnNameToRanking: QueryFragment = """
21+
CASE "name"
22+
"""
23+
for (keyPath, ranking) in rankings {
24+
guard let column = self[keyPath: keyPath] as? any WritableTableColumnExpression
25+
else {
26+
reportIssue(
27+
"""
28+
Key path cannot be used in 'bm25' function: \(keyPath)
29+
30+
Must be a key path to a table column on '\(QueryValue.self)'.
31+
"""
32+
)
33+
continue
34+
}
35+
columnNameToRanking.append(
36+
"""
37+
WHEN \(bind: column.name) THEN \(ranking)
38+
"""
39+
)
40+
}
41+
columnNameToRanking.append(" ELSE 1 END")
42+
for offset in Self.writableColumns.indices {
43+
queryFragments.append(
44+
"""
45+
(SELECT \(columnNameToRanking) \
46+
FROM pragma_table_info(\(quote: QueryValue.tableName, delimiter: .text)) \
47+
WHERE "cid" = \(offset))
48+
"""
49+
)
50+
}
51+
}
52+
return SQLQueryExpression("bm25(\(queryFragments.joined(separator: ", ")))")
53+
}
54+
1155
/// A predicate expression from this table matched against another _via_ the `MATCH` operator.
1256
///
1357
/// ```swift
@@ -72,7 +116,7 @@ extension TableColumnExpression where Root: FTS5 {
72116
///
73117
/// ```swift
74118
/// ReminderText.where { $0.title.match("get") }
75-
/// // SELECT … FROM "reminderTexts" WHERE ("reminderTexts" MATCH 'title:\"get\"')
119+
/// // SELECT … FROM "reminderTexts" WHERE ("reminderTexts" MATCH 'title:"get"')
76120
/// ```
77121
///
78122
/// - Parameter pattern: A string expression describing the `MATCH` pattern.

Tests/StructuredQueriesTests/FTSTests.swift

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,57 @@ extension SnapshotTests {
8787
}
8888
}
8989

90+
@Test func bm25() {
91+
assertQuery(
92+
ReminderText
93+
.where { $0.match("Week") }
94+
.order { $0.bm25([\.title: 10, \.notes: 5, \.tags: 2]) }
95+
) {
96+
"""
97+
SELECT "reminderTexts"."reminderID", "reminderTexts"."title", "reminderTexts"."notes", "reminderTexts"."listID", "reminderTexts"."listTitle", "reminderTexts"."tags"
98+
FROM "reminderTexts"
99+
WHERE ("reminderTexts" MATCH 'Week')
100+
ORDER BY bm25("reminderTexts", (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 0), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 1), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 2), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 3), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 4), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 5))
101+
"""
102+
} results: {
103+
"""
104+
┌────────────────────────────────┐
105+
│ ReminderText( │
106+
│ reminderID: 10, │
107+
│ title: "Send weekly emails", │
108+
│ notes: "", │
109+
│ listID: 3, │
110+
│ listTitle: "Business", │
111+
│ tags: ""
112+
│ ) │
113+
└────────────────────────────────┘
114+
"""
115+
}
116+
assertQuery(
117+
ReminderText
118+
.where { $0.match("Week") }
119+
.order { $0.bm25() }
120+
) {
121+
"""
122+
SELECT "reminderTexts"."reminderID", "reminderTexts"."title", "reminderTexts"."notes", "reminderTexts"."listID", "reminderTexts"."listTitle", "reminderTexts"."tags"
123+
FROM "reminderTexts"
124+
WHERE ("reminderTexts" MATCH 'Week')
125+
ORDER BY bm25("reminderTexts")
126+
"""
127+
} results: {
128+
"""
129+
┌────────────────────────────────┐
130+
│ ReminderText( │
131+
│ reminderID: 10, │
132+
│ title: "Send weekly emails", │
133+
│ notes: "", │
134+
│ listID: 3, │
135+
│ listTitle: "Business", │
136+
│ tags: ""
137+
│ ) │
138+
└────────────────────────────────┘
139+
"""
140+
}
141+
}
90142
}
91143
}

0 commit comments

Comments
 (0)