Skip to content

Commit 9a767c4

Browse files
authored
Fix Table.count(filter:), etc. (pointfreeco#66)
While technically a breaking change for folks specifying the parameter, the parameter wasn't really useful on its own without reaching out to global columns state, so I think it's safe to fix these APIs and consider the previous behavior a bug.
1 parent fc09db9 commit 9a767c4

File tree

5 files changed

+76
-29
lines changed

5 files changed

+76
-29
lines changed

Sources/StructuredQueriesCore/AggregateFunctions.swift

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ extension QueryExpression where QueryValue: QueryBindable {
2020
distinct isDistinct: Bool = false,
2121
filter: (some QueryExpression<Bool>)? = Bool?.none
2222
) -> some QueryExpression<Int> {
23-
AggregateFunction("count", isDistinct: isDistinct, self, filter: filter)
23+
AggregateFunction(
24+
"count",
25+
isDistinct: isDistinct,
26+
[queryFragment],
27+
filter: filter?.queryFragment
28+
)
2429
}
2530
}
2631

@@ -46,11 +51,12 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin
4651
order: (some QueryExpression)? = Bool?.none,
4752
filter: (some QueryExpression<Bool>)? = Bool?.none
4853
) -> some QueryExpression<String?> {
49-
if let separator {
50-
return AggregateFunction("group_concat", self, separator, order: order, filter: filter)
51-
} else {
52-
return AggregateFunction("group_concat", self, order: order, filter: filter)
53-
}
54+
AggregateFunction(
55+
"group_concat",
56+
separator.map { [queryFragment, $0.queryFragment] } ?? [queryFragment],
57+
order: order?.queryFragment,
58+
filter: filter?.queryFragment
59+
)
5460
}
5561

5662
/// A string concatenation aggregate of this expression.
@@ -68,7 +74,13 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin
6874
order: (some QueryExpression)? = Bool?.none,
6975
filter: (some QueryExpression<Bool>)? = Bool?.none
7076
) -> some QueryExpression<String?> {
71-
AggregateFunction("group_concat", isDistinct: isDistinct, self, order: order, filter: filter)
77+
AggregateFunction(
78+
"group_concat",
79+
isDistinct: isDistinct,
80+
[queryFragment],
81+
order: order?.queryFragment,
82+
filter: filter?.queryFragment
83+
)
7284
}
7385
}
7486

@@ -85,7 +97,7 @@ extension QueryExpression where QueryValue: QueryBindable {
8597
public func max(
8698
filter: (some QueryExpression<Bool>)? = Bool?.none
8799
) -> some QueryExpression<Int?> {
88-
AggregateFunction("max", self, filter: filter)
100+
AggregateFunction("max", [queryFragment], filter: filter?.queryFragment)
89101
}
90102

91103
/// A minimum aggregate of this expression.
@@ -100,7 +112,7 @@ extension QueryExpression where QueryValue: QueryBindable {
100112
public func min(
101113
filter: (some QueryExpression<Bool>)? = Bool?.none
102114
) -> some QueryExpression<Int?> {
103-
AggregateFunction("min", self, filter: filter)
115+
AggregateFunction("min", [queryFragment], filter: filter?.queryFragment)
104116
}
105117
}
106118

@@ -122,7 +134,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
122134
distinct isDistinct: Bool = false,
123135
filter: (some QueryExpression<Bool>)? = Bool?.none
124136
) -> some QueryExpression<Double?> {
125-
AggregateFunction("avg", isDistinct: isDistinct, self, filter: filter)
137+
AggregateFunction("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment)
126138
}
127139

128140
/// An sum aggregate of this expression.
@@ -145,7 +157,10 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
145157
// TODO: Report issue to Swift team.
146158
SQLQueryExpression(
147159
AggregateFunction<QueryValue._Optionalized>(
148-
"sum", isDistinct: isDistinct, self, filter: filter
160+
"sum",
161+
isDistinct: isDistinct,
162+
[queryFragment],
163+
filter: filter?.queryFragment
149164
)
150165
.queryFragment
151166
)
@@ -167,7 +182,12 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
167182
distinct isDistinct: Bool = false,
168183
filter: (some QueryExpression<Bool>)? = Bool?.none
169184
) -> some QueryExpression<QueryValue> {
170-
AggregateFunction("total", isDistinct: isDistinct, self, filter: filter)
185+
AggregateFunction(
186+
"total",
187+
isDistinct: isDistinct,
188+
[queryFragment],
189+
filter: filter?.queryFragment
190+
)
171191
}
172192
}
173193

@@ -182,9 +202,9 @@ extension QueryExpression where Self == AggregateFunction<Int> {
182202
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
183203
/// - Returns: A `count(*)` aggregate.
184204
public static func count(
185-
filter: (some QueryExpression<Bool>)? = Bool?.none
205+
filter: (any QueryExpression<Bool>)? = nil
186206
) -> Self {
187-
AggregateFunction("count", SQLQueryExpression("*"), filter: filter)
207+
AggregateFunction("count", ["*"], filter: filter?.queryFragment)
188208
}
189209
}
190210

@@ -196,18 +216,18 @@ public struct AggregateFunction<QueryValue>: QueryExpression {
196216
var order: QueryFragment?
197217
var filter: QueryFragment?
198218

199-
init<each Argument: QueryExpression>(
219+
init(
200220
_ name: QueryFragment,
201221
isDistinct: Bool = false,
202-
_ arguments: repeat each Argument,
203-
order: (some QueryExpression)? = Bool?.none,
204-
filter: (some QueryExpression)? = Bool?.none
222+
_ arguments: [QueryFragment] = [],
223+
order: QueryFragment? = nil,
224+
filter: QueryFragment? = nil
205225
) {
206226
self.name = name
207227
self.isDistinct = isDistinct
208-
self.arguments = Array(repeat each arguments)
209-
self.order = order?.queryFragment
210-
self.filter = filter?.queryFragment
228+
self.arguments = arguments
229+
self.order = order
230+
self.filter = filter
211231
}
212232

213233
public var queryFragment: QueryFragment {

Sources/StructuredQueriesCore/SQLite/JSONFunctions.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ extension QueryExpression where QueryValue: Codable & QueryBindable & Sendable {
3737
filter: (some QueryExpression<Bool>)? = Bool?.none
3838
) -> some QueryExpression<[QueryValue].JSONRepresentation> {
3939
AggregateFunction(
40-
"json_group_array", isDistinct: isDistinct, self, order: order, filter: filter)
40+
"json_group_array",
41+
isDistinct: isDistinct,
42+
[queryFragment],
43+
order: order?.queryFragment,
44+
filter: filter?.queryFragment
45+
)
4146
}
4247
}
4348

@@ -98,7 +103,12 @@ extension PrimaryKeyedTableDefinition where QueryValue: Codable & Sendable {
98103
filter: (some QueryExpression<Bool>)? = Bool?.none
99104
) -> some QueryExpression<[QueryValue].JSONRepresentation> {
100105
AggregateFunction(
101-
"json_group_array", isDistinct: isDistinct, jsonObject, order: order, filter: filter)
106+
"json_group_array",
107+
isDistinct: isDistinct,
108+
[jsonObject.queryFragment],
109+
order: order?.queryFragment,
110+
filter: filter?.queryFragment
111+
)
102112
}
103113

104114
private var jsonObject: some QueryExpression<QueryValue> {

Sources/StructuredQueriesCore/Statements/Select.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ extension Table {
295295
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
296296
/// - Returns: A select statement that selects `count(*)`.
297297
public static func count(
298-
filter: (some QueryExpression<Bool>)? = Bool?.none
298+
filter: ((TableColumns) -> any QueryExpression<Bool>)? = nil
299299
) -> Select<Int, Self, ()> {
300300
Where().count(filter: filter)
301301
}
@@ -1321,23 +1321,25 @@ extension Select {
13211321
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
13221322
/// - Returns: A new select statement that selects `count(*)`.
13231323
public func count<each J: Table>(
1324-
filter: (some QueryExpression<Bool>)? = Bool?.none
1324+
filter: ((From.TableColumns, repeat (each J).TableColumns) -> any QueryExpression<Bool>)? = nil
13251325
) -> Select<Int, From, (repeat each J)>
13261326
where Columns == (), Joins == (repeat each J) {
1327-
select { _ in .count(filter: filter) }
1327+
let filter = filter?(From.columns, repeat (each J).columns)
1328+
return select { _ in .count(filter: filter) }
13281329
}
13291330

13301331
/// Creates a new select statement from this one by appending `count(*)` to its selection.
13311332
///
13321333
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
13331334
/// - Returns: A new select statement that selects `count(*)`.
13341335
public func count<each C: QueryRepresentable, each J: Table>(
1335-
filter: (some QueryExpression<Bool>)? = Bool?.none
1336+
filter: ((From.TableColumns, repeat (each J).TableColumns) -> any QueryExpression<Bool>)? = nil
13361337
) -> Select<
13371338
(repeat each C, Int), From, (repeat each J)
13381339
>
13391340
where Columns == (repeat each C), Joins == (repeat each J) {
1340-
select { _ in .count(filter: filter) }
1341+
let filter = filter?(From.columns, repeat (each J).columns)
1342+
return select { _ in .count(filter: filter) }
13411343
}
13421344

13431345
/// Creates a new select statement from this one by transforming its selected columns to a new

Sources/StructuredQueriesCore/Statements/Where.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ extension Where: SelectStatement {
462462
/// - Parameter filter: A `FILTER` clause to apply to the aggregation.
463463
/// - Returns: A select statement that selects `count(*)`.
464464
public func count(
465-
filter: (some QueryExpression<Bool>)? = Bool?.none
465+
filter: ((From.TableColumns) -> any QueryExpression<Bool>)? = nil
466466
) -> Select<Int, From, ()> {
467467
asSelect().count(filter: filter)
468468
}

Tests/StructuredQueriesTests/SelectTests.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,21 @@ extension SnapshotTests {
857857
}
858858
}
859859

860+
@Test func countFilter() {
861+
assertQuery(Reminder.count { !$0.isCompleted }) {
862+
"""
863+
SELECT count(*) FILTER (WHERE NOT ("reminders"."isCompleted"))
864+
FROM "reminders"
865+
"""
866+
} results: {
867+
"""
868+
┌───┐
869+
│ 7 │
870+
└───┘
871+
"""
872+
}
873+
}
874+
860875
@Test func map() {
861876
assertQuery(Reminder.limit(1).select { ($0.id, $0.title) }.map { ($1, $0) }) {
862877
"""

0 commit comments

Comments
 (0)