Skip to content

Commit 6bd8700

Browse files
Better support for optionals in queries (#61)
* Better support for optionals in queries * Predicates (`where` and `having`) can return optional booleans. * Add `QueryExpression<Optional>.map` for optionally building queries on non-optional values. For example, a comparison that might use the `#sql` macro as an escape hatch can more safely and succinctly use `map`: ```diff Reminder.where { - #sql("\($0.dueDate) < \(Date())") + $0.dueDate.map { $0 < Date() } } ``` * wip * wip * wip * wip * wip * Add test for where clause with optional * wip --------- Co-authored-by: Brandon Williams <[email protected]>
1 parent 0bd0c10 commit 6bd8700

File tree

8 files changed

+70
-16
lines changed

8 files changed

+70
-16
lines changed

Sources/StructuredQueriesCore/QueryFragmentBuilder.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ extension QueryFragmentBuilder<Bool> {
2727
) -> [QueryFragment] {
2828
[expression.queryFragment]
2929
}
30+
31+
public static func buildExpression(
32+
_ expression: some QueryExpression<some _OptionalPromotable<Bool?>>
33+
) -> [QueryFragment] {
34+
[expression.queryFragment]
35+
}
3036
}
3137

3238
extension QueryFragmentBuilder<()> {

Sources/StructuredQueriesCore/ScalarFunctions.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ extension QueryExpression where QueryValue: FloatingPoint {
120120
}
121121
}
122122

123-
extension QueryExpression where QueryValue: Numeric {
123+
extension QueryExpression
124+
where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric {
124125
/// Wraps this numeric query expression with the `abs` function.
125126
///
126127
/// - Returns: An expression wrapped with the `abs` function.
@@ -251,14 +252,18 @@ extension QueryExpression where QueryValue == String {
251252
public func instr(_ occurrence: some QueryExpression<QueryValue>) -> some QueryExpression<Int> {
252253
QueryFunction("instr", self, occurrence)
253254
}
255+
}
254256

257+
extension QueryExpression where QueryValue: _OptionalPromotable<String?> {
255258
/// Wraps this string expression with the `lower` function.
256259
///
257260
/// - Returns: An expression wrapped with the `lower` function.
258261
public func lower() -> some QueryExpression<QueryValue> {
259262
QueryFunction("lower", self)
260263
}
264+
}
261265

266+
extension QueryExpression where QueryValue == String {
262267
/// Wraps this string expression with the `ltrim` function.
263268
///
264269
/// - Parameter characters: Characters to trim.
@@ -279,14 +284,18 @@ extension QueryExpression where QueryValue == String {
279284
public func octetLength() -> some QueryExpression<Int> {
280285
QueryFunction("octet_length", self)
281286
}
287+
}
282288

289+
extension QueryExpression where QueryValue: _OptionalPromotable<String?> {
283290
/// Wraps this string expression with the `quote` function.
284291
///
285292
/// - Returns: An expression wrapped with the `quote` function.
286293
public func quote() -> some QueryExpression<QueryValue> {
287294
QueryFunction("quote", self)
288295
}
296+
}
289297

298+
extension QueryExpression where QueryValue == String {
290299
/// Creates an expression invoking the `replace` function.
291300
///
292301
/// - Parameters:
@@ -346,13 +355,15 @@ extension QueryExpression where QueryValue == String {
346355
return QueryFunction("trim", self)
347356
}
348357
}
358+
}
349359

360+
extension QueryExpression where QueryValue: _OptionalPromotable<String?> {
350361
/// Wraps this string query expression with the `unhex` function.
351362
///
352363
/// - Parameter characters: Non-hexadecimal characters to skip.
353364
/// - Returns: An optional blob expression of the `unhex` function wrapping this expression.
354365
public func unhex(
355-
_ characters: (some QueryExpression<QueryValue>)? = QueryValue?.none
366+
_ characters: (some QueryExpression<String>)? = String?.none
356367
) -> some QueryExpression<[UInt8]?> {
357368
if let characters {
358369
return QueryFunction("unhex", self, characters)

Sources/StructuredQueriesCore/Statements/Delete.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ public struct Delete<From: Table, Returning> {
4848
///
4949
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
5050
/// - Returns: A statement with the added predicate.
51-
public func `where`(_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>) -> Self {
51+
public func `where`(
52+
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
53+
) -> Self {
5254
var update = self
5355
update.where.append(From.columns[keyPath: keyPath].queryFragment)
5456
return update
@@ -64,7 +66,9 @@ public struct Delete<From: Table, Returning> {
6466
/// - Parameter predicate: A closure that returns a Boolean expression to filter by.
6567
/// - Returns: A statement with the added predicate.
6668
@_disfavoredOverload
67-
public func `where`(_ predicate: (From.TableColumns) -> some QueryExpression<Bool>) -> Self {
69+
public func `where`(
70+
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
71+
) -> Self {
6872
var update = self
6973
update.where.append(predicate(From.columns).queryFragment)
7074
return update

Sources/StructuredQueriesCore/Statements/Select.swift

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ extension Table {
236236
/// columns.
237237
/// - Returns: A select statement that is filtered by the given predicate.
238238
public static func having(
239-
_ predicate: (TableColumns) -> some QueryExpression<Bool>
239+
_ predicate: (TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
240240
) -> SelectOf<Self> {
241241
Where().having(predicate)
242242
}
@@ -1110,7 +1110,7 @@ extension Select {
11101110
/// - Parameter keyPath: A key path from this select's table to a Boolean expression to filter by.
11111111
/// - Returns: A new select statement that appends the given predicate to its `WHERE` clause.
11121112
public func `where`(
1113-
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>
1113+
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
11141114
) -> Self
11151115
where Joins == () {
11161116
var select = self
@@ -1125,7 +1125,9 @@ extension Select {
11251125
/// - Returns: A new select statement that appends the given predicate to its `WHERE` clause.
11261126
@_disfavoredOverload
11271127
public func `where`<each J: Table>(
1128-
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<Bool>
1128+
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<
1129+
some _OptionalPromotable<Bool?>
1130+
>
11291131
) -> Self
11301132
where Joins == (repeat each J) {
11311133
var select = self
@@ -1218,7 +1220,9 @@ extension Select {
12181220
/// - Returns: A new select statement that appends the given predicate to its `HAVING` clause.
12191221
@_disfavoredOverload
12201222
public func having<each J: Table>(
1221-
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<Bool>
1223+
_ predicate: (From.TableColumns, repeat (each J).TableColumns) -> some QueryExpression<
1224+
some _OptionalPromotable<Bool?>
1225+
>
12221226
) -> Self
12231227
where Joins == (repeat each J) {
12241228
var select = self

Sources/StructuredQueriesCore/Statements/SelectStatement.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public typealias SelectStatementOf<From: Table, each Join: Table> =
4747

4848
extension SelectStatement {
4949
public static func `where`<From>(
50-
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
50+
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
5151
) -> Self
5252
where Self == Where<From> {
5353
Self(predicates: [predicate(From.columns).queryFragment])

Sources/StructuredQueriesCore/Statements/Update.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ public struct Update<From: Table, Returning> {
105105
///
106106
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
107107
/// - Returns: A statement with the added predicate.
108-
public func `where`(_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>) -> Self {
108+
public func `where`(
109+
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
110+
) -> Self {
109111
var update = self
110112
update.where.append(From.columns[keyPath: keyPath].queryFragment)
111113
return update
@@ -117,7 +119,7 @@ public struct Update<From: Table, Returning> {
117119
/// - Returns: A statement with the added predicate.
118120
@_disfavoredOverload
119121
public func `where`(
120-
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
122+
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
121123
) -> Self {
122124
var update = self
123125
update.where.append(predicate(From.columns).queryFragment)

Sources/StructuredQueriesCore/Statements/Where.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extension Table {
2020
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
2121
/// - Returns: A `WHERE` clause.
2222
public static func `where`(
23-
_ keyPath: KeyPath<TableColumns, some QueryExpression<Bool>>
23+
_ keyPath: KeyPath<TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
2424
) -> Where<Self> {
2525
Where(predicates: [columns[keyPath: keyPath].queryFragment])
2626
}
@@ -33,7 +33,7 @@ extension Table {
3333
/// - Returns: A `WHERE` clause.
3434
@_disfavoredOverload
3535
public static func `where`(
36-
_ predicate: (TableColumns) -> some QueryExpression<Bool>
36+
_ predicate: (TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
3737
) -> Where<Self> {
3838
Where(predicates: [predicate(columns).queryFragment])
3939
}
@@ -292,7 +292,7 @@ extension Where: SelectStatement {
292292
/// - Parameter keyPath: A key path to a Boolean expression to filter by.
293293
/// - Returns: A where clause with the added predicate.
294294
public func `where`(
295-
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<Bool>>
295+
_ keyPath: KeyPath<From.TableColumns, some QueryExpression<some _OptionalPromotable<Bool?>>>
296296
) -> Self {
297297
var `where` = self
298298
`where`.predicates.append(From.columns[keyPath: keyPath].queryFragment)
@@ -305,7 +305,7 @@ extension Where: SelectStatement {
305305
/// - Returns: A where clause with the added predicate.
306306
@_disfavoredOverload
307307
public func `where`(
308-
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
308+
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
309309
) -> Self {
310310
var `where` = self
311311
`where`.predicates.append(predicate(From.columns).queryFragment)
@@ -409,7 +409,7 @@ extension Where: SelectStatement {
409409

410410
/// A select statement for the filtered table with the given `HAVING` clause.
411411
public func having(
412-
_ predicate: (From.TableColumns) -> some QueryExpression<Bool>
412+
_ predicate: (From.TableColumns) -> some QueryExpression<some _OptionalPromotable<Bool?>>
413413
) -> SelectOf<From> {
414414
asSelect().having(predicate)
415415
}

Tests/StructuredQueriesTests/WhereTests.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import Dependencies
12
import Foundation
23
import InlineSnapshotTesting
34
import StructuredQueries
5+
import StructuredQueriesSQLite
46
import Testing
57

68
extension SnapshotTests {
@@ -109,5 +111,30 @@ extension SnapshotTests {
109111
"""
110112
}
111113
}
114+
115+
@Test func optionalBoolean() throws {
116+
@Dependency(\.defaultDatabase) var db
117+
let remindersListIDs = try db.execute(
118+
RemindersList.insert {
119+
RemindersList.Draft(title: "New list")
120+
}
121+
.returning(\.id)
122+
)
123+
let remindersListID = try #require(remindersListIDs.first)
124+
125+
assertQuery(
126+
RemindersList
127+
.find(remindersListID)
128+
.leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) }
129+
.where { $1.isCompleted }
130+
) {
131+
"""
132+
SELECT "remindersLists"."id", "remindersLists"."color", "remindersLists"."title", "reminders"."id", "reminders"."assignedUserID", "reminders"."dueDate", "reminders"."isCompleted", "reminders"."isFlagged", "reminders"."notes", "reminders"."priority", "reminders"."remindersListID", "reminders"."title"
133+
FROM "remindersLists"
134+
LEFT JOIN "reminders" ON ("remindersLists"."id" = "reminders"."remindersListID")
135+
WHERE ("remindersLists"."id" = 4) AND "reminders"."isCompleted"
136+
"""
137+
}
138+
}
112139
}
113140
}

0 commit comments

Comments
 (0)