Skip to content

Commit 3c8af53

Browse files
User-defined aggregate functions (#207)
* wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * Added a test for mode aggregation. * more docs * wip --------- Co-authored-by: Brandon Williams <[email protected]>
1 parent b4fadef commit 3c8af53

File tree

10 files changed

+1247
-127
lines changed

10 files changed

+1247
-127
lines changed

Sources/StructuredQueriesCore/AggregateFunctions.swift

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extension QueryExpression where QueryValue: QueryBindable {
2020
distinct isDistinct: Bool = false,
2121
filter: (some QueryExpression<Bool>)? = Bool?.none
2222
) -> some QueryExpression<Int> {
23-
AggregateFunction(
23+
AggregateFunctionExpression(
2424
"count",
2525
isDistinct: isDistinct,
2626
[queryFragment],
@@ -51,7 +51,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin
5151
order: (some QueryExpression)? = Bool?.none,
5252
filter: (some QueryExpression<Bool>)? = Bool?.none
5353
) -> some QueryExpression<String?> {
54-
AggregateFunction(
54+
AggregateFunctionExpression(
5555
"group_concat",
5656
separator.map { [queryFragment, $0.queryFragment] } ?? [queryFragment],
5757
order: order?.queryFragment,
@@ -74,7 +74,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin
7474
order: (some QueryExpression)? = Bool?.none,
7575
filter: (some QueryExpression<Bool>)? = Bool?.none
7676
) -> some QueryExpression<String?> {
77-
AggregateFunction(
77+
AggregateFunctionExpression(
7878
"group_concat",
7979
isDistinct: isDistinct,
8080
[queryFragment],
@@ -97,7 +97,7 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalPromotable
9797
public func max(
9898
filter: (some QueryExpression<Bool>)? = Bool?.none
9999
) -> some QueryExpression<QueryValue._Optionalized.Wrapped?> {
100-
AggregateFunction("max", [queryFragment], filter: filter?.queryFragment)
100+
AggregateFunctionExpression("max", [queryFragment], filter: filter?.queryFragment)
101101
}
102102

103103
/// A minimum aggregate of this expression.
@@ -112,7 +112,7 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalPromotable
112112
public func min(
113113
filter: (some QueryExpression<Bool>)? = Bool?.none
114114
) -> some QueryExpression<QueryValue._Optionalized.Wrapped?> {
115-
AggregateFunction("min", [queryFragment], filter: filter?.queryFragment)
115+
AggregateFunctionExpression("min", [queryFragment], filter: filter?.queryFragment)
116116
}
117117
}
118118

@@ -134,7 +134,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
134134
distinct isDistinct: Bool = false,
135135
filter: (some QueryExpression<Bool>)? = Bool?.none
136136
) -> some QueryExpression<Double?> {
137-
AggregateFunction("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment)
137+
AggregateFunctionExpression("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment)
138138
}
139139

140140
/// An sum aggregate of this expression.
@@ -156,7 +156,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
156156
// NB: We must explicitly erase here to avoid a runtime crash with opaque return types
157157
// TODO: Report issue to Swift team.
158158
SQLQueryExpression(
159-
AggregateFunction<QueryValue._Optionalized>(
159+
AggregateFunctionExpression<QueryValue._Optionalized>(
160160
"sum",
161161
isDistinct: isDistinct,
162162
[queryFragment],
@@ -182,7 +182,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
182182
distinct isDistinct: Bool = false,
183183
filter: (some QueryExpression<Bool>)? = Bool?.none
184184
) -> some QueryExpression<QueryValue> {
185-
AggregateFunction(
185+
AggregateFunctionExpression(
186186
"total",
187187
isDistinct: isDistinct,
188188
[queryFragment],
@@ -191,7 +191,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric
191191
}
192192
}
193193

194-
extension QueryExpression where Self == AggregateFunction<Int> {
194+
extension QueryExpression where Self == AggregateFunctionExpression<Int> {
195195
/// A `count(*)` aggregate.
196196
///
197197
/// ```swift
@@ -204,18 +204,34 @@ extension QueryExpression where Self == AggregateFunction<Int> {
204204
public static func count(
205205
filter: (any QueryExpression<Bool>)? = nil
206206
) -> Self {
207-
AggregateFunction("count", ["*"], filter: filter?.queryFragment)
207+
AggregateFunctionExpression("count", ["*"], filter: filter?.queryFragment)
208208
}
209209
}
210210

211211
/// A query expression of an aggregate function.
212-
public struct AggregateFunction<QueryValue>: QueryExpression, Sendable {
212+
public struct AggregateFunctionExpression<QueryValue>: QueryExpression, Sendable {
213213
var name: QueryFragment
214214
var isDistinct: Bool
215215
var arguments: [QueryFragment]
216216
var order: QueryFragment?
217217
var filter: QueryFragment?
218218

219+
public init<each Argument: QueryExpression>(
220+
_ name: String,
221+
distinct isDistinct: Bool = false,
222+
_ arguments: repeat each Argument,
223+
order: (some QueryExpression)? = Bool?.none,
224+
filter: (some QueryExpression<Bool>)? = Bool?.none
225+
) {
226+
self.init(
227+
QueryFragment(quote: name),
228+
isDistinct: isDistinct,
229+
Array(repeat each arguments),
230+
order: order?.queryFragment,
231+
filter: filter?.queryFragment
232+
)
233+
}
234+
219235
package init(
220236
_ name: QueryFragment,
221237
isDistinct: Bool = false,

Sources/StructuredQueriesCore/ScalarFunctions.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ extension QueryExpression where QueryValue == [UInt8] {
319319
}
320320
}
321321

322-
/// A query expression of a generalized query function.
323-
public struct QueryFunction<QueryValue>: QueryExpression {
322+
package struct QueryFunction<QueryValue>: QueryExpression {
324323
let name: QueryFragment
325324
let arguments: [QueryFragment]
326325

Sources/StructuredQueriesSQLite/Macros.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,41 @@ public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression>(
5454
module: "StructuredQueriesSQLiteMacros",
5555
type: "DatabaseFunctionMacro"
5656
)
57+
58+
/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction``
59+
/// protocol.
60+
///
61+
/// - Parameters
62+
/// - name: The function's name. Defaults to the name of the function the macro is applied to.
63+
/// - representableFunctionType: The function as represented in a query.
64+
/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially
65+
/// transparent"), _i.e._ given an input it will always return the same output.
66+
@attached(peer, names: overloaded, prefixed(`$`))
67+
public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression, R: QueryBindable>(
68+
_ name: String = "",
69+
as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type,
70+
isDeterministic: Bool = false
71+
) =
72+
#externalMacro(
73+
module: "StructuredQueriesSQLiteMacros",
74+
type: "DatabaseFunctionMacro"
75+
)
76+
77+
/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction``
78+
/// protocol.
79+
///
80+
/// - Parameters
81+
/// - name: The function's name. Defaults to the name of the function the macro is applied to.
82+
/// - representableFunctionType: The function as represented in a query.
83+
/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially
84+
/// transparent"), _i.e._ given an input it will always return the same output.
85+
@attached(peer, names: overloaded, prefixed(`$`))
86+
public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression>(
87+
_ name: String = "",
88+
as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type,
89+
isDeterministic: Bool = false
90+
) =
91+
#externalMacro(
92+
module: "StructuredQueriesSQLiteMacros",
93+
type: "DatabaseFunctionMacro"
94+
)

Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/// A type representing a database function.
22
///
3-
/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate
4-
/// a conformance.
3+
/// Don't conform to this protocol directly. Instead, use the
4+
/// [`@DatabaseFunction`](<doc:CustomFunctions>) macro to generate a conformance.
55
public protocol DatabaseFunction<Input, Output> {
66
/// A type representing the function's arguments.
77
associatedtype Input
@@ -22,8 +22,8 @@ public protocol DatabaseFunction<Input, Output> {
2222

2323
/// A type representing a scalar database function.
2424
///
25-
/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate
26-
/// a conformance.
25+
/// Don't conform to this protocol directly. Instead, use the
26+
/// [`@DatabaseFunction`](<doc:CustomFunctions#Scalar-functions>) macro to generate a conformance.
2727
public protocol ScalarDatabaseFunction<Input, Output>: DatabaseFunction {
2828
/// The function body. Uses a query decoder to process the input of a database function into a
2929
/// bindable value.
@@ -50,3 +50,68 @@ extension ScalarDatabaseFunction {
5050
}
5151
}
5252
}
53+
54+
/// A type representing an aggregate database function.
55+
///
56+
/// Don't conform to this protocol directly. Instead, use the
57+
/// [`@DatabaseFunction`](<doc:CustomFunctions#Aggregate-functions>) macro to generate a
58+
/// conformance.
59+
public protocol AggregateDatabaseFunction<Input, Output>: DatabaseFunction {
60+
/// A type representing one row of input to the aggregate function.
61+
associatedtype Element = Input
62+
63+
/// Decodes a row into an element to aggregate a result from.
64+
///
65+
/// - Parameter decoder: A query decoder.
66+
/// - Returns: An element to append to the sequence sent to the aggregate function.
67+
func step(_ decoder: inout some QueryDecoder) throws -> Element
68+
69+
/// Aggregates elements into a bindable value.
70+
///
71+
/// - Parameter arguments: A sequence of elements to aggregate from.
72+
/// - Returns: A binding returned from the aggregate function.
73+
func invoke(_ arguments: some Sequence<Element>) throws -> QueryBinding
74+
}
75+
76+
extension AggregateDatabaseFunction {
77+
/// An aggregate function call expression.
78+
///
79+
/// - Parameters
80+
/// - input: Expressions representing the arguments of the function.
81+
/// - isDistinct: Whether or not to include a `DISTINCT` clause, which filters duplicates from
82+
/// the aggregation.
83+
/// - order: An `ORDER BY` clause to apply to the aggregation.
84+
/// - filter: A `FILTER` clause to apply to the aggregation.
85+
/// - Returns: An expression representing the function call.
86+
@_disfavoredOverload
87+
public func callAsFunction(
88+
_ input: some QueryExpression<Input>,
89+
distinct isDistinct: Bool = false,
90+
order: (some QueryExpression)? = Bool?.none,
91+
filter: (some QueryExpression<Bool>)? = Bool?.none
92+
) -> some QueryExpression<Output>
93+
where Input: QueryBindable {
94+
$_isSelecting.withValue(false) {
95+
AggregateFunctionExpression(name, distinct: isDistinct, input, order: order, filter: filter)
96+
}
97+
}
98+
99+
/// An aggregate function call expression.
100+
///
101+
/// - Parameters
102+
/// - input: Expressions representing the arguments of the function.
103+
/// - order: An `ORDER BY` clause to apply to the aggregation.
104+
/// - filter: A `FILTER` clause to apply to the aggregation.
105+
/// - Returns: An expression representing the function call.
106+
@_disfavoredOverload
107+
public func callAsFunction<each T: QueryExpression>(
108+
_ input: repeat each T,
109+
order: (some QueryExpression)? = Bool?.none,
110+
filter: (some QueryExpression<Bool>)? = Bool?.none
111+
) -> some QueryExpression<Output>
112+
where Input == (repeat (each T).QueryValue) {
113+
$_isSelecting.withValue(false) {
114+
AggregateFunctionExpression(name, repeat each input, order: order, filter: filter)
115+
}
116+
}
117+
}

Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ from SQLite.
55

66
## Overview
77

8+
### Scalar functions
9+
810
StructuredQueries defines a macro specifically for defining Swift functions that can be called from
911
a query. It's called `@DatabaseFunction`, and can annotate any function that works with
1012
query-representable types.
@@ -18,11 +20,14 @@ func exclaim(_ string: String) -> String {
1820
}
1921
```
2022

23+
This defines a "scalar" function, which is called on a value for each row in a query, returning its
24+
result.
25+
2126
> Note: If your project is using [default main actor isolation] then you further need to annotate
2227
> your function as `nonisolated`.
2328
[default main actor isolation]: https://github.com/swiftlang/swift-evolution/blob/main/proposals/0466-control-default-actor-isolation.md
2429

25-
And will be immediately callable in a query by prefixing the function with `$`:
30+
Once defined, the function is immediately callable in a query by prefixing the function with `$`:
2631

2732
```swift
2833
Reminder.select { $exclaim($0.title) }
@@ -52,9 +57,52 @@ configuration.prepareDatabase { db in
5257
> }
5358
> ```
5459
60+
### Aggregate functions
61+
62+
It is also possible to define a Swift function that builds a single result from multiple rows of a
63+
query. The function must simply take a _sequence_ of query-representable types.
64+
65+
For example, suppose you want to compute the most common priority used across all reminders. This
66+
computation is called the "mode" in statistics, and unfortunately SQLite does not supply such
67+
a function. But it is quite easy to write this function in plain Swift:
68+
69+
```swift
70+
@DatabaseFunction
71+
func mode(priority priorities: some Sequence<Priority?>) -> Priority? {
72+
var occurrences: [Priority: Int] = [:]
73+
for priority in priorities {
74+
guard let priority
75+
else { continue }
76+
occurrences[priority, default: 0] += 1
77+
}
78+
return occurrences.max { $0.value < $1.value }?.key
79+
}
80+
```
81+
82+
This defines an "aggregate" function, and the sequence `priorities` that is passed to it represents
83+
all of the data from the database passed to it while aggregating. It is now straightforward
84+
to compute the mode of priorities across all reminders:
85+
86+
```swift
87+
Reminder
88+
.select { $mode(priority: $0.priority) }
89+
```
90+
91+
> Tip: Be sure to install the function in the database connection as discussed in
92+
> <doc:CustomFunctions#Scalar-functions> above.
93+
94+
You can also compute the mode of priorities inside each reminders list:
95+
96+
```swift
97+
RemindersList
98+
.group(by: \.id)
99+
.leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) }
100+
.select { ($0.title, $mode(priority: $1.priority)) }
101+
```
102+
55103
### Custom representations
56104

57-
To define a type that works with a custom representation, i.e. anytime you use `@Column(as:)` in
105+
To define a type that works with a custom representation, _i.e._ anytime you use `@Column(as:)` in
58106
your data type, you can use the `as` parameter of the macro to specify those types. For example,
59107
if your model holds onto a date and you want to store that date as a
60108
[unix timestamp](<doc:Foundation/Date/UnixTimeRepresentation-struct>) (i.e. double),
@@ -93,9 +141,22 @@ func jsonArrayExclaim(_ strings: [String]) -> [String] {
93141
}
94142
```
95143

144+
It is also possible to do this with aggregate functions, too, but you must describe the sequence as
145+
an `any Sequence` instead of a `some Sequence`:
146+
147+
```swift
148+
@DatabaseFunction(
149+
as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self
150+
)
151+
func jsonJoined(_ arrays: some Sequence<[String]>) -> [String] {
152+
arrays.flatMap(\.self)
153+
}
154+
```
155+
96156
## Topics
97157

98158
### Custom functions
99159

100160
- ``DatabaseFunction``
101161
- ``ScalarDatabaseFunction``
162+
- ``AggregateDatabaseFunction``

Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ extension QueryExpression where QueryValue: Codable & QueryBindable {
4646
order: (some QueryExpression)? = Bool?.none,
4747
filter: (some QueryExpression<Bool>)? = Bool?.none
4848
) -> some QueryExpression<[QueryValue].JSONRepresentation> {
49-
AggregateFunction(
49+
AggregateFunctionExpression(
5050
"json_group_array",
5151
isDistinct: isDistinct,
5252
[queryFragment],
@@ -112,7 +112,7 @@ extension PrimaryKeyedTableDefinition where QueryValue: Codable {
112112
order: (some QueryExpression)? = Bool?.none,
113113
filter: (some QueryExpression<Bool>)? = Bool?.none
114114
) -> some QueryExpression<[QueryValue].JSONRepresentation> {
115-
AggregateFunction(
115+
AggregateFunctionExpression(
116116
"json_group_array",
117117
isDistinct: isDistinct,
118118
[jsonObject().queryFragment],
@@ -200,7 +200,7 @@ where
200200
} else {
201201
primaryKeyFilter.queryFragment
202202
}
203-
return AggregateFunction(
203+
return AggregateFunctionExpression(
204204
"json_group_array",
205205
isDistinct: isDistinct,
206206
[QueryValue.columns.jsonObject().queryFragment],

0 commit comments

Comments
 (0)