Skip to content

Commit f68e4a0

Browse files
committed
Merge branch 'main' into custom-functions
2 parents d595615 + 6ee63c9 commit f68e4a0

File tree

3 files changed

+113
-4
lines changed

3 files changed

+113
-4
lines changed

Sources/StructuredQueriesCore/QueryFragment.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ extension QueryFragment: ExpressibleByStringInterpolation {
150150
}
151151

152152
public mutating func appendLiteral(_ literal: String) {
153+
guard !literal.isEmpty else { return }
153154
segments.append(.sql(literal))
154155
}
155156

@@ -183,7 +184,7 @@ extension QueryFragment: ExpressibleByStringInterpolation {
183184
///
184185
/// - Parameter sql: A raw query string.
185186
public mutating func appendInterpolation(raw sql: String) {
186-
segments.append(.sql(sql))
187+
appendLiteral(sql)
187188
}
188189

189190
/// Append a raw lossless string to the interpolation.
@@ -200,7 +201,7 @@ extension QueryFragment: ExpressibleByStringInterpolation {
200201
///
201202
/// - Parameter sql: A raw query string.
202203
public mutating func appendInterpolation(raw sql: some LosslessStringConvertible) {
203-
segments.append(.sql(sql.description))
204+
appendLiteral(sql.description)
204205
}
205206

206207
/// Append a query binding to the interpolation.

Sources/StructuredQueriesCore/SQLite/FTS5.swift

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,73 @@ import IssueReporting
22

33
/// A virtual table using the FTS5 extension.
44
///
5-
/// Apply this protocol to a `@Table` declaration to introduce FTS5 helpers.
5+
/// Apply this protocol to a `@Table` declaration to introduce [FTS5] helpers.
6+
///
7+
/// [FTS5]: https://www.sqlite.org/fts5.html
68
public protocol FTS5: Table {}
79

810
extension TableDefinition where QueryValue: FTS5 {
11+
/// A BM25 ranking function for the given column-accuracy mapping.
12+
///
13+
/// - Parameter rankings: A dictionary mapping columns to accuracy of a match.
14+
/// - Returns: A BM25 ranking function.
15+
public func bm25(
16+
_ rankings: KeyValuePairs<PartialKeyPath<Self>, Double> = [:]
17+
) -> some QueryExpression<Double> {
18+
var queryFragments: [QueryFragment] = ["\(QueryValue.self)"]
19+
if !rankings.isEmpty {
20+
var columnNameToRanking: QueryFragment = """
21+
CASE "name"
22+
"""
23+
for (keyPath, ranking) in rankings {
24+
guard let column = self[keyPath: keyPath] as? any WritableTableColumnExpression
25+
else {
26+
reportIssue(
27+
"""
28+
Key path cannot be used in 'bm25' function: \(keyPath)
29+
30+
Must be a key path to a table column on '\(QueryValue.self)'.
31+
"""
32+
)
33+
continue
34+
}
35+
columnNameToRanking.append(
36+
"""
37+
WHEN \(bind: column.name) THEN \(ranking)
38+
"""
39+
)
40+
}
41+
columnNameToRanking.append(" ELSE 1 END")
42+
for offset in Self.writableColumns.indices {
43+
queryFragments.append(
44+
"""
45+
(SELECT \(columnNameToRanking) \
46+
FROM pragma_table_info(\(quote: QueryValue.tableName, delimiter: .text)) \
47+
WHERE "cid" = \(offset))
48+
"""
49+
)
50+
}
51+
}
52+
return SQLQueryExpression("bm25(\(queryFragments.joined(separator: ", ")))")
53+
}
54+
955
/// A predicate expression from this table matched against another _via_ the `MATCH` operator.
1056
///
1157
/// ```swift
1258
/// ReminderText.where { $0.match("get") }
1359
/// // SELECT … FROM "reminderTexts" WHERE ("reminderTexts" MATCH 'get')
1460
/// ```
1561
///
62+
/// > Important: Avoid passing a string entered by the user directly to this operator. FTS5
63+
/// > queries have a distinct [syntax] that can specify particular columns and refine a search in
64+
/// > various ways. If FTS5 is given a query with invalid syntax, it can even throw SQL errors at
65+
/// > runtime.
66+
/// >
67+
/// > Instead, consider transforming the user's input into a query by quoting, prefixing, and/or
68+
/// > combining inputs from your UI into a valid query before handing it off to SQLite.
69+
///
70+
/// [syntax]: https://www.sqlite.org/fts5.html#full_text_query_syntax
71+
///
1672
/// - Parameter pattern: A string expression describing the `MATCH` pattern.
1773
/// - Returns: A predicate expression.
1874
public func match(_ pattern: some StringProtocol) -> some QueryExpression<Bool> {
@@ -60,7 +116,7 @@ extension TableColumnExpression where Root: FTS5 {
60116
///
61117
/// ```swift
62118
/// ReminderText.where { $0.title.match("get") }
63-
/// // SELECT … FROM "reminderTexts" WHERE ("reminderTexts"."title" MATCH 'get')
119+
/// // SELECT … FROM "reminderTexts" WHERE ("reminderTexts" MATCH 'title:"get"')
64120
/// ```
65121
///
66122
/// - Parameter pattern: A string expression describing the `MATCH` pattern.

Tests/StructuredQueriesTests/FTSTests.swift

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,57 @@ extension SnapshotTests {
8787
}
8888
}
8989

90+
@Test func bm25() {
91+
assertQuery(
92+
ReminderText
93+
.where { $0.match("Week") }
94+
.order { $0.bm25([\.title: 10, \.notes: 5, \.tags: 2]) }
95+
) {
96+
"""
97+
SELECT "reminderTexts"."reminderID", "reminderTexts"."title", "reminderTexts"."notes", "reminderTexts"."listID", "reminderTexts"."listTitle", "reminderTexts"."tags"
98+
FROM "reminderTexts"
99+
WHERE ("reminderTexts" MATCH 'Week')
100+
ORDER BY bm25("reminderTexts", (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 0), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 1), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 2), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 3), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 4), (SELECT CASE "name" WHEN 'title' THEN 10.0 WHEN 'notes' THEN 5.0 WHEN 'tags' THEN 2.0 ELSE 1 END FROM pragma_table_info('reminderTexts') WHERE "cid" = 5))
101+
"""
102+
} results: {
103+
"""
104+
┌────────────────────────────────┐
105+
│ ReminderText( │
106+
│ reminderID: 10, │
107+
│ title: "Send weekly emails", │
108+
│ notes: "", │
109+
│ listID: 3, │
110+
│ listTitle: "Business", │
111+
│ tags: ""
112+
│ ) │
113+
└────────────────────────────────┘
114+
"""
115+
}
116+
assertQuery(
117+
ReminderText
118+
.where { $0.match("Week") }
119+
.order { $0.bm25() }
120+
) {
121+
"""
122+
SELECT "reminderTexts"."reminderID", "reminderTexts"."title", "reminderTexts"."notes", "reminderTexts"."listID", "reminderTexts"."listTitle", "reminderTexts"."tags"
123+
FROM "reminderTexts"
124+
WHERE ("reminderTexts" MATCH 'Week')
125+
ORDER BY bm25("reminderTexts")
126+
"""
127+
} results: {
128+
"""
129+
┌────────────────────────────────┐
130+
│ ReminderText( │
131+
│ reminderID: 10, │
132+
│ title: "Send weekly emails", │
133+
│ notes: "", │
134+
│ listID: 3, │
135+
│ listTitle: "Business", │
136+
│ tags: ""
137+
│ ) │
138+
└────────────────────────────────┘
139+
"""
140+
}
141+
}
90142
}
91143
}

0 commit comments

Comments
 (0)