Skip to content

Commit 3d21de2

Browse files
authored
Merge branch 'main' into more-association-docs
2 parents 00d5cf3 + 87ebc7c commit 3d21de2

File tree

9 files changed

+121
-30
lines changed

9 files changed

+121
-30
lines changed

Sources/StructuredQueriesCore/AggregateFunctions.swift

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ extension QueryExpression where QueryValue: QueryBindable {
2020
distinct isDistinct: Bool = false,
2121
filter: (some QueryExpression<Bool>)? = Bool?.none
2222
) -> some QueryExpression<Int> {
23-
AggregateFunction("count", isDistinct: isDistinct, self, filter: filter)
23+
AggregateFunction(
24+
"count",
25+
isDistinct: isDistinct,
26+
[queryFragment],
27+
filter: filter?.queryFragment
28+
)
2429
}
2530
}
2631

@@ -46,11 +51,12 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin
4651
order: (some QueryExpression)? = Bool?.none,
4752
filter: (some QueryExpression<Bool>)? = Bool?.none
4853
) -> some QueryExpression<String?> {
49-
if let separator {
50-
return AggregateFunction("group_concat", self, separator, order: order, filter: filter)
51-
} else {
52-
return AggregateFunction("group_concat", self, order: order, filter: filter)
53-
}
54+
AggregateFunction(
55+
"group_concat",
56+
separator.map { [queryFragment, $0.queryFragment] } ?? [queryFragment],
57+
order: order?.queryFragment,
58+
filter: filter?.queryFragment
59+
)
5460
}
5561

5662
/// A string concatenation aggregate of this expression.
@@ -68,7 +74,13 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin
6874
order: (some QueryExpression)? = Bool?.none,
6975
filter: (some QueryExpression<Bool>)? = Bool?.none
7076
) -> some QueryExpression<String?> {
71-
AggregateFunction("group_concat", isDistinct: isDistinct, self, order: order, filter: filter)
77+
AggregateFunction(
78+
"group_concat",
79+
isDistinct: isDistinct,
80+
[queryFragment],
81+
order: order?.queryFragment,
82+
filter: filter?.queryFragment
83+
)
7284
}
7385
}
7486

@@ -85,7 +97,7 @@ extension QueryExpression where QueryValue: QueryBindable {
8597
public func max(
8698
filter: (some QueryExpression<Bool>)? = Bool?.none
8799
) -> some QueryExpression<Int?> {
88-
AggregateFunction("max", self, filter: filter)
100+
AggregateFunction("max", [queryFragment], filter: filter?.queryFragment)
89101
}
90102

91103
/// A minimum aggregate of this expression.
@@ -100,7 +112,7 @@ extension QueryExpression where QueryValue: QueryBindable {
100112
public func min(
101113
filter: (some QueryExpression<Bool>)? = Bool?.none
102114
) -> some QueryExpression<Int?> {
103-
AggregateFunction("min", self, filter: filter)
115+
AggregateFunction("min", [queryFragment], filter: filter?.queryFragment)
104116
}
105117
}
106118

@@ -122,7 +134,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
122134
distinct isDistinct: Bool = false,
123135
filter: (some QueryExpression<Bool>)? = Bool?.none
124136
) -> some QueryExpression<Double?> {
125-
AggregateFunction("avg", isDistinct: isDistinct, self, filter: filter)
137+
AggregateFunction("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment)
126138
}
127139

128140
/// An sum aggregate of this expression.
@@ -145,7 +157,10 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
145157
// TODO: Report issue to Swift team.
146158
SQLQueryExpression(
147159
AggregateFunction<QueryValue._Optionalized>(
148-
"sum", isDistinct: isDistinct, self, filter: filter
160+
"sum",
161+
isDistinct: isDistinct,
162+
[queryFragment],
163+
filter: filter?.queryFragment
149164
)
150165
.queryFragment
151166
)
@@ -167,7 +182,12 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
167182
distinct isDistinct: Bool = false,
168183
filter: (some QueryExpression<Bool>)? = Bool?.none
169184
) -> some QueryExpression<QueryValue> {
170-
AggregateFunction("total", isDistinct: isDistinct, self, filter: filter)
185+
AggregateFunction(
186+
"total",
187+
isDistinct: isDistinct,
188+
[queryFragment],
189+
filter: filter?.queryFragment
190+
)
171191
}
172192
}
173193

@@ -182,9 +202,9 @@ extension QueryExpression where Self == AggregateFunction<Int> {
182202
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
183203
/// - Returns: A `count(*)` aggregate.
184204
public static func count(
185-
filter: (some QueryExpression<Bool>)? = Bool?.none
205+
filter: (any QueryExpression<Bool>)? = nil
186206
) -> Self {
187-
AggregateFunction("count", SQLQueryExpression("*"), filter: filter)
207+
AggregateFunction("count", ["*"], filter: filter?.queryFragment)
188208
}
189209
}
190210

@@ -196,18 +216,18 @@ public struct AggregateFunction<QueryValue>: QueryExpression {
196216
var order: QueryFragment?
197217
var filter: QueryFragment?
198218

199-
init<each Argument: QueryExpression>(
219+
init(
200220
_ name: QueryFragment,
201221
isDistinct: Bool = false,
202-
_ arguments: repeat each Argument,
203-
order: (some QueryExpression)? = Bool?.none,
204-
filter: (some QueryExpression)? = Bool?.none
222+
_ arguments: [QueryFragment] = [],
223+
order: QueryFragment? = nil,
224+
filter: QueryFragment? = nil
205225
) {
206226
self.name = name
207227
self.isDistinct = isDistinct
208-
self.arguments = Array(repeat each arguments)
209-
self.order = order?.queryFragment
210-
self.filter = filter?.queryFragment
228+
self.arguments = arguments
229+
self.order = order
230+
self.filter = filter
211231
}
212232

213233
public var queryFragment: QueryFragment {

Sources/StructuredQueriesCore/Documentation.docc/Articles/DefiningYourSchema.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ that represent those database definitions.
2121
* [Custom data types](#Custom-data-types)
2222
* [RawRepresentable](#RawRepresentable)
2323
* [JSON](#JSON)
24+
* [Tagged identifiers](#Tagged-identifiers)
2425
* [Default representations for dates and UUIDs](#Default-representations-for-dates-and-UUIDs)
2526
* [Primary keyed tables](#Primary-keyed-tables)
2627
* [Ephemeral columns](#Ephemeral-columns)
@@ -298,10 +299,11 @@ This adds a new layer of type-safety when constructing queries. Previously compa
298299
`RemindersList.ID` to a `Reminder.ID` would compile just fine, even though it is a nonsensical thing
299300
to do. But now, such a comparison is a compile time error:
300301

301-
```
302+
```swift
302303
RemindersList.leftJoin(Reminder.all) {
303304
$0.id == $1.id // 🛑 Requires the types 'Reminder.ID' and 'RemindersList.ID' be equivalent
304305
}
306+
```
305307

306308
#### Default representations for dates and UUIDs
307309

Sources/StructuredQueriesCore/Operators.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ public func != <QueryValue>(
187187

188188
// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
189189
@_documentation(visibility: private)
190+
@_disfavoredOverload
190191
public func == <QueryValue: _OptionalProtocol>(
191192
lhs: any QueryExpression<QueryValue>,
192193
rhs: some QueryExpression<QueryValue.Wrapped>
@@ -196,6 +197,7 @@ public func == <QueryValue: _OptionalProtocol>(
196197

197198
// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
198199
@_documentation(visibility: private)
200+
@_disfavoredOverload
199201
public func != <QueryValue: _OptionalProtocol>(
200202
lhs: any QueryExpression<QueryValue>,
201203
rhs: some QueryExpression<QueryValue.Wrapped>

Sources/StructuredQueriesCore/Optional.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
public protocol _OptionalProtocol<Wrapped> {
22
associatedtype Wrapped
33
var _wrapped: Wrapped? { get }
4+
static var none: Self { get }
5+
static func some(_ wrapped: Wrapped) -> Self
46
}
57

68
extension Optional: _OptionalProtocol {

Sources/StructuredQueriesCore/SQLite/JSONFunctions.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ extension QueryExpression where QueryValue: Codable & QueryBindable & Sendable {
3737
filter: (some QueryExpression<Bool>)? = Bool?.none
3838
) -> some QueryExpression<[QueryValue].JSONRepresentation> {
3939
AggregateFunction(
40-
"json_group_array", isDistinct: isDistinct, self, order: order, filter: filter)
40+
"json_group_array",
41+
isDistinct: isDistinct,
42+
[queryFragment],
43+
order: order?.queryFragment,
44+
filter: filter?.queryFragment
45+
)
4146
}
4247
}
4348

@@ -98,7 +103,12 @@ extension PrimaryKeyedTableDefinition where QueryValue: Codable & Sendable {
98103
filter: (some QueryExpression<Bool>)? = Bool?.none
99104
) -> some QueryExpression<[QueryValue].JSONRepresentation> {
100105
AggregateFunction(
101-
"json_group_array", isDistinct: isDistinct, jsonObject, order: order, filter: filter)
106+
"json_group_array",
107+
isDistinct: isDistinct,
108+
[jsonObject.queryFragment],
109+
order: order?.queryFragment,
110+
filter: filter?.queryFragment
111+
)
102112
}
103113

104114
private var jsonObject: some QueryExpression<QueryValue> {

Sources/StructuredQueriesCore/Statements/Select.swift

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ extension Table {
295295
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
296296
/// - Returns: A select statement that selects `count(*)`.
297297
public static func count(
298-
filter: (some QueryExpression<Bool>)? = Bool?.none
298+
filter: ((TableColumns) -> any QueryExpression<Bool>)? = nil
299299
) -> Select<Int, Self, ()> {
300300
Where().count(filter: filter)
301301
}
@@ -444,6 +444,18 @@ extension Select {
444444
_select(selection)
445445
}
446446

447+
/// Creates a new select statement from this one by selecting the given result column.
448+
///
449+
/// - Parameter selection: A closure that selects a result column from this select's tables.
450+
/// - Returns: A new select statement that selects the given column.
451+
@_disfavoredOverload
452+
public func select<C: QueryExpression, each J: Table>(
453+
_ selection: (From.TableColumns, repeat (each J).TableColumns) -> C
454+
) -> Select<C.QueryValue, From, (repeat each J)>
455+
where Columns == (), C.QueryValue: QueryRepresentable, Joins == (repeat each J) {
456+
_select(selection)
457+
}
458+
447459
/// Creates a new select statement from this one by appending the given result column to its
448460
/// selection.
449461
///
@@ -1309,23 +1321,25 @@ extension Select {
13091321
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
13101322
/// - Returns: A new select statement that selects `count(*)`.
13111323
public func count<each J: Table>(
1312-
filter: (some QueryExpression<Bool>)? = Bool?.none
1324+
filter: ((From.TableColumns, repeat (each J).TableColumns) -> any QueryExpression<Bool>)? = nil
13131325
) -> Select<Int, From, (repeat each J)>
13141326
where Columns == (), Joins == (repeat each J) {
1315-
select { _ in .count(filter: filter) }
1327+
let filter = filter?(From.columns, repeat (each J).columns)
1328+
return select { _ in .count(filter: filter) }
13161329
}
13171330

13181331
/// Creates a new select statement from this one by appending `count(*)` to its selection.
13191332
///
13201333
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
13211334
/// - Returns: A new select statement that selects `count(*)`.
13221335
public func count<each C: QueryRepresentable, each J: Table>(
1323-
filter: (some QueryExpression<Bool>)? = Bool?.none
1336+
filter: ((From.TableColumns, repeat (each J).TableColumns) -> any QueryExpression<Bool>)? = nil
13241337
) -> Select<
13251338
(repeat each C, Int), From, (repeat each J)
13261339
>
13271340
where Columns == (repeat each C), Joins == (repeat each J) {
1328-
select { _ in .count(filter: filter) }
1341+
let filter = filter?(From.columns, repeat (each J).columns)
1342+
return select { _ in .count(filter: filter) }
13291343
}
13301344

13311345
/// Creates a new select statement from this one by transforming its selected columns to a new

Sources/StructuredQueriesCore/Statements/Where.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ extension Where: SelectStatement {
462462
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
463463
/// - Returns: A select statement that selects `count(*)`.
464464
public func count(
465-
filter: (some QueryExpression<Bool>)? = Bool?.none
465+
filter: ((From.TableColumns) -> any QueryExpression<Bool>)? = nil
466466
) -> Select<Int, From, ()> {
467467
asSelect().count(filter: filter)
468468
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import StructuredQueries
2+
3+
// NB: This is a compile-time test for a 'select' overload.
4+
@Selection
5+
private struct ReminderRow {
6+
let reminder: Reminder
7+
let isPastDue: Bool
8+
@Column(as: [String].JSONRepresentation.self)
9+
let tags: [String]
10+
}
11+
private var remindersQuery: some Statement<ReminderRow> {
12+
Reminder
13+
.limit(1)
14+
.select {
15+
ReminderRow.Columns(
16+
reminder: $0,
17+
isPastDue: true,
18+
tags: #sql("[]")
19+
)
20+
}
21+
}

Tests/StructuredQueriesTests/SelectTests.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ extension SnapshotTests {
1616
_ = Reminder.where(\.isCompleted).select(\.id)
1717
_ = Reminder.where(\.isCompleted).select { $0.id }
1818
_ = Reminder.where(\.isCompleted).select { ($0.id, $0.isCompleted) }
19+
20+
let condition1 = Int?.some(1) == 2
21+
#expect(condition1 == false)
22+
let condition2 = Int?.some(1) != 2
23+
#expect(condition2 == true)
1924
}
2025

2126
@Test func selectAll() {
@@ -852,6 +857,21 @@ extension SnapshotTests {
852857
}
853858
}
854859

860+
@Test func countFilter() {
861+
assertQuery(Reminder.count { !$0.isCompleted }) {
862+
"""
863+
SELECT count(*) FILTER (WHERE NOT ("reminders"."isCompleted"))
864+
FROM "reminders"
865+
"""
866+
} results: {
867+
"""
868+
┌───┐
869+
│ 7 │
870+
└───┘
871+
"""
872+
}
873+
}
874+
855875
@Test func map() {
856876
assertQuery(Reminder.limit(1).select { ($0.id, $0.title) }.map { ($1, $0) }) {
857877
"""

0 commit comments

Comments
 (0)